diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index dd4e19b47..1689295ef 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -1,9 +1,38 @@
-name: Test Server Core
+name: Lint & Test
on: [push, pull_request]
jobs:
- test-core:
- name: Run Core Tests
+ lint:
+ name: Lint
+ runs-on: ubuntu-latest
+
+ # We want to run on external PRs, but not on our own internal PRs as they'll be run
+ # by the push to the branch. This prevents duplicated runs on internal PRs.
+ # Some discussion of this here:
+ # https://github.community/t/duplicate-checks-on-push-and-pull-request-simultaneous-event/18012
+ if: github.event_name == 'push' || github.event.pull_request.head.repo.full_name != github.repository
+
+ steps:
+ - uses: actions/checkout@v2
+
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Install Python Packages
+ run: |
+ pip install --upgrade pip
+ pip install tox
+
+ - name: Run isort
+ run: tox -e isort
+
+ - name: Run Black
+ run: tox -e black
+
+ test:
+ name: Test
runs-on: ubuntu-latest
strategy:
fail-fast: false
diff --git a/analytics.py b/analytics.py
index 7813a9b59..30d6b542a 100644
--- a/analytics.py
+++ b/analytics.py
@@ -1,14 +1,15 @@
-
-import importlib
import contextlib
+import importlib
import os
from collections import defaultdict
+
from sqlalchemy.orm.session import Session
-from .model import ExternalIntegration
from .config import CannotLoadConfiguration
+from .model import ExternalIntegration
from .util.datetime_helpers import utc_now
+
class Analytics(object):
GLOBAL_ENABLED = None
@@ -22,15 +23,17 @@ def __init__(self, _db):
Analytics.LIBRARY_ENABLED = set()
# Find a list of all the ExternalIntegrations set up with a
# goal of analytics.
- integrations = _db.query(ExternalIntegration).filter(ExternalIntegration.goal==ExternalIntegration.ANALYTICS_GOAL)
+ integrations = _db.query(ExternalIntegration).filter(
+ ExternalIntegration.goal == ExternalIntegration.ANALYTICS_GOAL
+ )
# Turn each integration into an analytics provider.
for integration in integrations:
kwargs = {}
module = integration.protocol
- if module.startswith('.'):
+ if module.startswith("."):
# This is a relative import. Import it relative to
# this module. This should only happen during tests.
- kwargs['package'] =__name__
+ kwargs["package"] = __name__
else:
# This is an absolute import. Trust sys.path to find it.
pass
@@ -48,7 +51,9 @@ def __init__(self, _db):
self.library_providers[library.id].append(provider)
Analytics.LIBRARY_ENABLED.add(library.id)
else:
- self.initialization_exceptions[integration.id] = "Module %s does not have Provider defined." % module
+ self.initialization_exceptions[integration.id] = (
+ "Module %s does not have Provider defined." % module
+ )
except (ImportError, CannotLoadConfiguration) as e:
self.initialization_exceptions[integration.id] = e
diff --git a/app_server.py b/app_server.py
index 656a9b41b..5e745ac9d 100644
--- a/app_server.py
+++ b/app_server.py
@@ -1,51 +1,36 @@
"""Implement logic common to more than one of the Simplified applications."""
-from psycopg2 import DatabaseError
-import flask
import gzip
import json
+import logging
import os
-import sys
import subprocess
-from lxml import etree
+import sys
+import traceback
from functools import wraps
-from flask import url_for, make_response
-from flask_babel import lazy_gettext as _
from io import BytesIO
-from .util.flask_util import problem
-from .util.problem_detail import ProblemDetail
-import traceback
-import logging
-from .entrypoint import EntryPoint
-from .opds import (
- AcquisitionFeed,
- LookupAcquisitionFeed,
-)
-from .util.flask_util import OPDSFeedResponse
-from .util.opds_writer import (
- OPDSFeed,
- OPDSMessage,
-)
+
+import flask
+from flask import make_response, url_for
+from flask_babel import lazy_gettext as _
+from lxml import etree
+from psycopg2 import DatabaseError
from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import Session
-from sqlalchemy.orm.exc import (
- NoResultFound,
-)
-from .log import LogConfiguration
-from .model import (
- get_one,
- Complaint,
- Identifier,
- Patron,
-)
+
from .cdn import cdnify
from .classifier import Classifier
from .config import Configuration
-from .lane import (
- Facets,
- Pagination,
-)
+from .entrypoint import EntryPoint
+from .lane import Facets, Pagination
+from .log import LogConfiguration
+from .model import Complaint, Identifier, Patron, get_one
+from .opds import AcquisitionFeed, LookupAcquisitionFeed
from .problem_details import *
+from .util.flask_util import OPDSFeedResponse, problem
+from .util.opds_writer import OPDSFeed, OPDSMessage
+from .util.problem_detail import ProblemDetail
def cdn_url_for(*args, **kwargs):
@@ -54,8 +39,11 @@ def cdn_url_for(*args, **kwargs):
def load_facets_from_request(
- facet_config=None, worklist=None, base_class=Facets,
- base_class_constructor_kwargs=None, default_entrypoint=None
+ facet_config=None,
+ worklist=None,
+ base_class=Facets,
+ base_class_constructor_kwargs=None,
+ default_entrypoint=None,
):
"""Figure out which faceting object this request is asking for.
@@ -77,14 +65,18 @@ def load_facets_from_request(
library = flask.request.library
facet_config = facet_config or library
return base_class.from_request(
- library, facet_config, get_arg, get_header, worklist,
- default_entrypoint, **kwargs
+ library,
+ facet_config,
+ get_arg,
+ get_header,
+ worklist,
+ default_entrypoint,
+ **kwargs
)
def load_pagination_from_request(
- base_class=Pagination, base_class_constructor_kwargs=None,
- default_size=None
+ base_class=Pagination, base_class_constructor_kwargs=None, default_size=None
):
"""Figure out which Pagination object this request is asking for.
@@ -107,6 +99,7 @@ def decorated(*args, **kwargs):
if isinstance(v, ProblemDetail):
return v.response
return v
+
return decorated
@@ -123,20 +116,23 @@ def compressible(f):
though I don't know if that's the original source; it shows up in
a lot of places.
"""
+
@wraps(f)
def compressor(*args, **kwargs):
@flask.after_this_request
def compress(response):
- if (response.status_code < 200 or
- response.status_code >= 300 or
- 'Content-Encoding' in response.headers):
+ if (
+ response.status_code < 200
+ or response.status_code >= 300
+ or "Content-Encoding" in response.headers
+ ):
# Don't encode anything other than a 2xx response
# code. Don't encode a response that's
# already been encoded.
return response
- accept_encoding = flask.request.headers.get('Accept-Encoding', '')
- if not 'gzip' in accept_encoding.lower():
+ accept_encoding = flask.request.headers.get("Accept-Encoding", "")
+ if not "gzip" in accept_encoding.lower():
return response
# At this point we know we're going to be changing the
@@ -149,18 +145,19 @@ def compress(response):
response.direct_passthrough = False
buffer = BytesIO()
- gzipped = gzip.GzipFile(mode='wb', fileobj=buffer)
+ gzipped = gzip.GzipFile(mode="wb", fileobj=buffer)
gzipped.write(response.data)
gzipped.close()
response.data = buffer.getvalue()
- response.headers['Content-Encoding'] = 'gzip'
- response.vary.add('Accept-Encoding')
- response.headers['Content-Length'] = len(response.data)
+ response.headers["Content-Encoding"] = "gzip"
+ response.vary.add("Accept-Encoding")
+ response.headers["Content-Length"] = len(response.data)
return response
return f(*args, **kwargs)
+
return compressor
@@ -180,9 +177,9 @@ def handle(self, exception):
"""Something very bad has happened. Notify the client."""
# By default, when reporting errors, err on the side of
# terseness, to avoid leaking sensitive information.
- debug = self.app.config['DEBUG'] or self.debug
+ debug = self.app.config["DEBUG"] or self.debug
- if hasattr(self.app, 'manager') and hasattr(self.app.manager, '_db'):
+ if hasattr(self.app, "manager") and hasattr(self.app.manager, "_db"):
# There is an active database session.
# Use it to determine whether we are in debug mode, in
@@ -192,10 +189,12 @@ def handle(self, exception):
_db = self.app.manager._db
try:
LogConfiguration.from_configuration(_db)
- (log_level, database_log_level, handlers,
- errors) = LogConfiguration.from_configuration(
- self.app.manager._db
- )
+ (
+ log_level,
+ database_log_level,
+ handlers,
+ errors,
+ ) = LogConfiguration.from_configuration(self.app.manager._db)
debug = debug or (
LogConfiguration.DEBUG in (log_level, database_log_level)
)
@@ -215,9 +214,10 @@ def handle(self, exception):
# and let uwsgi restart it.
logging.error(
"Database error: %s Treating as fatal to avoid holding on to a tainted session!",
- exception, exc_info=exception
+ exception,
+ exc_info=exception,
)
- shutdown = flask.request.environ.get('werkzeug.server.shutdown')
+ shutdown = flask.request.environ.get("werkzeug.server.shutdown")
if shutdown:
shutdown()
else:
@@ -228,7 +228,7 @@ def handle(self, exception):
# Okay, it's not a database error. Turn it into a useful HTTP error
# response.
- if hasattr(exception, 'as_problem_detail_document'):
+ if hasattr(exception, "as_problem_detail_document"):
# This exception can be turned directly into a problem
# detail document.
document = exception.as_problem_detail_document(debug)
@@ -253,7 +253,7 @@ def handle(self, exception):
if debug:
body = tb
else:
- body = _('An internal error occured')
+ body = _("An internal error occured")
response = make_response(str(body), 500, {"Content-Type": "text/plain"})
log_method("Exception in web app: %s", exception, exc_info=exception)
@@ -262,17 +262,17 @@ def handle(self, exception):
class HeartbeatController(object):
- HEALTH_CHECK_TYPE = 'application/vnd.health+json'
- VERSION_FILENAME = '.version'
+ HEALTH_CHECK_TYPE = "application/vnd.health+json"
+ VERSION_FILENAME = ".version"
def heartbeat(self, conf_class=None):
- health_check_object = dict(status='pass')
+ health_check_object = dict(status="pass")
Conf = conf_class or Configuration
app_version = Conf.app_version()
if app_version and app_version != Conf.NO_APP_VERSION_FOUND:
- health_check_object['releaseID'] = app_version
- health_check_object['version'] = app_version.split('-')[0]
+ health_check_object["releaseID"] = app_version
+ health_check_object["version"] = app_version.split("-")[0]
data = json.dumps(health_check_object)
return make_response(data, 200, {"Content-Type": self.HEALTH_CHECK_TYPE})
@@ -290,9 +290,9 @@ def __init__(self, _db):
"""
self._db = _db
- def work_lookup(self, annotator, route_name='lookup', **process_urn_kwargs):
+ def work_lookup(self, annotator, route_name="lookup", **process_urn_kwargs):
"""Generate an OPDS feed describing works identified by identifier."""
- urns = flask.request.args.getlist('urn')
+ urns = flask.request.args.getlist("urn")
this_url = cdn_url_for(route_name, _external=True, urn=urns)
handler = self.process_urns(urns, **process_urn_kwargs)
@@ -302,7 +302,11 @@ def work_lookup(self, annotator, route_name='lookup', **process_urn_kwargs):
return handler
opds_feed = LookupAcquisitionFeed(
- self._db, "Lookup results", this_url, handler.works, annotator,
+ self._db,
+ "Lookup results",
+ this_url,
+ handler.works,
+ annotator,
precomposed_entries=handler.precomposed_entries,
)
return OPDSFeedResponse(str(opds_feed))
@@ -321,7 +325,7 @@ def process_urns(self, urns, **process_urn_kwargs):
handler.process_urns(urns, **process_urn_kwargs)
return handler
- def permalink(self, urn, annotator, route_name='work'):
+ def permalink(self, urn, annotator, route_name="work"):
"""Look up a single identifier and generate an OPDS feed.
TODO: This method is tested, but it seems unused and it
@@ -336,8 +340,12 @@ def permalink(self, urn, annotator, route_name='work'):
# list of works.
works = [work for (identifier, work) in handler.works]
opds_feed = AcquisitionFeed(
- self._db, urn, this_url, works, annotator,
- precomposed_entries=handler.precomposed_entries
+ self._db,
+ urn,
+ this_url,
+ works,
+ annotator,
+ precomposed_entries=handler.precomposed_entries,
)
return OPDSFeedResponse(str(opds_feed))
@@ -379,8 +387,7 @@ def add_urn_failure_messages(self, failures):
self.add_message(urn, 400, INVALID_URN.detail)
def process_identifier(self, identifier, urn, **kwargs):
- """Turn a URN into a Work suitable for use in an OPDS feed.
- """
+ """Turn a URN into a Work suitable for use in an OPDS feed."""
if not identifier.licensed_through:
# The default URNLookupHandler cannot look up an
# Identifier that has no associated LicensePool.
@@ -409,9 +416,7 @@ def add_entry(self, entry):
def add_message(self, urn, status_code, message):
"""An identifier lookup resulted in the creation of a message."""
- self.precomposed_entries.append(
- OPDSMessage(urn, status_code, message)
- )
+ self.precomposed_entries.append(OPDSMessage(urn, status_code, message))
def post_lookup_hook(self):
"""Run after looking up a number of Identifiers.
@@ -435,13 +440,15 @@ def register(self, license_pool, raw_data):
except ValueError as e:
return problem(None, 400, _("Invalid problem detail document"))
- type = data.get('type')
- source = data.get('source')
- detail = data.get('detail')
+ type = data.get("type")
+ source = data.get("source")
+ detail = data.get("detail")
if not type:
return problem(None, 400, _("No problem type specified."))
if type not in Complaint.VALID_TYPES:
- return problem(None, 400, _("Unrecognized problem type: %(type)s", type=type))
+ return problem(
+ None, 400, _("Unrecognized problem type: %(type)s", type=type)
+ )
complaint = None
try:
diff --git a/bin/initialize_database b/bin/initialize_database
index 117e0dc83..383e4f1ec 100755
--- a/bin/initialize_database
+++ b/bin/initialize_database
@@ -2,4 +2,5 @@
"""Run any new migrations from the database"""
import startup
from core.scripts import DatabaseMigrationInitializationScript
+
DatabaseMigrationInitializationScript().run()
diff --git a/bin/migrate_database b/bin/migrate_database
index 18ac82035..048708c0d 100755
--- a/bin/migrate_database
+++ b/bin/migrate_database
@@ -2,4 +2,5 @@
"""Run any new migrations from the database"""
import startup
from core.scripts import DatabaseMigrationScript
+
DatabaseMigrationScript().run()
diff --git a/bin/repair/check_contributor_names b/bin/repair/check_contributor_names
index 2e2d7973b..6abbebc67 100755
--- a/bin/repair/check_contributor_names
+++ b/bin/repair/check_contributor_names
@@ -14,4 +14,5 @@ were fixed or complained about.
"""
import startup
from core.scripts import CheckContributorNamesInDB
+
CheckContributorNamesInDB().run()
diff --git a/bin/repair/startup.py b/bin/repair/startup.py
index 12d408d78..cef296ce7 100644
--- a/bin/repair/startup.py
+++ b/bin/repair/startup.py
@@ -1,11 +1,13 @@
-from os import sys, path
+from os import path, sys
# Good overview of what is going on here:
# https://stackoverflow.com/questions/11536764/how-to-fix-attempted-relative-import-in-non-package-even-with-init-py
# Once we have a stable package name for core, it should be easier to do away with something like this
# for now we add the core component path to the sys.path when we are running these scripts
-component_dir = path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
+component_dir = path.dirname(
+ path.dirname(path.dirname(path.dirname(path.abspath(__file__))))
+)
# Load the 'core' module as though this script were being run from
# the parent component (either circulation or metadata).
-sys.path.append(component_dir)
\ No newline at end of file
+sys.path.append(component_dir)
diff --git a/bin/startup.py b/bin/startup.py
index 6fe7fa42c..5d06f31ba 100644
--- a/bin/startup.py
+++ b/bin/startup.py
@@ -1,4 +1,4 @@
-from os import sys, path
+from os import path, sys
# Good overview of what is going on here:
# https://stackoverflow.com/questions/11536764/how-to-fix-attempted-relative-import-in-non-package-even-with-init-py
@@ -8,4 +8,4 @@
# Load the 'core' module as though this script were being run from
# the parent component (either circulation or metadata).
-sys.path.append(component_dir)
\ No newline at end of file
+sys.path.append(component_dir)
diff --git a/cdn.py b/cdn.py
index fe490e2d3..e6ed0de93 100644
--- a/cdn.py
+++ b/cdn.py
@@ -1,7 +1,7 @@
"""Turn local URLs into CDN URLs."""
from urllib.parse import urlsplit, urlunsplit
-from .config import Configuration, CannotLoadConfiguration
+from .config import CannotLoadConfiguration, Configuration
def cdnify(url, cdns=None):
diff --git a/classifier/__init__.py b/classifier/__init__.py
index f06dd5f1b..7fafc3f59 100644
--- a/classifier/__init__.py
+++ b/classifier/__init__.py
@@ -12,16 +12,13 @@
# SQL to find commonly used classifications not assigned to a genre
# select count(identifiers.id) as c, subjects.type, substr(subjects.identifier, 0, 20) as i, substr(subjects.name, 0, 20) as n from workidentifiers join classifications on workidentifiers.id=classifications.work_identifier_id join subjects on classifications.subject_id=subjects.id where subjects.genre_id is null and subjects.fiction is null group by subjects.type, i, n order by c desc;
-import logging
import json
+import logging
import os
import pkgutil
import re
+from collections import Counter, defaultdict
from urllib.parse import urlparse
-from collections import (
- Counter,
- defaultdict,
-)
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import and_
@@ -32,6 +29,7 @@
NO_VALUE = "NONE"
NO_NUMBER = -1
+
class ClassifierConstants(object):
DDC = "DDC"
LCC = "LCC"
@@ -40,13 +38,13 @@ class ClassifierConstants(object):
OVERDRIVE = "Overdrive"
BISAC = "BISAC"
BIC = "BIC"
- TAG = "tag" # Folksonomic tags.
+ TAG = "tag" # Folksonomic tags.
# Appeal controlled vocabulary developed by NYPL
NYPL_APPEAL = "NYPL Appeal"
- GRADE_LEVEL = "Grade level" # "1-2", "Grade 4", "Kindergarten", etc.
- AGE_RANGE = "schema:typicalAgeRange" # "0-2", etc.
+ GRADE_LEVEL = "Grade level" # "1-2", "Grade 4", "Kindergarten", etc.
+ AGE_RANGE = "schema:typicalAgeRange" # "0-2", etc.
AXIS_360_AUDIENCE = "Axis 360 Audience"
# We know this says something about the audience but we're not sure what.
@@ -82,19 +80,30 @@ class ClassifierConstants(object):
AUDIENCES_YOUNG_CHILDREN = [AUDIENCE_CHILDREN, AUDIENCE_ALL_AGES]
AUDIENCES_JUVENILE = AUDIENCES_YOUNG_CHILDREN + [AUDIENCE_YOUNG_ADULT]
AUDIENCES_ADULT = [AUDIENCE_ADULT, AUDIENCE_ADULTS_ONLY, AUDIENCE_ALL_AGES]
- AUDIENCES = set([AUDIENCE_ADULT, AUDIENCE_ADULTS_ONLY, AUDIENCE_YOUNG_ADULT,
- AUDIENCE_CHILDREN, AUDIENCE_ALL_AGES, AUDIENCE_RESEARCH])
+ AUDIENCES = set(
+ [
+ AUDIENCE_ADULT,
+ AUDIENCE_ADULTS_ONLY,
+ AUDIENCE_YOUNG_ADULT,
+ AUDIENCE_CHILDREN,
+ AUDIENCE_ALL_AGES,
+ AUDIENCE_RESEARCH,
+ ]
+ )
SIMPLIFIED_GENRE = "http://librarysimplified.org/terms/genres/Simplified/"
SIMPLIFIED_FICTION_STATUS = "http://librarysimplified.org/terms/fiction/"
+
class Classifier(ClassifierConstants):
"""Turn an external classification into an internal genre, an
audience, an age level, and a fiction status.
"""
AUDIENCES_NO_RESEARCH = [
- x for x in ClassifierConstants.AUDIENCES if x != ClassifierConstants.AUDIENCE_RESEARCH
+ x
+ for x in ClassifierConstants.AUDIENCES
+ if x != ClassifierConstants.AUDIENCE_RESEARCH
]
classifiers = dict()
@@ -136,11 +145,12 @@ def classify(cls, subject):
if target_age == cls.range_tuple(None, None):
target_age = cls.default_target_age_for_audience(audience)
- return (cls.genre(identifier, name, fiction, audience),
- audience,
- target_age,
- fiction,
- )
+ return (
+ cls.genre(identifier, name, fiction, audience),
+ audience,
+ target_age,
+ fiction,
+ )
@classmethod
def scrub_identifier_and_name(cls, identifier, name):
@@ -174,7 +184,6 @@ def scrub_name(cls, name):
return None
return Lowercased(name)
-
@classmethod
def genre(cls, identifier, name, fiction=None, audience=None):
"""Is this identifier associated with a particular Genre?"""
@@ -202,9 +211,9 @@ def audience(cls, identifier, name):
"""What does this identifier+name say about the audience for
this book?
"""
- if 'juvenile' in name:
+ if "juvenile" in name:
return cls.AUDIENCE_CHILDREN
- elif 'young adult' in name or "YA" in name.original:
+ elif "young adult" in name or "YA" in name.original:
return cls.AUDIENCE_YOUNG_ADULT
return None
@@ -232,9 +241,7 @@ def default_target_age_for_audience(cls, audience):
"""
if audience == Classifier.AUDIENCE_YOUNG_ADULT:
return cls.range_tuple(14, 17)
- elif audience in (
- Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_ADULTS_ONLY
- ):
+ elif audience in (Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_ADULTS_ONLY):
return cls.range_tuple(18, None)
return cls.range_tuple(None, None)
@@ -294,11 +301,7 @@ def and_up(cls, young, keyword):
"""
if young is None:
return None
- if not any(
- [keyword.endswith(x) for x in
- ("and up", "and up.", "+", "+.")
- ]
- ):
+ if not any([keyword.endswith(x) for x in ("and up", "and up.", "+", "+.")]):
return None
if young >= 18:
@@ -317,47 +320,46 @@ def and_up(cls, young, keyword):
old = young + 2
return old
+
class GradeLevelClassifier(Classifier):
# How old a kid is when they start grade N in the US.
american_grade_to_age = {
# Preschool: 3-4 years
- 'preschool' : 3,
- 'pre-school' : 3,
- 'p' : 3,
- 'pk' : 4,
-
+ "preschool": 3,
+ "pre-school": 3,
+ "p": 3,
+ "pk": 4,
# Easy readers
- 'kindergarten' : 5,
- 'k' : 5,
- '0' : 5,
- 'first' : 6,
- '1' : 6,
- 'second' : 7,
- '2' : 7,
-
+ "kindergarten": 5,
+ "k": 5,
+ "0": 5,
+ "first": 6,
+ "1": 6,
+ "second": 7,
+ "2": 7,
# Chapter Books
- 'third' : 8,
- '3' : 8,
- 'fourth' : 9,
- '4' : 9,
- 'fifth' : 10,
- '5' : 10,
- 'sixth' : 11,
- '6' : 11,
- '7' : 12,
- '8' : 13,
-
+ "third": 8,
+ "3": 8,
+ "fourth": 9,
+ "4": 9,
+ "fifth": 10,
+ "5": 10,
+ "sixth": 11,
+ "6": 11,
+ "7": 12,
+ "8": 13,
# YA
- '9' : 14,
- '10' : 15,
- '11' : 16,
- '12': 17,
+ "9": 14,
+ "10": 15,
+ "11": 16,
+ "12": 17,
}
# Regular expressions that match common ways of expressing grade
# levels.
grade_res = [
- re.compile(x, re.I) for x in [
+ re.compile(x, re.I)
+ for x in [
"grades? ([kp0-9]+) to ([kp0-9]+)?",
"grades? ([kp0-9]+) ?-? ?([kp0-9]+)?",
"gr\.? ([kp0-9]+) ?-? ?([kp0-9]+)?",
@@ -366,7 +368,7 @@ class GradeLevelClassifier(Classifier):
"gr\.? ([kp0-9]+)",
"([0-9]+)[tnsr][hdt] grade",
"([a-z]+) grade",
- r'\b(kindergarten|preschool)\b',
+ r"\b(kindergarten|preschool)\b",
]
]
@@ -382,15 +384,14 @@ def audience(cls, identifier, name, require_explicit_age_marker=False):
target_age = cls.target_age(identifier, name, require_explicit_age_marker)
return cls.default_audience_for_target_age(target_age)
-
@classmethod
def target_age(cls, identifier, name, require_explicit_grade_marker=False):
- if (identifier and "education" in identifier) or (name and 'education' in name):
+ if (identifier and "education" in identifier) or (name and "education" in name):
# This is a book about teaching, e.g. fifth grade.
return cls.range_tuple(None, None)
- if (identifier and 'grader' in identifier) or (name and 'grader' in name):
+ if (identifier and "grader" in identifier) or (name and "grader" in name):
# This is a book about, e.g. fifth graders.
return cls.range_tuple(None, None)
@@ -413,9 +414,9 @@ def target_age(cls, identifier, name, require_explicit_grade_marker=False):
young, old = gr
# Strip leading zeros
- if young and young.lstrip('0'):
+ if young and young.lstrip("0"):
young = young.lstrip("0")
- if old and old.lstrip('0'):
+ if old and old.lstrip("0"):
old = old.lstrip("0")
young = cls.american_grade_to_age.get(young)
@@ -434,7 +435,7 @@ def target_age(cls, identifier, name, require_explicit_grade_marker=False):
old = young
if young is None and old is not None:
young = old
- if old and young and old < young:
+ if old and young and old < young:
young, old = old, young
return cls.range_tuple(young, old)
return cls.range_tuple(None, None)
@@ -452,32 +453,33 @@ def target_age_match(cls, query):
break
return (target_age, grade_words)
-class InterestLevelClassifier(Classifier):
+class InterestLevelClassifier(Classifier):
@classmethod
def audience(cls, identifier, name):
- if identifier in ('lg', 'mg+', 'mg'):
+ if identifier in ("lg", "mg+", "mg"):
return cls.AUDIENCE_CHILDREN
- elif identifier == 'ug':
+ elif identifier == "ug":
return cls.AUDIENCE_YOUNG_ADULT
else:
return None
@classmethod
def target_age(cls, identifier, name):
- if identifier == 'lg':
- return cls.range_tuple(5,8)
- if identifier in ('mg+', 'mg'):
- return cls.range_tuple(9,13)
- if identifier == 'ug':
- return cls.range_tuple(14,17)
+ if identifier == "lg":
+ return cls.range_tuple(5, 8)
+ if identifier in ("mg+", "mg"):
+ return cls.range_tuple(9, 13)
+ if identifier == "ug":
+ return cls.range_tuple(14, 17)
return None
class AgeClassifier(Classifier):
# Regular expressions that match common ways of expressing ages.
age_res = [
- re.compile(x, re.I) for x in [
+ re.compile(x, re.I)
+ for x in [
"age ([0-9]+) ?-? ?([0-9]+)?",
"age: ([0-9]+) ?-? ?([0-9]+)?",
"age: ([0-9]+) to ([0-9]+)",
@@ -577,54 +579,70 @@ def target_age_match(cls, query):
COMICS_AND_GRAPHIC_NOVELS,
"Drama",
dict(name="Erotica", audiences=Classifier.AUDIENCE_ADULTS_ONLY),
- dict(name="Fantasy", subgenres=[
- "Epic Fantasy",
- "Historical Fantasy",
- "Urban Fantasy",
- ]),
+ dict(
+ name="Fantasy",
+ subgenres=[
+ "Epic Fantasy",
+ "Historical Fantasy",
+ "Urban Fantasy",
+ ],
+ ),
"Folklore",
"Historical Fiction",
- dict(name="Horror", subgenres=[
- "Gothic Horror",
- "Ghost Stories",
- "Vampires",
- "Werewolves",
- "Occult Horror",
- ]),
+ dict(
+ name="Horror",
+ subgenres=[
+ "Gothic Horror",
+ "Ghost Stories",
+ "Vampires",
+ "Werewolves",
+ "Occult Horror",
+ ],
+ ),
"Humorous Fiction",
"Literary Fiction",
"LGBTQ Fiction",
- dict(name="Mystery", subgenres=[
- "Crime & Detective Stories",
- "Hard-Boiled Mystery",
- "Police Procedural",
- "Cozy Mystery",
- "Historical Mystery",
- "Paranormal Mystery",
- "Women Detectives",
- ]),
+ dict(
+ name="Mystery",
+ subgenres=[
+ "Crime & Detective Stories",
+ "Hard-Boiled Mystery",
+ "Police Procedural",
+ "Cozy Mystery",
+ "Historical Mystery",
+ "Paranormal Mystery",
+ "Women Detectives",
+ ],
+ ),
"Poetry",
"Religious Fiction",
- dict(name="Romance", subgenres=[
- "Contemporary Romance",
- "Gothic Romance",
- "Historical Romance",
- "Paranormal Romance",
- "Western Romance",
- "Romantic Suspense",
- ]),
- dict(name="Science Fiction", subgenres=[
- "Dystopian SF",
- "Space Opera",
- "Cyberpunk",
- "Military SF",
- "Alternative History",
- "Steampunk",
- "Romantic SF",
- "Media Tie-in SF",
- ]),
+ dict(
+ name="Romance",
+ subgenres=[
+ "Contemporary Romance",
+ "Gothic Romance",
+ "Historical Romance",
+ "Paranormal Romance",
+ "Western Romance",
+ "Romantic Suspense",
+ ],
+ ),
+ dict(
+ name="Science Fiction",
+ subgenres=[
+ "Dystopian SF",
+ "Space Opera",
+ "Cyberpunk",
+ "Military SF",
+ "Alternative History",
+ "Steampunk",
+ "Romantic SF",
+ "Media Tie-in SF",
+ ],
+ ),
"Short Stories",
- dict(name="Suspense/Thriller",
+ dict(
+ name="Suspense/Thriller",
subgenres=[
"Historical Thriller",
"Espionage",
@@ -643,92 +661,122 @@ def target_age_match(cls, query):
]
nonfiction_genres = [
- dict(name="Art & Design", subgenres=[
- "Architecture",
- "Art",
- "Art Criticism & Theory",
- "Art History",
- "Design",
- "Fashion",
- "Photography",
- ]),
+ dict(
+ name="Art & Design",
+ subgenres=[
+ "Architecture",
+ "Art",
+ "Art Criticism & Theory",
+ "Art History",
+ "Design",
+ "Fashion",
+ "Photography",
+ ],
+ ),
"Biography & Memoir",
"Education",
- dict(name="Personal Finance & Business", subgenres=[
- "Business",
- "Economics",
- "Management & Leadership",
- "Personal Finance & Investing",
- "Real Estate",
- ]),
- dict(name="Parenting & Family", subgenres=[
- "Family & Relationships",
- "Parenting",
- ]),
- dict(name="Food & Health", subgenres=[
- "Bartending & Cocktails",
- "Cooking",
- "Health & Diet",
- "Vegetarian & Vegan",
- ]),
- dict(name="History", subgenres=[
- "African History",
- "Ancient History",
- "Asian History",
- "Civil War History",
- "European History",
- "Latin American History",
- "Medieval History",
- "Middle East History",
- "Military History",
- "Modern History",
- "Renaissance & Early Modern History",
- "United States History",
- "World History",
- ]),
- dict(name="Hobbies & Home", subgenres=[
- "Antiques & Collectibles",
- "Crafts & Hobbies",
- "Gardening",
- "Games",
- "House & Home",
- "Pets",
- ]),
+ dict(
+ name="Personal Finance & Business",
+ subgenres=[
+ "Business",
+ "Economics",
+ "Management & Leadership",
+ "Personal Finance & Investing",
+ "Real Estate",
+ ],
+ ),
+ dict(
+ name="Parenting & Family",
+ subgenres=[
+ "Family & Relationships",
+ "Parenting",
+ ],
+ ),
+ dict(
+ name="Food & Health",
+ subgenres=[
+ "Bartending & Cocktails",
+ "Cooking",
+ "Health & Diet",
+ "Vegetarian & Vegan",
+ ],
+ ),
+ dict(
+ name="History",
+ subgenres=[
+ "African History",
+ "Ancient History",
+ "Asian History",
+ "Civil War History",
+ "European History",
+ "Latin American History",
+ "Medieval History",
+ "Middle East History",
+ "Military History",
+ "Modern History",
+ "Renaissance & Early Modern History",
+ "United States History",
+ "World History",
+ ],
+ ),
+ dict(
+ name="Hobbies & Home",
+ subgenres=[
+ "Antiques & Collectibles",
+ "Crafts & Hobbies",
+ "Gardening",
+ "Games",
+ "House & Home",
+ "Pets",
+ ],
+ ),
"Humorous Nonfiction",
- dict(name="Entertainment", subgenres=[
- "Film & TV",
- "Music",
- "Performing Arts",
- ]),
+ dict(
+ name="Entertainment",
+ subgenres=[
+ "Film & TV",
+ "Music",
+ "Performing Arts",
+ ],
+ ),
"Life Strategies",
"Literary Criticism",
"Periodicals",
"Philosophy",
"Political Science",
- dict(name="Reference & Study Aids", subgenres=[
- "Dictionaries",
- "Foreign Language Study",
- "Law",
- "Study Aids",
- ]),
- dict(name="Religion & Spirituality", subgenres=[
- "Body, Mind & Spirit",
- "Buddhism",
- "Christianity",
- "Hinduism",
- "Islam",
- "Judaism",
- ]),
- dict(name="Science & Technology", subgenres=[
- "Computers",
- "Mathematics",
- "Medical",
- "Nature",
- "Psychology",
- "Science",
- "Social Sciences",
- "Technology",
- ]),
+ dict(
+ name="Reference & Study Aids",
+ subgenres=[
+ "Dictionaries",
+ "Foreign Language Study",
+ "Law",
+ "Study Aids",
+ ],
+ ),
+ dict(
+ name="Religion & Spirituality",
+ subgenres=[
+ "Body, Mind & Spirit",
+ "Buddhism",
+ "Christianity",
+ "Hinduism",
+ "Islam",
+ "Judaism",
+ ],
+ ),
+ dict(
+ name="Science & Technology",
+ subgenres=[
+ "Computers",
+ "Mathematics",
+ "Medical",
+ "Nature",
+ "Psychology",
+ "Science",
+ "Social Sciences",
+ "Technology",
+ ],
+ ),
"Self-Help",
"Sports",
"Travel",
@@ -778,7 +826,15 @@ def has_subgenre(self, subgenre):
@property
def variable_name(self):
- return self.name.replace("-", "_").replace(", & ", "_").replace(", ", "_").replace(" & ", "_").replace(" ", "_").replace("/", "_").replace("'", "")
+ return (
+ self.name.replace("-", "_")
+ .replace(", & ", "_")
+ .replace(", ", "_")
+ .replace(" & ", "_")
+ .replace(" ", "_")
+ .replace("/", "_")
+ .replace("'", "")
+ )
@classmethod
def populate(cls, namespace, genres, fiction_source, nonfiction_source):
@@ -786,28 +842,35 @@ def populate(cls, namespace, genres, fiction_source, nonfiction_source):
list of fiction and nonfiction genres.
"""
for source, default_fiction in (
- (fiction_source, True),
- (nonfiction_source, False)):
+ (fiction_source, True),
+ (nonfiction_source, False),
+ ):
for item in source:
subgenres = []
audience_restriction = None
name = item
fiction = default_fiction
if isinstance(item, dict):
- name = item['name']
- subgenres = item.get('subgenres', [])
- audience_restriction = item.get('audience_restriction')
- fiction = item.get('fiction', default_fiction)
+ name = item["name"]
+ subgenres = item.get("subgenres", [])
+ audience_restriction = item.get("audience_restriction")
+ fiction = item.get("fiction", default_fiction)
cls.add_genre(
- namespace, genres, name, subgenres, fiction,
- None, audience_restriction)
+ namespace,
+ genres,
+ name,
+ subgenres,
+ fiction,
+ None,
+ audience_restriction,
+ )
@classmethod
- def add_genre(cls, namespace, genres, name, subgenres, fiction,
- parent, audience_restriction):
- """Create a GenreData object. Add it to a dictionary and a namespace.
- """
+ def add_genre(
+ cls, namespace, genres, name, subgenres, fiction, parent, audience_restriction
+ ):
+ """Create a GenreData object. Add it to a dictionary and a namespace."""
if isinstance(name, tuple):
name, default_fiction = name
default_fiction = None
@@ -817,10 +880,10 @@ def add_genre(cls, namespace, genres, name, subgenres, fiction,
default_audience = parent.audience_restriction
if isinstance(name, dict):
data = name
- subgenres = data.get('subgenres', [])
- name = data['name']
- fiction = data.get('fiction', default_fiction)
- audience_restriction = data.get('audience', default_audience)
+ subgenres = data.get("subgenres", [])
+ name = data["name"]
+ fiction = data.get("fiction", default_fiction)
+ audience_restriction = data.get("audience", default_audience)
if name in genres:
raise ValueError("Duplicate genre name! %s" % name)
@@ -838,14 +901,18 @@ def add_genre(cls, namespace, genres, name, subgenres, fiction,
# Do the same for subgenres.
for sub in subgenres:
- cls.add_genre(namespace, genres, sub, [], fiction,
- genre_data, audience_restriction)
+ cls.add_genre(
+ namespace, genres, sub, [], fiction, genre_data, audience_restriction
+ )
+
genres = dict()
GenreData.populate(globals(), genres, fiction_genres, nonfiction_genres)
+
class Lowercased(str):
"""A lowercased string that remembers its original value."""
+
def __new__(cls, value):
if isinstance(value, Lowercased):
# Nothing to do.
@@ -853,7 +920,7 @@ def __new__(cls, value):
if not isinstance(value, str):
value = str(value)
new_value = value.lower()
- if new_value.endswith('.'):
+ if new_value.endswith("."):
new_value = new_value[:-1]
o = super(Lowercased, cls).__new__(cls, new_value)
o.original = value
@@ -864,8 +931,8 @@ def scrub_identifier(cls, identifier):
if not identifier:
return identifier
-class AgeOrGradeClassifier(Classifier):
+class AgeOrGradeClassifier(Classifier):
@classmethod
def audience(cls, identifier, name):
audience = AgeClassifier.audience(identifier, name)
@@ -886,6 +953,7 @@ def target_age(cls, identifier, name):
age = GradeLevelClassifier.target_age(identifier, name, True)
return age
+
class FreeformAudienceClassifier(AgeOrGradeClassifier):
# NOTE: In practice, subjects like "books for all ages" tend to be
# more like advertising slogans than reliable indicators of an
@@ -895,33 +963,36 @@ class FreeformAudienceClassifier(AgeOrGradeClassifier):
@classmethod
def audience(cls, identifier, name):
- if identifier in ('children', 'pre-adolescent', 'beginning reader'):
+ if identifier in ("children", "pre-adolescent", "beginning reader"):
return cls.AUDIENCE_CHILDREN
- elif identifier in ('young adult', 'ya', 'teenagers', 'adolescent',
- 'early adolescents'):
+ elif identifier in (
+ "young adult",
+ "ya",
+ "teenagers",
+ "adolescent",
+ "early adolescents",
+ ):
return cls.AUDIENCE_YOUNG_ADULT
- elif identifier == 'adult':
+ elif identifier == "adult":
return cls.AUDIENCE_ADULT
- elif identifier == 'adults only':
+ elif identifier == "adults only":
return cls.AUDIENCE_ADULTS_ONLY
- elif identifier == 'all ages':
+ elif identifier == "all ages":
return cls.AUDIENCE_ALL_AGES
- elif identifier == 'research':
+ elif identifier == "research":
return cls.AUDIENCE_RESEARCH
return AgeOrGradeClassifier.audience(identifier, name)
@classmethod
def target_age(cls, identifier, name):
- if identifier == 'beginning reader':
- return cls.range_tuple(5,8)
- if identifier == 'pre-adolescent':
+ if identifier == "beginning reader":
+ return cls.range_tuple(5, 8)
+ if identifier == "pre-adolescent":
return cls.range_tuple(9, 12)
- if identifier == 'early adolescents':
+ if identifier == "early adolescents":
return cls.range_tuple(13, 15)
- if identifier == 'all ages':
- return cls.range_tuple(
- cls.ALL_AGES_AGE_CUTOFF, None
- )
+ if identifier == "all ages":
+ return cls.range_tuple(cls.ALL_AGES_AGE_CUTOFF, None)
strict_age = AgeClassifier.target_age(identifier, name, True)
if strict_age[0] or strict_age[1]:
return strict_age
@@ -939,52 +1010,56 @@ class WorkClassifier(object):
# TODO: This needs a lot of additions.
genre_publishers = {
- "Harlequin" : Romance,
- "Pocket Books/Star Trek" : Media_Tie_in_SF,
- "Kensington" : Urban_Fiction,
- "Fodor's Travel Publications" : Travel,
- "Marvel Entertainment, LLC" : Comics_Graphic_Novels,
+ "Harlequin": Romance,
+ "Pocket Books/Star Trek": Media_Tie_in_SF,
+ "Kensington": Urban_Fiction,
+ "Fodor's Travel Publications": Travel,
+ "Marvel Entertainment, LLC": Comics_Graphic_Novels,
}
genre_imprints = {
- "Harlequin Intrigue" : Romantic_Suspense,
- "Love Inspired Suspense" : Romantic_Suspense,
- "Harlequin Historical" : Historical_Romance,
- "Harlequin Historical Undone" : Historical_Romance,
- "Frommers" : Travel,
+ "Harlequin Intrigue": Romantic_Suspense,
+ "Love Inspired Suspense": Romantic_Suspense,
+ "Harlequin Historical": Historical_Romance,
+ "Harlequin Historical Undone": Historical_Romance,
+ "Frommers": Travel,
"LucasBooks": Media_Tie_in_SF,
}
audience_imprints = {
- "Harlequin Teen" : Classifier.AUDIENCE_YOUNG_ADULT,
- "HarperTeen" : Classifier.AUDIENCE_YOUNG_ADULT,
- "Open Road Media Teen & Tween" : Classifier.AUDIENCE_YOUNG_ADULT,
- "Rosen Young Adult" : Classifier.AUDIENCE_YOUNG_ADULT,
+ "Harlequin Teen": Classifier.AUDIENCE_YOUNG_ADULT,
+ "HarperTeen": Classifier.AUDIENCE_YOUNG_ADULT,
+ "Open Road Media Teen & Tween": Classifier.AUDIENCE_YOUNG_ADULT,
+ "Rosen Young Adult": Classifier.AUDIENCE_YOUNG_ADULT,
}
- not_adult_publishers = set([
- "Scholastic Inc.",
- "Random House Children's Books",
- "Little, Brown Books for Young Readers",
- "Penguin Young Readers Group",
- "Hachette Children's Books",
- "Nickelodeon Publishing",
- ])
-
- not_adult_imprints = set([
- "Scholastic",
- "Scholastic Paperbacks",
- "Random House Books for Young Readers",
- "HMH Books for Young Readers",
- "Knopf Books for Young Readers",
- "Delacorte Books for Young Readers",
- "Open Road Media Young Readers",
- "Macmillan Young Listeners",
- "Bloomsbury Childrens",
- "NYR Children's Collection",
- "Bloomsbury USA Childrens",
- "National Geographic Children's Books",
- ])
+ not_adult_publishers = set(
+ [
+ "Scholastic Inc.",
+ "Random House Children's Books",
+ "Little, Brown Books for Young Readers",
+ "Penguin Young Readers Group",
+ "Hachette Children's Books",
+ "Nickelodeon Publishing",
+ ]
+ )
+
+ not_adult_imprints = set(
+ [
+ "Scholastic",
+ "Scholastic Paperbacks",
+ "Random House Books for Young Readers",
+ "HMH Books for Young Readers",
+ "Knopf Books for Young Readers",
+ "Delacorte Books for Young Readers",
+ "Open Road Media Young Readers",
+ "Macmillan Young Listeners",
+ "Bloomsbury Childrens",
+ "NYR Children's Collection",
+ "Bloomsbury USA Childrens",
+ "National Geographic Children's Books",
+ ]
+ )
fiction_imprints = set(["Del Rey"])
nonfiction_imprints = set(["Harlequin Nonfiction"])
@@ -1036,7 +1111,7 @@ def add(self, classification):
self.classifications.append(classification)
# Make sure the Subject is ready to be used in calculations.
- if not classification.subject.checked: # or self.debug
+ if not classification.subject.checked: # or self.debug
classification.subject.assign_to_genre()
if classification.comes_from_license_source:
@@ -1061,7 +1136,11 @@ def add(self, classification):
# if classification is genre or NONE from staff, ignore all non-staff genres
is_genre = subject.genre != None
- is_none = (from_staff and subject.type == Subject.SIMPLIFIED_GENRE and subject.identifier == SimplifiedGenreClassifier.NONE)
+ is_none = (
+ from_staff
+ and subject.type == Subject.SIMPLIFIED_GENRE
+ and subject.identifier == SimplifiedGenreClassifier.NONE
+ )
if is_genre or is_none:
if not from_staff and self.using_staff_genres:
return
@@ -1100,8 +1179,10 @@ def add(self, classification):
# weight this way, we're also going to treat this
# classification as evidence _against_ an 'adult'
# classification.
- self.audience_weights[Classifier.AUDIENCE_YOUNG_ADULT] += (weight * 0.6)
- self.audience_weights[Classifier.AUDIENCE_CHILDREN] += (weight * 0.4)
+ self.audience_weights[Classifier.AUDIENCE_YOUNG_ADULT] += (
+ weight * 0.6
+ )
+ self.audience_weights[Classifier.AUDIENCE_CHILDREN] += weight * 0.4
for audience in Classifier.AUDIENCES_ADULT:
if audience != Classifier.AUDIENCE_ALL_AGES:
# 'All Ages' is considered an adult audience,
@@ -1132,9 +1213,12 @@ def add(self, classification):
self.target_age_upper_weights[target_max] += scaled_weight
if not self.using_staff_audience and not self.using_staff_target_age:
- if subject.type=='Overdrive' and subject.audience==Classifier.AUDIENCE_CHILDREN:
+ if (
+ subject.type == "Overdrive"
+ and subject.audience == Classifier.AUDIENCE_CHILDREN
+ ):
if subject.target_age and (
- subject.target_age.lower or subject.target_age.upper
+ subject.target_age.lower or subject.target_age.upper
):
# This is a juvenile classification like "Picture
# Books" which implies a target age.
@@ -1152,20 +1236,21 @@ def weigh_metadata(self):
This is basic stuff, like: Harlequin tends to publish
romances.
"""
- if self.work.title and ('Star Trek:' in self.work.title
- or 'Star Wars:' in self.work.title
- or ('Jedi' in self.work.title
- and self.work.imprint=='Del Rey')
+ if self.work.title and (
+ "Star Trek:" in self.work.title
+ or "Star Wars:" in self.work.title
+ or ("Jedi" in self.work.title and self.work.imprint == "Del Rey")
):
self.weigh_genre(Media_Tie_in_SF, 100)
publisher = self.work.publisher
imprint = self.work.imprint
- if (imprint in self.nonfiction_imprints
- or publisher in self.nonfiction_publishers):
+ if (
+ imprint in self.nonfiction_imprints
+ or publisher in self.nonfiction_publishers
+ ):
self.fiction_weights[False] = 100
- elif (imprint in self.fiction_imprints
- or publisher in self.fiction_publishers):
+ elif imprint in self.fiction_imprints or publisher in self.fiction_publishers:
self.fiction_weights[True] = 100
if imprint in self.genre_imprints:
@@ -1175,10 +1260,13 @@ def weigh_metadata(self):
if imprint in self.audience_imprints:
self.audience_weights[self.audience_imprints[imprint]] += 100
- elif (publisher in self.not_adult_publishers
- or imprint in self.not_adult_imprints):
- for audience in [Classifier.AUDIENCE_ADULT,
- Classifier.AUDIENCE_ADULTS_ONLY]:
+ elif (
+ publisher in self.not_adult_publishers or imprint in self.not_adult_imprints
+ ):
+ for audience in [
+ Classifier.AUDIENCE_ADULT,
+ Classifier.AUDIENCE_ADULTS_ONLY,
+ ]:
self.audience_weights[audience] -= 100
def prepare_to_classify(self):
@@ -1190,17 +1278,22 @@ def prepare_to_classify(self):
explicitly_indicated_audiences = (
Classifier.AUDIENCE_CHILDREN,
Classifier.AUDIENCE_YOUNG_ADULT,
- Classifier.AUDIENCE_ADULTS_ONLY)
+ Classifier.AUDIENCE_ADULTS_ONLY,
+ )
audiences_from_license_source = set(
- [classification.subject.audience
- for classification in self.direct_from_license_source]
+ [
+ classification.subject.audience
+ for classification in self.direct_from_license_source
+ ]
)
- if (self.direct_from_license_source
+ if (
+ self.direct_from_license_source
and not self.using_staff_audience
and not any(
audience in explicitly_indicated_audiences
for audience in audiences_from_license_source
- )):
+ )
+ ):
# If this was erotica, or a book for children or young
# adults, the distributor would have given some indication
# of that fact. In the absense of any such indication, we
@@ -1213,8 +1306,10 @@ def prepare_to_classify(self):
# classifications.
self.audience_weights[Classifier.AUDIENCE_ADULT] += 500
- if (self.overdrive_juvenile_generic
- and not self.overdrive_juvenile_with_target_age):
+ if (
+ self.overdrive_juvenile_generic
+ and not self.overdrive_juvenile_with_target_age
+ ):
# This book is classified under 'Juvenile Fiction' but not
# under 'Picture Books' or 'Beginning Readers'. The
# implicit target age here is 9-12 (the portion of
@@ -1313,12 +1408,16 @@ def audience(self, genres=[], default_audience=None):
# If the 'children' weight passes the threshold on its own
# we go with 'children'.
total_juvenile_weight = children_weight + ya_weight
- if (research_weight > (total_adult_weight + all_ages_weight) and
- research_weight > (total_juvenile_weight + all_ages_weight) and
- research_weight > threshold):
+ if (
+ research_weight > (total_adult_weight + all_ages_weight)
+ and research_weight > (total_juvenile_weight + all_ages_weight)
+ and research_weight > threshold
+ ):
audience = Classifier.AUDIENCE_RESEARCH
- elif (all_ages_weight > total_adult_weight and
- all_ages_weight > total_juvenile_weight):
+ elif (
+ all_ages_weight > total_adult_weight
+ and all_ages_weight > total_juvenile_weight
+ ):
audience = Classifier.AUDIENCE_ALL_AGES
elif children_weight > threshold and children_weight > ya_weight:
audience = Classifier.AUDIENCE_CHILDREN
@@ -1336,8 +1435,10 @@ def audience(self, genres=[], default_audience=None):
# weight, classify as 'adults only' to be safe.
#
# TODO: This has not been calibrated.
- if (audience==Classifier.AUDIENCE_ADULT
- and adults_only_weight > total_adult_weight/4):
+ if (
+ audience == Classifier.AUDIENCE_ADULT
+ and adults_only_weight > total_adult_weight / 4
+ ):
audience = Classifier.AUDIENCE_ADULTS_ONLY
return audience
@@ -1363,7 +1464,8 @@ def top_tier_values(self, counter):
def target_age(self, audience):
"""Derive a target age from the gathered data."""
if audience not in (
- Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_YOUNG_ADULT
+ Classifier.AUDIENCE_CHILDREN,
+ Classifier.AUDIENCE_YOUNG_ADULT,
):
# This is not a children's or YA book. Assertions about
# target age are irrelevant and the default value rules.
@@ -1458,16 +1560,14 @@ def weigh_genre(self, genre_data, weight):
self.genre_weights[genre] += weight
@classmethod
- def consolidate_genre_weights(
- cls, weights, subgenre_swallows_parent_at=0.03
- ):
+ def consolidate_genre_weights(cls, weights, subgenre_swallows_parent_at=0.03):
"""If a genre and its subgenres both show up, examine the subgenre
with the highest weight. If its weight exceeds a certain
proportion of the weight of the parent genre, assign the
parent's weight to the subgenre and remove the parent.
"""
- #print("Before consolidation:")
- #for genre, weight in weights.items():
+ # print("Before consolidation:")
+ # for genre, weight in weights.items():
# print("", genre, weight)
# Convert Genre objects to GenreData.
@@ -1481,11 +1581,12 @@ def consolidate_genre_weights(
for genre, weight in list(consolidated.items()):
for parent in genre.parents:
if parent in consolidated:
- if ((not parent in heaviest_child)
- or weight > heaviest_child[parent][1]):
+ if (not parent in heaviest_child) or weight > heaviest_child[
+ parent
+ ][1]:
heaviest_child[parent] = (genre, weight)
- #print("Heaviest child:")
- #for parent, (genre, weight) in heaviest_child.items():
+ # print("Heaviest child:")
+ # for parent, (genre, weight) in heaviest_child.items():
# print("", parent, genre, weight)
made_it = False
while not made_it:
@@ -1505,39 +1606,33 @@ def consolidate_genre_weights(
break
# We made it all the way through the dict without changing it.
made_it = True
- #print("Final heaviest child:")
- #for parent, (genre, weight) in heaviest_child.items():
+ # print("Final heaviest child:")
+ # for parent, (genre, weight) in heaviest_child.items():
# print("", parent, genre, weight)
- #print("After consolidation:")
- #for genre, weight in consolidated.items():
+ # print("After consolidation:")
+ # for genre, weight in consolidated.items():
# print("", genre, weight)
return consolidated
+
# Make a dictionary of classification schemes to classifiers.
Classifier.classifiers[Classifier.FREEFORM_AUDIENCE] = FreeformAudienceClassifier
Classifier.classifiers[Classifier.AXIS_360_AUDIENCE] = AgeOrGradeClassifier
# Finally, import classifiers described in submodules.
-from .age import (
- GradeLevelClassifier,
- InterestLevelClassifier,
- AgeClassifier,
-)
+from .age import AgeClassifier, GradeLevelClassifier, InterestLevelClassifier
+from .bic import BICClassifier
from .bisac import BISACClassifier
from .ddc import DeweyDecimalClassifier
-from .lcc import LCCClassifier
from .gutenberg import GutenbergBookshelfClassifier
-from .bic import BICClassifier
-from .simplified import (
- SimplifiedFictionClassifier,
- SimplifiedGenreClassifier,
-)
-from .overdrive import OverdriveClassifier
from .keyword import (
+ Eg,
+ FASTClassifier,
KeywordBasedClassifier,
LCSHClassifier,
- FASTClassifier,
TAGClassifier,
- Eg,
)
+from .lcc import LCCClassifier
+from .overdrive import OverdriveClassifier
+from .simplified import SimplifiedFictionClassifier, SimplifiedGenreClassifier
diff --git a/classifier/age.py b/classifier/age.py
index 6608cec8f..244967d3e 100644
--- a/classifier/age.py
+++ b/classifier/age.py
@@ -2,53 +2,52 @@
from . import Classifier
+
class GradeLevelClassifier(Classifier):
# How old a kid is when they start grade N in the US.
american_grade_to_age = {
# Preschool: 3-4 years
- 'preschool' : 3,
- 'pre-school' : 3,
- 'p' : 3,
- 'pk' : 4,
-
+ "preschool": 3,
+ "pre-school": 3,
+ "p": 3,
+ "pk": 4,
# Easy readers
- 'kindergarten' : 5,
- 'k' : 5,
- '0' : 5,
- 'first' : 6,
- '1' : 6,
- 'second' : 7,
- '2' : 7,
-
+ "kindergarten": 5,
+ "k": 5,
+ "0": 5,
+ "first": 6,
+ "1": 6,
+ "second": 7,
+ "2": 7,
# Chapter Books
- 'third' : 8,
- '3' : 8,
- 'fourth' : 9,
- '4' : 9,
- 'fifth' : 10,
- '5' : 10,
- 'sixth' : 11,
- '6' : 11,
- '7' : 12,
- 'seventh' : 12,
- '8' : 13,
- 'eighth' : 13,
-
+ "third": 8,
+ "3": 8,
+ "fourth": 9,
+ "4": 9,
+ "fifth": 10,
+ "5": 10,
+ "sixth": 11,
+ "6": 11,
+ "7": 12,
+ "seventh": 12,
+ "8": 13,
+ "eighth": 13,
# YA
- '9' : 14,
- 'ninth' : 14,
- '10' : 15,
- 'tenth': 15,
- '11' : 16,
- 'eleventh' : 17,
- '12': 17,
- 'twelfth': 17,
+ "9": 14,
+ "ninth": 14,
+ "10": 15,
+ "tenth": 15,
+ "11": 16,
+ "eleventh": 17,
+ "12": 17,
+ "twelfth": 17,
}
# Regular expressions that match common ways of expressing grade
# levels.
grade_res = [
- re.compile(x, re.I) for x in [
+ re.compile(x, re.I)
+ for x in [
"grades? ([kp0-9]+) to ([kp0-9]+)?",
"grades? ([kp0-9]+) ?-? ?([kp0-9]+)?",
"gr\.? ([kp0-9]+) ?-? ?([kp0-9]+)?",
@@ -57,7 +56,7 @@ class GradeLevelClassifier(Classifier):
"gr\.? ([kp0-9]+)",
"([0-9]+)[tnsr][hdt] grade",
"([a-z]+) grade",
- r'\b(kindergarten|preschool)\b',
+ r"\b(kindergarten|preschool)\b",
]
]
@@ -73,15 +72,14 @@ def audience(cls, identifier, name, require_explicit_age_marker=False):
target_age = cls.target_age(identifier, name, require_explicit_age_marker)
return cls.default_audience_for_target_age(target_age)
-
@classmethod
def target_age(cls, identifier, name, require_explicit_grade_marker=False):
- if (identifier and "education" in identifier) or (name and 'education' in name):
+ if (identifier and "education" in identifier) or (name and "education" in name):
# This is a book about teaching, e.g. fifth grade.
return cls.range_tuple(None, None)
- if (identifier and 'grader' in identifier) or (name and 'grader' in name):
+ if (identifier and "grader" in identifier) or (name and "grader" in name):
# This is a book about, e.g. fifth graders.
return cls.range_tuple(None, None)
@@ -104,9 +102,9 @@ def target_age(cls, identifier, name, require_explicit_grade_marker=False):
young, old = gr
# Strip leading zeros
- if young and young.lstrip('0'):
+ if young and young.lstrip("0"):
young = young.lstrip("0")
- if old and old.lstrip('0'):
+ if old and old.lstrip("0"):
old = old.lstrip("0")
young = cls.american_grade_to_age.get(young)
@@ -125,7 +123,7 @@ def target_age(cls, identifier, name, require_explicit_grade_marker=False):
old = young
if young is None and old is not None:
young = old
- if old and young and old < young:
+ if old and young and old < young:
young, old = old, young
return cls.range_tuple(young, old)
return cls.range_tuple(None, None)
@@ -143,32 +141,33 @@ def target_age_match(cls, query):
break
return (target_age, grade_words)
-class InterestLevelClassifier(Classifier):
+class InterestLevelClassifier(Classifier):
@classmethod
def audience(cls, identifier, name):
- if identifier in ('lg', 'mg+', 'mg'):
+ if identifier in ("lg", "mg+", "mg"):
return cls.AUDIENCE_CHILDREN
- elif identifier == 'ug':
+ elif identifier == "ug":
return cls.AUDIENCE_YOUNG_ADULT
else:
return None
@classmethod
def target_age(cls, identifier, name):
- if identifier == 'lg':
- return cls.range_tuple(5,8)
- if identifier in ('mg+', 'mg'):
- return cls.range_tuple(9,13)
- if identifier == 'ug':
- return cls.range_tuple(14,17)
+ if identifier == "lg":
+ return cls.range_tuple(5, 8)
+ if identifier in ("mg+", "mg"):
+ return cls.range_tuple(9, 13)
+ if identifier == "ug":
+ return cls.range_tuple(14, 17)
return None
class AgeClassifier(Classifier):
# Regular expressions that match common ways of expressing ages.
age_res = [
- re.compile(x, re.I) for x in [
+ re.compile(x, re.I)
+ for x in [
"age ([0-9]+) ?-? ?([0-9]+)?",
"age: ([0-9]+) ?-? ?([0-9]+)?",
"age: ([0-9]+) to ([0-9]+)",
@@ -230,7 +229,7 @@ def target_age(cls, identifier, name, require_explicit_age_marker=False):
if young > 99:
# This is not an age at all.
young = None
- if (young is not None and old is not None and young > old):
+ if young is not None and old is not None and young > old:
young, old = old, young
return cls.range_tuple(young, old)
return cls.range_tuple(None, None)
@@ -248,6 +247,7 @@ def target_age_match(cls, query):
break
return (target_age, age_words)
+
Classifier.classifiers[Classifier.AGE_RANGE] = AgeClassifier
Classifier.classifiers[Classifier.GRADE_LEVEL] = GradeLevelClassifier
Classifier.classifiers[Classifier.INTEREST_LEVEL] = InterestLevelClassifier
diff --git a/classifier/bic.py b/classifier/bic.py
index c59b61864..bd4631712 100644
--- a/classifier/bic.py
+++ b/classifier/bic.py
@@ -1,119 +1,125 @@
from . import *
+
class BICClassifier(Classifier):
# These prefixes came from from http://editeur.dyndns.org/bic_categories
LEVEL_1_PREFIXES = {
- Art_Design: 'A',
- Biography_Memoir: 'B',
- Foreign_Language_Study: 'C',
- Literary_Criticism: 'D',
- Reference_Study_Aids: 'G',
- Social_Sciences: 'J',
- Personal_Finance_Business: 'K',
- Law: 'L',
- Medical: 'M',
- Science_Technology: 'P',
- Technology: 'T',
- Computers: 'U',
+ Art_Design: "A",
+ Biography_Memoir: "B",
+ Foreign_Language_Study: "C",
+ Literary_Criticism: "D",
+ Reference_Study_Aids: "G",
+ Social_Sciences: "J",
+ Personal_Finance_Business: "K",
+ Law: "L",
+ Medical: "M",
+ Science_Technology: "P",
+ Technology: "T",
+ Computers: "U",
}
LEVEL_2_PREFIXES = {
- Art_History: 'AC',
- Photography: 'AJ',
- Design: 'AK',
- Architecture: 'AM',
- Film_TV: 'AP',
- Performing_Arts: 'AS',
- Music: 'AV',
- Poetry: 'DC',
- Drama: 'DD',
- Classics: 'FC',
- Mystery: 'FF',
- Suspense_Thriller: 'FH',
- Adventure: 'FJ',
- Horror: 'FK',
- Science_Fiction: 'FL',
- Fantasy: 'FM',
- Erotica: 'FP',
- Romance: 'FR',
- Historical_Fiction: 'FV',
- Religious_Fiction: 'FW',
- Comics_Graphic_Novels: 'FX',
- History: 'HB',
- Philosophy: 'HP',
- Religion_Spirituality: 'HR',
- Psychology: 'JM',
- Education: 'JN',
- Political_Science: 'JP',
- Economics: 'KC',
- Business: 'KJ',
- Mathematics: 'PB',
- Science: 'PD',
- Self_Help: 'VS',
- Body_Mind_Spirit: 'VX',
- Food_Health: 'WB',
- Antiques_Collectibles: 'WC',
- Crafts_Hobbies: 'WF',
- Humorous_Nonfiction: 'WH',
- House_Home: 'WK',
- Gardening: 'WM',
- Nature: 'WN',
- Sports: 'WS',
- Travel: 'WT',
+ Art_History: "AC",
+ Photography: "AJ",
+ Design: "AK",
+ Architecture: "AM",
+ Film_TV: "AP",
+ Performing_Arts: "AS",
+ Music: "AV",
+ Poetry: "DC",
+ Drama: "DD",
+ Classics: "FC",
+ Mystery: "FF",
+ Suspense_Thriller: "FH",
+ Adventure: "FJ",
+ Horror: "FK",
+ Science_Fiction: "FL",
+ Fantasy: "FM",
+ Erotica: "FP",
+ Romance: "FR",
+ Historical_Fiction: "FV",
+ Religious_Fiction: "FW",
+ Comics_Graphic_Novels: "FX",
+ History: "HB",
+ Philosophy: "HP",
+ Religion_Spirituality: "HR",
+ Psychology: "JM",
+ Education: "JN",
+ Political_Science: "JP",
+ Economics: "KC",
+ Business: "KJ",
+ Mathematics: "PB",
+ Science: "PD",
+ Self_Help: "VS",
+ Body_Mind_Spirit: "VX",
+ Food_Health: "WB",
+ Antiques_Collectibles: "WC",
+ Crafts_Hobbies: "WF",
+ Humorous_Nonfiction: "WH",
+ House_Home: "WK",
+ Gardening: "WM",
+ Nature: "WN",
+ Sports: "WS",
+ Travel: "WT",
}
LEVEL_3_PREFIXES = {
- Historical_Mystery: 'FFH',
- Espionage: 'FHD',
- Westerns: 'FJW',
- Space_Opera: 'FLS',
- Historical_Romance: 'FRH',
- Short_Stories: 'FYB',
- World_History: 'HBG',
- Military_History: 'HBW',
- Christianity: 'HRC',
- Buddhism: 'HRE',
- Hinduism: 'HRG',
- Islam: 'HRH',
- Judaism: 'HRJ',
- Fashion: 'WJF',
- Poetry: 'YDP',
- Adventure: 'YFC',
- Horror: 'YFD',
- Science_Fiction: 'YFG',
- Fantasy: 'YFH',
- Romance: 'YFM',
- Humorous_Fiction: 'YFQ',
- Historical_Fiction: 'YFT',
- Comics_Graphic_Novels: 'YFW',
- Art: 'YNA',
- Music: 'YNC',
- Performing_Arts: 'YND',
- Film_TV: 'YNF',
- History: 'YNH',
- Nature: 'YNN',
- Religion_Spirituality: 'YNR',
- Science_Technology: 'YNT',
- Humorous_Nonfiction: 'YNU',
- Sports: 'YNW',
+ Historical_Mystery: "FFH",
+ Espionage: "FHD",
+ Westerns: "FJW",
+ Space_Opera: "FLS",
+ Historical_Romance: "FRH",
+ Short_Stories: "FYB",
+ World_History: "HBG",
+ Military_History: "HBW",
+ Christianity: "HRC",
+ Buddhism: "HRE",
+ Hinduism: "HRG",
+ Islam: "HRH",
+ Judaism: "HRJ",
+ Fashion: "WJF",
+ Poetry: "YDP",
+ Adventure: "YFC",
+ Horror: "YFD",
+ Science_Fiction: "YFG",
+ Fantasy: "YFH",
+ Romance: "YFM",
+ Humorous_Fiction: "YFQ",
+ Historical_Fiction: "YFT",
+ Comics_Graphic_Novels: "YFW",
+ Art: "YNA",
+ Music: "YNC",
+ Performing_Arts: "YND",
+ Film_TV: "YNF",
+ History: "YNH",
+ Nature: "YNN",
+ Religion_Spirituality: "YNR",
+ Science_Technology: "YNT",
+ Humorous_Nonfiction: "YNU",
+ Sports: "YNW",
}
LEVEL_4_PREFIXES = {
- European_History: 'HBJD',
- Asian_History: 'HBJF',
- African_History: 'HBJH',
- Ancient_History: 'HBLA',
- Modern_History: 'HBLL',
- Drama: 'YNDS',
- Comics_Graphic_Novels: 'YNUC',
+ European_History: "HBJD",
+ Asian_History: "HBJF",
+ African_History: "HBJH",
+ Ancient_History: "HBLA",
+ Modern_History: "HBLL",
+ Drama: "YNDS",
+ Comics_Graphic_Novels: "YNUC",
}
- PREFIX_LISTS = [LEVEL_4_PREFIXES, LEVEL_3_PREFIXES, LEVEL_2_PREFIXES, LEVEL_1_PREFIXES]
+ PREFIX_LISTS = [
+ LEVEL_4_PREFIXES,
+ LEVEL_3_PREFIXES,
+ LEVEL_2_PREFIXES,
+ LEVEL_1_PREFIXES,
+ ]
@classmethod
def is_fiction(cls, identifier, name):
- if identifier.startswith('f') or identifier.startswith('yf'):
+ if identifier.startswith("f") or identifier.startswith("yf"):
return True
return False
@@ -133,4 +139,5 @@ def genre(cls, identifier, name, fiction=None, audience=None):
return l
return None
+
Classifier.classifiers[Classifier.BIC] = BICClassifier
diff --git a/classifier/bisac.py b/classifier/bisac.py
index 542caffd5..97da1813c 100644
--- a/classifier/bisac.py
+++ b/classifier/bisac.py
@@ -3,37 +3,47 @@
import os
import re
import string
+
from . import *
from .keyword import KeywordBasedClassifier
+
class CustomMatchToken(object):
"""A custom token used in matching rules."""
+
def matches(self, subject_token):
"""Does the given token match this one?"""
raise NotImplementedError()
+
class Something(CustomMatchToken):
"""A CustomMatchToken that will match any single token."""
+
def matches(self, subject_token):
return True
+
class RE(CustomMatchToken):
"""A CustomMatchToken that performs a regular expression search."""
+
def __init__(self, pattern):
self.re = re.compile(pattern, re.I)
def matches(self, subject_token):
return self.re.search(subject_token)
+
class Interchangeable(CustomMatchToken):
"""A token that matches a list of strings."""
+
def __init__(self, *choices):
"""All of these strings are interchangeable for matching purposes."""
self.choices = set([Lowercased(x) for x in choices])
- def matches(self,subject_token):
+ def matches(self, subject_token):
return Lowercased(subject_token) in self.choices
+
# Special tokens for use in matching rules.
something = Something()
fiction = Interchangeable("Juvenile Fiction", "Young Adult Fiction", "Fiction")
@@ -60,10 +70,13 @@ def matches(self,subject_token):
# If these variables are used in a rule, they must be the first token in
# that rule.
-special_variables = { nonfiction : "nonfiction",
- fiction : "fiction",
- juvenile : "juvenile",
- ya : "ya",}
+special_variables = {
+ nonfiction: "nonfiction",
+ fiction: "fiction",
+ juvenile: "juvenile",
+ ya: "ya",
+}
+
class MatchingRule(object):
"""A rule that takes a list of subject parts and returns
@@ -170,9 +183,7 @@ def _consume(self, rules, subject):
# rule token.
while subject:
subject_token = subject.pop(0)
- submatch, ignore1, ignore2 = self._consume(
- [next_rule], [subject_token]
- )
+ submatch, ignore1, ignore2 = self._consume([next_rule], [subject_token])
if submatch:
# We had to remove some number of subject tokens,
# but we found one that matches the next rule.
@@ -193,10 +204,13 @@ def _consume(self, rules, subject):
# This is too complex to be a CustomMatchToken because
# we may be modifying the subject token list.
match = subject_token not in (
- 'juvenile fiction', 'young adult fiction', 'fiction'
+ "juvenile fiction",
+ "young adult fiction",
+ "fiction",
)
if match and subject_token not in (
- 'juvenile nonfiction', 'young adult nonfiction'
+ "juvenile nonfiction",
+ "young adult nonfiction",
):
# The implicit top-level lane is 'nonfiction',
# which means we popped a token like 'History' that
@@ -272,269 +286,275 @@ class BISACClassifier(Classifier):
]
TARGET_AGE = [
- m((0,4), juvenile, anything, "Readers", "Beginner") ,
- m((5,7), juvenile, anything, "Readers", "Intermediate"),
- m((5,7), juvenile, anything, "Early Readers"),
- m((8,13), juvenile, anything, "Chapter Books")
+ m((0, 4), juvenile, anything, "Readers", "Beginner"),
+ m((5, 7), juvenile, anything, "Readers", "Intermediate"),
+ m((5, 7), juvenile, anything, "Early Readers"),
+ m((8, 13), juvenile, anything, "Chapter Books"),
]
GENRE = [
-
# Put all erotica in Erotica, to keep the other lanes at
# "Adult" level or lower.
- m(Erotica, anything, 'Erotica'),
-
+ m(Erotica, anything, "Erotica"),
# Put all non-erotica comics into the same bucket, regardless
# of their content.
- m(Comics_Graphic_Novels, 'Comics & Graphic Novels'),
- m(Comics_Graphic_Novels, nonfiction, 'Comics & Graphic Novels'),
- m(Comics_Graphic_Novels, fiction, 'Comics & Graphic Novels'),
-
+ m(Comics_Graphic_Novels, "Comics & Graphic Novels"),
+ m(Comics_Graphic_Novels, nonfiction, "Comics & Graphic Novels"),
+ m(Comics_Graphic_Novels, fiction, "Comics & Graphic Novels"),
# "Literary Criticism / Foo" implies Literary Criticism, not Foo.
- m(Literary_Criticism, anything, literary_criticism),
-
+ m(Literary_Criticism, anything, literary_criticism),
# "Fiction / Christian / Foo" implies Religious Fiction
# more strongly than it implies Foo.
- m(Religious_Fiction, fiction, anything, 'Christian'),
-
+ m(Religious_Fiction, fiction, anything, "Christian"),
# "Fiction / Foo / Short Stories" implies Short Stories more
# strongly than it implies Foo. This assumes that a short
# story collection within a genre will also be classified
# separately under that genre. This could definitely be
# improved but would require a Subject to map to multiple
# Genres.
- m(Short_Stories, fiction, anything, RE('^Anthologies')),
- m(Short_Stories, fiction, anything, RE('^Short Stories')),
- m(Short_Stories, 'Literary Collections'),
- m(Short_Stories, fiction, anything, 'Collections & Anthologies'),
-
+ m(Short_Stories, fiction, anything, RE("^Anthologies")),
+ m(Short_Stories, fiction, anything, RE("^Short Stories")),
+ m(Short_Stories, "Literary Collections"),
+ m(Short_Stories, fiction, anything, "Collections & Anthologies"),
# Classify top-level fiction categories into fiction genres.
#
# First, handle large overarching genres that have subgenres
# and adjacent genres.
#
-
# Fantasy
- m(Epic_Fantasy, fiction, 'Fantasy', 'Epic'),
- m(Historical_Fantasy, fiction, 'Fantasy', 'Historical'),
- m(Urban_Fantasy, fiction, 'Fantasy', 'Urban'),
- m(Fantasy, fiction, 'Fantasy'),
- m(Fantasy, fiction, 'Romance', 'Fantasy'),
- m(Fantasy, fiction, 'Sagas'),
-
+ m(Epic_Fantasy, fiction, "Fantasy", "Epic"),
+ m(Historical_Fantasy, fiction, "Fantasy", "Historical"),
+ m(Urban_Fantasy, fiction, "Fantasy", "Urban"),
+ m(Fantasy, fiction, "Fantasy"),
+ m(Fantasy, fiction, "Romance", "Fantasy"),
+ m(Fantasy, fiction, "Sagas"),
# Mystery
# n.b. no BISAC for Paranormal_Mystery
- m(Crime_Detective_Stories, fiction, 'Mystery & Detective', 'Private Investigators'),
- m(Crime_Detective_Stories, fiction, 'Crime'),
- m(Crime_Detective_Stories, fiction, 'Thrillers', 'Crime'),
- m(Hard_Boiled_Mystery, fiction, 'Mystery & Detective', 'Hard-Boiled'),
- m(Police_Procedural, fiction, 'Mystery & Detective', 'Police Procedural'),
- m(Cozy_Mystery, fiction, 'Mystery & Detective', 'Cozy'),
- m(Historical_Mystery, fiction, 'Mystery & Detective', 'Historical'),
- m(Women_Detectives, fiction, 'Mystery & Detective', 'Women Sleuths'),
- m(Mystery, fiction, anything, 'Mystery & Detective'),
-
+ m(
+ Crime_Detective_Stories,
+ fiction,
+ "Mystery & Detective",
+ "Private Investigators",
+ ),
+ m(Crime_Detective_Stories, fiction, "Crime"),
+ m(Crime_Detective_Stories, fiction, "Thrillers", "Crime"),
+ m(Hard_Boiled_Mystery, fiction, "Mystery & Detective", "Hard-Boiled"),
+ m(Police_Procedural, fiction, "Mystery & Detective", "Police Procedural"),
+ m(Cozy_Mystery, fiction, "Mystery & Detective", "Cozy"),
+ m(Historical_Mystery, fiction, "Mystery & Detective", "Historical"),
+ m(Women_Detectives, fiction, "Mystery & Detective", "Women Sleuths"),
+ m(Mystery, fiction, anything, "Mystery & Detective"),
# Horror
- m(Ghost_Stories, fiction, 'Ghost'),
- m(Occult_Horror, fiction, 'Occult & Supernatural'),
- m(Gothic_Horror, fiction, 'Gothic'),
- m(Horror, fiction, 'Horror'),
-
+ m(Ghost_Stories, fiction, "Ghost"),
+ m(Occult_Horror, fiction, "Occult & Supernatural"),
+ m(Gothic_Horror, fiction, "Gothic"),
+ m(Horror, fiction, "Horror"),
# Romance
# n.b. no BISAC for Gothic Romance
- m(Contemporary_Romance, fiction, 'Romance', 'Contemporary'),
- m(Historical_Romance, fiction, 'Romance', 'Historical'),
- m(Paranormal_Romance, fiction, 'Romance', 'Paranormal'),
- m(Western_Romance, fiction, 'Romance', 'Western'),
- m(Romantic_Suspense, fiction, 'Romance', 'Suspense'),
- m(Romantic_SF, fiction, 'Romance', 'Time Travel'),
- m(Romantic_SF, fiction, 'Romance', 'Science Fiction'),
- m(Romance, fiction, 'Romance'),
-
+ m(Contemporary_Romance, fiction, "Romance", "Contemporary"),
+ m(Historical_Romance, fiction, "Romance", "Historical"),
+ m(Paranormal_Romance, fiction, "Romance", "Paranormal"),
+ m(Western_Romance, fiction, "Romance", "Western"),
+ m(Romantic_Suspense, fiction, "Romance", "Suspense"),
+ m(Romantic_SF, fiction, "Romance", "Time Travel"),
+ m(Romantic_SF, fiction, "Romance", "Science Fiction"),
+ m(Romance, fiction, "Romance"),
# Science fiction
# n.b. no BISAC for Cyberpunk
- m(Dystopian_SF, fiction, 'Dystopian'),
- m(Space_Opera, fiction, 'Science Fiction', 'Space Opera'),
- m(Military_SF, fiction, 'Science Fiction', 'Military'),
- m(Alternative_History, fiction, 'Alternative History'),
+ m(Dystopian_SF, fiction, "Dystopian"),
+ m(Space_Opera, fiction, "Science Fiction", "Space Opera"),
+ m(Military_SF, fiction, "Science Fiction", "Military"),
+ m(Alternative_History, fiction, "Alternative History"),
# Juvenile steampunk is classified directly beneath 'fiction'.
- m(Steampunk, fiction, anything, 'Steampunk'),
- m(Science_Fiction, fiction, 'Science Fiction'),
-
+ m(Steampunk, fiction, anything, "Steampunk"),
+ m(Science_Fiction, fiction, "Science Fiction"),
# Thrillers
# n.b. no BISAC for Supernatural_Thriller
- m(Historical_Thriller, fiction, 'Thrillers', 'Historical'),
- m(Espionage, fiction, 'Thrillers', 'Espionage'),
- m(Medical_Thriller, fiction, 'Thrillers', 'Medical'),
- m(Political_Thriller, fiction, 'Thrillers', 'Political'),
- m(Legal_Thriller, fiction, 'Thrillers', 'Legal'),
- m(Technothriller, fiction, 'Thrillers', 'Technological'),
- m(Military_Thriller, fiction, 'Thrillers', 'Military'),
- m(Suspense_Thriller, fiction, 'Thrillers'),
-
+ m(Historical_Thriller, fiction, "Thrillers", "Historical"),
+ m(Espionage, fiction, "Thrillers", "Espionage"),
+ m(Medical_Thriller, fiction, "Thrillers", "Medical"),
+ m(Political_Thriller, fiction, "Thrillers", "Political"),
+ m(Legal_Thriller, fiction, "Thrillers", "Legal"),
+ m(Technothriller, fiction, "Thrillers", "Technological"),
+ m(Military_Thriller, fiction, "Thrillers", "Military"),
+ m(Suspense_Thriller, fiction, "Thrillers"),
# Then handle the less complicated genres of fiction.
- m(Adventure, fiction, 'Action & Adventure'),
- m(Adventure, fiction, 'Sea Stories'),
- m(Adventure, fiction, 'War & Military'),
- m(Classics, fiction, 'Classics'),
- m(Folklore, fiction, 'Fairy Tales, Folk Tales, Legends & Mythology'),
- m(Historical_Fiction, anything, 'Historical'),
- m(Humorous_Fiction, fiction, 'Humorous'),
- m(Humorous_Fiction, fiction, 'Satire'),
- m(Literary_Fiction, fiction, 'Literary'),
- m(LGBTQ_Fiction, fiction, 'Gay'),
- m(LGBTQ_Fiction, fiction, 'Lesbian'),
- m(LGBTQ_Fiction, fiction, 'Gay & Lesbian'),
- m(Religious_Fiction, fiction, 'Religious'),
- m(Religious_Fiction, fiction, 'Jewish'),
- m(Religious_Fiction, fiction, 'Visionary & Metaphysical'),
- m(Womens_Fiction, fiction, anything, 'Contemporary Women'),
- m(Westerns, fiction, 'Westerns'),
-
+ m(Adventure, fiction, "Action & Adventure"),
+ m(Adventure, fiction, "Sea Stories"),
+ m(Adventure, fiction, "War & Military"),
+ m(Classics, fiction, "Classics"),
+ m(Folklore, fiction, "Fairy Tales, Folk Tales, Legends & Mythology"),
+ m(Historical_Fiction, anything, "Historical"),
+ m(Humorous_Fiction, fiction, "Humorous"),
+ m(Humorous_Fiction, fiction, "Satire"),
+ m(Literary_Fiction, fiction, "Literary"),
+ m(LGBTQ_Fiction, fiction, "Gay"),
+ m(LGBTQ_Fiction, fiction, "Lesbian"),
+ m(LGBTQ_Fiction, fiction, "Gay & Lesbian"),
+ m(Religious_Fiction, fiction, "Religious"),
+ m(Religious_Fiction, fiction, "Jewish"),
+ m(Religious_Fiction, fiction, "Visionary & Metaphysical"),
+ m(Womens_Fiction, fiction, anything, "Contemporary Women"),
+ m(Westerns, fiction, "Westerns"),
# n.b. BISAC "Fiction / Urban" is distinct from "Fiction /
# African-American / Urban", and does not map to any of our
# genres.
- m(Urban_Fiction, fiction, 'African American', 'Urban'),
-
+ m(Urban_Fiction, fiction, "African American", "Urban"),
# BISAC classifies these genres at the top level, which we
# treat as 'nonfiction', but we classify them as fiction. It
# doesn't matter because they're neither, really.
- m(Drama, nonfiction, 'Drama'),
- m(Poetry, nonfiction, 'Poetry'),
-
+ m(Drama, nonfiction, "Drama"),
+ m(Poetry, nonfiction, "Poetry"),
# Now on to nonfiction.
-
# Classify top-level nonfiction categories into fiction genres.
#
# First, handle large overarching genres that have subgenres
# and adjacent genres.
#
-
# Art & Design
- m(Architecture, nonfiction, 'Architecture'),
- m(Art_Criticism_Theory, nonfiction, 'Art', 'Criticism & Theory'),
- m(Art_History, nonfiction, 'Art', 'History'),
- m(Fashion, nonfiction, 'Design', 'Fashion'),
- m(Design, nonfiction, 'Design'),
- m(Art_Design, nonfiction, 'Art'),
- m(Photography, nonfiction, 'Photography'),
-
+ m(Architecture, nonfiction, "Architecture"),
+ m(Art_Criticism_Theory, nonfiction, "Art", "Criticism & Theory"),
+ m(Art_History, nonfiction, "Art", "History"),
+ m(Fashion, nonfiction, "Design", "Fashion"),
+ m(Design, nonfiction, "Design"),
+ m(Art_Design, nonfiction, "Art"),
+ m(Photography, nonfiction, "Photography"),
# Personal Finance & Business
- m(Business, nonfiction, 'Business & Economics', RE('^Business.*')),
- m(Business, nonfiction, 'Business & Economics', 'Accounting'),
- m(Economics, nonfiction, 'Business & Economics', 'Economics'),
-
- m(Economics, nonfiction, 'Business & Economics', 'Environmental Economics'),
- m(Economics, nonfiction, 'Business & Economics', RE('^Econo.*')),
- m(Management_Leadership, nonfiction, 'Business & Economics', 'Management'),
- m(Management_Leadership, nonfiction, 'Business & Economics', 'Management Science'),
- m(Management_Leadership, nonfiction, 'Business & Economics', 'Leadership'),
- m(Personal_Finance_Investing, nonfiction, 'Business & Economics', 'Personal Finance'),
- m(Personal_Finance_Investing, nonfiction, 'Business & Economics', 'Personal Success'),
- m(Personal_Finance_Investing, nonfiction, 'Business & Economics', 'Investments & Securities'),
- m(Real_Estate, nonfiction, 'Business & Economics', 'Real Estate'),
- m(Personal_Finance_Business, nonfiction, 'Business & Economics'),
-
+ m(Business, nonfiction, "Business & Economics", RE("^Business.*")),
+ m(Business, nonfiction, "Business & Economics", "Accounting"),
+ m(Economics, nonfiction, "Business & Economics", "Economics"),
+ m(Economics, nonfiction, "Business & Economics", "Environmental Economics"),
+ m(Economics, nonfiction, "Business & Economics", RE("^Econo.*")),
+ m(Management_Leadership, nonfiction, "Business & Economics", "Management"),
+ m(
+ Management_Leadership,
+ nonfiction,
+ "Business & Economics",
+ "Management Science",
+ ),
+ m(Management_Leadership, nonfiction, "Business & Economics", "Leadership"),
+ m(
+ Personal_Finance_Investing,
+ nonfiction,
+ "Business & Economics",
+ "Personal Finance",
+ ),
+ m(
+ Personal_Finance_Investing,
+ nonfiction,
+ "Business & Economics",
+ "Personal Success",
+ ),
+ m(
+ Personal_Finance_Investing,
+ nonfiction,
+ "Business & Economics",
+ "Investments & Securities",
+ ),
+ m(Real_Estate, nonfiction, "Business & Economics", "Real Estate"),
+ m(Personal_Finance_Business, nonfiction, "Business & Economics"),
# Parenting & Family
- m(Parenting, nonfiction, 'Family & Relationships', 'Parenting'),
- m(Family_Relationships, nonfiction, 'Family & Relationships'),
-
+ m(Parenting, nonfiction, "Family & Relationships", "Parenting"),
+ m(Family_Relationships, nonfiction, "Family & Relationships"),
# Food & Health
- m(Bartending_Cocktails, nonfiction, 'Cooking', 'Beverages'),
- m(Health_Diet, nonfiction, 'Cooking', 'Health & Healing'),
- m(Health_Diet, nonfiction, 'Health & Fitness'),
- m(Vegetarian_Vegan, nonfiction, 'Cooking', 'Vegetarian & Vegan'),
- m(Cooking, nonfiction, 'Cooking'),
-
+ m(Bartending_Cocktails, nonfiction, "Cooking", "Beverages"),
+ m(Health_Diet, nonfiction, "Cooking", "Health & Healing"),
+ m(Health_Diet, nonfiction, "Health & Fitness"),
+ m(Vegetarian_Vegan, nonfiction, "Cooking", "Vegetarian & Vegan"),
+ m(Cooking, nonfiction, "Cooking"),
# History
- m(African_History, nonfiction, 'History', 'Africa'),
- m(Ancient_History, nonfiction, 'History', 'Ancient'),
- m(Asian_History, nonfiction, 'History', 'Asia'),
- m(Civil_War_History, nonfiction, 'History', 'United States', RE('^Civil War')),
- m(European_History, nonfiction, 'History', 'Europe'),
- m(Latin_American_History, nonfiction, 'History', 'Latin America'),
- m(Medieval_History, nonfiction, 'History', 'Medieval'),
- m(Military_History, nonfiction, 'History', 'Military'),
- m(Renaissance_Early_Modern_History, nonfiction, 'History', 'Renaissance'),
- m(Renaissance_Early_Modern_History, nonfiction, 'History', 'Modern', RE('^1[678]th Century')),
- m(Modern_History, nonfiction, 'History', 'Modern'),
- m(United_States_History, nonfiction, 'History', 'Native American'),
- m(United_States_History, nonfiction, 'History', 'United States'),
- m(World_History, nonfiction, 'History', 'World'),
- m(World_History, nonfiction, 'History', 'Civilization'),
- m(History, nonfiction, 'History'),
-
+ m(African_History, nonfiction, "History", "Africa"),
+ m(Ancient_History, nonfiction, "History", "Ancient"),
+ m(Asian_History, nonfiction, "History", "Asia"),
+ m(Civil_War_History, nonfiction, "History", "United States", RE("^Civil War")),
+ m(European_History, nonfiction, "History", "Europe"),
+ m(Latin_American_History, nonfiction, "History", "Latin America"),
+ m(Medieval_History, nonfiction, "History", "Medieval"),
+ m(Military_History, nonfiction, "History", "Military"),
+ m(Renaissance_Early_Modern_History, nonfiction, "History", "Renaissance"),
+ m(
+ Renaissance_Early_Modern_History,
+ nonfiction,
+ "History",
+ "Modern",
+ RE("^1[678]th Century"),
+ ),
+ m(Modern_History, nonfiction, "History", "Modern"),
+ m(United_States_History, nonfiction, "History", "Native American"),
+ m(United_States_History, nonfiction, "History", "United States"),
+ m(World_History, nonfiction, "History", "World"),
+ m(World_History, nonfiction, "History", "Civilization"),
+ m(History, nonfiction, "History"),
# Hobbies & Home
- m(Antiques_Collectibles, nonfiction, 'Antiques & Collectibles'),
- m(Crafts_Hobbies, nonfiction, 'Crafts & Hobbies'),
- m(Gardening, nonfiction, 'Gardening'),
- m(Games, nonfiction, 'Games'),
- m(House_Home, nonfiction, 'House & Home'),
- m(Pets, nonfiction, 'Pets'),
-
+ m(Antiques_Collectibles, nonfiction, "Antiques & Collectibles"),
+ m(Crafts_Hobbies, nonfiction, "Crafts & Hobbies"),
+ m(Gardening, nonfiction, "Gardening"),
+ m(Games, nonfiction, "Games"),
+ m(House_Home, nonfiction, "House & Home"),
+ m(Pets, nonfiction, "Pets"),
# Entertainment
- m(Film_TV, nonfiction, 'Performing Arts', 'Film & Video'),
- m(Film_TV, nonfiction, 'Performing Arts', 'Television'),
- m(Music, nonfiction, 'Music'),
- m(Performing_Arts, nonfiction, 'Performing Arts'),
-
+ m(Film_TV, nonfiction, "Performing Arts", "Film & Video"),
+ m(Film_TV, nonfiction, "Performing Arts", "Television"),
+ m(Music, nonfiction, "Music"),
+ m(Performing_Arts, nonfiction, "Performing Arts"),
# Reference & Study Aids
- m(Dictionaries, nonfiction, 'Reference', 'Dictionaries'),
- m(Foreign_Language_Study, nonfiction, 'Foreign Language Study'),
- m(Law, nonfiction, 'Law'),
- m(Study_Aids, nonfiction, 'Study Aids'),
- m(Reference_Study_Aids, nonfiction, 'Reference'),
- m(Reference_Study_Aids, nonfiction, 'Language Arts & Disciplines'),
-
+ m(Dictionaries, nonfiction, "Reference", "Dictionaries"),
+ m(Foreign_Language_Study, nonfiction, "Foreign Language Study"),
+ m(Law, nonfiction, "Law"),
+ m(Study_Aids, nonfiction, "Study Aids"),
+ m(Reference_Study_Aids, nonfiction, "Reference"),
+ m(Reference_Study_Aids, nonfiction, "Language Arts & Disciplines"),
# Religion & Spirituality
m(Body_Mind_Spirit, nonfiction, body_mind_spirit),
- m(Buddhism, nonfiction, 'Religion', 'Buddhism'),
- m(Christianity, nonfiction, 'Religion', RE('^Biblical')),
- m(Christianity, nonfiction, 'Religion', RE('^Christian')),
- m(Christianity, nonfiction, 'Bibles'),
- m(Hinduism, nonfiction, 'Religion', 'Hinduism'),
- m(Islam, nonfiction, 'Religion', 'Islam'),
- m(Judaism, nonfiction, 'Religion', 'Judaism'),
- m(Religion_Spirituality, nonfiction, 'Religion'),
-
+ m(Buddhism, nonfiction, "Religion", "Buddhism"),
+ m(Christianity, nonfiction, "Religion", RE("^Biblical")),
+ m(Christianity, nonfiction, "Religion", RE("^Christian")),
+ m(Christianity, nonfiction, "Bibles"),
+ m(Hinduism, nonfiction, "Religion", "Hinduism"),
+ m(Islam, nonfiction, "Religion", "Islam"),
+ m(Judaism, nonfiction, "Religion", "Judaism"),
+ m(Religion_Spirituality, nonfiction, "Religion"),
# Science & Technology
- m(Computers, nonfiction, 'Computers'),
- m(Mathematics, nonfiction, 'Mathematics'),
- m(Medical, nonfiction, 'Medical'),
- m(Nature, nonfiction, 'Nature'),
+ m(Computers, nonfiction, "Computers"),
+ m(Mathematics, nonfiction, "Mathematics"),
+ m(Medical, nonfiction, "Medical"),
+ m(Nature, nonfiction, "Nature"),
m(Psychology, nonfiction, psychology),
- m(Political_Science, nonfiction, 'Social Science', 'Politics & Government'),
- m(Social_Sciences, nonfiction, 'Social Science'),
+ m(Political_Science, nonfiction, "Social Science", "Politics & Government"),
+ m(Social_Sciences, nonfiction, "Social Science"),
m(Technology, nonfiction, technology),
- m(Technology, nonfiction, 'Transportation'),
- m(Science, nonfiction, 'Science'),
-
+ m(Technology, nonfiction, "Transportation"),
+ m(Science, nonfiction, "Science"),
# Then handle the less complicated genres of nonfiction.
# n.b. no BISAC for Periodicals.
# n.b. no BISAC for Humorous Nonfiction per se.
- m(Music, nonfiction, 'Biography & Autobiography', 'Composers & Musicians'),
- m(Entertainment, nonfiction, 'Biography & Autobiography', 'Entertainment & Performing Arts'),
- m(Biography_Memoir, nonfiction, 'Biography & Autobiography'),
+ m(Music, nonfiction, "Biography & Autobiography", "Composers & Musicians"),
+ m(
+ Entertainment,
+ nonfiction,
+ "Biography & Autobiography",
+ "Entertainment & Performing Arts",
+ ),
+ m(Biography_Memoir, nonfiction, "Biography & Autobiography"),
m(Education, nonfiction, "Education"),
- m(Philosophy, nonfiction, 'Philosophy'),
- m(Political_Science, nonfiction, 'Political Science'),
- m(Self_Help, nonfiction, 'Self-Help'),
- m(Sports, nonfiction, 'Sports & Recreation'),
- m(Travel, nonfiction, 'Travel'),
- m(True_Crime, nonfiction, 'True Crime'),
-
+ m(Philosophy, nonfiction, "Philosophy"),
+ m(Political_Science, nonfiction, "Political Science"),
+ m(Self_Help, nonfiction, "Self-Help"),
+ m(Sports, nonfiction, "Sports & Recreation"),
+ m(Travel, nonfiction, "Travel"),
+ m(True_Crime, nonfiction, "True Crime"),
# Handle cases where Juvenile/YA uses different terms than
# would be used for the same books for adults.
- m(Business, nonfiction, 'Careers'),
+ m(Business, nonfiction, "Careers"),
m(Christianity, nonfiction, "Religious", "Christian"),
m(Cooking, nonfiction, "Cooking & Food"),
m(Education, nonfiction, "School & Education"),
m(Family_Relationships, nonfiction, "Family"),
m(Fantasy, fiction, "Fantasy & Magic"),
- m(Ghost_Stories, fiction, 'Ghost Stories'),
- m(Fantasy, fiction, 'Magical Realism'),
- m(Fantasy, fiction, 'Mermaids'),
- m(Fashion, nonfiction, 'Fashion'),
+ m(Ghost_Stories, fiction, "Ghost Stories"),
+ m(Fantasy, fiction, "Magical Realism"),
+ m(Fantasy, fiction, "Mermaids"),
+ m(Fashion, nonfiction, "Fashion"),
m(Folklore, fiction, "Fairy Tales & Folklore"),
m(Folklore, fiction, "Legends, Myths, Fables"),
m(Games, nonfiction, "Games & Activities"),
@@ -542,41 +562,38 @@ class BISACClassifier(Classifier):
m(Horror, fiction, "Horror & Ghost Stories"),
m(Horror, fiction, "Monsters"),
m(Horror, fiction, "Paranormal"),
- m(Horror, fiction, 'Paranormal, Occult & Supernatural'),
- m(Horror, fiction, 'Vampires'),
- m(Horror, fiction, 'Werewolves & Shifters'),
- m(Horror, fiction, 'Zombies'),
+ m(Horror, fiction, "Paranormal, Occult & Supernatural"),
+ m(Horror, fiction, "Vampires"),
+ m(Horror, fiction, "Werewolves & Shifters"),
+ m(Horror, fiction, "Zombies"),
m(Humorous_Fiction, fiction, "Humorous Stories"),
m(Humorous_Nonfiction, "Young Adult Nonfiction", "Humor"),
- m(LGBTQ_Fiction, fiction, 'LGBT'),
+ m(LGBTQ_Fiction, fiction, "LGBT"),
m(Law, nonfiction, "Law & Crime"),
m(Mystery, fiction, "Mysteries & Detective Stories"),
m(Nature, nonfiction, "Animals"),
- m(Personal_Finance_Investing, nonfiction, 'Personal Finance'),
+ m(Personal_Finance_Investing, nonfiction, "Personal Finance"),
m(Poetry, fiction, "Nursery Rhymes"),
m(Poetry, fiction, "Stories in Verse"),
- m(Poetry, fiction, 'Novels in Verse'),
- m(Poetry, fiction, 'Poetry'),
+ m(Poetry, fiction, "Novels in Verse"),
+ m(Poetry, fiction, "Poetry"),
m(Reference_Study_Aids, nonfiction, "Language Arts"),
m(Romance, fiction, "Love & Romance"),
m(Science_Fiction, fiction, "Robots"),
m(Science_Fiction, fiction, "Time Travel"),
m(Social_Sciences, nonfiction, "Media Studies"),
- m(Suspense_Thriller, fiction, 'Superheroes'),
- m(Suspense_Thriller, fiction, 'Thrillers & Suspense'),
-
+ m(Suspense_Thriller, fiction, "Superheroes"),
+ m(Suspense_Thriller, fiction, "Thrillers & Suspense"),
# Most of the subcategories of 'Science & Nature' go into Nature,
# but these go into Science.
- m(Science, nonfiction, 'Science & Nature', 'Discoveries'),
- m(Science, nonfiction, 'Science & Nature', 'Experiments & Projects'),
- m(Science, nonfiction, 'Science & Nature', 'History of Science'),
- m(Science, nonfiction, 'Science & Nature', 'Physics'),
- m(Science, nonfiction, 'Science & Nature', 'Weights & Measures'),
- m(Science, nonfiction, 'Science & Nature', 'General'),
-
+ m(Science, nonfiction, "Science & Nature", "Discoveries"),
+ m(Science, nonfiction, "Science & Nature", "Experiments & Projects"),
+ m(Science, nonfiction, "Science & Nature", "History of Science"),
+ m(Science, nonfiction, "Science & Nature", "Physics"),
+ m(Science, nonfiction, "Science & Nature", "Weights & Measures"),
+ m(Science, nonfiction, "Science & Nature", "General"),
# Any other subcategory of 'Science & Nature' goes under Nature
- m(Nature, nonfiction, 'Science & Nature', something),
-
+ m(Nature, nonfiction, "Science & Nature", something),
# Life Strategies is juvenile/YA-specific, and contains both
# fiction and nonfiction. It's called "Social Issues" for
# juvenile fiction/nonfiction, and "Social Topics" for YA
@@ -633,19 +650,17 @@ def genre(cls, identifier, name, fiction, audience):
# If all else fails, try a keyword-based classifier.
keyword = "/".join(name)
- return KeywordBasedClassifier.genre(
- identifier, keyword, fiction, audience
- )
+ return KeywordBasedClassifier.genre(identifier, keyword, fiction, audience)
# A BISAC name copied from the BISAC website may end with this
# human-readable note, which is not part of the official name.
- see_also = re.compile('\(see also .*')
+ see_also = re.compile("\(see also .*")
@classmethod
def scrub_identifier(cls, identifier):
if not identifier:
return identifier
- if identifier.startswith('FB'):
+ if identifier.startswith("FB"):
identifier = identifier[2:]
if identifier in cls.NAMES:
# We know the canonical name for this BISAC identifier,
@@ -670,7 +685,7 @@ def scrub_name(cls, name):
name = name.replace(" ", ", ")
# The name may be enclosed in an extra set of quotes.
- for quote in ("'\""):
+ for quote in "'\"":
if name.startswith(quote):
name = name[1:]
if name.endswith(quote):
@@ -678,23 +693,23 @@ def scrub_name(cls, name):
# The name may end with an extraneous marker character or
# (if it was copied from the BISAC website) an asterisk.
- for separator in '|/*':
+ for separator in "|/*":
if name.endswith(separator):
name = name[:-1]
# A name copied from the BISAC website may end with a
# human-readable cross-reference.
- name = cls.see_also.sub('', name)
+ name = cls.see_also.sub("", name)
# The canonical separator character is a slash, but a pipe
# has also been used.
- for separator in '|/':
+ for separator in "|/":
if separator in name:
- parts = [name.strip() for name in name.split(separator)
- if name.strip()]
+ parts = [name.strip() for name in name.split(separator) if name.strip()]
break
else:
parts = [name]
return parts
+
Classifier.classifiers[Classifier.BISAC] = BISACClassifier
diff --git a/classifier/ddc.py b/classifier/ddc.py
index 4bb0f458e..95073eb0e 100644
--- a/classifier/ddc.py
+++ b/classifier/ddc.py
@@ -1,14 +1,15 @@
import json
import os
+
from . import *
base_dir = os.path.split(__file__)[0]
resource_dir = os.path.join(base_dir, "..", "resources")
+
class DeweyDecimalClassifier(Classifier):
- NAMES = json.load(
- open(os.path.join(resource_dir, "dewey_1000.json")))
+ NAMES = json.load(open(os.path.join(resource_dir, "dewey_1000.json")))
# Add some other values commonly found in MARC records.
NAMES["B"] = "Biography"
@@ -32,53 +33,75 @@ class DeweyDecimalClassifier(Classifier):
# 398.7 Jokes and jests
GENRES = {
- African_History : list(range(960, 970)),
- Architecture : list(range(710, 720)) + list(range(720, 730)),
- Art : list(range(700, 710)) + list(range(730, 770)) + [774, 776],
- Art_Criticism_Theory : [701],
- Asian_History : list(range(950, 960)) + [995, 996, 997],
- Biography_Memoir : ["B", 920],
- Economics : list(range(330, 340)),
- Christianity : [list(range(220, 230)) + list(range(230, 290))],
- Cooking : [list(range(640, 642))],
- Performing_Arts : [790, 791, 792],
- Entertainment : 790,
- Games : [793, 794, 795],
- Drama : [812, 822, 832, 842, 852, 862, 872, 882],
- Education : list(range(370,380)) + [707],
- European_History : list(range(940, 950)),
- Folklore : [398],
- History : [900],
- Islam : [297],
- Judaism : [296],
- Latin_American_History : list(range(981, 990)),
- Law : list(range(340, 350)) + [364],
- Management_Leadership : [658],
- Mathematics : list(range(510, 520)),
- Medical : list(range(610, 620)),
- Military_History : list(range(355, 360)),
- Music : list(range(780, 789)),
- Periodicals : list(range(50, 60)) + [105, 405, 505, 605, 705, 805, 905],
- Philosophy : list(range(160, 200)),
- Photography : [771, 772, 773, 775, 778, 779],
- Poetry : [811, 821, 831, 841, 851, 861, 871, 874, 881, 884],
- Political_Science : list(range(320, 330)) + list(range(351, 355)),
- Psychology : list(range(150, 160)),
- Foreign_Language_Study : list(range(430,500)),
- Reference_Study_Aids : list(range(10, 20)) + list(range(30, 40)) + [103, 203, 303, 403, 503, 603, 703, 803, 903] + list(range(410, 430)),
- Religion_Spirituality : list(range(200, 220)) + [290, 292, 293, 294, 295, 299],
- Science : ([500, 501, 502] + list(range(506, 510)) + list(range(520, 530))
- + list(range(530, 540)) + list(range(540, 550)) + list(range(550, 560))
- + list(range(560, 570)) + list(range(570, 580)) + list(range(580, 590))
- + list(range(590, 600))),
- Social_Sciences : (list(range(300, 310)) + list(range(360, 364)) + list(range(390,397)) + [399]),
- Sports : list(range(796, 800)),
- Technology : (
- [600, 601, 602, 604] + list(range(606, 610)) + list(range(610, 640))
- + list(range(660, 670)) + list(range(670, 680)) + list(range(681, 690)) + list(range(690, 700))),
- Travel : list(range(910, 920)),
- United_States_History : list(range(973,980)),
- World_History : [909],
+ African_History: list(range(960, 970)),
+ Architecture: list(range(710, 720)) + list(range(720, 730)),
+ Art: list(range(700, 710)) + list(range(730, 770)) + [774, 776],
+ Art_Criticism_Theory: [701],
+ Asian_History: list(range(950, 960)) + [995, 996, 997],
+ Biography_Memoir: ["B", 920],
+ Economics: list(range(330, 340)),
+ Christianity: [list(range(220, 230)) + list(range(230, 290))],
+ Cooking: [list(range(640, 642))],
+ Performing_Arts: [790, 791, 792],
+ Entertainment: 790,
+ Games: [793, 794, 795],
+ Drama: [812, 822, 832, 842, 852, 862, 872, 882],
+ Education: list(range(370, 380)) + [707],
+ European_History: list(range(940, 950)),
+ Folklore: [398],
+ History: [900],
+ Islam: [297],
+ Judaism: [296],
+ Latin_American_History: list(range(981, 990)),
+ Law: list(range(340, 350)) + [364],
+ Management_Leadership: [658],
+ Mathematics: list(range(510, 520)),
+ Medical: list(range(610, 620)),
+ Military_History: list(range(355, 360)),
+ Music: list(range(780, 789)),
+ Periodicals: list(range(50, 60)) + [105, 405, 505, 605, 705, 805, 905],
+ Philosophy: list(range(160, 200)),
+ Photography: [771, 772, 773, 775, 778, 779],
+ Poetry: [811, 821, 831, 841, 851, 861, 871, 874, 881, 884],
+ Political_Science: list(range(320, 330)) + list(range(351, 355)),
+ Psychology: list(range(150, 160)),
+ Foreign_Language_Study: list(range(430, 500)),
+ Reference_Study_Aids: list(range(10, 20))
+ + list(range(30, 40))
+ + [103, 203, 303, 403, 503, 603, 703, 803, 903]
+ + list(range(410, 430)),
+ Religion_Spirituality: list(range(200, 220)) + [290, 292, 293, 294, 295, 299],
+ Science: (
+ [500, 501, 502]
+ + list(range(506, 510))
+ + list(range(520, 530))
+ + list(range(530, 540))
+ + list(range(540, 550))
+ + list(range(550, 560))
+ + list(range(560, 570))
+ + list(range(570, 580))
+ + list(range(580, 590))
+ + list(range(590, 600))
+ ),
+ Social_Sciences: (
+ list(range(300, 310))
+ + list(range(360, 364))
+ + list(range(390, 397))
+ + [399]
+ ),
+ Sports: list(range(796, 800)),
+ Technology: (
+ [600, 601, 602, 604]
+ + list(range(606, 610))
+ + list(range(610, 640))
+ + list(range(660, 670))
+ + list(range(670, 680))
+ + list(range(681, 690))
+ + list(range(690, 700))
+ ),
+ Travel: list(range(910, 920)),
+ United_States_History: list(range(973, 980)),
+ World_History: [909],
}
@classmethod
@@ -94,11 +117,11 @@ def scrub_identifier(cls, identifier):
identifier = identifier.upper()
- if identifier.startswith('[') and identifier.endswith(']'):
+ if identifier.startswith("[") and identifier.endswith("]"):
# This is just bad data.
identifier = identifier[1:-1]
- if identifier.startswith('C') or identifier.startswith('A'):
+ if identifier.startswith("C") or identifier.startswith("A"):
# A work from our Canadian neighbors or our Australian
# friends.
identifier = identifier[1:]
@@ -108,8 +131,8 @@ def scrub_identifier(cls, identifier):
# Trim everything after the first period. We don't know how to
# deal with it.
- if '.' in identifier:
- identifier = identifier.split('.')[0]
+ if "." in identifier:
+ identifier = identifier.split(".")[0]
try:
identifier = int(identifier)
except ValueError:
@@ -122,13 +145,14 @@ def scrub_identifier(cls, identifier):
@classmethod
def is_fiction(cls, identifier, name):
"""Is the given DDC classification likely to contain fiction?"""
- if identifier == 'Y':
+ if identifier == "Y":
# Inconsistently used for young adult fiction and
# young adult nonfiction.
return None
- if (isinstance(identifier, (bytes, str)) and (
- identifier.startswith('Y') or identifier.startswith('J'))):
+ if isinstance(identifier, (bytes, str)) and (
+ identifier.startswith("Y") or identifier.startswith("J")
+ ):
# Young adult/children's literature--not necessarily fiction
identifier = identifier[1:]
try:
@@ -148,17 +172,17 @@ def is_fiction(cls, identifier, name):
@classmethod
def audience(cls, identifier, name):
- if identifier == 'E':
+ if identifier == "E":
# Juvenile fiction
return cls.AUDIENCE_CHILDREN
- if isinstance(identifier, (bytes, str)) and identifier.startswith('J'):
+ if isinstance(identifier, (bytes, str)) and identifier.startswith("J"):
return cls.AUDIENCE_CHILDREN
- if isinstance(identifier, (bytes, str)) and identifier.startswith('Y'):
+ if isinstance(identifier, (bytes, str)) and identifier.startswith("Y"):
return cls.AUDIENCE_YOUNG_ADULT
- if isinstance(identifier, (bytes, str)) and identifier=='FIC':
+ if isinstance(identifier, (bytes, str)) and identifier == "FIC":
# FIC is used for all types of fiction.
return None
@@ -170,9 +194,10 @@ def audience(cls, identifier, name):
def genre(cls, identifier, name, fiction=None, audience=None):
for genre, identifiers in list(cls.GENRES.items()):
if identifier == identifiers or (
- isinstance(identifiers, list)
- and identifier in identifiers):
+ isinstance(identifiers, list) and identifier in identifiers
+ ):
return genre
return None
+
Classifier.classifiers[Classifier.DDC] = DeweyDecimalClassifier
diff --git a/classifier/gutenberg.py b/classifier/gutenberg.py
index 5f74b2534..bb3c5d9ac 100644
--- a/classifier/gutenberg.py
+++ b/classifier/gutenberg.py
@@ -2,21 +2,24 @@
from . import *
+
class GutenbergBookshelfClassifier(Classifier):
# Any classification that includes the string "Fiction" will be
# counted as fiction. This is just the leftovers.
- FICTION = set([
- "Bestsellers, American, 1895-1923",
- "Adventure",
- "Fantasy",
- "Horror",
- "Mystery",
- "Western",
- "Suspense",
- "Thriller",
- "Children's Anthologies",
- ])
+ FICTION = set(
+ [
+ "Bestsellers, American, 1895-1923",
+ "Adventure",
+ "Fantasy",
+ "Horror",
+ "Mystery",
+ "Western",
+ "Suspense",
+ "Thriller",
+ "Children's Anthologies",
+ ]
+ )
GENRES = {
Adventure: [
@@ -25,60 +28,60 @@ class GutenbergBookshelfClassifier(Classifier):
],
# African_American : ["African American Writers"],
Ancient_History: ["Classical Antiquity"],
- Architecture : [
+ Architecture: [
"Architecture",
"The American Architect and Building News",
],
- Art : ["Art"],
- Biography_Memoir : [
+ Art: ["Art"],
+ Biography_Memoir: [
"Biographies",
"Children's Biography",
],
- Christianity : ["Christianity"],
+ Christianity: ["Christianity"],
Civil_War_History: "US Civil War",
- Classics : [
+ Classics: [
"Best Books Ever Listings",
"Harvard Classics",
],
- Cooking : [
+ Cooking: [
"Armour's Monthly Cook Book",
"Cookery",
],
- Drama : [
+ Drama: [
"One Act Plays",
"Opera",
"Plays",
],
- Erotica : "Erotic Fiction",
- Fantasy : "Fantasy",
- Foreign_Language_Study : [
+ Erotica: "Erotic Fiction",
+ Fantasy: "Fantasy",
+ Foreign_Language_Study: [
"Language Education",
],
- Gardening : [
+ Gardening: [
"Garden and Forest",
"Horticulture",
],
- Historical_Fiction : "Historical Fiction",
- History : [
+ Historical_Fiction: "Historical Fiction",
+ History: [
"Children's History",
],
- Horror : ["Gothic Fiction", "Horror"],
- Humorous_Fiction : ["Humor"],
- Islam : "Islam",
- Judaism : "Judaism",
- Law : [
+ Horror: ["Gothic Fiction", "Horror"],
+ Humorous_Fiction: ["Humor"],
+ Islam: "Islam",
+ Judaism: "Judaism",
+ Law: [
"British Law",
"Noteworthy Trials",
"United States Law",
],
- Literary_Criticism : ["Bibliomania"],
- Mathematics : "Mathematics",
- Medical : [
+ Literary_Criticism: ["Bibliomania"],
+ Mathematics: "Mathematics",
+ Medical: [
"Medicine",
"The North American Medical and Surgical Journal",
"Physiology",
],
- Military_History : [
+ Military_History: [
"American Revolutionary War",
"World War I",
"World War II",
@@ -87,22 +90,21 @@ class GutenbergBookshelfClassifier(Classifier):
"Napoleonic",
],
Modern_History: "Current History",
- Music : [
+ Music: [
"Music",
"Child's Own Book of Great Musicians",
],
- Mystery : [
+ Mystery: [
"Crime Fiction",
"Detective Fiction",
"Mystery Fiction",
],
- Nature : [
+ Nature: [
"Animal",
"Animals-Wild",
- "Bird-Lore"
- "Birds, Illustrated by Color Photography",
+ "Bird-Lore" "Birds, Illustrated by Color Photography",
],
- Periodicals : [
+ Periodicals: [
"Ainslee's",
"Prairie Farmer",
"Blackwood's Edinburgh Magazine",
@@ -171,31 +173,31 @@ class GutenbergBookshelfClassifier(Classifier):
"The Yellow Book",
"Women's Travel Journals",
],
- Pets : ["Animals-Domestic"],
- Philosophy : ["Philosophy"],
- Photography : "Photography",
- Poetry : [
+ Pets: ["Animals-Domestic"],
+ Philosophy: ["Philosophy"],
+ Photography: "Photography",
+ Poetry: [
"Poetry",
"Poetry, A Magazine of Verse",
"Children's Verse",
],
- Political_Science : [
+ Political_Science: [
"Anarchism",
"Politics",
],
- Psychology : ["Psychology"],
- Reference_Study_Aids : [
+ Psychology: ["Psychology"],
+ Reference_Study_Aids: [
"Reference",
"CIA World Factbooks",
],
- Religion_Spirituality : [
+ Religion_Spirituality: [
"Atheism",
"Bahá'í Faith",
"Hinduism",
"Paganism",
"Children's Religion",
],
- Science : [
+ Science: [
"Astronomy",
"Biology",
"Botany",
@@ -211,30 +213,30 @@ class GutenbergBookshelfClassifier(Classifier):
"Physics",
"Scientific American",
],
- Science_Fiction : [
+ Science_Fiction: [
"Astounding Stories",
"Precursors of Science Fiction",
"The Galaxy",
"Science Fiction",
],
- Social_Sciences : [
+ Social_Sciences: [
"Anthropology",
"Archaeology",
"The American Journal of Archaeology",
"Sociology",
],
- Suspense_Thriller : [
+ Suspense_Thriller: [
"Suspense",
"Thriller",
],
- Technology : [
+ Technology: [
"Engineering",
"Technology",
"Transportation",
],
- Travel : "Travel",
- True_Crime : "Crime Nonfiction",
- Westerns : "Western",
+ Travel: "Travel",
+ True_Crime: "Crime Nonfiction",
+ Westerns: "Western",
}
@classmethod
@@ -243,14 +245,17 @@ def scrub_identifier(cls, identifier):
@classmethod
def is_fiction(cls, identifier, name):
- if (identifier in cls.FICTION
- or "Fiction" in identifier or "Stories" in identifier):
+ if (
+ identifier in cls.FICTION
+ or "Fiction" in identifier
+ or "Stories" in identifier
+ ):
return True
return None
@classmethod
def audience(cls, identifier, name):
- if ("Children's" in identifier):
+ if "Children's" in identifier:
return cls.AUDIENCE_CHILDREN
return cls.AUDIENCE_ADULT
@@ -261,4 +266,5 @@ def genre(cls, identifier, name, fiction=None, audience=None):
return l
return None
+
Classifier.classifiers[Classifier.GUTENBERG_BOOKSHELF] = GutenbergBookshelfClassifier
diff --git a/classifier/keyword.py b/classifier/keyword.py
index bbf31a132..28530b7c7 100644
--- a/classifier/keyword.py
+++ b/classifier/keyword.py
@@ -1,11 +1,13 @@
from . import *
+
def match_kw(*l):
"""Turn a list of strings into a function which uses a regular expression
to match any of those strings, so long as there's a word boundary on both ends.
The function will match all the strings by default, or can exclude the strings
that are examples of the classification.
"""
+
def match_term(term, exclude_examples=False):
if not l:
return None
@@ -17,13 +19,13 @@ def match_term(term, exclude_examples=False):
if not keywords:
return None
any_keyword = "|".join(keywords)
- with_boundaries = r'\b(%s)\b' % any_keyword
+ with_boundaries = r"\b(%s)\b" % any_keyword
return re.compile(with_boundaries, re.I).search(term)
-
# This is a dictionary so it can be used as a class variable
return {"search": match_term}
+
class Eg(object):
"""Mark this string as an example of a classification, rather than
an exact identifier for that classification. For example, basketball
@@ -37,28 +39,39 @@ def __init__(self, term):
def __str__(self):
return self.term
+
class KeywordBasedClassifier(AgeOrGradeClassifier):
"""Classify a book based on keywords."""
# We have to handle these first because otherwise '\bfiction\b'
# will match it.
- LEVEL_1_NONFICTION_INDICATORS = match_kw(
- "non-fiction", "non fiction"
- )
+ LEVEL_1_NONFICTION_INDICATORS = match_kw("non-fiction", "non fiction")
LEVEL_2_FICTION_INDICATORS = match_kw(
- "fiction", Eg("stories"), Eg("tales"), Eg("literature"),
- Eg("bildungsromans"), "fictitious",
+ "fiction",
+ Eg("stories"),
+ Eg("tales"),
+ Eg("literature"),
+ Eg("bildungsromans"),
+ "fictitious",
)
LEVEL_2_NONFICTION_INDICATORS = match_kw(
- Eg("history"), Eg("biography"), Eg("histories"),
- Eg("biographies"), Eg("autobiography"), Eg("autobiographies"),
- "nonfiction", Eg("essays"), Eg("letters"), Eg("true story"),
- Eg("personal memoirs"))
+ Eg("history"),
+ Eg("biography"),
+ Eg("histories"),
+ Eg("biographies"),
+ Eg("autobiography"),
+ Eg("autobiographies"),
+ "nonfiction",
+ Eg("essays"),
+ Eg("letters"),
+ Eg("true story"),
+ Eg("personal memoirs"),
+ )
JUVENILE_INDICATORS = match_kw(
- "for children", "children's", "juvenile",
- Eg("nursery rhymes"), Eg("9-12"))
+ "for children", "children's", "juvenile", Eg("nursery rhymes"), Eg("9-12")
+ )
YOUNG_ADULT_INDICATORS = match_kw(
"young adult",
"ya",
@@ -72,27 +85,31 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
# Children's books don't generally deal with romance, so although
# "Juvenile Fiction" generally refers to children's fiction,
# "Juvenile Fiction / Love & Romance" is almost certainly YA.
- JUVENILE_TERMS_THAT_IMPLY_YOUNG_ADULT = set([
- "love & romance",
- "romance",
- "romantic",
- ])
+ JUVENILE_TERMS_THAT_IMPLY_YOUNG_ADULT = set(
+ [
+ "love & romance",
+ "romance",
+ "romantic",
+ ]
+ )
# These identifiers indicate that the string "children" or
# "juvenile" in the identifier does not actually mean the work is
# _for_ children.
- JUVENILE_BLACKLIST = set([
- "military participation",
- "services",
- "children's accidents",
- "children's voices",
- "juvenile delinquency",
- "children's television workshop",
- "missing children",
- ])
+ JUVENILE_BLACKLIST = set(
+ [
+ "military participation",
+ "services",
+ "children's accidents",
+ "children's voices",
+ "juvenile delinquency",
+ "children's television workshop",
+ "missing children",
+ ]
+ )
CATCHALL_KEYWORDS = {
- Adventure : match_kw(
+ Adventure: match_kw(
"adventure",
"adventurers",
"adventure stories",
@@ -102,32 +119,27 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
Eg("war stories"),
Eg("men's adventure"),
),
-
African_History: match_kw(
"african history",
"history.*africa",
),
-
Ancient_History: match_kw(
"ancient.*history",
"history.*ancient",
"civilization, classical",
),
-
Antiques_Collectibles: match_kw(
"antiques",
"collectibles",
"collectors",
"collecting",
),
-
Architecture: match_kw(
"architecture",
"architectural",
"architect",
"architects",
),
-
Art: match_kw(
"art",
"arts",
@@ -135,22 +147,18 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
"artists",
"artistic",
),
-
Art_Criticism_Theory: match_kw(
"art criticism",
"art / criticism & theory",
),
-
Art_History: match_kw(
"art.*history",
),
-
Asian_History: match_kw(
"asian history",
"history.*asia",
"australasian & pacific history",
),
-
Bartending_Cocktails: match_kw(
"cocktail",
"cocktails",
@@ -161,224 +169,200 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
Eg("wine & spirits"),
"spirits & cocktails",
),
-
- Biography_Memoir : match_kw(
- "autobiographies",
- "autobiography",
- "biographies",
- "biography",
- "biographical",
- "personal memoirs",
- ),
-
- Body_Mind_Spirit: match_kw(
- "body, mind & spirit",
- ),
-
- Buddhism: match_kw(
- "buddhism",
- "buddhist",
- "buddha",
- ),
-
- Business: match_kw(
- "business",
- "businesspeople",
- "businesswomen",
- "businessmen",
- "business & economics",
- "business & financial",
- "commerce",
- "sales",
- "selling",
- "sales & selling",
- Eg("nonprofit"),
- ),
-
- Christianity : match_kw(
- Eg("schema:creativework:bible"),
- Eg("baptist"),
- Eg("bible"),
- Eg("sermons"),
- Eg("devotional"),
- Eg("theological"),
- Eg("theology"),
- Eg('biblical'),
- "christian",
- "christianity",
- Eg("catholic"),
- Eg("protestant"),
- Eg("catholicism"),
- Eg("protestantism"),
- Eg("church"),
- Eg("christmas & advent"),
- ),
-
- Civil_War_History: match_kw(
- "american civil war",
- "1861-1865",
- "civil war period",
- ),
-
- Classics: match_kw(
- 'classics',
- ),
-
- Computers : match_kw(
- "computer",
- "computer science",
- "computational",
- "computers",
- "computing",
- Eg("data"),
- Eg("database"),
- Eg("hardware"),
- Eg("software"),
- Eg("software development"),
- Eg("information technology"),
- Eg("web"),
- Eg("world wide web"),
- ),
-
- Contemporary_Romance: match_kw(
- "contemporary romance",
- "romance--contemporary",
- "romance / contemporary",
- "romance - contemporary",
- ),
-
- Cooking : match_kw(
- Eg("non-alcoholic"),
- Eg("baking"),
- "cookbook",
- "cooking",
- "food",
- Eg("health & healing"),
- "home economics",
- "cuisine",
- ),
-
- Crafts_Hobbies: match_kw(
- "arts & crafts",
- "arts, crafts",
- Eg("beadwork"),
- Eg("candle crafts"),
- Eg("candle making"),
- Eg("carving"),
- Eg("ceramics"),
- "crafts & hobbies",
- "crafts",
- Eg("crochet"),
- Eg("crocheting"),
- Eg("cross-stitch"),
- "decorative arts",
- Eg("flower arranging"),
- "folkcrafts",
- "handicrafts",
- "hobbies",
- "hobby",
- "hobbyist",
- "hobbyists",
- Eg("jewelry"),
- Eg("knitting"),
- Eg("metal work"),
- Eg("needlework"),
- Eg("origami"),
- Eg("paper crafts"),
- Eg("pottery"),
- Eg("quilting"),
- Eg("quilts"),
- Eg("scrapbooking"),
- Eg("sewing"),
- Eg("soap making"),
- Eg("stamping"),
- Eg("stenciling"),
- Eg("textile crafts"),
- Eg("toymaking"),
- Eg("weaving"),
- Eg("woodwork"),
- ),
-
- Design: match_kw(
- "design",
- "designer",
- "designers",
- Eg("graphic design"),
- Eg("typography")
- ),
-
- Dictionaries: match_kw(
- "dictionaries",
- "dictionary",
- ),
-
- Drama : match_kw(
- Eg("comedies"),
- "drama",
- "dramatist",
- "dramatists",
- Eg("operas"),
- Eg("plays"),
- Eg("shakespeare"),
- Eg("tragedies"),
- Eg("tragedy"),
- ),
-
- Economics: match_kw(
- Eg("banking"),
- "economy",
- "economies",
- "economic",
- "economics",
- ),
-
- Education: match_kw(
- # TODO: a lot of these don't work well because of
- # the huge amount of fiction about students. This
- # will be fixed when we institute the
- # fiction/nonfiction split.
- "education",
- "educational",
- "educator",
- "educators",
- Eg("principals"),
- "teacher",
- "teachers",
- "teaching",
- #"schools",
- #"high school",
- "schooling",
- #"student",
- #"students",
- #"college",
- Eg("university"),
- Eg("universities"),
- ),
-
- Epic_Fantasy: match_kw(
- "epic fantasy",
- "fantasy - epic",
- "fantasy / epic",
- "fantasy--epic",
- "fantasy/epic",
- ),
-
- Espionage: match_kw(
- "espionage",
- "intrigue",
- "spies",
- "spy stories",
- "spy novels",
- "spy fiction",
- "spy thriller",
- ),
-
- Erotica : match_kw(
- 'erotic',
- 'erotica',
- ),
-
- # TODO: history _plus_ a place
+ Biography_Memoir: match_kw(
+ "autobiographies",
+ "autobiography",
+ "biographies",
+ "biography",
+ "biographical",
+ "personal memoirs",
+ ),
+ Body_Mind_Spirit: match_kw(
+ "body, mind & spirit",
+ ),
+ Buddhism: match_kw(
+ "buddhism",
+ "buddhist",
+ "buddha",
+ ),
+ Business: match_kw(
+ "business",
+ "businesspeople",
+ "businesswomen",
+ "businessmen",
+ "business & economics",
+ "business & financial",
+ "commerce",
+ "sales",
+ "selling",
+ "sales & selling",
+ Eg("nonprofit"),
+ ),
+ Christianity: match_kw(
+ Eg("schema:creativework:bible"),
+ Eg("baptist"),
+ Eg("bible"),
+ Eg("sermons"),
+ Eg("devotional"),
+ Eg("theological"),
+ Eg("theology"),
+ Eg("biblical"),
+ "christian",
+ "christianity",
+ Eg("catholic"),
+ Eg("protestant"),
+ Eg("catholicism"),
+ Eg("protestantism"),
+ Eg("church"),
+ Eg("christmas & advent"),
+ ),
+ Civil_War_History: match_kw(
+ "american civil war",
+ "1861-1865",
+ "civil war period",
+ ),
+ Classics: match_kw(
+ "classics",
+ ),
+ Computers: match_kw(
+ "computer",
+ "computer science",
+ "computational",
+ "computers",
+ "computing",
+ Eg("data"),
+ Eg("database"),
+ Eg("hardware"),
+ Eg("software"),
+ Eg("software development"),
+ Eg("information technology"),
+ Eg("web"),
+ Eg("world wide web"),
+ ),
+ Contemporary_Romance: match_kw(
+ "contemporary romance",
+ "romance--contemporary",
+ "romance / contemporary",
+ "romance - contemporary",
+ ),
+ Cooking: match_kw(
+ Eg("non-alcoholic"),
+ Eg("baking"),
+ "cookbook",
+ "cooking",
+ "food",
+ Eg("health & healing"),
+ "home economics",
+ "cuisine",
+ ),
+ Crafts_Hobbies: match_kw(
+ "arts & crafts",
+ "arts, crafts",
+ Eg("beadwork"),
+ Eg("candle crafts"),
+ Eg("candle making"),
+ Eg("carving"),
+ Eg("ceramics"),
+ "crafts & hobbies",
+ "crafts",
+ Eg("crochet"),
+ Eg("crocheting"),
+ Eg("cross-stitch"),
+ "decorative arts",
+ Eg("flower arranging"),
+ "folkcrafts",
+ "handicrafts",
+ "hobbies",
+ "hobby",
+ "hobbyist",
+ "hobbyists",
+ Eg("jewelry"),
+ Eg("knitting"),
+ Eg("metal work"),
+ Eg("needlework"),
+ Eg("origami"),
+ Eg("paper crafts"),
+ Eg("pottery"),
+ Eg("quilting"),
+ Eg("quilts"),
+ Eg("scrapbooking"),
+ Eg("sewing"),
+ Eg("soap making"),
+ Eg("stamping"),
+ Eg("stenciling"),
+ Eg("textile crafts"),
+ Eg("toymaking"),
+ Eg("weaving"),
+ Eg("woodwork"),
+ ),
+ Design: match_kw(
+ "design", "designer", "designers", Eg("graphic design"), Eg("typography")
+ ),
+ Dictionaries: match_kw(
+ "dictionaries",
+ "dictionary",
+ ),
+ Drama: match_kw(
+ Eg("comedies"),
+ "drama",
+ "dramatist",
+ "dramatists",
+ Eg("operas"),
+ Eg("plays"),
+ Eg("shakespeare"),
+ Eg("tragedies"),
+ Eg("tragedy"),
+ ),
+ Economics: match_kw(
+ Eg("banking"),
+ "economy",
+ "economies",
+ "economic",
+ "economics",
+ ),
+ Education: match_kw(
+ # TODO: a lot of these don't work well because of
+ # the huge amount of fiction about students. This
+ # will be fixed when we institute the
+ # fiction/nonfiction split.
+ "education",
+ "educational",
+ "educator",
+ "educators",
+ Eg("principals"),
+ "teacher",
+ "teachers",
+ "teaching",
+ # "schools",
+ # "high school",
+ "schooling",
+ # "student",
+ # "students",
+ # "college",
+ Eg("university"),
+ Eg("universities"),
+ ),
+ Epic_Fantasy: match_kw(
+ "epic fantasy",
+ "fantasy - epic",
+ "fantasy / epic",
+ "fantasy--epic",
+ "fantasy/epic",
+ ),
+ Espionage: match_kw(
+ "espionage",
+ "intrigue",
+ "spies",
+ "spy stories",
+ "spy novels",
+ "spy fiction",
+ "spy thriller",
+ ),
+ Erotica: match_kw(
+ "erotic",
+ "erotica",
+ ),
+ # TODO: history _plus_ a place
European_History: match_kw(
"europe.*history",
"history.*europe",
@@ -392,159 +376,139 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
Eg("history.*germany"),
# etc. etc. etc.
),
-
- Family_Relationships: match_kw(
- "family & relationships",
- "relationships",
- "family relationships",
- "human sexuality",
- "sexuality",
- ),
-
- Fantasy : match_kw(
- "fantasy",
- Eg("magic"),
- Eg("wizards"),
- Eg("fairies"),
- Eg("witches"),
- Eg("dragons"),
- Eg("sorcery"),
- Eg("witchcraft"),
- Eg("wizardry"),
- Eg("unicorns"),
- ),
-
- Fashion: match_kw(
- "fashion",
- "fashion design",
- "fashion designers",
- ),
-
- Film_TV: match_kw(
- Eg("director"),
- Eg("directors"),
- "film",
- "films",
- "movies",
- "movie",
- "motion picture",
- "motion pictures",
- "moviemaker",
- "moviemakers",
- Eg("producer"),
- Eg("producers"),
- "television",
- "tv",
- "video",
- ),
-
- Foreign_Language_Study: match_kw(
- Eg("english as a foreign language"),
- Eg("english as a second language"),
- Eg("esl"),
- "foreign language study",
- Eg("multi-language dictionaries"),
- ),
-
- Games : match_kw(
- "games",
- Eg("video games"),
- "gaming",
- Eg("gambling"),
- ),
-
- Gardening: match_kw(
- "gardening",
- "horticulture",
- ),
-
- Comics_Graphic_Novels: match_kw(
- "comics",
- "comic strip",
- "comic strips",
- "comic book",
- "comic books",
- "graphic novel",
- "graphic novels",
-
- # Formerly in 'Superhero'
- Eg("superhero"),
- Eg("superheroes"),
-
- # Formerly in 'Manga'
- Eg("japanese comic books"),
- Eg("japanese comics"),
- Eg("manga"),
- Eg("yaoi"),
-
- ),
-
- Hard_Boiled_Mystery: match_kw(
- "hard-boiled",
- "noir",
- ),
-
- Health_Diet: match_kw(
- # ! "health services" ?
- "fitness",
- "health",
- "health aspects",
- "health & fitness",
- "hygiene",
- "nutrition",
- "diet",
- "diets",
- "weight loss",
- ),
-
- Hinduism: match_kw(
- "hinduism",
- "hindu",
- "hindus",
- ),
-
- Historical_Fiction : match_kw(
- "historical fiction",
- "fiction.*historical",
- "^historical$",
- ),
-
- Historical_Romance: match_kw(
- "historical romance",
- Eg("regency romance"),
- Eg("romance.*regency"),
- ),
-
- History : match_kw(
- "histories",
- "history",
- "historiography",
- "historical period",
- Eg("pre-confederation"),
- ),
-
- Horror : match_kw(
- "horror",
- Eg("occult"),
- Eg("ghost"),
- Eg("ghost stories"),
- Eg("vampires"),
- Eg("paranormal fiction"),
- Eg("occult fiction"),
- Eg("supernatural"),
- "scary",
- ),
-
- House_Home: match_kw(
- "house and home",
- "house & home",
- Eg("remodeling"),
- Eg("renovation"),
- Eg("caretaking"),
- Eg("interior decorating"),
- ),
-
- Humorous_Fiction : match_kw(
+ Family_Relationships: match_kw(
+ "family & relationships",
+ "relationships",
+ "family relationships",
+ "human sexuality",
+ "sexuality",
+ ),
+ Fantasy: match_kw(
+ "fantasy",
+ Eg("magic"),
+ Eg("wizards"),
+ Eg("fairies"),
+ Eg("witches"),
+ Eg("dragons"),
+ Eg("sorcery"),
+ Eg("witchcraft"),
+ Eg("wizardry"),
+ Eg("unicorns"),
+ ),
+ Fashion: match_kw(
+ "fashion",
+ "fashion design",
+ "fashion designers",
+ ),
+ Film_TV: match_kw(
+ Eg("director"),
+ Eg("directors"),
+ "film",
+ "films",
+ "movies",
+ "movie",
+ "motion picture",
+ "motion pictures",
+ "moviemaker",
+ "moviemakers",
+ Eg("producer"),
+ Eg("producers"),
+ "television",
+ "tv",
+ "video",
+ ),
+ Foreign_Language_Study: match_kw(
+ Eg("english as a foreign language"),
+ Eg("english as a second language"),
+ Eg("esl"),
+ "foreign language study",
+ Eg("multi-language dictionaries"),
+ ),
+ Games: match_kw(
+ "games",
+ Eg("video games"),
+ "gaming",
+ Eg("gambling"),
+ ),
+ Gardening: match_kw(
+ "gardening",
+ "horticulture",
+ ),
+ Comics_Graphic_Novels: match_kw(
+ "comics",
+ "comic strip",
+ "comic strips",
+ "comic book",
+ "comic books",
+ "graphic novel",
+ "graphic novels",
+ # Formerly in 'Superhero'
+ Eg("superhero"),
+ Eg("superheroes"),
+ # Formerly in 'Manga'
+ Eg("japanese comic books"),
+ Eg("japanese comics"),
+ Eg("manga"),
+ Eg("yaoi"),
+ ),
+ Hard_Boiled_Mystery: match_kw(
+ "hard-boiled",
+ "noir",
+ ),
+ Health_Diet: match_kw(
+ # ! "health services" ?
+ "fitness",
+ "health",
+ "health aspects",
+ "health & fitness",
+ "hygiene",
+ "nutrition",
+ "diet",
+ "diets",
+ "weight loss",
+ ),
+ Hinduism: match_kw(
+ "hinduism",
+ "hindu",
+ "hindus",
+ ),
+ Historical_Fiction: match_kw(
+ "historical fiction",
+ "fiction.*historical",
+ "^historical$",
+ ),
+ Historical_Romance: match_kw(
+ "historical romance",
+ Eg("regency romance"),
+ Eg("romance.*regency"),
+ ),
+ History: match_kw(
+ "histories",
+ "history",
+ "historiography",
+ "historical period",
+ Eg("pre-confederation"),
+ ),
+ Horror: match_kw(
+ "horror",
+ Eg("occult"),
+ Eg("ghost"),
+ Eg("ghost stories"),
+ Eg("vampires"),
+ Eg("paranormal fiction"),
+ Eg("occult fiction"),
+ Eg("supernatural"),
+ "scary",
+ ),
+ House_Home: match_kw(
+ "house and home",
+ "house & home",
+ Eg("remodeling"),
+ Eg("renovation"),
+ Eg("caretaking"),
+ Eg("interior decorating"),
+ ),
+ Humorous_Fiction: match_kw(
"comedy",
"funny",
"humor",
@@ -554,7 +518,7 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
Eg("satire"),
"wit",
),
- Humorous_Nonfiction : match_kw(
+ Humorous_Nonfiction: match_kw(
"comedy",
"funny",
"humor",
@@ -563,146 +527,134 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
"humourous",
"wit",
),
-
- Entertainment: match_kw(
- # Almost a pure top-level category
- "entertainment",
- ),
-
- # These might be a problem because they might pick up
+ Entertainment: match_kw(
+ # Almost a pure top-level category
+ "entertainment",
+ ),
+ # These might be a problem because they might pick up
# hateful books. Not sure if this will be a problem.
- Islam : match_kw(
- 'islam', 'islamic', 'muslim', 'muslims', Eg('halal'),
- 'islamic studies',
+ Islam: match_kw(
+ "islam",
+ "islamic",
+ "muslim",
+ "muslims",
+ Eg("halal"),
+ "islamic studies",
),
-
- Judaism: match_kw(
- 'judaism', 'jewish', Eg('kosher'), 'jews',
- 'jewish studies',
- ),
-
- LGBTQ_Fiction: match_kw(
- 'lgbt',
- 'lgbtq',
- Eg('lesbian'),
- Eg('lesbians'),
- 'gay',
- Eg('bisexual'),
- Eg('transgender'),
- Eg('transsexual'),
- Eg('transsexuals'),
- 'homosexual',
- 'homosexuals',
- 'homosexuality',
- 'queer',
- ),
-
- Latin_American_History: match_kw(
- ),
-
- Law: match_kw(
- "court",
- "judicial",
- "law",
- "laws",
- "legislation",
- "legal",
- ),
-
- Legal_Thriller: match_kw(
- "legal thriller",
- "legal thrillers",
- ),
-
- Literary_Criticism: match_kw(
- "criticism, interpretation",
- ),
-
- Literary_Fiction: match_kw(
- "literary",
- "literary fiction",
- "general fiction",
- "fiction[^a-z]+general",
- "fiction[^a-z]+literary",
- ),
-
- Management_Leadership: match_kw(
- "management",
- "business & economics / leadership",
- "business & economics -- leadership",
- "management science",
- ),
-
- Mathematics : match_kw(
- Eg("algebra"),
- Eg("arithmetic"),
- Eg("calculus"),
- Eg("chaos theory"),
- Eg("game theory"),
- Eg("geometry"),
- Eg("group theory"),
- Eg("logic"),
- "math",
- "mathematical",
- "mathematician",
- "mathematicians",
- "mathematics",
- Eg("probability"),
- Eg("statistical"),
- Eg("statistics"),
- Eg("trigonometry"),
- ),
-
- Medical : match_kw(
- Eg("anatomy"),
- Eg("disease"),
- Eg("diseases"),
- Eg("disorders"),
- Eg("epidemiology"),
- Eg("illness"),
- Eg("illnesses"),
- "medical",
- "medicine",
- Eg("neuroscience"),
- Eg("ophthalmology"),
- Eg("physiology"),
- Eg("vaccines"),
- Eg("virus"),
- ),
-
- Medieval_History: match_kw(
- "civilization, medieval",
- "medieval period",
- "history.*medieval",
- ),
-
- Middle_East_History: match_kw(
- "middle east.*history",
- "history.*middle east",
- ),
-
-
- Military_History : match_kw(
- "military science",
- "warfare",
- "military",
- Eg("1914-1918"),
- Eg("1939-1945"),
- Eg("world war"),
- ),
-
- Modern_History: match_kw(
- Eg("1900 - 1999"),
- Eg("2000-2099"),
- "modern history",
- "history, modern",
- "history (modern)",
- "history--modern",
- Eg("history.*20th century"),
- Eg("history.*21st century"),
- ),
-
- # This is SF movie tie-ins, not movies & gaming per se.
+ Judaism: match_kw(
+ "judaism",
+ "jewish",
+ Eg("kosher"),
+ "jews",
+ "jewish studies",
+ ),
+ LGBTQ_Fiction: match_kw(
+ "lgbt",
+ "lgbtq",
+ Eg("lesbian"),
+ Eg("lesbians"),
+ "gay",
+ Eg("bisexual"),
+ Eg("transgender"),
+ Eg("transsexual"),
+ Eg("transsexuals"),
+ "homosexual",
+ "homosexuals",
+ "homosexuality",
+ "queer",
+ ),
+ Latin_American_History: match_kw(),
+ Law: match_kw(
+ "court",
+ "judicial",
+ "law",
+ "laws",
+ "legislation",
+ "legal",
+ ),
+ Legal_Thriller: match_kw(
+ "legal thriller",
+ "legal thrillers",
+ ),
+ Literary_Criticism: match_kw(
+ "criticism, interpretation",
+ ),
+ Literary_Fiction: match_kw(
+ "literary",
+ "literary fiction",
+ "general fiction",
+ "fiction[^a-z]+general",
+ "fiction[^a-z]+literary",
+ ),
+ Management_Leadership: match_kw(
+ "management",
+ "business & economics / leadership",
+ "business & economics -- leadership",
+ "management science",
+ ),
+ Mathematics: match_kw(
+ Eg("algebra"),
+ Eg("arithmetic"),
+ Eg("calculus"),
+ Eg("chaos theory"),
+ Eg("game theory"),
+ Eg("geometry"),
+ Eg("group theory"),
+ Eg("logic"),
+ "math",
+ "mathematical",
+ "mathematician",
+ "mathematicians",
+ "mathematics",
+ Eg("probability"),
+ Eg("statistical"),
+ Eg("statistics"),
+ Eg("trigonometry"),
+ ),
+ Medical: match_kw(
+ Eg("anatomy"),
+ Eg("disease"),
+ Eg("diseases"),
+ Eg("disorders"),
+ Eg("epidemiology"),
+ Eg("illness"),
+ Eg("illnesses"),
+ "medical",
+ "medicine",
+ Eg("neuroscience"),
+ Eg("ophthalmology"),
+ Eg("physiology"),
+ Eg("vaccines"),
+ Eg("virus"),
+ ),
+ Medieval_History: match_kw(
+ "civilization, medieval",
+ "medieval period",
+ "history.*medieval",
+ ),
+ Middle_East_History: match_kw(
+ "middle east.*history",
+ "history.*middle east",
+ ),
+ Military_History: match_kw(
+ "military science",
+ "warfare",
+ "military",
+ Eg("1914-1918"),
+ Eg("1939-1945"),
+ Eg("world war"),
+ ),
+ Modern_History: match_kw(
+ Eg("1900 - 1999"),
+ Eg("2000-2099"),
+ "modern history",
+ "history, modern",
+ "history (modern)",
+ "history--modern",
+ Eg("history.*20th century"),
+ Eg("history.*21st century"),
+ ),
+ # This is SF movie tie-ins, not movies & gaming per se.
# This one is difficult because it takes effect if book
# has subject "media tie-in" *and* "science fiction" or
# "fantasy"
@@ -712,504 +664,445 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
Eg("star wars"),
Eg("jedi"),
),
-
- Music: match_kw(
- "music",
- "musician",
- "musicians",
- "musical",
- Eg("genres & styles"),
- Eg("blues"),
- Eg("jazz"),
- Eg("rap"),
- Eg("hip-hop"),
- Eg("rock.*roll"),
- Eg("rock music"),
- Eg("punk rock"),
- ),
-
- Mystery : match_kw(
- Eg("crime"),
- Eg("detective"),
- Eg("murder"),
- "mystery",
- "mysteries",
- Eg("private investigators"),
- Eg("holmes, sherlock"),
- Eg("poirot, hercule"),
- Eg("schema:person:holmes, sherlock"),
- ),
-
- Nature : match_kw(
- # TODO: not sure about this one
- "nature",
- ),
-
- Body_Mind_Spirit: match_kw(
- "new age",
- ),
-
- Paranormal_Romance : match_kw(
- "paranormal romance",
- "romance.*paranormal",
- ),
-
- Parenting : match_kw(
- # "children" isn't here because the vast majority of
- # "children" tags indicate books _for_ children.
-
- # "family" isn't here because the vast majority
- # of "family" tags deal with specific families, e.g.
- # the Kennedys.
-
- "parenting",
- "parent",
- "parents",
- Eg("motherhood"),
- Eg("fatherhood"),
- ),
-
- Parenting_Family: match_kw(
- # Pure top-level category
- ),
-
- Performing_Arts: match_kw(
- "theatre",
- "theatrical",
- "performing arts",
- "entertainers",
- Eg("farce"),
- Eg("tragicomedy"),
- ),
-
- Periodicals : match_kw(
- "periodicals",
- "periodical",
- ),
-
- Personal_Finance_Investing: match_kw(
- "personal finance",
- "financial planning",
- "investing",
- Eg("retirement planning"),
- "money management",
- ),
-
- Pets: match_kw(
- "pets",
- Eg("dogs"),
- Eg("cats"),
- ),
-
- Philosophy : match_kw(
- "philosophy",
- "philosophical",
- "philosopher",
- "philosophers",
- Eg("epistemology"),
- Eg("metaphysics"),
- ),
-
- Photography: match_kw(
- "photography",
- "photographer",
- "photographers",
- "photographic",
- ),
-
- Police_Procedural: match_kw(
- "police[^a-z]+procedural",
- "police[^a-z]+procedurals",
- ),
-
- Poetry : match_kw(
- "poetry",
- "poet",
- "poets",
- "poem",
- "poems",
- Eg("sonnet"),
- Eg("sonnets"),
- ),
-
- Political_Science : match_kw(
- Eg("american government"),
- Eg("anarchism"),
- Eg("censorship"),
- Eg("citizenship"),
- Eg("civics"),
- Eg("communism"),
- Eg("corruption"),
- Eg("corrupt practices"),
- Eg("democracy"),
- Eg("geopolitics"),
- "government",
- Eg("human rights"),
- Eg("international relations"),
- Eg("political economy"),
- "political ideologies",
- "political process",
- "political science",
- Eg("public affairs"),
- Eg("public policy"),
- "politics",
- "political",
- Eg("current events"),
- ),
-
- Psychology: match_kw(
- "psychology",
- Eg("psychiatry"),
- "psychological aspects",
- Eg("psychiatric"),
- Eg("psychoanalysis"),
- ),
-
- Real_Estate: match_kw(
- "real estate",
- ),
-
- Reference_Study_Aids : match_kw(
- Eg("catalogs"),
- Eg("handbooks"),
- Eg("manuals"),
- Eg("reference"),
-
- # Formerly in 'Encyclopedias'
- Eg("encyclopaedias"),
- Eg("encyclopaedia"),
- Eg("encyclopedias"),
- Eg("encyclopedia"),
-
- # Formerly in 'Language Arts & Disciplines'
- Eg("alphabets"),
- Eg("communication studies"),
- Eg("composition"),
- Eg("creative writing"),
- Eg("grammar"),
- Eg("handwriting"),
- Eg("information sciences"),
- Eg("journalism"),
- Eg("library & information sciences"),
- Eg("linguistics"),
- Eg("literacy"),
- Eg("public speaking"),
- Eg("rhetoric"),
- Eg("sign language"),
- Eg("speech"),
- Eg("spelling"),
- Eg("style manuals"),
- Eg("syntax"),
- Eg("vocabulary"),
- Eg("writing systems"),
- ),
-
- Religion_Spirituality : match_kw(
- "religion",
- "religious",
- Eg("taoism"),
- Eg("taoist"),
- Eg("confucianism"),
- Eg("inspirational nonfiction"),
- ),
-
- Renaissance_Early_Modern_History: match_kw(
- "early modern period",
- "early modern history",
- "early modern, 1500-1700",
- "history.*early modern",
- "renaissance.*history",
- "history.*renaissance",
- ),
-
- Romance : match_kw(
- "love stories",
- "romance",
- "love & romance",
- "romances",
- ),
-
- Science : match_kw(
- Eg("aeronautics"),
- Eg("astronomy"),
- Eg("biology"),
- Eg("biophysics"),
- Eg("biochemistry"),
- Eg("botany"),
- Eg("chemistry"),
- Eg("earth sciences"),
- Eg("ecology"),
- Eg("entomology"),
- Eg("evolution"),
- Eg("geology"),
- Eg("genetics"),
- Eg("genetic engineering"),
- Eg("genomics"),
- Eg("ichthyology"),
- Eg("herpetology"),
- Eg("life sciences"),
- Eg("microbiology"),
- Eg("microscopy"),
- Eg("mycology"),
- Eg("ornithology"),
- Eg("natural history"),
- Eg("natural history"),
- Eg("physics"),
- "science",
- "scientist",
- "scientists",
- Eg("zoology"),
- Eg("virology"),
- Eg("cytology"),
- ),
-
- Science_Fiction : match_kw(
- "speculative fiction",
- "sci-fi",
- "sci fi",
- Eg("time travel"),
- ),
-
- #Science_Fiction_Fantasy: match_kw(
+ Music: match_kw(
+ "music",
+ "musician",
+ "musicians",
+ "musical",
+ Eg("genres & styles"),
+ Eg("blues"),
+ Eg("jazz"),
+ Eg("rap"),
+ Eg("hip-hop"),
+ Eg("rock.*roll"),
+ Eg("rock music"),
+ Eg("punk rock"),
+ ),
+ Mystery: match_kw(
+ Eg("crime"),
+ Eg("detective"),
+ Eg("murder"),
+ "mystery",
+ "mysteries",
+ Eg("private investigators"),
+ Eg("holmes, sherlock"),
+ Eg("poirot, hercule"),
+ Eg("schema:person:holmes, sherlock"),
+ ),
+ Nature: match_kw(
+ # TODO: not sure about this one
+ "nature",
+ ),
+ Body_Mind_Spirit: match_kw(
+ "new age",
+ ),
+ Paranormal_Romance: match_kw(
+ "paranormal romance",
+ "romance.*paranormal",
+ ),
+ Parenting: match_kw(
+ # "children" isn't here because the vast majority of
+ # "children" tags indicate books _for_ children.
+ # "family" isn't here because the vast majority
+ # of "family" tags deal with specific families, e.g.
+ # the Kennedys.
+ "parenting",
+ "parent",
+ "parents",
+ Eg("motherhood"),
+ Eg("fatherhood"),
+ ),
+ Parenting_Family: match_kw(
+ # Pure top-level category
+ ),
+ Performing_Arts: match_kw(
+ "theatre",
+ "theatrical",
+ "performing arts",
+ "entertainers",
+ Eg("farce"),
+ Eg("tragicomedy"),
+ ),
+ Periodicals: match_kw(
+ "periodicals",
+ "periodical",
+ ),
+ Personal_Finance_Investing: match_kw(
+ "personal finance",
+ "financial planning",
+ "investing",
+ Eg("retirement planning"),
+ "money management",
+ ),
+ Pets: match_kw(
+ "pets",
+ Eg("dogs"),
+ Eg("cats"),
+ ),
+ Philosophy: match_kw(
+ "philosophy",
+ "philosophical",
+ "philosopher",
+ "philosophers",
+ Eg("epistemology"),
+ Eg("metaphysics"),
+ ),
+ Photography: match_kw(
+ "photography",
+ "photographer",
+ "photographers",
+ "photographic",
+ ),
+ Police_Procedural: match_kw(
+ "police[^a-z]+procedural",
+ "police[^a-z]+procedurals",
+ ),
+ Poetry: match_kw(
+ "poetry",
+ "poet",
+ "poets",
+ "poem",
+ "poems",
+ Eg("sonnet"),
+ Eg("sonnets"),
+ ),
+ Political_Science: match_kw(
+ Eg("american government"),
+ Eg("anarchism"),
+ Eg("censorship"),
+ Eg("citizenship"),
+ Eg("civics"),
+ Eg("communism"),
+ Eg("corruption"),
+ Eg("corrupt practices"),
+ Eg("democracy"),
+ Eg("geopolitics"),
+ "government",
+ Eg("human rights"),
+ Eg("international relations"),
+ Eg("political economy"),
+ "political ideologies",
+ "political process",
+ "political science",
+ Eg("public affairs"),
+ Eg("public policy"),
+ "politics",
+ "political",
+ Eg("current events"),
+ ),
+ Psychology: match_kw(
+ "psychology",
+ Eg("psychiatry"),
+ "psychological aspects",
+ Eg("psychiatric"),
+ Eg("psychoanalysis"),
+ ),
+ Real_Estate: match_kw(
+ "real estate",
+ ),
+ Reference_Study_Aids: match_kw(
+ Eg("catalogs"),
+ Eg("handbooks"),
+ Eg("manuals"),
+ Eg("reference"),
+ # Formerly in 'Encyclopedias'
+ Eg("encyclopaedias"),
+ Eg("encyclopaedia"),
+ Eg("encyclopedias"),
+ Eg("encyclopedia"),
+ # Formerly in 'Language Arts & Disciplines'
+ Eg("alphabets"),
+ Eg("communication studies"),
+ Eg("composition"),
+ Eg("creative writing"),
+ Eg("grammar"),
+ Eg("handwriting"),
+ Eg("information sciences"),
+ Eg("journalism"),
+ Eg("library & information sciences"),
+ Eg("linguistics"),
+ Eg("literacy"),
+ Eg("public speaking"),
+ Eg("rhetoric"),
+ Eg("sign language"),
+ Eg("speech"),
+ Eg("spelling"),
+ Eg("style manuals"),
+ Eg("syntax"),
+ Eg("vocabulary"),
+ Eg("writing systems"),
+ ),
+ Religion_Spirituality: match_kw(
+ "religion",
+ "religious",
+ Eg("taoism"),
+ Eg("taoist"),
+ Eg("confucianism"),
+ Eg("inspirational nonfiction"),
+ ),
+ Renaissance_Early_Modern_History: match_kw(
+ "early modern period",
+ "early modern history",
+ "early modern, 1500-1700",
+ "history.*early modern",
+ "renaissance.*history",
+ "history.*renaissance",
+ ),
+ Romance: match_kw(
+ "love stories",
+ "romance",
+ "love & romance",
+ "romances",
+ ),
+ Science: match_kw(
+ Eg("aeronautics"),
+ Eg("astronomy"),
+ Eg("biology"),
+ Eg("biophysics"),
+ Eg("biochemistry"),
+ Eg("botany"),
+ Eg("chemistry"),
+ Eg("earth sciences"),
+ Eg("ecology"),
+ Eg("entomology"),
+ Eg("evolution"),
+ Eg("geology"),
+ Eg("genetics"),
+ Eg("genetic engineering"),
+ Eg("genomics"),
+ Eg("ichthyology"),
+ Eg("herpetology"),
+ Eg("life sciences"),
+ Eg("microbiology"),
+ Eg("microscopy"),
+ Eg("mycology"),
+ Eg("ornithology"),
+ Eg("natural history"),
+ Eg("natural history"),
+ Eg("physics"),
+ "science",
+ "scientist",
+ "scientists",
+ Eg("zoology"),
+ Eg("virology"),
+ Eg("cytology"),
+ ),
+ Science_Fiction: match_kw(
+ "speculative fiction",
+ "sci-fi",
+ "sci fi",
+ Eg("time travel"),
+ ),
+ # Science_Fiction_Fantasy: match_kw(
# "science fiction.*fantasy",
- #),
-
- Self_Help: match_kw(
- "self help",
- "self-help",
- "self improvement",
- "self-improvement",
- ),
- Folklore : match_kw(
- "fables",
- "folklore",
- "folktales",
- "folk tales",
- "myth",
- "legends",
- ),
-
- Short_Stories: match_kw(
- "short stories",
- Eg("literary collections"),
- ),
-
- Social_Sciences: match_kw(
- Eg("anthropology"),
- Eg("archaeology"),
- Eg("sociology"),
- Eg("ethnic studies"),
- Eg("feminism & feminist theory"),
- Eg("gender studies"),
- Eg("media studies"),
- Eg("minority studies"),
- Eg("men's studies"),
- Eg("regional studies"),
- Eg("women's studies"),
- Eg("demography"),
- Eg('lesbian studies'),
- Eg('gay studies'),
- Eg("black studies"),
- Eg("african-american studies"),
- Eg("customs & traditions"),
- Eg("criminology"),
- ),
-
- Sports: match_kw(
- # Ton of specific sports here since 'players'
- # doesn't work. TODO: Why? I don't remember.
- "sports",
- Eg("baseball"),
- Eg("football"),
- Eg("hockey"),
- Eg("soccer"),
- Eg("skating"),
- ),
-
- Study_Aids: match_kw(
- Eg("act"),
- Eg("advanced placement"),
- Eg("bar exam"),
- Eg("clep"),
- Eg("college entrance"),
- Eg("college guides"),
- Eg("financial aid"),
- Eg("certification"),
- Eg("ged"),
- Eg("gmat"),
- Eg("gre"),
- Eg("lsat"),
- Eg("mat"),
- Eg("mcat"),
- Eg("nmsqt"),
- Eg("nte"),
- Eg("psat"),
- Eg("sat"),
- "school guides",
- "study guide",
- "study guides",
- "study aids",
- Eg("toefl"),
- "workbooks",
- ),
-
- Romantic_Suspense : match_kw(
- "romantic.*suspense",
- "suspense.*romance",
- "romance.*suspense",
- "romantic.*thriller",
- "romance.*thriller",
- "thriller.*romance",
- ),
-
- Technology: match_kw(
- "technology",
- Eg("engineering"),
- Eg("bioengineering"),
- Eg("mechanics"),
-
- # Formerly in 'Transportation'
- Eg("transportation"),
- Eg("railroads"),
- Eg("trains"),
- Eg("automotive"),
- Eg("ships & shipbuilding"),
- Eg("cars & trucks"),
- ),
-
- Suspense_Thriller: match_kw(
- "thriller",
- "thrillers",
- "suspense",
- ),
-
- Technothriller : match_kw(
- "techno-thriller",
- "technothriller",
- "technothrillers",
- ),
-
- Travel : match_kw(
- Eg("discovery"),
- "exploration",
- "travel",
- "travels.*voyages",
- "voyage.*travels",
- "voyages",
- "travelers",
- "description.*travel",
- ),
-
- United_States_History: match_kw(
- "united states history",
- "u.s. history",
- Eg("american revolution"),
- Eg("1775-1783"),
- Eg("revolutionary period"),
- ),
-
- Urban_Fantasy: match_kw(
- "urban fantasy",
- "fantasy.*urban",
- ),
-
- Urban_Fiction: match_kw(
- "urban fiction",
- Eg("fiction.*african american.*urban"),
- ),
-
- Vegetarian_Vegan: match_kw(
- "vegetarian",
- Eg("vegan"),
- Eg("veganism"),
- "vegetarianism",
- ),
-
- Westerns : match_kw(
- "western stories",
- "westerns",
- ),
-
- Women_Detectives : match_kw(
- "women detectives",
- "women detective",
- "women private investigators",
- "women private investigator",
- "women sleuths",
- "women sleuth",
- ),
-
- Womens_Fiction : match_kw(
- "contemporary women",
- "chick lit",
- "womens fiction",
- "women's fiction",
- ),
-
- World_History: match_kw(
- "world history",
- "history[^a-z]*world",
- ),
+ # ),
+ Self_Help: match_kw(
+ "self help",
+ "self-help",
+ "self improvement",
+ "self-improvement",
+ ),
+ Folklore: match_kw(
+ "fables",
+ "folklore",
+ "folktales",
+ "folk tales",
+ "myth",
+ "legends",
+ ),
+ Short_Stories: match_kw(
+ "short stories",
+ Eg("literary collections"),
+ ),
+ Social_Sciences: match_kw(
+ Eg("anthropology"),
+ Eg("archaeology"),
+ Eg("sociology"),
+ Eg("ethnic studies"),
+ Eg("feminism & feminist theory"),
+ Eg("gender studies"),
+ Eg("media studies"),
+ Eg("minority studies"),
+ Eg("men's studies"),
+ Eg("regional studies"),
+ Eg("women's studies"),
+ Eg("demography"),
+ Eg("lesbian studies"),
+ Eg("gay studies"),
+ Eg("black studies"),
+ Eg("african-american studies"),
+ Eg("customs & traditions"),
+ Eg("criminology"),
+ ),
+ Sports: match_kw(
+ # Ton of specific sports here since 'players'
+ # doesn't work. TODO: Why? I don't remember.
+ "sports",
+ Eg("baseball"),
+ Eg("football"),
+ Eg("hockey"),
+ Eg("soccer"),
+ Eg("skating"),
+ ),
+ Study_Aids: match_kw(
+ Eg("act"),
+ Eg("advanced placement"),
+ Eg("bar exam"),
+ Eg("clep"),
+ Eg("college entrance"),
+ Eg("college guides"),
+ Eg("financial aid"),
+ Eg("certification"),
+ Eg("ged"),
+ Eg("gmat"),
+ Eg("gre"),
+ Eg("lsat"),
+ Eg("mat"),
+ Eg("mcat"),
+ Eg("nmsqt"),
+ Eg("nte"),
+ Eg("psat"),
+ Eg("sat"),
+ "school guides",
+ "study guide",
+ "study guides",
+ "study aids",
+ Eg("toefl"),
+ "workbooks",
+ ),
+ Romantic_Suspense: match_kw(
+ "romantic.*suspense",
+ "suspense.*romance",
+ "romance.*suspense",
+ "romantic.*thriller",
+ "romance.*thriller",
+ "thriller.*romance",
+ ),
+ Technology: match_kw(
+ "technology",
+ Eg("engineering"),
+ Eg("bioengineering"),
+ Eg("mechanics"),
+ # Formerly in 'Transportation'
+ Eg("transportation"),
+ Eg("railroads"),
+ Eg("trains"),
+ Eg("automotive"),
+ Eg("ships & shipbuilding"),
+ Eg("cars & trucks"),
+ ),
+ Suspense_Thriller: match_kw(
+ "thriller",
+ "thrillers",
+ "suspense",
+ ),
+ Technothriller: match_kw(
+ "techno-thriller",
+ "technothriller",
+ "technothrillers",
+ ),
+ Travel: match_kw(
+ Eg("discovery"),
+ "exploration",
+ "travel",
+ "travels.*voyages",
+ "voyage.*travels",
+ "voyages",
+ "travelers",
+ "description.*travel",
+ ),
+ United_States_History: match_kw(
+ "united states history",
+ "u.s. history",
+ Eg("american revolution"),
+ Eg("1775-1783"),
+ Eg("revolutionary period"),
+ ),
+ Urban_Fantasy: match_kw(
+ "urban fantasy",
+ "fantasy.*urban",
+ ),
+ Urban_Fiction: match_kw(
+ "urban fiction",
+ Eg("fiction.*african american.*urban"),
+ ),
+ Vegetarian_Vegan: match_kw(
+ "vegetarian",
+ Eg("vegan"),
+ Eg("veganism"),
+ "vegetarianism",
+ ),
+ Westerns: match_kw(
+ "western stories",
+ "westerns",
+ ),
+ Women_Detectives: match_kw(
+ "women detectives",
+ "women detective",
+ "women private investigators",
+ "women private investigator",
+ "women sleuths",
+ "women sleuth",
+ ),
+ Womens_Fiction: match_kw(
+ "contemporary women",
+ "chick lit",
+ "womens fiction",
+ "women's fiction",
+ ),
+ World_History: match_kw(
+ "world history",
+ "history[^a-z]*world",
+ ),
}
LEVEL_2_KEYWORDS = {
- Reference_Study_Aids : match_kw(
+ Reference_Study_Aids: match_kw(
# Formerly in 'Language Arts & Disciplines'
Eg("language arts & disciplines"),
Eg("language arts and disciplines"),
Eg("language arts"),
),
- Design : match_kw(
+ Design: match_kw(
"arts and crafts movement",
),
- Drama : match_kw(
+ Drama: match_kw(
Eg("opera"),
),
-
- Erotica : match_kw(
+ Erotica: match_kw(
Eg("erotic poetry"),
Eg("gay erotica"),
Eg("lesbian erotica"),
Eg("erotic photography"),
),
-
- Games : match_kw(
- Eg("games.*fantasy")
- ),
-
- Historical_Fiction : match_kw(
- Eg("arthurian romance.*"), # This is "romance" in the old
- # sense of a story.
- ),
-
- Literary_Criticism : match_kw(
- Eg("literary history"), # Not History
- Eg("romance language"), # Not Romance
+ Games: match_kw(Eg("games.*fantasy")),
+ Historical_Fiction: match_kw(
+ Eg("arthurian romance.*"), # This is "romance" in the old
+ # sense of a story.
),
-
- Media_Tie_in_SF : match_kw(
- 'tv, movie, video game adaptations' # Not Film & TV
+ Literary_Criticism: match_kw(
+ Eg("literary history"), # Not History
+ Eg("romance language"), # Not Romance
),
-
+ Media_Tie_in_SF: match_kw("tv, movie, video game adaptations"), # Not Film & TV
# We need to match these first so that the 'military'/'warfare'
# part doesn't match Military History.
Military_SF: match_kw(
"science fiction.*military",
"military.*science fiction",
- Eg("space warfare"), # Thankfully
+ Eg("space warfare"), # Thankfully
Eg("interstellar warfare"),
),
Military_Thriller: match_kw(
"military thrillers",
"thrillers.*military",
),
- Pets : match_kw(
+ Pets: match_kw(
"human-animal relationships",
),
- Political_Science : match_kw(
+ Political_Science: match_kw(
Eg("health care reform"),
),
-
# Stop the 'religious' from matching Religion/Spirituality.
Religious_Fiction: match_kw(
Eg("christian fiction"),
@@ -1217,10 +1110,9 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
Eg("fiction.*christian"),
"religious fiction",
"fiction.*religious",
- Eg("Oriental religions and wisdom")
+ Eg("Oriental religions and wisdom"),
),
-
- Romantic_Suspense : match_kw(
+ Romantic_Suspense: match_kw(
"romantic.*suspense",
"suspense.*romance",
"romance.*suspense",
@@ -1228,38 +1120,32 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
"romance.*thriller",
"thriller.*romance",
),
-
# Stop from showing up as 'science'
- Social_Sciences : match_kw(
+ Social_Sciences: match_kw(
"social sciences",
"social science",
"human science",
),
-
- Science_Fiction : match_kw(
+ Science_Fiction: match_kw(
"science fiction",
"science fiction.*general",
),
-
Supernatural_Thriller: match_kw(
"thriller.*supernatural",
"supernatural.*thriller",
),
-
# Stop from going into Mystery due to 'crime'
True_Crime: match_kw(
"true crime",
),
-
# Otherwise fiction.*urban turns Urban Fantasy into Urban Fiction
- Urban_Fantasy : match_kw(
+ Urban_Fantasy: match_kw(
"fiction.*fantasy.*urban",
),
-
# Stop the 'children' in 'children of' from matching Parenting.
- None : match_kw(
+ None: match_kw(
"children of",
- )
+ ),
}
LEVEL_3_KEYWORDS = {
@@ -1268,16 +1154,15 @@ class KeywordBasedClassifier(AgeOrGradeClassifier):
),
}
-
@classmethod
def is_fiction(cls, identifier, name, exclude_examples=False):
if not name:
return None
- if (cls.LEVEL_1_NONFICTION_INDICATORS["search"](name, exclude_examples)):
+ if cls.LEVEL_1_NONFICTION_INDICATORS["search"](name, exclude_examples):
return False
- if (cls.LEVEL_2_FICTION_INDICATORS["search"](name, exclude_examples)):
+ if cls.LEVEL_2_FICTION_INDICATORS["search"](name, exclude_examples):
return True
- if (cls.LEVEL_2_NONFICTION_INDICATORS["search"](name, exclude_examples)):
+ if cls.LEVEL_2_NONFICTION_INDICATORS["search"](name, exclude_examples):
return False
return None
@@ -1310,7 +1195,10 @@ def audience_match(cls, query):
audience_words = None
audience = cls.audience(None, query, exclude_examples=True)
if audience:
- for audience_keywords in [cls.JUVENILE_INDICATORS, cls.YOUNG_ADULT_INDICATORS]:
+ for audience_keywords in [
+ cls.JUVENILE_INDICATORS,
+ cls.YOUNG_ADULT_INDICATORS,
+ ]:
match = audience_keywords["search"](query, exclude_examples=True)
if match:
audience_words = match.group()
@@ -1318,15 +1206,21 @@ def audience_match(cls, query):
return (audience, audience_words)
@classmethod
- def genre(cls, identifier, name, fiction=None, audience=None, exclude_examples=False):
+ def genre(
+ cls, identifier, name, fiction=None, audience=None, exclude_examples=False
+ ):
matches = Counter()
match_against = [name]
for l in [cls.LEVEL_3_KEYWORDS, cls.LEVEL_2_KEYWORDS, cls.CATCHALL_KEYWORDS]:
for genre, keywords in list(l.items()):
if genre and fiction is not None and genre.is_fiction != fiction:
continue
- if (genre and audience and genre.audience_restriction
- and audience not in genre.audience_restriction):
+ if (
+ genre
+ and audience
+ and genre.audience_restriction
+ and audience not in genre.audience_restriction
+ ):
continue
if keywords and keywords["search"](name, exclude_examples):
matches[genre] += 1
@@ -1338,8 +1232,9 @@ def genre(cls, identifier, name, fiction=None, audience=None, exclude_examples=F
# because it's more specific.
for genre, count in matches.most_common():
if not most_specific_genre or (
- most_specific_genre.has_subgenre(genre)
- and count >= most_specific_count):
+ most_specific_genre.has_subgenre(genre)
+ and count >= most_specific_count
+ ):
most_specific_genre = genre
most_specific_count = count
if most_specific_genre:
@@ -1352,7 +1247,11 @@ def genre_match(cls, query):
genre_words = None
genre = cls.genre(None, query, exclude_examples=True)
if genre:
- for kwlist in [cls.LEVEL_3_KEYWORDS, cls.LEVEL_2_KEYWORDS, cls.CATCHALL_KEYWORDS]:
+ for kwlist in [
+ cls.LEVEL_3_KEYWORDS,
+ cls.LEVEL_2_KEYWORDS,
+ cls.CATCHALL_KEYWORDS,
+ ]:
if genre in list(kwlist.keys()):
genre_keywords = kwlist[genre]
match = genre_keywords["search"](query, exclude_examples=True)
@@ -1365,12 +1264,15 @@ def genre_match(cls, query):
class LCSHClassifier(KeywordBasedClassifier):
pass
+
class FASTClassifier(KeywordBasedClassifier):
pass
+
class TAGClassifier(KeywordBasedClassifier):
pass
+
Classifier.classifiers[Classifier.FAST] = FASTClassifier
Classifier.classifiers[Classifier.LCSH] = LCSHClassifier
Classifier.classifiers[Classifier.TAG] = TAGClassifier
diff --git a/classifier/lcc.py b/classifier/lcc.py
index 0aad74431..d67b8f16c 100644
--- a/classifier/lcc.py
+++ b/classifier/lcc.py
@@ -1,5 +1,6 @@
from . import *
+
class LCCClassifier(Classifier):
TOP_LEVEL = re.compile("^([A-Z]{1,2})")
@@ -7,7 +8,6 @@ class LCCClassifier(Classifier):
JUVENILE = set(["PZ"])
GENRES = {
-
# Unclassified/complicated stuff.
# "America": E11-E143
# Ancient_History: D51-D90
@@ -23,40 +23,79 @@ class LCCClassifier(Classifier):
# Sports: GV557-1198.995
# TODO: E and F are actually "the Americas".
# United_States_History is E151-E909, F1-F975 but not E456-E655
- African_History : ["DT"],
- Ancient_History : ["DE"],
- Architecture : ["NA"],
- Art_Criticism_Theory : ["BH"],
- Asian_History : ["DS", "DU"],
- Biography_Memoir : ["CT"],
- Business : ["HC", "HF", "HJ"],
- Christianity : ["BR", "BS", "BT", "BV", "BX"],
- Cooking : ["TX"],
- Crafts_Hobbies : ["TT"],
- Economics : ["HB"],
- Education : ["L"],
- European_History : ["DA", "DAW", "DB", "DD", "DF", "DG", "DH", "DJ", "DK", "DL", "DP", "DQ", "DR"],
- Folklore : ["GR"],
- Games : ["GV"],
- Islam : ["BP"],
- Judaism : ["BM"],
- Literary_Criticism : ["Z"],
- Mathematics : ["QA", "HA", "GA"],
+ African_History: ["DT"],
+ Ancient_History: ["DE"],
+ Architecture: ["NA"],
+ Art_Criticism_Theory: ["BH"],
+ Asian_History: ["DS", "DU"],
+ Biography_Memoir: ["CT"],
+ Business: ["HC", "HF", "HJ"],
+ Christianity: ["BR", "BS", "BT", "BV", "BX"],
+ Cooking: ["TX"],
+ Crafts_Hobbies: ["TT"],
+ Economics: ["HB"],
+ Education: ["L"],
+ European_History: [
+ "DA",
+ "DAW",
+ "DB",
+ "DD",
+ "DF",
+ "DG",
+ "DH",
+ "DJ",
+ "DK",
+ "DL",
+ "DP",
+ "DQ",
+ "DR",
+ ],
+ Folklore: ["GR"],
+ Games: ["GV"],
+ Islam: ["BP"],
+ Judaism: ["BM"],
+ Literary_Criticism: ["Z"],
+ Mathematics: ["QA", "HA", "GA"],
Medical: ["QM", "R"],
Military_History: ["U", "V"],
Music: ["M"],
- Parenting_Family : ["HQ"],
- Periodicals : ["AP", "AN"],
- Philosophy : ["BC", "BD", "BJ"],
+ Parenting_Family: ["HQ"],
+ Periodicals: ["AP", "AN"],
+ Philosophy: ["BC", "BD", "BJ"],
Photography: ["TR"],
- Political_Science : ["J", "HX"],
- Psychology : ["BF"],
- Reference_Study_Aids : ["AE", "AG", "AI"],
- Religion_Spirituality : ["BL", "BQ"],
- Science : ["QB", "QC", "QD", "QE", "QH", "QK", "QL", "QR", "CC", "GB", "GC", "QP"],
- Social_Sciences : ["HD", "HE", "HF", "HM", "HN", "HS", "HT", "HV", "GN", "GF", "GT"],
+ Political_Science: ["J", "HX"],
+ Psychology: ["BF"],
+ Reference_Study_Aids: ["AE", "AG", "AI"],
+ Religion_Spirituality: ["BL", "BQ"],
+ Science: [
+ "QB",
+ "QC",
+ "QD",
+ "QE",
+ "QH",
+ "QK",
+ "QL",
+ "QR",
+ "CC",
+ "GB",
+ "GC",
+ "QP",
+ ],
+ Social_Sciences: [
+ "HD",
+ "HE",
+ "HF",
+ "HM",
+ "HN",
+ "HS",
+ "HT",
+ "HV",
+ "GN",
+ "GF",
+ "GT",
+ ],
Sports: ["SK"],
- World_History : ["CB"],
+ World_History: ["CB"],
}
LEFTOVERS = dict(
@@ -87,9 +126,9 @@ def name_for(cls, identifier):
@classmethod
def is_fiction(cls, identifier, name):
- if identifier == 'P':
+ if identifier == "P":
return True
- if not identifier.startswith('P'):
+ if not identifier.startswith("P"):
return False
for i in cls.FICTION:
if identifier.startswith(i):
@@ -116,4 +155,5 @@ def audience(cls, identifier, name):
# trust that assumption.
return None
+
Classifier.classifiers[Classifier.LCC] = LCCClassifier
diff --git a/classifier/overdrive.py b/classifier/overdrive.py
index 72e86c4ce..b5411ad2e 100644
--- a/classifier/overdrive.py
+++ b/classifier/overdrive.py
@@ -2,6 +2,7 @@
from . import *
+
class OverdriveClassifier(Classifier):
# These genres are only used to describe video titles.
@@ -24,7 +25,7 @@ class OverdriveClassifier(Classifier):
"Stage Production",
"Theater",
"TV Series",
- "Young Adult Video"
+ "Young Adult Video",
]
# These genres are only used to describe music titles.
@@ -58,99 +59,113 @@ class OverdriveClassifier(Classifier):
"Rock",
"Soundtrack",
"Vocal",
- "World Music"
+ "World Music",
]
# Any classification that includes the string "Fiction" will be
# counted as fiction. This is just the leftovers.
- FICTION = set([
- "Fantasy",
- "Horror",
- "Literary Anthologies",
- "Mystery",
- "Romance",
- "Short Stories",
- "Suspense",
- "Thriller",
- "Western",
- ])
-
- NEITHER_FICTION_NOR_NONFICTION = [
- "Drama", "Poetry", "Latin",
- ] + MUSIC_GENRES + VIDEO_GENRES
+ FICTION = set(
+ [
+ "Fantasy",
+ "Horror",
+ "Literary Anthologies",
+ "Mystery",
+ "Romance",
+ "Short Stories",
+ "Suspense",
+ "Thriller",
+ "Western",
+ ]
+ )
+
+ NEITHER_FICTION_NOR_NONFICTION = (
+ [
+ "Drama",
+ "Poetry",
+ "Latin",
+ ]
+ + MUSIC_GENRES
+ + VIDEO_GENRES
+ )
GENRES = {
- Antiques_Collectibles : "Antiques",
- Architecture : "Architecture",
- Art : "Art",
- Biography_Memoir : "Biography & Autobiography",
- Business : ["Business", "Marketing & Sales", "Careers"],
- Christianity : "Christian Nonfiction",
- Computers : ["Computer Technology", "Social Media"],
- Classics : "Classic Literature",
- Cooking : "Cooking & Food",
- Crafts_Hobbies : "Crafts",
- Games : "Games",
- Drama : "Drama",
- Economics : "Economics",
- Education : "Education",
- Erotica : "Erotic Literature",
- Fantasy : "Fantasy",
- Folklore : ["Folklore", "Mythology"],
- Foreign_Language_Study : "Foreign Language Study",
- Gardening : "Gardening",
- Comics_Graphic_Novels : "Comic and Graphic Books",
- Health_Diet : "Health & Fitness",
- Historical_Fiction : ["Historical Fiction", "Antiquarian"],
- History : "History",
- Horror : "Horror",
- House_Home : "Home Design & Décor",
- Humorous_Fiction : "Humor (Fiction)",
- Humorous_Nonfiction : "Humor (Nonfiction)",
- Entertainment : "Entertainment",
- Judaism : "Judaica",
- Law : "Law",
- Literary_Criticism : [
- "Literary Criticism", "Criticism", "Language Arts", "Writing",
+ Antiques_Collectibles: "Antiques",
+ Architecture: "Architecture",
+ Art: "Art",
+ Biography_Memoir: "Biography & Autobiography",
+ Business: ["Business", "Marketing & Sales", "Careers"],
+ Christianity: "Christian Nonfiction",
+ Computers: ["Computer Technology", "Social Media"],
+ Classics: "Classic Literature",
+ Cooking: "Cooking & Food",
+ Crafts_Hobbies: "Crafts",
+ Games: "Games",
+ Drama: "Drama",
+ Economics: "Economics",
+ Education: "Education",
+ Erotica: "Erotic Literature",
+ Fantasy: "Fantasy",
+ Folklore: ["Folklore", "Mythology"],
+ Foreign_Language_Study: "Foreign Language Study",
+ Gardening: "Gardening",
+ Comics_Graphic_Novels: "Comic and Graphic Books",
+ Health_Diet: "Health & Fitness",
+ Historical_Fiction: ["Historical Fiction", "Antiquarian"],
+ History: "History",
+ Horror: "Horror",
+ House_Home: "Home Design & Décor",
+ Humorous_Fiction: "Humor (Fiction)",
+ Humorous_Nonfiction: "Humor (Nonfiction)",
+ Entertainment: "Entertainment",
+ Judaism: "Judaica",
+ Law: "Law",
+ Literary_Criticism: [
+ "Literary Criticism",
+ "Criticism",
+ "Language Arts",
+ "Writing",
],
- Management_Leadership : "Management",
- Mathematics : "Mathematics",
- Medical : "Medical",
- Military_History : "Military",
- Music : ["Music", "Songbook"],
- Mystery : "Mystery",
- Nature : "Nature",
- Body_Mind_Spirit : "New Age",
- Parenting_Family : ["Family & Relationships", "Child Development"],
- Performing_Arts : "Performing Arts",
- Personal_Finance_Investing : "Finance",
- Pets : "Pets",
- Philosophy : ["Philosophy", "Ethics"],
- Photography : "Photography",
- Poetry : "Poetry",
- Political_Science : ["Politics", "Current Events"],
- Psychology : ["Psychology", "Psychiatry", "Psychiatry & Psychology"],
- Reference_Study_Aids : ["Reference", "Grammar & Language Usage"],
- Religious_Fiction : ["Christian Fiction"],
- Religion_Spirituality : "Religion & Spirituality",
- Romance : "Romance",
- Science : ["Science", "Physics", "Chemistry", "Biology"],
- Science_Fiction : "Science Fiction",
+ Management_Leadership: "Management",
+ Mathematics: "Mathematics",
+ Medical: "Medical",
+ Military_History: "Military",
+ Music: ["Music", "Songbook"],
+ Mystery: "Mystery",
+ Nature: "Nature",
+ Body_Mind_Spirit: "New Age",
+ Parenting_Family: ["Family & Relationships", "Child Development"],
+ Performing_Arts: "Performing Arts",
+ Personal_Finance_Investing: "Finance",
+ Pets: "Pets",
+ Philosophy: ["Philosophy", "Ethics"],
+ Photography: "Photography",
+ Poetry: "Poetry",
+ Political_Science: ["Politics", "Current Events"],
+ Psychology: ["Psychology", "Psychiatry", "Psychiatry & Psychology"],
+ Reference_Study_Aids: ["Reference", "Grammar & Language Usage"],
+ Religious_Fiction: ["Christian Fiction"],
+ Religion_Spirituality: "Religion & Spirituality",
+ Romance: "Romance",
+ Science: ["Science", "Physics", "Chemistry", "Biology"],
+ Science_Fiction: "Science Fiction",
# Science_Fiction_Fantasy : "Science Fiction & Fantasy",
- Self_Help : ["Self-Improvement", "Self-Help", "Self Help", "Recovery"],
- Short_Stories : ["Literary Anthologies", "Short Stories"],
- Social_Sciences : [
- "Sociology", "Gender Studies",
- "Genealogy", "Media Studies", "Social Studies",
+ Self_Help: ["Self-Improvement", "Self-Help", "Self Help", "Recovery"],
+ Short_Stories: ["Literary Anthologies", "Short Stories"],
+ Social_Sciences: [
+ "Sociology",
+ "Gender Studies",
+ "Genealogy",
+ "Media Studies",
+ "Social Studies",
],
- Sports : "Sports & Recreations",
- Study_Aids : ["Study Aids & Workbooks", "Text Book"],
- Technology : ["Technology", "Engineering", "Transportation"],
- Suspense_Thriller : ["Suspense", "Thriller"],
- Travel : ["Travel", "Travel Literature", "Outdoor Recreation"],
- True_Crime : "True Crime",
+ Sports: "Sports & Recreations",
+ Study_Aids: ["Study Aids & Workbooks", "Text Book"],
+ Technology: ["Technology", "Engineering", "Transportation"],
+ Suspense_Thriller: ["Suspense", "Thriller"],
+ Travel: ["Travel", "Travel Literature", "Outdoor Recreation"],
+ True_Crime: "True Crime",
Urban_Fiction: ["African American Fiction", "Urban Fiction"],
- Westerns : "Western",
+ Westerns: "Western",
Womens_Fiction: "Chick Lit Fiction",
}
@@ -158,15 +173,17 @@ class OverdriveClassifier(Classifier):
def scrub_identifier(cls, identifier):
if not identifier:
return identifier
- if identifier.startswith('Foreign Language Study'):
- return 'Foreign Language Study'
+ if identifier.startswith("Foreign Language Study"):
+ return "Foreign Language Study"
return identifier
@classmethod
def is_fiction(cls, identifier, name):
- if (identifier in cls.FICTION
+ if (
+ identifier in cls.FICTION
or "Fiction" in identifier
- or "Literature" in identifier):
+ or "Literature" in identifier
+ ):
# "Literature" on Overdrive seems to be synonymous with fiction,
# but not necessarily "Literary Fiction".
return True
@@ -179,24 +196,28 @@ def is_fiction(cls, identifier, name):
@classmethod
def audience(cls, identifier, name):
- if ("Juvenile" in identifier or "Picture Book" in identifier
- or "Beginning Reader" in identifier or "Children's" in identifier):
+ if (
+ "Juvenile" in identifier
+ or "Picture Book" in identifier
+ or "Beginning Reader" in identifier
+ or "Children's" in identifier
+ ):
return cls.AUDIENCE_CHILDREN
elif "Young Adult" in identifier:
return cls.AUDIENCE_YOUNG_ADULT
- elif identifier in ('Fiction', 'Nonfiction'):
+ elif identifier in ("Fiction", "Nonfiction"):
return cls.AUDIENCE_ADULT
- elif identifier == 'Erotic Literature':
+ elif identifier == "Erotic Literature":
return cls.AUDIENCE_ADULTS_ONLY
return None
@classmethod
def target_age(cls, identifier, name):
- if identifier.startswith('Picture Book'):
+ if identifier.startswith("Picture Book"):
return cls.range_tuple(0, 4)
- elif identifier.startswith('Beginning Reader'):
- return cls.range_tuple(5,8)
- elif 'Young Adult' in identifier:
+ elif identifier.startswith("Beginning Reader"):
+ return cls.range_tuple(5, 8)
+ elif "Young Adult" in identifier:
# Internally we believe that 'Young Adult' means ages
# 14-17, but after looking at a large number of Overdrive
# books classified as 'Young Adult' we think that
@@ -209,8 +230,9 @@ def genre(cls, identifier, name, fiction=None, audience=None):
for l, v in list(cls.GENRES.items()):
if identifier == v or (isinstance(v, list) and identifier in v):
return l
- if identifier == 'Gay/Lesbian' and fiction:
+ if identifier == "Gay/Lesbian" and fiction:
return LGBTQ_Fiction
return None
+
Classifier.classifiers[Classifier.OVERDRIVE] = OverdriveClassifier
diff --git a/classifier/simplified.py b/classifier/simplified.py
index eecb9050b..d5d8addc9 100644
--- a/classifier/simplified.py
+++ b/classifier/simplified.py
@@ -1,6 +1,8 @@
-from . import *
from urllib.parse import unquote
+from . import *
+
+
class SimplifiedGenreClassifier(Classifier):
NONE = NO_VALUE
@@ -12,7 +14,7 @@ def scrub_identifier(cls, identifier):
if not identifier:
return identifier
if identifier.startswith(cls.SIMPLIFIED_GENRE):
- identifier = identifier[len(cls.SIMPLIFIED_GENRE):]
+ identifier = identifier[len(cls.SIMPLIFIED_GENRE) :]
identifier = unquote(identifier)
return Lowercased(identifier)
@@ -44,7 +46,6 @@ def _genre_by_name(cls, name, genres):
class SimplifiedFictionClassifier(Classifier):
-
@classmethod
def scrub_identifier(cls, identifier):
# If the identifier is a URI identifying a Simplified genre,
@@ -52,7 +53,7 @@ def scrub_identifier(cls, identifier):
if not identifier:
return identifier
if identifier.startswith(cls.SIMPLIFIED_FICTION_STATUS):
- identifier = identifier[len(cls.SIMPLIFIED_FICTION_STATUS):]
+ identifier = identifier[len(cls.SIMPLIFIED_FICTION_STATUS) :]
identifier = unquote(identifier)
return Lowercased(identifier)
@@ -65,5 +66,8 @@ def is_fiction(cls, identifier, name):
else:
return None
+
Classifier.classifiers[Classifier.SIMPLIFIED_GENRE] = SimplifiedGenreClassifier
-Classifier.classifiers[Classifier.SIMPLIFIED_FICTION_STATUS] = SimplifiedFictionClassifier
+Classifier.classifiers[
+ Classifier.SIMPLIFIED_FICTION_STATUS
+] = SimplifiedFictionClassifier
diff --git a/config.py b/config.py
index af20c7e8d..132c09b29 100644
--- a/config.py
+++ b/config.py
@@ -1,22 +1,23 @@
import contextlib
-import os
+import copy
import json
import logging
-import copy
+import os
+
+from flask_babel import lazy_gettext as _
from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import ArgumentError
from sqlalchemy.orm.session import Session
-from flask_babel import lazy_gettext as _
-from .facets import FacetConstants
from .entrypoint import EntryPoint
+from .facets import FacetConstants
from .util import LanguageCodes
-from .util.datetime_helpers import utc_now
+from .util.datetime_helpers import to_utc, utc_now
+
# It's convenient for other modules import IntegrationException
# from this module, alongside CannotLoadConfiguration.
from .util.http import IntegrationException
-from .util.datetime_helpers import to_utc
class CannotLoadConfiguration(IntegrationException):
@@ -28,6 +29,7 @@ class CannotLoadConfiguration(IntegrationException):
configuration, with no need to actually talk to the foreign
server.
"""
+
pass
@@ -45,6 +47,7 @@ def temp_config(new_config=None, replacement_classes=None):
for c in replacement_classes:
c.instance = old_config
+
@contextlib.contextmanager
def empty_config(replacement_classes=None):
with temp_config({}, replacement_classes) as i:
@@ -58,7 +61,7 @@ class ConfigurationConstants(object):
# one configuring which facet is the default.
ENABLED_FACETS_KEY_PREFIX = "facets_enabled_"
DEFAULT_FACET_KEY_PREFIX = "facets_default_"
-
+
# The "level" property determines which admins will be able to modify the setting. Level 1 settings can be modified by anyone.
# Level 2 settings can be modified only by library managers and system admins (i.e. not by librarians). Level 3 settings can be changed only by system admins.
# If no level is specified, the setting will be treated as Level 1 by default.
@@ -66,6 +69,7 @@ class ConfigurationConstants(object):
SYS_ADMIN_OR_MANAGER = 2
SYS_ADMIN_ONLY = 3
+
class Configuration(ConfigurationConstants):
log = logging.getLogger("Configuration file loader")
@@ -75,14 +79,13 @@ class Configuration(ConfigurationConstants):
# this class is defined.
instance = None
-
# Environment variables that contain URLs to the database
- DATABASE_TEST_ENVIRONMENT_VARIABLE = 'SIMPLIFIED_TEST_DATABASE'
- DATABASE_PRODUCTION_ENVIRONMENT_VARIABLE = 'SIMPLIFIED_PRODUCTION_DATABASE'
+ DATABASE_TEST_ENVIRONMENT_VARIABLE = "SIMPLIFIED_TEST_DATABASE"
+ DATABASE_PRODUCTION_ENVIRONMENT_VARIABLE = "SIMPLIFIED_PRODUCTION_DATABASE"
# The version of the app.
- APP_VERSION = 'app_version'
- VERSION_FILENAME = '.version'
+ APP_VERSION = "app_version"
+ VERSION_FILENAME = ".version"
NO_APP_VERSION_FOUND = object()
# Logging stuff
@@ -101,10 +104,10 @@ class Configuration(ConfigurationConstants):
DATA_DIRECTORY = "data_directory"
# ConfigurationSetting key for the base url of the app.
- BASE_URL_KEY = 'base_url'
+ BASE_URL_KEY = "base_url"
# ConfigurationSetting to enable the MeasurementReaper script
- MEASUREMENT_REAPER = 'measurement_reaper_enabled'
+ MEASUREMENT_REAPER = "measurement_reaper_enabled"
# Policies, mostly circulation specific
POLICIES = "policies"
@@ -133,7 +136,7 @@ class Configuration(ConfigurationConstants):
THREEM_INTEGRATION = "3M"
# ConfigurationSetting key for a CDN's mirror domain
- CDN_MIRRORED_DOMAIN_KEY = 'mirrored_domain'
+ CDN_MIRRORED_DOMAIN_KEY = "mirrored_domain"
# The name of the per-library configuration policy that controls whether
# books may be put on hold.
@@ -151,11 +154,11 @@ class Configuration(ConfigurationConstants):
# The name of the per-library per-patron authentication integration
# regular expression used to derive a patron's external_type from
# their authorization_identifier.
- EXTERNAL_TYPE_REGULAR_EXPRESSION = 'external_type_regular_expression'
+ EXTERNAL_TYPE_REGULAR_EXPRESSION = "external_type_regular_expression"
- WEBSITE_URL = 'website'
- NAME = 'name'
- SHORT_NAME = 'short_name'
+ WEBSITE_URL = "website"
+ NAME = "name"
+ SHORT_NAME = "short_name"
DEBUG = "DEBUG"
INFO = "INFO"
@@ -164,20 +167,20 @@ class Configuration(ConfigurationConstants):
# The default value to put into the 'app' field of JSON-format logs,
# unless LOG_APP_NAME overrides it.
- DEFAULT_APP_NAME = 'simplified'
+ DEFAULT_APP_NAME = "simplified"
# Settings for the integration with protocol=INTERNAL_LOGGING
- LOG_LEVEL = 'log_level'
- LOG_APP_NAME = 'log_app'
- DATABASE_LOG_LEVEL = 'database_log_level'
+ LOG_LEVEL = "log_level"
+ LOG_APP_NAME = "log_app"
+ DATABASE_LOG_LEVEL = "database_log_level"
LOG_LEVEL_UI = [
- { "key": DEBUG, "label": _("Debug") },
- { "key": INFO, "label": _("Info") },
- { "key": WARN, "label": _("Warn") },
- { "key": ERROR, "label": _("Error") },
+ {"key": DEBUG, "label": _("Debug")},
+ {"key": INFO, "label": _("Info")},
+ {"key": WARN, "label": _("Warn")},
+ {"key": ERROR, "label": _("Error")},
]
- EXCLUDED_AUDIO_DATA_SOURCES = 'excluded_audio_data_sources'
+ EXCLUDED_AUDIO_DATA_SOURCES = "excluded_audio_data_sources"
SITEWIDE_SETTINGS = [
{
@@ -187,141 +190,181 @@ class Configuration(ConfigurationConstants):
"format": "url",
},
{
- "key": LOG_LEVEL, "label": _("Log Level"), "type": "select",
- "options": LOG_LEVEL_UI, "default": INFO,
+ "key": LOG_LEVEL,
+ "label": _("Log Level"),
+ "type": "select",
+ "options": LOG_LEVEL_UI,
+ "default": INFO,
},
{
- "key": LOG_APP_NAME, "label": _("Application name"),
- "description": _("Log messages originating from this application will be tagged with this name. If you run multiple instances, giving each one a different application name will help you determine which instance is having problems."),
+ "key": LOG_APP_NAME,
+ "label": _("Application name"),
+ "description": _(
+ "Log messages originating from this application will be tagged with this name. If you run multiple instances, giving each one a different application name will help you determine which instance is having problems."
+ ),
"default": DEFAULT_APP_NAME,
"required": True,
},
{
- "key": DATABASE_LOG_LEVEL, "label": _("Database Log Level"),
- "type": "select", "options": LOG_LEVEL_UI,
- "description": _("Database logs are extremely verbose, so unless you're diagnosing a database-related problem, it's a good idea to set a higher log level for database messages."),
+ "key": DATABASE_LOG_LEVEL,
+ "label": _("Database Log Level"),
+ "type": "select",
+ "options": LOG_LEVEL_UI,
+ "description": _(
+ "Database logs are extremely verbose, so unless you're diagnosing a database-related problem, it's a good idea to set a higher log level for database messages."
+ ),
"default": WARN,
},
{
"key": EXCLUDED_AUDIO_DATA_SOURCES,
"label": _("Excluded audiobook sources"),
- "description": _("Audiobooks from these data sources will be hidden from the collection, even if they would otherwise show up as available."),
+ "description": _(
+ "Audiobooks from these data sources will be hidden from the collection, even if they would otherwise show up as available."
+ ),
"default": None,
"required": True,
},
{
"key": MEASUREMENT_REAPER,
- "label": _("Cleanup old measurement data"), "type": "select",
- "description": _("If this settings is 'true' old book measurement data will be cleaned out of the database. Some sites may want to keep this data for later analysis."),
- "options": { "true": "true", "false": "false" }, "default": "true",
- },
- ]
-
- LIBRARY_SETTINGS = [
- {
- "key": NAME,
- "label": _("Name"),
- "description": _("The human-readable name of this library."),
- "category": "Basic Information",
- "level": ConfigurationConstants.SYS_ADMIN_ONLY,
- "required": True
- },
- {
- "key": SHORT_NAME,
- "label": _("Short name"),
- "description": _("A short name of this library, to use when identifying it in scripts or URLs, e.g. 'NYPL'."),
- "category": "Basic Information",
- "level": ConfigurationConstants.SYS_ADMIN_ONLY,
- "required": True
- },
- {
- "key": WEBSITE_URL,
- "label": _("URL of the library's website"),
- "description": _("The library's main website, e.g. \"https://www.nypl.org/\" (not this Circulation Manager's URL)."),
- "required": True,
- "format": "url",
- "level": ConfigurationConstants.SYS_ADMIN_ONLY,
- "category": "Basic Information"
- },
- {
- "key": ALLOW_HOLDS,
- "label": _("Allow books to be put on hold"),
+ "label": _("Cleanup old measurement data"),
"type": "select",
- "options": [
- { "key": "true", "label": _("Allow holds") },
- { "key": "false", "label": _("Disable holds") },
- ],
+ "description": _(
+ "If this settings is 'true' old book measurement data will be cleaned out of the database. Some sites may want to keep this data for later analysis."
+ ),
+ "options": {"true": "true", "false": "false"},
"default": "true",
- "category": "Loans, Holds, & Fines",
- "level": ConfigurationConstants.SYS_ADMIN_ONLY
- },
- { "key": EntryPoint.ENABLED_SETTING,
- "label": _("Enabled entry points"),
- "description": _("Patrons will see the selected entry points at the top level and in search results.
Currently supported audiobook vendors: Bibliotheca, Axis 360"),
- "type": "list",
- "options": [
- { "key": entrypoint.INTERNAL_NAME,
- "label": EntryPoint.DISPLAY_TITLES.get(entrypoint) }
- for entrypoint in EntryPoint.ENTRY_POINTS
- ],
- "default": [x.INTERNAL_NAME for x in EntryPoint.DEFAULT_ENABLED],
- "category": "Lanes & Filters",
- # Renders a component with options that get narrowed down as the user makes selections.
- "format": "narrow",
- # Renders an input field that cannot be edited.
- "readOnly": True,
- "level": ConfigurationConstants.SYS_ADMIN_ONLY
- },
- {
- "key": FEATURED_LANE_SIZE,
- "label": _("Maximum number of books in the 'featured' lanes"),
- "type": "number",
- "default": 15,
- "category": "Lanes & Filters",
- "level": ConfigurationConstants.ALL_ACCESS
-
- },
- {
- "key": MINIMUM_FEATURED_QUALITY,
- "label": _("Minimum quality for books that show up in 'featured' lanes"),
- "description": _("Between 0 and 1."),
- "type": "number",
- "max": 1,
- "default": DEFAULT_MINIMUM_FEATURED_QUALITY,
- "category": "Lanes & Filters",
- "level": ConfigurationConstants.ALL_ACCESS
},
- ] + [
- { "key": ConfigurationConstants.ENABLED_FACETS_KEY_PREFIX + group,
- "label": description,
- "type": "list",
- "options": [
- { "key": facet, "label": FacetConstants.FACET_DISPLAY_TITLES.get(facet) }
- for facet in FacetConstants.FACETS_BY_GROUP.get(group)
- ],
- "default": FacetConstants.FACETS_BY_GROUP.get(group),
- "category": "Lanes & Filters",
- # Tells the front end that each of these settings is related to the corresponding default setting.
- "paired": ConfigurationConstants.DEFAULT_FACET_KEY_PREFIX + group,
- "level": ConfigurationConstants.SYS_ADMIN_OR_MANAGER
- } for group, description in FacetConstants.GROUP_DESCRIPTIONS.items()
- ] + [
- { "key": ConfigurationConstants.DEFAULT_FACET_KEY_PREFIX + group,
- "label": _("Default %(group)s", group=display_name),
- "type": "select",
- "options": [
- { "key": facet, "label": FacetConstants.FACET_DISPLAY_TITLES.get(facet) }
- for facet in FacetConstants.FACETS_BY_GROUP.get(group)
- ],
- "default": FacetConstants.DEFAULT_FACET.get(group),
- "category": "Lanes & Filters",
- "skip": True
- } for group, display_name in FacetConstants.GROUP_DISPLAY_TITLES.items()
]
+ LIBRARY_SETTINGS = (
+ [
+ {
+ "key": NAME,
+ "label": _("Name"),
+ "description": _("The human-readable name of this library."),
+ "category": "Basic Information",
+ "level": ConfigurationConstants.SYS_ADMIN_ONLY,
+ "required": True,
+ },
+ {
+ "key": SHORT_NAME,
+ "label": _("Short name"),
+ "description": _(
+ "A short name of this library, to use when identifying it in scripts or URLs, e.g. 'NYPL'."
+ ),
+ "category": "Basic Information",
+ "level": ConfigurationConstants.SYS_ADMIN_ONLY,
+ "required": True,
+ },
+ {
+ "key": WEBSITE_URL,
+ "label": _("URL of the library's website"),
+ "description": _(
+ "The library's main website, e.g. \"https://www.nypl.org/\" (not this Circulation Manager's URL)."
+ ),
+ "required": True,
+ "format": "url",
+ "level": ConfigurationConstants.SYS_ADMIN_ONLY,
+ "category": "Basic Information",
+ },
+ {
+ "key": ALLOW_HOLDS,
+ "label": _("Allow books to be put on hold"),
+ "type": "select",
+ "options": [
+ {"key": "true", "label": _("Allow holds")},
+ {"key": "false", "label": _("Disable holds")},
+ ],
+ "default": "true",
+ "category": "Loans, Holds, & Fines",
+ "level": ConfigurationConstants.SYS_ADMIN_ONLY,
+ },
+ {
+ "key": EntryPoint.ENABLED_SETTING,
+ "label": _("Enabled entry points"),
+ "description": _(
+ "Patrons will see the selected entry points at the top level and in search results.
Currently supported audiobook vendors: Bibliotheca, Axis 360"
+ ),
+ "type": "list",
+ "options": [
+ {
+ "key": entrypoint.INTERNAL_NAME,
+ "label": EntryPoint.DISPLAY_TITLES.get(entrypoint),
+ }
+ for entrypoint in EntryPoint.ENTRY_POINTS
+ ],
+ "default": [x.INTERNAL_NAME for x in EntryPoint.DEFAULT_ENABLED],
+ "category": "Lanes & Filters",
+ # Renders a component with options that get narrowed down as the user makes selections.
+ "format": "narrow",
+ # Renders an input field that cannot be edited.
+ "readOnly": True,
+ "level": ConfigurationConstants.SYS_ADMIN_ONLY,
+ },
+ {
+ "key": FEATURED_LANE_SIZE,
+ "label": _("Maximum number of books in the 'featured' lanes"),
+ "type": "number",
+ "default": 15,
+ "category": "Lanes & Filters",
+ "level": ConfigurationConstants.ALL_ACCESS,
+ },
+ {
+ "key": MINIMUM_FEATURED_QUALITY,
+ "label": _(
+ "Minimum quality for books that show up in 'featured' lanes"
+ ),
+ "description": _("Between 0 and 1."),
+ "type": "number",
+ "max": 1,
+ "default": DEFAULT_MINIMUM_FEATURED_QUALITY,
+ "category": "Lanes & Filters",
+ "level": ConfigurationConstants.ALL_ACCESS,
+ },
+ ]
+ + [
+ {
+ "key": ConfigurationConstants.ENABLED_FACETS_KEY_PREFIX + group,
+ "label": description,
+ "type": "list",
+ "options": [
+ {
+ "key": facet,
+ "label": FacetConstants.FACET_DISPLAY_TITLES.get(facet),
+ }
+ for facet in FacetConstants.FACETS_BY_GROUP.get(group)
+ ],
+ "default": FacetConstants.FACETS_BY_GROUP.get(group),
+ "category": "Lanes & Filters",
+ # Tells the front end that each of these settings is related to the corresponding default setting.
+ "paired": ConfigurationConstants.DEFAULT_FACET_KEY_PREFIX + group,
+ "level": ConfigurationConstants.SYS_ADMIN_OR_MANAGER,
+ }
+ for group, description in FacetConstants.GROUP_DESCRIPTIONS.items()
+ ]
+ + [
+ {
+ "key": ConfigurationConstants.DEFAULT_FACET_KEY_PREFIX + group,
+ "label": _("Default %(group)s", group=display_name),
+ "type": "select",
+ "options": [
+ {
+ "key": facet,
+ "label": FacetConstants.FACET_DISPLAY_TITLES.get(facet),
+ }
+ for facet in FacetConstants.FACETS_BY_GROUP.get(group)
+ ],
+ "default": FacetConstants.DEFAULT_FACET.get(group),
+ "category": "Lanes & Filters",
+ "skip": True,
+ }
+ for group, display_name in FacetConstants.GROUP_DISPLAY_TITLES.items()
+ ]
+ )
+
# This is set once CDN data is loaded from the database and
# inserted into the Configuration object.
- CDNS_LOADED_FROM_DATABASE = 'loaded_from_database'
+ CDNS_LOADED_FROM_DATABASE = "loaded_from_database"
@classmethod
def load(cls, _db=None):
@@ -335,15 +378,13 @@ def load(cls, _db=None):
cls.load_cdns(_db)
cls.app_version()
for parent in cls.__bases__:
- if parent.__name__.endswith('Configuration'):
+ if parent.__name__.endswith("Configuration"):
parent.load(_db)
@classmethod
def cdns_loaded_from_database(cls):
"""Has the site configuration been loaded from the database yet?"""
- return cls.instance and cls.instance.get(
- cls.CDNS_LOADED_FROM_DATABASE, False
- )
+ return cls.instance and cls.instance.get(cls.CDNS_LOADED_FROM_DATABASE, False)
# General getters
@@ -363,9 +404,7 @@ def required(cls, key):
value = cls.get(key)
if value is not None:
return value
- raise ValueError(
- "Required configuration variable %s was not defined!" % key
- )
+ raise ValueError("Required configuration variable %s was not defined!" % key)
@classmethod
def integration(cls, name, required=False):
@@ -374,9 +413,8 @@ def integration(cls, name, required=False):
v = integrations.get(name, {})
if not v and required:
raise ValueError(
- "Required integration '%s' was not defined! I see: %r" % (
- name, ", ".join(sorted(integrations.keys()))
- )
+ "Required integration '%s' was not defined! I see: %r"
+ % (name, ", ".join(sorted(integrations.keys())))
)
return v
@@ -386,9 +424,7 @@ def integration_url(cls, name, required=False):
integration = cls.integration(name, required=required)
v = integration.get(cls.URL, None)
if not v and required:
- raise ValueError(
- "Integration '%s' did not define a required 'url'!" % name
- )
+ raise ValueError("Integration '%s' did not define a required 'url'!" % name)
return v
@classmethod
@@ -401,11 +437,13 @@ def cdns(cls):
# Create a new database connection and find that
# information now.
from .model import SessionManager
+
url = cls.database_url()
_db = SessionManager.session(url)
cls.load_cdns(_db)
from .model import ExternalIntegration
+
return cls.integration(ExternalIntegration.CDN)
@classmethod
@@ -413,9 +451,7 @@ def policy(cls, name, default=None, required=False):
"""Find a policy configuration by name."""
v = cls.get(cls.POLICIES, {}).get(name, default)
if not v and required:
- raise ValueError(
- "Required policy %s was not defined!" % name
- )
+ raise ValueError("Required policy %s was not defined!" % name)
return v
# More specific getters.
@@ -436,7 +472,7 @@ def database_url(cls):
# controls which database is used, and it's set by the
# package_setup() function called in every component's
# tests/__init__.py.
- test = os.environ.get('TESTING', False)
+ test = os.environ.get("TESTING", False)
if test:
config_key = cls.DATABASE_TEST_URL
environment_variable = cls.DATABASE_TEST_ENVIRONMENT_VARIABLE
@@ -447,7 +483,8 @@ def database_url(cls):
url = os.environ.get(environment_variable)
if not url:
raise CannotLoadConfiguration(
- "Database URL was not defined in environment variable (%s)." % environment_variable
+ "Database URL was not defined in environment variable (%s)."
+ % environment_variable
)
url_obj = None
@@ -457,8 +494,8 @@ def database_url(cls):
# Improve the error message by giving a guide as to what's
# likely to work.
raise ArgumentError(
- "Bad format for database URL (%s). Expected something like postgres://[username]:[password]@[hostname]:[port]/[database name]" %
- url
+ "Bad format for database URL (%s). Expected something like postgres://[username]:[password]@[hostname]:[port]/[database name]"
+ % url
)
# Calling __to_string__ will hide the password.
@@ -492,7 +529,8 @@ def data_directory(cls):
@classmethod
def load_cdns(cls, _db, config_instance=None):
from .model import ExternalIntegration as EI
- cdns = _db.query(EI).filter(EI.goal==EI.CDN_GOAL).all()
+
+ cdns = _db.query(EI).filter(EI.goal == EI.CDN_GOAL).all()
cdn_integration = dict()
for cdn in cdns:
cdn_integration[cdn.setting(cls.CDN_MIRRORED_DOMAIN_KEY).value] = cdn.url
@@ -512,7 +550,9 @@ def localization_languages(cls):
# The last time we *checked* whether the database configuration had
# changed.
- LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE = "last_checked_for_site_configuration_update"
+ LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE = (
+ "last_checked_for_site_configuration_update"
+ )
# A sitewide configuration setting controlling *how often* to check
# whether the database configuration has changed.
@@ -520,7 +560,7 @@ def localization_languages(cls):
# NOTE: This setting is currently not used; the most reliable
# value seems to be zero. Assuming that's true, this whole
# subsystem can be removed.
- SITE_CONFIGURATION_TIMEOUT = 'site_configuration_timeout'
+ SITE_CONFIGURATION_TIMEOUT = "site_configuration_timeout"
# The name of the service associated with a Timestamp that tracks
# the last time the site's configuration changed in the database.
@@ -531,13 +571,10 @@ def last_checked_for_site_configuration_update(cls):
"""When was the last time we actually checked when the database
was updated?
"""
- return cls.instance.get(
- cls.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE, None
- )
+ return cls.instance.get(cls.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE, None)
@classmethod
- def site_configuration_last_update(cls, _db, known_value=None,
- timeout=0):
+ def site_configuration_last_update(cls, _db, known_value=None, timeout=0):
"""Check when the site configuration was last updated.
Updates Configuration.instance[Configuration.SITE_CONFIGURATION_LAST_UPDATE].
@@ -564,6 +601,7 @@ def site_configuration_last_update(cls, _db, known_value=None,
# never set to None). This code will hopefully be removed soon.
if _db and timeout is None:
from .model import ConfigurationSetting
+
timeout = ConfigurationSetting.sitewide(
_db, cls.SITE_CONFIGURATION_TIMEOUT
).int_value
@@ -575,12 +613,13 @@ def site_configuration_last_update(cls, _db, known_value=None,
# None.
timeout = 60
- last_check = cls.instance.get(
- cls.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE
- )
+ last_check = cls.instance.get(cls.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE)
- if (not known_value
- and last_check and (now - last_check).total_seconds() < timeout):
+ if (
+ not known_value
+ and last_check
+ and (now - last_check).total_seconds() < timeout
+ ):
# We went to the database less than [timeout] seconds ago.
# Assume there has been no change.
return cls._site_configuration_last_update()
@@ -591,9 +630,9 @@ def site_configuration_last_update(cls, _db, known_value=None,
# called.
if not known_value:
from .model import Timestamp
+
known_value = Timestamp.value(
- _db, cls.SITE_CONFIGURATION_CHANGED, service_type=None,
- collection=None
+ _db, cls.SITE_CONFIGURATION_CHANGED, service_type=None, collection=None
)
if not known_value:
# The site configuration has never changed.
@@ -626,7 +665,7 @@ def load_from_file(cls):
This is being phased out in favor of taking all configuration from a
database.
"""
- cfv = 'SIMPLIFIED_CONFIGURATION_FILE'
+ cfv = "SIMPLIFIED_CONFIGURATION_FILE"
config_path = os.environ.get(cfv)
if config_path:
try:
@@ -634,19 +673,22 @@ def load_from_file(cls):
configuration = cls._load(open(config_path).read())
except Exception as e:
raise CannotLoadConfiguration(
- "Error loading configuration file %s: %s" % (
- config_path, e)
+ "Error loading configuration file %s: %s" % (config_path, e)
)
else:
- configuration = cls._load('{}')
+ configuration = cls._load("{}")
return configuration
@classmethod
def _load(cls, str):
- lines = [x for x in str.split("\n")
- if not (x.strip().startswith("#") or x.strip().startswith("//"))]
+ lines = [
+ x
+ for x in str.split("\n")
+ if not (x.strip().startswith("#") or x.strip().startswith("//"))
+ ]
return json.loads("\n".join(lines))
+
# Immediately load the configuration file (if any).
Configuration.instance = Configuration.load_from_file()
diff --git a/coverage.py b/coverage.py
index ff43229a6..d21d9aa3c 100644
--- a/coverage.py
+++ b/coverage.py
@@ -4,9 +4,9 @@
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.functions import func
+from . import log # This sets the appropriate log format.
+from .metadata_layer import ReplacementPolicy, TimestampData
from .model import (
- get_one,
- get_one_or_create,
BaseCoverageRecord,
Collection,
CollectionMissing,
@@ -20,20 +20,19 @@
Timestamp,
Work,
WorkCoverageRecord,
+ get_one,
+ get_one_or_create,
)
-from .metadata_layer import (
- ReplacementPolicy,
- TimestampData,
-)
-from .util.worker_pools import DatabaseJob
from .util.datetime_helpers import utc_now
-from . import log # This sets the appropriate log format.
+from .util.worker_pools import DatabaseJob
+
class CoverageFailure(object):
"""Object representing the failure to provide coverage."""
- def __init__(self, obj, exception, data_source=None, transient=True,
- collection=None):
+ def __init__(
+ self, obj, exception, data_source=None, transient=True, collection=None
+ ):
self.obj = obj
self.data_source = data_source
self.exception = exception
@@ -46,7 +45,10 @@ def __repr__(self):
else:
data_source = None
return "" % (
- self.obj, data_source, self.transient, self.exception
+ self.obj,
+ data_source,
+ self.transient,
+ self.exception,
)
def to_coverage_record(self, operation=None):
@@ -57,8 +59,7 @@ def to_coverage_record(self, operation=None):
)
record, ignore = CoverageRecord.add_for(
- self.obj, self.data_source, operation=operation,
- collection=self.collection
+ self.obj, self.data_source, operation=operation, collection=self.collection
)
record.exception = self.exception
if self.transient:
@@ -69,9 +70,7 @@ def to_coverage_record(self, operation=None):
def to_work_coverage_record(self, operation):
"""Convert this failure into a WorkCoverageRecord."""
- record, ignore = WorkCoverageRecord.add_for(
- self.obj, operation=operation
- )
+ record, ignore = WorkCoverageRecord.add_for(self.obj, operation=operation)
record.exception = self.exception
if self.transient:
record.status = CoverageRecord.TRANSIENT_FAILURE
@@ -85,6 +84,7 @@ class CoverageProviderProgress(TimestampData):
"""A TimestampData optimized for the special needs of
CoverageProviders.
"""
+
def __init__(self, *args, **kwargs):
super(CoverageProviderProgress, self).__init__(*args, **kwargs)
@@ -103,11 +103,12 @@ def achievements(self):
human-readable string.
"""
template = "Items processed: %d. Successes: %d, transient failures: %d, persistent failures: %d"
- total = (self.successes + self.transient_failures
- + self.persistent_failures)
+ total = self.successes + self.transient_failures + self.persistent_failures
return template % (
- total, self.successes, self.transient_failures,
- self.persistent_failures
+ total,
+ self.successes,
+ self.transient_failures,
+ self.persistent_failures,
)
@achievements.setter
@@ -152,7 +153,11 @@ class BaseCoverageProvider(object):
# doing this.
DEFAULT_BATCH_SIZE = 100
- def __init__(self, _db, batch_size=None, cutoff_time=None,
+ def __init__(
+ self,
+ _db,
+ batch_size=None,
+ cutoff_time=None,
registered_only=False,
):
"""Constructor.
@@ -170,13 +175,11 @@ def __init__(self, _db, batch_size=None, cutoff_time=None,
"""
self._db = _db
if not self.__class__.SERVICE_NAME:
- raise ValueError(
- "%s must define SERVICE_NAME." % self.__class__.__name__
- )
+ raise ValueError("%s must define SERVICE_NAME." % self.__class__.__name__)
service_name = self.__class__.SERVICE_NAME
operation = self.operation
if operation:
- service_name += ' (%s)' % operation
+ service_name += " (%s)" % operation
self.service_name = service_name
if not batch_size or batch_size < 0:
batch_size = self.DEFAULT_BATCH_SIZE
@@ -187,7 +190,7 @@ def __init__(self, _db, batch_size=None, cutoff_time=None,
@property
def log(self):
- if not hasattr(self, '_log'):
+ if not hasattr(self, "_log"):
self._log = logging.getLogger(self.service_name)
return self._log
@@ -222,7 +225,7 @@ def run_once_and_update_timestamp(self):
# previous attempt.
covered_status_lists = [
BaseCoverageRecord.PREVIOUSLY_ATTEMPTED,
- BaseCoverageRecord.DEFAULT_COUNT_AS_COVERED
+ BaseCoverageRecord.DEFAULT_COUNT_AS_COVERED,
]
start_time = utc_now()
timestamp = self.timestamp
@@ -257,10 +260,11 @@ def run_once_and_update_timestamp(self):
except Exception as e:
logging.error(
"CoverageProvider %s raised uncaught exception.",
- self.service_name, exc_info=e
+ self.service_name,
+ exc_info=e,
)
- progress.exception=traceback.format_exc()
- progress.finish=utc_now()
+ progress.exception = traceback.format_exc()
+ progress.finish = utc_now()
# The next run_once() call might raise an exception,
# so let's write the work to the database as it's
@@ -285,8 +289,10 @@ def run_once_and_update_timestamp(self):
def timestamp(self):
"""Look up the Timestamp object for this CoverageProvider."""
return Timestamp.lookup(
- self._db, self.service_name, Timestamp.COVERAGE_PROVIDER_TYPE,
- collection=self.collection
+ self._db,
+ self.service_name,
+ Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=self.collection,
)
def finalize_timestampdata(self, timestamp, **kwargs):
@@ -294,8 +300,10 @@ def finalize_timestampdata(self, timestamp, **kwargs):
database.
"""
timestamp.finalize(
- self.service_name, Timestamp.COVERAGE_PROVIDER_TYPE,
- collection=self.collection, **kwargs
+ self.service_name,
+ Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=self.collection,
+ **kwargs
)
timestamp.apply(self._db)
self._db.commit()
@@ -322,14 +330,17 @@ def run_once(self, progress, count_as_covered=None):
:return: A CoverageProviderProgress representing whatever
additional progress has been made.
"""
- count_as_covered = count_as_covered or BaseCoverageRecord.DEFAULT_COUNT_AS_COVERED
+ count_as_covered = (
+ count_as_covered or BaseCoverageRecord.DEFAULT_COUNT_AS_COVERED
+ )
# Make it clear which class of items we're covering on this
# run.
- count_as_covered_message = ' (counting %s as covered)' % (', '.join(count_as_covered))
+ count_as_covered_message = " (counting %s as covered)" % (
+ ", ".join(count_as_covered)
+ )
qu = self.items_that_need_coverage(count_as_covered=count_as_covered)
- self.log.info("%d items need coverage%s", qu.count(),
- count_as_covered_message)
+ self.log.info("%d items need coverage%s", qu.count(), count_as_covered_message)
batch = qu.limit(self.batch_size).offset(progress.offset)
if not batch.count():
@@ -337,9 +348,11 @@ def run_once(self, progress, count_as_covered=None):
progress.finish = utc_now()
return progress
- (successes, transient_failures, persistent_failures), results = (
- self.process_batch_and_handle_results(batch)
- )
+ (
+ successes,
+ transient_failures,
+ persistent_failures,
+ ), results = self.process_batch_and_handle_results(batch)
# Update the running totals so that the service's eventual timestamp
# will have a useful .achievements.
@@ -400,15 +413,13 @@ def process_batch_and_handle_results(self, batch):
record = self.record_failure_as_coverage_record(item)
if item.transient:
self.log.warn(
- "Transient failure covering %r: %s",
- item.obj, item.exception
+ "Transient failure covering %r: %s", item.obj, item.exception
)
record.status = BaseCoverageRecord.TRANSIENT_FAILURE
transient_failures += 1
else:
self.log.error(
- "Persistent failure covering %r: %s",
- item.obj, item.exception
+ "Persistent failure covering %r: %s", item.obj, item.exception
)
record.status = BaseCoverageRecord.PERSISTENT_FAILURE
persistent_failures += 1
@@ -428,7 +439,8 @@ def process_batch_and_handle_results(self, batch):
# failed. Treat them as transient failures.
for item in unhandled_items:
self.log.warn(
- "%r was ignored by a coverage provider that was supposed to cover it.", item
+ "%r was ignored by a coverage provider that was supposed to cover it.",
+ item,
)
failure = self.failure_for_ignored_item(item)
record = self.record_failure_as_coverage_record(failure)
@@ -438,7 +450,10 @@ def process_batch_and_handle_results(self, batch):
self.log.info(
"Batch processed with %d successes, %d transient failures, %d persistent failures, %d ignored.",
- successes, transient_failures, persistent_failures, num_ignored
+ successes,
+ transient_failures,
+ persistent_failures,
+ num_ignored,
)
# Finalize this batch before moving on to the next one.
@@ -483,7 +498,7 @@ def should_update(self, coverage_record):
# so we need to do the work.
return True
- if coverage_record.status==BaseCoverageRecord.REGISTERED:
+ if coverage_record.status == BaseCoverageRecord.REGISTERED:
# There's a CoverageRecord, but coverage hasn't actually
# been attempted. Try to get covered.
return True
@@ -567,6 +582,7 @@ class IdentifierCoverageProvider(BaseCoverageProvider):
in BaseCoverageProvider; the rest are described in appropriate
comments in this class.
"""
+
# In your subclass, set this to the name of the data source you
# consult when providing coverage, e.g. DataSource.OVERDRIVE.
DATA_SOURCE_NAME = None
@@ -588,8 +604,13 @@ class IdentifierCoverageProvider(BaseCoverageProvider):
# Collections the Identifier belongs to.
COVERAGE_COUNTS_FOR_EVERY_COLLECTION = True
- def __init__(self, _db, collection=None, input_identifiers=None,
- replacement_policy=None, **kwargs
+ def __init__(
+ self,
+ _db,
+ collection=None,
+ input_identifiers=None,
+ replacement_policy=None,
+ **kwargs
):
"""Constructor.
@@ -657,7 +678,8 @@ def _input_identifier_types(cls):
# INPUT_IDENTIFIER_TYPES.
if value is cls.NO_SPECIFIED_TYPES:
raise ValueError(
- "%s must define INPUT_IDENTIFIER_TYPES, even if the value is None." % (cls.__name__)
+ "%s must define INPUT_IDENTIFIER_TYPES, even if the value is None."
+ % (cls.__name__)
)
if not value:
@@ -675,8 +697,13 @@ def _input_identifier_types(cls):
return value
@classmethod
- def register(cls, identifier, data_source=None, collection=None,
- force=False, autocreate=False
+ def register(
+ cls,
+ identifier,
+ data_source=None,
+ collection=None,
+ force=False,
+ autocreate=False,
):
"""Registers an identifier for future coverage.
@@ -687,8 +714,11 @@ def register(cls, identifier, data_source=None, collection=None,
log = logging.getLogger(name)
new_records, ignored_identifiers = cls.bulk_register(
- [identifier], data_source=data_source, collection=collection,
- force=force, autocreate=autocreate
+ [identifier],
+ data_source=data_source,
+ collection=collection,
+ force=force,
+ autocreate=autocreate,
)
was_registered = identifier not in ignored_identifiers
@@ -697,7 +727,7 @@ def register(cls, identifier, data_source=None, collection=None,
[new_record] = new_records
if was_registered and new_record:
- log.info('CREATED %r' % new_record)
+ log.info("CREATED %r" % new_record)
return new_record, was_registered
_db = Session.object_session(identifier)
@@ -713,12 +743,17 @@ def register(cls, identifier, data_source=None, collection=None,
existing_record = CoverageRecord.lookup(
identifier, data_source, cls.OPERATION, collection=collection
)
- log.info('FOUND %r' % existing_record)
+ log.info("FOUND %r" % existing_record)
return existing_record, was_registered
@classmethod
- def bulk_register(cls, identifiers, data_source=None, collection=None,
- force=False, autocreate=False
+ def bulk_register(
+ cls,
+ identifiers,
+ data_source=None,
+ collection=None,
+ force=False,
+ autocreate=False,
):
"""Registers identifiers for future coverage.
@@ -754,8 +789,11 @@ def bulk_register(cls, identifiers, data_source=None, collection=None,
collection = None
new_records, ignored_identifiers = CoverageRecord.bulk_add(
- identifiers, data_source, operation=cls.OPERATION,
- status=CoverageRecord.REGISTERED, collection=collection,
+ identifiers,
+ data_source,
+ operation=cls.OPERATION,
+ status=CoverageRecord.REGISTERED,
+ collection=collection,
force=force,
)
@@ -788,7 +826,8 @@ def data_source(self):
def failure(self, identifier, error, transient=True):
"""Create a CoverageFailure object to memorialize an error."""
return CoverageFailure(
- identifier, error,
+ identifier,
+ error,
data_source=self.data_source,
transient=transient,
collection=self.collection_or_not,
@@ -802,8 +841,10 @@ def can_cover(self, identifier):
caller may need to decide whether to pass an Identifier
into ensure_coverage() or register().
"""
- return (not self.input_identifier_types
- or identifier.type in self.input_identifier_types)
+ return (
+ not self.input_identifier_types
+ or identifier.type in self.input_identifier_types
+ )
def run_on_specific_identifiers(self, identifiers):
"""Split a specific set of Identifiers into batches and process one
@@ -838,7 +879,7 @@ def run_on_specific_identifiers(self, identifiers):
# Iterate over any items that were not automatic
# successes.
while index < len(need_coverage):
- batch = need_coverage[index:index+self.batch_size]
+ batch = need_coverage[index : index + self.batch_size]
(s, t, p), r = self.process_batch_and_handle_results(batch)
successes += s
transient_failures += t
@@ -875,19 +916,18 @@ def ensure_coverage(self, item, force=False):
collection = self.collection
coverage_record = get_one(
- self._db, CoverageRecord,
+ self._db,
+ CoverageRecord,
identifier=identifier,
collection=collection,
data_source=self.data_source,
operation=self.operation,
- on_multiple='interchangeable',
+ on_multiple="interchangeable",
)
if not force and not self.should_update(coverage_record):
return coverage_record
- counts, records = self.process_batch_and_handle_results(
- [identifier]
- )
+ counts, records = self.process_batch_and_handle_results([identifier])
if records:
coverage_record = records[0]
else:
@@ -899,8 +939,7 @@ def edition(self, identifier):
view of a given Identifier.
"""
edition, ignore = Edition.for_foreign_id(
- self._db, self.data_source, identifier.type,
- identifier.identifier
+ self._db, self.data_source, identifier.type, identifier.identifier
)
return edition
@@ -928,13 +967,13 @@ def set_metadata(self, identifier, metadata):
# The metadata layer will not use the collection when creating
# CoverageRecords for the metadata actions.
metadata.apply(
- edition, collection=self.collection,
+ edition,
+ collection=self.collection,
replace=self.replacement_policy,
)
except Exception as e:
self.log.warn(
- "Error applying metadata to edition %d: %s",
- edition.id, e, exc_info=e
+ "Error applying metadata to edition %d: %s", edition.id, e, exc_info=e
)
return self.failure(identifier, repr(e), transient=True)
@@ -960,9 +999,13 @@ def items_that_need_coverage(self, identifiers=None, **kwargs):
parameters.
"""
qu = Identifier.missing_coverage_from(
- self._db, self.input_identifier_types, self.data_source,
- count_as_missing_before=self.cutoff_time, operation=self.operation,
- identifiers=self.input_identifiers, collection=self.collection_or_not,
+ self._db,
+ self.input_identifier_types,
+ self.data_source,
+ count_as_missing_before=self.cutoff_time,
+ operation=self.operation,
+ identifiers=self.input_identifiers,
+ collection=self.collection_or_not,
**kwargs
)
@@ -970,7 +1013,7 @@ def items_that_need_coverage(self, identifiers=None, **kwargs):
qu = qu.filter(Identifier.id.in_([x.id for x in identifiers]))
if not identifiers and identifiers != None:
# An empty list was provided. The returned query should be empty.
- qu = qu.filter(Identifier.id==None)
+ qu = qu.filter(Identifier.id == None)
if self.registered_only:
# Return Identifiers that have been "registered" for coverage
@@ -984,8 +1027,10 @@ def add_coverage_record_for(self, item):
Edition/Identifier, as a CoverageRecord.
"""
record, is_new = CoverageRecord.add_for(
- item, data_source=self.data_source, operation=self.operation,
- collection=self.collection_or_not
+ item,
+ data_source=self.data_source,
+ operation=self.operation,
+ collection=self.collection_or_not,
)
record.status = CoverageRecord.SUCCESS
record.exception = None
@@ -999,9 +1044,7 @@ def failure_for_ignored_item(self, item):
"""Create a CoverageFailure recording the CoverageProvider's
failure to even try to process an item.
"""
- return self.failure(
- item, "Was ignored by CoverageProvider.", transient=True
- )
+ return self.failure(item, "Was ignored by CoverageProvider.", transient=True)
class CollectionCoverageProvider(IdentifierCoverageProvider):
@@ -1034,6 +1077,7 @@ class CollectionCoverageProvider(IdentifierCoverageProvider):
`ExternalIntegration` class, such as
ExternalIntegration.OPDS_IMPORT or ExternalIntegration.OVERDRIVE.
"""
+
# By default, this type of CoverageProvider will provide coverage to
# all Identifiers in the given Collection, regardless of their type.
INPUT_IDENTIFIER_TYPES = None
@@ -1058,21 +1102,16 @@ def __init__(self, collection, **kwargs):
"""
if not isinstance(collection, Collection):
raise CollectionMissing(
- "%s must be instantiated with a Collection." % (
- self.__class__.__name__
- )
+ "%s must be instantiated with a Collection." % (self.__class__.__name__)
)
if self.PROTOCOL and collection.protocol != self.PROTOCOL:
raise ValueError(
- "Collection protocol (%s) does not match CoverageProvider protocol (%s)" % (
- collection.protocol, self.PROTOCOL
- )
+ "Collection protocol (%s) does not match CoverageProvider protocol (%s)"
+ % (collection.protocol, self.PROTOCOL)
)
_db = Session.object_session(collection)
- super(CollectionCoverageProvider, self).__init__(
- _db, collection, **kwargs
- )
+ super(CollectionCoverageProvider, self).__init__(_db, collection, **kwargs)
def _default_replacement_policy(self, _db):
"""Unless told otherwise, assume that we are getting
@@ -1108,9 +1147,7 @@ def all(cls, _db, **kwargs):
def run_once(self, *args, **kwargs):
self.log.info("Considering collection %s", self.collection.name)
- return super(CollectionCoverageProvider, self).run_once(
- *args, **kwargs
- )
+ return super(CollectionCoverageProvider, self).run_once(*args, **kwargs)
def items_that_need_coverage(self, identifiers=None, **kwargs):
"""Find all Identifiers associated with this Collection but lacking
@@ -1120,7 +1157,7 @@ def items_that_need_coverage(self, identifiers=None, **kwargs):
identifiers, **kwargs
)
qu = qu.join(Identifier.licensed_through).filter(
- LicensePool.collection_id==self.collection_id
+ LicensePool.collection_id == self.collection_id
)
return qu
@@ -1134,8 +1171,7 @@ def license_pool(self, identifier, data_source=None):
should only be needed by the metadata wrangler.
"""
license_pools = [
- p for p in identifier.licensed_through
- if self.collection==p.collection
+ p for p in identifier.licensed_through if self.collection == p.collection
]
if license_pools:
@@ -1158,8 +1194,11 @@ def license_pool(self, identifier, data_source=None):
# which typically has to manage information about books it has no
# rights to.
license_pool, ignore = LicensePool.for_foreign_id(
- self._db, data_source, identifier.type,
- identifier.identifier, collection=self.collection
+ self._db,
+ data_source,
+ identifier.type,
+ identifier.identifier,
+ collection=self.collection,
)
return license_pool
@@ -1202,26 +1241,24 @@ def work(self, identifier, license_pool=None, **calculate_work_kwargs):
error = None
if not license_pool:
license_pool, ignore = LicensePool.for_foreign_id(
- self._db, self.data_source, identifier.type,
- identifier.identifier, collection=self.collection,
- autocreate=False
+ self._db,
+ self.data_source,
+ identifier.type,
+ identifier.identifier,
+ collection=self.collection,
+ autocreate=False,
)
if license_pool:
- if (not license_pool.work
- or not license_pool.work.presentation_ready):
- for (v, default) in (
- ('exclude_search', self.EXCLUDE_SEARCH_INDEX),
- ):
+ if not license_pool.work or not license_pool.work.presentation_ready:
+ for (v, default) in (("exclude_search", self.EXCLUDE_SEARCH_INDEX),):
if not v in calculate_work_kwargs:
calculate_work_kwargs[v] = default
# Calling calculate_work will calculate the work's
# presentation and make it presentation-ready if
# possible.
- work, created = license_pool.calculate_work(
- **calculate_work_kwargs
- )
+ work, created = license_pool.calculate_work(**calculate_work_kwargs)
if not work:
error = "Work could not be calculated"
else:
@@ -1232,7 +1269,10 @@ def work(self, identifier, license_pool=None, **calculate_work_kwargs):
return work
def set_metadata_and_circulation_data(
- self, identifier, metadata, circulationdata,
+ self,
+ identifier,
+ metadata,
+ circulationdata,
):
"""Makes sure that the given Identifier has a Work, Edition (in the
context of this Collection), and LicensePool (ditto), and that
@@ -1293,8 +1333,7 @@ def _set_circulationdata(self, identifier, circulationdata):
else:
collection_name = ""
self.log.warn(
- "Error applying circulationdata%s: %s",
- collection_name, e, exc_info=e
+ "Error applying circulationdata%s: %s", collection_name, e, exc_info=e
)
return self.failure(identifier, repr(e), transient=True)
@@ -1310,10 +1349,7 @@ def set_presentation_ready(self, identifier):
class CollectionCoverageProviderJob(DatabaseJob):
-
- def __init__(self, collection, provider_class, progress,
- **provider_kwargs
- ):
+ def __init__(self, collection, provider_class, progress, **provider_kwargs):
self.collection = collection
self.progress = progress
self.provider_class = provider_class
@@ -1341,9 +1377,7 @@ def items_that_need_coverage(self, identifiers=None, **kwargs):
qu = super(CollectionCoverageProvider, self).items_that_need_coverage(
identifiers, **kwargs
)
- qu = qu.join(Identifier.collections).filter(
- Collection.id==self.collection_id
- )
+ qu = qu.join(Identifier.collections).filter(Collection.id == self.collection_id)
return qu
@@ -1411,15 +1445,14 @@ def items_that_need_coverage(self, identifiers=None, **kwargs):
are chosen.
"""
qu = Work.missing_coverage_from(
- self._db, operation=self.operation,
+ self._db,
+ operation=self.operation,
count_as_missing_before=self.cutoff_time,
**kwargs
)
if identifiers:
ids = [x.id for x in identifiers]
- qu = qu.join(Work.license_pools).filter(
- LicensePool.identifier_id.in_(ids)
- )
+ qu = qu.join(Work.license_pools).filter(LicensePool.identifier_id.in_(ids))
if self.registered_only:
# Return Identifiers that have been "registered" for coverage
@@ -1444,9 +1477,7 @@ def add_coverage_records_for(self, works):
"""Add WorkCoverageRecords for a group of works from a batch,
each of which was successful.
"""
- WorkCoverageRecord.bulk_add(
- works, operation=self.operation
- )
+ WorkCoverageRecord.bulk_add(works, operation=self.operation)
# We can't return the specific WorkCoverageRecords that were
# created, but it doesn't matter because they're not used except
@@ -1465,13 +1496,13 @@ def record_failure_as_coverage_record(self, failure):
class PresentationReadyWorkCoverageProvider(WorkCoverageProvider):
- """A WorkCoverageProvider that only covers presentation-ready works.
- """
+ """A WorkCoverageProvider that only covers presentation-ready works."""
+
def items_that_need_coverage(self, identifiers=None, **kwargs):
- qu = super(PresentationReadyWorkCoverageProvider, self).items_that_need_coverage(
- identifiers, **kwargs
-)
- qu = qu.filter(Work.presentation_ready==True)
+ qu = super(
+ PresentationReadyWorkCoverageProvider, self
+ ).items_that_need_coverage(identifiers, **kwargs)
+ qu = qu.filter(Work.presentation_ready == True)
return qu
@@ -1487,6 +1518,7 @@ class WorkPresentationProvider(PresentationReadyWorkCoverageProvider):
needs to have some aspect of its presentation recalculated. These
providers give back the 'missing' coverage.
"""
+
DEFAULT_BATCH_SIZE = 100
@@ -1498,6 +1530,7 @@ class OPDSEntryWorkCoverageProvider(WorkPresentationProvider):
over all presentation-ready works, even ones which are already
covered.
"""
+
SERVICE_NAME = "OPDS Entry Work Coverage Provider"
OPERATION = WorkCoverageRecord.GENERATE_OPDS_OPERATION
DEFAULT_BATCH_SIZE = 1000
@@ -1511,6 +1544,7 @@ class MARCRecordWorkCoverageProvider(WorkPresentationProvider):
"""Make sure all presentation-ready works have an up-to-date MARC
record.
"""
+
SERVICE_NAME = "MARC Record Work Coverage Provider"
OPERATION = WorkCoverageRecord.GENERATE_MARC_OPERATION
DEFAULT_BATCH_SIZE = 1000
@@ -1529,28 +1563,29 @@ class WorkPresentationEditionCoverageProvider(WorkPresentationProvider):
Expensive operations -- calculating work quality, summary, and genre
classification -- are reserved for WorkClassificationCoverageProvider
"""
- SERVICE_NAME = 'Calculated presentation coverage provider'
+
+ SERVICE_NAME = "Calculated presentation coverage provider"
OPERATION = WorkCoverageRecord.CHOOSE_EDITION_OPERATION
POLICY = PresentationCalculationPolicy(
- choose_edition=True, set_edition_metadata=True, verbose=True,
-
+ choose_edition=True,
+ set_edition_metadata=True,
+ verbose=True,
# These are the expensive ones, and they're covered by
# WorkSummaryQualityClassificationCoverageProvider.
- classify=False, choose_summary=False, calculate_quality=False,
-
+ classify=False,
+ choose_summary=False,
+ calculate_quality=False,
# It would be better if there were a separate class for this
# operation (COVER_OPERATION), but it's a little complicated because
# that's not a WorkCoverageRecord operation.
choose_cover=True,
-
# We do this even though it's redundant with
# OPDSEntryWorkCoverageProvider. If you change a
# Work's presentation edition but don't update its OPDS entry,
# it effectively didn't happen.
regenerate_opds_entries=True,
-
# Same logic for the search index. This will flag the Work as
# needing a search index update, and SearchIndexCoverageProvider
# will take care of it.
@@ -1569,9 +1604,7 @@ def process_item(self, work):
return work
-class WorkClassificationCoverageProvider(
- WorkPresentationEditionCoverageProvider
-):
+class WorkClassificationCoverageProvider(WorkPresentationEditionCoverageProvider):
"""Calculates the 'expensive' parts of a work's presentation:
classifications, summary, and quality.
@@ -1584,6 +1617,7 @@ class WorkClassificationCoverageProvider(
works get their summaries recalculated, you need to remember that
the coverage record to delete is CLASSIFY_OPERATION.
"""
+
SERVICE_NAME = "Work classification coverage provider"
DEFAULT_BATCH_SIZE = 20
diff --git a/entrypoint.py b/entrypoint.py
index afdc62062..abfea4753 100644
--- a/entrypoint.py
+++ b/entrypoint.py
@@ -1,5 +1,3 @@
-
-
class EntryPoint(object):
"""A EntryPoint is a top-level entry point into a library's Lane structure
@@ -43,19 +41,16 @@ def register(cls, entrypoint_class, display_title, default_enabled=False):
:param default_enabled: New libraries should have this entry point
enabled by default.
"""
- value = getattr(entrypoint_class, 'INTERNAL_NAME', None)
+ value = getattr(entrypoint_class, "INTERNAL_NAME", None)
if not value:
raise ValueError(
- "EntryPoint class %s must define INTERNAL_NAME." % entrypoint_class.__name__
+ "EntryPoint class %s must define INTERNAL_NAME."
+ % entrypoint_class.__name__
)
if value in cls.BY_INTERNAL_NAME:
- raise ValueError(
- "Duplicate entry point internal name: %s" % value
- )
+ raise ValueError("Duplicate entry point internal name: %s" % value)
if display_title in list(cls.DISPLAY_TITLES.values()):
- raise ValueError(
- "Duplicate entry point display name: %s" % display_title
- )
+ raise ValueError("Duplicate entry point display name: %s" % display_title)
cls.DISPLAY_TITLES[entrypoint_class] = display_title
cls.BY_INTERNAL_NAME[value] = entrypoint_class
cls.ENTRY_POINTS.append(entrypoint_class)
@@ -98,8 +93,11 @@ def modify_database_query(cls, _db, qu):
class EverythingEntryPoint(EntryPoint):
"""An entry point that has everything."""
+
INTERNAL_NAME = "All"
URI = "http://schema.org/CreativeWork"
+
+
EntryPoint.register(EverythingEntryPoint, "All")
@@ -118,7 +116,8 @@ def modify_database_query(cls, _db, qu):
to match only items with the right medium.
"""
from .model import Edition
- return qu.filter(Edition.medium==cls.INTERNAL_NAME)
+
+ return qu.filter(Edition.medium == cls.INTERNAL_NAME)
@classmethod
def modify_search_filter(cls, filter):
@@ -133,9 +132,14 @@ def modify_search_filter(cls, filter):
class EbooksEntryPoint(MediumEntryPoint):
INTERNAL_NAME = "Book"
URI = "http://schema.org/EBook"
+
+
EntryPoint.register(EbooksEntryPoint, "eBooks", default_enabled=True)
+
class AudiobooksEntryPoint(MediumEntryPoint):
INTERNAL_NAME = "Audio"
URI = "http://bib.schema.org/Audiobook"
+
+
EntryPoint.register(AudiobooksEntryPoint, "Audiobooks")
diff --git a/exceptions.py b/exceptions.py
index 9c73d1eae..f1d8fc0c7 100644
--- a/exceptions.py
+++ b/exceptions.py
@@ -13,7 +13,7 @@ def __init__(self, message=None, inner_exception=None):
super(BaseError, self).__init__(message)
self._inner_exception = str(inner_exception) if inner_exception else None
-
+
def __hash__(self):
return hash(str(self))
@@ -41,8 +41,6 @@ def __eq__(self, other):
return str(self) == str(other)
def __repr__(self):
- return ''.format(
- (self),
- self.inner_exception
+ return "".format(
+ (self), self.inner_exception
)
-
diff --git a/external_list.py b/external_list.py
index b5ea41e4a..8c4dfe795 100644
--- a/external_list.py
+++ b/external_list.py
@@ -1,21 +1,16 @@
# encoding: utf-8
-from collections import defaultdict
-from dateutil.parser import parse
import csv
+import logging
import os
+from collections import defaultdict
+
+from dateutil.parser import parse
from sqlalchemy import or_
from sqlalchemy.orm.session import Session
-from .opds_import import SimplifiedOPDSLookup
-import logging
from .config import Configuration
-from .metadata_layer import (
- CSVMetadataImporter,
- ReplacementPolicy,
-)
+from .metadata_layer import CSVMetadataImporter, ReplacementPolicy
from .model import (
- get_one,
- get_one_or_create,
Classification,
CustomList,
CustomListEntry,
@@ -24,30 +19,37 @@
Identifier,
Subject,
Work,
+ get_one,
+ get_one_or_create,
)
+from .opds_import import SimplifiedOPDSLookup
from .util import LanguageCodes
from .util.datetime_helpers import utc_now
+
class CustomListFromCSV(CSVMetadataImporter):
"""Create a CustomList, with entries, from a CSV file."""
- def __init__(self, data_source_name, list_name, metadata_client=None,
- overwrite_old_data=False,
- annotation_field='text',
- annotation_author_name_field='name',
- annotation_author_affiliation_field='location',
- first_appearance_field='timestamp',
- **kwargs
- ):
+ def __init__(
+ self,
+ data_source_name,
+ list_name,
+ metadata_client=None,
+ overwrite_old_data=False,
+ annotation_field="text",
+ annotation_author_name_field="name",
+ annotation_author_affiliation_field="location",
+ first_appearance_field="timestamp",
+ **kwargs
+ ):
super(CustomListFromCSV, self).__init__(data_source_name, **kwargs)
self.foreign_identifier = list_name
self.list_name = list_name
- self.overwrite_old_data=overwrite_old_data
+ self.overwrite_old_data = overwrite_old_data
if not metadata_client:
metadata_url = Configuration.integration_url(
- Configuration.METADATA_WRANGLER_INTEGRATION,
- required=True
+ Configuration.METADATA_WRANGLER_INTEGRATION, required=True
)
metadata_client = SimplifiedOPDSLookup(metadata_url)
self.metadata_client = metadata_client
@@ -73,17 +75,16 @@ def to_customlist(self, _db, dictreader):
CustomList,
data_source=data_source,
foreign_identifier=self.foreign_identifier,
- create_method_kwargs = dict(
+ create_method_kwargs=dict(
created=now,
- )
+ ),
)
custom_list.updated = now
# Turn the rows of the CSV file into a sequence of Metadata
# objects, then turn each Metadata into a CustomListEntry object.
for metadata in self.to_metadata(dictreader):
- entry = self.metadata_to_list_entry(
- custom_list, data_source, now, metadata)
+ entry = self.metadata_to_list_entry(custom_list, data_source, now, metadata)
def metadata_to_list_entry(self, custom_list, data_source, now, metadata):
"""Convert a Metadata object to a CustomListEntry."""
@@ -91,7 +92,8 @@ def metadata_to_list_entry(self, custom_list, data_source, now, metadata):
title_from_external_list = self.metadata_to_title(now, metadata)
list_entry, was_new = title_from_external_list.to_custom_list_entry(
- custom_list, self.metadata_client, self.overwrite_old_data)
+ custom_list, self.metadata_client, self.overwrite_old_data
+ )
e = list_entry.edition
if not e:
@@ -99,16 +101,17 @@ def metadata_to_list_entry(self, custom_list, data_source, now, metadata):
# couldn't find a useful Identifier.
self.log.info("Could not create edition for %s", metadata.title)
else:
- q = _db.query(Work).join(Work.presentation_edition).filter(
- Edition.permanent_work_id==e.permanent_work_id)
+ q = (
+ _db.query(Work)
+ .join(Work.presentation_edition)
+ .filter(Edition.permanent_work_id == e.permanent_work_id)
+ )
if q.count() > 0:
- self.log.info("Found matching work in collection for %s",
- metadata.title
+ self.log.info(
+ "Found matching work in collection for %s", metadata.title
)
else:
- self.log.info("No matching work found for %s",
- metadata.title
- )
+ self.log.info("No matching work found for %s", metadata.title)
return list_entry
def metadata_to_title(self, now, metadata):
@@ -124,23 +127,24 @@ def metadata_to_title(self, now, metadata):
metadata=metadata,
first_appearance=first_appearance,
most_recent_appearance=now,
- annotation=annotation
+ annotation=annotation,
)
def annotation_citation(self, row):
"""Extract a citation for an annotation from a row of a CSV file."""
annotation_author = self._field(row, self.annotation_author_name_field)
annotation_author_affiliation = self._field(
- row, self.annotation_author_affiliation_field)
+ row, self.annotation_author_affiliation_field
+ )
if annotation_author_affiliation == annotation_author:
annotation_author_affiliation = None
- annotation_extra = ''
+ annotation_extra = ""
if annotation_author:
annotation_extra = annotation_author
if annotation_author_affiliation:
- annotation_extra += ', ' + annotation_author_affiliation
+ annotation_extra += ", " + annotation_author_affiliation
if annotation_extra:
- return ' —' + annotation_extra
+ return " —" + annotation_extra
return None
@@ -150,18 +154,16 @@ class TitleFromExternalList(object):
Edition and CustomListEntry objects.
"""
- def __init__(self, metadata, first_appearance, most_recent_appearance,
- annotation):
+ def __init__(self, metadata, first_appearance, most_recent_appearance, annotation):
self.log = logging.getLogger("Title from external list")
self.metadata = metadata
self.first_appearance = first_appearance or most_recent_appearance
- self.most_recent_appearance = (
- most_recent_appearance or utc_now()
- )
+ self.most_recent_appearance = most_recent_appearance or utc_now()
self.annotation = annotation
- def to_custom_list_entry(self, custom_list, metadata_client,
- overwrite_old_data=False):
+ def to_custom_list_entry(
+ self, custom_list, metadata_client, overwrite_old_data=False
+ ):
"""Turn this object into a CustomListEntry with associated Edition."""
_db = Session.object_session(custom_list)
edition = self.to_edition(_db, metadata_client, overwrite_old_data)
@@ -170,23 +172,29 @@ def to_custom_list_entry(self, custom_list, metadata_client,
_db, CustomListEntry, edition=edition, customlist=custom_list
)
- if (not list_entry.first_appearance
- or list_entry.first_appearance > self.first_appearance):
+ if (
+ not list_entry.first_appearance
+ or list_entry.first_appearance > self.first_appearance
+ ):
if list_entry.first_appearance:
self.log.info(
"I thought %s first showed up at %s, but then I saw it earlier, at %s!",
- self.metadata.title, list_entry.first_appearance,
- self.first_appearance
+ self.metadata.title,
+ list_entry.first_appearance,
+ self.first_appearance,
)
list_entry.first_appearance = self.first_appearance
- if (not list_entry.most_recent_appearance
- or list_entry.most_recent_appearance < self.most_recent_appearance):
+ if (
+ not list_entry.most_recent_appearance
+ or list_entry.most_recent_appearance < self.most_recent_appearance
+ ):
if list_entry.most_recent_appearance:
self.log.info(
"I thought %s most recently showed up at %s, but then I saw it later, at %s!",
- self.metadata.title, list_entry.most_recent_appearance,
- self.most_recent_appearance
+ self.metadata.title,
+ list_entry.most_recent_appearance,
+ self.most_recent_appearance,
)
list_entry.most_recent_appearance = self.most_recent_appearance
@@ -218,8 +226,7 @@ def to_edition(self, _db, metadata_client, overwrite_old_data=False):
Edition's primary identifier may be associated with the other
Editions' primary identifiers. (p=0.85)
"""
- self.log.info("Converting %s to an Edition object.",
- self.metadata.title)
+ self.log.info("Converting %s to an Edition object.", self.metadata.title)
# Make sure the Metadata object's view of the book is present
# as an Edition. This will also associate all its identifiers
@@ -228,18 +235,14 @@ def to_edition(self, _db, metadata_client, overwrite_old_data=False):
try:
edition, is_new = self.metadata.edition(_db)
except ValueError as e:
- self.log.info(
- "Ignoring %s, no corresponding edition.", self.metadata.title
- )
+ self.log.info("Ignoring %s, no corresponding edition.", self.metadata.title)
return None
if overwrite_old_data:
policy = ReplacementPolicy.from_metadata_source(
even_if_not_apparently_updated=True
)
else:
- policy = ReplacementPolicy.append_only(
- even_if_not_apparently_updated=True
- )
+ policy = ReplacementPolicy.append_only(even_if_not_apparently_updated=True)
self.metadata.apply(
edition=edition,
collection=None,
@@ -308,6 +311,7 @@ class ClassificationBasedMembershipManager(MembershipManager):
"""Manage a custom list containing all Editions whose primary
Identifier is classified under one of the given subject fragments.
"""
+
def __init__(self, custom_list, subject_fragments):
super(ClassificationBasedMembershipManager, self).__init__(custom_list)
self.subject_fragments = subject_fragments
@@ -320,18 +324,17 @@ def new_membership(self):
"""
subject_clause = None
for i in self.subject_fragments:
- c = Subject.identifier.ilike('%' + i + '%')
+ c = Subject.identifier.ilike("%" + i + "%")
if subject_clause is None:
subject_clause = c
else:
subject_clause = or_(subject_clause, c)
- qu = self._db.query(Edition).distinct(Edition.id).join(
- Edition.primary_identifier
- ).join(
- Identifier.classifications
- ).join(
- Classification.subject
+ qu = (
+ self._db.query(Edition)
+ .distinct(Edition.id)
+ .join(Edition.primary_identifier)
+ .join(Identifier.classifications)
+ .join(Classification.subject)
)
qu = qu.filter(subject_clause)
return qu
-
diff --git a/external_search.py b/external_search.py
index 40bc75475..75a31bd30 100644
--- a/external_search.py
+++ b/external_search.py
@@ -1,20 +1,16 @@
-from collections import defaultdict
import contextlib
import datetime
-
import json
+import logging
+import os
+import re
+import time
+from collections import defaultdict
+
from elasticsearch import Elasticsearch
+from elasticsearch.exceptions import ElasticsearchException, RequestError
from elasticsearch.helpers import bulk as elasticsearch_bulk
-from elasticsearch.exceptions import (
- RequestError,
- ElasticsearchException,
-)
-from elasticsearch_dsl import (
- Index,
- MultiSearch,
- Search,
- SF,
-)
+from elasticsearch_dsl import SF, Index, MultiSearch, Search
from elasticsearch_dsl.query import (
Bool,
DisMax,
@@ -26,31 +22,27 @@
MatchPhrase,
MultiMatch,
Nested,
- Query as BaseQuery,
- SimpleQueryString,
- Term,
- Terms,
)
+from elasticsearch_dsl.query import Query as BaseQuery
+from elasticsearch_dsl.query import SimpleQueryString, Term, Terms
+from flask_babel import lazy_gettext as _
from spellchecker import SpellChecker
-from flask_babel import lazy_gettext as _
-from .config import (
- Configuration,
- CannotLoadConfiguration,
-)
from .classifier import (
- KeywordBasedClassifier,
- GradeLevelClassifier,
AgeClassifier,
Classifier,
+ GradeLevelClassifier,
+ KeywordBasedClassifier,
)
+from .config import CannotLoadConfiguration, Configuration
+from .coverage import CoverageFailure, WorkPresentationProvider
from .facets import FacetConstants
+from .lane import Pagination
from .metadata_layer import IdentifierData
from .model import (
- numericrange_to_tuple,
Collection,
- Contributor,
ConfigurationSetting,
+ Contributor,
DataSource,
Edition,
ExternalIntegration,
@@ -58,27 +50,16 @@
Library,
Work,
WorkCoverageRecord,
+ numericrange_to_tuple,
)
-from .lane import Pagination
from .monitor import WorkSweepMonitor
-from .coverage import (
- CoverageFailure,
- WorkPresentationProvider,
-)
from .problem_details import INVALID_INPUT
-from .selftest import (
- HasSelfTests,
- SelfTestResult,
-)
+from .selftest import HasSelfTests, SelfTestResult
+from .util.datetime_helpers import from_timestamp
from .util.personal_names import display_name_to_sort_name
from .util.problem_detail import ProblemDetail
from .util.stopwords import ENGLISH_STOPWORDS
-from .util.datetime_helpers import from_timestamp
-import os
-import logging
-import re
-import time
@contextlib.contextmanager
def mock_search_index(mock=None):
@@ -101,30 +82,40 @@ class ExternalSearchIndex(HasSelfTests):
# instantiating new ExternalSearchIndex objects.
MOCK_IMPLEMENTATION = None
- WORKS_INDEX_PREFIX_KEY = 'works_index_prefix'
- DEFAULT_WORKS_INDEX_PREFIX = 'circulation-works'
+ WORKS_INDEX_PREFIX_KEY = "works_index_prefix"
+ DEFAULT_WORKS_INDEX_PREFIX = "circulation-works"
- TEST_SEARCH_TERM_KEY = 'test_search_term'
- DEFAULT_TEST_SEARCH_TERM = 'test'
+ TEST_SEARCH_TERM_KEY = "test_search_term"
+ DEFAULT_TEST_SEARCH_TERM = "test"
- work_document_type = 'work-type'
+ work_document_type = "work-type"
__client = None
- CURRENT_ALIAS_SUFFIX = 'current'
- VERSION_RE = re.compile('-v([0-9]+)$')
+ CURRENT_ALIAS_SUFFIX = "current"
+ VERSION_RE = re.compile("-v([0-9]+)$")
SETTINGS = [
- { "key": ExternalIntegration.URL, "label": _("URL"), "required": True, "format": "url" },
- { "key": WORKS_INDEX_PREFIX_KEY, "label": _("Index prefix"),
- "default": DEFAULT_WORKS_INDEX_PREFIX,
- "required": True,
- "description": _("Any Elasticsearch indexes needed for this application will be created with this unique prefix. In most cases, the default will work fine. You may need to change this if you have multiple application servers using a single Elasticsearch server.")
+ {
+ "key": ExternalIntegration.URL,
+ "label": _("URL"),
+ "required": True,
+ "format": "url",
+ },
+ {
+ "key": WORKS_INDEX_PREFIX_KEY,
+ "label": _("Index prefix"),
+ "default": DEFAULT_WORKS_INDEX_PREFIX,
+ "required": True,
+ "description": _(
+ "Any Elasticsearch indexes needed for this application will be created with this unique prefix. In most cases, the default will work fine. You may need to change this if you have multiple application servers using a single Elasticsearch server."
+ ),
+ },
+ {
+ "key": TEST_SEARCH_TERM_KEY,
+ "label": _("Test search term"),
+ "default": DEFAULT_TEST_SEARCH_TERM,
+ "description": _("Self tests will use this value as the search term."),
},
- { "key": TEST_SEARCH_TERM_KEY,
- "label": _("Test search term"),
- "default": DEFAULT_TEST_SEARCH_TERM,
- "description": _("Self tests will use this value as the search term.")
- }
]
SITEWIDE = True
@@ -142,8 +133,7 @@ def reset(cls):
def search_integration(cls, _db):
"""Look up the ExternalIntegration for ElasticSearch."""
return ExternalIntegration.lookup(
- _db, ExternalIntegration.ELASTICSEARCH,
- goal=ExternalIntegration.SEARCH_GOAL
+ _db, ExternalIntegration.ELASTICSEARCH, goal=ExternalIntegration.SEARCH_GOAL
)
@classmethod
@@ -159,7 +149,7 @@ def works_prefixed(cls, _db, value):
return None
setting = integration.setting(cls.WORKS_INDEX_PREFIX_KEY)
prefix = setting.value_or_default(cls.DEFAULT_WORKS_INDEX_PREFIX)
- return prefix + '-' + value
+ return prefix + "-" + value
@classmethod
def works_index_name(cls, _db):
@@ -184,8 +174,15 @@ def load(cls, _db, *args, **kwargs):
return cls.MOCK_IMPLEMENTATION
return cls(_db, *args, **kwargs)
- def __init__(self, _db, url=None, works_index=None, test_search_term=None,
- in_testing=False, mapping=None):
+ def __init__(
+ self,
+ _db,
+ url=None,
+ works_index=None,
+ test_search_term=None,
+ in_testing=False,
+ mapping=None,
+ ):
"""Constructor
:param in_testing: Set this to true if you don't want an
@@ -221,21 +218,17 @@ def __init__(self, _db, url=None, works_index=None, test_search_term=None,
url = url or integration.url
if not works_index:
works_index = self.works_index_name(_db)
- test_search_term = integration.setting(
- self.TEST_SEARCH_TERM_KEY).value
+ test_search_term = integration.setting(self.TEST_SEARCH_TERM_KEY).value
if not url:
- raise CannotLoadConfiguration(
- "No URL configured to Elasticsearch server."
- )
- self.test_search_term = (
- test_search_term or self.DEFAULT_TEST_SEARCH_TERM
- )
+ raise CannotLoadConfiguration("No URL configured to Elasticsearch server.")
+ self.test_search_term = test_search_term or self.DEFAULT_TEST_SEARCH_TERM
if not in_testing:
if not ExternalSearchIndex.__client:
- use_ssl = url.startswith('https://')
+ use_ssl = url.startswith("https://")
self.log.info(
"Connecting to index %s in Elasticsearch cluster at %s",
- works_index, url
+ works_index,
+ url,
)
ExternalSearchIndex.__client = Elasticsearch(
url, use_ssl=use_ssl, timeout=20, maxsize=25
@@ -259,14 +252,14 @@ def __init__(self, _db, url=None, works_index=None, test_search_term=None,
raise
except ElasticsearchException as e:
raise CannotLoadConfiguration(
- "Exception communicating with Elasticsearch server: %s" %
- repr(e)
+ "Exception communicating with Elasticsearch server: %s" % repr(e)
)
self.search = Search(using=self.__client, index=self.works_alias)
def bulk(docs, **kwargs):
return elasticsearch_bulk(self.__client, docs, **kwargs)
+
self.bulk = bulk
def set_works_index_and_alias(self, _db):
@@ -320,10 +313,8 @@ def _use_as_works_alias(name):
return
# Create the alias and search against it.
- response = self.indices.put_alias(
- index=self.works_index, name=alias_name
- )
- if not response.get('acknowledged'):
+ response = self.indices.put_alias(index=self.works_index, name=alias_name)
+ if not response.get("acknowledged"):
self.log.error("Alias '%s' could not be created", alias_name)
# Work against the index instead of an alias.
_use_as_works_alias(self.works_index)
@@ -345,7 +336,7 @@ def setup_index(self, new_index=None, **index_settings):
self.log.info("Creating index %s", index_name)
body = self.mapping.body()
- body.setdefault('settings', {}).update(index_settings)
+ body.setdefault("settings", {}).update(index_settings)
index = self.indices.create(index=index_name, body=body)
def set_stored_scripts(self):
@@ -365,17 +356,20 @@ def set_stored_scripts(self):
def transfer_current_alias(self, _db, new_index):
"""Force -current alias onto a new index"""
if not self.indices.exists(index=new_index):
- raise ValueError(
- "Index '%s' does not exist on this client." % new_index)
+ raise ValueError("Index '%s' does not exist on this client." % new_index)
current_base_name = self.base_index_name(self.works_index)
new_base_name = self.base_index_name(new_index)
if new_base_name != current_base_name:
raise ValueError(
- ("Index '%s' is not in series with current index '%s'. "
- "Confirm the base name (without version number) of both indices"
- "is the same.") % (new_index, self.works_index))
+ (
+ "Index '%s' is not in series with current index '%s'. "
+ "Confirm the base name (without version number) of both indices"
+ "is the same."
+ )
+ % (new_index, self.works_index)
+ )
self.works_index = self.__client.works_index = new_index
alias_name = self.works_alias_name(_db)
@@ -401,24 +395,21 @@ def transfer_current_alias(self, _db, new_index):
# The alias exists on one or more other indices. Remove
# the alias altogether, then put it back on the works
# index.
- self.indices.delete_alias(index='_all', name=alias_name)
- self.indices.put_alias(
- index=self.works_index, name=alias_name
- )
+ self.indices.delete_alias(index="_all", name=alias_name)
+ self.indices.put_alias(index=self.works_index, name=alias_name)
self.works_alias = self.__client.works_alias = alias_name
def base_index_name(self, index_or_alias):
"""Removes version or current suffix from base index name"""
- current_re = re.compile(self.CURRENT_ALIAS_SUFFIX+'$')
- base_works_index = re.sub(current_re, '', index_or_alias)
- base_works_index = re.sub(self.VERSION_RE, '', base_works_index)
+ current_re = re.compile(self.CURRENT_ALIAS_SUFFIX + "$")
+ base_works_index = re.sub(current_re, "", index_or_alias)
+ base_works_index = re.sub(self.VERSION_RE, "", base_works_index)
return base_works_index
- def create_search_doc(self, query_string, filter, pagination,
- debug):
+ def create_search_doc(self, query_string, filter, pagination, debug):
query = Query(query_string, filter)
search = query.build(self.search, pagination)
@@ -433,7 +424,7 @@ def create_search_doc(self, query_string, filter, pagination,
# Don't restrict the fields at all -- get everything.
# This makes it easy to investigate everything about the
# results we do get.
- fields = ['*']
+ fields = ["*"]
else:
# All we absolutely need is the work ID, which is a
# key into the database, plus the values of any script fields,
@@ -448,8 +439,7 @@ def create_search_doc(self, query_string, filter, pagination,
search = search.source(fields)
return search
- def query_works(self, query_string, filter=None, pagination=None,
- debug=False):
+ def query_works(self, query_string, filter=None, pagination=None, debug=False):
"""Run a search query.
This works by calling query_works_multi().
@@ -513,7 +503,7 @@ def query_works_multi(self, queries, debug=False):
function_score = FunctionScore(
query=dict(match_all=dict()),
functions=function_scores,
- score_mode="sum"
+ score_mode="sum",
)
search = search.query(function_score)
multi = multi.add(search)
@@ -526,15 +516,18 @@ def query_works_multi(self, queries, debug=False):
if debug:
b = time.time()
self.log.debug(
- "Elasticsearch query %r completed in %.3fsec",
- query_string, b-a
+ "Elasticsearch query %r completed in %.3fsec", query_string, b - a
)
for results in resultset:
for i, result in enumerate(results):
self.log.debug(
'%02d "%s" (%s) work=%s score=%.3f shard=%s',
- i, result.sort_title, result.sort_author, result.meta['id'],
- result.meta.explanation['value'] or 0, result.meta['shard']
+ i,
+ result.sort_title,
+ result.sort_author,
+ result.meta["id"],
+ result.meta.explanation["value"] or 0,
+ result.meta["shard"],
)
for i, results in enumerate(resultset):
@@ -595,26 +588,34 @@ def bulk_update(self, works, retry_on_batch_failure=True):
docs = []
time3 = time.time()
- self.log.info("Created %i search documents in %.2f seconds" % (len(docs), time2 - time1))
- self.log.info("Uploaded %i search documents in %.2f seconds" % (len(docs), time3 - time2))
+ self.log.info(
+ "Created %i search documents in %.2f seconds" % (len(docs), time2 - time1)
+ )
+ self.log.info(
+ "Uploaded %i search documents in %.2f seconds" % (len(docs), time3 - time2)
+ )
- doc_ids = [d['_id'] for d in docs]
+ doc_ids = [d["_id"] for d in docs]
# We weren't able to create search documents for these works, maybe
# because they don't have presentation editions yet.
def get_error_id(error):
- return error.get('data', {}).get('_id', None) or error.get('index', {}).get('_id', None)
+ return error.get("data", {}).get("_id", None) or error.get("index", {}).get(
+ "_id", None
+ )
+
error_ids = [get_error_id(error) for error in errors]
missing_works = [
- work for work in works
- if work.id not in doc_ids and work.id not in error_ids
+ work
+ for work in works
+ if work.id not in doc_ids
+ and work.id not in error_ids
and work not in successes
]
successes.extend(
- [work for work in works
- if work.id in doc_ids and work.id not in error_ids]
+ [work for work in works if work.id in doc_ids and work.id not in error_ids]
)
failures = []
@@ -629,22 +630,25 @@ def get_error_id(error):
if works_with_error:
work = works_with_error[0]
- exception = error.get('exception', None)
- error_message = error.get('error', None)
+ exception = error.get("exception", None)
+ error_message = error.get("error", None)
if not error_message:
- error_message = error.get('index', {}).get('error', None)
+ error_message = error.get("index", {}).get("error", None)
failures.append((work, error_message))
- self.log.info("Successfully indexed %i documents, failed to index %i." % (success_count, len(failures)))
+ self.log.info(
+ "Successfully indexed %i documents, failed to index %i."
+ % (success_count, len(failures))
+ )
return successes, failures
def remove_work(self, work):
- """Remove the search document for `work` from the search index.
- """
- args = dict(index=self.works_index, doc_type=self.work_document_type,
- id=work.id)
+ """Remove the search document for `work` from the search index."""
+ args = dict(
+ index=self.works_index, doc_type=self.work_document_type, id=work.id
+ )
if self.exists(**args):
self.delete(**args)
@@ -653,25 +657,22 @@ def _run_self_tests(self, _db, in_testing=False):
def _search():
return self.create_search_doc(
- self.test_search_term, filter=None,
- pagination=None, debug=True
+ self.test_search_term, filter=None, pagination=None, debug=True
)
def _works():
return self.query_works(
- self.test_search_term, filter=None, pagination=None,
- debug=True
+ self.test_search_term, filter=None, pagination=None, debug=True
)
# The self-tests:
def _search_for_term():
- titles = [("%s (%s)" %(x.sort_title, x.sort_author)) for x in _works()]
+ titles = [("%s (%s)" % (x.sort_title, x.sort_author)) for x in _works()]
return titles
yield self.run_test(
- ("Search results for '%s':" %(self.test_search_term)),
- _search_for_term
+ ("Search results for '%s':" % (self.test_search_term)), _search_for_term
)
def _get_raw_doc():
@@ -683,16 +684,14 @@ def _get_raw_doc():
return json.dumps(search.to_dict(), indent=1)
yield self.run_test(
- ("Search document for '%s':" %(self.test_search_term)),
- _get_raw_doc
+ ("Search document for '%s':" % (self.test_search_term)), _get_raw_doc
)
def _get_raw_results():
return [json.dumps(x.to_dict(), indent=1) for x in _works()]
yield self.run_test(
- ("Raw search results for '%s':" %(self.test_search_term)),
- _get_raw_results
+ ("Raw search results for '%s':" % (self.test_search_term)), _get_raw_results
)
def _count_docs():
@@ -702,16 +701,15 @@ def _count_docs():
return str(self.search.count())
yield self.run_test(
- ("Total number of search results for '%s':" %(self.test_search_term)),
- _count_docs
+ ("Total number of search results for '%s':" % (self.test_search_term)),
+ _count_docs,
)
def _total_count():
return str(self.count_works(None))
yield self.run_test(
- "Total number of documents in this search index:",
- _total_count
+ "Total number of documents in this search index:", _total_count
)
def _collections():
@@ -724,10 +722,7 @@ def _collections():
return json.dumps(result, indent=1)
- yield self.run_test(
- "Total number of documents per collection:",
- _collections
- )
+ yield self.run_test("Total number of documents per collection:", _collections)
class MappingDocument(object):
@@ -755,7 +750,7 @@ def add_property(self, name, type, **description):
# side-by-side comparison.
defaults = dict(index=True, store=False)
- description['type'] = type
+ description["type"] = type
for default_name, default_value in list(defaults.items()):
if default_name not in description:
description[default_name] = default_value
@@ -798,16 +793,13 @@ def basic_text_property_hook(self, description):
using an analyzer that leaves stopwords in place, for searches
that rely on stopwords.
"""
- description['type'] = 'text'
- description['analyzer'] = 'en_default_text_analyzer'
- description['fields'] = {
- "minimal": {
- "type": "text",
- "analyzer": "en_minimal_text_analyzer"
- },
+ description["type"] = "text"
+ description["analyzer"] = "en_default_text_analyzer"
+ description["fields"] = {
+ "minimal": {"type": "text", "analyzer": "en_minimal_text_analyzer"},
"with_stopwords": {
"type": "text",
- "analyzer": "en_with_stopwords_text_analyzer"
+ "analyzer": "en_with_stopwords_text_analyzer",
},
}
@@ -846,8 +838,8 @@ def version_name(cls):
version = cls.VERSION_NAME
if not version:
raise NotImplementedError("VERSION_NAME not defined")
- if not version.startswith('v'):
- version = 'v%s' % version
+ if not version.startswith("v"):
+ version = "v%s" % version
return version
@classmethod
@@ -872,7 +864,7 @@ def create(self, search_client, base_index_name):
:return: True or False, indicating whether the index was created new.
"""
- versioned_index = base_index_name+'-'+self.version_name()
+ versioned_index = base_index_name + "-" + self.version_name()
if search_client.indices.exists(index=versioned_index):
return False
else:
@@ -881,9 +873,9 @@ def create(self, search_client, base_index_name):
def sort_author_keyword_property_hook(self, description):
"""Give the `sort_author` property its custom analyzer."""
- description['type'] = 'text'
- description['analyzer'] = 'en_sort_author_analyzer'
- description['fielddata'] = True
+ description["type"] = "text"
+ description["analyzer"] = "en_sort_author_analyzer"
+ description["fielddata"] = True
def body(self):
"""Generate the body of the mapping document for this version of the
@@ -894,7 +886,7 @@ def body(self):
filter=self.filters,
char_filter=self.char_filters,
normalizer=self.normalizers,
- analyzer=self.analyzers
+ analyzer=self.analyzers,
)
)
@@ -903,13 +895,9 @@ def body(self):
# Add subdocuments as additional properties.
for name, subdocument in list(self.subdocuments.items()):
- properties[name] = dict(
- type="nested", properties=subdocument.properties
- )
+ properties[name] = dict(type="nested", properties=subdocument.properties)
- mappings = {
- ExternalSearchIndex.work_document_type : dict(properties=properties)
- }
+ mappings = {ExternalSearchIndex.work_document_type: dict(properties=properties)}
return dict(settings=settings, mappings=mappings)
@@ -935,7 +923,8 @@ class CurrentMapping(Mapping):
# becomes "H G Wells" becomes "HG Wells".
CHAR_FILTERS = {
"remove_apostrophes": dict(
- type="pattern_replace", pattern="'",
+ type="pattern_replace",
+ pattern="'",
replacement="",
)
}
@@ -944,25 +933,21 @@ class CurrentMapping(Mapping):
# The special author name "[Unknown]" should sort after everything
# else. REPLACEMENT CHARACTER is the final valid Unicode character.
("unknown_author", "\[Unknown\]", "\N{REPLACEMENT CHARACTER}"),
-
# Works by a given primary author should be secondarily sorted
# by title, not by the other contributors.
("primary_author_only", "\s+;.*", ""),
-
# Remove parentheticals (e.g. the full name of someone who
# goes by initials).
("strip_parentheticals", "\s+\([^)]+\)", ""),
-
# Remove periods from consideration.
("strip_periods", "\.", ""),
-
# Collapse spaces for people whose sort names end with initials.
("collapse_three_initials", " ([A-Z]) ([A-Z]) ([A-Z])$", " $1$2$3"),
("collapse_two_initials", " ([A-Z]) ([A-Z])$", " $1$2"),
]:
- normalizer = dict(type="pattern_replace",
- pattern=pattern,
- replacement=replacement)
+ normalizer = dict(
+ type="pattern_replace", pattern=pattern, replacement=replacement
+ )
CHAR_FILTERS[name] = normalizer
AUTHOR_CHAR_FILTER_NAMES.append(name)
@@ -978,7 +963,7 @@ def __init__(self):
# e.g. ignore capitalization when considering whether
# two books belong to the same series or whether two
# author names are the same.
- self.normalizers['filterable_string'] = dict(
+ self.normalizers["filterable_string"] = dict(
type="custom", filter=["lowercase", "asciifolding"]
)
@@ -1011,22 +996,18 @@ def __init__(self):
# analyzer. We'll be using these to build our own analyzers.
# Filter out English stopwords.
- self.filters['english_stop'] = dict(
- type="stop", stopwords=["_english_"]
- )
+ self.filters["english_stop"] = dict(type="stop", stopwords=["_english_"])
# The default English stemmer, used in the en_default analyzer.
- self.filters['english_stemmer'] = dict(
- type="stemmer", language="english"
- )
+ self.filters["english_stemmer"] = dict(type="stemmer", language="english")
# A less aggressive English stemmer, used in the en_minimal analyzer.
- self.filters['minimal_english_stemmer'] = dict(
+ self.filters["minimal_english_stemmer"] = dict(
type="stemmer", language="minimal_english"
)
# A filter that removes English posessives such as "'s"
- self.filters['english_posessive_stemmer'] = dict(
+ self.filters["english_posessive_stemmer"] = dict(
type="stemmer", language="possessive_english"
)
@@ -1042,34 +1023,36 @@ def __init__(self):
# default configuration for the English analyzer.
common_text_analyzer = dict(
type="custom",
- char_filter=["html_strip", "remove_apostrophes"], # NEW
+ char_filter=["html_strip", "remove_apostrophes"], # NEW
tokenizer="standard",
)
common_filter = [
"lowercase",
- "asciifolding", # NEW
+ "asciifolding", # NEW
]
# The default_text_analyzer uses Elasticsearch's standard
# English stemmer and removes stopwords.
- self.analyzers['en_default_text_analyzer'] = dict(common_text_analyzer)
- self.analyzers['en_default_text_analyzer']['filter'] = (
- common_filter + ["english_stop", 'english_stemmer']
- )
+ self.analyzers["en_default_text_analyzer"] = dict(common_text_analyzer)
+ self.analyzers["en_default_text_analyzer"]["filter"] = common_filter + [
+ "english_stop",
+ "english_stemmer",
+ ]
# The minimal_text_analyzer uses a less aggressive English
# stemmer, and removes stopwords.
- self.analyzers['en_minimal_text_analyzer'] = dict(common_text_analyzer)
- self.analyzers['en_minimal_text_analyzer']['filter'] = (
- common_filter + ['english_stop', 'minimal_english_stemmer']
- )
+ self.analyzers["en_minimal_text_analyzer"] = dict(common_text_analyzer)
+ self.analyzers["en_minimal_text_analyzer"]["filter"] = common_filter + [
+ "english_stop",
+ "minimal_english_stemmer",
+ ]
# The en_with_stopwords_text_analyzer uses the less aggressive
# stemmer and does not remove stopwords.
- self.analyzers['en_with_stopwords_text_analyzer'] = dict(common_text_analyzer)
- self.analyzers['en_with_stopwords_text_analyzer']['filter'] = (
- common_filter + ['minimal_english_stemmer']
- )
+ self.analyzers["en_with_stopwords_text_analyzer"] = dict(common_text_analyzer)
+ self.analyzers["en_with_stopwords_text_analyzer"]["filter"] = common_filter + [
+ "minimal_english_stemmer"
+ ]
# Now we need to define a special analyzer used only by the
# 'sort_author' property.
@@ -1077,7 +1060,7 @@ def __init__(self):
# Here's a special filter used only by that analyzer. It
# duplicates the filter used by the icu_collation_keyword data
# type.
- self.filters['en_sortable_filter'] = dict(
+ self.filters["en_sortable_filter"] = dict(
type="icu_collation", language="en", country="US"
)
@@ -1088,63 +1071,66 @@ def __init__(self):
#
# This is necessary because normal icu_collation_keyword
# fields can't specify char_filter.
- self.analyzers['en_sort_author_analyzer'] = dict(
+ self.analyzers["en_sort_author_analyzer"] = dict(
tokenizer="keyword",
- filter = ["en_sortable_filter"],
- char_filter = self.AUTHOR_CHAR_FILTER_NAMES,
+ filter=["en_sortable_filter"],
+ char_filter=self.AUTHOR_CHAR_FILTER_NAMES,
)
# Now, the main event. Set up the field properties for the
# base document.
fields_by_type = {
- "basic_text": ['summary'],
- 'filterable_text': [
- 'title', 'subtitle', 'series', 'classifications.term',
- 'author', 'publisher', 'imprint'
+ "basic_text": ["summary"],
+ "filterable_text": [
+ "title",
+ "subtitle",
+ "series",
+ "classifications.term",
+ "author",
+ "publisher",
+ "imprint",
],
- 'boolean': ['presentation_ready'],
- 'icu_collation_keyword': ['sort_title'],
- 'sort_author_keyword' : ['sort_author'],
- 'integer': ['series_position', 'work_id'],
- 'long': ['last_update_time'],
+ "boolean": ["presentation_ready"],
+ "icu_collation_keyword": ["sort_title"],
+ "sort_author_keyword": ["sort_author"],
+ "integer": ["series_position", "work_id"],
+ "long": ["last_update_time"],
}
self.add_properties(fields_by_type)
# Set up subdocuments.
contributors = self.subdocument("contributors")
contributor_fields = {
- 'filterable_text' : ['sort_name', 'display_name', 'family_name'],
- 'keyword': ['role', 'lc', 'viaf'],
+ "filterable_text": ["sort_name", "display_name", "family_name"],
+ "keyword": ["role", "lc", "viaf"],
}
contributors.add_properties(contributor_fields)
licensepools = self.subdocument("licensepools")
licensepool_fields = {
- 'integer': ['collection_id', 'data_source_id'],
- 'long': ['availability_time'],
- 'boolean': ['available', 'open_access', 'suppressed', 'licensed'],
- 'keyword': ['medium'],
+ "integer": ["collection_id", "data_source_id"],
+ "long": ["availability_time"],
+ "boolean": ["available", "open_access", "suppressed", "licensed"],
+ "keyword": ["medium"],
}
licensepools.add_properties(licensepool_fields)
identifiers = self.subdocument("identifiers")
- identifier_fields = {
- 'keyword': ['identifier', 'type']
- }
+ identifier_fields = {"keyword": ["identifier", "type"]}
identifiers.add_properties(identifier_fields)
genres = self.subdocument("genres")
genre_fields = {
- 'keyword': ['scheme', 'name', 'term'],
- 'float': ['weight'],
+ "keyword": ["scheme", "name", "term"],
+ "float": ["weight"],
}
genres.add_properties(genre_fields)
customlists = self.subdocument("customlists")
customlist_fields = {
- 'integer': ['list_id'],
- 'long': ['first_appearance'],
- 'boolean': ['featured'],
+ "integer": ["list_id"],
+ "long": ["first_appearance"],
+ "boolean": ["featured"],
}
customlists.add_properties(customlist_fields)
@@ -1241,11 +1227,11 @@ def _nest(cls, subdocument, query):
@classmethod
def _nestable(cls, field, query):
"""Make a query against a field nestable, if necessary."""
- if 's.' in field:
+ if "s." in field:
# This is a query against a field from a subdocument. We
# can't run it against the top-level document; it has to
# be run in the context of its subdocument.
- subdocument = field.split('.', 1)[0]
+ subdocument = field.split(".", 1)[0]
query = cls._nest(subdocument, query)
return query
@@ -1265,7 +1251,7 @@ def _match_range(cls, field, operation, value):
e.g. _match_range("field.name", "gte", 5) will match
any value for field.name greater than 5.
"""
- match = {field : {operation: value}}
+ match = {field: {operation: value}}
return dict(range=match)
@classmethod
@@ -1282,7 +1268,7 @@ def make_target_age_query(cls, target_age, boost=1.1):
# There must be _some_ overlap with the provided range.
must = [
cls._match_range("target_age.upper", "gte", lower),
- cls._match_range("target_age.lower", "lte", upper)
+ cls._match_range("target_age.lower", "lte", upper),
]
# Results with ranges contained within the query range are
@@ -1311,6 +1297,7 @@ def _combine_hypotheses(self, hypotheses):
qu = MatchAll()
return qu
+
class Query(SearchBase):
"""An attempt to find something in the search index."""
@@ -1331,8 +1318,8 @@ class Query(SearchBase):
)
# The contributor names in the contributors sub-document have the
# same weight as the 'author' field in the main document.
- for field in ['contributors.sort_name', 'contributors.display_name']:
- WEIGHT_FOR_FIELD[field] = WEIGHT_FOR_FIELD['author']
+ for field in ["contributors.sort_name", "contributors.display_name"]:
+ WEIGHT_FOR_FIELD[field] = WEIGHT_FOR_FIELD["author"]
# When someone searches for a person's name, they're most likely
# searching for that person's contributions in one of these roles.
@@ -1380,9 +1367,7 @@ class Query(SearchBase):
# For each of these fields, we're going to test the hypothesis
# that the query string is nothing but an attempt to match this
# field.
- SIMPLE_MATCH_FIELDS = [
- 'title', 'subtitle', 'series', 'publisher', 'imprint'
- ]
+ SIMPLE_MATCH_FIELDS = ["title", "subtitle", "series", "publisher", "imprint"]
# For each of these fields, we're going to test the hypothesis
# that the query string contains words from the book's title
@@ -1392,18 +1377,18 @@ class Query(SearchBase):
# looking at the .author field -- the display name of the primary
# author associated with the Work's presentation Editon -- not
# the .display_names in the 'contributors' subdocument.
- MULTI_MATCH_FIELDS = ['subtitle', 'series', 'author']
+ MULTI_MATCH_FIELDS = ["subtitle", "series", "author"]
# For each of these fields, we're going to test the hypothesis
# that the query string is a good match for an aggressively
# stemmed version of this field.
- STEMMABLE_FIELDS = ['title', 'subtitle', 'series']
+ STEMMABLE_FIELDS = ["title", "subtitle", "series"]
# Although we index all text fields using an analyzer that
# preserves stopwords, these are the only fields where we
# currently think it's worth testing a hypothesis that stopwords
# in a query string are _important_.
- STOPWORD_FIELDS = ['title', 'subtitle', 'series']
+ STOPWORD_FIELDS = ["title", "subtitle", "series"]
# SpellChecker is expensive to initialize, so keep around
# a class-level instance.
@@ -1499,9 +1484,7 @@ def build(self, elasticsearch, pagination=None):
# filter -- works must be presentation-ready, etc.
universal_base_filter = Filter.universal_base_filter()
if universal_base_filter:
- query_filter = Filter._chain_filters(
- base_filter, universal_base_filter
- )
+ query_filter = Filter._chain_filters(base_filter, universal_base_filter)
else:
query_filter = base_filter
if query_filter:
@@ -1527,7 +1510,7 @@ def build(self, elasticsearch, pagination=None):
# filter context rather than query context.
subquery = Bool(filter=subfilter)
search = search.filter(
- name_or_query='nested', path=path, query=subquery
+ name_or_query="nested", path=path, query=subquery
)
if self.filter:
@@ -1633,8 +1616,11 @@ def elasticsearch_query(self):
# better results by filtering out junk.
boost = self.SLIGHTLY_ABOVE_BASELINE
self._hypothesize(
- hypotheses, sub_hypotheses, boost, all_must_match=True,
- filters=filters
+ hypotheses,
+ sub_hypotheses,
+ boost,
+ all_must_match=True,
+ filters=filters,
)
# That's it!
@@ -1666,23 +1652,19 @@ def match_one_field_hypotheses(self, base_field, query_string=None):
query_string = query_string or self.query_string
- keyword_match_coefficient = (
- self.KEYWORD_MATCH_COEFFICIENT_FOR_FIELD.get(
- base_field,
- self.DEFAULT_KEYWORD_MATCH_COEFFICIENT
- )
+ keyword_match_coefficient = self.KEYWORD_MATCH_COEFFICIENT_FOR_FIELD.get(
+ base_field, self.DEFAULT_KEYWORD_MATCH_COEFFICIENT
)
fields = [
# A keyword match means the field value is a near-exact
# match for the query string. This is one of the best
# search results we can possibly return.
- ('keyword', keyword_match_coefficient, Term),
-
+ ("keyword", keyword_match_coefficient, Term),
# This is the baseline query -- a phrase match against a
# single field. Most queries turn out to represent
# consecutive words from a single field.
- ('minimal', self.BASELINE_COEFFICIENT, MatchPhrase)
+ ("minimal", self.BASELINE_COEFFICIENT, MatchPhrase),
]
if self.contains_stopwords and base_field in self.STOPWORD_FIELDS:
@@ -1691,9 +1673,7 @@ def match_one_field_hypotheses(self, base_field, query_string=None):
#
# Boost this slightly above the baseline so that if
# it matches, it'll beat out baseline queries.
- fields.append(
- ('with_stopwords', self.SLIGHTLY_ABOVE_BASELINE, MatchPhrase)
- )
+ fields.append(("with_stopwords", self.SLIGHTLY_ABOVE_BASELINE, MatchPhrase))
if base_field in self.STEMMABLE_FIELDS:
# This query might benefit from a non-phrase Match against
@@ -1707,7 +1687,7 @@ def match_one_field_hypotheses(self, base_field, query_string=None):
for subfield, match_type_coefficient, query_class in fields:
if subfield:
- field_name = base_field + '.' + subfield
+ field_name = base_field + "." + subfield
else:
field_name = base_field
@@ -1739,7 +1719,7 @@ def match_one_field_hypotheses(self, base_field, query_string=None):
qu = query_class(**kwargs)
yield qu, field_weight
- if self.fuzzy_coefficient and subfield == 'minimal':
+ if self.fuzzy_coefficient and subfield == "minimal":
# Trying one or more fuzzy versions of this hypothesis
# would also be appropriate. We only do fuzzy searches
# on the subfield with minimal stemming, because we
@@ -1764,9 +1744,7 @@ def match_author_hypotheses(self):
# Ask Elasticsearch to match what was typed against
# contributors.display_name.
- for x in self._author_field_must_match(
- 'display_name', self.query_string
- ):
+ for x in self._author_field_must_match("display_name", self.query_string):
yield x
# Although almost nobody types a sort name into a search box,
@@ -1776,7 +1754,7 @@ def match_author_hypotheses(self):
# that against contributors.sort_name.
sort_name = display_name_to_sort_name(self.query_string)
if sort_name:
- for x in self._author_field_must_match('sort_name', sort_name):
+ for x in self._author_field_must_match("sort_name", sort_name):
yield x
def _author_field_must_match(self, base_field, query_string=None):
@@ -1791,7 +1769,7 @@ def _author_field_must_match(self, base_field, query_string=None):
:param must_match: The query string to match against.
"""
query_string = query_string or self.query_string
- field_name = 'contributors.%s' % base_field
+ field_name = "contributors.%s" % base_field
for author_matches, weight in self.match_one_field_hypotheses(
field_name, query_string
):
@@ -1817,7 +1795,7 @@ def _role_must_also_match(cls, base_query):
"""
match_role = Terms(**{"contributors.role": cls.SEARCH_RELEVANT_ROLES})
match_both = Bool(must=[base_query, match_role])
- return cls._nest('contributors', match_both)
+ return cls._nest("contributors", match_both)
@property
def match_topic_hypotheses(self):
@@ -1836,7 +1814,7 @@ def match_topic_hypotheses(self):
fields=["summary", "classifications.term"],
type="best_fields",
)
- yield qu, self.WEIGHT_FOR_FIELD['summary']
+ yield qu, self.WEIGHT_FOR_FIELD["summary"]
def title_multi_match_for(self, other_field):
"""Helper method to create a MultiMatch hypothesis that crosses
@@ -1855,20 +1833,19 @@ def title_multi_match_for(self, other_field):
return
# We only search the '.minimal' variants of these fields.
- field_names = ['title.minimal', other_field + ".minimal"]
+ field_names = ["title.minimal", other_field + ".minimal"]
# The weight of this hypothesis should be somewhere between
# the weight of a pure title match, and the weight of a pure
# match against the field we're checking.
- title_weight = self.WEIGHT_FOR_FIELD['title']
+ title_weight = self.WEIGHT_FOR_FIELD["title"]
other_weight = self.WEIGHT_FOR_FIELD[other_field]
- combined_weight = other_weight * (other_weight/title_weight)
+ combined_weight = other_weight * (other_weight / title_weight)
hypothesis = MultiMatch(
query=self.query_string,
- fields = field_names,
+ fields=field_names,
type="cross_fields",
-
# This hypothesis must be able to explain the entire query
# string. Otherwise the weight contributed by the title
# will boost _partial_ title matches over better matches
@@ -1904,15 +1881,15 @@ def _fuzzy_matches(self, field_name, **kwargs):
# max_expansions limits the number of possible alternates
# Elasticsearch will consider for any given word.
kwargs.update(fuzziness="AUTO", max_expansions=2)
- yield Match(**{field_name : kwargs}), self.fuzzy_coefficient * 0.50
+ yield Match(**{field_name: kwargs}), self.fuzzy_coefficient * 0.50
# Assuming that no typoes were made in the first
# character of a word (usually a safe assumption) we
# can bump the score up to 75% of the non-fuzzy
# hypothesis.
kwargs = dict(kwargs)
- kwargs['prefix_length'] = 1
- yield Match(**{field_name : kwargs}), self.fuzzy_coefficient * 0.75
+ kwargs["prefix_length"] = 1
+ yield Match(**{field_name: kwargs}), self.fuzzy_coefficient * 0.75
@classmethod
def _hypothesize(cls, hypotheses, query, boost, filters=None, **kwargs):
@@ -1979,17 +1956,17 @@ def __init__(self, query_string, query_class=Query):
genre, genre_match = KeywordBasedClassifier.genre_match(query_string)
if genre:
query_string = self.add_match_term_filter(
- genre.name, 'genres.name', query_string, genre_match
+ genre.name, "genres.name", query_string, genre_match
)
# Handle the 'young adult' part of 'young adult romance'
- audience, audience_match = KeywordBasedClassifier.audience_match(
- query_string
- )
+ audience, audience_match = KeywordBasedClassifier.audience_match(query_string)
if audience:
query_string = self.add_match_term_filter(
- audience.replace(" ", "").lower(), 'audience', query_string,
- audience_match
+ audience.replace(" ", "").lower(),
+ "audience",
+ query_string,
+ audience_match,
)
# Handle the 'nonfiction' part of 'asteroids nonfiction'
@@ -1999,7 +1976,7 @@ def __init__(self, query_string, query_class=Query):
elif re.compile(r"\bfiction\b", re.IGNORECASE).search(query_string):
fiction = "fiction"
query_string = self.add_match_term_filter(
- fiction, 'fiction', query_string, fiction
+ fiction, "fiction", query_string, fiction
)
# Handle the 'grade 5' part of 'grade 5 dogs'
age_from_grade, grade_match = GradeLevelClassifier.target_age_match(
@@ -2036,8 +2013,10 @@ def __init__(self, query_string, query_class=Query):
# different hypotheses, fuzzy matches, etc. So the simplest thing
# to do is to create a Query object for the smaller search query
# and see what its .elasticsearch_query is.
- if (self.final_query_string
- and self.final_query_string != self.original_query_string):
+ if (
+ self.final_query_string
+ and self.final_query_string != self.original_query_string
+ ):
recursive = self.query_class(
self.final_query_string, use_query_parser=False
).elasticsearch_query
@@ -2093,9 +2072,9 @@ def _without_match(cls, query_string, match):
# dash.
word_boundary_pattern = r"\b%s[\w'\-]*\b"
- return re.compile(
- word_boundary_pattern % match.strip(), re.IGNORECASE
- ).sub("", query_string)
+ return re.compile(word_boundary_pattern % match.strip(), re.IGNORECASE).sub(
+ "", query_string
+ )
class Filter(SearchBase):
@@ -2116,13 +2095,15 @@ class Filter(SearchBase):
# When search results include known script fields, we need to
# wrap the works we would be returning in WorkSearchResults so
# the useful information from the search engine isn't lost.
- KNOWN_SCRIPT_FIELDS = ['last_update']
+ KNOWN_SCRIPT_FIELDS = ["last_update"]
# In general, someone looking for things "by this person" is
# probably looking for one of these roles.
AUTHOR_MATCH_ROLES = list(Contributor.AUTHOR_ROLES) + [
- Contributor.NARRATOR_ROLE, Contributor.EDITOR_ROLE,
- Contributor.DIRECTOR_ROLE, Contributor.ACTOR_ROLE
+ Contributor.NARRATOR_ROLE,
+ Contributor.EDITOR_ROLE,
+ Contributor.DIRECTOR_ROLE,
+ Contributor.ACTOR_ROLE,
]
@classmethod
@@ -2137,46 +2118,58 @@ def from_worklist(cls, _db, worklist, facets):
# For most configuration settings there is a single value --
# either defined on the WorkList or defined by its parent.
inherit_one = worklist.inherited_value
- media = inherit_one('media')
- languages = inherit_one('languages')
- fiction = inherit_one('fiction')
- audiences = inherit_one('audiences')
- target_age = inherit_one('target_age')
- collections = inherit_one('collection_ids') or library
+ media = inherit_one("media")
+ languages = inherit_one("languages")
+ fiction = inherit_one("fiction")
+ audiences = inherit_one("audiences")
+ target_age = inherit_one("target_age")
+ collections = inherit_one("collection_ids") or library
- license_datasource_id = inherit_one('license_datasource_id')
+ license_datasource_id = inherit_one("license_datasource_id")
# For genre IDs and CustomList IDs, we might get a separate
# set of restrictions from every item in the WorkList hierarchy.
# _All_ restrictions must be met for a work to match the filter.
inherit_some = worklist.inherited_values
- genre_id_restrictions = inherit_some('genre_ids')
- customlist_id_restrictions = inherit_some('customlist_ids')
+ genre_id_restrictions = inherit_some("genre_ids")
+ customlist_id_restrictions = inherit_some("customlist_ids")
# See if there are any excluded audiobook sources on this
# site.
- excluded = (
- ConfigurationSetting.excluded_audio_data_sources(_db)
- )
- excluded_audiobook_data_sources = [
- DataSource.lookup(_db, x) for x in excluded
- ]
+ excluded = ConfigurationSetting.excluded_audio_data_sources(_db)
+ excluded_audiobook_data_sources = [DataSource.lookup(_db, x) for x in excluded]
if library is None:
allow_holds = True
else:
allow_holds = library.allow_holds
return cls(
- collections, media, languages, fiction, audiences,
- target_age, genre_id_restrictions, customlist_id_restrictions,
+ collections,
+ media,
+ languages,
+ fiction,
+ audiences,
+ target_age,
+ genre_id_restrictions,
+ customlist_id_restrictions,
facets,
excluded_audiobook_data_sources=excluded_audiobook_data_sources,
- allow_holds=allow_holds, license_datasource=license_datasource_id
+ allow_holds=allow_holds,
+ license_datasource=license_datasource_id,
)
- def __init__(self, collections=None, media=None, languages=None,
- fiction=None, audiences=None, target_age=None,
- genre_restriction_sets=None, customlist_restriction_sets=None,
- facets=None, script_fields=None, **kwargs
+ def __init__(
+ self,
+ collections=None,
+ media=None,
+ languages=None,
+ fiction=None,
+ audiences=None,
+ target_age=None,
+ genre_restriction_sets=None,
+ customlist_restriction_sets=None,
+ facets=None,
+ script_fields=None,
+ **kwargs
):
"""Constructor.
@@ -2291,27 +2284,27 @@ def __init__(self, collections=None, media=None, languages=None,
# Pull less-important values out of the keyword arguments.
excluded_audiobook_data_sources = kwargs.pop(
- 'excluded_audiobook_data_sources', []
+ "excluded_audiobook_data_sources", []
)
self.excluded_audiobook_data_sources = self._filter_ids(
excluded_audiobook_data_sources
)
- self.allow_holds = kwargs.pop('allow_holds', True)
+ self.allow_holds = kwargs.pop("allow_holds", True)
- self.updated_after = kwargs.pop('updated_after', None)
+ self.updated_after = kwargs.pop("updated_after", None)
- self.series = kwargs.pop('series', None)
+ self.series = kwargs.pop("series", None)
- self.author = kwargs.pop('author', None)
+ self.author = kwargs.pop("author", None)
- self.min_score = kwargs.pop('min_score', None)
+ self.min_score = kwargs.pop("min_score", None)
- self.match_nothing = kwargs.pop('match_nothing', False)
+ self.match_nothing = kwargs.pop("match_nothing", False)
- license_datasources = kwargs.pop('license_datasource', None)
+ license_datasources = kwargs.pop("license_datasource", None)
self.license_datasources = self._filter_ids(license_datasources)
- identifiers = kwargs.pop('identifiers', [])
+ identifiers = kwargs.pop("identifiers", [])
self.identifiers = list(self._scrub_identifiers(identifiers))
# At this point there should be no keyword arguments -- you can't pass
@@ -2363,8 +2356,10 @@ def audiences(self):
# If YOUNG_ADULT or ADULT is an audience, then ALL_AGES is
# always going to be an additional audience.
- if any(x in as_is for x in [Classifier.AUDIENCE_YOUNG_ADULT,
- Classifier.AUDIENCE_ADULT]):
+ if any(
+ x in as_is
+ for x in [Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_ADULT]
+ ):
return with_all_ages
# At this point, if CHILDREN is _not_ included, we know that
@@ -2375,8 +2370,11 @@ def audiences(self):
# Now we know that CHILDREN is an audience. It's going to come
# down to the upper bound on the target age.
- if (self.target_age and self.target_age[1] is not None
- and self.target_age[1] < Classifier.ALL_AGES_AGE_CUTOFF):
+ if (
+ self.target_age
+ and self.target_age[1] is not None
+ and self.target_age[1] < Classifier.ALL_AGES_AGE_CUTOFF
+ ):
# The audience for this query does not include any kids
# who are expected to have the reading fluency necessary
# for ALL_AGES books.
@@ -2414,20 +2412,18 @@ def build(self, _chain_filters=None):
collection_ids = filter_ids(self.collection_ids)
if collection_ids:
- collection_match = Terms(
- **{'licensepools.collection_id' : collection_ids}
- )
- nested_filters['licensepools'].append(collection_match)
+ collection_match = Terms(**{"licensepools.collection_id": collection_ids})
+ nested_filters["licensepools"].append(collection_match)
license_datasources = filter_ids(self.license_datasources)
if license_datasources:
datasource_match = Terms(
- **{'licensepools.data_source_id' : license_datasources}
+ **{"licensepools.data_source_id": license_datasources}
)
- nested_filters['licensepools'].append(datasource_match)
+ nested_filters["licensepools"].append(datasource_match)
if self.author is not None:
- nested_filters['contributors'].append(self.author_filter)
+ nested_filters["contributors"].append(self.author_filter)
if self.media:
f = chain(f, Terms(medium=scrub_list(self.media)))
@@ -2437,9 +2433,9 @@ def build(self, _chain_filters=None):
if self.fiction is not None:
if self.fiction:
- value = 'fiction'
+ value = "fiction"
else:
- value = 'nonfiction'
+ value = "nonfiction"
f = chain(f, Term(fiction=value))
if self.series:
@@ -2465,41 +2461,39 @@ def build(self, _chain_filters=None):
for genre_ids in self.genre_restriction_sets:
ids = filter_ids(genre_ids)
- nested_filters['genres'].append(
- Terms(**{'genres.term' : filter_ids(genre_ids)})
+ nested_filters["genres"].append(
+ Terms(**{"genres.term": filter_ids(genre_ids)})
)
for customlist_ids in self.customlist_restriction_sets:
ids = filter_ids(customlist_ids)
- nested_filters['customlists'].append(
- Terms(**{'customlists.list_id' : ids})
- )
+ nested_filters["customlists"].append(Terms(**{"customlists.list_id": ids}))
- open_access = Term(**{'licensepools.open_access' : True})
- if self.availability==FacetConstants.AVAILABLE_NOW:
+ open_access = Term(**{"licensepools.open_access": True})
+ if self.availability == FacetConstants.AVAILABLE_NOW:
# Only open-access books and books with currently available
# copies should be displayed.
- available = Term(**{'licensepools.available' : True})
- nested_filters['licensepools'].append(
+ available = Term(**{"licensepools.available": True})
+ nested_filters["licensepools"].append(
Bool(should=[open_access, available], minimum_should_match=1)
)
- elif self.availability==FacetConstants.AVAILABLE_OPEN_ACCESS:
+ elif self.availability == FacetConstants.AVAILABLE_OPEN_ACCESS:
# Only open-access books should be displayed.
- nested_filters['licensepools'].append(open_access)
- elif self.availability==FacetConstants.AVAILABLE_NOT_NOW:
+ nested_filters["licensepools"].append(open_access)
+ elif self.availability == FacetConstants.AVAILABLE_NOT_NOW:
# Only books that are _not_ currently available should be displayed.
- not_open_access = Term(**{'licensepools.open_access' : False})
- licensed = Term(**{'licensepools.licensed' : True})
- not_available = Term(**{'licensepools.available' : False})
- nested_filters['licensepools'].append(
+ not_open_access = Term(**{"licensepools.open_access": False})
+ licensed = Term(**{"licensepools.licensed": True})
+ not_available = Term(**{"licensepools.available": False})
+ nested_filters["licensepools"].append(
Bool(must=[not_open_access, licensed, not_available])
)
- if self.subcollection==FacetConstants.COLLECTION_FEATURED:
+ if self.subcollection == FacetConstants.COLLECTION_FEATURED:
# Exclude books with a quality of less than the library's
# minimum featured quality.
range_query = self._match_range(
- 'quality', 'gte', self.minimum_featured_quality
+ "quality", "gte", self.minimum_featured_quality
)
f = chain(f, Bool(must=range_query))
@@ -2511,38 +2505,34 @@ def build(self, _chain_filters=None):
# Both identifier and type must match for the match
# to count.
for name, value in (
- ('identifier', identifier.identifier),
- ('type', identifier.type),
+ ("identifier", identifier.identifier),
+ ("type", identifier.type),
):
- subclauses.append(
- Term(**{'identifiers.%s' % name : value})
- )
+ subclauses.append(Term(**{"identifiers.%s" % name: value}))
clauses.append(Bool(must=subclauses))
# At least one the identifiers must match for the work to
# match.
identifier_f = Bool(should=clauses, minimum_should_match=1)
- nested_filters['identifiers'].append(identifier_f)
+ nested_filters["identifiers"].append(identifier_f)
# Some sources of audiobooks may be excluded because the
# server can't fulfill them or the anticipated client can't
# play them.
excluded = self.excluded_audiobook_data_sources
if excluded:
- audio = Term(**{'licensepools.medium': Edition.AUDIO_MEDIUM})
- excluded_audio_source = Terms(
- **{'licensepools.data_source_id' : excluded}
- )
+ audio = Term(**{"licensepools.medium": Edition.AUDIO_MEDIUM})
+ excluded_audio_source = Terms(**{"licensepools.data_source_id": excluded})
excluded_audio = Bool(must=[audio, excluded_audio_source])
not_excluded_audio = Bool(must_not=excluded_audio)
- nested_filters['licensepools'].append(not_excluded_audio)
+ nested_filters["licensepools"].append(not_excluded_audio)
# If holds are not allowed, only license pools that are
# currently available should be considered.
if not self.allow_holds:
- licenses_available = Term(**{'licensepools.available' : True})
+ licenses_available = Term(**{"licensepools.available": True})
currently_available = Bool(should=[licenses_available, open_access])
- nested_filters['licensepools'].append(currently_available)
+ nested_filters["licensepools"].append(currently_available)
# Perhaps only books whose bibliographic metadata was updated
# recently should be included.
@@ -2551,11 +2541,9 @@ def build(self, _chain_filters=None):
# .last_update is probably a datetime. Convert it here.
updated_after = self.updated_after
if isinstance(updated_after, datetime.datetime):
- updated_after = (
- updated_after - from_timestamp(0)
- ).total_seconds()
+ updated_after = (updated_after - from_timestamp(0)).total_seconds()
last_update_time_query = self._match_range(
- 'last_update_time', 'gte', updated_after
+ "last_update_time", "gte", updated_after
)
f = chain(f, Bool(must=last_update_time_query))
@@ -2578,9 +2566,7 @@ def universal_base_filter(cls, _chain_filters=None):
base_filter = None
# We only want to show works that are presentation-ready.
- base_filter = _chain_filters(
- base_filter, Term(**{"presentation_ready":True})
- )
+ base_filter = _chain_filters(base_filter, Term(**{"presentation_ready": True}))
return base_filter
@@ -2604,13 +2590,13 @@ def universal_nested_filters(cls):
# It's easier to stay consistent by indexing all Works and
# filtering them out later, than to do it by adding and
# removing works from the index.
- not_suppressed = Term(**{'licensepools.suppressed' : False})
- nested_filters['licensepools'].append(not_suppressed)
+ not_suppressed = Term(**{"licensepools.suppressed": False})
+ nested_filters["licensepools"].append(not_suppressed)
- owns_licenses = Term(**{'licensepools.licensed' : True})
- open_access = Term(**{'licensepools.open_access' : True})
+ owns_licenses = Term(**{"licensepools.licensed": True})
+ open_access = Term(**{"licensepools.open_access": True})
currently_owned = Bool(should=[owns_licenses, open_access])
- nested_filters['licensepools'].append(currently_owned)
+ nested_filters["licensepools"].append(currently_owned)
return nested_filters
@@ -2634,14 +2620,12 @@ def sort_order(self):
# as long as possible. For example, a feed sorted by author
# will be secondarily sorted by title and work ID, not just by
# work ID.
- default_sort_order = ['sort_author', 'sort_title', 'work_id']
+ default_sort_order = ["sort_author", "sort_title", "work_id"]
order_field_keys = self.order
if not isinstance(order_field_keys, list):
order_field_keys = [order_field_keys]
- order_fields = [
- self._make_order_field(key) for key in order_field_keys
- ]
+ order_fields = [self._make_order_field(key) for key in order_field_keys]
# Apply any parts of the default sort order not yet covered,
# concluding (in most cases) with work_id, the tiebreaker field.
@@ -2659,7 +2643,7 @@ def asc(self):
return "asc"
def _make_order_field(self, key):
- if key == 'last_update_time':
+ if key == "last_update_time":
# Sorting by last_update_time may be very simple or very
# complex, depending on whether or not the filter
# involves collection or list membership.
@@ -2670,22 +2654,20 @@ def _make_order_field(self, key):
# The simple case, handled below.
pass
- if '.' not in key:
+ if "." not in key:
# A simple case.
- return { key : self.asc }
+ return {key: self.asc}
# At this point we're sorting by a nested field.
nested = None
- if key == 'licensepools.availability_time':
+ if key == "licensepools.availability_time":
nested, mode = self._availability_time_sort_order
else:
- raise ValueError(
- "I don't know how to sort by %s." % key
- )
+ raise ValueError("I don't know how to sort by %s." % key)
sort_description = dict(order=self.asc, mode=mode)
if nested:
- sort_description['nested']=nested
- return { key : sort_description }
+ sort_description["nested"] = nested
+ return {key: sort_description}
@property
def _availability_time_sort_order(self):
@@ -2698,15 +2680,11 @@ def _availability_time_sort_order(self):
if collection_ids:
nested = dict(
path="licensepools",
- filter=dict(
- terms={
- "licensepools.collection_id": collection_ids
- }
- ),
+ filter=dict(terms={"licensepools.collection_id": collection_ids}),
)
# If a book shows up in multiple collections, we're only
# interested in the collection that had it the earliest.
- mode = 'min'
+ mode = "min"
return nested, mode
@property
@@ -2731,18 +2709,12 @@ def last_update_time_script_field(self):
all_list_ids.update(self._filter_ids(restriction))
nested = dict(
path="customlists",
- filter=dict(
- terms={"customlists.list_id": list(all_list_ids)}
- )
- )
- params = dict(
- collection_ids=collection_ids,
- list_ids=list(all_list_ids)
+ filter=dict(terms={"customlists.list_id": list(all_list_ids)}),
)
+ params = dict(collection_ids=collection_ids, list_ids=list(all_list_ids))
return dict(
script=dict(
- stored=CurrentMapping.script_name("work_last_update"),
- params=params
+ stored=CurrentMapping.script_name("work_last_update"), params=params
)
)
@@ -2756,12 +2728,12 @@ def _last_update_time_order_by(self):
time as the script to use for a sort value.
"""
field = self.last_update_time_script_field
- if not 'last_update' in self.script_fields:
- self.script_fields['last_update'] = field
+ if not "last_update" in self.script_fields:
+ self.script_fields["last_update"] = field
return dict(
_script=dict(
type="number",
- script=field['script'],
+ script=field["script"],
order=self.asc,
),
)
@@ -2781,7 +2753,9 @@ def _last_update_time_order_by(self):
# Below that point, we prefer higher-quality works to
# lower-quality works, such that a work's score is proportional to
# the square of its quality.
- FEATURABLE_SCRIPT = "Math.pow(Math.min(%(cutoff).5f, doc['quality'].value), %(exponent).5f) * 5"
+ FEATURABLE_SCRIPT = (
+ "Math.pow(Math.min(%(cutoff).5f, doc['quality'].value), %(exponent).5f) * 5"
+ )
# Used in tests to deactivate the random component of
# featurability_scoring_functions.
@@ -2793,15 +2767,13 @@ def featurability_scoring_functions(self, random_seed):
"""
exponent = 2
- cutoff = (self.minimum_featured_quality ** exponent)
- script = self.FEATURABLE_SCRIPT % dict(
- cutoff=cutoff, exponent=exponent
- )
- quality_field = SF('script_score', script=dict(source=script))
+ cutoff = self.minimum_featured_quality ** exponent
+ script = self.FEATURABLE_SCRIPT % dict(cutoff=cutoff, exponent=exponent)
+ quality_field = SF("script_score", script=dict(source=script))
# Currently available works are more featurable.
- available = Term(**{'licensepools.available' : True})
- nested = Nested(path='licensepools', query=available)
+ available = Term(**{"licensepools.available": True})
+ nested = Nested(path="licensepools", query=available)
available_now = dict(filter=nested, weight=5)
function_scores = [quality_field, available_now]
@@ -2811,10 +2783,10 @@ def featurability_scoring_functions(self, random_seed):
# books every time.
if random_seed != self.DETERMINISTIC:
random = SF(
- 'random_score',
+ "random_score",
seed=random_seed or int(time.time()),
field="work_id",
- weight=1.1
+ weight=1.1,
)
function_scores.append(random)
@@ -2825,10 +2797,10 @@ def featurability_scoring_functions(self, random_seed):
# We're looking for works on certain custom lists. A work
# that's _featured_ on one of these lists will be boosted
# quite a lot versus one that's not.
- featured = Term(**{'customlists.featured' : True})
- on_list = Terms(**{'customlists.list_id' : list(list_ids)})
+ featured = Term(**{"customlists.featured": True})
+ on_list = Terms(**{"customlists.list_id": list(list_ids)})
featured_on_list = Bool(must=[featured, on_list])
- nested = Nested(path='customlists', query=featured_on_list)
+ nested = Nested(path="customlists", query=featured_on_list)
featured_on_relevant_list = dict(filter=nested, weight=11)
function_scores.append(featured_on_relevant_list)
return function_scores
@@ -2846,6 +2818,7 @@ def target_age_filter(self):
lower, upper = self.target_age
if lower is None and upper is None:
return None
+
def does_not_exist(field):
"""A filter that matches if there is no value for `field`."""
return Bool(must_not=[Exists(field=field)])
@@ -2854,8 +2827,7 @@ def or_does_not_exist(clause, field):
"""Either the given `clause` matches or the given field
does not exist.
"""
- return Bool(should=[clause, does_not_exist(field)],
- minimum_should_match=1)
+ return Bool(should=[clause, does_not_exist(field)], minimum_should_match=1)
clauses = []
@@ -2889,26 +2861,21 @@ def author_filter(self):
"""
if not self.author:
return None
- authorship_role = Terms(
- **{'contributors.role' : self.AUTHOR_MATCH_ROLES}
- )
+ authorship_role = Terms(**{"contributors.role": self.AUTHOR_MATCH_ROLES})
clauses = []
for field, value in [
- ('sort_name.keyword', self.author.sort_name),
- ('display_name.keyword', self.author.display_name),
- ('viaf', self.author.viaf),
- ('lc', self.author.lc)
+ ("sort_name.keyword", self.author.sort_name),
+ ("display_name.keyword", self.author.display_name),
+ ("viaf", self.author.viaf),
+ ("lc", self.author.lc),
]:
if not value or value == Edition.UNKNOWN_AUTHOR:
continue
- clauses.append(
- Term(**{'contributors.%s' % field : value})
- )
+ clauses.append(Term(**{"contributors.%s" % field: value}))
same_person = Bool(should=clauses, minimum_should_match=1)
return Bool(must=[authorship_role, same_person])
-
@classmethod
def _scrub(cls, s):
"""Modify a string for use in a filter match.
@@ -2986,8 +2953,7 @@ class SortKeyPagination(Pagination):
list.
"""
- def __init__(self, last_item_on_previous_page=None,
- size=Pagination.DEFAULT_SIZE):
+ def __init__(self, last_item_on_previous_page=None, size=Pagination.DEFAULT_SIZE):
self.size = size
self.last_item_on_previous_page = last_item_on_previous_page
@@ -3003,7 +2969,7 @@ def from_request(cls, get_arg, default_size=None):
size = cls.size_from_request(get_arg, default_size)
if isinstance(size, ProblemDetail):
return size
- pagination_key = get_arg('key', None)
+ pagination_key = get_arg("key", None)
if pagination_key:
try:
pagination_key = json.loads(pagination_key)
@@ -3019,8 +2985,8 @@ def items(self):
"""
pagination_key = self.pagination_key
if pagination_key:
- yield("key", self.pagination_key)
- yield("size", self.size)
+ yield ("key", self.pagination_key)
+ yield ("size", self.size)
@property
def pagination_key(self):
@@ -3119,6 +3085,7 @@ class WorkSearchResult(object):
obtained through Elasticsearch, such as its 'last modified' date
the context of a specific lane.
"""
+
def __init__(self, work, hit):
self._work = work
self._hit = hit
@@ -3129,7 +3096,7 @@ def __getattr__(self, k):
class MockExternalSearchIndex(ExternalSearchIndex):
- work_document_type = 'work-type'
+ work_document_type = "work-type"
def __init__(self, url=None):
self.url = url
@@ -3156,7 +3123,9 @@ def delete(self, index, doc_type, id):
def exists(self, index, doc_type, id):
return self._key(index, doc_type, id) in self.docs
- def create_search_doc(self, query_string, filter=None, pagination=None, debug=False):
+ def create_search_doc(
+ self, query_string, filter=None, pagination=None, debug=False
+ ):
return list(self.docs.values())
def query_works(self, query_string, filter, pagination, debug=False):
@@ -3170,7 +3139,8 @@ def sort_key(x):
if isinstance(x, MockSearchResult):
return x.work_id
else:
- return x['_id']
+ return x["_id"]
+
docs = sorted(list(self.docs.values()), key=sort_key)
if pagination:
start_at = 0
@@ -3180,7 +3150,7 @@ def sort_key(x):
if pagination.last_item_on_previous_page:
look_for = pagination.last_item_on_previous_page[-1]
for i, x in enumerate(docs):
- if x['_id'] == look_for:
+ if x["_id"] == look_for:
start_at = i + 1
break
else:
@@ -3193,9 +3163,7 @@ def sort_key(x):
if isinstance(x, MockSearchResult):
results.append(x)
else:
- results.append(
- MockSearchResult(x["title"], x["author"], {}, x['_id'])
- )
+ results.append(MockSearchResult(x["title"], x["author"], {}, x["_id"]))
if pagination:
pagination.page_loaded(results)
@@ -3214,20 +3182,22 @@ def count_works(self, filter):
def bulk(self, docs, **kwargs):
for doc in docs:
- self.index(doc['_index'], doc['_type'], doc['_id'], doc)
+ self.index(doc["_index"], doc["_type"], doc["_id"], doc)
return len(docs), []
+
class MockMeta(dict):
"""Mock the .meta object associated with an Elasticsearch search
result. This is necessary to get SortKeyPagination to work with
MockExternalSearchIndex.
"""
+
@property
def sort(self):
- return self['_sort']
+ return self["_sort"]
-class MockSearchResult(object):
+class MockSearchResult(object):
def __init__(self, sort_title, sort_author, meta, id):
self.sort_title = sort_title
self.sort_author = sort_author
@@ -3253,18 +3223,16 @@ class SearchIndexCoverageProvider(WorkPresentationProvider):
search index.
"""
- SERVICE_NAME = 'Search index coverage provider'
+ SERVICE_NAME = "Search index coverage provider"
DEFAULT_BATCH_SIZE = 500
OPERATION = WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION
def __init__(self, *args, **kwargs):
- search_index_client = kwargs.pop('search_index_client', None)
+ search_index_client = kwargs.pop("search_index_client", None)
super(SearchIndexCoverageProvider, self).__init__(*args, **kwargs)
- self.search_index_client = (
- search_index_client or ExternalSearchIndex(self._db)
- )
+ self.search_index_client = search_index_client or ExternalSearchIndex(self._db)
def process_batch(self, works):
"""
diff --git a/facets.py b/facets.py
index 343d768f0..bda9dcb31 100644
--- a/facets.py
+++ b/facets.py
@@ -1,18 +1,19 @@
from flask_babel import lazy_gettext as _
+
class FacetConstants(object):
# A special constant, basically an additional rel, indicating that
# an OPDS facet group represents different entry points into a
# WorkList.
- ENTRY_POINT_REL = 'http://librarysimplified.org/terms/rel/entrypoint'
- ENTRY_POINT_FACET_GROUP_NAME = 'entrypoint'
+ ENTRY_POINT_REL = "http://librarysimplified.org/terms/rel/entrypoint"
+ ENTRY_POINT_FACET_GROUP_NAME = "entrypoint"
# Query arguments can change how long a feed is to be cached.
- MAX_CACHE_AGE_NAME = 'max_age'
+ MAX_CACHE_AGE_NAME = "max_age"
# Subset the collection, roughly, by quality.
- COLLECTION_FACET_GROUP_NAME = 'collection'
+ COLLECTION_FACET_GROUP_NAME = "collection"
COLLECTION_FULL = "full"
COLLECTION_FEATURED = "featured"
COLLECTION_FACETS = [
@@ -21,12 +22,12 @@ class FacetConstants(object):
]
# Subset the collection by availability.
- AVAILABILITY_FACET_GROUP_NAME = 'available'
+ AVAILABILITY_FACET_GROUP_NAME = "available"
AVAILABLE_NOW = "now"
AVAILABLE_ALL = "all"
AVAILABLE_OPEN_ACCESS = "always"
- AVAILABLE_NOT_NOW = "not_now" # Used only in QA jackpot feeds -- real patrons don't
- # want to see this.
+ AVAILABLE_NOT_NOW = "not_now" # Used only in QA jackpot feeds -- real patrons don't
+ # want to see this.
AVAILABILITY_FACETS = [
AVAILABLE_NOW,
AVAILABLE_ALL,
@@ -34,14 +35,14 @@ class FacetConstants(object):
]
# The names of the order facets.
- ORDER_FACET_GROUP_NAME = 'order'
- ORDER_TITLE = 'title'
- ORDER_AUTHOR = 'author'
- ORDER_LAST_UPDATE = 'last_update'
- ORDER_ADDED_TO_COLLECTION = 'added'
- ORDER_SERIES_POSITION = 'series'
- ORDER_WORK_ID = 'work_id'
- ORDER_RANDOM = 'random'
+ ORDER_FACET_GROUP_NAME = "order"
+ ORDER_TITLE = "title"
+ ORDER_AUTHOR = "author"
+ ORDER_LAST_UPDATE = "last_update"
+ ORDER_ADDED_TO_COLLECTION = "added"
+ ORDER_SERIES_POSITION = "series"
+ ORDER_WORK_ID = "work_id"
+ ORDER_RANDOM = "random"
# Some order facets, like series and work id,
# only make sense in certain contexts.
# These are the options that can be enabled
@@ -58,9 +59,7 @@ class FacetConstants(object):
# Most facets should be ordered in ascending order by default (A>-Z), but
# these dates should be ordered descending by default (new->old).
- ORDER_DESCENDING_BY_DEFAULT = [
- ORDER_ADDED_TO_COLLECTION, ORDER_LAST_UPDATE
- ]
+ ORDER_DESCENDING_BY_DEFAULT = [ORDER_ADDED_TO_COLLECTION, ORDER_LAST_UPDATE]
FACETS_BY_GROUP = {
COLLECTION_FACET_GROUP_NAME: COLLECTION_FACETS,
@@ -69,64 +68,60 @@ class FacetConstants(object):
}
GROUP_DISPLAY_TITLES = {
- ORDER_FACET_GROUP_NAME : _("Sort by"),
- AVAILABILITY_FACET_GROUP_NAME : _("Availability"),
- COLLECTION_FACET_GROUP_NAME : _('Collection'),
+ ORDER_FACET_GROUP_NAME: _("Sort by"),
+ AVAILABILITY_FACET_GROUP_NAME: _("Availability"),
+ COLLECTION_FACET_GROUP_NAME: _("Collection"),
}
GROUP_DESCRIPTIONS = {
- ORDER_FACET_GROUP_NAME : _("Allow patrons to sort by"),
- AVAILABILITY_FACET_GROUP_NAME : _("Allow patrons to filter availability to"),
- COLLECTION_FACET_GROUP_NAME : _('Allow patrons to filter collection to'),
+ ORDER_FACET_GROUP_NAME: _("Allow patrons to sort by"),
+ AVAILABILITY_FACET_GROUP_NAME: _("Allow patrons to filter availability to"),
+ COLLECTION_FACET_GROUP_NAME: _("Allow patrons to filter collection to"),
}
FACET_DISPLAY_TITLES = {
- ORDER_TITLE : _('Title'),
- ORDER_AUTHOR : _('Author'),
- ORDER_LAST_UPDATE : _('Last Update'),
- ORDER_ADDED_TO_COLLECTION : _('Recently Added'),
- ORDER_SERIES_POSITION: _('Series Position'),
- ORDER_WORK_ID : _('Work ID'),
- ORDER_RANDOM : _('Random'),
-
- AVAILABLE_NOW : _("Available now"),
- AVAILABLE_ALL : _("All"),
- AVAILABLE_OPEN_ACCESS : _("Yours to keep"),
-
- COLLECTION_FULL : _("Everything"),
- COLLECTION_FEATURED : _("Popular Books"),
+ ORDER_TITLE: _("Title"),
+ ORDER_AUTHOR: _("Author"),
+ ORDER_LAST_UPDATE: _("Last Update"),
+ ORDER_ADDED_TO_COLLECTION: _("Recently Added"),
+ ORDER_SERIES_POSITION: _("Series Position"),
+ ORDER_WORK_ID: _("Work ID"),
+ ORDER_RANDOM: _("Random"),
+ AVAILABLE_NOW: _("Available now"),
+ AVAILABLE_ALL: _("All"),
+ AVAILABLE_OPEN_ACCESS: _("Yours to keep"),
+ COLLECTION_FULL: _("Everything"),
+ COLLECTION_FEATURED: _("Popular Books"),
}
# Unless a library offers an alternate configuration, patrons will
# see these facet groups.
DEFAULT_ENABLED_FACETS = {
- ORDER_FACET_GROUP_NAME : [
- ORDER_AUTHOR, ORDER_TITLE, ORDER_ADDED_TO_COLLECTION
+ ORDER_FACET_GROUP_NAME: [ORDER_AUTHOR, ORDER_TITLE, ORDER_ADDED_TO_COLLECTION],
+ AVAILABILITY_FACET_GROUP_NAME: [
+ AVAILABLE_ALL,
+ AVAILABLE_NOW,
+ AVAILABLE_OPEN_ACCESS,
],
- AVAILABILITY_FACET_GROUP_NAME : [
- AVAILABLE_ALL, AVAILABLE_NOW, AVAILABLE_OPEN_ACCESS
- ],
- COLLECTION_FACET_GROUP_NAME : [
- COLLECTION_FULL, COLLECTION_FEATURED
- ]
+ COLLECTION_FACET_GROUP_NAME: [COLLECTION_FULL, COLLECTION_FEATURED],
}
# Unless a library offers an alternate configuration, these
# facets will be the default selection for the facet groups.
DEFAULT_FACET = {
- ORDER_FACET_GROUP_NAME : ORDER_AUTHOR,
- AVAILABILITY_FACET_GROUP_NAME : AVAILABLE_ALL,
- COLLECTION_FACET_GROUP_NAME : COLLECTION_FULL,
+ ORDER_FACET_GROUP_NAME: ORDER_AUTHOR,
+ AVAILABILITY_FACET_GROUP_NAME: AVAILABLE_ALL,
+ COLLECTION_FACET_GROUP_NAME: COLLECTION_FULL,
}
SORT_ORDER_TO_ELASTICSEARCH_FIELD_NAME = {
- ORDER_TITLE : "sort_title",
- ORDER_AUTHOR : "sort_author",
- ORDER_LAST_UPDATE : 'last_update_time',
- ORDER_ADDED_TO_COLLECTION : 'licensepools.availability_time',
- ORDER_SERIES_POSITION : ['series_position', 'sort_title'],
- ORDER_WORK_ID : '_id',
- ORDER_RANDOM : 'random',
+ ORDER_TITLE: "sort_title",
+ ORDER_AUTHOR: "sort_author",
+ ORDER_LAST_UPDATE: "last_update_time",
+ ORDER_ADDED_TO_COLLECTION: "licensepools.availability_time",
+ ORDER_SERIES_POSITION: ["series_position", "sort_title"],
+ ORDER_WORK_ID: "_id",
+ ORDER_RANDOM: "random",
}
@@ -137,6 +132,7 @@ class FacetConfig(FacetConstants):
use a facet configuration different from the site-wide
facets.
"""
+
@classmethod
def from_library(cls, library):
diff --git a/lane.py b/lane.py
index 19494d90b..6a582f5fe 100644
--- a/lane.py
+++ b/lane.py
@@ -1,30 +1,31 @@
# encoding: utf-8
-from collections import defaultdict
import datetime
import logging
import time
+from collections import defaultdict
from urllib.parse import quote_plus
-from psycopg2.extras import NumericRange
-from sqlalchemy.sql import select
-from sqlalchemy.sql.expression import Select
-from sqlalchemy.dialects.postgresql import JSON
+
+import elasticsearch
from flask_babel import lazy_gettext as _
+from psycopg2.extras import NumericRange
from sqlalchemy import (
- and_,
- case,
- or_,
- not_,
+ Boolean,
+ Column,
+ ForeignKey,
Integer,
Table,
Unicode,
+ UniqueConstraint,
+ and_,
+ case,
+ event,
+ not_,
+ or_,
text,
)
-from sqlalchemy.ext.associationproxy import (
- association_proxy,
-)
-from sqlalchemy.ext.hybrid import (
- hybrid_property,
-)
+from sqlalchemy.dialects.postgresql import ARRAY, INT4RANGE, JSON
+from sqlalchemy.ext.associationproxy import association_proxy
+from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import (
aliased,
backref,
@@ -34,37 +35,15 @@
lazyload,
relationship,
)
-from sqlalchemy.sql.expression import literal
-import elasticsearch
-from sqlalchemy import (
- event,
- Boolean,
- Column,
- ForeignKey,
- Integer,
- UniqueConstraint,
-)
-from sqlalchemy.dialects.postgresql import (
- ARRAY,
- INT4RANGE,
-)
+from sqlalchemy.sql import select
+from sqlalchemy.sql.expression import Select, literal
from . import classifier
+from .classifier import Classifier, GenreData
from .config import Configuration
-from .classifier import (
- Classifier,
- GenreData,
-)
-from .entrypoint import (
- EntryPoint,
- EverythingEntryPoint,
-)
+from .entrypoint import EntryPoint, EverythingEntryPoint
+from .facets import FacetConstants
from .model import (
- directly_modified,
- get_one_or_create,
- numericrange_to_tuple,
- site_configuration_has_changed,
- tuple_to_numericrange,
Base,
CachedFeed,
Collection,
@@ -74,25 +53,26 @@
DeliveryMechanism,
Edition,
Genre,
- get_one,
Library,
LicensePool,
LicensePoolDeliveryMechanism,
Session,
Work,
WorkGenre,
+ directly_modified,
+ get_one,
+ get_one_or_create,
+ numericrange_to_tuple,
+ site_configuration_has_changed,
+ tuple_to_numericrange,
)
from .model.constants import EditionConstants
-from .facets import FacetConstants
from .problem_details import *
-from .util import (
- fast_query_count,
- LanguageCodes,
-)
-from .util.problem_detail import ProblemDetail
+from .util import LanguageCodes, fast_query_count
from .util.accept_language import parse_accept_language
-from .util.opds_writer import OPDSFeed
from .util.datetime_helpers import utc_now
+from .util.opds_writer import OPDSFeed
+from .util.problem_detail import ProblemDetail
class BaseFacets(FacetConstants):
@@ -129,7 +109,7 @@ def cached(self):
"""
if self.max_cache_age is None:
return None
- return (self.max_cache_age != 0)
+ return self.max_cache_age != 0
@property
def query_string(self):
@@ -182,8 +162,10 @@ class FacetsWithEntryPoint(BaseFacets):
"""Basic Facets class that knows how to filter a query based on a
selected EntryPoint.
"""
- def __init__(self, entrypoint=None, entrypoint_is_default=False,
- max_cache_age=None, **kwargs):
+
+ def __init__(
+ self, entrypoint=None, entrypoint_is_default=False, max_cache_age=None, **kwargs
+ ):
"""Constructor.
:param entrypoint: An EntryPoint (optional).
@@ -221,15 +203,22 @@ def navigate(self, entrypoint):
a different EntryPoint.
"""
return self.__class__(
- entrypoint=entrypoint, entrypoint_is_default=False,
+ entrypoint=entrypoint,
+ entrypoint_is_default=False,
max_cache_age=self.max_cache_age,
**self.constructor_kwargs
)
@classmethod
def from_request(
- cls, library, facet_config, get_argument, get_header, worklist,
- default_entrypoint=None, **extra_kwargs
+ cls,
+ library,
+ facet_config,
+ get_argument,
+ get_header,
+ worklist,
+ default_entrypoint=None,
+ **extra_kwargs
):
"""Load a faceting object from an HTTP request.
@@ -259,14 +248,23 @@ def from_request(
a problem with the input from the request.
"""
return cls._from_request(
- facet_config, get_argument, get_header, worklist,
- default_entrypoint, **extra_kwargs
+ facet_config,
+ get_argument,
+ get_header,
+ worklist,
+ default_entrypoint,
+ **extra_kwargs
)
@classmethod
def _from_request(
- cls, facet_config, get_argument, get_header, worklist,
- default_entrypoint=None, **extra_kwargs
+ cls,
+ facet_config,
+ get_argument,
+ get_header,
+ worklist,
+ default_entrypoint=None,
+ **extra_kwargs
):
"""Load a faceting object from an HTTP request.
@@ -274,9 +272,7 @@ def _from_request(
but call this method to load the EntryPoint and actually
instantiate the faceting class.
"""
- entrypoint_name = get_argument(
- Facets.ENTRY_POINT_FACET_GROUP_NAME, None
- )
+ entrypoint_name = get_argument(Facets.ENTRY_POINT_FACET_GROUP_NAME, None)
valid_entrypoints = list(cls.selectable_entrypoints(facet_config))
entrypoint = cls.load_entrypoint(
entrypoint_name, valid_entrypoints, default=default_entrypoint
@@ -285,16 +281,16 @@ def _from_request(
return entrypoint
entrypoint, is_default = entrypoint
- max_cache_age = get_argument(
- Facets.MAX_CACHE_AGE_NAME, None
- )
+ max_cache_age = get_argument(Facets.MAX_CACHE_AGE_NAME, None)
max_cache_age = cls.load_max_cache_age(max_cache_age)
if isinstance(max_cache_age, ProblemDetail):
return max_cache_age
return cls(
- entrypoint=entrypoint, entrypoint_is_default=is_default,
- max_cache_age=max_cache_age, **extra_kwargs
+ entrypoint=entrypoint,
+ entrypoint_is_default=is_default,
+ max_cache_age=max_cache_age,
+ **extra_kwargs
)
@classmethod
@@ -358,8 +354,7 @@ def items(self):
In this class that just means the entrypoint and any max_cache_age.
"""
if self.entrypoint:
- yield (self.ENTRY_POINT_FACET_GROUP_NAME,
- self.entrypoint.INTERNAL_NAME)
+ yield (self.ENTRY_POINT_FACET_GROUP_NAME, self.entrypoint.INTERNAL_NAME)
if self.max_cache_age not in (None, CachedFeed.CACHE_FOREVER):
if self.max_cache_age == CachedFeed.IGNORE_CACHE:
value = 0
@@ -395,10 +390,16 @@ class Facets(FacetsWithEntryPoint):
ORDER_BY_RELEVANCE = "relevance"
@classmethod
- def default(cls, library, collection=None, availability=None, order=None,
- entrypoint=None):
- return cls(library, collection=collection, availability=availability,
- order=order, entrypoint=entrypoint)
+ def default(
+ cls, library, collection=None, availability=None, order=None, entrypoint=None
+ ):
+ return cls(
+ library,
+ collection=collection,
+ availability=availability,
+ order=order,
+ entrypoint=entrypoint,
+ )
@classmethod
def available_facets(cls, config, facet_group_name):
@@ -440,8 +441,7 @@ def _values_from_request(cls, config, get_argument, get_header):
order_facets = cls.available_facets(config, g)
if order and not order in order_facets:
return INVALID_INPUT.detailed(
- _("I don't know how to order a feed by '%(order)s'", order=order),
- 400
+ _("I don't know how to order a feed by '%(order)s'", order=order), 400
)
g = Facets.AVAILABILITY_FACET_GROUP_NAME
@@ -449,8 +449,11 @@ def _values_from_request(cls, config, get_argument, get_header):
availability_facets = cls.available_facets(config, g)
if availability and not availability in availability_facets:
return INVALID_INPUT.detailed(
- _("I don't understand the availability term '%(availability)s'", availability=availability),
- 400
+ _(
+ "I don't understand the availability term '%(availability)s'",
+ availability=availability,
+ ),
+ 400,
)
g = Facets.COLLECTION_FACET_GROUP_NAME
@@ -458,38 +461,61 @@ def _values_from_request(cls, config, get_argument, get_header):
collection_facets = cls.available_facets(config, g)
if collection and not collection in collection_facets:
return INVALID_INPUT.detailed(
- _("I don't understand what '%(collection)s' refers to.", collection=collection),
- 400
+ _(
+ "I don't understand what '%(collection)s' refers to.",
+ collection=collection,
+ ),
+ 400,
)
enabled = {
- Facets.ORDER_FACET_GROUP_NAME : order_facets,
- Facets.AVAILABILITY_FACET_GROUP_NAME : availability_facets,
- Facets.COLLECTION_FACET_GROUP_NAME : collection_facets,
+ Facets.ORDER_FACET_GROUP_NAME: order_facets,
+ Facets.AVAILABILITY_FACET_GROUP_NAME: availability_facets,
+ Facets.COLLECTION_FACET_GROUP_NAME: collection_facets,
}
return dict(
- order=order, availability=availability, collection=collection,
- enabled_facets=enabled
+ order=order,
+ availability=availability,
+ collection=collection,
+ enabled_facets=enabled,
)
@classmethod
- def from_request(cls, library, config, get_argument, get_header, worklist,
- default_entrypoint=None, **extra):
+ def from_request(
+ cls,
+ library,
+ config,
+ get_argument,
+ get_header,
+ worklist,
+ default_entrypoint=None,
+ **extra
+ ):
"""Load a faceting object from an HTTP request."""
values = cls._values_from_request(config, get_argument, get_header)
if isinstance(values, ProblemDetail):
return values
extra.update(values)
- extra['library'] = library
+ extra["library"] = library
- return cls._from_request(config, get_argument, get_header, worklist,
- default_entrypoint, **extra)
+ return cls._from_request(
+ config, get_argument, get_header, worklist, default_entrypoint, **extra
+ )
- def __init__(self, library, collection, availability, order,
- order_ascending=None, enabled_facets=None, entrypoint=None,
- entrypoint_is_default=False, **constructor_kwargs):
+ def __init__(
+ self,
+ library,
+ collection,
+ availability,
+ order,
+ order_ascending=None,
+ enabled_facets=None,
+ entrypoint=None,
+ entrypoint_is_default=False,
+ **constructor_kwargs
+ ):
"""Constructor.
:param collection: This is not a Collection object; it's a value for
@@ -515,8 +541,14 @@ def __init__(self, library, collection, availability, order,
else:
order_ascending = self.ORDER_ASCENDING
- if (availability == self.AVAILABLE_ALL and (library and not library.allow_holds)
- and (self.AVAILABLE_NOW in self.available_facets(library, self.AVAILABILITY_FACET_GROUP_NAME))):
+ if (
+ availability == self.AVAILABLE_ALL
+ and (library and not library.allow_holds)
+ and (
+ self.AVAILABLE_NOW
+ in self.available_facets(library, self.AVAILABILITY_FACET_GROUP_NAME)
+ )
+ ):
# Under normal circumstances we would show all works, but
# library configuration says to hide books that aren't
# available.
@@ -533,8 +565,7 @@ def __init__(self, library, collection, availability, order,
self.order_ascending = order_ascending
self.facets_enabled_at_init = enabled_facets
- def navigate(self, collection=None, availability=None, order=None,
- entrypoint=None):
+ def navigate(self, collection=None, availability=None, order=None, entrypoint=None):
"""Create a slightly different Facets object from this one."""
return self.__class__(
library=self.library,
@@ -544,17 +575,16 @@ def navigate(self, collection=None, availability=None, order=None,
enabled_facets=self.facets_enabled_at_init,
entrypoint=(entrypoint or self.entrypoint),
entrypoint_is_default=False,
- max_cache_age=self.max_cache_age
+ max_cache_age=self.max_cache_age,
)
-
def items(self):
- for k,v in list(super(Facets, self).items()):
+ for k, v in list(super(Facets, self).items()):
yield k, v
if self.order:
yield (self.ORDER_FACET_GROUP_NAME, self.order)
if self.availability:
- yield (self.AVAILABILITY_FACET_GROUP_NAME, self.availability)
+ yield (self.AVAILABILITY_FACET_GROUP_NAME, self.availability)
if self.collection:
yield (self.COLLECTION_FACET_GROUP_NAME, self.collection)
@@ -572,7 +602,7 @@ def enabled_facets(self):
facet_types = [
self.ORDER_FACET_GROUP_NAME,
self.AVAILABILITY_FACET_GROUP_NAME,
- self.COLLECTION_FACET_GROUP_NAME
+ self.COLLECTION_FACET_GROUP_NAME,
]
for facet_type in facet_types:
yield self.facets_enabled_at_init.get(facet_type, [])
@@ -581,7 +611,7 @@ def enabled_facets(self):
for group_name in (
Facets.ORDER_FACET_GROUP_NAME,
Facets.AVAILABILITY_FACET_GROUP_NAME,
- Facets.COLLECTION_FACET_GROUP_NAME
+ Facets.COLLECTION_FACET_GROUP_NAME,
):
yield self.available_facets(self.library, group_name)
@@ -601,7 +631,7 @@ def dy(new_value):
group = self.ORDER_FACET_GROUP_NAME
current_value = self.order
facets = self.navigate(order=new_value)
- return (group, new_value, facets, current_value==new_value)
+ return (group, new_value, facets, current_value == new_value)
# First, the order facets.
if len(order_facets) > 1:
@@ -613,7 +643,7 @@ def dy(new_value):
group = self.AVAILABILITY_FACET_GROUP_NAME
current_value = self.availability
facets = self.navigate(availability=new_value)
- return (group, new_value, facets, new_value==current_value)
+ return (group, new_value, facets, new_value == current_value)
if len(availability_facets) > 1:
for facet in availability_facets:
@@ -624,7 +654,7 @@ def dy(new_value):
group = self.COLLECTION_FACET_GROUP_NAME
current_value = self.collection
facets = self.navigate(collection=new_value)
- return (group, new_value, facets, new_value==current_value)
+ return (group, new_value, facets, new_value == current_value)
if len(collection_facets) > 1:
for facet in collection_facets:
@@ -671,7 +701,7 @@ def modify_database_query(self, _db, qu):
LicensePool.open_access == True,
LicensePool.self_hosted == True,
LicensePool.unlimited_access,
- LicensePool.licenses_available > 0
+ LicensePool.licenses_available > 0,
)
if self.availability == self.AVAILABLE_NOW:
@@ -681,7 +711,7 @@ def modify_database_query(self, _db, qu):
LicensePool.open_access == True,
LicensePool.self_hosted == True,
LicensePool.licenses_owned > 0,
- LicensePool.unlimited_access
+ LicensePool.unlimited_access,
)
elif self.availability == self.AVAILABLE_OPEN_ACCESS:
# TODO: self-hosted content could be allowed here
@@ -690,8 +720,7 @@ def modify_database_query(self, _db, qu):
elif self.availability == self.AVAILABLE_NOT_NOW:
# The book must be licensed but currently unavailable.
availability_clause = and_(
- not_(available_now),
- LicensePool.licenses_owned > 0
+ not_(available_now), LicensePool.licenses_owned > 0
)
qu = qu.filter(availability_clause)
@@ -702,9 +731,7 @@ def modify_database_query(self, _db, qu):
elif self.collection == self.COLLECTION_FEATURED:
# Exclude books with a quality of less than the library's
# minimum featured quality.
- qu = qu.filter(
- Work.quality >= self.library.minimum_featured_quality
- )
+ qu = qu.filter(Work.quality >= self.library.minimum_featured_quality)
return qu
@@ -730,7 +757,7 @@ def available_facets(cls, config, facet_group_name):
# adding it if necessary.
order = cls.DEFAULT_SORT_ORDER
if order in default:
- default = [x for x in default if x!=order]
+ default = [x for x in default if x != order]
return [order] + default
@classmethod
@@ -752,10 +779,10 @@ class DatabaseBackedFacets(Facets):
# -- they map directly onto a field of one of the tables we're
# querying.
ORDER_FACET_TO_DATABASE_FIELD = {
- FacetConstants.ORDER_WORK_ID : Work.id,
- FacetConstants.ORDER_TITLE : Edition.sort_title,
- FacetConstants.ORDER_AUTHOR : Edition.sort_author,
- FacetConstants.ORDER_LAST_UPDATE : Work.last_update_time,
+ FacetConstants.ORDER_WORK_ID: Work.id,
+ FacetConstants.ORDER_TITLE: Edition.sort_title,
+ FacetConstants.ORDER_AUTHOR: Edition.sort_author,
+ FacetConstants.ORDER_LAST_UPDATE: Work.last_update_time,
}
@classmethod
@@ -764,8 +791,9 @@ def available_facets(cls, config, facet_group_name):
standard = config.enabled_facets(facet_group_name)
if facet_group_name != cls.ORDER_FACET_GROUP_NAME:
return standard
- return [order for order in standard
- if order in cls.ORDER_FACET_TO_DATABASE_FIELD]
+ return [
+ order for order in standard if order in cls.ORDER_FACET_TO_DATABASE_FIELD
+ ]
@classmethod
def default_facet(cls, config, facet_group_name):
@@ -793,9 +821,7 @@ def order_by(self):
"""Given these Facets, create a complete ORDER BY clause for queries
against WorkModelWithGenre.
"""
- default_sort_order = [
- Edition.sort_author, Edition.sort_title, Work.id
- ]
+ default_sort_order = [Edition.sort_author, Edition.sort_title, Work.id]
primary_order_by = self.ORDER_FACET_TO_DATABASE_FIELD.get(self.order)
if primary_order_by is not None:
@@ -846,8 +872,9 @@ class FeaturedFacets(FacetsWithEntryPoint):
# This Facets class is used exclusively for grouped feeds.
CACHED_FEED_TYPE = CachedFeed.GROUPS_TYPE
- def __init__(self, minimum_featured_quality, entrypoint=None,
- random_seed=None, **kwargs):
+ def __init__(
+ self, minimum_featured_quality, entrypoint=None, random_seed=None, **kwargs
+ ):
"""Set up an object that finds featured books in a given
WorkList.
@@ -856,7 +883,7 @@ def __init__(self, minimum_featured_quality, entrypoint=None,
"""
super(FeaturedFacets, self).__init__(entrypoint=entrypoint, **kwargs)
self.minimum_featured_quality = minimum_featured_quality
- self.random_seed=random_seed
+ self.random_seed = random_seed
@classmethod
def default(cls, lane, **kwargs):
@@ -877,9 +904,13 @@ def navigate(self, minimum_featured_quality=None, entrypoint=None):
"""Create a slightly different FeaturedFacets object based on this
one.
"""
- minimum_featured_quality = minimum_featured_quality or self.minimum_featured_quality
+ minimum_featured_quality = (
+ minimum_featured_quality or self.minimum_featured_quality
+ )
entrypoint = entrypoint or self.entrypoint
- return self.__class__(minimum_featured_quality, entrypoint, max_cache_age=self.max_cache_age)
+ return self.__class__(
+ minimum_featured_quality, entrypoint, max_cache_age=self.max_cache_age
+ )
def modify_search_filter(self, filter):
super(FeaturedFacets, self).modify_search_filter(filter)
@@ -907,8 +938,8 @@ class SearchFacets(Facets):
DEFAULT_MIN_SCORE = 500
def __init__(self, **kwargs):
- languages = kwargs.pop('languages', None)
- media = kwargs.pop('media', None)
+ languages = kwargs.pop("languages", None)
+ media = kwargs.pop("media", None)
# Our default_facets implementation will fill in values for
# the facet groups defined by the Facets class. This
@@ -917,10 +948,10 @@ def __init__(self, **kwargs):
# SearchFacets itself doesn't need one. However, in real
# usage, a Library will be provided via
# SearchFacets.from_request.
- kwargs.setdefault('library', None)
- kwargs.setdefault('collection', None)
- kwargs.setdefault('availability', None)
- order = kwargs.setdefault('order', None)
+ kwargs.setdefault("library", None)
+ kwargs.setdefault("collection", None)
+ kwargs.setdefault("availability", None)
+ order = kwargs.setdefault("order", None)
if order in (None, self.ORDER_BY_RELEVANCE):
# Search results are ordered by score, so there is no
@@ -928,7 +959,7 @@ def __init__(self, **kwargs):
default_min_score = None
else:
default_min_score = self.DEFAULT_MIN_SCORE
- self.min_score = kwargs.pop('min_score', default_min_score)
+ self.min_score = kwargs.pop("min_score", default_min_score)
super(SearchFacets, self).__init__(**kwargs)
if media == Edition.ALL_MEDIUM:
@@ -966,14 +997,22 @@ def _ensure_list(self, x):
return [x]
@classmethod
- def from_request(cls, library, config, get_argument, get_header, worklist,
- default_entrypoint=EverythingEntryPoint, **extra):
+ def from_request(
+ cls,
+ library,
+ config,
+ get_argument,
+ get_header,
+ worklist,
+ default_entrypoint=EverythingEntryPoint,
+ **extra
+ ):
values = cls._values_from_request(config, get_argument, get_header)
if isinstance(values, ProblemDetail):
return values
extra.update(values)
- extra['library'] = library
+ extra["library"] = library
# Searches against a WorkList will use the union of the
# languages allowed by the WorkList and the languages found in
# the client's Accept-Language header.
@@ -995,7 +1034,7 @@ def from_request(cls, library, config, get_argument, get_header, worklist,
except ValueError as e:
min_score = None
if min_score is not None:
- extra['min_score'] = min_score
+ extra["min_score"] = min_score
# The client can request an additional restriction on
# the media types to be returned by searches.
@@ -1003,17 +1042,16 @@ def from_request(cls, library, config, get_argument, get_header, worklist,
media = get_argument("media", None)
if media not in EditionConstants.KNOWN_MEDIA:
media = None
- extra['media'] = media
+ extra["media"] = media
languageQuery = get_argument("language", None)
# Currently, the only value passed to the language query from the client is
# `all`. This will remove the default browser's Accept-Language header value
# in the search request.
- if languageQuery != "all" :
- extra['languages'] = languages
+ if languageQuery != "all":
+ extra["languages"] = languages
return cls._from_request(
- config, get_argument, get_header, worklist, default_entrypoint,
- **extra
+ config, get_argument, get_header, worklist, default_entrypoint, **extra
)
@classmethod
@@ -1083,14 +1121,15 @@ def items(self):
yield ("media", self.media_argument)
if self.min_score is not None:
- yield ('min_score', str(self.min_score))
+ yield ("min_score", str(self.min_score))
def navigate(self, **kwargs):
- min_score = kwargs.pop('min_score', self.min_score)
+ min_score = kwargs.pop("min_score", self.min_score)
new_facets = super(SearchFacets, self).navigate(**kwargs)
new_facets.min_score = min_score
return new_facets
+
class Pagination(object):
DEFAULT_SIZE = 50
@@ -1139,11 +1178,9 @@ def _int_from_request(cls, key, get_arg, make_detail, default):
@classmethod
def size_from_request(cls, get_arg, default):
- make_detail = lambda size: (
- _("Invalid page size: %(size)s", size=size)
- )
+ make_detail = lambda size: (_("Invalid page size: %(size)s", size=size))
size = cls._int_from_request(
- 'size', get_arg, make_detail, default or cls.DEFAULT_SIZE
+ "size", get_arg, make_detail, default or cls.DEFAULT_SIZE
)
if isinstance(size, ProblemDetail):
return size
@@ -1157,21 +1194,22 @@ def from_request(cls, get_arg, default_size=None):
if isinstance(size, ProblemDetail):
return size
offset = cls._int_from_request(
- 'after', get_arg,
+ "after",
+ get_arg,
lambda offset: _("Invalid offset: %(offset)s", offset=offset),
- 0
+ 0,
)
if isinstance(offset, ProblemDetail):
return offset
return cls(offset, size)
def items(self):
- yield("after", self.offset)
- yield("size", self.size)
+ yield ("after", self.offset)
+ yield ("size", self.size)
@property
def query_string(self):
- return "&".join("=".join(map(str, x)) for x in list(self.items()))
+ return "&".join("=".join(map(str, x)) for x in list(self.items()))
@property
def first_page(self):
@@ -1181,7 +1219,7 @@ def first_page(self):
def next_page(self):
if not self.has_next_page:
return None
- return Pagination(self.offset+self.size, self.size)
+ return Pagination(self.offset + self.size, self.size)
@property
def previous_page(self):
@@ -1224,7 +1262,7 @@ def modify_search_query(self, search):
:return: A Search object.
"""
- return search[self.offset:self.offset+self.size]
+ return search[self.offset : self.offset + self.size]
def page_loaded(self, page):
"""An actual page of results has been fetched. Keep any internal state
@@ -1279,15 +1317,14 @@ def top_level_for_library(self, _db, library):
"""
# Load all of this Library's visible top-level Lane objects
# from the database.
- top_level_lanes = _db.query(Lane).filter(
- Lane.library==library
- ).filter(
- Lane.parent==None
- ).filter(
- Lane._visible==True
- ).order_by(
- Lane.priority
- ).all()
+ top_level_lanes = (
+ _db.query(Lane)
+ .filter(Lane.library == library)
+ .filter(Lane.parent == None)
+ .filter(Lane._visible == True)
+ .order_by(Lane.priority)
+ .all()
+ )
if len(top_level_lanes) == 1:
# The site configuration includes a single top-level lane;
@@ -1299,18 +1336,31 @@ def top_level_for_library(self, _db, library):
wl = TopLevelWorkList()
wl.initialize(
- library, display_name=library.name, children=top_level_lanes,
- media=Edition.FULFILLABLE_MEDIA, entrypoints=library.entrypoints
+ library,
+ display_name=library.name,
+ children=top_level_lanes,
+ media=Edition.FULFILLABLE_MEDIA,
+ entrypoints=library.entrypoints,
)
return wl
- def initialize(self, library, display_name=None, genres=None,
- audiences=None, languages=None, media=None,
- customlists=None, list_datasource=None,
- list_seen_in_previous_days=None,
- children=None, priority=None, entrypoints=None,
- fiction=None, license_datasource=None,
- target_age=None,
+ def initialize(
+ self,
+ library,
+ display_name=None,
+ genres=None,
+ audiences=None,
+ languages=None,
+ media=None,
+ customlists=None,
+ list_datasource=None,
+ list_seen_in_previous_days=None,
+ children=None,
+ priority=None,
+ entrypoints=None,
+ fiction=None,
+ license_datasource=None,
+ target_age=None,
):
"""Initialize with basic data.
@@ -1458,7 +1508,11 @@ def get_library(self, _db):
def get_customlists(self, _db):
"""Get customlists associated with the Worklist."""
if hasattr(self, "_customlist_ids") and self._customlist_ids is not None:
- return _db.query(CustomList).filter(CustomList.id.in_(self._customlist_ids)).all()
+ return (
+ _db.query(CustomList)
+ .filter(CustomList.id.in_(self._customlist_ids))
+ .all()
+ )
return []
@property
@@ -1475,7 +1529,7 @@ def visible_children(self):
"""
return sorted(
[x for x in self.children if x.visible],
- key = lambda x: (x.priority, x.display_name or "")
+ key=lambda x: (x.priority, x.display_name or ""),
)
@property
@@ -1571,14 +1625,13 @@ def full_identifier(self):
captures its position within the heirarchy.
"""
full_parentage = [str(x.display_name) for x in self.hierarchy]
- if getattr(self, 'library', None):
+ if getattr(self, "library", None):
# This WorkList is associated with a specific library.
# incorporate the library's name to distinguish between it
# and other lanes in the same position in another library.
full_parentage.insert(0, self.library.short_name)
return " / ".join(full_parentage)
-
@property
def language_key(self):
"""Return a string identifying the languages used in this WorkList.
@@ -1592,13 +1645,12 @@ def language_key(self):
@property
def audience_key(self):
"""Translates audiences list into url-safe string"""
- key = ''
- if (self.audiences and
- Classifier.AUDIENCES.difference(self.audiences)):
+ key = ""
+ if self.audiences and Classifier.AUDIENCES.difference(self.audiences):
# There are audiences and they're not the default
# "any audience", so add them to the URL.
audiences = [quote_plus(a) for a in sorted(self.audiences)]
- key += ','.join(audiences)
+ key += ",".join(audiences)
return key
@property
@@ -1609,9 +1661,7 @@ def unique_key(self):
This is used when caching feeds for this WorkList. For Lanes,
the lane_id is used instead.
"""
- return "%s-%s-%s" % (
- self.display_name, self.language_key, self.audience_key
- )
+ return "%s-%s-%s" % (self.display_name, self.language_key, self.audience_key)
def accessible_to(self, patron):
"""As a matter of library policy, is the given `Patron` allowed
@@ -1661,8 +1711,15 @@ def overview_facets(self, _db, facets):
"""
return facets
- def groups(self, _db, include_sublanes=True, pagination=None, facets=None,
- search_engine=None, debug=False):
+ def groups(
+ self,
+ _db,
+ include_sublanes=True,
+ pagination=None,
+ facets=None,
+ search_engine=None,
+ debug=False,
+ ):
"""Extract a list of samples from each child of this WorkList. This
can be used to create a grouped acquisition feed for the WorkList.
@@ -1713,13 +1770,25 @@ def groups(self, _db, include_sublanes=True, pagination=None, facets=None,
# for any children that are Lanes, and call groups()
# recursively for any children that are not.
for work, worklist in self._groups_for_lanes(
- _db, relevant_children, relevant_lanes, pagination=pagination,
- facets=facets, search_engine=search_engine, debug=debug
+ _db,
+ relevant_children,
+ relevant_lanes,
+ pagination=pagination,
+ facets=facets,
+ search_engine=search_engine,
+ debug=debug,
):
yield work, worklist
- def works(self, _db, facets=None, pagination=None, search_engine=None,
- debug=False, **kwargs):
+ def works(
+ self,
+ _db,
+ facets=None,
+ pagination=None,
+ search_engine=None,
+ debug=False,
+ **kwargs
+ ):
"""Use a search engine to obtain Work or Work-like objects that belong
in this WorkList.
@@ -1740,15 +1809,12 @@ def works(self, _db, facets=None, pagination=None, search_engine=None,
that generates such a list when executed.
"""
- from .external_search import (
- Filter,
- ExternalSearchIndex,
- )
+ from .external_search import ExternalSearchIndex, Filter
+
search_engine = search_engine or ExternalSearchIndex.load(_db)
filter = self.filter(_db, facets)
hits = search_engine.query_works(
- query_string=None, filter=filter, pagination=pagination,
- debug=debug
+ query_string=None, filter=filter, pagination=pagination, debug=debug
)
return self.works_for_hits(_db, hits, facets=facets)
@@ -1759,6 +1825,7 @@ def filter(self, _db, facets):
called.
"""
from .external_search import Filter
+
filter = Filter.from_worklist(_db, self, facets)
modified = self.modify_search_filter_hook(filter)
if modified is None:
@@ -1794,10 +1861,7 @@ def works_for_resultsets(self, _db, resultsets, facets=None):
"""Convert a list of lists of Hit objects into a list
of lists of Work objects.
"""
- from .external_search import (
- Filter,
- WorkSearchResult,
- )
+ from .external_search import Filter, WorkSearchResult
has_script_fields = None
work_ids = set()
@@ -1808,10 +1872,8 @@ def works_for_resultsets(self, _db, resultsets, facets=None):
# We don't know whether any script fields were
# included, and now we're in a position to find
# out.
- has_script_fields = (
- any(
- x in result for x in Filter.KNOWN_SCRIPT_FIELDS
- )
+ has_script_fields = any(
+ x in result for x in Filter.KNOWN_SCRIPT_FIELDS
)
if has_script_fields is None:
@@ -1855,9 +1917,7 @@ def works_for_resultsets(self, _db, resultsets, facets=None):
works.append(work)
b = time.time()
- logging.info(
- "Obtained %sxWork in %.2fsec", len(all_works), b-a
- )
+ logging.info("Obtained %sxWork in %.2fsec", len(all_works), b - a)
return work_lists
@property
@@ -1865,8 +1925,9 @@ def search_target(self):
"""By default, a WorkList is searchable."""
return self
- def search(self, _db, query, search_client, pagination=None, facets=None,
- debug=False):
+ def search(
+ self, _db, query, search_client, pagination=None, facets=None, debug=False
+ ):
"""Find works in this WorkList that match a search query.
:param _db: A database connection.
@@ -1884,19 +1945,15 @@ def search(self, _db, query, search_client, pagination=None, facets=None,
return results
if not pagination:
- pagination = Pagination(
- offset=0, size=Pagination.DEFAULT_SEARCH_SIZE
- )
+ pagination = Pagination(offset=0, size=Pagination.DEFAULT_SEARCH_SIZE)
filter = self.filter(_db, facets)
try:
- hits = search_client.query_works(
- query, filter, pagination, debug
- )
+ hits = search_client.query_works(query, filter, pagination, debug)
except elasticsearch.exceptions.ElasticsearchException as e:
logging.error(
"Problem communicating with ElasticSearch. Returning empty list of search results.",
- exc_info=e
+ exc_info=e,
)
if hits:
results = self.works_for_hits(_db, hits)
@@ -1904,8 +1961,14 @@ def search(self, _db, query, search_client, pagination=None, facets=None,
return results
def _groups_for_lanes(
- self, _db, relevant_lanes, queryable_lanes, pagination, facets,
- search_engine=None, debug=False
+ self,
+ _db,
+ relevant_lanes,
+ queryable_lanes,
+ pagination,
+ facets,
+ search_engine=None,
+ debug=False,
):
"""Ask the search engine for groups of featurable works in the
given lanes. Fill in gaps as necessary.
@@ -1936,12 +1999,13 @@ def _groups_for_lanes(
# We ask for a few extra works for each lane, to reduce the
# risk that we'll end up reusing a book in two different
# lanes.
- ask_for_size = max(target_size+1, int(target_size * 1.10))
+ ask_for_size = max(target_size + 1, int(target_size * 1.10))
pagination = Pagination(size=ask_for_size)
else:
target_size = pagination.size
from .external_search import ExternalSearchIndex
+
search_engine = search_engine or ExternalSearchIndex.load(_db)
if isinstance(self, Lane):
@@ -1952,8 +2016,12 @@ def _groups_for_lanes(
queryable_lane_set = set(queryable_lanes)
works_and_lanes = list(
self._featured_works_with_lanes(
- _db, queryable_lanes, pagination=pagination,
- facets=facets, search_engine=search_engine, debug=debug
+ _db,
+ queryable_lanes,
+ pagination=pagination,
+ facets=facets,
+ search_engine=search_engine,
+ debug=debug,
)
)
@@ -1962,14 +2030,12 @@ def _done_with_lane(lane):
the lane changes or we've reached the end of the list.
"""
# Did we get enough items?
- num_missing = target_size-len(by_lane[lane])
+ num_missing = target_size - len(by_lane[lane])
if num_missing > 0 and might_need_to_reuse:
# No, we need to use some works we used in a
# previous lane to fill out this lane. Stick
# them at the end.
- by_lane[lane].extend(
- list(might_need_to_reuse.values())[:num_missing]
- )
+ by_lane[lane].extend(list(might_need_to_reuse.values())[:num_missing])
used_works = set()
by_lane = defaultdict(list)
@@ -2014,8 +2080,10 @@ def _done_with_lane(lane):
# Lane at all. Do a whole separate query and plug it
# in at this point.
for x in lane.groups(
- _db, include_sublanes=False,
- pagination=pagination, facets=facets,
+ _db,
+ include_sublanes=False,
+ pagination=pagination,
+ facets=facets,
):
yield x
@@ -2059,6 +2127,7 @@ def _featured_works_with_lanes(
for lane in lanes:
overview_facets = lane.overview_facets(_db, facets)
from .external_search import Filter
+
filter = Filter.from_worklist(_db, lane, overview_facets)
queries.append((None, filter, pagination))
resultsets = list(search_engine.query_works_multi(queries))
@@ -2111,6 +2180,7 @@ class TopLevelWorkList(HierarchyWorkList):
"""A special WorkList representing the top-level view of
a library's collection.
"""
+
pass
@@ -2183,14 +2253,11 @@ def base_query(cls, _db):
"""Return a query that contains the joins set up as necessary to
create OPDS feeds.
"""
- qu = _db.query(
- Work
- ).join(
- Work.license_pools
- ).join(
- Work.presentation_edition
- ).filter(
- LicensePool.superceded==False
+ qu = (
+ _db.query(Work)
+ .join(Work.license_pools)
+ .join(Work.presentation_edition)
+ .filter(LicensePool.superceded == False)
)
# Apply optimizations.
@@ -2208,7 +2275,7 @@ def _modify_loading(cls, qu):
contains_eager(Work.presentation_edition),
contains_eager(Work.license_pools),
)
- license_pool_name = 'license_pools'
+ license_pool_name = "license_pools"
# Load some objects that wouldn't normally be loaded, but
# which are necessary when generating OPDS feeds.
@@ -2222,19 +2289,17 @@ def _modify_loading(cls, qu):
# These speed up the process of generating acquisition links.
joinedload(license_pool_name, "delivery_mechanisms"),
joinedload(license_pool_name, "delivery_mechanisms", "delivery_mechanism"),
-
joinedload(license_pool_name, "identifier"),
-
# These speed up the process of generating the open-access link
# for open-access works.
joinedload(license_pool_name, "delivery_mechanisms", "resource"),
- joinedload(license_pool_name, "delivery_mechanisms", "resource", "representation"),
+ joinedload(
+ license_pool_name, "delivery_mechanisms", "resource", "representation"
+ ),
)
return qu
- def only_show_ready_deliverable_works(
- self, _db, query, show_suppressed=False
- ):
+ def only_show_ready_deliverable_works(self, _db, query, show_suppressed=False):
"""Restrict a query to show only presentation-ready works present in
an appropriate collection which the default client can
fulfill.
@@ -2243,8 +2308,7 @@ def only_show_ready_deliverable_works(
LicensePool.
"""
return Collection.restrict_to_ready_deliverable_works(
- query, show_suppressed=show_suppressed,
- collection_ids=self.collection_ids
+ query, show_suppressed=show_suppressed, collection_ids=self.collection_ids
)
@classmethod
@@ -2278,11 +2342,9 @@ def bibliographic_filter_clauses(self, _db, qu):
if self.media:
clauses.append(Edition.medium.in_(self.media))
if self.fiction is not None:
- clauses.append(Work.fiction==self.fiction)
+ clauses.append(Work.fiction == self.fiction)
if self.license_datasource_id:
- clauses.append(
- LicensePool.data_source_id==self.license_datasource_id
- )
+ clauses.append(LicensePool.data_source_id == self.license_datasource_id)
if self.genre_ids:
qu, clause = self.genre_filter_clause(qu)
@@ -2299,9 +2361,7 @@ def bibliographic_filter_clauses(self, _db, qu):
# In addition to the other any other restrictions, books
# will show up here only if they would also show up in the
# parent WorkList.
- qu, parent_clauses = self.parent.bibliographic_filter_clauses(
- _db, qu
- )
+ qu, parent_clauses = self.parent.bibliographic_filter_clauses(_db, qu)
if parent_clauses:
clauses.extend(parent_clauses)
@@ -2339,7 +2399,7 @@ def customlist_filter_clauses(self, qu):
# the table every time.
a_entry = aliased(CustomListEntry)
- clause = a_entry.work_id==Work.id
+ clause = a_entry.work_id == Work.id
qu = qu.join(a_entry, clause)
# Actually build the restriction clauses.
@@ -2350,24 +2410,21 @@ def customlist_filter_clauses(self, qu):
# CustomLists from this DataSource. This is significantly
# simpler than adding a join against CustomList.
customlist_ids = Select(
- [CustomList.id],
- CustomList.data_source_id==self.list_datasource_id
+ [CustomList.id], CustomList.data_source_id == self.list_datasource_id
)
else:
customlist_ids = self.customlist_ids
if customlist_ids is not None:
clauses.append(a_entry.list_id.in_(customlist_ids))
if self.list_seen_in_previous_days:
- cutoff = utc_now() - datetime.timedelta(
- self.list_seen_in_previous_days
- )
- clauses.append(a_entry.most_recent_appearance >=cutoff)
+ cutoff = utc_now() - datetime.timedelta(self.list_seen_in_previous_days)
+ clauses.append(a_entry.most_recent_appearance >= cutoff)
return qu, clauses
def genre_filter_clause(self, qu):
wg = aliased(WorkGenre)
- qu = qu.join(wg, wg.work_id==Work.id)
+ qu = qu.join(wg, wg.work_id == Work.id)
return qu, wg.genre_id.in_(self.genre_ids)
def age_range_filter_clauses(self):
@@ -2384,12 +2441,8 @@ def age_range_filter_clauses(self):
target_age = tuple_to_numericrange(target_age)
audiences = self.audiences or []
- adult_audiences = [
- Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_ADULTS_ONLY
- ]
- if (target_age.upper >= 18 or (
- any(x in audiences for x in adult_audiences))
- ):
+ adult_audiences = [Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_ADULTS_ONLY]
+ if target_age.upper >= 18 or (any(x in audiences for x in adult_audiences)):
# Books for adults don't have target ages. If we're
# including books for adults, either due to the audience
# setting or the target age setting, allow the target age
@@ -2402,12 +2455,7 @@ def age_range_filter_clauses(self):
# set_target_age makes sure of that. The work's target age
# must overlap that of the lane.
- return [
- or_(
- Work.target_age.overlaps(target_age),
- audience_has_no_target_age
- )
- ]
+ return [or_(Work.target_age.overlaps(target_age), audience_has_no_target_age)]
def modify_database_query_hook(self, _db, qu):
"""A hook method allowing subclasses to modify a database query
@@ -2421,6 +2469,7 @@ def modify_database_query_hook(self, _db, qu):
class SpecificWorkList(DatabaseBackedWorkList):
"""A WorkList that only finds specific works, identified by ID."""
+
def __init__(self, work_ids):
super(SpecificWorkList, self).__init__()
self.work_ids = work_ids
@@ -2428,19 +2477,18 @@ def __init__(self, work_ids):
def modify_database_query_hook(self, _db, qu):
qu = qu.filter(
Work.id.in_(self.work_ids),
- LicensePool.work_id.in_(self.work_ids), # Query optimization
+ LicensePool.work_id.in_(self.work_ids), # Query optimization
)
return qu
class LaneGenre(Base):
"""Relationship object between Lane and Genre."""
- __tablename__ = 'lanes_genres'
+
+ __tablename__ = "lanes_genres"
id = Column(Integer, primary_key=True)
- lane_id = Column(Integer, ForeignKey('lanes.id'), index=True,
- nullable=False)
- genre_id = Column(Integer, ForeignKey('genres.id'), index=True,
- nullable=False)
+ lane_id = Column(Integer, ForeignKey("lanes.id"), index=True, nullable=False)
+ genre_id = Column(Integer, ForeignKey("genres.id"), index=True, nullable=False)
# An inclusive relationship means that books classified under the
# genre are included in the lane. An exclusive relationship means
@@ -2453,9 +2501,7 @@ class LaneGenre(Base):
# means that only the genre itself is affected.
recursive = Column(Boolean, default=True, nullable=False)
- __table_args__ = (
- UniqueConstraint('lane_id', 'genre_id'),
- )
+ __table_args__ = (UniqueConstraint("lane_id", "genre_id"),)
@classmethod
def from_genre(cls, genre):
@@ -2464,6 +2510,7 @@ def from_genre(cls, genre):
lg.genre = genre
return lg
+
Genre.lane_genres = relationship(
"LaneGenre", foreign_keys=LaneGenre.genre_id, backref="genre"
)
@@ -2481,14 +2528,12 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):
# The set of Works in a standard Lane is cacheable for twenty
# minutes. Note that this only applies to paginated feeds --
# grouped feeds are cached indefinitely.
- MAX_CACHE_AGE = 20*60
+ MAX_CACHE_AGE = 20 * 60
- __tablename__ = 'lanes'
+ __tablename__ = "lanes"
id = Column(Integer, primary_key=True)
- library_id = Column(Integer, ForeignKey('libraries.id'), index=True,
- nullable=False)
- parent_id = Column(Integer, ForeignKey('lanes.id'), index=True,
- nullable=True)
+ library_id = Column(Integer, ForeignKey("libraries.id"), index=True, nullable=False)
+ parent_id = Column(Integer, ForeignKey("lanes.id"), index=True, nullable=True)
priority = Column(Integer, index=True, nullable=False, default=0)
# How many titles are in this lane? This is periodically
@@ -2502,16 +2547,17 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):
# A lane may have one parent lane and many sublanes.
sublanes = relationship(
"Lane",
- backref=backref("parent", remote_side = [id]),
+ backref=backref("parent", remote_side=[id]),
)
# A lane may have multiple associated LaneGenres. For most lanes,
# this is how the contents of the lanes are defined.
- genres = association_proxy('lane_genres', 'genre',
- creator=LaneGenre.from_genre)
+ genres = association_proxy("lane_genres", "genre", creator=LaneGenre.from_genre)
lane_genres = relationship(
- "LaneGenre", foreign_keys="LaneGenre.lane_id", backref="lane",
- cascade='all, delete-orphan'
+ "LaneGenre",
+ foreign_keys="LaneGenre.lane_id",
+ backref="lane",
+ cascade="all, delete-orphan",
)
# display_name is the name of the lane as shown to patrons. It's
@@ -2530,7 +2576,7 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):
# A lane may be restricted to works classified for specific audiences
# (e.g. only Young Adult works).
- _audiences = Column(ARRAY(Unicode), name='audiences')
+ _audiences = Column(ARRAY(Unicode), name="audiences")
# A lane may further be restricted to works classified as suitable
# for a specific age range.
@@ -2548,21 +2594,18 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):
# Only books licensed through this DataSource will be shown.
license_datasource_id = Column(
- Integer, ForeignKey('datasources.id'), index=True,
- nullable=True
+ Integer, ForeignKey("datasources.id"), index=True, nullable=True
)
# Only books on one or more CustomLists obtained from this
# DataSource will be shown.
_list_datasource_id = Column(
- Integer, ForeignKey('datasources.id'), index=True,
- nullable=True
+ Integer, ForeignKey("datasources.id"), index=True, nullable=True
)
# Only the books on these specific CustomLists will be shown.
customlists = relationship(
- "CustomList", secondary=lambda: lanes_customlists,
- backref="lane"
+ "CustomList", secondary=lambda: lanes_customlists, backref="lane"
)
# This has no effect unless list_datasource_id or
@@ -2590,9 +2633,7 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):
# one would want to see a big list containing everything, and b)
# the sublanes are exhaustive of the Lane's content, so there's
# nothing new to be seen by going into that big list.
- include_self_in_grouped_feed = Column(
- Boolean, default=True, nullable=False
- )
+ include_self_in_grouped_feed = Column(Boolean, default=True, nullable=False)
# Only a visible lane will show up in the user interface. The
# admin interface can see all the lanes, visible or not.
@@ -2600,19 +2641,19 @@ class Lane(Base, DatabaseBackedWorkList, HierarchyWorkList):
# A Lane may have many CachedFeeds.
cachedfeeds = relationship(
- "CachedFeed", backref="lane",
+ "CachedFeed",
+ backref="lane",
cascade="all, delete-orphan",
)
# A Lane may have many CachedMARCFiles.
cachedmarcfiles = relationship(
- "CachedMARCFile", backref="lane",
+ "CachedMARCFile",
+ backref="lane",
cascade="all, delete-orphan",
)
- __table_args__ = (
- UniqueConstraint('parent_id', 'display_name'),
- )
+ __table_args__ = (UniqueConstraint("parent_id", "display_name"),)
def get_library(self, _db):
"""For compatibility with WorkList.get_library()."""
@@ -2671,7 +2712,10 @@ def is_self_or_descendant(self, ancestor):
# A TopLevelWorkList won't show up in a Lane's parentage,
# because it's not a Lane, but if they share the same library
# it can be presumed to be the lane's ultimate ancestor.
- if isinstance(ancestor, TopLevelWorkList) and self.library_id==ancestor.library_id:
+ if (
+ isinstance(ancestor, TopLevelWorkList)
+ and self.library_id == ancestor.library_id
+ ):
return True
return False
@@ -2714,7 +2758,9 @@ def audiences(self, value):
contradicts the current value to the `target_age` field.
"""
if self._audiences and self._target_age and value != self._audiences:
- raise ValueError("Cannot modify Lane.audiences when Lane.target_age is set!")
+ raise ValueError(
+ "Cannot modify Lane.audiences when Lane.target_age is set!"
+ )
if isinstance(value, (bytes, str)):
value = [value]
self._audiences = value
@@ -2750,13 +2796,9 @@ def target_age(self, value):
if value.lower >= Classifier.ADULT_AGE_CUTOFF:
# Adults are adults and there's no point in tracking
# precise age gradations for them.
- value = tuple_to_numericrange(
- (Classifier.ADULT_AGE_CUTOFF, value.upper)
- )
+ value = tuple_to_numericrange((Classifier.ADULT_AGE_CUTOFF, value.upper))
if value.upper >= Classifier.ADULT_AGE_CUTOFF:
- value = tuple_to_numericrange(
- (value.lower, Classifier.ADULT_AGE_CUTOFF)
- )
+ value = tuple_to_numericrange((value.lower, Classifier.ADULT_AGE_CUTOFF))
self._target_age = value
if value.upper >= Classifier.ADULT_AGE_CUTOFF:
@@ -2778,7 +2820,7 @@ def list_datasource(self, value):
"""
if value:
self.customlists = []
- if hasattr(self, '_customlist_ids'):
+ if hasattr(self, "_customlist_ids"):
# The next time someone asks for .customlist_ids,
# the list will be refreshed.
del self._customlist_ids
@@ -2801,8 +2843,11 @@ def uses_customlists(self):
"""
if self.customlists or self.list_datasource:
return True
- if (self.parent and self.inherit_parent_restrictions
- and self.parent.uses_customlists):
+ if (
+ self.parent
+ and self.inherit_parent_restrictions
+ and self.parent.uses_customlists
+ ):
return True
return False
@@ -2826,15 +2871,18 @@ def update_size(self, _db, search_engine=None):
"""Update the stored estimate of the number of Works in this Lane."""
library = self.get_library(_db)
from .external_search import ExternalSearchIndex
+
search_engine = search_engine or ExternalSearchIndex.load(_db)
# Do the estimate for every known entry point.
by_entrypoint = dict()
for entrypoint in EntryPoint.ENTRY_POINTS:
facets = DatabaseBackedFacets(
- library, FacetConstants.COLLECTION_FULL,
+ library,
+ FacetConstants.COLLECTION_FULL,
FacetConstants.AVAILABLE_ALL,
- order=FacetConstants.ORDER_WORK_ID, entrypoint=entrypoint
+ order=FacetConstants.ORDER_WORK_ID,
+ entrypoint=entrypoint,
)
filter = self.filter(_db, facets)
by_entrypoint[entrypoint.URI] = search_engine.count_works(filter)
@@ -2849,7 +2897,7 @@ def genre_ids(self):
:return: A list of genre IDs, or None if this Lane does not
consider genres at all.
"""
- if not hasattr(self, '_genre_ids'):
+ if not hasattr(self, "_genre_ids"):
self._genre_ids = self._gather_genre_ids()
return self._genre_ids
@@ -2866,8 +2914,15 @@ def _gather_genre_ids(self):
bucket = included_ids
else:
bucket = excluded_ids
- if self.fiction != None and genre.default_fiction != None and self.fiction != genre.default_fiction:
- logging.error("Lane %s has a genre %s that does not match its fiction restriction.", (self.full_identifier, genre.name))
+ if (
+ self.fiction != None
+ and genre.default_fiction != None
+ and self.fiction != genre.default_fiction
+ ):
+ logging.error(
+ "Lane %s has a genre %s that does not match its fiction restriction.",
+ (self.full_identifier, genre.name),
+ )
bucket.add(genre.id)
if lanegenre.recursive:
for subgenre in genre.subgenres:
@@ -2883,8 +2938,7 @@ def _gather_genre_ids(self):
# Fantasy' is included but 'Fantasy' and its subgenres are
# excluded.
logging.error(
- "Lane %s has a self-negating set of genre IDs.",
- self.full_identifier
+ "Lane %s has a self-negating set of genre IDs.", self.full_identifier
)
return genre_ids
@@ -2895,7 +2949,7 @@ def customlist_ids(self):
:return: A list of CustomList IDs, possibly empty.
"""
- if not hasattr(self, '_customlist_ids'):
+ if not hasattr(self, "_customlist_ids"):
self._customlist_ids = self._gather_customlist_ids()
return self._customlist_ids
@@ -2906,8 +2960,7 @@ def _gather_customlist_ids(self):
# DataSource.
_db = Session.object_session(self)
query = select(
- [CustomList.id],
- CustomList.data_source_id==self.list_datasource.id
+ [CustomList.id], CustomList.data_source_id == self.list_datasource.id
)
ids = [x[0] for x in _db.execute(query)]
else:
@@ -2933,13 +2986,13 @@ def affected_by_customlist(self, customlist):
# Either the data source must match, or there must be a specific link
# between the Lane and the CustomList.
- data_source_matches = (
- Lane._list_datasource_id==customlist.data_source_id
- )
- specific_link = CustomList.id==customlist.id
+ data_source_matches = Lane._list_datasource_id == customlist.data_source_id
+ specific_link = CustomList.id == customlist.id
- return _db.query(Lane).outerjoin(Lane.customlists).filter(
- or_(data_source_matches, specific_link)
+ return (
+ _db.query(Lane)
+ .outerjoin(Lane.customlists)
+ .filter(or_(data_source_matches, specific_link))
)
def add_genre(self, genre, inclusive=True, recursive=True):
@@ -2951,11 +3004,9 @@ def add_genre(self, genre, inclusive=True, recursive=True):
_db = Session.object_session(self)
if isinstance(genre, (bytes, str)):
genre, ignore = Genre.lookup(_db, genre)
- lanegenre, is_new = get_one_or_create(
- _db, LaneGenre, lane=self, genre=genre
- )
- lanegenre.inclusive=inclusive
- lanegenre.recursive=recursive
+ lanegenre, is_new = get_one_or_create(_db, LaneGenre, lane=self, genre=genre)
+ lanegenre.inclusive = inclusive
+ lanegenre.recursive = recursive
self._genre_ids = self._gather_genre_ids()
return lanegenre, is_new
@@ -2979,7 +3030,10 @@ def search_target(self):
languages = self.languages
media = self.media
audiences = None
- if Classifier.AUDIENCE_YOUNG_ADULT in self.audiences or Classifier.AUDIENCE_CHILDREN in self.audiences:
+ if (
+ Classifier.AUDIENCE_YOUNG_ADULT in self.audiences
+ or Classifier.AUDIENCE_CHILDREN in self.audiences
+ ):
audiences = self.audiences
# If there are too many languages or audiences, the description
@@ -2997,8 +3051,13 @@ def search_target(self):
display_name = " ".join(display_name_parts)
wl = WorkList()
- wl.initialize(self.library, display_name=display_name,
- languages=languages, media=media, audiences=audiences)
+ wl.initialize(
+ self.library,
+ display_name=display_name,
+ languages=languages,
+ media=media,
+ audiences=audiences,
+ )
return wl
def _size_for_facets(self, facets):
@@ -3014,13 +3073,19 @@ def _size_for_facets(self, facets):
if facets and facets.entrypoint:
entrypoint_name = facets.entrypoint.URI
- if (self.size_by_entrypoint
- and entrypoint_name in self.size_by_entrypoint):
+ if self.size_by_entrypoint and entrypoint_name in self.size_by_entrypoint:
size = self.size_by_entrypoint[entrypoint_name]
return size
- def groups(self, _db, include_sublanes=True, pagination=None, facets=None,
- search_engine=None, debug=False):
+ def groups(
+ self,
+ _db,
+ include_sublanes=True,
+ pagination=None,
+ facets=None,
+ search_engine=None,
+ debug=False,
+ ):
"""Return a list of (Work, Lane) 2-tuples
describing a sequence of featured items for this lane and
(optionally) its children.
@@ -3046,15 +3111,20 @@ def groups(self, _db, include_sublanes=True, pagination=None, facets=None,
# lane's restrictions. Lanes that don't inherit this lane's
# restrictions will need to be handled in a separate call to
# groups().
- queryable_lanes = [x for x in relevant_lanes
- if x == self or x.inherit_parent_restrictions]
+ queryable_lanes = [
+ x for x in relevant_lanes if x == self or x.inherit_parent_restrictions
+ ]
return self._groups_for_lanes(
- _db, relevant_lanes, queryable_lanes, pagination=pagination,
- facets=facets, search_engine=search_engine, debug=debug
+ _db,
+ relevant_lanes,
+ queryable_lanes,
+ pagination=pagination,
+ facets=facets,
+ search_engine=search_engine,
+ debug=debug,
)
- def search(self, _db, query_string, search_client, pagination=None,
- facets=None):
+ def search(self, _db, query_string, search_client, pagination=None, facets=None):
"""Find works in this lane that also match a search query.
:param _db: A database connection.
@@ -3073,8 +3143,7 @@ def search(self, _db, query_string, search_client, pagination=None,
# Tell that object to run the search.
m = search_target.search
- return m(_db, query_string, search_client, pagination,
- facets=facets)
+ return m(_db, query_string, search_client, pagination, facets=facets)
def explain(self):
"""Create a series of human-readable strings to explain a lane's settings."""
@@ -3082,39 +3151,53 @@ def explain(self):
lines.append("ID: %s" % self.id)
lines.append("Library: %s" % self.library.short_name)
if self.parent:
- lines.append("Parent ID: %s (%s)" % (self.parent.id, self.parent.display_name))
+ lines.append(
+ "Parent ID: %s (%s)" % (self.parent.id, self.parent.display_name)
+ )
lines.append("Priority: %s" % self.priority)
lines.append("Display name: %s" % self.display_name)
return lines
-Library.lanes = relationship("Lane", backref="library", foreign_keys=Lane.library_id, cascade='all, delete-orphan')
-DataSource.list_lanes = relationship("Lane", backref="_list_datasource", foreign_keys=Lane._list_datasource_id)
-DataSource.license_lanes = relationship("Lane", backref="license_datasource", foreign_keys=Lane.license_datasource_id)
+
+Library.lanes = relationship(
+ "Lane",
+ backref="library",
+ foreign_keys=Lane.library_id,
+ cascade="all, delete-orphan",
+)
+DataSource.list_lanes = relationship(
+ "Lane", backref="_list_datasource", foreign_keys=Lane._list_datasource_id
+)
+DataSource.license_lanes = relationship(
+ "Lane", backref="license_datasource", foreign_keys=Lane.license_datasource_id
+)
lanes_customlists = Table(
- 'lanes_customlists', Base.metadata,
- Column(
- 'lane_id', Integer, ForeignKey('lanes.id'),
- index=True, nullable=False
- ),
+ "lanes_customlists",
+ Base.metadata,
+ Column("lane_id", Integer, ForeignKey("lanes.id"), index=True, nullable=False),
Column(
- 'customlist_id', Integer, ForeignKey('customlists.id'),
- index=True, nullable=False
+ "customlist_id",
+ Integer,
+ ForeignKey("customlists.id"),
+ index=True,
+ nullable=False,
),
- UniqueConstraint('lane_id', 'customlist_id'),
+ UniqueConstraint("lane_id", "customlist_id"),
)
-@event.listens_for(Lane, 'after_insert')
-@event.listens_for(Lane, 'after_delete')
-@event.listens_for(LaneGenre, 'after_insert')
-@event.listens_for(LaneGenre, 'after_delete')
+
+@event.listens_for(Lane, "after_insert")
+@event.listens_for(Lane, "after_delete")
+@event.listens_for(LaneGenre, "after_insert")
+@event.listens_for(LaneGenre, "after_delete")
def configuration_relevant_lifecycle_event(mapper, connection, target):
site_configuration_has_changed(target)
-@event.listens_for(Lane, 'after_update')
-@event.listens_for(LaneGenre, 'after_update')
+@event.listens_for(Lane, "after_update")
+@event.listens_for(LaneGenre, "after_update")
def configuration_relevant_update(mapper, connection, target):
if directly_modified(target):
site_configuration_has_changed(target)
diff --git a/lcp/credential.py b/lcp/credential.py
index 709ca09f8..e4200bd98 100644
--- a/lcp/credential.py
+++ b/lcp/credential.py
@@ -1,17 +1,16 @@
import logging
-
from enum import Enum
-from .exceptions import LCPError
from ..model import Credential, DataSource
+from .exceptions import LCPError
class LCPCredentialType(Enum):
"""Contains an enumeration of different LCP credential types"""
- PATRON_ID = 'Patron ID passed to the LCP License Server'
- LCP_PASSPHRASE = 'LCP Passphrase passed to the LCP License Server'
- LCP_HASHED_PASSPHRASE = 'Hashed LCP Passphrase passed to the LCP License Server'
+ PATRON_ID = "Patron ID passed to the LCP License Server"
+ LCP_PASSPHRASE = "LCP Passphrase passed to the LCP License Server"
+ LCP_HASHED_PASSPHRASE = "Hashed LCP Passphrase passed to the LCP License Server"
class LCPCredentialFactory(object):
@@ -21,7 +20,9 @@ def __init__(self):
"""Initializes a new instance of LCPCredentialFactory class"""
self._logger = logging.getLogger(__name__)
- def _get_or_create_persistent_token(self, db, patron, data_source_type, credential_type, value=None):
+ def _get_or_create_persistent_token(
+ self, db, patron, data_source_type, credential_type, value=None
+ ):
"""Gets or creates a new persistent token
:param db: Database session
@@ -42,17 +43,19 @@ def _get_or_create_persistent_token(self, db, patron, data_source_type, credenti
data_source = DataSource.lookup(db, data_source_type)
transaction = db.begin_nested()
- credential, is_new = Credential.persistent_token_create(db, data_source, credential_type, patron, value)
+ credential, is_new = Credential.persistent_token_create(
+ db, data_source, credential_type, patron, value
+ )
transaction.commit()
self._logger.info(
'Successfully {0} "{1}" {2} for {3} in "{4}" data source with value "{5}"'.format(
- 'created new' if is_new else 'fetched existing',
+ "created new" if is_new else "fetched existing",
credential_type,
credential,
patron,
data_source_type,
- value
+ value,
)
)
@@ -71,7 +74,11 @@ def get_patron_id(self, db, patron):
:rtype: string
"""
patron_id, _ = self._get_or_create_persistent_token(
- db, patron, DataSource.INTERNAL_PROCESSING, LCPCredentialType.PATRON_ID.value)
+ db,
+ patron,
+ DataSource.INTERNAL_PROCESSING,
+ LCPCredentialType.PATRON_ID.value,
+ )
return patron_id
@@ -88,7 +95,11 @@ def get_patron_passphrase(self, db, patron):
:rtype: string
"""
patron_passphrase, _ = self._get_or_create_persistent_token(
- db, patron, DataSource.INTERNAL_PROCESSING, LCPCredentialType.LCP_PASSPHRASE.value)
+ db,
+ patron,
+ DataSource.INTERNAL_PROCESSING,
+ LCPCredentialType.LCP_PASSPHRASE.value,
+ )
return patron_passphrase
@@ -105,10 +116,14 @@ def get_hashed_passphrase(self, db, patron):
:rtype: string
"""
hashed_passphrase, is_new = self._get_or_create_persistent_token(
- db, patron, DataSource.INTERNAL_PROCESSING, LCPCredentialType.LCP_HASHED_PASSPHRASE.value)
+ db,
+ patron,
+ DataSource.INTERNAL_PROCESSING,
+ LCPCredentialType.LCP_HASHED_PASSPHRASE.value,
+ )
if is_new:
- raise LCPError('Passphrase have to be explicitly set')
+ raise LCPError("Passphrase have to be explicitly set")
return hashed_passphrase
@@ -125,4 +140,9 @@ def set_hashed_passphrase(self, db, patron, hashed_passphrase):
:type hashed_passphrase: string
"""
self._get_or_create_persistent_token(
- db, patron, DataSource.INTERNAL_PROCESSING, LCPCredentialType.LCP_HASHED_PASSPHRASE.value, hashed_passphrase)
+ db,
+ patron,
+ DataSource.INTERNAL_PROCESSING,
+ LCPCredentialType.LCP_HASHED_PASSPHRASE.value,
+ hashed_passphrase,
+ )
diff --git a/local_analytics_provider.py b/local_analytics_provider.py
index b55c9e94a..28bea1aea 100644
--- a/local_analytics_provider.py
+++ b/local_analytics_provider.py
@@ -1,11 +1,7 @@
from flask_babel import lazy_gettext as _
-from .model import (
- Session,
- CirculationEvent,
- ExternalIntegration,
- get_one,
- create
-)
+
+from .model import CirculationEvent, ExternalIntegration, Session, create, get_one
+
class LocalAnalyticsProvider(object):
NAME = _("Local Analytics")
@@ -29,28 +25,42 @@ class LocalAnalyticsProvider(object):
{
"key": LOCATION_SOURCE,
"label": _("Geographic location of events"),
- "description": _("Local analytics events may have a geographic location associated with them. How should the location be determined?
Note: to use the patron's neighborhood as the event location, you must also tell your patron authentication mechanism how to gather a patron's neighborhood information."),
+ "description": _(
+ "Local analytics events may have a geographic location associated with them. How should the location be determined?
Note: to use the patron's neighborhood as the event location, you must also tell your patron authentication mechanism how to gather a patron's neighborhood information."
+ ),
"default": LOCATION_SOURCE_DISABLED,
"type": "select",
"options": [
- { "key": LOCATION_SOURCE_DISABLED, "label": _("Disable this feature.") },
- { "key": LOCATION_SOURCE_NEIGHBORHOOD, "label": _("Use the patron's neighborhood as the event location.") },
+ {"key": LOCATION_SOURCE_DISABLED, "label": _("Disable this feature.")},
+ {
+ "key": LOCATION_SOURCE_NEIGHBORHOOD,
+ "label": _("Use the patron's neighborhood as the event location."),
+ },
],
},
]
def __init__(self, integration, library=None):
self.integration_id = integration.id
- self.location_source = integration.setting(
- self.LOCATION_SOURCE
- ).value or self.LOCATION_SOURCE_DISABLED
+ self.location_source = (
+ integration.setting(self.LOCATION_SOURCE).value
+ or self.LOCATION_SOURCE_DISABLED
+ )
if library:
self.library_id = library.id
else:
self.library_id = None
- def collect_event(self, library, license_pool, event_type, time,
- old_value=None, new_value=None, **kwargs):
+ def collect_event(
+ self,
+ library,
+ license_pool,
+ event_type,
+ time,
+ old_value=None,
+ new_value=None,
+ **kwargs
+ ):
if not library and not license_pool:
raise ValueError("Either library or license_pool must be provided.")
if library:
@@ -65,20 +75,26 @@ def collect_event(self, library, license_pool, event_type, time,
neighborhood = kwargs.pop("neighborhood", None)
return CirculationEvent.log(
- _db, license_pool, event_type, old_value, new_value, start=time,
- library=library, location=neighborhood
+ _db,
+ license_pool,
+ event_type,
+ old_value,
+ new_value,
+ start=time,
+ library=library,
+ location=neighborhood,
)
@classmethod
def initialize(cls, _db):
- """Find or create a local analytics service.
- """
+ """Find or create a local analytics service."""
# If a local analytics service already exists, return it.
local_analytics = get_one(
- _db, ExternalIntegration,
+ _db,
+ ExternalIntegration,
protocol=cls.__module__,
- goal=ExternalIntegration.ANALYTICS_GOAL
+ goal=ExternalIntegration.ANALYTICS_GOAL,
)
# If a local analytics service already exists, don't create a
@@ -86,11 +102,13 @@ def initialize(cls, _db):
# "Local Analytics".
if not local_analytics:
local_analytics, ignore = create(
- _db, ExternalIntegration,
+ _db,
+ ExternalIntegration,
protocol=cls.__module__,
goal=ExternalIntegration.ANALYTICS_GOAL,
- name=str(cls.NAME)
+ name=str(cls.NAME),
)
return local_analytics
+
Provider = LocalAnalyticsProvider
diff --git a/log.py b/log.py
index eb07a1757..266a7915e 100644
--- a/log.py
+++ b/log.py
@@ -1,19 +1,19 @@
-
-import logging
import json
+import logging
import os
import socket
-from flask_babel import lazy_gettext as _
from io import StringIO
+
+from boto3.session import Session as AwsSession
+from flask_babel import lazy_gettext as _
from loggly.handlers import HTTPSHandler as LogglyHandler
from watchtower import CloudWatchLogHandler
-from boto3.session import Session as AwsSession
-from .config import Configuration
-from .config import CannotLoadConfiguration
-from .model import ExternalIntegration, ConfigurationSetting
+from .config import CannotLoadConfiguration, Configuration
+from .model import ConfigurationSetting, ExternalIntegration
from .util.datetime_helpers import utc_now
+
class JSONFormatter(logging.Formatter):
hostname = socket.gethostname()
fqdn = socket.getfqdn()
@@ -33,12 +33,11 @@ def ensure_str(s):
if isinstance(s, bytes):
s = s.decode("utf-8")
return s
+
message = ensure_str(record.msg)
if record.args:
- record_args = tuple(
- [ensure_str(arg) for arg in record.args]
- )
+ record_args = tuple([ensure_str(arg) for arg in record.args])
try:
message = message % record_args
except Exception as e:
@@ -47,8 +46,9 @@ def ensure_str(s):
# code shouldn't break the code that actually does the
# work, but we can't just let this slide -- we need to
# report the problem so it can be fixed.
- message = "Log message could not be formatted. Exception: %r. Original message: message=%r args=%r" % (
- e, message, record_args
+ message = (
+ "Log message could not be formatted. Exception: %r. Original message: message=%r args=%r"
+ % (e, message, record_args)
)
data = dict(
host=self.hostname,
@@ -57,16 +57,16 @@ def ensure_str(s):
level=record.levelname,
filename=record.filename,
message=message,
- timestamp=utc_now().isoformat()
+ timestamp=utc_now().isoformat(),
)
if record.exc_info:
- data['traceback'] = self.formatException(record.exc_info)
+ data["traceback"] = self.formatException(record.exc_info)
return json.dumps(data)
class StringFormatter(logging.Formatter):
- """Encode all output as a string.
- """
+ """Encode all output as a string."""
+
def format(self, record):
data = super(StringFormatter, self).format(record)
return str(data)
@@ -75,14 +75,18 @@ def format(self, record):
class Logger(object):
"""Abstract base class for logging"""
- DEFAULT_APP_NAME = 'simplified'
+ DEFAULT_APP_NAME = "simplified"
- JSON_LOG_FORMAT = 'json'
- TEXT_LOG_FORMAT = 'text'
- DEFAULT_MESSAGE_TEMPLATE = "%(asctime)s:%(name)s:%(levelname)s:%(filename)s:%(message)s"
+ JSON_LOG_FORMAT = "json"
+ TEXT_LOG_FORMAT = "text"
+ DEFAULT_MESSAGE_TEMPLATE = (
+ "%(asctime)s:%(name)s:%(levelname)s:%(filename)s:%(message)s"
+ )
@classmethod
- def set_formatter(cls, handler, app_name=None, log_format=None, message_template=None):
+ def set_formatter(
+ cls, handler, app_name=None, log_format=None, message_template=None
+ ):
"""Tell the given `handler` to format its log messages in a
certain way.
"""
@@ -103,27 +107,31 @@ def from_configuration(cls, _db, testing=False):
"""Should be implemented in each logging class."""
raise NotImplementedError()
+
class SysLogger(Logger):
- NAME = 'sysLog'
+ NAME = "sysLog"
# Settings for the integration with protocol=INTERNAL_LOGGING
- LOG_FORMAT = 'log_format'
- LOG_MESSAGE_TEMPLATE = 'message_template'
+ LOG_FORMAT = "log_format"
+ LOG_MESSAGE_TEMPLATE = "message_template"
SETTINGS = [
{
- "key": LOG_FORMAT, "label": _("Log Format"), "type": "select",
+ "key": LOG_FORMAT,
+ "label": _("Log Format"),
+ "type": "select",
"options": [
- { "key": Logger.JSON_LOG_FORMAT, "label": _("json") },
- { "key": Logger.TEXT_LOG_FORMAT, "label": _("text") }
- ]
+ {"key": Logger.JSON_LOG_FORMAT, "label": _("json")},
+ {"key": Logger.TEXT_LOG_FORMAT, "label": _("text")},
+ ],
},
{
- "key": LOG_MESSAGE_TEMPLATE, "label": _("template"),
+ "key": LOG_MESSAGE_TEMPLATE,
+ "label": _("template"),
"default": Logger.DEFAULT_MESSAGE_TEMPLATE,
"required": True,
- }
+ },
]
SITEWIDE = True
@@ -151,32 +159,39 @@ def from_configuration(cls, _db, testing=False):
if internal:
internal_log_format = (
- internal.setting(cls.LOG_FORMAT).value
- or internal_log_format
+ internal.setting(cls.LOG_FORMAT).value or internal_log_format
)
message_template = (
- internal.setting(cls.LOG_MESSAGE_TEMPLATE).value
- or message_template
+ internal.setting(cls.LOG_MESSAGE_TEMPLATE).value or message_template
+ )
+ app_name = (
+ ConfigurationSetting.sitewide(_db, Configuration.LOG_APP_NAME).value
+ or app_name
)
- app_name = ConfigurationSetting.sitewide(_db, Configuration.LOG_APP_NAME).value or app_name
handler = logging.StreamHandler()
- cls.set_formatter(handler, log_format=internal_log_format, message_template=message_template, app_name=app_name)
+ cls.set_formatter(
+ handler,
+ log_format=internal_log_format,
+ message_template=message_template,
+ app_name=app_name,
+ )
return handler
+
class Loggly(Logger):
NAME = "Loggly"
DEFAULT_LOGGLY_URL = "https://logs-01.loggly.com/inputs/%(token)s/tag/python/"
- USER = 'user'
- PASSWORD = 'password'
- URL = 'url'
+ USER = "user"
+ PASSWORD = "password"
+ URL = "url"
SETTINGS = [
- { "key": USER, "label": _("Username"), "required": True },
- { "key": PASSWORD, "label": _("Password"), "required": True },
- { "key": URL, "label": _("URL"), "required": True, "format": "url" },
+ {"key": USER, "label": _("Username"), "required": True},
+ {"key": PASSWORD, "label": _("Password"), "required": True},
+ {"key": URL, "label": _("URL"), "required": True, "format": "url"},
]
SITEWIDE = True
@@ -184,15 +199,16 @@ class Loggly(Logger):
@classmethod
def from_configuration(cls, _db, testing=False):
loggly = None
- from .model import (ExternalIntegration, ConfigurationSetting)
+ from .model import ConfigurationSetting, ExternalIntegration
app_name = cls.DEFAULT_APP_NAME
if _db and not testing:
goal = ExternalIntegration.LOGGING_GOAL
- loggly = ExternalIntegration.lookup(
- _db, ExternalIntegration.LOGGLY, goal
+ loggly = ExternalIntegration.lookup(_db, ExternalIntegration.LOGGLY, goal)
+ app_name = (
+ ConfigurationSetting.sitewide(_db, Configuration.LOG_APP_NAME).value
+ or app_name
)
- app_name = ConfigurationSetting.sitewide(_db, Configuration.LOG_APP_NAME).value or app_name
if loggly:
loggly = Loggly.loggly_handler(loggly)
@@ -202,8 +218,7 @@ def from_configuration(cls, _db, testing=False):
@classmethod
def loggly_handler(cls, externalintegration):
- """Turn a Loggly ExternalIntegration into a log handler.
- """
+ """Turn a Loggly ExternalIntegration into a log handler."""
token = externalintegration.password
url = externalintegration.url or cls.DEFAULT_LOGGLY_URL
if not url:
@@ -214,17 +229,19 @@ def loggly_handler(cls, externalintegration):
url = cls._interpolate_loggly_url(url, token)
except (TypeError, KeyError) as e:
raise CannotLoadConfiguration(
- "Cannot interpolate token %s into loggly URL %s" % (
- token, url,
+ "Cannot interpolate token %s into loggly URL %s"
+ % (
+ token,
+ url,
)
)
return LogglyHandler(url)
@classmethod
def _interpolate_loggly_url(cls, url, token):
- if '%s' in url:
+ if "%s" in url:
return url % token
- if '%(' in url:
+ if "%(" in url:
return url % dict(token=token)
# Assume the token is already in the URL.
@@ -238,38 +255,39 @@ def set_formatter(cls, handler, app_name):
formatter = JSONFormatter(app_name)
handler.setFormatter(formatter)
+
class CloudwatchLogs(Logger):
NAME = "AWS Cloudwatch Logs"
- GROUP = 'group'
- STREAM = 'stream'
- INTERVAL = 'interval'
- CREATE_GROUP = 'create_group'
- REGION = 'region'
- DEFAULT_REGION = 'us-west-2'
+ GROUP = "group"
+ STREAM = "stream"
+ INTERVAL = "interval"
+ CREATE_GROUP = "create_group"
+ REGION = "region"
+ DEFAULT_REGION = "us-west-2"
DEFAULT_INTERVAL = 60
DEFAULT_CREATE_GROUP = "TRUE"
# https://docs.aws.amazon.com/general/latest/gr/rande.html#cwl_region
REGIONS = [
- {"key": "us-east-2", "label": _("US East (Ohio)")},
- {"key": "us-east-1", "label": _("US East (N. Virginia)")},
- {"key": "us-west-1", "label": _("US West (N. California)")},
- {"key": "us-west-2", "label": _("US West (Oregon)")},
- {"key": "ap-south-1", "label": _("Asia Pacific (Mumbai)")},
+ {"key": "us-east-2", "label": _("US East (Ohio)")},
+ {"key": "us-east-1", "label": _("US East (N. Virginia)")},
+ {"key": "us-west-1", "label": _("US West (N. California)")},
+ {"key": "us-west-2", "label": _("US West (Oregon)")},
+ {"key": "ap-south-1", "label": _("Asia Pacific (Mumbai)")},
{"key": "ap-northeast-3", "label": _("Asia Pacific (Osaka-Local)")},
{"key": "ap-northeast-2", "label": _("Asia Pacific (Seoul)")},
{"key": "ap-southeast-1", "label": _("Asia Pacific (Singapore)")},
{"key": "ap-southeast-2", "label": _("Asia Pacific (Sydney)")},
{"key": "ap-northeast-1", "label": _("Asia Pacific (Tokyo)")},
- {"key": "ca-central-1", "label": _("Canada (Central)")},
- {"key": "cn-north-1", "label": _("China (Beijing)")},
+ {"key": "ca-central-1", "label": _("Canada (Central)")},
+ {"key": "cn-north-1", "label": _("China (Beijing)")},
{"key": "cn-northwest-1", "label": _("China (Ningxia)")},
- {"key": "eu-central-1", "label": _("EU (Frankfurt)")},
- {"key": "eu-west-1", "label": _("EU (Ireland)")},
- {"key": "eu-west-2", "label": _("EU (London)")},
- {"key": "eu-west-3", "label": _("EU (Paris)")},
- {"key": "sa-east-1", "label": _("South America (Sao Paulo)")},
+ {"key": "eu-central-1", "label": _("EU (Frankfurt)")},
+ {"key": "eu-west-1", "label": _("EU (Ireland)")},
+ {"key": "eu-west-2", "label": _("EU (London)")},
+ {"key": "eu-west-3", "label": _("EU (Paris)")},
+ {"key": "sa-east-1", "label": _("South America (Sao Paulo)")},
]
SETTINGS = [
@@ -304,8 +322,8 @@ class CloudwatchLogs(Logger):
"label": _("Automatically Create Log Group"),
"type": "select",
"options": [
- { "key": "TRUE", "label": _("Yes") },
- { "key": "FALSE", "label": _("No") },
+ {"key": "TRUE", "label": _("Yes")},
+ {"key": "FALSE", "label": _("No")},
],
"default": True,
"required": True,
@@ -325,7 +343,10 @@ def from_configuration(cls, _db, testing=False):
settings = ExternalIntegration.lookup(
_db, ExternalIntegration.CLOUDWATCH, goal
)
- app_name = ConfigurationSetting.sitewide(_db, Configuration.LOG_APP_NAME).value or app_name
+ app_name = (
+ ConfigurationSetting.sitewide(_db, Configuration.LOG_APP_NAME).value
+ or app_name
+ )
if settings:
cloudwatch = cls.get_handler(settings, testing)
@@ -335,13 +356,14 @@ def from_configuration(cls, _db, testing=False):
@classmethod
def get_handler(cls, settings, testing=False):
- """Turn ExternalIntegration into a log handler.
- """
+ """Turn ExternalIntegration into a log handler."""
group = settings.setting(cls.GROUP).value or cls.DEFAULT_APP_NAME
stream = settings.setting(cls.STREAM).value or cls.DEFAULT_APP_NAME
interval = settings.setting(cls.INTERVAL).value or cls.DEFAULT_INTERVAL
region = settings.setting(cls.REGION).value or cls.DEFAULT_REGION
- create_group = settings.setting(cls.CREATE_GROUP).value or cls.DEFAULT_CREATE_GROUP
+ create_group = (
+ settings.setting(cls.CREATE_GROUP).value or cls.DEFAULT_CREATE_GROUP
+ )
try:
interval = int(interval)
@@ -359,16 +381,18 @@ def get_handler(cls, settings, testing=False):
stream_name=stream,
send_interval=interval,
boto3_session=session,
- create_log_group=create_group == "TRUE"
+ create_log_group=create_group == "TRUE",
)
# Add a filter that makes sure no messages from botocore are processed by
# the cloudwatch logs integration, as these messages can lead to an infinite loop.
class BotoFilter(logging.Filter):
def filter(self, record):
- return not record.name.startswith('botocore')
+ return not record.name.startswith("botocore")
+
handler.addFilter(BotoFilter())
return handler
+
class LogConfiguration(object):
"""Configures the active Python logging handlers based on logging
configuration from the database.
@@ -381,34 +405,47 @@ class LogConfiguration(object):
# The default value to put into the 'app' field of JSON-format logs,
# unless LOG_APP_NAME overrides it.
- DEFAULT_APP_NAME = 'simplified'
- LOG_APP_NAME = 'log_app'
+ DEFAULT_APP_NAME = "simplified"
+ LOG_APP_NAME = "log_app"
DEFAULT_LOG_LEVEL = INFO
DEFAULT_DATABASE_LOG_LEVEL = WARN
# Settings for the integration with protocol=INTERNAL_LOGGING
- LOG_LEVEL = 'log_level'
- DATABASE_LOG_LEVEL = 'database_log_level'
+ LOG_LEVEL = "log_level"
+ DATABASE_LOG_LEVEL = "database_log_level"
LOG_LEVEL_UI = [
- { "key": DEBUG, "value": _("Debug") },
- { "key": INFO, "value": _("Info") },
- { "key": WARN, "value": _("Warn") },
- { "key": ERROR, "value": _("Error") },
+ {"key": DEBUG, "value": _("Debug")},
+ {"key": INFO, "value": _("Info")},
+ {"key": WARN, "value": _("Warn")},
+ {"key": ERROR, "value": _("Error")},
]
SITEWIDE_SETTINGS = [
- { "key": LOG_LEVEL, "label": _("Log Level"), "type": "select",
- "options": LOG_LEVEL_UI, "default": INFO,
+ {
+ "key": LOG_LEVEL,
+ "label": _("Log Level"),
+ "type": "select",
+ "options": LOG_LEVEL_UI,
+ "default": INFO,
},
- { "key": LOG_APP_NAME, "label": _("Log Application name"),
- "description": _("Log messages originating from this application will be tagged with this name. If you run multiple instances, giving each one a different application name will help you determine which instance is having problems."),
- "default": DEFAULT_APP_NAME,
+ {
+ "key": LOG_APP_NAME,
+ "label": _("Log Application name"),
+ "description": _(
+ "Log messages originating from this application will be tagged with this name. If you run multiple instances, giving each one a different application name will help you determine which instance is having problems."
+ ),
+ "default": DEFAULT_APP_NAME,
},
- { "key": DATABASE_LOG_LEVEL, "label": _("Database Log Level"),
- "type": "select", "options": LOG_LEVEL_UI,
- "description": _("Database logs are extremely verbose, so unless you're diagnosing a database-related problem, it's a good idea to set a higher log level for database messages."),
- "default": WARN,
+ {
+ "key": DATABASE_LOG_LEVEL,
+ "label": _("Database Log Level"),
+ "type": "select",
+ "options": LOG_LEVEL_UI,
+ "description": _(
+ "Database logs are extremely verbose, so unless you're diagnosing a database-related problem, it's a good idea to set a higher log level for database messages."
+ ),
+ "default": WARN,
},
]
@@ -422,8 +459,8 @@ def initialize(cls, _db, testing=False):
:param testing: True if unit tests are currently running; otherwise False.
"""
- log_level, database_log_level, new_handlers, errors = (
- cls.from_configuration(_db, testing)
+ log_level, database_log_level, new_handlers, errors = cls.from_configuration(
+ _db, testing
)
# Replace the set of handlers associated with the root logger.
@@ -439,9 +476,10 @@ def initialize(cls, _db, testing=False):
# Set the loggers for various verbose libraries to the database
# log level, which is probably higher than the normal log level.
for logger in (
- 'sqlalchemy.engine', 'elasticsearch',
- 'requests.packages.urllib3.connectionpool',
- 'botocore'
+ "sqlalchemy.engine",
+ "elasticsearch",
+ "requests.packages.urllib3.connectionpool",
+ "botocore",
):
logging.getLogger(logger).setLevel(database_log_level)
@@ -453,7 +491,7 @@ def initialize(cls, _db, testing=False):
loop_prevention_log_level = cls.ERROR
else:
loop_prevention_log_level = cls.WARN
- for logger in ['urllib3.connectionpool']:
+ for logger in ["urllib3.connectionpool"]:
logging.getLogger(logger).setLevel(loop_prevention_log_level)
# If we had an error creating any log handlers report it
@@ -491,7 +529,9 @@ def from_configuration(cls, _db, testing=False):
or log_level
)
database_log_level = (
- ConfigurationSetting.sitewide(_db, Configuration.DATABASE_LOG_LEVEL).value
+ ConfigurationSetting.sitewide(
+ _db, Configuration.DATABASE_LOG_LEVEL
+ ).value
or database_log_level
)
@@ -505,8 +545,6 @@ def from_configuration(cls, _db, testing=False):
if handler:
handlers.append(handler)
except Exception as e:
- errors.append(
- "Error creating logger %s %s" % (logger.NAME, str(e))
- )
+ errors.append("Error creating logger %s %s" % (logger.NAME, str(e)))
return log_level, database_log_level, handlers, errors
diff --git a/marc.py b/marc.py
index 87057bf0d..4700f76ad 100644
--- a/marc.py
+++ b/marc.py
@@ -1,25 +1,15 @@
-
+import re
from io import BytesIO
+
from flask_babel import lazy_gettext as _
-import re
-from pymarc import (
- Field,
- Record,
- MARCWriter
-)
+from pymarc import Field, MARCWriter, Record
-from .config import (
- Configuration,
- CannotLoadConfiguration,
-)
-from .lane import BaseFacets
-from .external_search import (
- ExternalSearchIndex,
- SortKeyPagination,
-)
+from .classifier import Classifier
+from .config import CannotLoadConfiguration, Configuration
+from .external_search import ExternalSearchIndex, SortKeyPagination
+from .lane import BaseFacets, Lane
+from .mirror import MirrorUploader
from .model import (
- get_one,
- get_one_or_create,
CachedMARCFile,
Collection,
ConfigurationSetting,
@@ -30,14 +20,14 @@
Representation,
Session,
Work,
+ get_one,
+ get_one_or_create,
)
-from .classifier import Classifier
-from .mirror import MirrorUploader
from .s3 import S3Uploader
-from .lane import Lane
from .util import LanguageCodes
from .util.datetime_helpers import utc_now
+
class Annotator(object):
"""The Annotator knows how to add information about a Work to
a MARC record."""
@@ -57,13 +47,24 @@ class Annotator(object):
# There doesn't seem to be any particular vocabulary for this.
FORMAT_TERMS = {
(Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM): "EPUB eBook",
- (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM): "Adobe EPUB eBook",
+ (
+ Representation.EPUB_MEDIA_TYPE,
+ DeliveryMechanism.ADOBE_DRM,
+ ): "Adobe EPUB eBook",
(Representation.PDF_MEDIA_TYPE, DeliveryMechanism.NO_DRM): "PDF eBook",
(Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM): "Adobe PDF eBook",
}
- def annotate_work_record(self, work, active_license_pool, edition,
- identifier, record, integration=None, updated=None):
+ def annotate_work_record(
+ self,
+ work,
+ active_license_pool,
+ edition,
+ identifier,
+ record,
+ integration=None,
+ updated=None,
+ ):
"""Add metadata from this work to a MARC record.
:work: The Work whose record is being annotated.
@@ -83,16 +84,18 @@ def leader(cls, work):
# The record length is automatically updated once fields are added.
initial_record_length = "00000"
- record_status = "n" # New record
+ record_status = "n" # New record
if getattr(work, cls.marc_cache_field):
- record_status = "c" # Corrected or revised
+ record_status = "c" # Corrected or revised
# Distributors consistently seem to use type "a" - language material - for
# ebooks, though there is also type "m" for computer files.
record_type = "a"
- bibliographic_level = "m" # Monograph/item
-
- leader = initial_record_length + record_status + record_type + bibliographic_level
+ bibliographic_level = "m" # Monograph/item
+
+ leader = (
+ initial_record_length + record_status + record_type + bibliographic_level
+ )
# Additional information about the record that's always the same.
leader += " 2200000 4500"
return leader
@@ -100,17 +103,14 @@ def leader(cls, work):
@classmethod
def add_control_fields(cls, record, identifier, pool, edition):
# Unique identifier for this record.
- record.add_field(
- Field(tag="001", data=identifier.urn))
+ record.add_field(Field(tag="001", data=identifier.urn))
# Field 003 (MARC organization code) is library-specific, so it's added separately.
- record.add_field(
- Field(tag="005", data=utc_now().strftime("%Y%m%d%H%M%S.0")))
+ record.add_field(Field(tag="005", data=utc_now().strftime("%Y%m%d%H%M%S.0")))
# Field 006: m = computer file, d = the file is a document
- record.add_field(
- Field(tag="006", data="m d "))
+ record.add_field(Field(tag="006", data="m d "))
# Field 007: more details about electronic resource
# Since this depends on the pool, it might be better not to cache it.
@@ -121,17 +121,18 @@ def add_control_fields(cls, record, identifier, pool, edition):
else:
file_formats_code = "m"
record.add_field(
- Field(tag="007", data="cr cn ---" + file_formats_code + "nuuu"))
+ Field(tag="007", data="cr cn ---" + file_formats_code + "nuuu")
+ )
# Field 008 (fixed-length data elements):
data = utc_now().strftime("%y%m%d")
publication_date = edition.issued or edition.published
if publication_date:
- date_type = "s" # single known date
+ date_type = "s" # single known date
# Not using strftime because some years are pre-1900.
date_value = "%04i" % publication_date.year
else:
- date_type = "n" # dates unknown
+ date_type = "n" # dates unknown
date_value = " "
data += date_type + date_value
data += " "
@@ -144,13 +145,11 @@ def add_control_fields(cls, record, identifier, pool, edition):
language = LanguageCodes.string_to_alpha_3(edition.language)
data += language
data += " "
- record.add_field(
- Field(tag="008", data=data))
+ record.add_field(Field(tag="008", data=data))
@classmethod
def add_marc_organization_code(cls, record, marc_org):
- record.add_field(
- Field(tag="003", data=marc_org))
+ record.add_field(Field(tag="003", data=marc_org))
@classmethod
def add_isbn(cls, record, identifier):
@@ -161,18 +160,24 @@ def add_isbn(cls, record, identifier):
if not isbn:
_db = Session.object_session(identifier)
identifier_ids = identifier.equivalent_identifier_ids()[identifier.id]
- isbn = _db.query(Identifier).filter(
- Identifier.type==Identifier.ISBN).filter(
- Identifier.id.in_(identifier_ids)).order_by(
- Identifier.id).first()
+ isbn = (
+ _db.query(Identifier)
+ .filter(Identifier.type == Identifier.ISBN)
+ .filter(Identifier.id.in_(identifier_ids))
+ .order_by(Identifier.id)
+ .first()
+ )
if isbn:
record.add_field(
Field(
tag="020",
- indicators=[" "," "],
+ indicators=[" ", " "],
subfields=[
- "a", isbn.identifier,
- ]))
+ "a",
+ isbn.identifier,
+ ],
+ )
+ )
@classmethod
def add_title(cls, record, edition):
@@ -181,7 +186,7 @@ def add_title(cls, record, edition):
# the title and the sort_title.
non_filing_characters = 0
if edition.title != edition.sort_title and ("," in edition.sort_title):
- stemmed = edition.sort_title[:edition.sort_title.rindex(",")]
+ stemmed = edition.sort_title[: edition.sort_title.rindex(",")]
non_filing_characters = edition.title.index(stemmed)
# MARC only supports up to 9 non-filing characters, but if we got more
# something is probably wrong anyway.
@@ -198,7 +203,8 @@ def add_title(cls, record, edition):
tag="245",
indicators=["0", non_filing_characters],
subfields=subfields,
- ))
+ )
+ )
@classmethod
def add_contributors(cls, record, edition):
@@ -213,10 +219,13 @@ def add_contributors(cls, record, edition):
record.add_field(
Field(
tag="100",
- indicators=["1"," "],
+ indicators=["1", " "],
subfields=[
- "a", str(edition.sort_author),
- ]))
+ "a",
+ str(edition.sort_author),
+ ],
+ )
+ )
if len(edition.contributions) > 1:
for contribution in edition.contributions:
@@ -226,9 +235,13 @@ def add_contributors(cls, record, edition):
tag="700",
indicators=["1", " "],
subfields=[
- "a", str(contributor.sort_name),
- "e", contribution.role,
- ]))
+ "a",
+ str(contributor.sort_name),
+ "e",
+ contribution.role,
+ ],
+ )
+ )
@classmethod
def add_publisher(cls, record, edition):
@@ -242,10 +255,15 @@ def add_publisher(cls, record, edition):
tag="264",
indicators=[" ", "1"],
subfields=[
- "a", "[Place of publication not identified]",
- "b", str(edition.publisher or ""),
- "c", year,
- ]))
+ "a",
+ "[Place of publication not identified]",
+ "b",
+ str(edition.publisher or ""),
+ "c",
+ year,
+ ],
+ )
+ )
@classmethod
def add_distributor(cls, record, pool):
@@ -255,8 +273,11 @@ def add_distributor(cls, record, pool):
tag="264",
indicators=[" ", "2"],
subfields=[
- "b", str(pool.data_source.name),
- ]))
+ "b",
+ str(pool.data_source.name),
+ ],
+ )
+ )
@classmethod
def add_physical_description(cls, record, edition):
@@ -267,58 +288,63 @@ def add_physical_description(cls, record, edition):
tag="300",
indicators=[" ", " "],
subfields=[
- "a", "1 online resource",
- ]))
+ "a",
+ "1 online resource",
+ ],
+ )
+ )
record.add_field(
Field(
tag="336",
indicators=[" ", " "],
- subfields=[
- "a", "text",
- "b", "txt",
- "2", "rdacontent"
- ]))
+ subfields=["a", "text", "b", "txt", "2", "rdacontent"],
+ )
+ )
elif edition.medium == Edition.AUDIO_MEDIUM:
record.add_field(
Field(
tag="300",
indicators=[" ", " "],
subfields=[
- "a", "1 sound file",
- "b", "digital",
- ]))
+ "a",
+ "1 sound file",
+ "b",
+ "digital",
+ ],
+ )
+ )
record.add_field(
Field(
tag="336",
indicators=[" ", " "],
- subfields=[
- "a", "spoken word",
- "b", "spw",
- "2", "rdacontent"
- ]))
+ subfields=["a", "spoken word", "b", "spw", "2", "rdacontent"],
+ )
+ )
record.add_field(
Field(
tag="337",
indicators=[" ", " "],
- subfields=[
- "a", "computer",
- "b", "c",
- "2", "rdamedia"
- ]))
+ subfields=["a", "computer", "b", "c", "2", "rdamedia"],
+ )
+ )
record.add_field(
Field(
tag="338",
indicators=[" ", " "],
subfields=[
- "a", "online resource",
- "b", "cr",
- "2", "rdacarrier",
- ]))
-
+ "a",
+ "online resource",
+ "b",
+ "cr",
+ "2",
+ "rdacarrier",
+ ],
+ )
+ )
file_type = None
if edition.medium == Edition.BOOK_MEDIUM:
@@ -331,9 +357,13 @@ def add_physical_description(cls, record, edition):
tag="347",
indicators=[" ", " "],
subfields=[
- "a", file_type,
- "2", "rda",
- ]))
+ "a",
+ file_type,
+ "2",
+ "rda",
+ ],
+ )
+ )
# Form of work
form = None
@@ -348,9 +378,13 @@ def add_physical_description(cls, record, edition):
tag="380",
indicators=[" ", " "],
subfields=[
- "a", "eBook",
- "2", "tlcgt",
- ]))
+ "a",
+ "eBook",
+ "2",
+ "tlcgt",
+ ],
+ )
+ )
@classmethod
def add_audience(cls, record, work):
@@ -358,11 +392,15 @@ def add_audience(cls, record, work):
record.add_field(
Field(
tag="385",
- indicators=[" ", " "],
+ indicators=[" ", " "],
subfields=[
- "a", audience,
- "2", "tlctarget",
- ]))
+ "a",
+ audience,
+ "2",
+ "tlctarget",
+ ],
+ )
+ )
@classmethod
def add_series(cls, record, edition):
@@ -375,7 +413,8 @@ def add_series(cls, record, edition):
tag="490",
indicators=["0", " "],
subfields=subfields,
- ))
+ )
+ )
@classmethod
def add_system_details(cls, record):
@@ -383,9 +422,9 @@ def add_system_details(cls, record):
Field(
tag="538",
indicators=[" ", " "],
- subfields=[
- "a", "Mode of access: World Wide Web."
- ]))
+ subfields=["a", "Mode of access: World Wide Web."],
+ )
+ )
@classmethod
def add_formats(cls, record, pool):
@@ -398,24 +437,29 @@ def add_formats(cls, record, pool):
record.add_field(
Field(
tag="538",
- indicators=[" "," "],
+ indicators=[" ", " "],
subfields=[
- "a", format,
- ]))
-
+ "a",
+ format,
+ ],
+ )
+ )
@classmethod
def add_summary(cls, record, work):
summary = work.summary_text
if summary:
- stripped = re.sub('<[^>]+?>', ' ', summary)
+ stripped = re.sub("<[^>]+?>", " ", summary)
record.add_field(
Field(
tag="520",
indicators=[" ", " "],
subfields=[
- "a", stripped.encode('ascii', 'ignore'),
- ]))
+ "a",
+ stripped.encode("ascii", "ignore"),
+ ],
+ )
+ )
@classmethod
def add_simplified_genres(cls, record, work):
@@ -429,9 +473,13 @@ def add_simplified_genres(cls, record, work):
tag="650",
indicators=["0", "7"],
subfields=[
- "a", genre.name,
- "2", "Library Simplified",
- ]))
+ "a",
+ genre.name,
+ "2",
+ "Library Simplified",
+ ],
+ )
+ )
@classmethod
def add_ebooks_subject(cls, record):
@@ -441,8 +489,11 @@ def add_ebooks_subject(cls, record):
tag="655",
indicators=[" ", "0"],
subfields=[
- "a", "Electronic books.",
- ]))
+ "a",
+ "Electronic books.",
+ ],
+ )
+ )
class MARCExporterFacets(BaseFacets):
@@ -466,7 +517,9 @@ class MARCExporter(object):
NAME = ExternalIntegration.MARC_EXPORT
- DESCRIPTION = _("Export metadata into MARC files that can be imported into an ILS manually.")
+ DESCRIPTION = _(
+ "Export metadata into MARC files that can be imported into an ILS manually."
+ )
# This setting (in days) controls how often MARC files should be
# automatically updated. Since the crontab in docker isn't easily
@@ -480,65 +533,85 @@ class MARCExporter(object):
# http://www.loc.gov/marc/organizations/org-search.php
MARC_ORGANIZATION_CODE = "marc_organization_code"
- WEB_CLIENT_URL = 'marc_web_client_url'
- INCLUDE_SUMMARY = 'include_summary'
- INCLUDE_SIMPLIFIED_GENRES = 'include_simplified_genres'
+ WEB_CLIENT_URL = "marc_web_client_url"
+ INCLUDE_SUMMARY = "include_summary"
+ INCLUDE_SIMPLIFIED_GENRES = "include_simplified_genres"
LIBRARY_SETTINGS = [
- { "key": UPDATE_FREQUENCY,
- "label": _("Update frequency (in days)"),
- "description": _("The circulation manager will wait this number of days between generating MARC files."),
- "type": "number",
- "default": DEFAULT_UPDATE_FREQUENCY,
+ {
+ "key": UPDATE_FREQUENCY,
+ "label": _("Update frequency (in days)"),
+ "description": _(
+ "The circulation manager will wait this number of days between generating MARC files."
+ ),
+ "type": "number",
+ "default": DEFAULT_UPDATE_FREQUENCY,
},
- { "key": MARC_ORGANIZATION_CODE,
- "label": _("The MARC organization code for this library (003 field)."),
- "description": _("MARC organization codes are assigned by the Library of Congress."),
+ {
+ "key": MARC_ORGANIZATION_CODE,
+ "label": _("The MARC organization code for this library (003 field)."),
+ "description": _(
+ "MARC organization codes are assigned by the Library of Congress."
+ ),
},
{
- "key": WEB_CLIENT_URL,
- "label": _("The base URL for the web catalog for this library, for the 856 field."),
- "description": _("If using a library registry that provides a web catalog, this can be left blank."),
+ "key": WEB_CLIENT_URL,
+ "label": _(
+ "The base URL for the web catalog for this library, for the 856 field."
+ ),
+ "description": _(
+ "If using a library registry that provides a web catalog, this can be left blank."
+ ),
},
- { "key": INCLUDE_SUMMARY,
- "label": _("Include summaries in MARC records (520 field)"),
- "type": "select",
- "options": [
- { "key": "false", "label": _("Do not include summaries") },
- { "key": "true", "label": _("Include summaries") },
- ],
- "default": "false",
+ {
+ "key": INCLUDE_SUMMARY,
+ "label": _("Include summaries in MARC records (520 field)"),
+ "type": "select",
+ "options": [
+ {"key": "false", "label": _("Do not include summaries")},
+ {"key": "true", "label": _("Include summaries")},
+ ],
+ "default": "false",
},
- { "key": INCLUDE_SIMPLIFIED_GENRES,
- "label": _("Include Library Simplified genres in MARC records (650 fields)"),
- "type": "select",
- "options": [
- { "key": "false", "label": _("Do not include Library Simplified genres") },
- { "key": "true", "label": _("Include Library Simplified genres") },
- ],
- "default": "false",
+ {
+ "key": INCLUDE_SIMPLIFIED_GENRES,
+ "label": _(
+ "Include Library Simplified genres in MARC records (650 fields)"
+ ),
+ "type": "select",
+ "options": [
+ {
+ "key": "false",
+ "label": _("Do not include Library Simplified genres"),
+ },
+ {"key": "true", "label": _("Include Library Simplified genres")},
+ ],
+ "default": "false",
},
]
NO_MIRROR_INTEGRATION = "NO_MIRROR"
DEFAULT_MIRROR_INTEGRATION = dict(
- key=NO_MIRROR_INTEGRATION,
- label=_("None - Do not mirror MARC files")
+ key=NO_MIRROR_INTEGRATION, label=_("None - Do not mirror MARC files")
)
SETTING = {
"key": "mirror_integration_id",
"label": _("MARC Mirror"),
- "description": _("Storage protocol to use for uploading generated MARC files. The service must already be configured under 'Storage Services'."),
+ "description": _(
+ "Storage protocol to use for uploading generated MARC files. The service must already be configured under 'Storage Services'."
+ ),
"type": "select",
- "options" : [DEFAULT_MIRROR_INTEGRATION]
+ "options": [DEFAULT_MIRROR_INTEGRATION],
}
@classmethod
def from_config(cls, library):
_db = Session.object_session(library)
integration = ExternalIntegration.lookup(
- _db, ExternalIntegration.MARC_EXPORT,
- ExternalIntegration.CATALOG_GOAL, library=library
+ _db,
+ ExternalIntegration.MARC_EXPORT,
+ ExternalIntegration.CATALOG_GOAL,
+ library=library,
)
if not integration:
raise CannotLoadConfiguration(
@@ -550,26 +623,27 @@ def __init__(self, _db, library, integration):
self._db = _db
self.library = library
self.integration = integration
-
+
@classmethod
def get_storage_settings(cls, _db):
integrations = ExternalIntegration.for_goal(
_db, ExternalIntegration.STORAGE_GOAL
)
- cls.SETTING['options'] = [cls.DEFAULT_MIRROR_INTEGRATION]
+ cls.SETTING["options"] = [cls.DEFAULT_MIRROR_INTEGRATION]
for integration in integrations:
- # Only add an integration to choose from if it has a
+ # Only add an integration to choose from if it has a
# MARC File Bucket field in its settings.
- configuration_settings = [s for s in integration.settings if s.key=="marc_bucket"]
+ configuration_settings = [
+ s for s in integration.settings if s.key == "marc_bucket"
+ ]
if configuration_settings:
if configuration_settings[0].value:
- cls.SETTING['options'].append(
+ cls.SETTING["options"].append(
dict(key=str(integration.id), label=integration.name)
)
-
- return cls.SETTING
+ return cls.SETTING
@classmethod
def create_record(cls, work, annotator, force_create=False, integration=None):
@@ -613,12 +687,22 @@ def create_record(cls, work, annotator, force_create=False, integration=None):
setattr(work, annotator.marc_cache_field, data.decode("utf8"))
# Add additional fields that should not be cached.
- annotator.annotate_work_record(work, pool, edition, identifier, record, integration)
+ annotator.annotate_work_record(
+ work, pool, edition, identifier, record, integration
+ )
return record
- def records(self, lane, annotator, mirror_integration, start_time=None,
- force_refresh=False, mirror=None, search_engine=None,
- query_batch_size=500, upload_batch_size=7500,
+ def records(
+ self,
+ lane,
+ annotator,
+ mirror_integration,
+ start_time=None,
+ force_refresh=False,
+ mirror=None,
+ search_engine=None,
+ query_batch_size=500,
+ upload_batch_size=7500,
):
"""
Create and export a MARC file for the books in a lane.
@@ -642,7 +726,9 @@ def records(self, lane, annotator, mirror_integration, start_time=None,
storage_protocol = mirror_integration.protocol
mirror = MirrorUploader.implementation(mirror_integration)
if mirror.NAME != storage_protocol:
- raise Exception("Mirror integration does not match configured storage protocol")
+ raise Exception(
+ "Mirror integration does not match configured storage protocol"
+ )
if not mirror:
raise Exception("No mirror integration is configured")
@@ -659,8 +745,7 @@ def records(self, lane, annotator, mirror_integration, start_time=None,
url = mirror.marc_file_url(self.library, lane, end_time, start_time)
representation, ignore = get_one_or_create(
- self._db, Representation, url=url,
- media_type=Representation.MARC_MEDIA_TYPE
+ self._db, Representation, url=url, media_type=Representation.MARC_MEDIA_TYPE
)
with mirror.multipart_upload(representation, url) as upload:
@@ -669,8 +754,10 @@ def records(self, lane, annotator, mirror_integration, start_time=None,
while pagination is not None:
# Retrieve one 'page' of works from the search index.
works = lane.works(
- self._db, pagination=pagination, facets=facets,
- search_engine=search_engine
+ self._db,
+ pagination=pagination,
+ facets=facets,
+ search_engine=search_engine,
)
for work in works:
# Create a record for each work and add it to the
@@ -696,10 +783,13 @@ def records(self, lane, annotator, mirror_integration, start_time=None,
representation.fetched_at = end_time
if not representation.mirror_exception:
cached, is_new = get_one_or_create(
- self._db, CachedMARCFile, library=self.library,
+ self._db,
+ CachedMARCFile,
+ library=self.library,
lane=(lane if isinstance(lane, Lane) else None),
start_time=start_time,
- create_method_kwargs=dict(representation=representation))
+ create_method_kwargs=dict(representation=representation),
+ )
if not is_new:
cached.representation = representation
cached.end_time = end_time
diff --git a/metadata_layer.py b/metadata_layer.py
index 61d32f742..95e7d082c 100644
--- a/metadata_layer.py
+++ b/metadata_layer.py
@@ -6,28 +6,22 @@
model. Doing a third-party integration should be as simple as putting
the information into this format.
"""
-from collections import defaultdict
-from sqlalchemy.orm.session import Session
-from dateutil.parser import parse
-from sqlalchemy.sql.expression import and_, or_
-from sqlalchemy.orm.exc import (
- NoResultFound,
-)
-from sqlalchemy.orm import aliased
import csv
import datetime
import logging
import re
+from collections import defaultdict
+
+from dateutil.parser import parse
from pymarc import MARCReader
+from sqlalchemy.orm import aliased
+from sqlalchemy.orm.exc import NoResultFound
+from sqlalchemy.orm.session import Session
+from sqlalchemy.sql.expression import and_, or_
-from .classifier import Classifier
-from .util import LanguageCodes
-from .util.http import RemoteIntegrationException
-from .util.personal_names import name_tidy
-from .util.median import median
+from .analytics import Analytics
+from .classifier import NO_NUMBER, NO_VALUE, Classifier
from .model import (
- get_one,
- get_one_or_create,
CirculationEvent,
Classification,
Collection,
@@ -43,40 +37,44 @@
LicensePool,
LicensePoolDeliveryMechanism,
LinkRelations,
- Subject,
- Hyperlink,
PresentationCalculationPolicy,
- RightsStatus,
Representation,
Resource,
+ RightsStatus,
+ Subject,
Timestamp,
Work,
+ get_one,
+ get_one_or_create,
)
from .model.configuration import ExternalIntegrationLink
-from .classifier import NO_VALUE, NO_NUMBER
-from .analytics import Analytics
-from .util.personal_names import display_name_to_sort_name
+from .util import LanguageCodes
from .util.datetime_helpers import strptime_utc, to_utc, utc_now
+from .util.http import RemoteIntegrationException
+from .util.median import median
+from .util.personal_names import display_name_to_sort_name, name_tidy
+
class ReplacementPolicy(object):
"""How serious should we be about overwriting old metadata with
this new metadata?
"""
+
def __init__(
- self,
- identifiers=False,
- subjects=False,
- contributions=False,
- links=False,
- formats=False,
- rights=False,
- link_content=False,
- mirrors=None,
- content_modifier=None,
- analytics=None,
- http_get=None,
- even_if_not_apparently_updated=False,
- presentation_calculation_policy=None
+ self,
+ identifiers=False,
+ subjects=False,
+ contributions=False,
+ links=False,
+ formats=False,
+ rights=False,
+ link_content=False,
+ mirrors=None,
+ content_modifier=None,
+ analytics=None,
+ http_get=None,
+ even_if_not_apparently_updated=False,
+ presentation_calculation_policy=None,
):
self.identifiers = identifiers
self.subjects = subjects
@@ -91,8 +89,7 @@ def __init__(
self.analytics = analytics
self.http_get = http_get
self.presentation_calculation_policy = (
- presentation_calculation_policy or
- PresentationCalculationPolicy()
+ presentation_calculation_policy or PresentationCalculationPolicy()
)
@classmethod
@@ -147,6 +144,7 @@ def append_only(cls, **args):
**args
)
+
class SubjectData(object):
def __init__(self, type, identifier, name=None, weight=1):
self.type = type
@@ -170,15 +168,27 @@ def key(self):
def __repr__(self):
return '' % (
- self.type, self.identifier, self.name, self.weight
+ self.type,
+ self.identifier,
+ self.name,
+ self.weight,
)
class ContributorData(object):
-
- def __init__(self, sort_name=None, display_name=None,
- family_name=None, wikipedia_name=None, roles=None,
- lc=None, viaf=None, biography=None, aliases=None, extra=None):
+ def __init__(
+ self,
+ sort_name=None,
+ display_name=None,
+ family_name=None,
+ wikipedia_name=None,
+ roles=None,
+ lc=None,
+ viaf=None,
+ biography=None,
+ aliases=None,
+ extra=None,
+ ):
self.sort_name = sort_name
self.display_name = display_name
self.family_name = family_name
@@ -196,10 +206,16 @@ def __init__(self, sort_name=None, display_name=None,
self.extra = extra or dict()
# TODO: consider if it's time for ContributorData to connect back to Contributions
-
def __repr__(self):
- return '' % (self.sort_name, self.display_name, self.family_name, self.wikipedia_name, self.roles, self.lc, self.viaf)
-
+ return '' % (
+ self.sort_name,
+ self.display_name,
+ self.family_name,
+ self.wikipedia_name,
+ self.roles,
+ self.lc,
+ self.viaf,
+ )
@classmethod
def from_contribution(cls, contribution):
@@ -216,12 +232,11 @@ def from_contribution(cls, contribution):
viaf=c.viaf,
biography=c.biography,
aliases=c.aliases,
- roles=[contribution.role]
+ roles=[contribution.role],
)
@classmethod
- def lookup(cls, _db, sort_name=None, display_name=None, lc=None,
- viaf=None):
+ def lookup(cls, _db, sort_name=None, display_name=None, lc=None, viaf=None):
"""Create a (potentially synthetic) ContributorData based on
the best available information in the database.
@@ -229,13 +244,13 @@ def lookup(cls, _db, sort_name=None, display_name=None, lc=None,
"""
clauses = []
if sort_name:
- clauses.append(Contributor.sort_name==sort_name)
+ clauses.append(Contributor.sort_name == sort_name)
if display_name:
- clauses.append(Contributor.display_name==display_name)
+ clauses.append(Contributor.display_name == display_name)
if lc:
- clauses.append(Contributor.lc==lc)
+ clauses.append(Contributor.lc == lc)
if viaf:
- clauses.append(Contributor.viaf==viaf)
+ clauses.append(Contributor.viaf == viaf)
if not clauses:
raise ValueError("No Contributor information provided!")
@@ -256,23 +271,18 @@ def lookup(cls, _db, sort_name=None, display_name=None, lc=None,
# name doesn't match.
for c in contributors:
if c.sort_name:
- values_by_field['sort_name'].add(c.sort_name)
+ values_by_field["sort_name"].add(c.sort_name)
if c.display_name:
- values_by_field['display_name'].add(c.display_name)
+ values_by_field["display_name"].add(c.display_name)
if c.lc:
- values_by_field['lc'].add(c.lc)
+ values_by_field["lc"].add(c.lc)
if c.viaf:
- values_by_field['viaf'].add(c.viaf)
+ values_by_field["viaf"].add(c.viaf)
# Use any passed-in values as default values for the
# ContributorData. Below, missing values may be filled in and
# inaccurate values may be replaced.
- kwargs = dict(
- sort_name=sort_name,
- display_name=display_name,
- lc=lc,
- viaf=viaf
- )
+ kwargs = dict(sort_name=sort_name, display_name=display_name, lc=lc, viaf=viaf)
for k, values in list(values_by_field.items()):
if len(values) == 1:
# All the Contributors we found have the same
@@ -282,7 +292,7 @@ def lookup(cls, _db, sort_name=None, display_name=None, lc=None,
return ContributorData(roles=[], **kwargs)
def apply(self, destination, replace=None):
- """ Update the passed-in Contributor-type object with this
+ """Update the passed-in Contributor-type object with this
ContributorData's information.
:param: destination -- the Contributor or ContributorData object to
@@ -292,7 +302,13 @@ def apply(self, destination, replace=None):
:return: the possibly changed Contributor object and a flag of whether it's been changed.
"""
log = logging.getLogger("Abstract metadata layer")
- log.debug("Applying %r (%s) into %r (%s)", self, self.viaf, destination, destination.viaf)
+ log.debug(
+ "Applying %r (%s) into %r (%s)",
+ self,
+ self.viaf,
+ destination,
+ destination.viaf,
+ )
made_changes = False
@@ -320,21 +336,17 @@ def apply(self, destination, replace=None):
if self.viaf and self.viaf != destination.viaf:
destination.viaf = self.viaf
made_changes = True
- if (self.family_name and
- self.family_name != destination.family_name):
+ if self.family_name and self.family_name != destination.family_name:
destination.family_name = self.family_name
made_changes = True
- if (self.display_name and
- self.display_name != destination.display_name):
+ if self.display_name and self.display_name != destination.display_name:
destination.display_name = self.display_name
made_changes = True
- if (self.wikipedia_name and
- self.wikipedia_name != destination.wikipedia_name):
+ if self.wikipedia_name and self.wikipedia_name != destination.wikipedia_name:
destination.wikipedia_name = self.wikipedia_name
made_changes = True
- if (self.biography and
- self.biography != destination.biography):
+ if self.biography and self.biography != destination.biography:
destination.biography = self.biography
made_changes = True
@@ -344,11 +356,9 @@ def apply(self, destination, replace=None):
return destination, made_changes
-
def find_sort_name(self, _db, identifiers, metadata_client):
- """Try as hard as possible to find this person's sort name.
- """
+ """Try as hard as possible to find this person's sort name."""
log = logging.getLogger("Abstract metadata layer")
if self.sort_name:
# log.debug(
@@ -367,7 +377,8 @@ def find_sort_name(self, _db, identifiers, metadata_client):
# exact sort name? If so, use their display name.
# If not, take our best guess based on the display name.
sort_name = self.display_name_to_sort_name_from_existing_contributor(
- _db, self.display_name)
+ _db, self.display_name
+ )
if sort_name:
self.sort_name = sort_name
return True
@@ -386,7 +397,8 @@ def find_sort_name(self, _db, identifiers, metadata_client):
log = logging.getLogger("Abstract metadata layer")
log.error(
"Metadata client exception while determining sort name for %s",
- self.display_name, exc_info=e
+ self.display_name,
+ exc_info=e,
)
if sort_name:
self.sort_name = sort_name
@@ -396,7 +408,7 @@ def find_sort_name(self, _db, identifiers, metadata_client):
# on the display name.
self.sort_name = display_name_to_sort_name(self.display_name)
- return (self.sort_name is not None)
+ return self.sort_name is not None
@classmethod
def display_name_to_sort_name_from_existing_contributor(self, _db, display_name):
@@ -412,46 +424,53 @@ def display_name_to_sort_name_from_existing_contributor(self, _db, display_name)
time an external list item is relevant), this will probably be
easy.
"""
- contributors = _db.query(Contributor).filter(
- Contributor.display_name==display_name).filter(
- Contributor.sort_name != None).all()
+ contributors = (
+ _db.query(Contributor)
+ .filter(Contributor.display_name == display_name)
+ .filter(Contributor.sort_name != None)
+ .all()
+ )
if contributors:
log = logging.getLogger("Abstract metadata layer")
log.debug(
"Determined that sort name of %s is %s based on previously existing contributor",
display_name,
- contributors[0].sort_name
+ contributors[0].sort_name,
)
return contributors[0].sort_name
return None
- def _display_name_to_sort_name(
- self, _db, metadata_client, identifier_obj
- ):
+ def _display_name_to_sort_name(self, _db, metadata_client, identifier_obj):
response = metadata_client.canonicalize_author_name(
- identifier_obj, self.display_name)
+ identifier_obj, self.display_name
+ )
sort_name = None
if isinstance(response, (bytes, str)):
sort_name = response
else:
log = logging.getLogger("Abstract metadata layer")
- if (response.status_code == 200
- and response.headers['Content-Type'].startswith('text/plain')):
+ if response.status_code == 200 and response.headers[
+ "Content-Type"
+ ].startswith("text/plain"):
sort_name = response.content
log.info(
"Canonicalizer found sort name for %r: %s => %s",
- identifier_obj, self.display_name, sort_name
+ identifier_obj,
+ self.display_name,
+ sort_name,
)
else:
log.warn(
"Canonicalizer could not find sort name for %r/%s",
- identifier_obj, self.display_name
+ identifier_obj,
+ self.display_name,
)
return sort_name
def display_name_to_sort_name_through_canonicalizer(
- self, _db, identifiers, metadata_client):
+ self, _db, identifiers, metadata_client
+ ):
sort_name = None
for identifier in identifiers:
if identifier.type != Identifier.ISBN:
@@ -464,9 +483,7 @@ def display_name_to_sort_name_through_canonicalizer(
break
if not sort_name:
- sort_name = self._display_name_to_sort_name(
- _db, metadata_client, None
- )
+ sort_name = self._display_name_to_sort_name(_db, metadata_client, None)
return sort_name
@@ -478,19 +495,28 @@ def __init__(self, type, identifier, weight=1):
def __repr__(self):
return '' % (
- self.type, self.identifier, self.weight
+ self.type,
+ self.identifier,
+ self.weight,
)
def load(self, _db):
- return Identifier.for_foreign_id(
- _db, self.type, self.identifier
- )
+ return Identifier.for_foreign_id(_db, self.type, self.identifier)
class LinkData(object):
- def __init__(self, rel, href=None, media_type=None, content=None,
- thumbnail=None, rights_uri=None, rights_explanation=None,
- original=None, transformation_settings=None):
+ def __init__(
+ self,
+ rel,
+ href=None,
+ media_type=None,
+ content=None,
+ thumbnail=None,
+ rights_uri=None,
+ rights_explanation=None,
+ original=None,
+ transformation_settings=None,
+ ):
if not rel:
raise ValueError("rel is required")
@@ -530,19 +556,21 @@ def __repr__(self):
if self.content:
content = ", %d bytes content" % len(self.content)
else:
- content = ''
+ content = ""
if self.thumbnail:
- thumbnail = ', has thumbnail'
+ thumbnail = ", has thumbnail"
else:
- thumbnail = ''
+ thumbnail = ""
return '' % (
- self.rel, self.href, self.media_type, thumbnail,
- content
+ self.rel,
+ self.href,
+ self.media_type,
+ thumbnail,
+ content,
)
def mirror_type(self):
- """Returns the type of mirror that should be used for the link.
- """
+ """Returns the type of mirror that should be used for the link."""
if self.rel in [Hyperlink.IMAGE, Hyperlink.THUMBNAIL_IMAGE]:
return ExternalIntegrationLink.COVERS
@@ -553,11 +581,7 @@ def mirror_type(self):
class MeasurementData(object):
- def __init__(self,
- quantity_measured,
- value,
- weight=1,
- taken_at=None):
+ def __init__(self, quantity_measured, value, weight=1, taken_at=None):
if not quantity_measured:
raise ValueError("quantity_measured is required.")
if value is None:
@@ -571,7 +595,10 @@ def __init__(self,
def __repr__(self):
return '' % (
- self.quantity_measured, self.value, self.weight, self.taken_at
+ self.quantity_measured,
+ self.value,
+ self.weight,
+ self.taken_at,
)
@@ -580,18 +607,23 @@ def __init__(self, content_type, drm_scheme, link=None, rights_uri=None):
self.content_type = content_type
self.drm_scheme = drm_scheme
if link and not isinstance(link, LinkData):
- raise TypeError(
- "Expected LinkData object, got %s" % type(link)
- )
+ raise TypeError("Expected LinkData object, got %s" % type(link))
self.link = link
self.rights_uri = rights_uri
- if ((not self.rights_uri) and self.link and self.link.rights_uri):
+ if (not self.rights_uri) and self.link and self.link.rights_uri:
self.rights_uri = self.link.rights_uri
class LicenseData(object):
- def __init__(self, identifier, checkout_url, status_url, expires=None, remaining_checkouts=None,
- concurrent_checkouts=None):
+ def __init__(
+ self,
+ identifier,
+ checkout_url,
+ status_url,
+ expires=None,
+ remaining_checkouts=None,
+ concurrent_checkouts=None,
+ ):
self.identifier = identifier
self.checkout_url = checkout_url
self.status_url = status_url
@@ -604,8 +636,9 @@ class TimestampData(object):
CLEAR_VALUE = Timestamp.CLEAR_VALUE
- def __init__(self, start=None, finish=None, achievements=None,
- counter=None, exception=None):
+ def __init__(
+ self, start=None, finish=None, achievements=None, counter=None, exception=None
+ ):
"""A constructor intended to be used by a service to customize its
eventual Timestamp.
@@ -652,9 +685,17 @@ def is_complete(self):
"""
return self.is_failure or self.finish not in (None, self.CLEAR_VALUE)
- def finalize(self, service, service_type, collection, start=None,
- finish=None, achievements=None, counter=None,
- exception=None):
+ def finalize(
+ self,
+ service,
+ service_type,
+ collection,
+ start=None,
+ finish=None,
+ achievements=None,
+ counter=None,
+ exception=None,
+ ):
"""Finalize any values that were not set during the constructor.
This is intended to be run by the code that originally ran the
@@ -693,9 +734,15 @@ def apply(self, _db):
)
return Timestamp.stamp(
- _db, self.service, self.service_type, self.collection(_db),
- self.start, self.finish, self.achievements, self.counter,
- self.exception
+ _db,
+ self.service,
+ self.service_type,
+ self.collection(_db),
+ self.start,
+ self.finish,
+ self.achievements,
+ self.counter,
+ self.exception,
)
@@ -721,12 +768,9 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
self.log.info("Not mirroring %s: rel=%s", link.href, link_obj.rel)
return
- if (link.rights_uri
- and link.rights_uri == RightsStatus.IN_COPYRIGHT):
+ if link.rights_uri and link.rights_uri == RightsStatus.IN_COPYRIGHT:
self.log.info(
- "Not mirroring %s: rights status=%s" % (
- link.href, link.rights_uri
- )
+ "Not mirroring %s: rights status=%s" % (link.href, link.rights_uri)
)
return
@@ -737,9 +781,7 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
if not mirror:
return
else:
- self.log.info(
- "No mirror uploader with key %s found" % mirror_type
- )
+ self.log.info("No mirror uploader with key %s found" % mirror_type)
return
http_get = policy.http_get
@@ -757,7 +799,11 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
pools = [model_object]
identifier = model_object.identifier
- if (identifier and identifier.primarily_identifies and identifier.primarily_identifies[0]):
+ if (
+ identifier
+ and identifier.primarily_identifies
+ and identifier.primarily_identifies[0]
+ ):
edition = identifier.primarily_identifies[0]
elif isinstance(model_object, Edition):
pools = model_object.license_pools
@@ -766,11 +812,15 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
if edition and edition.title:
title = edition.title
else:
- title = getattr(self, 'title', None) or None
+ title = getattr(self, "title", None) or None
- if ((not identifier) or (link_obj.identifier and identifier != link_obj.identifier)):
+ if (not identifier) or (
+ link_obj.identifier and identifier != link_obj.identifier
+ ):
# insanity found
- self.log.warn("Tried to mirror a link with an invalid identifier %r" % identifier)
+ self.log.warn(
+ "Tried to mirror a link with an invalid identifier %r" % identifier
+ )
return
max_age = None
@@ -783,7 +833,9 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
# This will fetch a representation of the original and
# store it in the database.
representation, is_new = Representation.get(
- _db, link.href, do_get=http_get,
+ _db,
+ link.href,
+ do_get=http_get,
presumed_media_type=link.media_type,
max_age=max_age,
)
@@ -801,7 +853,9 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
if pools and link.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD:
for pool in pools:
pool.suppressed = True
- pool.license_exception = "Fetch exception: %s" % representation.fetch_exception
+ pool.license_exception = (
+ "Fetch exception: %s" % representation.fetch_exception
+ )
self.log.error(pool.license_exception)
return
@@ -810,14 +864,16 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
# again.
if representation.status_code == 304 and representation.mirror_url:
self.log.info(
- "Representation has not changed, assuming mirror at %s is up to date.", representation.mirror_url
+ "Representation has not changed, assuming mirror at %s is up to date.",
+ representation.mirror_url,
)
return
- if representation.status_code // 100 not in (2,3):
+ if representation.status_code // 100 not in (2, 3):
self.log.info(
"Representation %s gave %s status code, not mirroring.",
- representation.url, representation.status_code
+ representation.url,
+ representation.status_code,
)
return
@@ -835,18 +891,27 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
if not representation.mirrorable_media_type:
if link.media_type:
- self.log.info("Saw unsupported media type for %s: %s. Assuming original media type %s is correct",
- representation.url, representation.media_type, link.media_type)
+ self.log.info(
+ "Saw unsupported media type for %s: %s. Assuming original media type %s is correct",
+ representation.url,
+ representation.media_type,
+ link.media_type,
+ )
representation.media_type = link.media_type
else:
- self.log.info("Not mirroring %s: unsupported media type %s",
- representation.url, representation.media_type)
+ self.log.info(
+ "Not mirroring %s: unsupported media type %s",
+ representation.url,
+ representation.media_type,
+ )
return
# Determine the best URL to use when mirroring this
# representation.
- if link.media_type in Representation.BOOK_MEDIA_TYPES or \
- link.media_type in Representation.AUDIOBOOK_MEDIA_TYPES:
+ if (
+ link.media_type in Representation.BOOK_MEDIA_TYPES
+ or link.media_type in Representation.AUDIOBOOK_MEDIA_TYPES
+ ):
url_title = title or identifier.identifier
extension = representation.extension()
mirror_url = mirror.book_url(
@@ -854,15 +919,13 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
data_source=data_source,
title=url_title,
extension=extension,
- open_access=link.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD
+ open_access=link.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD,
)
else:
filename = representation.default_filename(
link_obj, representation.media_type
)
- mirror_url = mirror.cover_image_url(
- data_source, identifier, filename
- )
+ mirror_url = mirror.cover_image_url(data_source, identifier, filename)
# Mirror it.
collection = pools[0].collection if pools else None
@@ -874,7 +937,9 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
if pools and link.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD:
for pool in pools:
pool.suppressed = True
- pool.license_exception = "Mirror exception: %s" % representation.mirror_exception
+ pool.license_exception = (
+ "Mirror exception: %s" % representation.mirror_exception
+ )
self.log.error(pool.license_exception)
if link_obj.rel == Hyperlink.IMAGE:
@@ -883,20 +948,24 @@ def mirror_link(self, model_object, data_source, link, link_obj, policy):
link_obj, Representation.PNG_MEDIA_TYPE
)
thumbnail_url = mirror.cover_image_url(
- data_source, identifier, thumbnail_filename,
- Edition.MAX_THUMBNAIL_HEIGHT
+ data_source,
+ identifier,
+ thumbnail_filename,
+ Edition.MAX_THUMBNAIL_HEIGHT,
)
thumbnail, is_new = representation.scale(
max_height=Edition.MAX_THUMBNAIL_HEIGHT,
max_width=Edition.MAX_THUMBNAIL_WIDTH,
destination_url=thumbnail_url,
destination_media_type=Representation.PNG_MEDIA_TYPE,
- force=True
+ force=True,
)
if is_new:
# A thumbnail was created distinct from the original
# image. Mirror it as well.
- mirror.mirror_one(thumbnail, mirror_to=thumbnail_url, collection=collection)
+ mirror.mirror_one(
+ thumbnail, mirror_to=thumbnail_url, collection=collection
+ )
if link_obj.rel in Hyperlink.SELF_HOSTED_BOOKS:
# If we mirrored book content successfully, remove it from
@@ -917,23 +986,21 @@ class CirculationData(MetaToModelUtility):
Metadata : Edition :: CirculationData : Licensepool
"""
- log = logging.getLogger(
- "Abstract metadata layer - Circulation data"
- )
+ log = logging.getLogger("Abstract metadata layer - Circulation data")
def __init__(
- self,
- data_source,
- primary_identifier,
- licenses_owned=None,
- licenses_available=None,
- licenses_reserved=None,
- patrons_in_hold_queue=None,
- formats=None,
- default_rights_uri=None,
- links=None,
- licenses=None,
- last_checked=None,
+ self,
+ data_source,
+ primary_identifier,
+ licenses_owned=None,
+ licenses_available=None,
+ licenses_reserved=None,
+ patrons_in_hold_queue=None,
+ formats=None,
+ default_rights_uri=None,
+ links=None,
+ licenses=None,
+ last_checked=None,
):
"""Constructor.
@@ -971,7 +1038,10 @@ def __init__(
self.formats = formats or []
self.default_rights_uri = None
- self.set_default_rights_uri(data_source_name=self.data_source_name, default_rights_uri=default_rights_uri)
+ self.set_default_rights_uri(
+ data_source_name=self.data_source_name,
+ default_rights_uri=default_rights_uri,
+ )
self.__links = None
self.links = links
@@ -985,8 +1055,8 @@ def links(self):
@links.setter
def links(self, arg_links):
- """ If got passed all links, undiscriminately, filter out to only those relevant to
- pools (the rights-related links).
+ """If got passed all links, undiscriminately, filter out to only those relevant to
+ pools (the rights-related links).
"""
# start by deleting any old links
self.__links = []
@@ -1001,9 +1071,11 @@ def links(self, arg_links):
self.__links.append(link)
# An open-access link or open-access rights implies a FormatData object.
- open_access_link = (link.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD and link.href)
+ open_access_link = (
+ link.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD and link.href
+ )
# try to deduce if the link is open-access, even if it doesn't explicitly say it is
- rights_uri = link.rights_uri or self.default_rights_uri
+ rights_uri = link.rights_uri or self.default_rights_uri
open_access_rights_link = (
link.media_type in Representation.BOOK_MEDIA_TYPES
and link.href
@@ -1011,9 +1083,11 @@ def links(self, arg_links):
)
if open_access_link or open_access_rights_link:
- if (open_access_link
+ if (
+ open_access_link
and rights_uri != RightsStatus.IN_COPYRIGHT
- and not rights_uri in RightsStatus.OPEN_ACCESS):
+ and not rights_uri in RightsStatus.OPEN_ACCESS
+ ):
# We don't know exactly what's going on here but
# the link said it was an open-access book
# and the rights URI doesn't contradict it,
@@ -1036,23 +1110,23 @@ def links(self, arg_links):
)
)
-
def __repr__(self):
- description_string = ''
-
+ description_string = ""
+ )
- description_data = {'licenses_owned':self.licenses_owned}
+ description_data = {"licenses_owned": self.licenses_owned}
if self._primary_identifier:
- description_data['primary_identifier'] = self._primary_identifier
+ description_data["primary_identifier"] = self._primary_identifier
else:
- description_data['primary_identifier'] = self.primary_identifier_obj
- description_data['licenses_available'] = self.licenses_available
- description_data['default_rights_uri'] = self.default_rights_uri
- description_data['links'] = self.links
- description_data['formats'] = self.formats
- description_data['data_source'] = self.data_source_name
+ description_data["primary_identifier"] = self.primary_identifier_obj
+ description_data["licenses_available"] = self.licenses_available
+ description_data["default_rights_uri"] = self.default_rights_uri
+ description_data["links"] = self.links
+ description_data["formats"] = self.formats
+ description_data["data_source"] = self.data_source_name
return description_string % description_data
@@ -1088,9 +1162,7 @@ def license_pool(self, _db, collection, analytics=None):
will be tracked with this.
"""
if not collection:
- raise ValueError(
- "Cannot find license pool: no collection provided."
- )
+ raise ValueError("Cannot find license pool: no collection provided.")
identifier = self.primary_identifier(_db)
if not identifier:
raise ValueError(
@@ -1099,10 +1171,11 @@ def license_pool(self, _db, collection, analytics=None):
data_source_obj = self.data_source(_db)
license_pool, is_new = LicensePool.for_foreign_id(
- _db, data_source=data_source_obj,
+ _db,
+ data_source=data_source_obj,
foreign_id_type=identifier.type,
foreign_id=identifier.identifier,
- collection=collection
+ collection=collection,
)
if is_new:
@@ -1113,35 +1186,39 @@ def license_pool(self, _db, collection, analytics=None):
if analytics:
for library in collection.libraries:
analytics.collect_event(
- library, license_pool,
+ library,
+ license_pool,
CirculationEvent.DISTRIBUTOR_TITLE_ADD,
self.last_checked,
- old_value=0, new_value=1,
+ old_value=0,
+ new_value=1,
)
license_pool.last_checked = self.last_checked
return license_pool, is_new
-
@property
def has_open_access_link(self):
"""Does this Circulation object have an associated open-access link?"""
return any(
- [x for x in self.links
- if x.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD
- and x.href
- and x.rights_uri != RightsStatus.IN_COPYRIGHT
+ [
+ x
+ for x in self.links
+ if x.rel == Hyperlink.OPEN_ACCESS_DOWNLOAD
+ and x.href
+ and x.rights_uri != RightsStatus.IN_COPYRIGHT
]
)
-
def set_default_rights_uri(self, data_source_name, default_rights_uri=None):
if default_rights_uri:
self.default_rights_uri = default_rights_uri
elif data_source_name:
# We didn't get rights passed in, so use the default rights for the data source if any.
- default = RightsStatus.DATA_SOURCE_DEFAULT_RIGHTS_STATUS.get(data_source_name, None)
+ default = RightsStatus.DATA_SOURCE_DEFAULT_RIGHTS_STATUS.get(
+ data_source_name, None
+ )
if default:
self.default_rights_uri = default
@@ -1163,10 +1240,12 @@ def apply(self, _db, collection, replace=None):
# can only be stored in a LicensePool, but we have no
# Collection to tell us which LicensePool to use. This is
# indicative of an error in programming.
- if not collection and (self.licenses_owned is not None
- or self.licenses_available is not None
- or self.licenses_reserved is not None
- or self.patrons_in_hold_queue is not None):
+ if not collection and (
+ self.licenses_owned is not None
+ or self.licenses_available is not None
+ or self.licenses_reserved is not None
+ or self.patrons_in_hold_queue is not None
+ ):
raise ValueError(
"Cannot store circulation information because no "
"Collection was provided."
@@ -1193,8 +1272,11 @@ def apply(self, _db, collection, replace=None):
for link in self.links:
if link.rel in Hyperlink.CIRCULATION_ALLOWED:
link_obj, ignore = identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
link_objects[link] = link_obj
@@ -1228,10 +1310,12 @@ def apply(self, _db, collection, replace=None):
resource = None
# This can cause a non-open-access LicensePool to go open-access.
lpdm = LicensePoolDeliveryMechanism.set(
- data_source, identifier, format.content_type,
+ data_source,
+ identifier,
+ format.content_type,
format.drm_scheme,
format.rights_uri or self.default_rights_uri,
- resource
+ resource,
)
new_lpdms.append(lpdm)
@@ -1242,14 +1326,17 @@ def apply(self, _db, collection, replace=None):
for lpdm in old_lpdms:
if lpdm not in new_lpdms:
for loan in lpdm.fulfills:
- self.log.info("Loan %i is associated with a format that is no longer available. Deleting its delivery mechanism." % loan.id)
+ self.log.info(
+ "Loan %i is associated with a format that is no longer available. Deleting its delivery mechanism."
+ % loan.id
+ )
loan.fulfillment = None
# This can cause an open-access LicensePool to go
# non-open-access.
lpdm.delete()
new_open_access = any(pool.open_access for pool in pools)
- open_access_status_changed = (old_open_access != new_open_access)
+ open_access_status_changed = old_open_access != new_open_access
old_licenses = new_licenses = []
if pool:
@@ -1257,7 +1344,9 @@ def apply(self, _db, collection, replace=None):
for license in self.licenses:
license_obj, ignore = get_one_or_create(
- _db, License, identifier=license.identifier,
+ _db,
+ License,
+ identifier=license.identifier,
license_pool_id=pool.id,
)
license_obj.checkout_url = license.checkout_url
@@ -1276,7 +1365,10 @@ def apply(self, _db, collection, replace=None):
# still loans we'll need it.
# But if we track individual licenses for other protocols,
# we may need to handle this differently.
- self.log.warn("License %i is no longer available but still has loans." % license.id)
+ self.log.warn(
+ "License %i is no longer available but still has loans."
+ % license.id
+ )
# Finally, if we have data for a specific Collection's license
# for this book, find its LicensePool and update it.
@@ -1290,7 +1382,7 @@ def apply(self, _db, collection, replace=None):
new_licenses_reserved=self.licenses_reserved,
new_patrons_in_hold_queue=self.patrons_in_hold_queue,
analytics=analytics,
- as_of=self.last_checked
+ as_of=self.last_checked,
)
# If this is the first time we've seen this pool, or we never
@@ -1302,9 +1394,12 @@ def apply(self, _db, collection, replace=None):
work.set_presentation_ready()
work_changed = True
- made_changes = (made_changes or changed_availability
- or open_access_status_changed
- or work_changed)
+ made_changes = (
+ made_changes
+ or changed_availability
+ or open_access_status_changed
+ or work_changed
+ )
return pool, made_changes
@@ -1329,36 +1424,44 @@ class Metadata(MetaToModelUtility):
log = logging.getLogger("Abstract metadata layer")
BASIC_EDITION_FIELDS = [
- 'title', 'sort_title', 'subtitle', 'language', 'medium',
- 'series', 'series_position', 'publisher', 'imprint',
- 'issued', 'published'
+ "title",
+ "sort_title",
+ "subtitle",
+ "language",
+ "medium",
+ "series",
+ "series_position",
+ "publisher",
+ "imprint",
+ "issued",
+ "published",
]
def __init__(
- self,
- data_source,
- title=None,
- subtitle=None,
- sort_title=None,
- language=None,
- medium=None,
- series=None,
- series_position=None,
- publisher=None,
- imprint=None,
- issued=None,
- published=None,
- primary_identifier=None,
- identifiers=None,
- recommendations=None,
- subjects=None,
- contributors=None,
- measurements=None,
- links=None,
- data_source_last_updated=None,
- # Note: brought back to keep callers of bibliographic extraction process_one() methods simple.
- circulation=None,
- **kwargs
+ self,
+ data_source,
+ title=None,
+ subtitle=None,
+ sort_title=None,
+ language=None,
+ medium=None,
+ series=None,
+ series_position=None,
+ publisher=None,
+ imprint=None,
+ issued=None,
+ published=None,
+ primary_identifier=None,
+ identifiers=None,
+ recommendations=None,
+ subjects=None,
+ contributors=None,
+ measurements=None,
+ links=None,
+ data_source_last_updated=None,
+ # Note: brought back to keep callers of bibliographic extraction process_one() methods simple.
+ circulation=None,
+ **kwargs
):
# data_source is where the data comes from (e.g. overdrive, metadata wrangler, admin interface),
# and not necessarily where the associated Identifier's LicencePool's lending licenses are coming from.
@@ -1383,11 +1486,10 @@ def __init__(
self.issued = issued
self.published = published
- self.primary_identifier=primary_identifier
+ self.primary_identifier = primary_identifier
self.identifiers = identifiers or []
self.permanent_work_id = None
- if (self.primary_identifier
- and self.primary_identifier not in self.identifiers):
+ if self.primary_identifier and self.primary_identifier not in self.identifiers:
self.identifiers.append(self.primary_identifier)
self.recommendations = recommendations or []
self.subjects = subjects or []
@@ -1408,8 +1510,8 @@ def links(self):
@links.setter
def links(self, arg_links):
- """ If got passed all links, undiscriminately, filter out to only those relevant to
- editions (the image/cover/etc links).
+ """If got passed all links, undiscriminately, filter out to only those relevant to
+ editions (the image/cover/etc links).
"""
# start by deleting any old links
self.__links = []
@@ -1422,7 +1524,6 @@ def links(self, arg_links):
# only accept the types of links relevant to editions
self.__links.append(link)
-
@classmethod
def from_edition(cls, edition):
"""Create a basic Metadata object for the given Edition.
@@ -1444,9 +1545,11 @@ def from_edition(cls, edition):
# the NYT best-seller API.
if edition.sort_author and edition.sort_author != Edition.UNKNOWN_AUTHOR:
contributors.append(
- ContributorData(sort_name=edition.sort_author,
- display_name=edition.author,
- roles=[Contributor.PRIMARY_AUTHOR_ROLE])
+ ContributorData(
+ sort_name=edition.sort_author,
+ display_name=edition.author,
+ roles=[Contributor.PRIMARY_AUTHOR_ROLE],
+ )
)
i = edition.primary_identifier
@@ -1488,7 +1591,6 @@ def primary_author(self):
break
return primary_author
-
def update(self, metadata):
"""Update this Metadata object with values from the given Metadata
object.
@@ -1499,16 +1601,15 @@ def update(self, metadata):
fields = self.BASIC_EDITION_FIELDS
for field in fields:
new_value = getattr(metadata, field)
- if new_value != None and new_value != '':
+ if new_value != None and new_value != "":
setattr(self, field, new_value)
- new_value = getattr(metadata, 'contributors')
+ new_value = getattr(metadata, "contributors")
if new_value and isinstance(new_value, list):
- old_value = getattr(self, 'contributors')
+ old_value = getattr(self, "contributors")
# if we already have a better value, don't override it with a "missing info" placeholder value
if not (old_value and new_value[0].sort_name == Edition.UNKNOWN_AUTHOR):
- setattr(self, 'contributors', new_value)
-
+ setattr(self, "contributors", new_value)
def calculate_permanent_work_id(self, _db, metadata_client):
"""Try to calculate a permanent work ID from this metadata.
@@ -1522,23 +1623,21 @@ def calculate_permanent_work_id(self, _db, metadata_client):
return None, None
if not primary_author.sort_name and metadata_client:
- primary_author.find_sort_name(
- _db, self.identifiers, metadata_client
- )
+ primary_author.find_sort_name(_db, self.identifiers, metadata_client)
sort_author = primary_author.sort_name
pwid = Edition.calculate_permanent_work_id_for_title_and_author(
- self.title, sort_author, "book")
- self.permanent_work_id=pwid
+ self.title, sort_author, "book"
+ )
+ self.permanent_work_id = pwid
return pwid
- def associate_with_identifiers_based_on_permanent_work_id(
- self, _db):
+ def associate_with_identifiers_based_on_permanent_work_id(self, _db):
"""Try to associate this object's primary identifier with
the primary identifiers of Editions in the database which share
a permanent work ID.
"""
- if (not self.primary_identifier or not self.permanent_work_id):
+ if not self.primary_identifier or not self.permanent_work_id:
# We don't have the information necessary to carry out this
# task.
return
@@ -1553,25 +1652,28 @@ def associate_with_identifiers_based_on_permanent_work_id(
# Try to find the primary identifiers of other Editions with
# the same permanent work ID and the same medium, representing
# books already in our collection.
- qu = _db.query(Identifier).join(
- Identifier.primarily_identifies).filter(
- Edition.permanent_work_id==self.permanent_work_id).filter(
- Identifier.type.in_(
- Identifier.LICENSE_PROVIDING_IDENTIFIER_TYPES
- )
- ).filter(
- Edition.medium==self.medium
- )
+ qu = (
+ _db.query(Identifier)
+ .join(Identifier.primarily_identifies)
+ .filter(Edition.permanent_work_id == self.permanent_work_id)
+ .filter(Identifier.type.in_(Identifier.LICENSE_PROVIDING_IDENTIFIER_TYPES))
+ .filter(Edition.medium == self.medium)
+ )
identifiers_same_work_id = qu.all()
for same_work_id in identifiers_same_work_id:
- if (same_work_id.type != self.primary_identifier.type
- or same_work_id.identifier != self.primary_identifier.identifier):
+ if (
+ same_work_id.type != self.primary_identifier.type
+ or same_work_id.identifier != self.primary_identifier.identifier
+ ):
self.log.info(
"Discovered that %r is equivalent to %r because of matching permanent work ID %s",
- same_work_id, primary_identifier_obj, self.permanent_work_id
+ same_work_id,
+ primary_identifier_obj,
+ self.permanent_work_id,
)
primary_identifier_obj.equivalent_to(
- self.data_source(_db), same_work_id, 0.85)
+ self.data_source(_db), same_work_id, 0.85
+ )
def data_source(self, _db):
if not self.data_source_obj:
@@ -1583,22 +1685,20 @@ def data_source(self, _db):
return self.data_source_obj
def edition(self, _db, create_if_not_exists=True):
- """ Find or create the edition described by this Metadata object.
- """
+ """Find or create the edition described by this Metadata object."""
if not self.primary_identifier:
- raise ValueError(
- "Cannot find edition: metadata has no primary identifier."
- )
+ raise ValueError("Cannot find edition: metadata has no primary identifier.")
data_source = self.data_source(_db)
return Edition.for_foreign_id(
- _db, data_source, self.primary_identifier.type,
+ _db,
+ data_source,
+ self.primary_identifier.type,
self.primary_identifier.identifier,
- create_if_not_exists=create_if_not_exists
+ create_if_not_exists=create_if_not_exists,
)
-
def consolidate_identifiers(self):
by_weight = defaultdict(list)
for i in self.identifiers:
@@ -1606,8 +1706,7 @@ def consolidate_identifiers(self):
new_identifiers = []
for (type, identifier), weights in list(by_weight.items()):
new_identifiers.append(
- IdentifierData(type=type, identifier=identifier,
- weight=median(weights))
+ IdentifierData(type=type, identifier=identifier, weight=median(weights))
)
self.identifiers = new_identifiers
@@ -1616,33 +1715,39 @@ def guess_license_pools(self, _db, metadata_client):
potentials = {}
for contributor in self.contributors:
if not any(
- x in contributor.roles for x in
- (Contributor.AUTHOR_ROLE,
- Contributor.PRIMARY_AUTHOR_ROLE)
+ x in contributor.roles
+ for x in (Contributor.AUTHOR_ROLE, Contributor.PRIMARY_AUTHOR_ROLE)
):
continue
contributor.find_sort_name(_db, self.identifiers, metadata_client)
confidence = 0
- base = _db.query(Edition).filter(
- Edition.title.ilike(self.title)).filter(
- Edition.medium==Edition.BOOK_MEDIUM)
+ base = (
+ _db.query(Edition)
+ .filter(Edition.title.ilike(self.title))
+ .filter(Edition.medium == Edition.BOOK_MEDIUM)
+ )
success = False
# A match based on work ID is the most reliable.
pwid = self.calculate_permanent_work_id(_db, metadata_client)
- clause = and_(Edition.data_source_id==LicensePool.data_source_id, Edition.primary_identifier_id==LicensePool.identifier_id)
- qu = base.filter(Edition.permanent_work_id==pwid).join(LicensePool, clause)
+ clause = and_(
+ Edition.data_source_id == LicensePool.data_source_id,
+ Edition.primary_identifier_id == LicensePool.identifier_id,
+ )
+ qu = base.filter(Edition.permanent_work_id == pwid).join(
+ LicensePool, clause
+ )
success = self._run_query(qu, potentials, 0.95)
if not success and contributor.sort_name:
- qu = base.filter(Edition.sort_author==contributor.sort_name)
+ qu = base.filter(Edition.sort_author == contributor.sort_name)
success = self._run_query(qu, potentials, 0.9)
if not success and contributor.display_name:
- qu = base.filter(Edition.author==contributor.display_name)
+ qu = base.filter(Edition.author == contributor.display_name)
success = self._run_query(qu, potentials, 0.8)
if not success:
# Look for the book by an unknown author (our mistake)
- qu = base.filter(Edition.author==Edition.UNKNOWN_AUTHOR)
+ qu = base.filter(Edition.author == Edition.UNKNOWN_AUTHOR)
success = self._run_query(qu, potentials, 0.45)
if not success:
# See if there is any book with this title at all.
@@ -1660,21 +1765,27 @@ def _run_query(self, qu, potentials, confidence):
return success
REL_REQUIRES_NEW_PRESENTATION_EDITION = [
- LinkRelations.IMAGE, LinkRelations.THUMBNAIL_IMAGE
+ LinkRelations.IMAGE,
+ LinkRelations.THUMBNAIL_IMAGE,
]
REL_REQUIRES_FULL_RECALCULATION = [LinkRelations.DESCRIPTION]
# TODO: We need to change all calls to apply() to use a ReplacementPolicy
# instead of passing in individual `replace` arguments. Once that's done,
# we can get rid of the `replace` arguments.
- def apply(self, edition, collection, metadata_client=None, replace=None,
- replace_identifiers=False,
- replace_subjects=False,
- replace_contributions=False,
- replace_links=False,
- replace_formats=False,
- replace_rights=False,
- force=False,
+ def apply(
+ self,
+ edition,
+ collection,
+ metadata_client=None,
+ replace=None,
+ replace_identifiers=False,
+ replace_subjects=False,
+ replace_contributions=False,
+ replace_links=False,
+ replace_formats=False,
+ replace_rights=False,
+ force=False,
):
"""Apply this metadata to the given edition.
@@ -1704,17 +1815,21 @@ def apply(self, edition, collection, metadata_client=None, replace=None,
links=replace_links,
formats=replace_formats,
rights=replace_rights,
- even_if_not_apparently_updated=force
+ even_if_not_apparently_updated=force,
)
# We were given an Edition, so either this metadata's
# primary_identifier must be missing or it must match the
# Edition's primary identifier.
if self.primary_identifier:
- if (self.primary_identifier.type != edition.primary_identifier.type
- or self.primary_identifier.identifier != edition.primary_identifier.identifier):
+ if (
+ self.primary_identifier.type != edition.primary_identifier.type
+ or self.primary_identifier.identifier
+ != edition.primary_identifier.identifier
+ ):
raise ValueError(
- "Metadata's primary identifier (%s/%s) does not match edition's primary identifier (%r)" % (
+ "Metadata's primary identifier (%s/%s) does not match edition's primary identifier (%r)"
+ % (
self.primary_identifier.type,
self.primary_identifier.identifier,
edition.primary_identifier,
@@ -1738,14 +1853,16 @@ def apply(self, edition, collection, metadata_client=None, replace=None,
identifier = edition.primary_identifier
- self.log.info(
- "APPLYING METADATA TO EDITION: %s", self.title
- )
- fields = self.BASIC_EDITION_FIELDS+['permanent_work_id']
+ self.log.info("APPLYING METADATA TO EDITION: %s", self.title)
+ fields = self.BASIC_EDITION_FIELDS + ["permanent_work_id"]
for field in fields:
old_edition_value = getattr(edition, field)
new_metadata_value = getattr(self, field)
- if new_metadata_value != None and new_metadata_value != '' and (new_metadata_value != old_edition_value):
+ if (
+ new_metadata_value != None
+ and new_metadata_value != ""
+ and (new_metadata_value != old_edition_value)
+ ):
if new_metadata_value in [NO_VALUE, NO_NUMBER]:
new_metadata_value = None
setattr(edition, field, new_metadata_value)
@@ -1753,8 +1870,9 @@ def apply(self, edition, collection, metadata_client=None, replace=None,
# Create equivalencies between all given identifiers and
# the edition's primary identifier.
- contributors_changed = self.update_contributions(_db, edition,
- metadata_client, replace.contributions)
+ contributors_changed = self.update_contributions(
+ _db, edition, metadata_client, replace.contributions
+ )
if contributors_changed:
work_requires_new_presentation_edition = True
@@ -1763,21 +1881,22 @@ def apply(self, edition, collection, metadata_client=None, replace=None,
for identifier_data in self.identifiers:
if not identifier_data.identifier:
continue
- if (identifier_data.identifier==identifier.identifier and
- identifier_data.type==identifier.type):
+ if (
+ identifier_data.identifier == identifier.identifier
+ and identifier_data.type == identifier.type
+ ):
# These are the same identifier.
continue
new_identifier, ignore = Identifier.for_foreign_id(
- _db, identifier_data.type, identifier_data.identifier)
+ _db, identifier_data.type, identifier_data.identifier
+ )
identifier.equivalent_to(
- data_source, new_identifier, identifier_data.weight)
+ data_source, new_identifier, identifier_data.weight
+ )
new_subjects = {}
if self.subjects:
- new_subjects = dict(
- (subject.key, subject)
- for subject in self.subjects
- )
+ new_subjects = dict((subject.key, subject) for subject in self.subjects)
if replace.subjects:
# Remove any old Subjects from this data source, unless they
# are also in the list of new subjects.
@@ -1810,8 +1929,12 @@ def _key(classification):
# Apply all new subjects to the identifier.
for subject in list(new_subjects.values()):
identifier.classify(
- data_source, subject.type, subject.identifier,
- subject.name, weight=subject.weight)
+ data_source,
+ subject.type,
+ subject.identifier,
+ subject.name,
+ weight=subject.weight,
+ )
work_requires_full_recalculation = True
# Associate all links with the primary identifier.
@@ -1835,21 +1958,30 @@ def _key(classification):
if link.original:
rights_status = RightsStatus.lookup(_db, link.original.rights_uri)
original_resource, ignore = get_one_or_create(
- _db, Resource, url=link.original.href,
+ _db,
+ Resource,
+ url=link.original.href,
)
if not original_resource.data_source:
original_resource.data_source = data_source
original_resource.rights_status = rights_status
- original_resource.rights_explanation = link.original.rights_explanation
+ original_resource.rights_explanation = (
+ link.original.rights_explanation
+ )
if link.original.content:
original_resource.set_fetched_content(
link.original.guessed_media_type,
- link.original.content, None)
+ link.original.content,
+ None,
+ )
link_obj, ignore = identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
media_type=link.guessed_media_type,
- content=link.content, rights_status_uri=link.rights_uri,
+ content=link.content,
+ rights_status_uri=link.rights_uri,
rights_explanation=link.rights_explanation,
original_resource=original_resource,
transformation_settings=link.transformation_settings,
@@ -1864,28 +1996,26 @@ def _key(classification):
if link.thumbnail.rel == Hyperlink.THUMBNAIL_IMAGE:
thumbnail = link.thumbnail
thumbnail_obj, ignore = identifier.add_link(
- rel=thumbnail.rel, href=thumbnail.href,
+ rel=thumbnail.rel,
+ href=thumbnail.href,
data_source=data_source,
media_type=thumbnail.guessed_media_type,
- content=thumbnail.content
+ content=thumbnail.content,
)
work_requires_new_presentation_edition = True
- if (thumbnail_obj.resource
- and thumbnail_obj.resource.representation):
+ if thumbnail_obj.resource and thumbnail_obj.resource.representation:
thumbnail_obj.resource.representation.thumbnail_of = (
link_obj.resource.representation
)
else:
self.log.error(
- "Thumbnail link %r cannot be marked as a thumbnail of %r because it has no Representation, probably due to a missing media type." % (
- link.thumbnail, link
- )
+ "Thumbnail link %r cannot be marked as a thumbnail of %r because it has no Representation, probably due to a missing media type."
+ % (link.thumbnail, link)
)
else:
self.log.error(
- "Thumbnail link %r does not have the thumbnail link relation! Not acceptable as a thumbnail of %r." % (
- link.thumbnail, link
- )
+ "Thumbnail link %r does not have the thumbnail link relation! Not acceptable as a thumbnail of %r."
+ % (link.thumbnail, link)
)
link.thumbnail = None
@@ -1893,9 +2023,11 @@ def _key(classification):
for measurement in self.measurements:
work_requires_full_recalculation = True
identifier.add_measurement(
- data_source, measurement.quantity_measured,
- measurement.value, measurement.weight,
- measurement.taken_at
+ data_source,
+ measurement.quantity_measured,
+ measurement.value,
+ measurement.weight,
+ measurement.taken_at,
)
if not edition.sort_author:
@@ -1907,7 +2039,7 @@ def _key(classification):
self.log.info(
"In the absence of Contributor objects, setting Edition author name to %s/%s",
primary_author.sort_name,
- primary_author.display_name
+ primary_author.display_name,
)
edition.sort_author = primary_author.sort_name
edition.display_author = primary_author.display_name
@@ -1961,7 +2093,11 @@ def _key(classification):
DataSource.BIBLIOTHECA,
DataSource.AXIS_360,
]
- if work_requires_new_presentation_edition and (not data_source.integration_client) and (data_source.name not in METADATA_UPLOAD_BLACKLIST):
+ if (
+ work_requires_new_presentation_edition
+ and (not data_source.integration_client)
+ and (data_source.name not in METADATA_UPLOAD_BLACKLIST)
+ ):
# Create a transient failure CoverageRecord for this edition
# so it will be processed by the MetadataUploadCoverageProvider.
internal_processing = DataSource.lookup(_db, DataSource.INTERNAL_PROCESSING)
@@ -1969,19 +2105,27 @@ def _key(classification):
# If there's already a CoverageRecord, don't change it to transient failure.
# TODO: Once the metadata wrangler can handle it, we'd like to re-sync the
# metadata every time there's a change. For now,
- cr = CoverageRecord.lookup(edition, internal_processing,
- operation=CoverageRecord.METADATA_UPLOAD_OPERATION)
+ cr = CoverageRecord.lookup(
+ edition,
+ internal_processing,
+ operation=CoverageRecord.METADATA_UPLOAD_OPERATION,
+ )
if not cr:
- CoverageRecord.add_for(edition, internal_processing,
- operation=CoverageRecord.METADATA_UPLOAD_OPERATION,
- status=CoverageRecord.TRANSIENT_FAILURE)
+ CoverageRecord.add_for(
+ edition,
+ internal_processing,
+ operation=CoverageRecord.METADATA_UPLOAD_OPERATION,
+ status=CoverageRecord.TRANSIENT_FAILURE,
+ )
# Update the coverage record for this edition and data
# source. We omit the collection information, even if we know
# which collection this is, because we only changed metadata.
CoverageRecord.add_for(
- edition, data_source, timestamp=self.data_source_last_updated,
- collection=None
+ edition,
+ data_source,
+ timestamp=self.data_source_last_updated,
+ collection=None,
)
if work_requires_full_recalculation or work_requires_new_presentation_edition:
@@ -1991,8 +2135,10 @@ def _key(classification):
# Any LicensePool will do here, since all LicensePools for
# a given Identifier have the same Work.
pool = get_one(
- _db, LicensePool, identifier=edition.primary_identifier,
- on_multiple='interchangeable'
+ _db,
+ LicensePool,
+ identifier=edition.primary_identifier,
+ on_multiple="interchangeable",
)
if pool and pool.work:
work = pool.work
@@ -2003,7 +2149,6 @@ def _key(classification):
return edition, work_requires_new_presentation_edition
-
def make_thumbnail(self, data_source, link, link_obj):
"""Make sure a Hyperlink representing an image is connected
to its thumbnail.
@@ -2016,26 +2161,29 @@ def make_thumbnail(self, data_source, link, link_obj):
# The image serves as its own thumbnail. This is a
# hacky way to represent this in the database.
if link_obj.resource.representation:
- link_obj.resource.representation.image_height = Edition.MAX_THUMBNAIL_HEIGHT
+ link_obj.resource.representation.image_height = (
+ Edition.MAX_THUMBNAIL_HEIGHT
+ )
return link_obj
# The thumbnail and image are different. Make sure there's a
# separate link to the thumbnail.
thumbnail_obj, ignore = link_obj.identifier.add_link(
- rel=thumbnail.rel, href=thumbnail.href,
+ rel=thumbnail.rel,
+ href=thumbnail.href,
data_source=data_source,
media_type=thumbnail.media_type,
- content=thumbnail.content
+ content=thumbnail.content,
)
# And make sure the thumbnail knows it's a thumbnail of the main
# image.
if thumbnail_obj.resource.representation:
- thumbnail_obj.resource.representation.thumbnail_of = link_obj.resource.representation
+ thumbnail_obj.resource.representation.thumbnail_of = (
+ link_obj.resource.representation
+ )
return thumbnail_obj
-
- def update_contributions(self, _db, edition, metadata_client=None,
- replace=True):
+ def update_contributions(self, _db, edition, metadata_client=None, replace=True):
contributors_changed = False
old_contributors = []
new_contributors = []
@@ -2055,17 +2203,17 @@ def update_contributions(self, _db, edition, metadata_client=None,
edition.contributions = surviving_contributions
for contributor_data in self.contributors:
- contributor_data.find_sort_name(
- _db, self.identifiers, metadata_client
- )
- if (contributor_data.sort_name
+ contributor_data.find_sort_name(_db, self.identifiers, metadata_client)
+ if (
+ contributor_data.sort_name
or contributor_data.lc
- or contributor_data.viaf):
+ or contributor_data.viaf
+ ):
contributor = edition.add_contributor(
name=contributor_data.sort_name,
roles=contributor_data.roles,
lc=contributor_data.lc,
- viaf=contributor_data.viaf
+ viaf=contributor_data.viaf,
)
new_contributors.append(contributor.id)
if contributor_data.display_name:
@@ -2083,7 +2231,7 @@ def update_contributions(self, _db, edition, metadata_client=None,
else:
self.log.info(
"Not registering %s because no sort name, LC, or VIAF",
- contributor_data.display_name
+ contributor_data.display_name,
)
if sorted(old_contributors) != sorted(new_contributors):
@@ -2091,7 +2239,6 @@ def update_contributions(self, _db, edition, metadata_client=None,
return contributors_changed
-
def filter_recommendations(self, _db):
"""Filters out recommended identifiers that don't exist in the db.
Any IdentifierData objects will be replaced with Identifiers.
@@ -2103,9 +2250,11 @@ def filter_recommendations(self, _db):
self.recommendations = []
for type, identifiers in list(by_type.items()):
- existing_identifiers = _db.query(Identifier).\
- filter(Identifier.type==type).\
- filter(Identifier.identifier.in_(identifiers))
+ existing_identifiers = (
+ _db.query(Identifier)
+ .filter(Identifier.type == type)
+ .filter(Identifier.identifier.in_(identifiers))
+ )
self.recommendations += existing_identifiers.all()
if self.primary_identifier in self.recommendations:
@@ -2126,47 +2275,49 @@ class CSVMetadataImporter(object):
Identifier.AXIS_360_ID,
Identifier.OVERDRIVE_ID,
Identifier.THREEM_ID,
- Identifier.ISBN
+ Identifier.ISBN,
]
DEFAULT_IDENTIFIER_FIELD_NAMES = {
- Identifier.OVERDRIVE_ID : ("overdrive id", 0.75),
- Identifier.THREEM_ID : ("3m id", 0.75),
- Identifier.AXIS_360_ID : ("axis 360 id", 0.75),
- Identifier.ISBN : ("isbn", 0.75),
+ Identifier.OVERDRIVE_ID: ("overdrive id", 0.75),
+ Identifier.THREEM_ID: ("3m id", 0.75),
+ Identifier.AXIS_360_ID: ("axis 360 id", 0.75),
+ Identifier.ISBN: ("isbn", 0.75),
}
- # When classifications are imported from a CSV file, we treat
+ # When classifications are imported from a CSV file, we treat
# them as though they came from a trusted distributor.
DEFAULT_SUBJECT_FIELD_NAMES = {
- 'tags': (Subject.TAG, Classification.TRUSTED_DISTRIBUTOR_WEIGHT),
- 'age' : (Subject.AGE_RANGE, Classification.TRUSTED_DISTRIBUTOR_WEIGHT),
- 'audience' : (Subject.FREEFORM_AUDIENCE,
- Classification.TRUSTED_DISTRIBUTOR_WEIGHT),
+ "tags": (Subject.TAG, Classification.TRUSTED_DISTRIBUTOR_WEIGHT),
+ "age": (Subject.AGE_RANGE, Classification.TRUSTED_DISTRIBUTOR_WEIGHT),
+ "audience": (
+ Subject.FREEFORM_AUDIENCE,
+ Classification.TRUSTED_DISTRIBUTOR_WEIGHT,
+ ),
}
def __init__(
- self,
- data_source_name,
- title_field='title',
- language_field='language',
- default_language='eng',
- medium_field='medium',
- default_medium=Edition.BOOK_MEDIUM,
- series_field='series',
- publisher_field='publisher',
- imprint_field='imprint',
- issued_field='issued',
- published_field=['published', 'publication year'],
- identifier_fields=DEFAULT_IDENTIFIER_FIELD_NAMES,
- subject_fields=DEFAULT_SUBJECT_FIELD_NAMES,
- sort_author_field='file author as',
- display_author_field=['author', 'display author as']
+ self,
+ data_source_name,
+ title_field="title",
+ language_field="language",
+ default_language="eng",
+ medium_field="medium",
+ default_medium=Edition.BOOK_MEDIUM,
+ series_field="series",
+ publisher_field="publisher",
+ imprint_field="imprint",
+ issued_field="issued",
+ published_field=["published", "publication year"],
+ identifier_fields=DEFAULT_IDENTIFIER_FIELD_NAMES,
+ subject_fields=DEFAULT_SUBJECT_FIELD_NAMES,
+ sort_author_field="file author as",
+ display_author_field=["author", "display author as"],
):
self.data_source_name = data_source_name
self.title_field = title_field
- self.language_field=language_field
- self.default_language=default_language
+ self.language_field = language_field
+ self.default_language = default_language
self.medium_field = medium_field
self.default_medium = default_medium
self.series_field = series_field
@@ -2196,8 +2347,8 @@ def to_metadata(self, dictreader):
break
if not found_identifier_field:
raise CSVFormatError(
- "Could not find a primary identifier field. Possibilities: %r. Actualities: %r." %
- (possibilities, fields)
+ "Could not find a primary identifier field. Possibilities: %r. Actualities: %r."
+ % (possibilities, fields)
)
for row in dictreader:
@@ -2236,9 +2387,7 @@ def row_to_metadata(self, row):
if field_name in row:
value = self._field(row, field_name)
if value:
- identifier = IdentifierData(
- identifier_type, value, weight=weight
- )
+ identifier = IdentifierData(identifier_type, value, weight=weight)
identifiers.append(identifier)
if not primary_identifier:
primary_identifier = identifier
@@ -2248,11 +2397,7 @@ def row_to_metadata(self, row):
values = self.list_field(row, field_name)
for value in values:
subjects.append(
- SubjectData(
- type=subject_type,
- identifier=value,
- weight=weight
- )
+ SubjectData(type=subject_type, identifier=value, weight=weight)
)
contributors = []
@@ -2261,8 +2406,9 @@ def row_to_metadata(self, row):
if sort_author or display_author:
contributors.append(
ContributorData(
- sort_name=sort_author, display_name=display_author,
- roles=[Contributor.AUTHOR_ROLE]
+ sort_name=sort_author,
+ display_name=display_author,
+ roles=[Contributor.AUTHOR_ROLE],
)
)
@@ -2279,7 +2425,7 @@ def row_to_metadata(self, row):
primary_identifier=primary_identifier,
identifiers=identifiers,
subjects=subjects,
- contributors=contributors
+ contributors=contributors,
)
metadata.csv_row = row
return metadata
@@ -2348,7 +2494,7 @@ class MARCExtractor(object):
# Common things found in a MARC record after the name of the author
# which we sould like to remove.
END_OF_AUTHOR_NAME_RES = [
- re.compile(",\s+[0-9]+-"), # Birth year
+ re.compile(",\s+[0-9]+-"), # Birth year
re.compile(",\s+active "),
re.compile(",\s+graf,"),
re.compile(",\s+author."),
@@ -2361,7 +2507,7 @@ def name_cleanup(cls, name):
for regex in cls.END_OF_AUTHOR_NAME_RES:
match = regex.search(name)
if match:
- name = name[:match.start()]
+ name = name[: match.start()]
break
name = name_tidy(name)
return name
@@ -2383,15 +2529,15 @@ def parse(cls, file, data_source_name, default_medium_type=None):
for record in reader:
title = record.title()
- if title.endswith(' /'):
- title = title[:-len(' /')]
+ if title.endswith(" /"):
+ title = title[: -len(" /")]
issued_year = cls.parse_year(record.pubyear())
publisher = record.publisher()
- if publisher.endswith(','):
+ if publisher.endswith(","):
publisher = publisher[:-1]
links = []
- summary = record.notes()[0]['a']
+ summary = record.notes()[0]["a"]
if summary:
summary_link = LinkData(
@@ -2401,22 +2547,23 @@ def parse(cls, file, data_source_name, default_medium_type=None):
)
links.append(summary_link)
- isbn = record['020']['a'].split(" ")[0]
- primary_identifier = IdentifierData(
- Identifier.ISBN, isbn
- )
+ isbn = record["020"]["a"].split(" ")[0]
+ primary_identifier = IdentifierData(Identifier.ISBN, isbn)
- subjects = [SubjectData(
- Classifier.FAST,
- subject['a'],
- ) for subject in record.subjects()]
+ subjects = [
+ SubjectData(
+ Classifier.FAST,
+ subject["a"],
+ )
+ for subject in record.subjects()
+ ]
author = record.author()
if author:
author = cls.name_cleanup(author)
author_names = [author]
else:
- author_names = ['Anonymous']
+ author_names = ["Anonymous"]
contributors = [
ContributorData(
sort_name=author,
@@ -2425,16 +2572,18 @@ def parse(cls, file, data_source_name, default_medium_type=None):
for author in author_names
]
- metadata_records.append(Metadata(
- data_source=data_source_name,
- title=title,
- language='eng',
- medium=Edition.BOOK_MEDIUM,
- publisher=publisher,
- issued=issued_year,
- primary_identifier=primary_identifier,
- subjects=subjects,
- contributors=contributors,
- links=links
- ))
+ metadata_records.append(
+ Metadata(
+ data_source=data_source_name,
+ title=title,
+ language="eng",
+ medium=Edition.BOOK_MEDIUM,
+ publisher=publisher,
+ issued=issued_year,
+ primary_identifier=primary_identifier,
+ subjects=subjects,
+ contributors=contributors,
+ links=links,
+ )
+ )
return metadata_records
diff --git a/migration/20150813_set_open_access_download_url.py b/migration/20150813_set_open_access_download_url.py
index 9499375e2..391a82428 100644
--- a/migration/20150813_set_open_access_download_url.py
+++ b/migration/20150813_set_open_access_download_url.py
@@ -2,34 +2,34 @@
"""Set Edition.open_access_download_url for all Project Gutenberg books."""
import os
import sys
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
-from core.monitor import EditionSweepMonitor
from core.model import (
- production_session,
DataSource,
+ DeliveryMechanism,
Edition,
Representation,
- DeliveryMechanism,
+ production_session,
)
+from core.monitor import EditionSweepMonitor
from core.scripts import RunMonitorScript
+set_delivery_mechanism = len(sys.argv) > 1 and sys.argv[1] == "delivery"
-set_delivery_mechanism = len(sys.argv) > 1 and sys.argv[1] == 'delivery'
class OpenAccessDownloadSetMonitor(EditionSweepMonitor):
"""Set the open-access link f."""
def __init__(self, _db, interval_seconds=None):
super(OpenAccessDownloadSetMonitor, self).__init__(
- _db, "Open Access Download link set", interval_seconds,
- batch_size=100
+ _db, "Open Access Download link set", interval_seconds, batch_size=100
)
def edition_query(self):
gutenberg = DataSource.lookup(self._db, DataSource.GUTENBERG)
- return self._db.query(Edition).filter(Edition.data_source==gutenberg)
+ return self._db.query(Edition).filter(Edition.data_source == gutenberg)
def process_edition(self, edition):
edition.set_open_access_link()
@@ -38,11 +38,11 @@ def process_edition(self, edition):
if link:
print(edition.id, edition.title, link.url)
edition.license_pool.set_delivery_mechanism(
- Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM,
- link
+ Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM, link
)
else:
print(edition.id, edition.title, "[no link]")
return True
+
RunMonitorScript(OpenAccessDownloadSetMonitor).run()
diff --git a/migration/20150826-4-recalculate_age_range_for_childrens_books.py b/migration/20150826-4-recalculate_age_range_for_childrens_books.py
index 1ca584d8b..79727a8ec 100644
--- a/migration/20150826-4-recalculate_age_range_for_childrens_books.py
+++ b/migration/20150826-4-recalculate_age_range_for_childrens_books.py
@@ -2,26 +2,26 @@
"""Recalculate the age range for all subjects whose audience is Children or Young Adult."""
import os
import sys
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
-from core.monitor import SubjectSweepMonitor
from core.classifier import Classifier
-from core.model import (
- production_session,
- DataSource,
- Edition,
- Subject,
-)
+from core.model import DataSource, Edition, Subject, production_session
+from core.monitor import SubjectSweepMonitor
from core.scripts import RunMonitorScript
+
class RecalculateAgeRangeMonitor(SubjectSweepMonitor):
"""Recalculate the age range for every young adult or children's subject."""
def __init__(self, _db, interval_seconds=None):
super(RecalculateAgeRangeMonitor, self).__init__(
- _db, "20150825 migration - Recalculate age range for children's books",
- interval_seconds, batch_size=1000)
+ _db,
+ "20150825 migration - Recalculate age range for children's books",
+ interval_seconds,
+ batch_size=1000,
+ )
def subject_query(self):
audiences = [Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_CHILDREN]
@@ -33,4 +33,5 @@ def process_identifier(self, subject):
if subject.target_age != old_target_age and subject.target_age.lower != None:
print(("%r: %r->%r" % (subject, old_target_age, subject.target_age)))
+
RunMonitorScript(RecalculateAgeRangeMonitor).run()
diff --git a/migration/20150826-5-recalculate_work_presentation_for_childrens_books.py b/migration/20150826-5-recalculate_work_presentation_for_childrens_books.py
index 28fdb6336..b0a028b1b 100644
--- a/migration/20150826-5-recalculate_work_presentation_for_childrens_books.py
+++ b/migration/20150826-5-recalculate_work_presentation_for_childrens_books.py
@@ -2,45 +2,52 @@
"""Recalculate the age range for all subjects whose audience is Children or Young Adult."""
import os
import sys
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
-from core.monitor import WorkSweepMonitor
from core.classifier import Classifier
from core.model import (
- production_session,
DataSource,
Edition,
+ Identifier,
Subject,
Work,
- Identifier,
+ production_session,
)
+from core.monitor import WorkSweepMonitor
from core.scripts import RunMonitorScript
from psycopg2.extras import NumericRange
+
class RecalculateAgeRangeMonitor(WorkSweepMonitor):
"""Recalculate the age range for every young adult or children's book."""
def __init__(self, _db, interval_seconds=None):
super(RecalculateAgeRangeMonitor, self).__init__(
- _db, "20150825 migration - Recalculate age range for children's books (Works)",
- interval_seconds, batch_size=10)
+ _db,
+ "20150825 migration - Recalculate age range for children's books (Works)",
+ interval_seconds,
+ batch_size=10,
+ )
def work_query(self):
audiences = [Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_CHILDREN]
return self._db.query(Work).filter(Work.audience.in_(audiences))
def process_work(self, work):
- primary_identifier_ids = [
- x.primary_identifier.id for x in work.editions]
+ primary_identifier_ids = [x.primary_identifier.id for x in work.editions]
data = Identifier.recursively_equivalent_identifier_ids(
- self._db, primary_identifier_ids, 5, threshold=0.5)
+ self._db, primary_identifier_ids, 5, threshold=0.5
+ )
flattened_data = Identifier.flatten_identifier_ids(data)
workgenres, work.fiction, work.audience, target_age = work.assign_genres(
- flattened_data)
+ flattened_data
+ )
old_target_age = work.target_age
work.target_age = NumericRange(*target_age)
if work.target_age != old_target_age and work.target_age.lower != None:
print("%r: %r->%r" % (work.title, old_target_age, work.target_age))
+
RunMonitorScript(RecalculateAgeRangeMonitor).run()
diff --git a/migration/20150828-recalculate_age_range_for_childrens_books.py b/migration/20150828-recalculate_age_range_for_childrens_books.py
index 6653b1c9c..df1dcedd4 100644
--- a/migration/20150828-recalculate_age_range_for_childrens_books.py
+++ b/migration/20150828-recalculate_age_range_for_childrens_books.py
@@ -2,26 +2,26 @@
"""Recalculate the age range for all subjects whose audience is Children or Young Adult."""
import os
import sys
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
-from core.monitor import SubjectSweepMonitor
from core.classifier import Classifier
-from core.model import (
- production_session,
- DataSource,
- Edition,
- Subject,
-)
+from core.model import DataSource, Edition, Subject, production_session
+from core.monitor import SubjectSweepMonitor
from core.scripts import RunMonitorScript
+
class RecalculateAgeRangeMonitor(SubjectSweepMonitor):
"""Recalculate the age range for every young adult or children's subject."""
def __init__(self, _db, interval_seconds=None):
super(RecalculateAgeRangeMonitor, self).__init__(
- _db, "20150825 migration - Recalculate age range for children's books",
- interval_seconds, batch_size=1000)
+ _db,
+ "20150825 migration - Recalculate age range for children's books",
+ interval_seconds,
+ batch_size=1000,
+ )
def subject_query(self):
audiences = [Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_CHILDREN]
@@ -33,4 +33,5 @@ def process_identifier(self, subject):
if subject.target_age != old_target_age and subject.target_age.lower != None:
print("%r: %r->%r" % (subject, old_target_age, subject.target_age))
+
RunMonitorScript(RecalculateAgeRangeMonitor).run()
diff --git a/migration/20151002-fix-audiobooks-mislabeled-as-print-books.py b/migration/20151002-fix-audiobooks-mislabeled-as-print-books.py
index b69fde40e..495131e71 100644
--- a/migration/20151002-fix-audiobooks-mislabeled-as-print-books.py
+++ b/migration/20151002-fix-audiobooks-mislabeled-as-print-books.py
@@ -3,44 +3,53 @@
import os
import sys
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..")
sys.path.append(os.path.abspath(package_dir))
-from monitor import IdentifierSweepMonitor
+from threem import ThreeMAPI
+
from model import (
+ DeliveryMechanism,
+ Edition,
Identifier,
LicensePool,
- DeliveryMechanism,
LicensePoolDeliveryMechanism,
- Edition,
)
-from scripts import RunMonitorScript
+from monitor import IdentifierSweepMonitor
from overdrive import OverdriveAPI, OverdriveRepresentationExtractor
-from threem import ThreeMAPI
+from scripts import RunMonitorScript
-class SetDeliveryMechanismMonitor(IdentifierSweepMonitor):
+class SetDeliveryMechanismMonitor(IdentifierSweepMonitor):
def __init__(self, _db, interval_seconds=None):
super(SetDeliveryMechanismMonitor, self).__init__(
- _db, "20151002 migration - Correct medium of mislabeled audiobooks",
- interval_seconds, batch_size=100)
+ _db,
+ "20151002 migration - Correct medium of mislabeled audiobooks",
+ interval_seconds,
+ batch_size=100,
+ )
self.overdrive = OverdriveAPI(_db)
self.threem = ThreeMAPI(_db)
- types = [Identifier.THREEM_ID, Identifier.OVERDRIVE_ID,
- Identifier.AXIS_360_ID]
+ types = [Identifier.THREEM_ID, Identifier.OVERDRIVE_ID, Identifier.AXIS_360_ID]
- content_types = ["application/epub+zip", "application/pdf",
- "Kindle via Amazon", "Streaming Text"]
+ content_types = [
+ "application/epub+zip",
+ "application/pdf",
+ "Kindle via Amazon",
+ "Streaming Text",
+ ]
def identifier_query(self):
- qu = self._db.query(Identifier).join(
- Identifier.licensed_through).join(
- LicensePool.delivery_mechanisms).join(
- LicensePoolDeliveryMechanism.delivery_mechanism).filter(
- Identifier.type.in_(self.types)).filter(
- ~DeliveryMechanism.content_type.in_(self.content_types)
- )
+ qu = (
+ self._db.query(Identifier)
+ .join(Identifier.licensed_through)
+ .join(LicensePool.delivery_mechanisms)
+ .join(LicensePoolDeliveryMechanism.delivery_mechanism)
+ .filter(Identifier.type.in_(self.types))
+ .filter(~DeliveryMechanism.content_type.in_(self.content_types))
+ )
return qu
def process_identifier(self, identifier):
@@ -51,12 +60,12 @@ def process_identifier(self, identifier):
correct_medium = lpdm.delivery_mechanism.implicit_medium
if correct_medium:
break
- if not correct_medium and identifier.type==Identifier.OVERDRIVE_ID:
+ if not correct_medium and identifier.type == Identifier.OVERDRIVE_ID:
content = self.overdrive.metadata_lookup(identifier)
metadata = OverdriveRepresentationExtractor.book_info_to_metadata(content)
correct_medium = metadata.medium
- if not correct_medium and identifier.type==Identifier.THREEM_ID:
+ if not correct_medium and identifier.type == Identifier.THREEM_ID:
metadata = self.threem.bibliographic_lookup(identifier)
correct_medium = metadata.medium
@@ -64,7 +73,13 @@ def process_identifier(self, identifier):
set_trace()
if lp.edition.medium != correct_medium:
- print(("%s is actually %s, not %s" % (lp.edition.title, correct_medium, lp.edition.medium)))
+ print(
+ (
+ "%s is actually %s, not %s"
+ % (lp.edition.title, correct_medium, lp.edition.medium)
+ )
+ )
lp.edition.medium = correct_medium or Edition.BOOK_MEDIUM
+
RunMonitorScript(SetDeliveryMechanismMonitor).run()
diff --git a/migration/20151019-recalculate-fiction-status-for-overdrive-subjects.py b/migration/20151019-recalculate-fiction-status-for-overdrive-subjects.py
index 57c194afd..9859a59af 100644
--- a/migration/20151019-recalculate-fiction-status-for-overdrive-subjects.py
+++ b/migration/20151019-recalculate-fiction-status-for-overdrive-subjects.py
@@ -2,29 +2,29 @@
"""Recalculate the age range for all subjects whose audience is Children or Young Adult."""
import os
import sys
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
-from core.monitor import SubjectSweepMonitor
from core.classifier import Classifier
-from core.model import (
- production_session,
- DataSource,
- Edition,
- Subject,
-)
+from core.model import DataSource, Edition, Subject, production_session
+from core.monitor import SubjectSweepMonitor
from core.scripts import RunMonitorScript
+
class RecalculateFictionStatusMonitor(SubjectSweepMonitor):
"""Recalculate the age range for every young adult or children's subject."""
def __init__(self, _db, interval_seconds=None):
super(RecalculateFictionStatusMonitor, self).__init__(
- _db, "20150825 migration - Recalculate age range for children's books",
- interval_seconds, batch_size=1000)
+ _db,
+ "20150825 migration - Recalculate age range for children's books",
+ interval_seconds,
+ batch_size=1000,
+ )
def subject_query(self):
- return self._db.query(Subject).filter(Subject.type==Subject.OVERDRIVE)
+ return self._db.query(Subject).filter(Subject.type == Subject.OVERDRIVE)
def process_identifier(self, subject):
old_fiction = subject.fiction
@@ -32,4 +32,5 @@ def process_identifier(self, subject):
subject.assign_to_genre()
print("%s %s %s" % (subject.identifier, subject.fiction, subject.audience))
+
RunMonitorScript(RecalculateFictionStatusMonitor).run()
diff --git a/migration/20151020-fix-dewey-304.py b/migration/20151020-fix-dewey-304.py
index 83814eba1..ac9f4441d 100644
--- a/migration/20151020-fix-dewey-304.py
+++ b/migration/20151020-fix-dewey-304.py
@@ -5,21 +5,22 @@
Correct all 304.* and 305.* subjects, and reclassify every work classified under those IDs.
"""
+import logging
import os
import sys
-import logging
from pdb import set_trace
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
from core.model import (
- production_session,
Edition,
+ Genre,
Identifier,
+ Subject,
Work,
- Genre,
WorkGenre,
- Subject,
+ production_session,
)
_db = production_session()
@@ -27,23 +28,35 @@
def reclassify(ddc):
log = logging.getLogger("Migration script - Fix Dewey %s" % ddc)
- for subject in _db.query(Subject).filter(Subject.type==Subject.DDC).filter(Subject.identifier.like(ddc + "%")):
+ for subject in (
+ _db.query(Subject)
+ .filter(Subject.type == Subject.DDC)
+ .filter(Subject.identifier.like(ddc + "%"))
+ ):
log.info("Considering subject %s/%s", subject.identifier, subject.name)
subject.assign_to_genre()
for cl in subject.classifications:
ids = cl.identifier.equivalent_identifier_ids()
log.info("Looking for editions associated with %d ids.", len(ids))
- editions = _db.query(Edition).filter(Edition.primary_identifier_id.in_(ids)).all()
+ editions = (
+ _db.query(Edition).filter(Edition.primary_identifier_id.in_(ids)).all()
+ )
for edition in editions:
if edition.work:
old_genres = set(edition.work.genres)
edition.work.calculate_presentation()
if old_genres != set(edition.work.genres):
- log.info("%s GENRE CHANGE: %r -> %r", edition.title, old_genres, edition.work.genres)
+ log.info(
+ "%s GENRE CHANGE: %r -> %r",
+ edition.title,
+ old_genres,
+ edition.work.genres,
+ )
else:
edition.calculate_presentation()
_db.commit()
+
reclassify("205")
reclassify("304")
reclassify("305")
diff --git a/migration/20151021-set-license-pool-for-list-membership.py b/migration/20151021-set-license-pool-for-list-membership.py
index 508cc8dbc..6848f5f16 100644
--- a/migration/20151021-set-license-pool-for-list-membership.py
+++ b/migration/20151021-set-license-pool-for-list-membership.py
@@ -1,23 +1,21 @@
#!/usr/bin/env python
"""Make sure every CustomListEntry has a LicensePool set.
"""
+import logging
import os
import sys
-import logging
from pdb import set_trace
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
-from core.model import (
- production_session,
- CustomListEntry,
-)
+from core.model import CustomListEntry, production_session
_db = production_session()
-qu = _db.query(CustomListEntry).filter(CustomListEntry.license_pool==None)
+qu = _db.query(CustomListEntry).filter(CustomListEntry.license_pool == None)
print("Fixing %d custom list entries with no licensepool." % qu.count())
for cle in qu:
diff --git a/migration/20160706-repair-duplicate-and-ambiguous-contributor-records.py b/migration/20160706-repair-duplicate-and-ambiguous-contributor-records.py
index 0cb6ef8fe..8222f8084 100644
--- a/migration/20160706-repair-duplicate-and-ambiguous-contributor-records.py
+++ b/migration/20160706-repair-duplicate-and-ambiguous-contributor-records.py
@@ -3,34 +3,27 @@
and 'Author', and Editions that list the same contributor in an
'Unknown' role plus some more specific role.
"""
+import logging
import os
import sys
-import logging
from pdb import set_trace
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
import time
-from sqlalchemy.orm import (
- aliased,
-)
-from sqlalchemy.sql.expression import (
- and_,
- or_
-)
-from core.model import (
- Contribution,
- Contributor,
- Edition,
- production_session
-)
+from core.model import Contribution, Contributor, Edition, production_session
+from sqlalchemy.orm import aliased
+from sqlalchemy.sql.expression import and_, or_
def dedupe(edition):
print("Deduping edition %s (%s)" % (edition.id, edition.title))
- primary_author = [x for x in edition.contributions if x.role==Contributor.PRIMARY_AUTHOR_ROLE]
+ primary_author = [
+ x for x in edition.contributions if x.role == Contributor.PRIMARY_AUTHOR_ROLE
+ ]
seen = set()
contributors_with_roles = set()
unresolved_mysteries = {}
@@ -54,9 +47,12 @@ def dedupe(edition):
_db.delete(contribution)
continue
seen.add(key)
- if role == 'Unknown':
+ if role == "Unknown":
if contributor in contributors_with_roles:
- print(" Found unknown role for %s, but mystery already resolved." % contributor.name)
+ print(
+ " Found unknown role for %s, but mystery already resolved."
+ % contributor.name
+ )
_db.delete(contribution)
else:
print(" The role of %s is a mystery." % contributor.name)
@@ -71,6 +67,7 @@ def dedupe(edition):
del unresolved_mysteries[contributor]
_db.delete(now_resolved)
+
_db = production_session()
contribution2 = aliased(Contribution)
@@ -79,21 +76,24 @@ def dedupe(edition):
# and some other role. Also find editions where one Contributor is listed
# twice in author roles.
unknown_role_or_duplicate_author_role = or_(
- and_(Contribution.role==Contributor.UNKNOWN_ROLE,
- contribution2.role != Contributor.UNKNOWN_ROLE),
+ and_(
+ Contribution.role == Contributor.UNKNOWN_ROLE,
+ contribution2.role != Contributor.UNKNOWN_ROLE,
+ ),
and_(
Contribution.role.in_(Contributor.AUTHOR_ROLES),
contribution2.role.in_(Contributor.AUTHOR_ROLES),
- )
+ ),
)
-qu = _db.query(Edition).join(Edition.contributions).join(
- contribution2, contribution2.edition_id==Edition.id).filter(
- contribution2.id != Contribution.id).filter(
- contribution2.contributor_id==Contribution.contributor_id
- ).filter(
- unknown_role_or_duplicate_author_role
- )
+qu = (
+ _db.query(Edition)
+ .join(Edition.contributions)
+ .join(contribution2, contribution2.edition_id == Edition.id)
+ .filter(contribution2.id != Contribution.id)
+ .filter(contribution2.contributor_id == Contribution.contributor_id)
+ .filter(unknown_role_or_duplicate_author_role)
+)
print("Fixing %s Editions." % qu.count())
qu = qu.limit(1000)
@@ -102,9 +102,9 @@ def dedupe(edition):
a = time.time()
results = qu.all()
for ed in qu:
- #for contribution in ed.contributions:
+ # for contribution in ed.contributions:
# print contribution.contributor, contribution.role
dedupe(ed)
_db.commit()
b = time.time()
- print("Batch processed in %.2f sec" % (b-a))
+ print("Batch processed in %.2f sec" % (b - a))
diff --git a/migration/20160719-fix-incorrectly-encoded-work-descriptions.py b/migration/20160719-fix-incorrectly-encoded-work-descriptions.py
index 1dfec72c5..14710da11 100644
--- a/migration/20160719-fix-incorrectly-encoded-work-descriptions.py
+++ b/migration/20160719-fix-incorrectly-encoded-work-descriptions.py
@@ -2,10 +2,11 @@
"""Fix work descriptions that were originally UTF-8 but were incorrectly
encoded as Windows-1252.
"""
+import logging
import os
import sys
-import logging
from pdb import set_trace
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
@@ -13,10 +14,7 @@
import time
from core.external_search import ExternalSearchIndex
-from core.model import (
- production_session,
- Work,
-)
+from core.model import Work, production_session
_db = production_session()
client = ExternalSearchIndex()
diff --git a/migration/20160721-kick-out-licensepools-without-pwid.py b/migration/20160721-kick-out-licensepools-without-pwid.py
index f44b4802f..c3b366cb4 100644
--- a/migration/20160721-kick-out-licensepools-without-pwid.py
+++ b/migration/20160721-kick-out-licensepools-without-pwid.py
@@ -5,20 +5,16 @@
their Works.
"""
+import logging
import os
import sys
-import logging
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..")
sys.path.append(os.path.abspath(package_dir))
-from core.model import (
- production_session,
- Edition,
- LicensePool,
- Work,
-)
+from core.model import Edition, LicensePool, Work, production_session
_db = production_session()
@@ -33,27 +29,34 @@ def fix(_db, description, qu):
print("Committing")
_db.commit()
-no_presentation_edition = _db.query(LicensePool).outerjoin(
- LicensePool.presentation_edition).filter(Edition.id==None).filter(
- LicensePool.work_id != None
- )
-no_permanent_work_id = _db.query(LicensePool).join(
- LicensePool.presentation_edition).filter(
- Edition.permanent_work_id==None
- ).filter(
- LicensePool.work_id != None
- )
+no_presentation_edition = (
+ _db.query(LicensePool)
+ .outerjoin(LicensePool.presentation_edition)
+ .filter(Edition.id == None)
+ .filter(LicensePool.work_id != None)
+)
-no_title = _db.query(LicensePool).join(
- LicensePool.presentation_edition).filter(
- Edition.title==None
- ).filter(
- LicensePool.work_id != None
- )
+no_permanent_work_id = (
+ _db.query(LicensePool)
+ .join(LicensePool.presentation_edition)
+ .filter(Edition.permanent_work_id == None)
+ .filter(LicensePool.work_id != None)
+)
+
+no_title = (
+ _db.query(LicensePool)
+ .join(LicensePool.presentation_edition)
+ .filter(Edition.title == None)
+ .filter(LicensePool.work_id != None)
+)
-licensepools_in_same_work_as_another_licensepool_with_different_pwid = _db.execute("select lp1.id from licensepools lp1 join works w on lp1.work_id=w.id join editions e1 on lp1.presentation_edition_id=e1.id join licensepools lp2 on lp2.work_id=w.id join editions e2 on e2.id=lp2.presentation_edition_id and e2.permanent_work_id != e1.permanent_work_id;")
-ids = [x[0] for x in licensepools_in_same_work_as_another_licensepool_with_different_pwid]
+licensepools_in_same_work_as_another_licensepool_with_different_pwid = _db.execute(
+ "select lp1.id from licensepools lp1 join works w on lp1.work_id=w.id join editions e1 on lp1.presentation_edition_id=e1.id join licensepools lp2 on lp2.work_id=w.id join editions e2 on e2.id=lp2.presentation_edition_id and e2.permanent_work_id != e1.permanent_work_id;"
+)
+ids = [
+ x[0] for x in licensepools_in_same_work_as_another_licensepool_with_different_pwid
+]
in_same_work = _db.query(LicensePool).filter(LicensePool.id.in_(ids))
fix(_db, "Pools in the same work as another pool with a different pwid", in_same_work)
diff --git a/migration/20160728-1-add-rights-status-names.py b/migration/20160728-1-add-rights-status-names.py
index 7cdc52a97..10ceed4dc 100755
--- a/migration/20160728-1-add-rights-status-names.py
+++ b/migration/20160728-1-add-rights-status-names.py
@@ -1,18 +1,16 @@
#!/usr/bin/env python
"""Add names to rightsstatus table."""
+import logging
import os
import sys
-import logging
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..", "..", "..")
sys.path.append(os.path.abspath(package_dir))
-from core.model import (
- production_session,
- RightsStatus,
-)
+from core.model import RightsStatus, production_session
_db = production_session()
diff --git a/migration/20170713-18-move-third-party-config-to-external-integrations.py b/migration/20170713-18-move-third-party-config-to-external-integrations.py
index dd8e8e969..147f5bd46 100755
--- a/migration/20170713-18-move-third-party-config-to-external-integrations.py
+++ b/migration/20170713-18-move-third-party-config-to-external-integrations.py
@@ -2,10 +2,9 @@
"""Move integration details from the Configuration file into the
database as ExternalIntegrations
"""
+import logging
import os
import sys
-import logging
-
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..")
@@ -13,24 +12,23 @@
from config import Configuration
from external_search import ExternalSearchIndex
-from model import (
- ExternalIntegration as EI,
- production_session,
-)
-
+from model import ExternalIntegration as EI
+from model import production_session
from s3 import S3Uploader
log = logging.getLogger(name="Core configuration import")
+
def log_import(integration_or_setting):
log.info("CREATED: %r" % integration_or_setting)
+
try:
Configuration.load()
_db = production_session()
# Import CDN configuration.
- cdn_conf = Configuration.integration('CDN')
+ cdn_conf = Configuration.integration("CDN")
if cdn_conf and isinstance(cdn_conf, dict):
for k, v in list(cdn_conf.items()):
@@ -41,9 +39,9 @@ def log_import(integration_or_setting):
log_import(cdn)
# Import Elasticsearch configuration.
- elasticsearch_conf = Configuration.integration('Elasticsearch')
+ elasticsearch_conf = Configuration.integration("Elasticsearch")
if elasticsearch_conf:
- url = elasticsearch_conf.get('url')
+ url = elasticsearch_conf.get("url")
works_index = elasticsearch_conf.get(ExternalSearchIndex.WORKS_INDEX_KEY)
integration = EI(protocol=EI.ELASTICSEARCH, goal=EI.SEARCH_GOAL)
@@ -52,19 +50,17 @@ def log_import(integration_or_setting):
if url:
integration.url = str(url)
if works_index:
- integration.set_setting(
- ExternalSearchIndex.WORKS_INDEX_KEY, works_index
- )
+ integration.set_setting(ExternalSearchIndex.WORKS_INDEX_KEY, works_index)
log_import(integration)
# Import S3 configuration.
- s3_conf = Configuration.integration('S3')
+ s3_conf = Configuration.integration("S3")
if s3_conf:
- username = s3_conf.get('access_key')
- password = s3_conf.get('secret_key')
- del s3_conf['access_key']
- del s3_conf['secret_key']
+ username = s3_conf.get("access_key")
+ password = s3_conf.get("secret_key")
+ del s3_conf["access_key"]
+ del s3_conf["secret_key"]
integration = EI(protocol=EI.S3, goal=EI.STORAGE_GOAL)
_db.add(integration)
diff --git a/migration/20170908-change-metadata-wrangler-settings.py b/migration/20170908-change-metadata-wrangler-settings.py
index 5482ab181..3e97f5b88 100755
--- a/migration/20170908-change-metadata-wrangler-settings.py
+++ b/migration/20170908-change-metadata-wrangler-settings.py
@@ -1,18 +1,15 @@
#!/usr/bin/env python
"""Delete outdated ConfigurationSettings for the metadata wrangler."""
+import logging
import os
import sys
-import logging
-
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..")
sys.path.append(os.path.abspath(package_dir))
-from model import (
- production_session,
- ExternalIntegration as EI,
-)
+from model import ExternalIntegration as EI
+from model import production_session
_db = production_session()
try:
@@ -20,10 +17,10 @@
if integration:
for setting in integration.settings:
- if setting.key == 'username':
+ if setting.key == "username":
# A username (or client_id) is no longer required.
_db.delete(setting)
- if setting.key == 'password':
+ if setting.key == "password":
# The password (previously client_secret) must be reset to
# register for a shared_secret.
setting.value = None
diff --git a/migration/20170926-migrate-log-configuration.py b/migration/20170926-migrate-log-configuration.py
index 035e78333..dc3d496e9 100755
--- a/migration/20170926-migrate-log-configuration.py
+++ b/migration/20170926-migrate-log-configuration.py
@@ -3,30 +3,27 @@
database as ExternalIntegrations
"""
+import logging
import os
import sys
-import logging
-
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..")
sys.path.append(os.path.abspath(package_dir))
from config import Configuration
-from model import (
- ExternalIntegration as EI,
- production_session,
-)
+from model import ExternalIntegration as EI
+from model import production_session
+
_db = production_session()
log = logging.getLogger(name="Log configuration import")
-loggly_conf = Configuration.integration('loggly')
+loggly_conf = Configuration.integration("loggly")
if loggly_conf:
integration = EI(goal=EI.LOGGING_GOAL, protocol=EI.LOGGLY)
_db.add(integration)
integration.url = loggly_conf.get(
- 'url', 'https://logs-01.loggly.com/inputs/%(token)s/tag/python/'
+ "url", "https://logs-01.loggly.com/inputs/%(token)s/tag/python/"
)
- integration.password = loggly_conf.get('token')
+ integration.password = loggly_conf.get("token")
_db.commit()
-
diff --git a/migration/20171201-rebuild-search-index.py b/migration/20171201-rebuild-search-index.py
index 3646796b8..4b984c048 100755
--- a/migration/20171201-rebuild-search-index.py
+++ b/migration/20171201-rebuild-search-index.py
@@ -5,8 +5,10 @@
"""
import os
import sys
+
bin_dir = os.path.split(__file__)[0]
package_dir = os.path.join(bin_dir, "..")
sys.path.append(os.path.abspath(package_dir))
from scripts import UpdateSearchIndexScript
+
UpdateSearchIndexScript().run()
diff --git a/mirror.py b/mirror.py
index 65b9e2453..a5a333465 100644
--- a/mirror.py
+++ b/mirror.py
@@ -1,15 +1,16 @@
-from abc import abstractmethod, ABCMeta
+from abc import ABCMeta, abstractmethod
from urllib.parse import urlsplit
from .config import CannotLoadConfiguration
from .util.datetime_helpers import utc_now
+
class MirrorUploader(metaclass=ABCMeta):
"""Handles the job of uploading a representation's content to
a mirror that we control.
"""
- STORAGE_GOAL = 'storage'
+ STORAGE_GOAL = "storage"
# Depending on the .protocol of an ExternalIntegration with
# .goal=STORAGE, a different subclass might be initialized by
@@ -38,9 +39,10 @@ def mirror(cls, _db, storage_name=None, integration=None):
def integration_by_name(cls, _db, storage_name=None):
"""Find the ExternalIntegration for the mirror by storage name."""
from .model import ExternalIntegration
+
qu = _db.query(ExternalIntegration).filter(
- ExternalIntegration.goal==cls.STORAGE_GOAL,
- ExternalIntegration.name==storage_name
+ ExternalIntegration.goal == cls.STORAGE_GOAL,
+ ExternalIntegration.name == storage_name,
)
integrations = qu.all()
if not integrations:
@@ -62,10 +64,14 @@ def for_collection(cls, collection, purpose):
mirror integration.
"""
from .model import ExternalIntegration
+
try:
from .model import Session
+
_db = Session.object_session(collection)
- integration = ExternalIntegration.for_collection_and_purpose(_db, collection, purpose)
+ integration = ExternalIntegration.for_collection_and_purpose(
+ _db, collection, purpose
+ )
except CannotLoadConfiguration as e:
return None
return cls.implementation(integration)
@@ -96,8 +102,8 @@ def __init__(self, integration, host):
# This collection's 'mirror integration' isn't intended to
# be used to mirror anything.
raise CannotLoadConfiguration(
- "Cannot create an MirrorUploader from an integration with goal=%s" %
- integration.goal
+ "Cannot create an MirrorUploader from an integration with goal=%s"
+ % integration.goal
)
self._host = host
@@ -132,10 +138,16 @@ def mirror_batch(self, representations):
"""Mirror a batch of Representations at once."""
for representation in representations:
- self.mirror_one(representation, '')
-
- def book_url(self, identifier, extension='.epub', open_access=True,
- data_source=None, title=None):
+ self.mirror_one(representation, "")
+
+ def book_url(
+ self,
+ identifier,
+ extension=".epub",
+ open_access=True,
+ data_source=None,
+ title=None,
+ ):
"""The URL of the hosted EPUB file for the given identifier.
This does not upload anything to the URL, but it is expected
@@ -144,8 +156,7 @@ def book_url(self, identifier, extension='.epub', open_access=True,
"""
raise NotImplementedError()
- def cover_image_url(self, data_source, identifier, filename=None,
- scaled_size=None):
+ def cover_image_url(self, data_source, identifier, filename=None, scaled_size=None):
"""The URL of the hosted cover image for the given identifier.
This does not upload anything to the URL, but it is expected
diff --git a/mock_analytics_provider.py b/mock_analytics_provider.py
index 8270c4f0a..c74f1dff3 100644
--- a/mock_analytics_provider.py
+++ b/mock_analytics_provider.py
@@ -12,4 +12,5 @@ def collect_event(self, library, lp, event_type, time=None, **kwargs):
self.event_type = event_type
self.time = time
+
Provider = MockAnalyticsProvider
diff --git a/model/__init__.py b/model/__init__.py
index c74ce0b5d..8781e936b 100644
--- a/model/__init__.py
+++ b/model/__init__.py
@@ -3,40 +3,21 @@
import logging
import os
import warnings
+
from psycopg2.extensions import adapt as sqlescape
from psycopg2.extras import NumericRange
-from sqlalchemy import (
- Column,
- create_engine,
- ForeignKey,
- Integer,
- Table,
- text,
-)
-from sqlalchemy.exc import (
- IntegrityError,
- SAWarning,
-)
+from sqlalchemy import Column, ForeignKey, Integer, Table, create_engine, text
+from sqlalchemy.exc import IntegrityError, SAWarning
from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import (
- relationship,
- sessionmaker,
-)
-from sqlalchemy.orm.exc import (
- NoResultFound,
- MultipleResultsFound,
-)
-from sqlalchemy.sql import (
- compiler,
- select,
-)
-from sqlalchemy.sql.expression import (
- literal_column,
- table,
-)
+from sqlalchemy.orm import relationship, sessionmaker
+from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
+from sqlalchemy.sql import compiler, select
+from sqlalchemy.sql.expression import literal_column, table
Base = declarative_base()
+from .. import classifier
+from ..util.datetime_helpers import utc_now
from .constants import (
DataSourceConstants,
EditionConstants,
@@ -44,16 +25,15 @@
LinkRelations,
MediaTypes,
)
-from .. import classifier
-from ..util.datetime_helpers import utc_now
+
def flush(db):
"""Flush the database connection unless it's known to already be flushing."""
is_flushing = False
- if hasattr(db, '_flushing'):
+ if hasattr(db, "_flushing"):
# This is a regular database session.
is_flushing = db._flushing
- elif hasattr(db, 'registry'):
+ elif hasattr(db, "registry"):
# This is a flask_scoped_session scoped session.
is_flushing = db.registry()._flushing
else:
@@ -61,16 +41,16 @@ def flush(db):
if not is_flushing:
db.flush()
-def create(db, model, create_method='',
- create_method_kwargs=None,
- **kwargs):
+
+def create(db, model, create_method="", create_method_kwargs=None, **kwargs):
kwargs.update(create_method_kwargs or {})
created = getattr(model, create_method, model)(**kwargs)
db.add(created)
flush(db)
return created, True
-def get_one(db, model, on_multiple='error', constraint=None, **kwargs):
+
+def get_one(db, model, on_multiple="error", constraint=None, **kwargs):
"""Gets an object from the database based on its attributes.
:param constraint: A single clause that can be passed into
@@ -78,9 +58,9 @@ def get_one(db, model, on_multiple='error', constraint=None, **kwargs):
:return: object or None
"""
constraint = constraint
- if 'constraint' in kwargs:
- constraint = kwargs['constraint']
- del kwargs['constraint']
+ if "constraint" in kwargs:
+ constraint = kwargs["constraint"]
+ del kwargs["constraint"]
q = db.query(model).filter_by(**kwargs)
if constraint is not None:
@@ -89,9 +69,9 @@ def get_one(db, model, on_multiple='error', constraint=None, **kwargs):
try:
return q.one()
except MultipleResultsFound:
- if on_multiple == 'error':
+ if on_multiple == "error":
raise
- elif on_multiple == 'interchangeable':
+ elif on_multiple == "interchangeable":
# These records are interchangeable so we can use
# whichever one we want.
#
@@ -102,9 +82,8 @@ def get_one(db, model, on_multiple='error', constraint=None, **kwargs):
except NoResultFound:
return None
-def get_one_or_create(db, model, create_method='',
- create_method_kwargs=None,
- **kwargs):
+
+def get_one_or_create(db, model, create_method="", create_method_kwargs=None, **kwargs):
one = get_one(db, model, **kwargs)
if one:
return one, False
@@ -112,7 +91,7 @@ def get_one_or_create(db, model, create_method='',
__transaction = db.begin_nested()
try:
# These kwargs are supported by get_one() but not by create().
- get_one_keys = ['on_multiple', 'constraint']
+ get_one_keys = ["on_multiple", "constraint"]
for key in get_one_keys:
if key in kwargs:
del kwargs[key]
@@ -121,11 +100,16 @@ def get_one_or_create(db, model, create_method='',
return obj
except IntegrityError as e:
logging.info(
- "INTEGRITY ERROR on %r %r, %r: %r", model, create_method_kwargs,
- kwargs, e)
+ "INTEGRITY ERROR on %r %r, %r: %r",
+ model,
+ create_method_kwargs,
+ kwargs,
+ e,
+ )
__transaction.rollback()
return db.query(model).filter_by(**kwargs).one(), False
+
def numericrange_to_string(r):
"""Helper method to convert a NumericRange to a human-readable string."""
if not r:
@@ -144,7 +128,8 @@ def numericrange_to_string(r):
lower += 1
if upper == lower:
return str(lower)
- return "%s-%s" % (lower,upper)
+ return "%s-%s" % (lower, upper)
+
def numericrange_to_tuple(r):
"""Helper method to normalize NumericRange into a tuple."""
@@ -158,11 +143,13 @@ def numericrange_to_tuple(r):
upper -= 1
return lower, upper
+
def tuple_to_numericrange(t):
"""Helper method to convert a tuple to an inclusive NumericRange."""
if not t:
return None
- return NumericRange(t[0], t[1], '[]')
+ return NumericRange(t[0], t[1], "[]")
+
class PresentationCalculationPolicy(object):
"""Which parts of the Work or Edition's presentation
@@ -173,20 +160,21 @@ class PresentationCalculationPolicy(object):
DEFAULT_THRESHOLD = 0.5
DEFAULT_CUTOFF = 1000
- def __init__(self,
- choose_edition=True,
- set_edition_metadata=True,
- classify=True,
- choose_summary=True,
- calculate_quality=True,
- choose_cover=True,
- regenerate_opds_entries=False,
- regenerate_marc_record=False,
- update_search_index=False,
- verbose=True,
- equivalent_identifier_levels=DEFAULT_LEVELS,
- equivalent_identifier_threshold=DEFAULT_THRESHOLD,
- equivalent_identifier_cutoff=DEFAULT_CUTOFF,
+ def __init__(
+ self,
+ choose_edition=True,
+ set_edition_metadata=True,
+ classify=True,
+ choose_summary=True,
+ calculate_quality=True,
+ choose_cover=True,
+ regenerate_opds_entries=False,
+ regenerate_marc_record=False,
+ update_search_index=False,
+ verbose=True,
+ equivalent_identifier_levels=DEFAULT_LEVELS,
+ equivalent_identifier_threshold=DEFAULT_THRESHOLD,
+ equivalent_identifier_cutoff=DEFAULT_CUTOFF,
):
"""Constructor.
@@ -234,8 +222,8 @@ def __init__(self,
self.choose_edition = choose_edition
self.set_edition_metadata = set_edition_metadata
self.classify = classify
- self.choose_summary=choose_summary
- self.calculate_quality=calculate_quality
+ self.choose_summary = choose_summary
+ self.calculate_quality = calculate_quality
self.choose_cover = choose_cover
# We will regenerate OPDS entries if any of the metadata
@@ -257,7 +245,6 @@ def __init__(self,
self.equivalent_identifier_threshold = equivalent_identifier_threshold
self.equivalent_identifier_cutoff = equivalent_identifier_cutoff
-
@classmethod
def recalculate_everything(cls):
"""A PresentationCalculationPolicy that always recalculates
@@ -281,9 +268,10 @@ def reset_cover(cls):
set_edition_metadata=False,
classify=False,
choose_summary=False,
- calculate_quality=False
+ calculate_quality=False,
)
+
def dump_query(query):
dialect = query.session.bind.dialect
statement = query.statement
@@ -291,19 +279,21 @@ def dump_query(query):
comp.compile()
enc = dialect.encoding
params = {}
- for k,v in list(comp.params.items()):
+ for k, v in list(comp.params.items()):
if isinstance(v, str):
v = v.encode(enc)
params[k] = sqlescape(v)
return (comp.string.encode(enc) % params).decode(enc)
+
DEBUG = False
+
class SessionManager(object):
# A function that calculates recursively equivalent identifiers
# is also defined in SQL.
- RECURSIVE_EQUIVALENTS_FUNCTION = 'recursive_equivalents.sql'
+ RECURSIVE_EQUIVALENTS_FUNCTION = "recursive_equivalents.sql"
engine_for_url = {}
@@ -320,7 +310,7 @@ def sessionmaker(cls, url=None, session=None):
bind_obj = cls.engine(url)
elif session:
bind_obj = session.get_bind()
- if not os.environ.get('TESTING'):
+ if not os.environ.get("TESTING"):
# If a factory is being created from a session in test mode,
# use the same Connection for all of the tests so objects can
# be accessed. Otherwise, bind against an Engine object.
@@ -350,12 +340,10 @@ def initialize(cls, url, initialize_data=True, initialize_schema=True):
connection = engine.connect()
# Check if the recursive equivalents function exists already.
- query = select(
- [literal_column('proname')]
- ).select_from(
- table('pg_proc')
- ).where(
- literal_column('proname')=='fn_recursive_equivalents'
+ query = (
+ select([literal_column("proname")])
+ .select_from(table("pg_proc"))
+ .where(literal_column("proname") == "fn_recursive_equivalents")
)
result = connection.execute(query)
result = list(result)
@@ -366,7 +354,10 @@ def initialize(cls, url, initialize_data=True, initialize_schema=True):
cls.resource_directory(), cls.RECURSIVE_EQUIVALENTS_FUNCTION
)
if not os.path.exists(resource_file):
- raise IOError("Could not load recursive equivalents function from %s: file does not exist." % resource_file)
+ raise IOError(
+ "Could not load recursive equivalents function from %s: file does not exist."
+ % resource_file
+ )
sql = open(resource_file).read()
connection.execute(sql)
@@ -396,8 +387,9 @@ def initialize_schema(cls, engine):
"""Initialize the database schema."""
# Use SQLAlchemy to create all the tables.
to_create = [
- table_obj for name, table_obj in list(Base.metadata.tables.items())
- if not name.startswith('mv_')
+ table_obj
+ for name, table_obj in list(Base.metadata.tables.items())
+ if not name.startswith("mv_")
]
Base.metadata.create_all(engine, tables=to_create)
@@ -407,8 +399,9 @@ def session(cls, url, initialize_data=True, initialize_schema=True):
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=SAWarning)
engine, connection = cls.initialize(
- url, initialize_data=initialize_data,
- initialize_schema=initialize_schema
+ url,
+ initialize_data=initialize_data,
+ initialize_schema=initialize_schema,
)
session = Session(connection)
if initialize_data:
@@ -418,9 +411,10 @@ def session(cls, url, initialize_data=True, initialize_schema=True):
@classmethod
def initialize_data(cls, session, set_site_configuration=True):
# Create initial content.
- from .datasource import DataSource
from .classification import Genre
+ from .datasource import DataSource
from .licensing import DeliveryMechanism
+
list(DataSource.well_known_sources(session))
# Load all existing Genre objects.
@@ -437,7 +431,10 @@ def initialize_data(cls, session, set_site_configuration=True):
# Make sure that the mechanisms fulfillable by the default
# client are marked as such.
- for content_type, drm_scheme in DeliveryMechanism.default_client_can_fulfill_lookup:
+ for (
+ content_type,
+ drm_scheme,
+ ) in DeliveryMechanism.default_client_can_fulfill_lookup:
mechanism, is_new = DeliveryMechanism.lookup(
session, content_type, drm_scheme
)
@@ -446,9 +443,11 @@ def initialize_data(cls, session, set_site_configuration=True):
# If there is currently no 'site configuration change'
# Timestamp in the database, create one.
timestamp, is_new = get_one_or_create(
- session, Timestamp, collection=None,
+ session,
+ Timestamp,
+ collection=None,
service=Configuration.SITE_CONFIGURATION_CHANGED,
- create_method_kwargs=dict(finish=utc_now())
+ create_method_kwargs=dict(finish=utc_now()),
)
if is_new:
site_configuration_has_changed(session)
@@ -458,6 +457,7 @@ def initialize_data(cls, session, set_site_configuration=True):
# it was updated by cls.update_timestamps_table
return session
+
def production_session(initialize_data=True):
url = Configuration.database_url()
if url.startswith('"'):
@@ -473,62 +473,35 @@ def production_session(initialize_data=True):
# unit tests, and 2) package_setup() will call initialize() again
# with the right arguments.
from ..log import LogConfiguration
+
LogConfiguration.initialize(_db)
return _db
-from .admin import (
- Admin,
- AdminRole,
-)
-from .coverage import (
- BaseCoverageRecord,
- CoverageRecord,
- Timestamp,
- WorkCoverageRecord,
-)
-from .cachedfeed import (
- CachedFeed,
- WillNotGenerateExpensiveFeed,
- CachedMARCFile,
-)
+
+from .admin import Admin, AdminRole
+from .cachedfeed import CachedFeed, CachedMARCFile, WillNotGenerateExpensiveFeed
from .circulationevent import CirculationEvent
-from .classification import (
- Classification,
- Genre,
- Subject,
-)
+from .classification import Classification, Genre, Subject
from .collection import (
Collection,
CollectionIdentifier,
CollectionMissing,
collections_identifiers,
)
+from .complaint import Complaint
from .configuration import (
ConfigurationSetting,
ExternalIntegration,
ExternalIntegrationLink,
)
-from .complaint import Complaint
-from .contributor import (
- Contribution,
- Contributor,
-)
-from .credential import (
- Credential,
- DelegatedPatronIdentifier,
- DRMDeviceIdentifier,
-)
-from .customlist import (
- CustomList,
- CustomListEntry,
-)
+from .contributor import Contribution, Contributor
+from .coverage import BaseCoverageRecord, CoverageRecord, Timestamp, WorkCoverageRecord
+from .credential import Credential, DelegatedPatronIdentifier, DRMDeviceIdentifier
+from .customlist import CustomList, CustomListEntry
from .datasource import DataSource
from .edition import Edition
from .hasfulltablecache import HasFullTableCache
-from .identifier import (
- Equivalency,
- Identifier,
-)
+from .identifier import Equivalency, Identifier
from .integrationclient import IntegrationClient
from .library import Library
from .licensing import (
@@ -539,6 +512,7 @@ def production_session(initialize_data=True):
PolicyException,
RightsStatus,
)
+from .listeners import *
from .measurement import Measurement
from .patron import (
Annotation,
@@ -548,14 +522,5 @@ def production_session(initialize_data=True):
Patron,
PatronProfileStorage,
)
-from .listeners import *
-from .resource import (
- Hyperlink,
- Representation,
- Resource,
- ResourceTransformation,
-)
-from .work import (
- Work,
- WorkGenre,
-)
+from .resource import Hyperlink, Representation, Resource, ResourceTransformation
+from .work import Work, WorkGenre
diff --git a/model/admin.py b/model/admin.py
index 1db7846f1..c1a7b2aa0 100644
--- a/model/admin.py
+++ b/model/admin.py
@@ -2,32 +2,19 @@
# Admin, AdminRole
-from . import (
- Base,
- get_one,
- get_one_or_create
-)
-from .hasfulltablecache import HasFullTableCache
-
import bcrypt
-from sqlalchemy import (
- Column,
- ForeignKey,
- Index,
- Integer,
- Unicode,
- UniqueConstraint,
-)
+from sqlalchemy import Column, ForeignKey, Index, Integer, Unicode, UniqueConstraint
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import (
- relationship,
- validates
-)
+from sqlalchemy.orm import relationship, validates
from sqlalchemy.orm.session import Session
+from . import Base, get_one, get_one_or_create
+from .hasfulltablecache import HasFullTableCache
+
+
class Admin(Base, HasFullTableCache):
- __tablename__ = 'admins'
+ __tablename__ = "admins"
id = Column(Integer, primary_key=True)
email = Column(Unicode, unique=True, nullable=False)
@@ -52,7 +39,7 @@ def update_credentials(self, _db, credential=None):
self.credential = credential
_db.commit()
- @validates('email')
+ @validates("email")
def validate_email(self, key, address):
# strip any whitespace from email address
return address.strip()
@@ -73,6 +60,7 @@ def authenticate(cls, _db, email, password):
"""Finds an authenticated Admin by email and password
:return: Admin or None
"""
+
def lookup_hook():
return get_one(_db, Admin, email=str(email)), False
@@ -89,9 +77,16 @@ def with_password(cls, _db):
def is_system_admin(self):
_db = Session.object_session(self)
+
def lookup_hook():
- return get_one(_db, AdminRole, admin=self, role=AdminRole.SYSTEM_ADMIN), False
- role, ignore = AdminRole.by_cache_key(_db, (self.id, None, AdminRole.SYSTEM_ADMIN), lookup_hook)
+ return (
+ get_one(_db, AdminRole, admin=self, role=AdminRole.SYSTEM_ADMIN),
+ False,
+ )
+
+ role, ignore = AdminRole.by_cache_key(
+ _db, (self.id, None, AdminRole.SYSTEM_ADMIN), lookup_hook
+ )
if role:
return True
return False
@@ -100,9 +95,18 @@ def is_sitewide_library_manager(self):
_db = Session.object_session(self)
if self.is_system_admin():
return True
+
def lookup_hook():
- return get_one(_db, AdminRole, admin=self, role=AdminRole.SITEWIDE_LIBRARY_MANAGER), False
- role, ignore = AdminRole.by_cache_key(_db, (self.id, None, AdminRole.SITEWIDE_LIBRARY_MANAGER), lookup_hook)
+ return (
+ get_one(
+ _db, AdminRole, admin=self, role=AdminRole.SITEWIDE_LIBRARY_MANAGER
+ ),
+ False,
+ )
+
+ role, ignore = AdminRole.by_cache_key(
+ _db, (self.id, None, AdminRole.SITEWIDE_LIBRARY_MANAGER), lookup_hook
+ )
if role:
return True
return False
@@ -111,9 +115,16 @@ def is_sitewide_librarian(self):
_db = Session.object_session(self)
if self.is_sitewide_library_manager():
return True
+
def lookup_hook():
- return get_one(_db, AdminRole, admin=self, role=AdminRole.SITEWIDE_LIBRARIAN), False
- role, ignore = AdminRole.by_cache_key(_db, (self.id, None, AdminRole.SITEWIDE_LIBRARIAN), lookup_hook)
+ return (
+ get_one(_db, AdminRole, admin=self, role=AdminRole.SITEWIDE_LIBRARIAN),
+ False,
+ )
+
+ role, ignore = AdminRole.by_cache_key(
+ _db, (self.id, None, AdminRole.SITEWIDE_LIBRARIAN), lookup_hook
+ )
if role:
return True
return False
@@ -125,8 +136,20 @@ def is_library_manager(self, library):
return True
# If not, they could stil be a manager of _this_ library.
def lookup_hook():
- return get_one(_db, AdminRole, admin=self, library=library, role=AdminRole.LIBRARY_MANAGER), False
- role, ignore = AdminRole.by_cache_key(_db, (self.id, library.id, AdminRole.LIBRARY_MANAGER), lookup_hook)
+ return (
+ get_one(
+ _db,
+ AdminRole,
+ admin=self,
+ library=library,
+ role=AdminRole.LIBRARY_MANAGER,
+ ),
+ False,
+ )
+
+ role, ignore = AdminRole.by_cache_key(
+ _db, (self.id, library.id, AdminRole.LIBRARY_MANAGER), lookup_hook
+ )
if role:
return True
return False
@@ -141,8 +164,20 @@ def is_librarian(self, library):
return True
# If not, they might be a librarian of _this_ library.
def lookup_hook():
- return get_one(_db, AdminRole, admin=self, library=library, role=AdminRole.LIBRARIAN), False
- role, ignore = AdminRole.by_cache_key(_db, (self.id, library.id, AdminRole.LIBRARIAN), lookup_hook)
+ return (
+ get_one(
+ _db,
+ AdminRole,
+ admin=self,
+ library=library,
+ role=AdminRole.LIBRARIAN,
+ ),
+ False,
+ )
+
+ role, ignore = AdminRole.by_cache_key(
+ _db, (self.id, library.id, AdminRole.LIBRARIAN), lookup_hook
+ )
if role:
return True
return False
@@ -157,7 +192,9 @@ def can_see_collection(self, collection):
def add_role(self, role, library=None):
_db = Session.object_session(self)
- role, is_new = get_one_or_create(_db, AdminRole, admin=self, role=role, library=library)
+ role, is_new = get_one_or_create(
+ _db, AdminRole, admin=self, role=role, library=library
+ )
return role
def remove_role(self, role, library=None):
@@ -169,18 +206,17 @@ def remove_role(self, role, library=None):
def __repr__(self):
return "" % self.email
+
class AdminRole(Base, HasFullTableCache):
- __tablename__ = 'adminroles'
+ __tablename__ = "adminroles"
id = Column(Integer, primary_key=True)
admin_id = Column(Integer, ForeignKey("admins.id"), nullable=False, index=True)
library_id = Column(Integer, ForeignKey("libraries.id"), nullable=True, index=True)
role = Column(Unicode, nullable=False, index=True)
- __table_args__ = (
- UniqueConstraint('admin_id', 'library_id', 'role'),
- )
+ __table_args__ = (UniqueConstraint("admin_id", "library_id", "role"),)
SYSTEM_ADMIN = "system"
SITEWIDE_LIBRARY_MANAGER = "manager-all"
@@ -188,7 +224,13 @@ class AdminRole(Base, HasFullTableCache):
SITEWIDE_LIBRARIAN = "librarian-all"
LIBRARIAN = "librarian"
- ROLES = [SYSTEM_ADMIN, SITEWIDE_LIBRARY_MANAGER, LIBRARY_MANAGER, SITEWIDE_LIBRARIAN, LIBRARIAN]
+ ROLES = [
+ SYSTEM_ADMIN,
+ SITEWIDE_LIBRARY_MANAGER,
+ LIBRARY_MANAGER,
+ SITEWIDE_LIBRARIAN,
+ LIBRARIAN,
+ ]
_cache = HasFullTableCache.RESET
_id_cache = HasFullTableCache.RESET
@@ -203,7 +245,15 @@ def to_dict(self):
def __repr__(self):
return "" % (
- self.role, (self.library and self.library.short_name), self.admin.email)
+ self.role,
+ (self.library and self.library.short_name),
+ self.admin.email,
+ )
-Index("ix_adminroles_admin_id_library_id_role", AdminRole.admin_id, AdminRole.library_id, AdminRole.role)
+Index(
+ "ix_adminroles_admin_id_library_id_role",
+ AdminRole.admin_id,
+ AdminRole.library_id,
+ AdminRole.role,
+)
diff --git a/model/cachedfeed.py b/model/cachedfeed.py
index afcf8f699..9c8a1f87e 100644
--- a/model/cachedfeed.py
+++ b/model/cachedfeed.py
@@ -1,41 +1,27 @@
# encoding: utf-8
# CachedFeed, WillNotGenerateExpensiveFeed
-from . import (
- Base,
- flush,
- get_one,
- get_one_or_create,
-)
-from collections import namedtuple
import datetime
import logging
-from sqlalchemy import (
- Column,
- DateTime,
- ForeignKey,
- Index,
- Integer,
- Unicode,
-)
-from sqlalchemy.sql.expression import (
- and_,
-)
+from collections import namedtuple
+
+from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, Unicode
+from sqlalchemy.sql.expression import and_
-from ..util.flask_util import OPDSFeedResponse
from ..util.datetime_helpers import utc_now
+from ..util.flask_util import OPDSFeedResponse
+from . import Base, flush, get_one, get_one_or_create
+
class CachedFeed(Base):
- __tablename__ = 'cachedfeeds'
+ __tablename__ = "cachedfeeds"
id = Column(Integer, primary_key=True)
# Every feed is associated with a lane. If null, this is a feed
# for a WorkList. If work_id is also null, it's a feed for the
# top-level.
- lane_id = Column(
- Integer, ForeignKey('lanes.id'),
- nullable=True, index=True)
+ lane_id = Column(Integer, ForeignKey("lanes.id"), nullable=True, index=True)
# Every feed has a timestamp reflecting when it was created.
timestamp = Column(DateTime(timezone=True), nullable=True, index=True)
@@ -58,23 +44,20 @@ class CachedFeed(Base):
content = Column(Unicode, nullable=True)
# Every feed is associated with a Library.
- library_id = Column(
- Integer, ForeignKey('libraries.id'), index=True
- )
+ library_id = Column(Integer, ForeignKey("libraries.id"), index=True)
# A feed may be associated with a Work.
- work_id = Column(Integer, ForeignKey('works.id'),
- nullable=True, index=True)
+ work_id = Column(Integer, ForeignKey("works.id"), nullable=True, index=True)
# Distinct types of feeds that might be cached.
- GROUPS_TYPE = 'groups'
- PAGE_TYPE = 'page'
- NAVIGATION_TYPE = 'navigation'
- CRAWLABLE_TYPE = 'crawlable'
- RELATED_TYPE = 'related'
- RECOMMENDATIONS_TYPE = 'recommendations'
- SERIES_TYPE = 'series'
- CONTRIBUTOR_TYPE = 'contributor'
+ GROUPS_TYPE = "groups"
+ PAGE_TYPE = "page"
+ NAVIGATION_TYPE = "navigation"
+ CRAWLABLE_TYPE = "crawlable"
+ RELATED_TYPE = "related"
+ RECOMMENDATIONS_TYPE = "recommendations"
+ SERIES_TYPE = "series"
+ CONTRIBUTOR_TYPE = "contributor"
# Special constants for cache durations.
CACHE_FOREVER = object()
@@ -83,8 +66,16 @@ class CachedFeed(Base):
log = logging.getLogger("CachedFeed")
@classmethod
- def fetch(cls, _db, worklist, facets, pagination, refresher_method,
- max_age=None, raw=False, **response_kwargs
+ def fetch(
+ cls,
+ _db,
+ worklist,
+ facets,
+ pagination,
+ refresher_method,
+ max_age=None,
+ raw=False,
+ **response_kwargs
):
"""Retrieve a cached feed from the database if possible.
@@ -132,9 +123,9 @@ def fetch(cls, _db, worklist, facets, pagination, refresher_method,
# TODO: this constraint_clause might not be necessary anymore.
# ISTR it was an attempt to avoid race conditions, and we do a
# better job of that now.
- constraint_clause = and_(cls.content!=None, cls.timestamp!=None)
+ constraint_clause = and_(cls.content != None, cls.timestamp != None)
kwargs = dict(
- on_multiple='interchangeable',
+ on_multiple="interchangeable",
constraint=constraint_clause,
type=keys.feed_type,
library=keys.library,
@@ -142,10 +133,10 @@ def fetch(cls, _db, worklist, facets, pagination, refresher_method,
lane_id=keys.lane_id,
unique_key=keys.unique_key,
facets=keys.facets_key,
- pagination=keys.pagination_key
+ pagination=keys.pagination_key,
)
feed_data = None
- if (max_age is cls.IGNORE_CACHE or isinstance(max_age, int) and max_age <= 0):
+ if max_age is cls.IGNORE_CACHE or isinstance(max_age, int) and max_age <= 0:
# Don't even bother checking for a CachedFeed: we're
# just going to replace it.
feed_obj = None
@@ -189,13 +180,13 @@ def fetch(cls, _db, worklist, facets, pagination, refresher_method,
#
# Set some defaults in case the caller didn't pass them in.
if isinstance(max_age, int):
- response_kwargs.setdefault('max_age', max_age)
+ response_kwargs.setdefault("max_age", max_age)
if max_age == cls.IGNORE_CACHE:
# If we were asked to ignore our internal cache, we should
# also tell the client not to store this document in _its_
# internal cache.
- response_kwargs['max_age'] = 0
+ response_kwargs["max_age"] = 0
if keys.library and keys.library.has_root_lanes:
# If this feed is associated with a Library that guides
@@ -208,12 +199,9 @@ def fetch(cls, _db, worklist, facets, pagination, refresher_method,
# TODO: it might be possible to make this decision in a
# more fine-grained way, which would allow intermediaries
# to cache these feeds.
- response_kwargs['private'] = True
+ response_kwargs["private"] = True
- return OPDSFeedResponse(
- response=feed_data,
- **response_kwargs
- )
+ return OPDSFeedResponse(response=feed_data, **response_kwargs)
@classmethod
def feed_type(cls, worklist, facets):
@@ -288,9 +276,9 @@ def _should_refresh(cls, feed_obj, max_age):
# If we found *anything*, and the cache time is CACHE_FOREVER,
# we will never refresh.
should_refresh = False
- elif (feed_obj.timestamp
- and feed_obj.timestamp + datetime.timedelta(seconds=max_age) <=
- utc_now()
+ elif (
+ feed_obj.timestamp
+ and feed_obj.timestamp + datetime.timedelta(seconds=max_age) <= utc_now()
):
# Here it comes down to a date comparison: how old is the
# CachedFeed?
@@ -300,9 +288,16 @@ def _should_refresh(cls, feed_obj, max_age):
# This named tuple makes it easy to manage the return value of
# _prepare_keys.
CachedFeedKeys = namedtuple(
- 'CachedFeedKeys',
- ['feed_type', 'library', 'work', 'lane_id', 'unique_key', 'facets_key',
- 'pagination_key']
+ "CachedFeedKeys",
+ [
+ "feed_type",
+ "library",
+ "work",
+ "lane_id",
+ "unique_key",
+ "facets_key",
+ "pagination_key",
+ ],
)
@classmethod
@@ -319,9 +314,7 @@ def _prepare_keys(cls, _db, worklist, facets, pagination):
:return: A CachedFeedKeys object.
"""
if not worklist:
- raise ValueError(
- "Cannot prepare a CachedFeed without a WorkList."
- )
+ raise ValueError("Cannot prepare a CachedFeed without a WorkList.")
feed_type = cls.feed_type(worklist, facets)
@@ -330,10 +323,11 @@ def _prepare_keys(cls, _db, worklist, facets, pagination):
# A feed may be associated with a specific Work,
# e.g. recommendations for readers of that Work.
- work = getattr(worklist, 'work', None)
+ work = getattr(worklist, "work", None)
# Either lane_id or unique_key must be set, but not both.
from ..lane import Lane
+
if isinstance(worklist, Lane):
lane_id = worklist.id
unique_key = None
@@ -356,9 +350,13 @@ def _prepare_keys(cls, _db, worklist, facets, pagination):
pagination_key = pagination.query_string
return cls.CachedFeedKeys(
- feed_type=feed_type, library=library, work=work, lane_id=lane_id,
- unique_key=unique_key, facets_key=facets_key,
- pagination_key=pagination_key
+ feed_type=feed_type,
+ library=library,
+ work=work,
+ lane_id=lane_id,
+ unique_key=unique_key,
+ facets_key=facets_key,
+ pagination_key=pagination_key,
)
def update(self, _db, content):
@@ -372,16 +370,23 @@ def __repr__(self):
else:
length = "No content"
return "" % (
- self.id, self.lane_id, self.type,
- self.facets, self.pagination,
- self.timestamp, length
+ self.id,
+ self.lane_id,
+ self.type,
+ self.facets,
+ self.pagination,
+ self.timestamp,
+ length,
)
Index(
"ix_cachedfeeds_library_id_lane_id_type_facets_pagination",
- CachedFeed.library_id, CachedFeed.lane_id, CachedFeed.type,
- CachedFeed.facets, CachedFeed.pagination
+ CachedFeed.library_id,
+ CachedFeed.lane_id,
+ CachedFeed.type,
+ CachedFeed.facets,
+ CachedFeed.pagination,
)
@@ -389,28 +394,26 @@ class WillNotGenerateExpensiveFeed(Exception):
"""This exception is raised when a feed is not cached, but it's too
expensive to generate.
"""
+
pass
+
class CachedMARCFile(Base):
"""A record that a MARC file has been created and cached for a particular lane."""
- __tablename__ = 'cachedmarcfiles'
+ __tablename__ = "cachedmarcfiles"
id = Column(Integer, primary_key=True)
# Every MARC file is associated with a library and a lane. If the
# lane is null, the file is for the top-level WorkList.
- library_id = Column(
- Integer, ForeignKey('libraries.id'),
- nullable=False, index=True)
+ library_id = Column(Integer, ForeignKey("libraries.id"), nullable=False, index=True)
- lane_id = Column(
- Integer, ForeignKey('lanes.id'),
- nullable=True, index=True)
+ lane_id = Column(Integer, ForeignKey("lanes.id"), nullable=True, index=True)
# The representation for this file stores the URL where it was mirrored.
representation_id = Column(
- Integer, ForeignKey('representations.id'),
- nullable=False)
+ Integer, ForeignKey("representations.id"), nullable=False
+ )
start_time = Column(DateTime(timezone=True), nullable=True, index=True)
end_time = Column(DateTime(timezone=True), nullable=True, index=True)
diff --git a/model/circulationevent.py b/model/circulationevent.py
index 666d7fae1..2312f6b9f 100644
--- a/model/circulationevent.py
+++ b/model/circulationevent.py
@@ -3,21 +3,12 @@
import logging
-from sqlalchemy import (
- Column,
- DateTime,
- ForeignKey,
- Index,
- Integer,
- String,
- Unicode,
-)
-
-from . import (
- Base,
- get_one_or_create,
-)
+
+from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Unicode
+
from ..util.datetime_helpers import utc_now
+from . import Base, get_one_or_create
+
class CirculationEvent(Base):
@@ -25,7 +16,8 @@ class CirculationEvent(Base):
We log these so we can measure things like the velocity of
individual books.
"""
- __tablename__ = 'circulationevents'
+
+ __tablename__ = "circulationevents"
# Used to explicitly tag an event as happening at an unknown time.
NO_DATE = object()
@@ -33,8 +25,7 @@ class CirculationEvent(Base):
id = Column(Integer, primary_key=True)
# One LicensePool can have many circulation events.
- license_pool_id = Column(
- Integer, ForeignKey('licensepools.id'), index=True)
+ license_pool_id = Column(Integer, ForeignKey("licensepools.id"), index=True)
type = Column(String(32), index=True)
start = Column(DateTime(timezone=True), index=True)
@@ -45,10 +36,7 @@ class CirculationEvent(Base):
# The Library associated with the event, if it happened in the
# context of a particular Library and we know which one.
- library_id = Column(
- Integer, ForeignKey('libraries.id'),
- index=True, nullable=True
- )
+ library_id = Column(Integer, ForeignKey("libraries.id"), index=True, nullable=True)
# The geographic location associated with the event. This string
# may mean different things for different libraries. It might be a
@@ -64,11 +52,7 @@ class CirculationEvent(Base):
#
# TODO: Maybe there should also be an index that takes
# library_id into account, for per-library event lists.
- Index(
- "ix_circulationevents_start_desc_nullslast",
- start.desc().nullslast()
- ),
-
+ Index("ix_circulationevents_start_desc_nullslast", start.desc().nullslast()),
# License pool ID + library ID + type + start must be unique.
Index(
"ix_circulationevents_license_pool_library_type_start",
@@ -76,9 +60,8 @@ class CirculationEvent(Base):
library_id,
type,
start,
- unique=True
+ unique=True,
),
-
# However, library_id may be null. If this is so, then license pool ID
# + type + start must be unique.
Index(
@@ -87,7 +70,7 @@ class CirculationEvent(Base):
type,
start,
unique=True,
- postgresql_where=(library_id==None)
+ postgresql_where=(library_id == None),
),
)
@@ -125,13 +108,22 @@ class CirculationEvent(Base):
OPEN_BOOK,
]
-
# The time format used when exporting to JSON.
TIME_FORMAT = "%Y-%m-%dT%H:%M:%S+00:00"
@classmethod
- def log(cls, _db, license_pool, event_name, old_value, new_value,
- start=None, end=None, library=None, location=None):
+ def log(
+ cls,
+ _db,
+ license_pool,
+ event_name,
+ old_value,
+ new_value,
+ start=None,
+ end=None,
+ library=None,
+ location=None,
+ ):
"""Log a CirculationEvent to the database, assuming it
hasn't already been recorded.
"""
@@ -144,15 +136,19 @@ def log(cls, _db, license_pool, event_name, old_value, new_value,
if not end:
end = start
event, was_new = get_one_or_create(
- _db, CirculationEvent, license_pool=license_pool,
- type=event_name, start=start, library=library,
+ _db,
+ CirculationEvent,
+ license_pool=license_pool,
+ type=event_name,
+ start=start,
+ library=library,
create_method_kwargs=dict(
old_value=old_value,
new_value=new_value,
delta=delta,
end=end,
- location=location
- )
+ location=location,
+ ),
)
if was_new:
logging.info("EVENT %s %s=>%s", event_name, old_value, new_value)
diff --git a/model/classification.py b/model/classification.py
index 20d557897..b13cc1a4e 100644
--- a/model/classification.py
+++ b/model/classification.py
@@ -2,25 +2,6 @@
# Subject, Classification, Genre
-from . import (
- Base,
- get_one,
- get_one_or_create,
- numericrange_to_string,
- numericrange_to_tuple,
- tuple_to_numericrange,
-)
-from .constants import DataSourceConstants
-from .hasfulltablecache import HasFullTableCache
-
-from .. import classifier
-from ..classifier import (
- Classifier,
- COMICS_AND_GRAPHIC_NOVELS,
- Erotica,
- GenreData,
-)
-
import logging
from sqlalchemy import (
@@ -28,10 +9,10 @@
Column,
Enum,
ForeignKey,
- func,
Integer,
Unicode,
UniqueConstraint,
+ func,
)
from sqlalchemy.dialects.postgresql import INT4RANGE
from sqlalchemy.ext.associationproxy import association_proxy
@@ -39,25 +20,37 @@
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.functions import func
+from .. import classifier
+from ..classifier import COMICS_AND_GRAPHIC_NOVELS, Classifier, Erotica, GenreData
+from . import (
+ Base,
+ get_one,
+ get_one_or_create,
+ numericrange_to_string,
+ numericrange_to_tuple,
+ tuple_to_numericrange,
+)
+from .constants import DataSourceConstants
+from .hasfulltablecache import HasFullTableCache
+
+
class Subject(Base):
"""A subject under which books might be classified."""
# Types of subjects.
- LCC = Classifier.LCC # Library of Congress Classification
- LCSH = Classifier.LCSH # Library of Congress Subject Headings
+ LCC = Classifier.LCC # Library of Congress Classification
+ LCSH = Classifier.LCSH # Library of Congress Subject Headings
FAST = Classifier.FAST
- DDC = Classifier.DDC # Dewey Decimal Classification
+ DDC = Classifier.DDC # Dewey Decimal Classification
OVERDRIVE = Classifier.OVERDRIVE # Overdrive's classification system
BISAC = Classifier.BISAC
- BIC = Classifier.BIC # BIC Subject Categories
- TAG = Classifier.TAG # Folksonomic tags.
+ BIC = Classifier.BIC # BIC Subject Categories
+ TAG = Classifier.TAG # Folksonomic tags.
FREEFORM_AUDIENCE = Classifier.FREEFORM_AUDIENCE
NYPL_APPEAL = Classifier.NYPL_APPEAL
# Types with terms that are suitable for search.
- TYPES_FOR_SEARCH = [
- FAST, OVERDRIVE, BISAC, TAG
- ]
+ TYPES_FOR_SEARCH = [FAST, OVERDRIVE, BISAC, TAG]
AXIS_360_AUDIENCE = Classifier.AXIS_360_AUDIENCE
GRADE_LEVEL = Classifier.GRADE_LEVEL
@@ -75,26 +68,26 @@ class Subject(Base):
SIMPLIFIED_FICTION_STATUS = Classifier.SIMPLIFIED_FICTION_STATUS
by_uri = {
- SIMPLIFIED_GENRE : SIMPLIFIED_GENRE,
- SIMPLIFIED_FICTION_STATUS : SIMPLIFIED_FICTION_STATUS,
- "http://librarysimplified.org/terms/genres/Overdrive/" : OVERDRIVE,
- "http://librarysimplified.org/terms/genres/3M/" : BISAC,
- "http://id.worldcat.org/fast/" : FAST, # I don't think this is official.
- "http://purl.org/dc/terms/LCC" : LCC,
- "http://purl.org/dc/terms/LCSH" : LCSH,
- "http://purl.org/dc/terms/DDC" : DDC,
- "http://schema.org/typicalAgeRange" : AGE_RANGE,
- "http://schema.org/audience" : FREEFORM_AUDIENCE,
- "http://www.bisg.org/standards/bisac_subject/" : BISAC,
+ SIMPLIFIED_GENRE: SIMPLIFIED_GENRE,
+ SIMPLIFIED_FICTION_STATUS: SIMPLIFIED_FICTION_STATUS,
+ "http://librarysimplified.org/terms/genres/Overdrive/": OVERDRIVE,
+ "http://librarysimplified.org/terms/genres/3M/": BISAC,
+ "http://id.worldcat.org/fast/": FAST, # I don't think this is official.
+ "http://purl.org/dc/terms/LCC": LCC,
+ "http://purl.org/dc/terms/LCSH": LCSH,
+ "http://purl.org/dc/terms/DDC": DDC,
+ "http://schema.org/typicalAgeRange": AGE_RANGE,
+ "http://schema.org/audience": FREEFORM_AUDIENCE,
+ "http://www.bisg.org/standards/bisac_subject/": BISAC,
# Feedbooks uses a modified BISAC which we know how to handle.
- "http://www.feedbooks.com/categories" : BISAC,
+ "http://www.feedbooks.com/categories": BISAC,
}
uri_lookup = dict()
for k, v in list(by_uri.items()):
uri_lookup[v] = k
- __tablename__ = 'subjects'
+ __tablename__ = "subjects"
id = Column(Integer, primary_key=True)
# Type should be one of the constants in this class.
type = Column(Unicode, index=True)
@@ -114,16 +107,24 @@ class Subject(Base):
# Whether classification under this subject implies anything about
# the book's audience.
audience = Column(
- Enum("Adult", "Young Adult", "Children", "Adults Only",
- "All Ages", "Research",
- name="audience"),
- default=None, index=True)
+ Enum(
+ "Adult",
+ "Young Adult",
+ "Children",
+ "Adults Only",
+ "All Ages",
+ "Research",
+ name="audience",
+ ),
+ default=None,
+ index=True,
+ )
# For children's books, the target age implied by this subject.
target_age = Column(INT4RANGE, default=None, index=True)
# Each Subject may claim affinity with one Genre.
- genre_id = Column(Integer, ForeignKey('genres.id'), index=True)
+ genre_id = Column(Integer, ForeignKey("genres.id"), index=True)
# A locked Subject has been reviewed by a human and software will
# not mess with it without permission.
@@ -134,14 +135,10 @@ class Subject(Base):
checked = Column(Boolean, default=False, index=True)
# One Subject may participate in many Classifications.
- classifications = relationship(
- "Classification", backref="subject"
- )
+ classifications = relationship("Classification", backref="subject")
# Type + identifier must be unique.
- __table_args__ = (
- UniqueConstraint('type', 'identifier'),
- )
+ __table_args__ = (UniqueConstraint("type", "identifier"),)
def __repr__(self):
if self.name:
@@ -162,14 +159,21 @@ def __repr__(self):
genre = ' genre="%s"' % self.genre.name
else:
genre = ""
- if (self.target_age is not None
- and (self.target_age.lower or self.target_age.upper)
+ if self.target_age is not None and (
+ self.target_age.lower or self.target_age.upper
):
- age_range= " " + self.target_age_string
+ age_range = " " + self.target_age_string
else:
age_range = ""
- a = '[%s:%s%s%s%s%s%s]' % (
- self.type, self.identifier, name, fiction, audience, genre, age_range)
+ a = "[%s:%s%s%s%s%s%s]" % (
+ self.type,
+ self.identifier,
+ name,
+ fiction,
+ audience,
+ genre,
+ age_range,
+ )
return str(a)
@property
@@ -185,7 +189,7 @@ def describes_format(self):
different adaptation of the same underlying work.
TODO: See note in assign_genres about the hacky way this is used.
"""
- if self.genre and self.genre.name==COMICS_AND_GRAPHIC_NOVELS:
+ if self.genre and self.genre.name == COMICS_AND_GRAPHIC_NOVELS:
return True
return False
@@ -210,14 +214,12 @@ def lookup(cls, _db, type, identifier, name, autocreate=True):
# Type + identifier is unique, but type + name is not
# (though maybe it should be). So we need to provide
# on_multiple.
- find_with = dict(name=name, on_multiple='interchangeable')
+ find_with = dict(name=name, on_multiple="interchangeable")
create_with = dict()
if autocreate:
subject, new = get_one_or_create(
- _db, Subject, type=type,
- create_method_kwargs=create_with,
- **find_with
+ _db, Subject, type=type, create_method_kwargs=create_with, **find_with
)
else:
subject = get_one(_db, Subject, type=type, **find_with)
@@ -229,20 +231,22 @@ def lookup(cls, _db, type, identifier, name, autocreate=True):
return subject, new
@classmethod
- def common_but_not_assigned_to_genre(cls, _db, min_occurances=1000,
- type_restriction=None):
- q = _db.query(Subject).join(Classification).filter(Subject.genre==None)
+ def common_but_not_assigned_to_genre(
+ cls, _db, min_occurances=1000, type_restriction=None
+ ):
+ q = _db.query(Subject).join(Classification).filter(Subject.genre == None)
if type_restriction:
- q = q.filter(Subject.type==type_restriction)
- q = q.group_by(Subject.id).having(
- func.count(Subject.id) > min_occurances).order_by(
- func.count(Classification.id).desc())
+ q = q.filter(Subject.type == type_restriction)
+ q = (
+ q.group_by(Subject.id)
+ .having(func.count(Subject.id) > min_occurances)
+ .order_by(func.count(Classification.id).desc())
+ )
return q
@classmethod
- def assign_to_genres(cls, _db, type_restriction=None, force=False,
- batch_size=1000):
+ def assign_to_genres(cls, _db, type_restriction=None, force=False, batch_size=1000):
"""Find subjects that have not been checked yet, assign each a
genre/audience/fiction status if possible, and mark each as checked.
@@ -253,13 +257,13 @@ def assign_to_genres(cls, _db, type_restriction=None, force=False,
subjects have been checked.
"""
- q = _db.query(Subject).filter(Subject.locked==False)
+ q = _db.query(Subject).filter(Subject.locked == False)
if type_restriction:
- q = q.filter(Subject.type==type_restriction)
+ q = q.filter(Subject.type == type_restriction)
if not force:
- q = q.filter(Subject.checked==False)
+ q = q.filter(Subject.checked == False)
counter = 0
for subject in q:
@@ -302,42 +306,39 @@ def assign_to_genre(self):
shorthand = ":".join(x for x in parts if x)
if genre != self.genre:
- log.info(
- "%s genre %r=>%r", shorthand, self.genre, genre
- )
+ log.info("%s genre %r=>%r", shorthand, self.genre, genre)
self.genre = genre
if audience:
if self.audience != audience:
- log.info(
- "%s audience %s=>%s", shorthand, self.audience, audience
- )
+ log.info("%s audience %s=>%s", shorthand, self.audience, audience)
self.audience = audience
if fiction is not None:
if self.fiction != fiction:
- log.info(
- "%s fiction %s=>%s", shorthand, self.fiction, fiction
- )
+ log.info("%s fiction %s=>%s", shorthand, self.fiction, fiction)
self.fiction = fiction
- if (numericrange_to_tuple(self.target_age) != target_age and
- not (not self.target_age and not target_age)):
+ if numericrange_to_tuple(self.target_age) != target_age and not (
+ not self.target_age and not target_age
+ ):
log.info(
- "%s target_age %r=>%r", shorthand,
- self.target_age, tuple_to_numericrange(target_age)
+ "%s target_age %r=>%r",
+ shorthand,
+ self.target_age,
+ tuple_to_numericrange(target_age),
)
self.target_age = tuple_to_numericrange(target_age)
class Classification(Base):
"""The assignment of a Identifier to a Subject."""
- __tablename__ = 'classifications'
+
+ __tablename__ = "classifications"
id = Column(Integer, primary_key=True)
- identifier_id = Column(
- Integer, ForeignKey('identifiers.id'), index=True)
- subject_id = Column(Integer, ForeignKey('subjects.id'), index=True)
- data_source_id = Column(Integer, ForeignKey('datasources.id'), index=True)
+ identifier_id = Column(Integer, ForeignKey("identifiers.id"), index=True)
+ subject_id = Column(Integer, ForeignKey("subjects.id"), index=True)
+ data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True)
# How much weight the data source gives to this classification.
weight = Column(Integer)
@@ -366,12 +367,9 @@ def scaled_weight(self):
# This goes into Classification rather than Subject because it's
# possible that one particular data source could use a certain
# subject type in an unreliable way.
- _juvenile_subject_types = set([
- Subject.LCC
- ])
+ _juvenile_subject_types = set([Subject.LCC])
_quality_as_indicator_of_target_age = {
-
# Not all classifications are equally reliable as indicators
# of a target age. This dictionary contains the coefficients
# we multiply against the weights of incoming classifications
@@ -383,42 +381,35 @@ def scaled_weight(self):
# classifications. But we sometimes have very little
# information about target age, so being careful about how
# much we trust different data sources can become important.
-
- DataSourceConstants.MANUAL : 1.0,
+ DataSourceConstants.MANUAL: 1.0,
DataSourceConstants.LIBRARY_STAFF: 1.0,
- (DataSourceConstants.METADATA_WRANGLER, Subject.AGE_RANGE) : 1.0,
-
- Subject.AXIS_360_AUDIENCE : 0.9,
- (DataSourceConstants.OVERDRIVE, Subject.INTEREST_LEVEL) : 0.9,
- (DataSourceConstants.OVERDRIVE, Subject.OVERDRIVE) : 0.9, # But see below
- (DataSourceConstants.AMAZON, Subject.AGE_RANGE) : 0.85,
- (DataSourceConstants.AMAZON, Subject.GRADE_LEVEL) : 0.85,
-
+ (DataSourceConstants.METADATA_WRANGLER, Subject.AGE_RANGE): 1.0,
+ Subject.AXIS_360_AUDIENCE: 0.9,
+ (DataSourceConstants.OVERDRIVE, Subject.INTEREST_LEVEL): 0.9,
+ (DataSourceConstants.OVERDRIVE, Subject.OVERDRIVE): 0.9, # But see below
+ (DataSourceConstants.AMAZON, Subject.AGE_RANGE): 0.85,
+ (DataSourceConstants.AMAZON, Subject.GRADE_LEVEL): 0.85,
# Although Overdrive usually reserves Fiction and Nonfiction
# for books for adults, it's not as reliable an indicator as
# other Overdrive classifications.
- (DataSourceConstants.OVERDRIVE, Subject.OVERDRIVE, "Fiction") : 0.7,
- (DataSourceConstants.OVERDRIVE, Subject.OVERDRIVE, "Nonfiction") : 0.7,
-
- Subject.AGE_RANGE : 0.6,
- Subject.GRADE_LEVEL : 0.6,
-
+ (DataSourceConstants.OVERDRIVE, Subject.OVERDRIVE, "Fiction"): 0.7,
+ (DataSourceConstants.OVERDRIVE, Subject.OVERDRIVE, "Nonfiction"): 0.7,
+ Subject.AGE_RANGE: 0.6,
+ Subject.GRADE_LEVEL: 0.6,
# There's no real way to know what this measures, since it
# could be anything. If a tag mentions a target age or a grade
# level, the accuracy seems to be... not terrible.
- Subject.TAG : 0.45,
-
+ Subject.TAG: 0.45,
# Tags that come from OCLC Linked Data are of lower quality
# because they sometimes talk about completely the wrong book.
- (DataSourceConstants.OCLC_LINKED_DATA, Subject.TAG) : 0.3,
-
+ (DataSourceConstants.OCLC_LINKED_DATA, Subject.TAG): 0.3,
# These measure reading level, not age appropriateness.
# However, if the book is a remedial work for adults we won't
# be calculating a target age in the first place, so it's okay
# to use reading level as a proxy for age appropriateness in a
# pinch. (But not outside of a pinch.)
- (DataSourceConstants.OVERDRIVE, Subject.GRADE_LEVEL) : 0.35,
- Subject.LEXILE_SCORE : 0.1,
+ (DataSourceConstants.OVERDRIVE, Subject.GRADE_LEVEL): 0.35,
+ Subject.LEXILE_SCORE: 0.1,
Subject.ATOS_SCORE: 0.1,
}
@@ -444,7 +435,7 @@ def quality_as_indicator_of_target_age(self):
(data_source, subject_type, self.subject.identifier),
(data_source, subject_type),
data_source,
- subject_type
+ subject_type,
]
for key in keys:
if key in q:
@@ -472,7 +463,8 @@ class Genre(Base, HasFullTableCache):
"""A subject-matter classification for a book.
Much, much more general than Classification.
"""
- __tablename__ = 'genres'
+
+ __tablename__ = "genres"
id = Column(Integer, primary_key=True)
name = Column(Unicode, unique=True, index=True)
@@ -480,10 +472,11 @@ class Genre(Base, HasFullTableCache):
subjects = relationship("Subject", backref="genre")
# One Genre may participate in many WorkGenre assignments.
- works = association_proxy('work_genres', 'work')
+ works = association_proxy("work_genres", "work")
- work_genres = relationship("WorkGenre", backref="genre",
- cascade="all, delete-orphan")
+ work_genres = relationship(
+ "WorkGenre", backref="genre", cascade="all, delete-orphan"
+ )
_cache = HasFullTableCache.RESET
_id_cache = HasFullTableCache.RESET
@@ -494,7 +487,11 @@ def __repr__(self):
else:
length = 0
return "" % (
- self.name, len(self.subjects), len(self.works), length)
+ self.name,
+ len(self.subjects),
+ len(self.works),
+ length,
+ )
def cache_key(self):
return self.name
diff --git a/model/collection.py b/model/collection.py
index 22ef2bfb5..0fec39b4b 100644
--- a/model/collection.py
+++ b/model/collection.py
@@ -4,73 +4,54 @@
from abc import ABCMeta, abstractmethod
from sqlalchemy import (
+ Boolean,
Column,
- exists,
ForeignKey,
- func,
Integer,
Table,
Unicode,
UniqueConstraint,
- Boolean
+ exists,
+ func,
)
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.orm import (
- backref,
- contains_eager,
- joinedload,
- mapper,
- relationship,
-)
+from sqlalchemy.orm import backref, contains_eager, joinedload, mapper, relationship
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import Session
-from sqlalchemy.sql.expression import (
- and_,
- or_,
-)
+from sqlalchemy.sql.expression import and_, or_
+from ..util.string_helpers import base64
+from . import Base, create, get_one, get_one_or_create
from .configuration import (
+ BaseConfigurationStorage,
ConfigurationSetting,
ExternalIntegration,
- BaseConfigurationStorage)
-from .constants import EditionConstants
-from .coverage import (
- CoverageRecord,
- WorkCoverageRecord,
)
+from .constants import EditionConstants
+from .coverage import CoverageRecord, WorkCoverageRecord
from .datasource import DataSource
from .edition import Edition
from .hasfulltablecache import HasFullTableCache
from .identifier import Identifier
from .integrationclient import IntegrationClient
from .library import Library
-from .licensing import (
- LicensePool,
- LicensePoolDeliveryMechanism,
-)
+from .licensing import LicensePool, LicensePoolDeliveryMechanism
from .work import Work
-from . import (
- Base,
- create,
- get_one,
- get_one_or_create,
-)
-from ..util.string_helpers import base64
+
class Collection(Base, HasFullTableCache):
- """A Collection is a set of LicensePools obtained through some mechanism.
- """
+ """A Collection is a set of LicensePools obtained through some mechanism."""
- __tablename__ = 'collections'
+ __tablename__ = "collections"
id = Column(Integer, primary_key=True)
name = Column(Unicode, unique=True, nullable=False, index=True)
- DATA_SOURCE_NAME_SETTING = 'data_source'
+ DATA_SOURCE_NAME_SETTING = "data_source"
# For use in forms that edit Collections.
- EXTERNAL_ACCOUNT_ID_KEY = 'external_account_id'
+ EXTERNAL_ACCOUNT_ID_KEY = "external_account_id"
# How does the provider of this collection distinguish it from
# other collections it provides? On the other side this is usually
@@ -84,14 +65,15 @@ class Collection(Base, HasFullTableCache):
# the metadata and licenses. Each Collection has a distinct
# ExternalIntegration.
external_integration_id = Column(
- Integer, ForeignKey('externalintegrations.id'), unique=True, index=True)
+ Integer, ForeignKey("externalintegrations.id"), unique=True, index=True
+ )
# A Collection may specialize some other Collection. For instance,
# an Overdrive Advantage collection is a specialization of an
# ordinary Overdrive collection. It uses the same access key and
# secret as the Overdrive collection, but it has a distinct
# external_account_id.
- parent_id = Column(Integer, ForeignKey('collections.id'), index=True)
+ parent_id = Column(Integer, ForeignKey("collections.id"), index=True)
# When deleting a collection, this flag is set to True so that the deletion
# script can take care of deleting it in the background. This is
@@ -102,20 +84,17 @@ class Collection(Base, HasFullTableCache):
# An Overdrive collection may have many children corresponding
# to Overdrive Advantage collections.
children = relationship(
- "Collection", backref=backref("parent", remote_side = [id]),
- uselist=True
+ "Collection", backref=backref("parent", remote_side=[id]), uselist=True
)
# A Collection can provide books to many Libraries.
libraries = relationship(
- "Library", secondary=lambda: collections_libraries,
- backref="collections"
+ "Library", secondary=lambda: collections_libraries, backref="collections"
)
# A Collection can include many LicensePools.
licensepools = relationship(
- "LicensePool", backref="collection",
- cascade="all, delete-orphan"
+ "LicensePool", backref="collection", cascade="all, delete-orphan"
)
# A Collection can have many associated Credentials.
@@ -126,15 +105,13 @@ class Collection(Base, HasFullTableCache):
timestamps = relationship("Timestamp", backref="collection")
catalog = relationship(
- "Identifier", secondary=lambda: collections_identifiers,
- backref="collections"
+ "Identifier", secondary=lambda: collections_identifiers, backref="collections"
)
# A Collection can be associated with multiple CoverageRecords
# for Identifiers in its catalog.
coverage_records = relationship(
- "CoverageRecord", backref="collection",
- cascade="all"
+ "CoverageRecord", backref="collection", cascade="all"
)
# A collection may be associated with one or more custom lists.
@@ -143,8 +120,7 @@ class Collection(Base, HasFullTableCache):
# the list and they won't be added back, so the list doesn't
# necessarily match the collection.
customlists = relationship(
- "CustomList", secondary=lambda: collections_customlists,
- backref="collections"
+ "CustomList", secondary=lambda: collections_customlists, backref="collections"
)
_cache = HasFullTableCache.RESET
@@ -156,9 +132,7 @@ class Collection(Base, HasFullTableCache):
GLOBAL_COLLECTION_DATA_SOURCES = [DataSource.ENKI]
def __repr__(self):
- return '' % (
- self.name, self.protocol, self.id
- )
+ return '' % (self.name, self.protocol, self.id)
def cache_key(self):
return (self.name, self.external_integration.protocol)
@@ -173,8 +147,10 @@ def by_name_and_protocol(cls, _db, name, protocol):
:return: A 2-tuple (collection, is_new)
"""
key = (name, protocol)
+
def lookup_hook():
return cls._by_name_and_protocol(_db, key)
+
return cls.by_cache_key(_db, key, lookup_hook)
@classmethod
@@ -190,7 +166,7 @@ def _by_name_and_protocol(cls, _db, cache_key):
name, protocol = cache_key
qu = cls.by_protocol(_db, protocol)
- qu = qu.filter(Collection.name==name)
+ qu = qu.filter(Collection.name == name)
try:
collection = qu.one()
is_new = False
@@ -201,14 +177,10 @@ def _by_name_and_protocol(cls, _db, cache_key):
# The collection already exists, it just uses a different
# protocol than the one we asked about.
raise ValueError(
- 'Collection "%s" does not use protocol "%s".' % (
- name, protocol
- )
+ 'Collection "%s" does not use protocol "%s".' % (name, protocol)
)
- integration = collection.create_external_integration(
- protocol=protocol
- )
- collection.external_integration.protocol=protocol
+ integration = collection.create_external_integration(protocol=protocol)
+ collection.external_integration.protocol = protocol
return collection, is_new
@classmethod
@@ -222,12 +194,14 @@ def by_protocol(cls, _db, protocol):
"""
qu = _db.query(Collection)
if protocol:
- qu = qu.join(
- ExternalIntegration,
- ExternalIntegration.id==Collection.external_integration_id).filter(
- ExternalIntegration.goal==ExternalIntegration.LICENSE_GOAL
- ).filter(ExternalIntegration.protocol==protocol).filter(
- Collection.marked_for_deletion==False
+ qu = (
+ qu.join(
+ ExternalIntegration,
+ ExternalIntegration.id == Collection.external_integration_id,
+ )
+ .filter(ExternalIntegration.goal == ExternalIntegration.LICENSE_GOAL)
+ .filter(ExternalIntegration.protocol == protocol)
+ .filter(Collection.marked_for_deletion == False)
)
return qu
@@ -241,13 +215,17 @@ def by_datasource(cls, _db, data_source):
if isinstance(data_source, DataSource):
data_source = data_source.name
- qu = _db.query(cls).join(ExternalIntegration,
- cls.external_integration_id==ExternalIntegration.id)\
- .join(ExternalIntegration.settings)\
- .filter(ConfigurationSetting.key==Collection.DATA_SOURCE_NAME_SETTING)\
- .filter(ConfigurationSetting.value==data_source).filter(
- Collection.marked_for_deletion==False
+ qu = (
+ _db.query(cls)
+ .join(
+ ExternalIntegration,
+ cls.external_integration_id == ExternalIntegration.id,
)
+ .join(ExternalIntegration.settings)
+ .filter(ConfigurationSetting.key == Collection.DATA_SOURCE_NAME_SETTING)
+ .filter(ConfigurationSetting.value == data_source)
+ .filter(Collection.marked_for_deletion == False)
+ )
return qu
@hybrid_property
@@ -262,9 +240,8 @@ def protocol(self, new_protocol):
"""Modify the protocol in use by this Collection."""
if self.parent and self.parent.protocol != new_protocol:
raise ValueError(
- "Proposed new protocol (%s) contradicts parent collection's protocol (%s)." % (
- new_protocol, self.parent.protocol
- )
+ "Proposed new protocol (%s) contradicts parent collection's protocol (%s)."
+ % (new_protocol, self.parent.protocol)
)
self.external_integration.protocol = new_protocol
for child in self.children:
@@ -272,13 +249,15 @@ def protocol(self, new_protocol):
@hybrid_property
def primary_identifier_source(self):
- """ Identify if should try to use another identifier than """
+ """Identify if should try to use another identifier than """
return self.external_integration.primary_identifier_source
@primary_identifier_source.setter
def primary_identifier_source(self, new_primary_identifier_source):
- """ Modify the primary identifier source in use by this Collection."""
- self.external_integration.primary_identifier_source = new_primary_identifier_source
+ """Modify the primary identifier source in use by this Collection."""
+ self.external_integration.primary_identifier_source = (
+ new_primary_identifier_source
+ )
# For collections that can control the duration of the loans they
# create, the durations are stored in these settings and new loans are
@@ -286,8 +265,8 @@ def primary_identifier_source(self, new_primary_identifier_source):
# where loan duration is negotiated out-of-bounds, all loans are
# _assumed_ to have these durations unless we hear otherwise from
# the server.
- AUDIOBOOK_LOAN_DURATION_KEY = 'audio_loan_duration'
- EBOOK_LOAN_DURATION_KEY = 'ebook_loan_duration'
+ AUDIOBOOK_LOAN_DURATION_KEY = "audio_loan_duration"
+ EBOOK_LOAN_DURATION_KEY = "ebook_loan_duration"
STANDARD_DEFAULT_LOAN_PERIOD = 21
def default_loan_period(self, library, medium=EditionConstants.BOOK_MEDIUM):
@@ -295,8 +274,10 @@ def default_loan_period(self, library, medium=EditionConstants.BOOK_MEDIUM):
that someone who borrows a non-open-access item from this
collection has it for this number of days.
"""
- return self.default_loan_period_setting(
- library, medium).int_value or self.STANDARD_DEFAULT_LOAN_PERIOD
+ return (
+ self.default_loan_period_setting(library, medium).int_value
+ or self.STANDARD_DEFAULT_LOAN_PERIOD
+ )
def default_loan_period_setting(self, library, medium=EditionConstants.BOOK_MEDIUM):
"""Until we hear otherwise from the license provider, we assume
@@ -309,15 +290,13 @@ def default_loan_period_setting(self, library, medium=EditionConstants.BOOK_MEDI
else:
key = self.EBOOK_LOAN_DURATION_KEY
if isinstance(library, Library):
- return (
- ConfigurationSetting.for_library_and_externalintegration(
- _db, key, library, self.external_integration
- )
+ return ConfigurationSetting.for_library_and_externalintegration(
+ _db, key, library, self.external_integration
)
elif isinstance(library, IntegrationClient):
return self.external_integration.setting(key)
- DEFAULT_RESERVATION_PERIOD_KEY = 'default_reservation_period'
+ DEFAULT_RESERVATION_PERIOD_KEY = "default_reservation_period"
STANDARD_DEFAULT_RESERVATION_PERIOD = 3
@hybrid_property
@@ -329,20 +308,22 @@ def default_reservation_period(self):
return (
self.external_integration.setting(
self.DEFAULT_RESERVATION_PERIOD_KEY,
- ).int_value or self.STANDARD_DEFAULT_RESERVATION_PERIOD
+ ).int_value
+ or self.STANDARD_DEFAULT_RESERVATION_PERIOD
)
@default_reservation_period.setter
def default_reservation_period(self, new_value):
new_value = int(new_value)
self.external_integration.setting(
- self.DEFAULT_RESERVATION_PERIOD_KEY).value = str(new_value)
+ self.DEFAULT_RESERVATION_PERIOD_KEY
+ ).value = str(new_value)
# When you import an OPDS feed, you may know the intended audience of the works (e.g. children or researchers),
# even though the OPDS feed may not contain that information.
# It should be possible to configure a collection with a default audience,
# so that books imported from the OPDS feed end up with the right audience.
- DEFAULT_AUDIENCE_KEY = 'default_audience'
+ DEFAULT_AUDIENCE_KEY = "default_audience"
@hybrid_property
def default_audience(self):
@@ -382,14 +363,15 @@ def create_external_integration(self, protocol):
_db = Session.object_session(self)
goal = ExternalIntegration.LICENSE_GOAL
external_integration, is_new = get_one_or_create(
- _db, ExternalIntegration, id=self.external_integration_id,
- create_method_kwargs=dict(protocol=protocol, goal=goal)
+ _db,
+ ExternalIntegration,
+ id=self.external_integration_id,
+ create_method_kwargs=dict(protocol=protocol, goal=goal),
)
if external_integration.protocol != protocol:
raise ValueError(
- "Located ExternalIntegration, but its protocol (%s) does not match desired protocol (%s)." % (
- external_integration.protocol, protocol
- )
+ "Located ExternalIntegration, but its protocol (%s) does not match desired protocol (%s)."
+ % (external_integration.protocol, protocol)
)
self.external_integration_id = external_integration.id
return external_integration
@@ -416,9 +398,11 @@ def external_integration(self):
@property
def unique_account_id(self):
"""Identifier that uniquely represents this Collection of works"""
- if (self.data_source
+ if (
+ self.data_source
and self.data_source.name in self.GLOBAL_COLLECTION_DATA_SOURCES
- and not self.parent):
+ and not self.parent
+ ):
# Every top-level collection from this data source has the
# same catalog. Treat them all as one collection named
# after the data source.
@@ -430,7 +414,7 @@ def unique_account_id(self):
raise ValueError("Unique account identifier not set")
if self.parent:
- return self.parent.unique_account_id + '+' + unique_account_id
+ return self.parent.unique_account_id + "+" + unique_account_id
return unique_account_id
@hybrid_property
@@ -448,9 +432,7 @@ def data_source(self):
the data source is a Collection-specific setting.
"""
data_source = None
- name = ExternalIntegration.DATA_SOURCE_FOR_LICENSE_PROTOCOL.get(
- self.protocol
- )
+ name = ExternalIntegration.DATA_SOURCE_FOR_LICENSE_PROTOCOL.get(self.protocol)
if not name:
name = self.external_integration.setting(
Collection.DATA_SOURCE_NAME_SETTING
@@ -502,14 +484,14 @@ def metadata_identifier(self):
if self.protocol == ExternalIntegration.OPDS_IMPORT:
# Remove ending / from OPDS url that could duplicate the collection
# on the Metadata Wrangler.
- while account_id.endswith('/'):
+ while account_id.endswith("/"):
account_id = account_id[:-1]
encode = base64.urlsafe_b64encode
account_id = encode(account_id)
protocol = encode(self.protocol)
- metadata_identifier = protocol + ':' + account_id
+ metadata_identifier = protocol + ":" + account_id
return encode(metadata_identifier)
def disassociate_library(self, library):
@@ -523,12 +505,12 @@ def disassociate_library(self, library):
self.libraries.remove(library)
_db = Session.object_session(self)
- qu = _db.query(
- ConfigurationSetting
- ).filter(
- ConfigurationSetting.library==library
- ).filter(
- ConfigurationSetting.external_integration==self.external_integration
+ qu = (
+ _db.query(ConfigurationSetting)
+ .filter(ConfigurationSetting.library == library)
+ .filter(
+ ConfigurationSetting.external_integration == self.external_integration
+ )
)
qu.delete()
@@ -540,13 +522,12 @@ def _decode_metadata_identifier(cls, metadata_identifier):
try:
decode = base64.urlsafe_b64decode
details = decode(metadata_identifier)
- encoded_details = details.split(':', 1)
+ encoded_details = details.split(":", 1)
[protocol, account_id] = [decode(d) for d in encoded_details]
except (TypeError, ValueError) as e:
raise ValueError(
- "Metadata identifier '%s' is invalid: %s" % (
- metadata_identifier, str(e)
- )
+ "Metadata identifier '%s' is invalid: %s"
+ % (metadata_identifier, str(e))
)
return protocol, account_id
@@ -571,9 +552,7 @@ def from_metadata_identifier(cls, _db, metadata_identifier, data_source=None):
# identifier. Give it an ExternalIntegration with the
# corresponding protocol, and set its data source and
# external_account_id.
- collection, is_new = create(
- _db, Collection, name=metadata_identifier
- )
+ collection, is_new = create(_db, Collection, name=metadata_identifier)
collection.create_external_integration(protocol)
if protocol == ExternalIntegration.OPDS_IMPORT:
@@ -581,9 +560,7 @@ def from_metadata_identifier(cls, _db, metadata_identifier, data_source=None):
# the OPDS feed (the "account ID") and the data source.
collection.external_account_id = account_id
if data_source and not isinstance(data_source, DataSource):
- data_source = DataSource.lookup(
- _db, data_source, autocreate=True
- )
+ data_source = DataSource.lookup(_db, data_source, autocreate=True)
collection.data_source = data_source
return collection, is_new
@@ -597,7 +574,7 @@ def pools_with_no_delivery_mechanisms(self):
"""
_db = Session.object_session(self)
qu = LicensePool.with_no_delivery_mechanisms(_db)
- return qu.filter(LicensePool.collection==self)
+ return qu.filter(LicensePool.collection == self)
def explain(self, include_secrets=False):
"""Create a series of human-readable strings to explain a collection's
@@ -612,7 +589,7 @@ def explain(self, include_secrets=False):
if self.name:
lines.append('Name: "%s"' % self.name)
if self.parent:
- lines.append('Parent: %s' % self.parent.name)
+ lines.append("Parent: %s" % self.parent.name)
integration = self.external_integration
if integration.protocol:
lines.append('Protocol: "%s"' % integration.protocol)
@@ -636,13 +613,13 @@ def catalog_identifiers(self, identifiers):
return
_db = Session.object_session(identifiers[0])
- already_in_catalog = _db.query(Identifier).join(
- CollectionIdentifier
- ).filter(
- CollectionIdentifier.collection_id==self.id
- ).filter(
- Identifier.id.in_([x.id for x in identifiers])
- ).all()
+ already_in_catalog = (
+ _db.query(Identifier)
+ .join(CollectionIdentifier)
+ .filter(CollectionIdentifier.collection_id == self.id)
+ .filter(Identifier.id.in_([x.id for x in identifiers]))
+ .all()
+ )
new_catalog_entries = [
dict(collection_id=self.id, identifier_id=identifier.id)
@@ -661,18 +638,20 @@ def unresolved_catalog(self, _db, data_source_name, operation):
"""
coverage_source = DataSource.lookup(_db, data_source_name)
is_not_resolved = and_(
- CoverageRecord.operation==operation,
- CoverageRecord.data_source_id==coverage_source.id,
- CoverageRecord.status!=CoverageRecord.SUCCESS,
+ CoverageRecord.operation == operation,
+ CoverageRecord.data_source_id == coverage_source.id,
+ CoverageRecord.status != CoverageRecord.SUCCESS,
)
- query = _db.query(Identifier)\
- .outerjoin(Identifier.licensed_through)\
- .outerjoin(Identifier.coverage_records)\
- .outerjoin(LicensePool.work).outerjoin(Identifier.collections)\
- .filter(
- Collection.id==self.id, is_not_resolved, Work.id==None
- ).order_by(Identifier.id)
+ query = (
+ _db.query(Identifier)
+ .outerjoin(Identifier.licensed_through)
+ .outerjoin(Identifier.coverage_records)
+ .outerjoin(LicensePool.work)
+ .outerjoin(Identifier.collections)
+ .filter(Collection.id == self.id, is_not_resolved, Work.id == None)
+ .order_by(Identifier.id)
+ )
return query
@@ -690,21 +669,25 @@ def licensepools_with_works_updated_since(self, _db, timestamp):
necessary to create full OPDS entries for the works.
"""
opds_operation = WorkCoverageRecord.GENERATE_OPDS_OPERATION
- qu = _db.query(
- LicensePool
- ).join(
- LicensePool.work,
- ).join(
- LicensePool.identifier,
- ).join(
- Work.coverage_records,
- ).join(
- CollectionIdentifier,
- Identifier.id==CollectionIdentifier.identifier_id
+ qu = (
+ _db.query(LicensePool)
+ .join(
+ LicensePool.work,
+ )
+ .join(
+ LicensePool.identifier,
+ )
+ .join(
+ Work.coverage_records,
+ )
+ .join(
+ CollectionIdentifier,
+ Identifier.id == CollectionIdentifier.identifier_id,
+ )
)
qu = qu.filter(
- WorkCoverageRecord.operation==opds_operation,
- CollectionIdentifier.collection_id==self.id
+ WorkCoverageRecord.operation == opds_operation,
+ CollectionIdentifier.collection_id == self.id,
)
qu = qu.options(
contains_eager(LicensePool.work),
@@ -712,30 +695,33 @@ def licensepools_with_works_updated_since(self, _db, timestamp):
)
if timestamp:
- qu = qu.filter(
- WorkCoverageRecord.timestamp > timestamp
- )
+ qu = qu.filter(WorkCoverageRecord.timestamp > timestamp)
qu = qu.order_by(WorkCoverageRecord.timestamp)
return qu
def isbns_updated_since(self, _db, timestamp):
"""Finds all ISBNs in a collection's catalog that have been updated
- since the timestamp but don't have a Work to show for it. Used in
- the metadata wrangler.
+ since the timestamp but don't have a Work to show for it. Used in
+ the metadata wrangler.
- :return: a Query
+ :return: a Query
"""
- isbns = _db.query(Identifier, func.max(CoverageRecord.timestamp).label('latest'))\
- .join(Identifier.collections)\
- .join(Identifier.coverage_records)\
- .outerjoin(Identifier.licensed_through)\
- .group_by(Identifier.id).order_by('latest')\
+ isbns = (
+ _db.query(Identifier, func.max(CoverageRecord.timestamp).label("latest"))
+ .join(Identifier.collections)
+ .join(Identifier.coverage_records)
+ .outerjoin(Identifier.licensed_through)
+ .group_by(Identifier.id)
+ .order_by("latest")
.filter(
- Collection.id==self.id,
- LicensePool.work_id==None,
- CoverageRecord.status==CoverageRecord.SUCCESS,
- ).enable_eagerloads(False).options(joinedload(Identifier.coverage_records))
+ Collection.id == self.id,
+ LicensePool.work_id == None,
+ CoverageRecord.status == CoverageRecord.SUCCESS,
+ )
+ .enable_eagerloads(False)
+ .options(joinedload(Identifier.coverage_records))
+ )
if timestamp:
isbns = isbns.filter(CoverageRecord.timestamp > timestamp)
@@ -744,8 +730,11 @@ def isbns_updated_since(self, _db, timestamp):
@classmethod
def restrict_to_ready_deliverable_works(
- cls, query, collection_ids=None, show_suppressed=False,
- allow_holds=True,
+ cls,
+ query,
+ collection_ids=None,
+ show_suppressed=False,
+ allow_holds=True,
):
"""Restrict a query to show only presentation-ready works present in
an appropriate collection which the default client can
@@ -772,8 +761,10 @@ def restrict_to_ready_deliverable_works(
# Only find books that have some kind of DeliveryMechanism.
LPDM = LicensePoolDeliveryMechanism
exists_clause = exists().where(
- and_(LicensePool.data_source_id==LPDM.data_source_id,
- LicensePool.identifier_id==LPDM.identifier_id)
+ and_(
+ LicensePool.data_source_id == LPDM.data_source_id,
+ LicensePool.identifier_id == LPDM.identifier_id,
+ )
)
query = query.filter(exists_clause)
@@ -783,17 +774,17 @@ def restrict_to_ready_deliverable_works(
_db = query.session
excluded = ConfigurationSetting.excluded_audio_data_sources(_db)
if excluded:
- audio_excluded_ids = [
- DataSource.lookup(_db, x).id for x in excluded
- ]
+ audio_excluded_ids = [DataSource.lookup(_db, x).id for x in excluded]
query = query.filter(
- or_(Edition.medium != EditionConstants.AUDIO_MEDIUM,
- ~LicensePool.data_source_id.in_(audio_excluded_ids))
+ or_(
+ Edition.medium != EditionConstants.AUDIO_MEDIUM,
+ ~LicensePool.data_source_id.in_(audio_excluded_ids),
+ )
)
# Only find books with unsuppressed LicensePools.
if not show_suppressed:
- query = query.filter(LicensePool.suppressed==False)
+ query = query.filter(LicensePool.suppressed == False)
# Only find books with available licenses or books from self-hosted collections using MirrorUploader
query = query.filter(
@@ -801,15 +792,13 @@ def restrict_to_ready_deliverable_works(
LicensePool.licenses_owned > 0,
LicensePool.open_access,
LicensePool.unlimited_access,
- LicensePool.self_hosted
+ LicensePool.self_hosted,
)
)
# Only find books in an appropriate collection.
if collection_ids is not None:
- query = query.filter(
- LicensePool.collection_id.in_(collection_ids)
- )
+ query = query.filter(LicensePool.collection_id.in_(collection_ids))
# If we don't allow holds, hide any books with no available copies.
if not allow_holds:
@@ -818,7 +807,7 @@ def restrict_to_ready_deliverable_works(
LicensePool.licenses_available > 0,
LicensePool.open_access,
LicensePool.self_hosted,
- LicensePool.unlimited_access
+ LicensePool.unlimited_access,
)
)
return query
@@ -856,7 +845,10 @@ def delete(self, search_index=None):
# Collection, assuming it wasn't deleted already.
if self.external_integration:
for link in self.external_integration.links:
- if link.other_integration and link.other_integration.goal == ExternalIntegration.STORAGE_GOAL:
+ if (
+ link.other_integration
+ and link.other_integration.goal == ExternalIntegration.STORAGE_GOAL
+ ):
logging.info(
f"Deletion of collection {self.name} is disassociating "
f"storage integration {link.other_integration.name}."
@@ -870,30 +862,40 @@ def delete(self, search_index=None):
collections_libraries = Table(
- 'collections_libraries', Base.metadata,
- Column(
- 'collection_id', Integer, ForeignKey('collections.id'),
- index=True, nullable=False
- ),
- Column(
- 'library_id', Integer, ForeignKey('libraries.id'),
- index=True, nullable=False
- ),
- UniqueConstraint('collection_id', 'library_id'),
- )
+ "collections_libraries",
+ Base.metadata,
+ Column(
+ "collection_id",
+ Integer,
+ ForeignKey("collections.id"),
+ index=True,
+ nullable=False,
+ ),
+ Column(
+ "library_id", Integer, ForeignKey("libraries.id"), index=True, nullable=False
+ ),
+ UniqueConstraint("collection_id", "library_id"),
+)
collections_identifiers = Table(
- 'collections_identifiers', Base.metadata,
+ "collections_identifiers",
+ Base.metadata,
Column(
- 'collection_id', Integer, ForeignKey('collections.id'),
- index=True, nullable=False
+ "collection_id",
+ Integer,
+ ForeignKey("collections.id"),
+ index=True,
+ nullable=False,
),
Column(
- 'identifier_id', Integer, ForeignKey('identifiers.id'),
- index=True, nullable=False
+ "identifier_id",
+ Integer,
+ ForeignKey("identifiers.id"),
+ index=True,
+ nullable=False,
),
- UniqueConstraint('collection_id', 'identifier_id'),
+ UniqueConstraint("collection_id", "identifier_id"),
)
# Create an ORM model for the collections_identifiers join table
@@ -907,27 +909,37 @@ class CollectionMissing(Exception):
of a Collection, but there was no Collection available.
"""
+
mapper(
- CollectionIdentifier, collections_identifiers,
+ CollectionIdentifier,
+ collections_identifiers,
primary_key=(
collections_identifiers.columns.collection_id,
- collections_identifiers.columns.identifier_id
- )
+ collections_identifiers.columns.identifier_id,
+ ),
)
collections_customlists = Table(
- 'collections_customlists', Base.metadata,
+ "collections_customlists",
+ Base.metadata,
Column(
- 'collection_id', Integer, ForeignKey('collections.id'),
- index=True, nullable=False,
+ "collection_id",
+ Integer,
+ ForeignKey("collections.id"),
+ index=True,
+ nullable=False,
),
Column(
- 'customlist_id', Integer, ForeignKey('customlists.id'),
- index=True, nullable=False,
+ "customlist_id",
+ Integer,
+ ForeignKey("customlists.id"),
+ index=True,
+ nullable=False,
),
- UniqueConstraint('collection_id', 'customlist_id'),
+ UniqueConstraint("collection_id", "customlist_id"),
)
+
class HasExternalIntegrationPerCollection(metaclass=ABCMeta):
"""Interface allowing to get access to an external integration"""
@@ -972,10 +984,12 @@ def save(self, db, setting_name, value):
:type value: Any
"""
collection = Collection.by_id(db, self._collection_id)
- integration = self._integration_owner.collection_external_integration(collection)
+ integration = self._integration_owner.collection_external_integration(
+ collection
+ )
ConfigurationSetting.for_externalintegration(
- setting_name,
- integration).value = value
+ setting_name, integration
+ ).value = value
def load(self, db, setting_name):
"""Loads and returns the library's configuration setting
@@ -989,9 +1003,11 @@ def load(self, db, setting_name):
:return: Any
"""
collection = Collection.by_id(db, self._collection_id)
- integration = self._integration_owner.collection_external_integration(collection)
+ integration = self._integration_owner.collection_external_integration(
+ collection
+ )
value = ConfigurationSetting.for_externalintegration(
- setting_name,
- integration).value
+ setting_name, integration
+ ).value
return value
diff --git a/model/complaint.py b/model/complaint.py
index 871a9c70f..c7d8a51e6 100644
--- a/model/complaint.py
+++ b/model/complaint.py
@@ -1,57 +1,49 @@
# encoding: utf-8
# Complaint
-from sqlalchemy import (
- Column,
- DateTime,
- ForeignKey,
- Integer,
- String,
-)
+from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
from sqlalchemy.orm.session import Session
-from . import (
- Base,
- create,
- get_one_or_create,
-)
from ..util.datetime_helpers import utc_now
+from . import Base, create, get_one_or_create
+
class Complaint(Base):
"""A complaint about a LicensePool (or, potentially, something else)."""
- __tablename__ = 'complaints'
-
- VALID_TYPES = set([
- "http://librarysimplified.org/terms/problem/" + x
- for x in [
- 'wrong-genre',
- 'wrong-audience',
- 'wrong-age-range',
- 'wrong-title',
- 'wrong-medium',
- 'wrong-author',
- 'bad-cover-image',
- 'bad-description',
- 'cannot-fulfill-loan',
- 'cannot-issue-loan',
- 'cannot-render',
- 'cannot-return',
- ]
- ])
+ __tablename__ = "complaints"
+
+ VALID_TYPES = set(
+ [
+ "http://librarysimplified.org/terms/problem/" + x
+ for x in [
+ "wrong-genre",
+ "wrong-audience",
+ "wrong-age-range",
+ "wrong-title",
+ "wrong-medium",
+ "wrong-author",
+ "bad-cover-image",
+ "bad-description",
+ "cannot-fulfill-loan",
+ "cannot-issue-loan",
+ "cannot-render",
+ "cannot-return",
+ ]
+ ]
+ )
LICENSE_POOL_TYPES = [
- 'cannot-fulfill-loan',
- 'cannot-issue-loan',
- 'cannot-render',
- 'cannot-return',
+ "cannot-fulfill-loan",
+ "cannot-issue-loan",
+ "cannot-render",
+ "cannot-return",
]
id = Column(Integer, primary_key=True)
# One LicensePool can have many complaints lodged against it.
- license_pool_id = Column(
- Integer, ForeignKey('licensepools.id'), index=True)
+ license_pool_id = Column(Integer, ForeignKey("licensepools.id"), index=True)
# The type of complaint.
type = Column(String, nullable=False, index=True)
@@ -80,14 +72,16 @@ def register(self, license_pool, type, source, detail, resolved=None):
now = utc_now()
if source:
complaint, is_new = get_one_or_create(
- _db, Complaint,
+ _db,
+ Complaint,
license_pool=license_pool,
- source=source, type=type,
+ source=source,
+ type=type,
resolved=resolved,
- on_multiple='interchangeable',
- create_method_kwargs = dict(
+ on_multiple="interchangeable",
+ create_method_kwargs=dict(
timestamp=now,
- )
+ ),
)
complaint.timestamp = now
complaint.detail = detail
@@ -100,7 +94,7 @@ def register(self, license_pool, type, source, detail, resolved=None):
type=type,
timestamp=now,
detail=detail,
- resolved=resolved
+ resolved=resolved,
)
return complaint, is_new
diff --git a/model/configuration.py b/model/configuration.py
index f2e9b0f5d..afb7aae68 100644
--- a/model/configuration.py
+++ b/model/configuration.py
@@ -14,85 +14,92 @@
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import and_
-from .constants import DataSourceConstants
-from .hasfulltablecache import HasFullTableCache
-from .library import Library
from ..config import CannotLoadConfiguration, Configuration
from ..mirror import MirrorUploader
from ..util.string_helpers import random_string
from . import Base, get_one, get_one_or_create
+from .constants import DataSourceConstants
+from .hasfulltablecache import HasFullTableCache
+from .library import Library
+
class ExternalIntegrationLink(Base, HasFullTableCache):
- __tablename__ = 'externalintegrationslinks'
+ __tablename__ = "externalintegrationslinks"
NO_MIRROR_INTEGRATION = "NO_MIRROR"
# Possible purposes that a storage external integration can be used for.
# These string literals may be stored in the database, so changes to them
# may need to be accompanied by a DB migration.
- COVERS = 'covers_mirror'
- COVERS_KEY = '{0}_integration_id'.format(COVERS)
+ COVERS = "covers_mirror"
+ COVERS_KEY = "{0}_integration_id".format(COVERS)
- OPEN_ACCESS_BOOKS = 'books_mirror'
- OPEN_ACCESS_BOOKS_KEY = '{0}_integration_id'.format(OPEN_ACCESS_BOOKS)
+ OPEN_ACCESS_BOOKS = "books_mirror"
+ OPEN_ACCESS_BOOKS_KEY = "{0}_integration_id".format(OPEN_ACCESS_BOOKS)
- PROTECTED_ACCESS_BOOKS = 'protected_access_books_mirror'
- PROTECTED_ACCESS_BOOKS_KEY = '{0}_integration_id'.format(PROTECTED_ACCESS_BOOKS)
+ PROTECTED_ACCESS_BOOKS = "protected_access_books_mirror"
+ PROTECTED_ACCESS_BOOKS_KEY = "{0}_integration_id".format(PROTECTED_ACCESS_BOOKS)
MARC = "MARC_mirror"
id = Column(Integer, primary_key=True)
external_integration_id = Column(
- Integer, ForeignKey('externalintegrations.id'), index=True
- )
- library_id = Column(
- Integer, ForeignKey('libraries.id'), index=True
+ Integer, ForeignKey("externalintegrations.id"), index=True
)
+ library_id = Column(Integer, ForeignKey("libraries.id"), index=True)
other_integration_id = Column(
- Integer, ForeignKey('externalintegrations.id'), index=True
+ Integer, ForeignKey("externalintegrations.id"), index=True
)
purpose = Column(Unicode, index=True)
mirror_settings = [
{
- 'key': COVERS_KEY,
- 'type': COVERS,
- 'description_type': 'cover images',
- 'label': 'Covers Mirror'
+ "key": COVERS_KEY,
+ "type": COVERS,
+ "description_type": "cover images",
+ "label": "Covers Mirror",
},
{
- 'key': OPEN_ACCESS_BOOKS_KEY,
- 'type': OPEN_ACCESS_BOOKS,
- 'description_type': 'free books',
- 'label': 'Open Access Books Mirror'
+ "key": OPEN_ACCESS_BOOKS_KEY,
+ "type": OPEN_ACCESS_BOOKS,
+ "description_type": "free books",
+ "label": "Open Access Books Mirror",
},
{
- 'key': PROTECTED_ACCESS_BOOKS_KEY,
- 'type': PROTECTED_ACCESS_BOOKS,
- 'description_type': 'self-hosted, commercially licensed books',
- 'label': 'Protected Access Books Mirror'
- }
+ "key": PROTECTED_ACCESS_BOOKS_KEY,
+ "type": PROTECTED_ACCESS_BOOKS,
+ "description_type": "self-hosted, commercially licensed books",
+ "label": "Protected Access Books Mirror",
+ },
]
settings = []
for mirror_setting in mirror_settings:
- mirror_type = mirror_setting['type']
- mirror_description_type = mirror_setting['description_type']
- mirror_label = mirror_setting['label']
-
- settings.append({
- 'key': '{0}_integration_id'.format(mirror_type.lower()),
- 'label': _(mirror_label),
- "description": _('Any {0} encountered while importing content from this collection '
- 'can be mirrored to a server you control.'.format(mirror_description_type)),
- 'type': 'select',
- 'options': [
- {
- 'key': NO_MIRROR_INTEGRATION,
- 'label': _('None - Do not mirror {0}'.format(mirror_description_type))
- }
- ]
- })
+ mirror_type = mirror_setting["type"]
+ mirror_description_type = mirror_setting["description_type"]
+ mirror_label = mirror_setting["label"]
+
+ settings.append(
+ {
+ "key": "{0}_integration_id".format(mirror_type.lower()),
+ "label": _(mirror_label),
+ "description": _(
+ "Any {0} encountered while importing content from this collection "
+ "can be mirrored to a server you control.".format(
+ mirror_description_type
+ )
+ ),
+ "type": "select",
+ "options": [
+ {
+ "key": NO_MIRROR_INTEGRATION,
+ "label": _(
+ "None - Do not mirror {0}".format(mirror_description_type)
+ ),
+ }
+ ],
+ }
+ )
COLLECTION_MIRROR_SETTINGS = settings
@@ -107,21 +114,21 @@ class ExternalIntegration(Base, HasFullTableCache):
#
# These integrations are associated with external services such as
# Google Enterprise which authenticate library administrators.
- ADMIN_AUTH_GOAL = 'admin_auth'
+ ADMIN_AUTH_GOAL = "admin_auth"
# These integrations are associated with external services such as
# SIP2 which authenticate library patrons. Other constants related
# to this are defined in the circulation manager.
- PATRON_AUTH_GOAL = 'patron_auth'
+ PATRON_AUTH_GOAL = "patron_auth"
# These integrations are associated with external services such
# as Overdrive which provide access to books.
- LICENSE_GOAL = 'licenses'
+ LICENSE_GOAL = "licenses"
# These integrations are associated with external services such as
# the metadata wrangler, which provide information about books,
# but not the books themselves.
- METADATA_GOAL = 'metadata'
+ METADATA_GOAL = "metadata"
# These integrations are associated with external services such as
# S3 that provide access to book covers.
@@ -129,40 +136,40 @@ class ExternalIntegration(Base, HasFullTableCache):
# These integrations are associated with external services like
# Cloudfront or other CDNs that mirror and/or cache certain domains.
- CDN_GOAL = 'CDN'
+ CDN_GOAL = "CDN"
# These integrations are associated with external services such as
# Elasticsearch that provide indexed search.
- SEARCH_GOAL = 'search'
+ SEARCH_GOAL = "search"
# These integrations are associated with external services such as
# Google Analytics, which receive analytics events.
- ANALYTICS_GOAL = 'analytics'
+ ANALYTICS_GOAL = "analytics"
# These integrations are associated with external services such as
# Adobe Vendor ID, which manage access to DRM-dependent content.
- DRM_GOAL = 'drm'
+ DRM_GOAL = "drm"
# These integrations are associated with external services that
# help patrons find libraries.
- DISCOVERY_GOAL = 'discovery'
+ DISCOVERY_GOAL = "discovery"
# These integrations are associated with external services that
# collect logs of server-side events.
- LOGGING_GOAL = 'logging'
+ LOGGING_GOAL = "logging"
# These integrations are associated with external services that
# a library uses to manage its catalog.
- CATALOG_GOAL = 'ils_catalog'
+ CATALOG_GOAL = "ils_catalog"
# Supported protocols for ExternalIntegrations with LICENSE_GOAL.
- OPDS_IMPORT = 'OPDS Import'
- OPDS2_IMPORT = 'OPDS 2.0 Import'
+ OPDS_IMPORT = "OPDS Import"
+ OPDS2_IMPORT = "OPDS 2.0 Import"
OVERDRIVE = DataSourceConstants.OVERDRIVE
ODILO = DataSourceConstants.ODILO
BIBLIOTHECA = DataSourceConstants.BIBLIOTHECA
AXIS_360 = DataSourceConstants.AXIS_360
- OPDS_FOR_DISTRIBUTORS = 'OPDS for Distributors'
+ OPDS_FOR_DISTRIBUTORS = "OPDS for Distributors"
ENKI = DataSourceConstants.ENKI
FEEDBOOKS = DataSourceConstants.FEEDBOOKS
ODL = "ODL"
@@ -180,58 +187,64 @@ class ExternalIntegration(Base, HasFullTableCache):
GUTENBERG = DataSourceConstants.GUTENBERG
LICENSE_PROTOCOLS = [
- OPDS_IMPORT, OVERDRIVE, ODILO, BIBLIOTHECA, AXIS_360,
- GUTENBERG, ENKI, MANUAL
+ OPDS_IMPORT,
+ OVERDRIVE,
+ ODILO,
+ BIBLIOTHECA,
+ AXIS_360,
+ GUTENBERG,
+ ENKI,
+ MANUAL,
]
# Some integrations with LICENSE_GOAL imply that the data and
# licenses come from a specific data source.
DATA_SOURCE_FOR_LICENSE_PROTOCOL = {
- OVERDRIVE : DataSourceConstants.OVERDRIVE,
- ODILO : DataSourceConstants.ODILO,
- BIBLIOTHECA : DataSourceConstants.BIBLIOTHECA,
- AXIS_360 : DataSourceConstants.AXIS_360,
- ENKI : DataSourceConstants.ENKI,
- FEEDBOOKS : DataSourceConstants.FEEDBOOKS,
+ OVERDRIVE: DataSourceConstants.OVERDRIVE,
+ ODILO: DataSourceConstants.ODILO,
+ BIBLIOTHECA: DataSourceConstants.BIBLIOTHECA,
+ AXIS_360: DataSourceConstants.AXIS_360,
+ ENKI: DataSourceConstants.ENKI,
+ FEEDBOOKS: DataSourceConstants.FEEDBOOKS,
}
# Integrations with METADATA_GOAL
- BIBBLIO = 'Bibblio'
- CONTENT_CAFE = 'Content Cafe'
- NOVELIST = 'NoveList Select'
- NYPL_SHADOWCAT = 'Shadowcat'
- NYT = 'New York Times'
- METADATA_WRANGLER = 'Metadata Wrangler'
- CONTENT_SERVER = 'Content Server'
+ BIBBLIO = "Bibblio"
+ CONTENT_CAFE = "Content Cafe"
+ NOVELIST = "NoveList Select"
+ NYPL_SHADOWCAT = "Shadowcat"
+ NYT = "New York Times"
+ METADATA_WRANGLER = "Metadata Wrangler"
+ CONTENT_SERVER = "Content Server"
# Integrations with STORAGE_GOAL
- S3 = 'Amazon S3'
- MINIO = 'MinIO'
- LCP = 'LCP'
+ S3 = "Amazon S3"
+ MINIO = "MinIO"
+ LCP = "LCP"
# Integrations with CDN_GOAL
- CDN = 'CDN'
+ CDN = "CDN"
# Integrations with SEARCH_GOAL
- ELASTICSEARCH = 'Elasticsearch'
+ ELASTICSEARCH = "Elasticsearch"
# Integrations with DRM_GOAL
- ADOBE_VENDOR_ID = 'Adobe Vendor ID'
+ ADOBE_VENDOR_ID = "Adobe Vendor ID"
# Integrations with DISCOVERY_GOAL
- OPDS_REGISTRATION = 'OPDS Registration'
+ OPDS_REGISTRATION = "OPDS Registration"
# Integrations with ANALYTICS_GOAL
- GOOGLE_ANALYTICS = 'Google Analytics'
+ GOOGLE_ANALYTICS = "Google Analytics"
# Integrations with ADMIN_AUTH_GOAL
- GOOGLE_OAUTH = 'Google OAuth'
+ GOOGLE_OAUTH = "Google OAuth"
# List of such ADMIN_AUTH_GOAL integrations
ADMIN_AUTH_PROTOCOLS = [GOOGLE_OAUTH]
# Integrations with LOGGING_GOAL
- INTERNAL_LOGGING = 'Internal logging'
+ INTERNAL_LOGGING = "Internal logging"
LOGGLY = "Loggly"
CLOUDWATCH = "AWS Cloudwatch Logs"
@@ -261,7 +274,7 @@ class ExternalIntegration(Base, HasFullTableCache):
_cache = HasFullTableCache.RESET
_id_cache = HasFullTableCache.RESET
- __tablename__ = 'externalintegrations'
+ __tablename__ = "externalintegrations"
id = Column(Integer, primary_key=True)
# Each integration should have a protocol (explaining what type of
@@ -279,34 +292,41 @@ class ExternalIntegration(Base, HasFullTableCache):
# Any additional configuration information goes into
# ConfigurationSettings.
settings = relationship(
- "ConfigurationSetting", backref="external_integration",
- lazy="joined", cascade="all, delete",
+ "ConfigurationSetting",
+ backref="external_integration",
+ lazy="joined",
+ cascade="all, delete",
)
# Any number of Collections may designate an ExternalIntegration
# as the source of their configuration
collections = relationship(
- "Collection", backref="_external_integration",
- foreign_keys='Collection.external_integration_id',
+ "Collection",
+ backref="_external_integration",
+ foreign_keys="Collection.external_integration_id",
)
links = relationship(
"ExternalIntegrationLink",
backref="integration",
foreign_keys="ExternalIntegrationLink.external_integration_id",
- cascade="all, delete-orphan"
+ cascade="all, delete-orphan",
)
other_links = relationship(
"ExternalIntegrationLink",
backref="other_integration",
foreign_keys="ExternalIntegrationLink.other_integration_id",
- cascade="all, delete-orphan"
+ cascade="all, delete-orphan",
)
def __repr__(self):
return "" % (
- self.protocol, self.goal, len(self.settings), self.id)
+ self.protocol,
+ self.goal,
+ len(self.settings),
+ self.id,
+ )
def cache_key(self):
# TODO: This is not ideal, but the lookup method isn't like
@@ -320,13 +340,8 @@ def cache_key(self):
@classmethod
def for_goal(cls, _db, goal):
- """Return all external integrations by goal type.
- """
- integrations = _db.query(cls).filter(
- cls.goal==goal
- ).order_by(
- cls.name
- )
+ """Return all external integrations by goal type."""
+ integrations = _db.query(cls).filter(cls.goal == goal).order_by(cls.name)
return integrations
@@ -337,22 +352,28 @@ def for_collection_and_purpose(cls, _db, collection, purpose):
:param collection: Use the mirror configuration for this Collection.
:param purpose: Use the purpose of the mirror configuration.
"""
- qu = _db.query(cls).join(
- ExternalIntegrationLink,
- ExternalIntegrationLink.other_integration_id==cls.id
- ).filter(
- ExternalIntegrationLink.external_integration_id==collection.external_integration_id,
- ExternalIntegrationLink.purpose==purpose
+ qu = (
+ _db.query(cls)
+ .join(
+ ExternalIntegrationLink,
+ ExternalIntegrationLink.other_integration_id == cls.id,
+ )
+ .filter(
+ ExternalIntegrationLink.external_integration_id
+ == collection.external_integration_id,
+ ExternalIntegrationLink.purpose == purpose,
+ )
)
integrations = qu.all()
if not integrations:
raise CannotLoadConfiguration(
- "No storage integration for collection '%s' and purpose '%s' is configured." %
- (collection.name, purpose)
+ "No storage integration for collection '%s' and purpose '%s' is configured."
+ % (collection.name, purpose)
)
if len(integrations) > 1:
raise CannotLoadConfiguration(
- "Multiple integrations found for collection '%s' and purpose '%s'" % (collection.name, purpose)
+ "Multiple integrations found for collection '%s' and purpose '%s'"
+ % (collection.name, purpose)
)
[integration] = integrations
@@ -361,12 +382,14 @@ def for_collection_and_purpose(cls, _db, collection, purpose):
@classmethod
def lookup(cls, _db, protocol, goal, library=None):
- integrations = _db.query(cls).outerjoin(cls.libraries).filter(
- cls.protocol==protocol, cls.goal==goal
+ integrations = (
+ _db.query(cls)
+ .outerjoin(cls.libraries)
+ .filter(cls.protocol == protocol, cls.goal == goal)
)
if library:
- integrations = integrations.filter(Library.id==library.id)
+ integrations = integrations.filter(Library.id == library.id)
integrations = integrations.all()
if len(integrations) > 1:
@@ -374,7 +397,7 @@ def lookup(cls, _db, protocol, goal, library=None):
if [i for i in integrations if i.libraries] and not library:
raise ValueError(
- 'This ExternalIntegration requires a library and none was provided.'
+ "This ExternalIntegration requires a library and none was provided."
)
if not integrations:
@@ -398,18 +421,13 @@ def with_setting_value(cls, _db, protocol, goal, key, value):
has this value.
:return: A Query object.
"""
- return _db.query(
- ExternalIntegration
- ).join(
- ExternalIntegration.settings
- ).filter(
- ExternalIntegration.goal==goal
- ).filter(
- ExternalIntegration.protocol==protocol
- ).filter(
- ConfigurationSetting.key==key
- ).filter(
- ConfigurationSetting.value==value
+ return (
+ _db.query(ExternalIntegration)
+ .join(ExternalIntegration.settings)
+ .filter(ExternalIntegration.goal == goal)
+ .filter(ExternalIntegration.protocol == protocol)
+ .filter(ConfigurationSetting.key == key)
+ .filter(ConfigurationSetting.value == value)
)
@classmethod
@@ -423,12 +441,11 @@ def for_library_and_goal(cls, _db, library, goal):
Library and the given goal.
:return: A Query.
"""
- return _db.query(ExternalIntegration).join(
- ExternalIntegration.libraries
- ).filter(
- ExternalIntegration.goal==goal
- ).filter(
- Library.id==library.id
+ return (
+ _db.query(ExternalIntegration)
+ .join(ExternalIntegration.libraries)
+ .filter(ExternalIntegration.goal == goal)
+ .filter(Library.id == library.id)
)
@classmethod
@@ -443,9 +460,8 @@ def one_for_library_and_goal(cls, _db, library, goal):
return None
if len(integrations) > 1:
raise CannotLoadConfiguration(
- "Library %s defines multiple integrations with goal %s!" % (
- library.name, goal
- )
+ "Library %s defines multiple integrations with goal %s!"
+ % (library.name, goal)
)
return integrations[0]
@@ -460,9 +476,7 @@ def setting(self, key):
:param key: Name of the setting.
:return: A ConfigurationSetting
"""
- return ConfigurationSetting.for_externalintegration(
- key, self
- )
+ return ConfigurationSetting.for_externalintegration(key, self)
@hybrid_property
def url(self):
@@ -502,8 +516,9 @@ def primary_identifier_source(self):
@primary_identifier_source.setter
def primary_identifier_source(self, new_primary_identifier_source):
- return self.set_setting(self.PRIMARY_IDENTIFIER_SOURCE,
- new_primary_identifier_source)
+ return self.set_setting(
+ self.PRIMARY_IDENTIFIER_SOURCE, new_primary_identifier_source
+ )
def explain(self, library=None, include_secrets=False):
"""Create a series of human-readable strings to explain an
@@ -525,6 +540,7 @@ def key(setting):
if setting.library:
return setting.key, setting.library.name
return (setting.key, None)
+
for setting in sorted(self.settings, key=key):
if library and setting.library and setting.library != library:
# This is a different library's specialization of
@@ -536,7 +552,8 @@ def key(setting):
explanation = "%s='%s'" % (setting.key, setting.value)
if setting.library:
explanation = "%s (applies only to %s)" % (
- explanation, setting.library.name
+ explanation,
+ setting.library.name,
)
if include_secrets or not setting.is_secret:
lines.append(explanation)
@@ -560,56 +577,53 @@ class ConfigurationSetting(Base, HasFullTableCache):
is a patron of, is associated with both a Library and an
ExternalIntegration.
"""
- __tablename__ = 'configurationsettings'
+
+ __tablename__ = "configurationsettings"
id = Column(Integer, primary_key=True)
external_integration_id = Column(
- Integer, ForeignKey('externalintegrations.id'), index=True
- )
- library_id = Column(
- Integer, ForeignKey('libraries.id'), index=True
+ Integer, ForeignKey("externalintegrations.id"), index=True
)
+ library_id = Column(Integer, ForeignKey("libraries.id"), index=True)
key = Column(Unicode)
_value = Column(Unicode, name="value")
__table_args__ = (
# Unique indexes to prevent the creation of redundant
# configuration settings.
-
# If both external_integration_id and library_id are null,
# then the key--the name of a sitewide setting--must be unique.
Index(
"ix_configurationsettings_key",
key,
unique=True,
- postgresql_where=and_(
- external_integration_id==None, library_id==None
- )
+ postgresql_where=and_(external_integration_id == None, library_id == None),
),
-
# If external_integration_id is null but library_id is not,
# then (library_id, key) must be unique.
Index(
"ix_configurationsettings_library_id_key",
- library_id, key,
+ library_id,
+ key,
unique=True,
- postgresql_where=(external_integration_id==None)
+ postgresql_where=(external_integration_id == None),
),
-
# If library_id is null but external_integration_id is not,
# then (external_integration_id, key) must be unique.
Index(
"ix_configurationsettings_external_integration_id_key",
- external_integration_id, key,
+ external_integration_id,
+ key,
unique=True,
- postgresql_where=library_id==None
+ postgresql_where=library_id == None,
),
-
# If both external_integration_id and library_id have values,
# then (external_integration_id, library_id, key) must be
# unique.
Index(
"ix_configurationsettings_external_integration_id_library_id_key",
- external_integration_id, library_id, key,
+ external_integration_id,
+ library_id,
+ key,
unique=True,
),
)
@@ -618,8 +632,7 @@ class ConfigurationSetting(Base, HasFullTableCache):
_id_cache = HasFullTableCache.RESET
def __repr__(self):
- return '' % (
- self.key, self.id)
+ return "" % (self.key, self.id)
@classmethod
def sitewide_secret(cls, _db, key):
@@ -640,9 +653,11 @@ def explain(cls, _db, include_secrets=False):
lines = []
site_wide_settings = []
- for setting in _db.query(ConfigurationSetting).filter(
- ConfigurationSetting.library==None).filter(
- ConfigurationSetting.external_integration==None):
+ for setting in (
+ _db.query(ConfigurationSetting)
+ .filter(ConfigurationSetting.library == None)
+ .filter(ConfigurationSetting.external_integration == None)
+ ):
if not include_secrets and setting.key.endswith("_secret"):
continue
site_wide_settings.append(setting)
@@ -693,19 +708,22 @@ def cache_key(self):
@classmethod
def for_library_and_externalintegration(
- cls, _db, key, library, external_integration
+ cls, _db, key, library, external_integration
):
"""Find or create a ConfigurationSetting associated with a Library
and an ExternalIntegration.
"""
+
def create():
"""Function called when a ConfigurationSetting is not found in cache
and must be created.
"""
return get_one_or_create(
- _db, ConfigurationSetting,
- library=library, external_integration=external_integration,
- key=key
+ _db,
+ ConfigurationSetting,
+ library=library,
+ external_integration=external_integration,
+ key=key,
)
# ConfigurationSettings are stored in cache based on their library,
@@ -729,7 +747,8 @@ def value(self):
# ExternalIntegration. Treat the value set on the
# ExternalIntegration as a default.
return self.for_externalintegration(
- self.key, self.external_integration).value
+ self.key, self.external_integration
+ ).value
elif self.library:
# This is a library-specific setting. Treat the site-wide
# value as a default.
@@ -752,11 +771,11 @@ def _is_secret(self, key):
saying that a specific setting should be treated as secret.
"""
return any(
- key == x or
- key.startswith('%s_' % x) or
- key.endswith('_%s' % x) or
- ("_%s_" %x) in key
- for x in ('secret', 'password')
+ key == x
+ or key.startswith("%s_" % x)
+ or key.endswith("_%s" % x)
+ or ("_%s_" % x) in key
+ for x in ("secret", "password")
)
@property
@@ -772,7 +791,8 @@ def value_or_default(self, default):
self.value = default
return self.value
- MEANS_YES = set(['true', 't', 'yes', 'y'])
+ MEANS_YES = set(["true", "t", "yes", "y"])
+
@property
def bool_value(self):
"""Turn the value into a boolean if possible.
@@ -827,9 +847,7 @@ def excluded_audio_data_sources(cls, _db):
Most methods like this go into Configuration, but this one needs
to reference data model objects for its default value.
"""
- value = cls.sitewide(
- _db, Configuration.EXCLUDED_AUDIO_DATA_SOURCES
- ).json_value
+ value = cls.sitewide(_db, Configuration.EXCLUDED_AUDIO_DATA_SOURCES).json_value
if value is None:
value = cls.EXCLUDED_AUDIO_DATA_SOURCES_DEFAULT
return value
@@ -909,8 +927,8 @@ def save(self, db, setting_name, value):
"""
integration = self._integration_association.external_integration(db)
ConfigurationSetting.for_externalintegration(
- setting_name,
- integration).value = value
+ setting_name, integration
+ ).value = value
def load(self, db, setting_name):
"""Loads and returns the library's configuration setting
@@ -925,8 +943,8 @@ def load(self, db, setting_name):
"""
integration = self._integration_association.external_integration(db)
value = ConfigurationSetting.for_externalintegration(
- setting_name,
- integration).value
+ setting_name, integration
+ ).value
return value
@@ -934,12 +952,12 @@ def load(self, db, setting_name):
class ConfigurationAttributeType(Enum):
"""Enumeration of configuration setting types"""
- TEXT = 'text'
- TEXTAREA = 'textarea'
- SELECT = 'select'
- NUMBER = 'number'
- LIST = 'list'
- MENU = 'menu'
+ TEXT = "text"
+ TEXTAREA = "textarea"
+ SELECT = "select"
+ NUMBER = "number"
+ LIST = "list"
+ MENU = "menu"
def to_control_type(self):
"""Converts the value to a attribute type understandable by circulation-admin
@@ -959,15 +977,15 @@ def to_control_type(self):
class ConfigurationAttribute(Enum):
"""Enumeration of configuration setting attributes"""
- KEY = 'key'
- LABEL = 'label'
- DESCRIPTION = 'description'
- TYPE = 'type'
- REQUIRED = 'required'
- DEFAULT = 'default'
- OPTIONS = 'options'
- CATEGORY = 'category'
- FORMAT = 'format'
+ KEY = "key"
+ LABEL = "label"
+ DESCRIPTION = "description"
+ TYPE = "type"
+ REQUIRED = "required"
+ DEFAULT = "default"
+ OPTIONS = "options"
+ CATEGORY = "category"
+ FORMAT = "format"
class ConfigurationOption(object):
@@ -997,9 +1015,7 @@ def __eq__(self, other):
if not isinstance(other, ConfigurationOption):
return False
- return \
- self.key == other.key and \
- self.label == other.label
+ return self.key == other.key and self.label == other.label
@property
def key(self):
@@ -1025,10 +1041,7 @@ def to_settings(self):
:return: Dictionary containing option metadata in the SETTINGS format
:rtype: Dict
"""
- return {
- 'key': self.key,
- 'label': self.label
- }
+ return {"key": self.key, "label": self.label}
@staticmethod
def from_enum(cls):
@@ -1041,12 +1054,9 @@ def from_enum(cls):
:rtype: List[Dict]
"""
if not issubclass(cls, Enum):
- raise ValueError('Class should be descendant of Enum')
+ raise ValueError("Class should be descendant of Enum")
- return [
- ConfigurationOption(element.value, element.name)
- for element in cls
- ]
+ return [ConfigurationOption(element.value, element.name) for element in cls]
class HasConfigurationSettings(metaclass=ABCMeta):
@@ -1083,17 +1093,17 @@ class ConfigurationMetadata(object):
_counter = 0
def __init__(
- self,
- key,
- label,
- description,
- type,
- required=False,
- default=None,
- options=None,
- category=None,
- format=None,
- index=None
+ self,
+ key,
+ label,
+ description,
+ type,
+ required=False,
+ default=None,
+ options=None,
+ category=None,
+ format=None,
+ index=None,
):
"""Initializes a new instance of ConfigurationMetadata class
@@ -1157,7 +1167,9 @@ def __get__(self, owner_instance, owner_type):
return self
if not isinstance(owner_instance, HasConfigurationSettings):
- raise Exception('owner must be an instance of ConfigurationSettingsMetadataOwner type')
+ raise Exception(
+ "owner must be an instance of ConfigurationSettingsMetadataOwner type"
+ )
return owner_instance.get_setting_value(self._key)
@@ -1171,7 +1183,9 @@ def __set__(self, owner_instance, value):
:type value: Any
"""
if not isinstance(owner_instance, HasConfigurationSettings):
- raise Exception('owner must be an instance ConfigurationSettingsMetadataOwner type')
+ raise Exception(
+ "owner must be an instance ConfigurationSettingsMetadataOwner type"
+ )
return owner_instance.set_setting_value(self._key, value)
@@ -1289,12 +1303,13 @@ def to_settings(self):
ConfigurationAttribute.TYPE.value: self.type.to_control_type(),
ConfigurationAttribute.REQUIRED.value: self.required,
ConfigurationAttribute.DEFAULT.value: self.default,
- ConfigurationAttribute.OPTIONS.value:
- [option.to_settings() for option in self.options]
- if self.options
- else None,
+ ConfigurationAttribute.OPTIONS.value: [
+ option.to_settings() for option in self.options
+ ]
+ if self.options
+ else None,
ConfigurationAttribute.CATEGORY.value: self.category,
- ConfigurationAttribute.FORMAT.value: self.format
+ ConfigurationAttribute.FORMAT.value: self.format,
}
@staticmethod
@@ -1369,12 +1384,22 @@ def to_settings_generator(cls):
for name, member in ConfigurationMetadata.get_configuration_metadata(cls):
key_attribute = getattr(member, ConfigurationAttribute.KEY.value, None)
label_attribute = getattr(member, ConfigurationAttribute.LABEL.value, None)
- description_attribute = getattr(member, ConfigurationAttribute.DESCRIPTION.value, None)
+ description_attribute = getattr(
+ member, ConfigurationAttribute.DESCRIPTION.value, None
+ )
type_attribute = getattr(member, ConfigurationAttribute.TYPE.value, None)
- required_attribute = getattr(member, ConfigurationAttribute.REQUIRED.value, None)
- default_attribute = getattr(member, ConfigurationAttribute.DEFAULT.value, None)
- options_attribute = getattr(member, ConfigurationAttribute.OPTIONS.value, None)
- category_attribute = getattr(member, ConfigurationAttribute.CATEGORY.value, None)
+ required_attribute = getattr(
+ member, ConfigurationAttribute.REQUIRED.value, None
+ )
+ default_attribute = getattr(
+ member, ConfigurationAttribute.DEFAULT.value, None
+ )
+ options_attribute = getattr(
+ member, ConfigurationAttribute.OPTIONS.value, None
+ )
+ category_attribute = getattr(
+ member, ConfigurationAttribute.CATEGORY.value, None
+ )
yield {
ConfigurationAttribute.KEY.value: key_attribute,
@@ -1383,11 +1408,12 @@ def to_settings_generator(cls):
ConfigurationAttribute.TYPE.value: type_attribute.to_control_type(),
ConfigurationAttribute.REQUIRED.value: required_attribute,
ConfigurationAttribute.DEFAULT.value: default_attribute,
- ConfigurationAttribute.OPTIONS.value:
- [option.to_settings() for option in options_attribute]
- if options_attribute
- else None,
- ConfigurationAttribute.CATEGORY.value: category_attribute
+ ConfigurationAttribute.OPTIONS.value: [
+ option.to_settings() for option in options_attribute
+ ]
+ if options_attribute
+ else None,
+ ConfigurationAttribute.CATEGORY.value: category_attribute,
}
@classmethod
@@ -1419,5 +1445,7 @@ def create(self, configuration_storage, db, configuration_grouping_class):
:return: ConfigurationGrouping instance
:rtype: ConfigurationGrouping
"""
- with configuration_grouping_class(configuration_storage, db) as configuration_bucket:
+ with configuration_grouping_class(
+ configuration_storage, db
+ ) as configuration_bucket:
yield configuration_bucket
diff --git a/model/constants.py b/model/constants.py
index 7fe3c90b9..d1e4af6c2 100644
--- a/model/constants.py
+++ b/model/constants.py
@@ -5,6 +5,7 @@
import re
from collections import OrderedDict
+
class DataSourceConstants(object):
GUTENBERG = "Gutenberg"
OVERDRIVE = "Overdrive"
@@ -42,9 +43,7 @@ class DataSourceConstants(object):
LCP = "LCP"
PROQUEST = "ProQuest"
- DEPRECATED_NAMES = {
- "3M" : BIBLIOTHECA
- }
+ DEPRECATED_NAMES = {"3M": BIBLIOTHECA}
THREEM = BIBLIOTHECA
# Some sources of open-access ebooks are better than others. This
@@ -89,6 +88,7 @@ class DataSourceConstants(object):
# higher priority than the source of the license pool.
COVER_IMAGE_PRIORITY = [METADATA_WRANGLER] + PRESENTATION_EDITION_PRIORITY
+
class EditionConstants(object):
ALL_MEDIUM = object()
BOOK_MEDIUM = "Book"
@@ -103,21 +103,28 @@ class EditionConstants(object):
CODEX_FORMAT = "Codex"
# These are all media known to the system.
- KNOWN_MEDIA = (BOOK_MEDIUM, PERIODICAL_MEDIUM, AUDIO_MEDIUM, MUSIC_MEDIUM,
- VIDEO_MEDIUM, IMAGE_MEDIUM, COURSEWARE_MEDIUM)
+ KNOWN_MEDIA = (
+ BOOK_MEDIUM,
+ PERIODICAL_MEDIUM,
+ AUDIO_MEDIUM,
+ MUSIC_MEDIUM,
+ VIDEO_MEDIUM,
+ IMAGE_MEDIUM,
+ COURSEWARE_MEDIUM,
+ )
# These are the media types currently fulfillable by the default
# client.
FULFILLABLE_MEDIA = [BOOK_MEDIUM, AUDIO_MEDIUM]
medium_to_additional_type = {
- BOOK_MEDIUM : "http://schema.org/EBook",
- AUDIO_MEDIUM : "http://bib.schema.org/Audiobook",
- PERIODICAL_MEDIUM : "http://schema.org/PublicationIssue",
- MUSIC_MEDIUM : "http://schema.org/MusicRecording",
- VIDEO_MEDIUM : "http://schema.org/VideoObject",
+ BOOK_MEDIUM: "http://schema.org/EBook",
+ AUDIO_MEDIUM: "http://bib.schema.org/Audiobook",
+ PERIODICAL_MEDIUM: "http://schema.org/PublicationIssue",
+ MUSIC_MEDIUM: "http://schema.org/MusicRecording",
+ VIDEO_MEDIUM: "http://schema.org/VideoObject",
IMAGE_MEDIUM: "http://schema.org/ImageObject",
- COURSEWARE_MEDIUM: "http://schema.org/Course"
+ COURSEWARE_MEDIUM: "http://schema.org/Course",
}
additional_type_to_medium = {}
@@ -129,13 +136,13 @@ class EditionConstants(object):
# Map the medium constants to the strings used when generating
# permanent work IDs.
medium_for_permanent_work_id = {
- BOOK_MEDIUM : "book",
- AUDIO_MEDIUM : "book",
- MUSIC_MEDIUM : "music",
- PERIODICAL_MEDIUM : "book",
+ BOOK_MEDIUM: "book",
+ AUDIO_MEDIUM: "book",
+ MUSIC_MEDIUM: "music",
+ PERIODICAL_MEDIUM: "book",
VIDEO_MEDIUM: "movie",
IMAGE_MEDIUM: "image",
- COURSEWARE_MEDIUM: "courseware"
+ COURSEWARE_MEDIUM: "courseware",
}
@@ -163,23 +170,27 @@ class IdentifierConstants(object):
PROQUEST_ID = "ProQuest Doc ID"
DEPRECATED_NAMES = {
- "3M ID" : BIBLIOTHECA_ID,
+ "3M ID": BIBLIOTHECA_ID,
}
THREEM_ID = BIBLIOTHECA_ID
LICENSE_PROVIDING_IDENTIFIER_TYPES = [
- BIBLIOTHECA_ID, OVERDRIVE_ID, ODILO_ID, AXIS_360_ID,
- GUTENBERG_ID, ELIB_ID, SUDOC_CALL_NUMBER,
+ BIBLIOTHECA_ID,
+ OVERDRIVE_ID,
+ ODILO_ID,
+ AXIS_360_ID,
+ GUTENBERG_ID,
+ ELIB_ID,
+ SUDOC_CALL_NUMBER,
]
URN_SCHEME_PREFIX = "urn:librarysimplified.org/terms/id/"
ISBN_URN_SCHEME_PREFIX = "urn:isbn:"
GUTENBERG_URN_SCHEME_PREFIX = "http://www.gutenberg.org/ebooks/"
- GUTENBERG_URN_SCHEME_RE = re.compile(
- GUTENBERG_URN_SCHEME_PREFIX + "([0-9]+)")
+ GUTENBERG_URN_SCHEME_RE = re.compile(GUTENBERG_URN_SCHEME_PREFIX + "([0-9]+)")
OTHER_URN_SCHEME_PREFIX = "urn:"
- IDEAL_COVER_ASPECT_RATIO = 2.0/3
+ IDEAL_COVER_ASPECT_RATIO = 2.0 / 3
IDEAL_IMAGE_HEIGHT = 240
IDEAL_IMAGE_WIDTH = 160
@@ -203,9 +214,24 @@ class LinkRelations(object):
DRM_ENCRYPTED_DOWNLOAD = "http://opds-spec.org/acquisition/"
BORROW = "http://opds-spec.org/acquisition/borrow"
- CIRCULATION_ALLOWED = [OPEN_ACCESS_DOWNLOAD, DRM_ENCRYPTED_DOWNLOAD, BORROW, GENERIC_OPDS_ACQUISITION]
- METADATA_ALLOWED = [CANONICAL, IMAGE, THUMBNAIL_IMAGE, ILLUSTRATION, REVIEW,
- DESCRIPTION, SHORT_DESCRIPTION, AUTHOR, ALTERNATE, SAMPLE]
+ CIRCULATION_ALLOWED = [
+ OPEN_ACCESS_DOWNLOAD,
+ DRM_ENCRYPTED_DOWNLOAD,
+ BORROW,
+ GENERIC_OPDS_ACQUISITION,
+ ]
+ METADATA_ALLOWED = [
+ CANONICAL,
+ IMAGE,
+ THUMBNAIL_IMAGE,
+ ILLUSTRATION,
+ REVIEW,
+ DESCRIPTION,
+ SHORT_DESCRIPTION,
+ AUTHOR,
+ ALTERNATE,
+ SAMPLE,
+ ]
MIRRORED = [OPEN_ACCESS_DOWNLOAD, GENERIC_OPDS_ACQUISITION, IMAGE, THUMBNAIL_IMAGE]
SELF_HOSTED_BOOKS = list(set(CIRCULATION_ALLOWED) & set(MIRRORED))
@@ -237,13 +263,17 @@ class MediaTypes(object):
# (hopefully future) ebook manifests, we invent values for the
# 'profile' parameter.
OVERDRIVE_MANIFEST_MEDIA_TYPE = "application/vnd.overdrive.circulation.api+json"
- OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE = OVERDRIVE_MANIFEST_MEDIA_TYPE + ";profile=audiobook"
- OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE = OVERDRIVE_MANIFEST_MEDIA_TYPE + ";profile=ebook"
+ OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE = (
+ OVERDRIVE_MANIFEST_MEDIA_TYPE + ";profile=audiobook"
+ )
+ OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE = (
+ OVERDRIVE_MANIFEST_MEDIA_TYPE + ";profile=ebook"
+ )
AUDIOBOOK_MEDIA_TYPES = [
OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE,
AUDIOBOOK_MANIFEST_MEDIA_TYPE,
- AUDIOBOOK_PACKAGE_MEDIA_TYPE
+ AUDIOBOOK_PACKAGE_MEDIA_TYPE,
]
BOOK_MEDIA_TYPES = [
@@ -268,7 +298,7 @@ class MediaTypes(object):
SUPPORTED_BOOK_MEDIA_TYPES = [
EPUB_MEDIA_TYPE,
PDF_MEDIA_TYPE,
- AUDIOBOOK_MANIFEST_MEDIA_TYPE
+ AUDIOBOOK_MANIFEST_MEDIA_TYPE,
]
# Most of the time, if you believe a resource to be media type A,
@@ -297,21 +327,21 @@ class MediaTypes(object):
(APPLICATION_XML_MEDIA_TYPE, "xml"),
(AUDIOBOOK_MANIFEST_MEDIA_TYPE, "audiobook-manifest"),
(AUDIOBOOK_PACKAGE_MEDIA_TYPE, "audiobook"),
- (SCORM_MEDIA_TYPE, "zip")
+ (SCORM_MEDIA_TYPE, "zip"),
]
)
- COMMON_EBOOK_EXTENSIONS = ['.epub', '.pdf', '.audiobook']
- COMMON_IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif']
+ COMMON_EBOOK_EXTENSIONS = [".epub", ".pdf", ".audiobook"]
+ COMMON_IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif"]
# Invert FILE_EXTENSIONS and add some extra guesses.
MEDIA_TYPE_FOR_EXTENSION = {
- ".htm" : TEXT_HTML_MEDIA_TYPE,
- ".jpeg" : JPEG_MEDIA_TYPE,
+ ".htm": TEXT_HTML_MEDIA_TYPE,
+ ".jpeg": JPEG_MEDIA_TYPE,
}
for media_type, extension in list(FILE_EXTENSIONS.items()):
- extension = '.' + extension
+ extension = "." + extension
if extension not in MEDIA_TYPE_FOR_EXTENSION:
# FILE_EXTENSIONS lists more common extensions first. If
# multiple media types have the same extension, the most
diff --git a/model/contributor.py b/model/contributor.py
index 23865b80e..e6817c86f 100644
--- a/model/contributor.py
+++ b/model/contributor.py
@@ -2,38 +2,24 @@
# Contributor, Contribution
-from . import (
- Base,
- flush,
- get_one_or_create,
-)
-
import logging
import re
-from sqlalchemy import (
- Column,
- ForeignKey,
- Integer,
- Unicode,
- UniqueConstraint,
-)
-from sqlalchemy.dialects.postgresql import (
- ARRAY,
- JSON,
-)
+
+from sqlalchemy import Column, ForeignKey, Integer, Unicode, UniqueConstraint
+from sqlalchemy.dialects.postgresql import ARRAY, JSON
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.mutable import MutableDict
-from sqlalchemy.orm import (
- relationship,
- synonym,
-)
+from sqlalchemy.orm import relationship, synonym
from sqlalchemy.orm.session import Session
+
from ..util.personal_names import display_name_to_sort_name
+from . import Base, flush, get_one_or_create
+
class Contributor(Base):
"""Someone (usually human) who contributes to books."""
- __tablename__ = 'contributors'
+ __tablename__ = "contributors"
id = Column(Integer, primary_key=True)
# Standard identifiers for this contributor.
@@ -42,7 +28,7 @@ class Contributor(Base):
# This is the name by which this person is known in the original
# catalog. It is sortable, e.g. "Twain, Mark".
- _sort_name = Column('sort_name', Unicode, index=True)
+ _sort_name = Column("sort_name", Unicode, index=True)
aliases = Column(ARRAY(Unicode), default=[])
# This is the name we will display publicly. Ideally it will be
@@ -82,25 +68,25 @@ class Contributor(Base):
FOREWORD_ROLE = "Foreword Author"
AFTERWORD_ROLE = "Afterword Author"
COLOPHON_ROLE = "Colophon Author"
- UNKNOWN_ROLE = 'Unknown'
- DIRECTOR_ROLE = 'Director'
- PRODUCER_ROLE = 'Producer'
- EXECUTIVE_PRODUCER_ROLE = 'Executive Producer'
- ACTOR_ROLE = 'Actor'
- LYRICIST_ROLE = 'Lyricist'
- CONTRIBUTOR_ROLE = 'Contributor'
- COMPOSER_ROLE = 'Composer'
- NARRATOR_ROLE = 'Narrator'
- COMPILER_ROLE = 'Compiler'
- ADAPTER_ROLE = 'Adapter'
- PERFORMER_ROLE = 'Performer'
- MUSICIAN_ROLE = 'Musician'
- ASSOCIATED_ROLE = 'Associated name'
- COLLABORATOR_ROLE = 'Collaborator'
- ENGINEER_ROLE = 'Engineer'
- COPYRIGHT_HOLDER_ROLE = 'Copyright holder'
- TRANSCRIBER_ROLE = 'Transcriber'
- DESIGNER_ROLE = 'Designer'
+ UNKNOWN_ROLE = "Unknown"
+ DIRECTOR_ROLE = "Director"
+ PRODUCER_ROLE = "Producer"
+ EXECUTIVE_PRODUCER_ROLE = "Executive Producer"
+ ACTOR_ROLE = "Actor"
+ LYRICIST_ROLE = "Lyricist"
+ CONTRIBUTOR_ROLE = "Contributor"
+ COMPOSER_ROLE = "Composer"
+ NARRATOR_ROLE = "Narrator"
+ COMPILER_ROLE = "Compiler"
+ ADAPTER_ROLE = "Adapter"
+ PERFORMER_ROLE = "Performer"
+ MUSICIAN_ROLE = "Musician"
+ ASSOCIATED_ROLE = "Associated name"
+ COLLABORATOR_ROLE = "Collaborator"
+ ENGINEER_ROLE = "Engineer"
+ COPYRIGHT_HOLDER_ROLE = "Copyright holder"
+ TRANSCRIBER_ROLE = "Transcriber"
+ DESIGNER_ROLE = "Designer"
AUTHOR_ROLES = set([PRIMARY_AUTHOR_ROLE, AUTHOR_ROLE])
# Map our recognized roles to MARC relators.
@@ -108,55 +94,63 @@ class Contributor(Base):
#
# This is used when crediting contributors in OPDS feeds.
MARC_ROLE_CODES = {
- ACTOR_ROLE : 'act',
- ADAPTER_ROLE : 'adp',
- AFTERWORD_ROLE : 'aft',
- ARTIST_ROLE : 'art',
- ASSOCIATED_ROLE : 'asn',
- AUTHOR_ROLE : 'aut', # Joint author: USE Author
- COLLABORATOR_ROLE : 'ctb', # USE Contributor
- COLOPHON_ROLE : 'aft', # Author of afterword, colophon, etc.
- COMPILER_ROLE : 'com',
- COMPOSER_ROLE : 'cmp',
- CONTRIBUTOR_ROLE : 'ctb',
- COPYRIGHT_HOLDER_ROLE : 'cph',
- DESIGNER_ROLE : 'dsr',
- DIRECTOR_ROLE : 'drt',
- EDITOR_ROLE : 'edt',
- ENGINEER_ROLE : 'eng',
- EXECUTIVE_PRODUCER_ROLE : 'pro',
- FOREWORD_ROLE : 'wpr', # Writer of preface
- ILLUSTRATOR_ROLE : 'ill',
- INTRODUCTION_ROLE : 'win',
- LYRICIST_ROLE : 'lyr',
- MUSICIAN_ROLE : 'mus',
- NARRATOR_ROLE : 'nrt',
- PERFORMER_ROLE : 'prf',
- PHOTOGRAPHER_ROLE : 'pht',
- PRIMARY_AUTHOR_ROLE : 'aut',
- PRODUCER_ROLE : 'pro',
- TRANSCRIBER_ROLE : 'trc',
- TRANSLATOR_ROLE : 'trl',
- LETTERER_ROLE : 'ctb',
- PENCILER_ROLE : 'ctb',
- COLORIST_ROLE : 'clr',
- INKER_ROLE : 'ctb',
- UNKNOWN_ROLE : 'asn',
+ ACTOR_ROLE: "act",
+ ADAPTER_ROLE: "adp",
+ AFTERWORD_ROLE: "aft",
+ ARTIST_ROLE: "art",
+ ASSOCIATED_ROLE: "asn",
+ AUTHOR_ROLE: "aut", # Joint author: USE Author
+ COLLABORATOR_ROLE: "ctb", # USE Contributor
+ COLOPHON_ROLE: "aft", # Author of afterword, colophon, etc.
+ COMPILER_ROLE: "com",
+ COMPOSER_ROLE: "cmp",
+ CONTRIBUTOR_ROLE: "ctb",
+ COPYRIGHT_HOLDER_ROLE: "cph",
+ DESIGNER_ROLE: "dsr",
+ DIRECTOR_ROLE: "drt",
+ EDITOR_ROLE: "edt",
+ ENGINEER_ROLE: "eng",
+ EXECUTIVE_PRODUCER_ROLE: "pro",
+ FOREWORD_ROLE: "wpr", # Writer of preface
+ ILLUSTRATOR_ROLE: "ill",
+ INTRODUCTION_ROLE: "win",
+ LYRICIST_ROLE: "lyr",
+ MUSICIAN_ROLE: "mus",
+ NARRATOR_ROLE: "nrt",
+ PERFORMER_ROLE: "prf",
+ PHOTOGRAPHER_ROLE: "pht",
+ PRIMARY_AUTHOR_ROLE: "aut",
+ PRODUCER_ROLE: "pro",
+ TRANSCRIBER_ROLE: "trc",
+ TRANSLATOR_ROLE: "trl",
+ LETTERER_ROLE: "ctb",
+ PENCILER_ROLE: "ctb",
+ COLORIST_ROLE: "clr",
+ INKER_ROLE: "ctb",
+ UNKNOWN_ROLE: "asn",
}
# People from these roles can be put into the 'author' slot if no
# author proper is given.
AUTHOR_SUBSTITUTE_ROLES = [
- EDITOR_ROLE, COMPILER_ROLE, COMPOSER_ROLE, DIRECTOR_ROLE,
- CONTRIBUTOR_ROLE, TRANSLATOR_ROLE, ADAPTER_ROLE, PHOTOGRAPHER_ROLE,
- ARTIST_ROLE, LYRICIST_ROLE, COPYRIGHT_HOLDER_ROLE
+ EDITOR_ROLE,
+ COMPILER_ROLE,
+ COMPOSER_ROLE,
+ DIRECTOR_ROLE,
+ CONTRIBUTOR_ROLE,
+ TRANSLATOR_ROLE,
+ ADAPTER_ROLE,
+ PHOTOGRAPHER_ROLE,
+ ARTIST_ROLE,
+ LYRICIST_ROLE,
+ COPYRIGHT_HOLDER_ROLE,
]
PERFORMER_ROLES = [ACTOR_ROLE, PERFORMER_ROLE, NARRATOR_ROLE, MUSICIAN_ROLE]
# Extra fields
- BIRTH_DATE = 'birthDate'
- DEATH_DATE = 'deathDate'
+ BIRTH_DATE = "birthDate"
+ DEATH_DATE = "deathDate"
def __repr__(self):
extra = ""
@@ -174,8 +168,17 @@ def author_contributor_tiers(cls):
yield cls.PERFORMER_ROLES
@classmethod
- def lookup(cls, _db, sort_name=None, viaf=None, lc=None, aliases=None,
- extra=None, create_new=True, name=None):
+ def lookup(
+ cls,
+ _db,
+ sort_name=None,
+ viaf=None,
+ lc=None,
+ aliases=None,
+ extra=None,
+ create_new=True,
+ name=None,
+ ):
"""Find or create a record (or list of records) for the given Contributor.
:return: A tuple of found Contributor (or None), and a boolean flag
indicating if new Contributor database object has beed created.
@@ -189,15 +192,16 @@ def lookup(cls, _db, sort_name=None, viaf=None, lc=None, aliases=None,
extra = extra or dict()
create_method_kwargs = {
- Contributor.sort_name.name : sort_name,
- Contributor.aliases.name : aliases,
- Contributor.extra.name : extra
+ Contributor.sort_name.name: sort_name,
+ Contributor.aliases.name: aliases,
+ Contributor.extra.name: extra,
}
if not sort_name and not lc and not viaf:
raise ValueError(
"Cannot look up a Contributor without any identifying "
- "information whatsoever!")
+ "information whatsoever!"
+ )
if sort_name and not lc and not viaf:
# We will not create a Contributor based solely on a name
@@ -207,7 +211,7 @@ def lookup(cls, _db, sort_name=None, viaf=None, lc=None, aliases=None,
# return all of them.
#
# We currently do not check aliases when doing name lookups.
- q = _db.query(Contributor).filter(Contributor.sort_name==sort_name)
+ q = _db.query(Contributor).filter(Contributor.sort_name == sort_name)
contributors = q.all()
if contributors:
return contributors, new
@@ -233,8 +237,10 @@ def lookup(cls, _db, sort_name=None, viaf=None, lc=None, aliases=None,
if create_new:
contributor, new = get_one_or_create(
- _db, Contributor, create_method_kwargs=create_method_kwargs,
- on_multiple='interchangeable',
+ _db,
+ Contributor,
+ create_method_kwargs=create_method_kwargs,
+ on_multiple="interchangeable",
**query
)
if contributor:
@@ -246,14 +252,13 @@ def lookup(cls, _db, sort_name=None, viaf=None, lc=None, aliases=None,
return contributors, new
-
@property
def sort_name(self):
return self._sort_name
@sort_name.setter
def sort_name(self, new_sort_name):
- """ See if the passed-in value is in the prescribed Last, First format.
+ """See if the passed-in value is in the prescribed Last, First format.
If it is, great, set the self._sprt_name to the new value.
If new value is not in correct format, then
@@ -280,8 +285,7 @@ def sort_name(self, new_sort_name):
self._sort_name = new_sort_name
# tell SQLAlchemy to use the sort_name setter for ort_name, not _sort_name, after all.
- sort_name = synonym('_sort_name', descriptor=sort_name)
-
+ sort_name = synonym("_sort_name", descriptor=sort_name)
def merge_into(self, destination):
"""Two Contributor records should be the same.
@@ -303,7 +307,7 @@ def merge_into(self, destination):
self,
self.viaf,
destination,
- destination.viaf
+ destination.viaf,
)
# make sure we're not losing any names we know for the contributor
@@ -343,9 +347,10 @@ def merge_into(self, destination):
# the old contribution) or not (in which case we switch the
# contributor ID)?
existing_record = _db.query(Contribution).filter(
- Contribution.contributor_id==destination.id,
- Contribution.edition_id==contribution.edition.id,
- Contribution.role==contribution.role)
+ Contribution.contributor_id == destination.id,
+ Contribution.edition_id == contribution.edition.id,
+ Contribution.role == contribution.role,
+ )
if existing_record.count():
_db.delete(contribution)
else:
@@ -360,14 +365,16 @@ def merge_into(self, destination):
ALPHABETIC = re.compile("[a-zA-z]")
NUMBERS = re.compile("[0-9]")
- DATE_RES = [re.compile("\(?" + x + "\)?") for x in
- ("[0-9?]+-",
- "[0-9]+st cent",
- "[0-9]+nd cent",
- "[0-9]+th cent",
- "\bcirca",)
- ]
-
+ DATE_RES = [
+ re.compile("\(?" + x + "\)?")
+ for x in (
+ "[0-9?]+-",
+ "[0-9]+st cent",
+ "[0-9]+nd cent",
+ "[0-9]+th cent",
+ "\bcirca",
+ )
+ ]
def default_names(self, default_display_name=None):
"""Attempt to derive a family name ("Twain") and a display name ("Mark
@@ -394,15 +401,16 @@ def _default_names(cls, name, default_display_name=None):
name = cls.PARENTHETICAL.sub("", name)
name = name.strip()
- if ', ' in name:
+ if ", " in name:
# This is probably a personal name.
parts = name.split(", ")
if len(parts) > 2:
# The most likely scenario is that the final part
# of the name is a date or a set of dates. If this
# seems true, just delete that part.
- if (cls.NUMBERS.search(parts[-1])
- or not cls.ALPHABETIC.search(parts[-1])):
+ if cls.NUMBERS.search(parts[-1]) or not cls.ALPHABETIC.search(
+ parts[-1]
+ ):
parts = parts[:-1]
# The final part of the name may have a date or a set
# of dates at the end. If so, remove it from that string.
@@ -410,7 +418,7 @@ def _default_names(cls, name, default_display_name=None):
for date_re in cls.DATE_RES:
m = date_re.search(final)
if m:
- new_part = final[:m.start()].strip()
+ new_part = final[: m.start()].strip()
if new_part:
parts[-1] = new_part
else:
@@ -419,9 +427,12 @@ def _default_names(cls, name, default_display_name=None):
family_name = parts[0]
p = parts[-1].lower()
- if (p in ('llc', 'inc', 'inc.')
- or p.endswith("company") or p.endswith(" co.")
- or p.endswith(" co")):
+ if (
+ p in ("llc", "inc", "inc.")
+ or p.endswith("company")
+ or p.endswith(" co.")
+ or p.endswith(" co")
+ ):
# No, this is a corporate name that contains a comma.
# It can't be split on the comma, so don't bother.
family_name = None
@@ -436,7 +447,7 @@ def _default_names(cls, name, default_display_name=None):
display_name = parts[1] + " " + parts[0]
if len(parts) > 2:
# There's a leftover bit.
- if parts[2] in ('Mrs.', 'Mrs', 'Sir'):
+ if parts[2] in ("Mrs.", "Mrs", "Sir"):
# "Jones, Bob, Mrs."
# => "Mrs. Bob Jones"
display_name = parts[2] + " " + display_name
@@ -451,16 +462,15 @@ def _default_names(cls, name, default_display_name=None):
return family_name, display_name
+
class Contribution(Base):
"""A contribution made by a Contributor to a Edition."""
- __tablename__ = 'contributions'
+
+ __tablename__ = "contributions"
id = Column(Integer, primary_key=True)
- edition_id = Column(Integer, ForeignKey('editions.id'), index=True,
- nullable=False)
- contributor_id = Column(Integer, ForeignKey('contributors.id'), index=True,
- nullable=False)
- role = Column(Unicode, index=True, nullable=False)
- __table_args__ = (
- UniqueConstraint('edition_id', 'contributor_id', 'role'),
+ edition_id = Column(Integer, ForeignKey("editions.id"), index=True, nullable=False)
+ contributor_id = Column(
+ Integer, ForeignKey("contributors.id"), index=True, nullable=False
)
-
+ role = Column(Unicode, index=True, nullable=False)
+ __table_args__ = (UniqueConstraint("edition_id", "contributor_id", "role"),)
diff --git a/model/coverage.py b/model/coverage.py
index 7808eda60..564e28fea 100644
--- a/model/coverage.py
+++ b/model/coverage.py
@@ -13,29 +13,21 @@
UniqueConstraint,
)
from sqlalchemy.orm.session import Session
-from sqlalchemy.sql.expression import (
- and_,
- or_,
- literal,
- literal_column,
-)
+from sqlalchemy.sql.expression import and_, literal, literal_column, or_
-from . import (
- Base,
- get_one,
- get_one_or_create,
-)
from ..util.datetime_helpers import utc_now
+from . import Base, get_one, get_one_or_create
+
class BaseCoverageRecord(object):
"""Contains useful constants used by both CoverageRecord and
WorkCoverageRecord.
"""
- SUCCESS = 'success'
- TRANSIENT_FAILURE = 'transient failure'
- PERSISTENT_FAILURE = 'persistent failure'
- REGISTERED = 'registered'
+ SUCCESS = "success"
+ TRANSIENT_FAILURE = "transient failure"
+ PERSISTENT_FAILURE = "persistent failure"
+ REGISTERED = "registered"
ALL_STATUSES = [REGISTERED, SUCCESS, TRANSIENT_FAILURE, PERSISTENT_FAILURE]
@@ -47,12 +39,18 @@ class BaseCoverageRecord(object):
# as present if it ended in transient failure.
DEFAULT_COUNT_AS_COVERED = [SUCCESS, PERSISTENT_FAILURE]
- status_enum = Enum(SUCCESS, TRANSIENT_FAILURE, PERSISTENT_FAILURE,
- REGISTERED, name='coverage_status')
+ status_enum = Enum(
+ SUCCESS,
+ TRANSIENT_FAILURE,
+ PERSISTENT_FAILURE,
+ REGISTERED,
+ name="coverage_status",
+ )
@classmethod
- def not_covered(cls, count_as_covered=None,
- count_as_not_covered_if_covered_before=None):
+ def not_covered(
+ cls, count_as_covered=None, count_as_not_covered_if_covered_before=None
+ ):
"""Filter a query to find only items without coverage records.
:param count_as_covered: A list of constants that indicate
@@ -70,13 +68,11 @@ def not_covered(cls, count_as_covered=None,
# If there is no coverage record, then of course the item is
# not covered.
- missing = cls.id==None
+ missing = cls.id == None
# If we're looking for specific coverage statuses, then a
# record does not count if it has some other status.
- missing = or_(
- missing, ~cls.status.in_(count_as_covered)
- )
+ missing = or_(missing, ~cls.status.in_(count_as_covered))
# If the record's timestamp is before the cutoff time, we
# don't count it as covered, regardless of which status it
@@ -94,7 +90,7 @@ class Timestamp(Base):
and general scripts.
"""
- __tablename__ = 'timestamps'
+ __tablename__ = "timestamps"
MONITOR_TYPE = "monitor"
COVERAGE_PROVIDER_TYPE = "coverage_provider"
@@ -106,7 +102,9 @@ class Timestamp(Base):
CLEAR_VALUE = object()
service_type_enum = Enum(
- MONITOR_TYPE, COVERAGE_PROVIDER_TYPE, SCRIPT_TYPE,
+ MONITOR_TYPE,
+ COVERAGE_PROVIDER_TYPE,
+ SCRIPT_TYPE,
name="service_type",
)
@@ -123,8 +121,9 @@ class Timestamp(Base):
# The collection, if any, associated with this service -- some services
# run separately on a number of collections.
- collection_id = Column(Integer, ForeignKey('collections.id'),
- index=True, nullable=True)
+ collection_id = Column(
+ Integer, ForeignKey("collections.id"), index=True, nullable=True
+ )
# The last time the service _started_ running.
start = Column(DateTime(timezone=True), nullable=True)
@@ -150,7 +149,7 @@ class Timestamp(Base):
exception = Column(Unicode, nullable=True)
def __repr__(self):
- format = '%b %d, %Y at %H:%M'
+ format = "%b %d, %Y at %H:%M"
if self.finish:
finish = self.finish.strftime(format)
else:
@@ -165,21 +164,27 @@ def __repr__(self):
collection = None
message = "" % (
- self.service, collection, start, finish, self.counter
+ self.service,
+ collection,
+ start,
+ finish,
+ self.counter,
)
return message
@classmethod
def lookup(cls, _db, service, service_type, collection):
return get_one(
- _db, Timestamp, service=service, service_type=service_type,
- collection=collection
+ _db,
+ Timestamp,
+ service=service,
+ service_type=service_type,
+ collection=collection,
)
@classmethod
def value(cls, _db, service, service_type, collection):
- """Return the current value of the given Timestamp, if it exists.
- """
+ """Return the current value of the given Timestamp, if it exists."""
stamp = cls.lookup(_db, service, service_type, collection)
if not stamp:
return None
@@ -187,8 +192,16 @@ def value(cls, _db, service, service_type, collection):
@classmethod
def stamp(
- cls, _db, service, service_type, collection=None, start=None,
- finish=None, achievements=None, counter=None, exception=None
+ cls,
+ _db,
+ service,
+ service_type,
+ collection=None,
+ start=None,
+ finish=None,
+ achievements=None,
+ counter=None,
+ exception=None,
):
"""Set a Timestamp, creating it if necessary.
@@ -221,7 +234,8 @@ def stamp(
elif finish is None:
finish = start
stamp, was_new = get_one_or_create(
- _db, Timestamp,
+ _db,
+ Timestamp,
service=service,
service_type=service_type,
collection=collection,
@@ -232,8 +246,9 @@ def stamp(
_db.commit()
return stamp
- def update(self, start=None, finish=None, achievements=None,
- counter=None, exception=None):
+ def update(
+ self, start=None, finish=None, achievements=None, counter=None, exception=None
+ ):
"""Use a single method to update all the fields that aren't
used to identify a Timestamp.
"""
@@ -270,36 +285,36 @@ def update(self, start=None, finish=None, achievements=None,
def to_data(self):
"""Convert this Timestamp to an unfinalized TimestampData."""
from ..metadata_layer import TimestampData
+
return TimestampData(
- start=self.start, finish=self.finish,
- achievements=self.achievements, counter=self.counter
+ start=self.start,
+ finish=self.finish,
+ achievements=self.achievements,
+ counter=self.counter,
)
- __table_args__ = (
- UniqueConstraint('service', 'collection_id'),
- )
+ __table_args__ = (UniqueConstraint("service", "collection_id"),)
+
class CoverageRecord(Base, BaseCoverageRecord):
"""A record of a Identifier being used as input into some process."""
- __tablename__ = 'coveragerecords'
- SET_EDITION_METADATA_OPERATION = 'set-edition-metadata'
- CHOOSE_COVER_OPERATION = 'choose-cover'
- REAP_OPERATION = 'reap'
- IMPORT_OPERATION = 'import'
- RESOLVE_IDENTIFIER_OPERATION = 'resolve-identifier'
- REPAIR_SORT_NAME_OPERATION = 'repair-sort-name'
- METADATA_UPLOAD_OPERATION = 'metadata-upload'
+ __tablename__ = "coveragerecords"
+
+ SET_EDITION_METADATA_OPERATION = "set-edition-metadata"
+ CHOOSE_COVER_OPERATION = "choose-cover"
+ REAP_OPERATION = "reap"
+ IMPORT_OPERATION = "import"
+ RESOLVE_IDENTIFIER_OPERATION = "resolve-identifier"
+ REPAIR_SORT_NAME_OPERATION = "repair-sort-name"
+ METADATA_UPLOAD_OPERATION = "metadata-upload"
id = Column(Integer, primary_key=True)
- identifier_id = Column(
- Integer, ForeignKey('identifiers.id'), index=True)
+ identifier_id = Column(Integer, ForeignKey("identifiers.id"), index=True)
# If applicable, this is the ID of the data source that took the
# Identifier as input.
- data_source_id = Column(
- Integer, ForeignKey('datasources.id')
- )
+ data_source_id = Column(Integer, ForeignKey("datasources.id"))
operation = Column(String(255), default=None)
timestamp = Column(DateTime(timezone=True), index=True)
@@ -310,19 +325,24 @@ class CoverageRecord(Base, BaseCoverageRecord):
# If applicable, this is the ID of the collection for which
# coverage has taken place. This is currently only applicable
# for Metadata Wrangler coverage.
- collection_id = Column(
- Integer, ForeignKey('collections.id'), nullable=True
- )
+ collection_id = Column(Integer, ForeignKey("collections.id"), nullable=True)
__table_args__ = (
Index(
- 'ix_identifier_id_data_source_id_operation',
- identifier_id, data_source_id, operation,
- unique=True, postgresql_where=collection_id.is_(None)),
+ "ix_identifier_id_data_source_id_operation",
+ identifier_id,
+ data_source_id,
+ operation,
+ unique=True,
+ postgresql_where=collection_id.is_(None),
+ ),
Index(
- 'ix_identifier_id_data_source_id_operation_collection_id',
- identifier_id, data_source_id, operation, collection_id,
- unique=True
+ "ix_identifier_id_data_source_id_operation_collection_id",
+ identifier_id,
+ data_source_id,
+ operation,
+ collection_id,
+ unique=True,
),
)
@@ -335,11 +355,11 @@ def human_readable(self, template):
if self.operation:
operation = ' operation="%s"' % self.operation
else:
- operation = ''
+ operation = ""
if self.exception:
exception = ' exception="%s"' % self.exception
else:
- exception = ''
+ exception = ""
return template % dict(
timestamp=self.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
identifier_type=self.identifier.type,
@@ -351,8 +371,9 @@ def human_readable(self, template):
)
@classmethod
- def lookup(cls, edition_or_identifier, data_source, operation=None,
- collection=None):
+ def lookup(
+ cls, edition_or_identifier, data_source, operation=None, collection=None
+ ):
from .datasource import DataSource
from .edition import Edition
from .identifier import Identifier
@@ -363,24 +384,31 @@ def lookup(cls, edition_or_identifier, data_source, operation=None,
elif isinstance(edition_or_identifier, Edition):
identifier = edition_or_identifier.primary_identifier
else:
- raise ValueError(
- "Cannot look up a coverage record for %r." % edition)
+ raise ValueError("Cannot look up a coverage record for %r." % edition)
if isinstance(data_source, (bytes, str)):
data_source = DataSource.lookup(_db, data_source)
return get_one(
- _db, CoverageRecord,
+ _db,
+ CoverageRecord,
identifier=identifier,
data_source=data_source,
operation=operation,
collection=collection,
- on_multiple='interchangeable',
+ on_multiple="interchangeable",
)
@classmethod
- def add_for(self, edition, data_source, operation=None, timestamp=None,
- status=BaseCoverageRecord.SUCCESS, collection=None):
+ def add_for(
+ self,
+ edition,
+ data_source,
+ operation=None,
+ timestamp=None,
+ status=BaseCoverageRecord.SUCCESS,
+ collection=None,
+ ):
from .edition import Edition
from .identifier import Identifier
@@ -390,24 +418,31 @@ def add_for(self, edition, data_source, operation=None, timestamp=None,
elif isinstance(edition, Edition):
identifier = edition.primary_identifier
else:
- raise ValueError(
- "Cannot create a coverage record for %r." % edition)
+ raise ValueError("Cannot create a coverage record for %r." % edition)
timestamp = timestamp or utc_now()
coverage_record, is_new = get_one_or_create(
- _db, CoverageRecord,
+ _db,
+ CoverageRecord,
identifier=identifier,
data_source=data_source,
operation=operation,
collection=collection,
- on_multiple='interchangeable'
+ on_multiple="interchangeable",
)
coverage_record.status = status
coverage_record.timestamp = timestamp
return coverage_record, is_new
@classmethod
- def bulk_add(cls, identifiers, data_source, operation=None, timestamp=None,
- status=BaseCoverageRecord.SUCCESS, exception=None, collection=None,
+ def bulk_add(
+ cls,
+ identifiers,
+ data_source,
+ operation=None,
+ timestamp=None,
+ status=BaseCoverageRecord.SUCCESS,
+ exception=None,
+ collection=None,
force=False,
):
"""Create and update CoverageRecords so that every Identifier in
@@ -424,9 +459,9 @@ def bulk_add(cls, identifiers, data_source, operation=None, timestamp=None,
identifier_ids = [i.id for i in identifiers]
equivalent_record = and_(
- cls.operation==operation,
- cls.data_source==data_source,
- cls.collection==collection,
+ cls.operation == operation,
+ cls.data_source == data_source,
+ cls.collection == collection,
)
updated_or_created_results = list()
@@ -434,18 +469,27 @@ def bulk_add(cls, identifiers, data_source, operation=None, timestamp=None,
# Make sure that works that previously had a
# CoverageRecord for this operation have their timestamp
# and status updated.
- update = cls.__table__.update().where(and_(
- cls.identifier_id.in_(identifier_ids),
- equivalent_record,
- )).values(
- dict(timestamp=timestamp, status=status, exception=exception)
- ).returning(cls.id, cls.identifier_id)
+ update = (
+ cls.__table__.update()
+ .where(
+ and_(
+ cls.identifier_id.in_(identifier_ids),
+ equivalent_record,
+ )
+ )
+ .values(dict(timestamp=timestamp, status=status, exception=exception))
+ .returning(cls.id, cls.identifier_id)
+ )
updated_or_created_results = _db.execute(update).fetchall()
- already_covered = _db.query(cls.id, cls.identifier_id).filter(
- equivalent_record,
- cls.identifier_id.in_(identifier_ids),
- ).subquery()
+ already_covered = (
+ _db.query(cls.id, cls.identifier_id)
+ .filter(
+ equivalent_record,
+ cls.identifier_id.in_(identifier_ids),
+ )
+ .subquery()
+ )
# Make sure that any identifiers that need a CoverageRecord get one.
# The SELECT part of the INSERT...SELECT query.
@@ -454,33 +498,43 @@ def bulk_add(cls, identifiers, data_source, operation=None, timestamp=None,
if collection:
collection_id = collection.id
- new_records = _db.query(
- Identifier.id.label('identifier_id'),
- literal(operation, type_=String(255)).label('operation'),
- literal(timestamp, type_=DateTime).label('timestamp'),
- literal(status, type_=BaseCoverageRecord.status_enum).label('status'),
- literal(exception, type_=Unicode).label('exception'),
- literal(data_source_id, type_=Integer).label('data_source_id'),
- literal(collection_id, type_=Integer).label('collection_id'),
- ).select_from(Identifier).outerjoin(
- already_covered, Identifier.id==already_covered.c.identifier_id,
- ).filter(already_covered.c.id==None)
+ new_records = (
+ _db.query(
+ Identifier.id.label("identifier_id"),
+ literal(operation, type_=String(255)).label("operation"),
+ literal(timestamp, type_=DateTime).label("timestamp"),
+ literal(status, type_=BaseCoverageRecord.status_enum).label("status"),
+ literal(exception, type_=Unicode).label("exception"),
+ literal(data_source_id, type_=Integer).label("data_source_id"),
+ literal(collection_id, type_=Integer).label("collection_id"),
+ )
+ .select_from(Identifier)
+ .outerjoin(
+ already_covered,
+ Identifier.id == already_covered.c.identifier_id,
+ )
+ .filter(already_covered.c.id == None)
+ )
new_records = new_records.filter(Identifier.id.in_(identifier_ids))
# The INSERT part.
- insert = cls.__table__.insert().from_select(
- [
- literal_column('identifier_id'),
- literal_column('operation'),
- literal_column('timestamp'),
- literal_column('status'),
- literal_column('exception'),
- literal_column('data_source_id'),
- literal_column('collection_id'),
- ],
- new_records
- ).returning(cls.id, cls.identifier_id)
+ insert = (
+ cls.__table__.insert()
+ .from_select(
+ [
+ literal_column("identifier_id"),
+ literal_column("operation"),
+ literal_column("timestamp"),
+ literal_column("status"),
+ literal_column("exception"),
+ literal_column("data_source_id"),
+ literal_column("collection_id"),
+ ],
+ new_records,
+ )
+ .returning(cls.id, cls.identifier_id)
+ )
inserts = _db.execute(insert).fetchall()
@@ -496,15 +550,24 @@ def bulk_add(cls, identifiers, data_source, operation=None, timestamp=None,
impacted_identifier_ids = [r[1] for r in updated_or_created_results]
if new_and_updated_record_ids:
- new_records = _db.query(cls).filter(cls.id.in_(
- new_and_updated_record_ids
- )).all()
+ new_records = (
+ _db.query(cls).filter(cls.id.in_(new_and_updated_record_ids)).all()
+ )
- ignored_identifiers = [i for i in identifiers if i.id not in impacted_identifier_ids]
+ ignored_identifiers = [
+ i for i in identifiers if i.id not in impacted_identifier_ids
+ ]
return new_records, ignored_identifiers
-Index("ix_coveragerecords_data_source_id_operation_identifier_id", CoverageRecord.data_source_id, CoverageRecord.operation, CoverageRecord.identifier_id)
+
+Index(
+ "ix_coveragerecords_data_source_id_operation_identifier_id",
+ CoverageRecord.data_source_id,
+ CoverageRecord.operation,
+ CoverageRecord.identifier_id,
+)
+
class WorkCoverageRecord(Base, BaseCoverageRecord):
"""A record of some operation that was performed on a Work.
@@ -513,18 +576,19 @@ class WorkCoverageRecord(Base, BaseCoverageRecord):
we presume that all the operations involve internal work only,
and as such there is no data_source_id.
"""
- __tablename__ = 'workcoveragerecords'
- CHOOSE_EDITION_OPERATION = 'choose-edition'
- CLASSIFY_OPERATION = 'classify'
- SUMMARY_OPERATION = 'summary'
- QUALITY_OPERATION = 'quality'
- GENERATE_OPDS_OPERATION = 'generate-opds'
- GENERATE_MARC_OPERATION = 'generate-marc'
- UPDATE_SEARCH_INDEX_OPERATION = 'update-search-index'
+ __tablename__ = "workcoveragerecords"
+
+ CHOOSE_EDITION_OPERATION = "choose-edition"
+ CLASSIFY_OPERATION = "classify"
+ SUMMARY_OPERATION = "summary"
+ QUALITY_OPERATION = "quality"
+ GENERATE_OPDS_OPERATION = "generate-opds"
+ GENERATE_MARC_OPERATION = "generate-marc"
+ UPDATE_SEARCH_INDEX_OPERATION = "update-search-index"
id = Column(Integer, primary_key=True)
- work_id = Column(Integer, ForeignKey('works.id'), index=True)
+ work_id = Column(Integer, ForeignKey("works.id"), index=True)
operation = Column(String(255), index=True, default=None)
timestamp = Column(DateTime(timezone=True), index=True)
@@ -532,50 +596,56 @@ class WorkCoverageRecord(Base, BaseCoverageRecord):
status = Column(BaseCoverageRecord.status_enum, index=True)
exception = Column(Unicode, index=True)
- __table_args__ = (
- UniqueConstraint('work_id', 'operation'),
- )
+ __table_args__ = (UniqueConstraint("work_id", "operation"),)
def __repr__(self):
if self.exception:
exception = ' exception="%s"' % self.exception
else:
- exception = ''
+ exception = ""
template = ''
return template % (
- self.work_id, self.operation,
+ self.work_id,
+ self.operation,
self.timestamp.strftime("%Y-%m-%d %H:%M:%S"),
- exception
+ exception,
)
@classmethod
def lookup(self, work, operation):
_db = Session.object_session(work)
return get_one(
- _db, WorkCoverageRecord,
+ _db,
+ WorkCoverageRecord,
work=work,
operation=operation,
- on_multiple='interchangeable',
+ on_multiple="interchangeable",
)
@classmethod
- def add_for(self, work, operation, timestamp=None,
- status=CoverageRecord.SUCCESS):
+ def add_for(self, work, operation, timestamp=None, status=CoverageRecord.SUCCESS):
_db = Session.object_session(work)
timestamp = timestamp or utc_now()
coverage_record, is_new = get_one_or_create(
- _db, WorkCoverageRecord,
+ _db,
+ WorkCoverageRecord,
work=work,
operation=operation,
- on_multiple='interchangeable'
+ on_multiple="interchangeable",
)
coverage_record.status = status
coverage_record.timestamp = timestamp
return coverage_record, is_new
@classmethod
- def bulk_add(self, works, operation, timestamp=None,
- status=CoverageRecord.SUCCESS, exception=None):
+ def bulk_add(
+ self,
+ works,
+ operation,
+ timestamp=None,
+ status=CoverageRecord.SUCCESS,
+ exception=None,
+ ):
"""Create and update WorkCoverageRecords so that every Work in
`works` has an identical record.
"""
@@ -591,10 +661,16 @@ def bulk_add(self, works, operation, timestamp=None,
# Make sure that works that previously had a
# WorkCoverageRecord for this operation have their timestamp
# and status updated.
- update = WorkCoverageRecord.__table__.update().where(
- and_(WorkCoverageRecord.work_id.in_(work_ids),
- WorkCoverageRecord.operation==operation)
- ).values(dict(timestamp=timestamp, status=status, exception=exception))
+ update = (
+ WorkCoverageRecord.__table__.update()
+ .where(
+ and_(
+ WorkCoverageRecord.work_id.in_(work_ids),
+ WorkCoverageRecord.operation == operation,
+ )
+ )
+ .values(dict(timestamp=timestamp, status=status, exception=exception))
+ )
_db.execute(update)
# Make sure that any works that are missing a
@@ -602,38 +678,39 @@ def bulk_add(self, works, operation, timestamp=None,
# Works that already have a WorkCoverageRecord will be ignored
# by the INSERT but handled by the UPDATE.
- already_covered = _db.query(WorkCoverageRecord.work_id).select_from(
- WorkCoverageRecord).filter(
- WorkCoverageRecord.work_id.in_(work_ids)
- ).filter(
- WorkCoverageRecord.operation==operation
- )
+ already_covered = (
+ _db.query(WorkCoverageRecord.work_id)
+ .select_from(WorkCoverageRecord)
+ .filter(WorkCoverageRecord.work_id.in_(work_ids))
+ .filter(WorkCoverageRecord.operation == operation)
+ )
# The SELECT part of the INSERT...SELECT query.
new_records = _db.query(
- Work.id.label('work_id'),
- literal(operation, type_=String(255)).label('operation'),
- literal(timestamp, type_=DateTime).label('timestamp'),
- literal(status, type_=BaseCoverageRecord.status_enum).label('status')
- ).select_from(
- Work
- )
- new_records = new_records.filter(
- Work.id.in_(work_ids)
- ).filter(
+ Work.id.label("work_id"),
+ literal(operation, type_=String(255)).label("operation"),
+ literal(timestamp, type_=DateTime).label("timestamp"),
+ literal(status, type_=BaseCoverageRecord.status_enum).label("status"),
+ ).select_from(Work)
+ new_records = new_records.filter(Work.id.in_(work_ids)).filter(
~Work.id.in_(already_covered)
)
# The INSERT part.
insert = WorkCoverageRecord.__table__.insert().from_select(
[
- literal_column('work_id'),
- literal_column('operation'),
- literal_column('timestamp'),
- literal_column('status'),
+ literal_column("work_id"),
+ literal_column("operation"),
+ literal_column("timestamp"),
+ literal_column("status"),
],
- new_records
+ new_records,
)
_db.execute(insert)
-Index("ix_workcoveragerecords_operation_work_id", WorkCoverageRecord.operation, WorkCoverageRecord.work_id)
+
+Index(
+ "ix_workcoveragerecords_operation_work_id",
+ WorkCoverageRecord.operation,
+ WorkCoverageRecord.work_id,
+)
diff --git a/model/credential.py b/model/credential.py
index e062362a8..a0e40abc1 100644
--- a/model/credential.py
+++ b/model/credential.py
@@ -2,6 +2,7 @@
# Credential, DRMDeviceIdentifier, DelegatedPatronIdentifier
import datetime
import uuid
+
import sqlalchemy
from sqlalchemy import (
Column,
@@ -16,62 +17,70 @@
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import and_
-from . import Base, get_one, get_one_or_create
from ..util import is_session
from ..util.datetime_helpers import utc_now
+from . import Base, get_one, get_one_or_create
class Credential(Base):
"""A place to store credentials for external services."""
- __tablename__ = 'credentials'
+
+ __tablename__ = "credentials"
id = Column(Integer, primary_key=True)
- data_source_id = Column(Integer, ForeignKey('datasources.id'), index=True)
- patron_id = Column(Integer, ForeignKey('patrons.id'), index=True)
- collection_id = Column(Integer, ForeignKey('collections.id'), index=True)
+ data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True)
+ patron_id = Column(Integer, ForeignKey("patrons.id"), index=True)
+ collection_id = Column(Integer, ForeignKey("collections.id"), index=True)
type = Column(String(255), index=True)
credential = Column(String)
expires = Column(DateTime(timezone=True), index=True)
# One Credential can have many associated DRMDeviceIdentifiers.
drm_device_identifiers = relationship(
- "DRMDeviceIdentifier", backref=backref("credential", lazy='joined')
+ "DRMDeviceIdentifier", backref=backref("credential", lazy="joined")
)
__table_args__ = (
-
# Unique indexes to prevent the creation of redundant credentials.
-
# If both patron_id and collection_id are null, then (data_source_id,
# type, credential) must be unique.
Index(
"ix_credentials_data_source_id_type_token",
- data_source_id, type, credential, unique=True,
- postgresql_where=and_(patron_id==None, collection_id==None)
+ data_source_id,
+ type,
+ credential,
+ unique=True,
+ postgresql_where=and_(patron_id == None, collection_id == None),
),
-
# If patron_id is null but collection_id is not, then
# (data_source, type, collection_id) must be unique.
Index(
"ix_credentials_data_source_id_type_collection_id",
- data_source_id, type, collection_id,
- unique=True, postgresql_where=(patron_id==None)
+ data_source_id,
+ type,
+ collection_id,
+ unique=True,
+ postgresql_where=(patron_id == None),
),
-
# If collection_id is null but patron_id is not, then
# (data_source, type, patron_id) must be unique.
# (At the moment this never happens.)
Index(
"ix_credentials_data_source_id_type_patron_id",
- data_source_id, type, patron_id,
- unique=True, postgresql_where=(collection_id==None)
+ data_source_id,
+ type,
+ patron_id,
+ unique=True,
+ postgresql_where=(collection_id == None),
),
-
# If neither collection_id nor patron_id is null, then
# (data_source, type, patron_id, collection_id)
# must be unique.
Index(
"ix_credentials_data_source_id_type_patron_id_collection_id",
- data_source_id, type, patron_id, collection_id,
+ data_source_id,
+ type,
+ patron_id,
+ collection_id,
unique=True,
),
)
@@ -110,32 +119,44 @@ def _filter_invalid_credential(cls, credential, allow_persistent_token):
return None
@classmethod
- def lookup(cls, _db, data_source, token_type, patron, refresher_method,
- allow_persistent_token=False, allow_empty_token=False,
- collection=None, force_refresh=False):
+ def lookup(
+ cls,
+ _db,
+ data_source,
+ token_type,
+ patron,
+ refresher_method,
+ allow_persistent_token=False,
+ allow_empty_token=False,
+ collection=None,
+ force_refresh=False,
+ ):
from .datasource import DataSource
+
if isinstance(data_source, str):
data_source = DataSource.lookup(_db, data_source)
credential, is_new = get_one_or_create(
- _db, Credential, data_source=data_source, type=token_type, patron=patron, collection=collection)
- if (is_new
+ _db,
+ Credential,
+ data_source=data_source,
+ type=token_type,
+ patron=patron,
+ collection=collection,
+ )
+ if (
+ is_new
or force_refresh
or (not credential.expires and not allow_persistent_token)
or (not credential.credential and not allow_empty_token)
- or (credential.expires
- and credential.expires <= utc_now())):
+ or (credential.expires and credential.expires <= utc_now())
+ ):
if refresher_method:
refresher_method(credential)
return credential
@classmethod
def lookup_by_token(
- cls,
- _db,
- data_source,
- token_type,
- token,
- allow_persistent_token=False
+ cls, _db, data_source, token_type, token, allow_persistent_token=False
):
"""Look up a unique token.
Lookup will fail on expired tokens. Unless persistent tokens
@@ -143,20 +164,20 @@ def lookup_by_token(
"""
credential = get_one(
- _db, Credential, data_source=data_source, type=token_type,
- credential=token)
+ _db, Credential, data_source=data_source, type=token_type, credential=token
+ )
return cls._filter_invalid_credential(credential, allow_persistent_token)
@classmethod
def lookup_by_patron(
- cls,
- _db,
- data_source_name,
- token_type,
- patron,
- allow_persistent_token=False,
- auto_create_datasource=True
+ cls,
+ _db,
+ data_source_name,
+ token_type,
+ patron,
+ allow_persistent_token=False,
+ auto_create_datasource=True,
):
"""Look up a unique token.
Lookup will fail on expired tokens. Unless persistent tokens
@@ -197,17 +218,12 @@ def lookup_by_patron(
raise ValueError('"auto_create_datasource" argument must be boolean')
from .datasource import DataSource
+
data_source = DataSource.lookup(
- _db,
- data_source_name,
- autocreate=auto_create_datasource
+ _db, data_source_name, autocreate=auto_create_datasource
)
credential = get_one(
- _db,
- Credential,
- data_source=data_source,
- type=token_type,
- patron=patron
+ _db, Credential, data_source=data_source, type=token_type, patron=patron
)
return cls._filter_invalid_credential(credential, allow_persistent_token)
@@ -218,19 +234,12 @@ def lookup_and_expire_temporary_token(cls, _db, data_source, type, token):
credential = cls.lookup_by_token(_db, data_source, type, token)
if not credential:
return None
- credential.expires = utc_now() - datetime.timedelta(
- seconds=5)
+ credential.expires = utc_now() - datetime.timedelta(seconds=5)
return credential
@classmethod
def temporary_token_create(
- cls,
- _db,
- data_source,
- token_type,
- patron,
- duration,
- value=None
+ cls, _db, data_source, token_type, patron, duration, value=None
):
"""Create a temporary token for the given data_source/type/patron.
The token will be good for the specified `duration`.
@@ -238,62 +247,72 @@ def temporary_token_create(
expires = utc_now() + duration
token_string = value or str(uuid.uuid1())
credential, is_new = get_one_or_create(
- _db, Credential, data_source=data_source, type=token_type, patron=patron)
+ _db, Credential, data_source=data_source, type=token_type, patron=patron
+ )
# If there was already a token of this type for this patron,
# the new one overwrites the old one.
- credential.credential=token_string
- credential.expires=expires
+ credential.credential = token_string
+ credential.expires = expires
return credential, is_new
@classmethod
- def persistent_token_create(self, _db, data_source, type, patron, token_string=None):
+ def persistent_token_create(
+ self, _db, data_source, type, patron, token_string=None
+ ):
"""Create or retrieve a persistent token for the given
data_source/type/patron.
"""
if token_string is None:
token_string = str(uuid.uuid1())
credential, is_new = get_one_or_create(
- _db, Credential, data_source=data_source, type=type, patron=patron,
- create_method_kwargs=dict(credential=token_string)
+ _db,
+ Credential,
+ data_source=data_source,
+ type=type,
+ patron=patron,
+ create_method_kwargs=dict(credential=token_string),
)
- credential.expires=None
+ credential.expires = None
return credential, is_new
# A Credential may have many associated DRMDeviceIdentifiers.
def register_drm_device_identifier(self, device_identifier):
_db = Session.object_session(self)
return get_one_or_create(
- _db, DRMDeviceIdentifier,
+ _db,
+ DRMDeviceIdentifier,
credential=self,
- device_identifier=device_identifier
+ device_identifier=device_identifier,
)
def deregister_drm_device_identifier(self, device_identifier):
_db = Session.object_session(self)
device_id_obj = get_one(
- _db, DRMDeviceIdentifier,
+ _db,
+ DRMDeviceIdentifier,
credential=self,
- device_identifier=device_identifier
+ device_identifier=device_identifier,
)
if device_id_obj:
_db.delete(device_id_obj)
def __repr__(self):
- return \
- ')'.format(
- self.data_source_id,
- self.patron_id,
- self.collection_id,
- self.type,
- self.credential,
- self.expires
- )
+ return (
+ ")".format(
+ self.data_source_id,
+ self.patron_id,
+ self.collection_id,
+ self.type,
+ self.credential,
+ self.expires,
+ )
+ )
class DRMDeviceIdentifier(Base):
@@ -301,9 +320,10 @@ class DRMDeviceIdentifier(Base):
Associated with a Credential, most commonly a patron's "Identifier
for Adobe account ID purposes" Credential.
"""
- __tablename__ = 'drmdeviceidentifiers'
+
+ __tablename__ = "drmdeviceidentifiers"
id = Column(Integer, primary_key=True)
- credential_id = Column(Integer, ForeignKey('credentials.id'), index=True)
+ credential_id = Column(Integer, ForeignKey("credentials.id"), index=True)
device_identifier = Column(String(255), index=True)
@@ -314,9 +334,10 @@ class DelegatedPatronIdentifier(Base):
the SimplyE app.
Those identifiers are stored here.
"""
- ADOBE_ACCOUNT_ID = 'Adobe Account ID'
- __tablename__ = 'delegatedpatronidentifiers'
+ ADOBE_ACCOUNT_ID = "Adobe Account ID"
+
+ __tablename__ = "delegatedpatronidentifiers"
id = Column(Integer, primary_key=True)
type = Column(String(255), index=True)
library_uri = Column(String(255), index=True)
@@ -329,14 +350,11 @@ class DelegatedPatronIdentifier(Base):
# foreign library is trying to look up.
delegated_identifier = Column(String)
- __table_args__ = (
- UniqueConstraint('type', 'library_uri', 'patron_identifier'),
- )
+ __table_args__ = (UniqueConstraint("type", "library_uri", "patron_identifier"),)
@classmethod
def get_one_or_create(
- cls, _db, library_uri, patron_identifier, identifier_type,
- create_function
+ cls, _db, library_uri, patron_identifier, identifier_type, create_function
):
"""Look up the delegated identifier for the given patron. If there is
none, create one.
@@ -355,8 +373,11 @@ def get_one_or_create(
:return: A 2-tuple (DelegatedPatronIdentifier, is_new)
"""
identifier, is_new = get_one_or_create(
- _db, DelegatedPatronIdentifier, library_uri=library_uri,
- patron_identifier=patron_identifier, type=identifier_type
+ _db,
+ DelegatedPatronIdentifier,
+ library_uri=library_uri,
+ patron_identifier=patron_identifier,
+ type=identifier_type,
)
if is_new:
identifier.delegated_identifier = create_function()
diff --git a/model/customlist.py b/model/customlist.py
index cacce2654..d14c9953f 100644
--- a/model/customlist.py
+++ b/model/customlist.py
@@ -1,9 +1,10 @@
# encoding: utf-8
# CustomList, CustomListEntry
-from pdb import set_trace
-from functools import total_ordering
import logging
+from functools import total_ordering
+from pdb import set_trace
+
from sqlalchemy import (
Boolean,
Column,
@@ -15,18 +16,16 @@
UniqueConstraint,
)
from sqlalchemy.orm import relationship
-from sqlalchemy.sql.expression import or_
from sqlalchemy.orm.session import Session
+from sqlalchemy.sql.expression import or_
-from . import (
- Base,
- get_one_or_create,
-)
+from ..util.datetime_helpers import utc_now
+from . import Base, get_one_or_create
from .datasource import DataSource
from .identifier import Identifier
from .licensing import LicensePool
from .work import Work
-from ..util.datetime_helpers import utc_now
+
@total_ordering
class CustomList(Base):
@@ -34,28 +33,27 @@ class CustomList(Base):
STAFF_PICKS_NAME = "Staff Picks"
- __tablename__ = 'customlists'
+ __tablename__ = "customlists"
id = Column(Integer, primary_key=True)
primary_language = Column(Unicode, index=True)
- data_source_id = Column(Integer, ForeignKey('datasources.id'), index=True)
+ data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True)
foreign_identifier = Column(Unicode, index=True)
name = Column(Unicode, index=True)
description = Column(Unicode)
created = Column(DateTime(timezone=True), index=True)
updated = Column(DateTime(timezone=True), index=True)
responsible_party = Column(Unicode)
- library_id = Column(Integer, ForeignKey('libraries.id'), index=True, nullable=True)
+ library_id = Column(Integer, ForeignKey("libraries.id"), index=True, nullable=True)
# How many titles are in this list? This is calculated and
# cached when the list contents change.
size = Column(Integer, nullable=False, default=0)
- entries = relationship(
- "CustomListEntry", backref="customlist")
+ entries = relationship("CustomListEntry", backref="customlist")
__table_args__ = (
- UniqueConstraint('data_source_id', 'foreign_identifier'),
- UniqueConstraint('name', 'library_id'),
+ UniqueConstraint("data_source_id", "foreign_identifier"),
+ UniqueConstraint("name", "library_id"),
)
# TODO: It should be possible to associate a CustomList with an
@@ -64,7 +62,9 @@ class CustomList(Base):
def __repr__(self):
return '' % (
- self.name, self.foreign_identifier, len(self.entries)
+ self.name,
+ self.foreign_identifier,
+ len(self.entries),
)
def __eq__(self, other):
@@ -72,17 +72,17 @@ def __eq__(self, other):
if other is None or not isinstance(other, CustomList):
return False
return (self.foreign_identifier, self.name) == (
- other.foreign_identifier, other.name
+ other.foreign_identifier,
+ other.name,
)
def __lt__(self, other):
"""Comparison implementation for total_ordering."""
if other is None or not isinstance(other, CustomList):
return False
- return (
- self.foreign_identifier, self.name
- ) < (
- other.foreign_identifier, other.name
+ return (self.foreign_identifier, self.name) < (
+ other.foreign_identifier,
+ other.name,
)
@classmethod
@@ -110,15 +110,19 @@ def find(cls, _db, foreign_identifier_or_name, data_source=None, library=None):
qu = _db.query(cls)
if source_name:
qu = qu.join(CustomList.data_source).filter(
- DataSource.name==str(source_name))
+ DataSource.name == str(source_name)
+ )
qu = qu.filter(
- or_(CustomList.foreign_identifier==foreign_identifier,
- CustomList.name==foreign_identifier))
+ or_(
+ CustomList.foreign_identifier == foreign_identifier,
+ CustomList.name == foreign_identifier,
+ )
+ )
if library:
- qu = qu.filter(CustomList.library_id==library.id)
+ qu = qu.filter(CustomList.library_id == library.id)
else:
- qu = qu.filter(CustomList.library_id==None)
+ qu = qu.filter(CustomList.library_id == None)
custom_lists = qu.all()
@@ -136,8 +140,14 @@ def featured_works(self):
identifiers = [ed.primary_identifier for ed in editions]
return Work.from_identifiers(_db, identifiers)
- def add_entry(self, work_or_edition, annotation=None, first_appearance=None,
- featured=None, update_external_index=True):
+ def add_entry(
+ self,
+ work_or_edition,
+ annotation=None,
+ first_appearance=None,
+ featured=None,
+ update_external_index=True,
+ ):
"""Add a Work or Edition to a CustomList.
:param work_or_edition: A Work or an Edition. If this is a
@@ -195,13 +205,18 @@ def add_entry(self, work_or_edition, annotation=None, first_appearance=None,
# exact same book may already be on the list. Either find
# an exact duplicate, or create a new entry.
entry, was_new = get_one_or_create(
- _db, CustomListEntry,
- customlist=self, edition=edition, work=work,
+ _db,
+ CustomListEntry,
+ customlist=self,
+ edition=edition,
+ work=work,
create_method_kwargs=dict(first_appearance=first_appearance),
)
- if (not entry.most_recent_appearance
- or entry.most_recent_appearance < first_appearance):
+ if (
+ not entry.most_recent_appearance
+ or entry.most_recent_appearance < first_appearance
+ ):
entry.most_recent_appearance = first_appearance
if annotation:
entry.annotation = str(annotation)
@@ -258,7 +273,7 @@ def entries_for_work(self, work_or_edition):
if equivalent_ids:
clauses.append(CustomListEntry.edition_id.in_(equivalent_ids))
if work:
- clauses.append(CustomListEntry.work==work)
+ clauses.append(CustomListEntry.work == work)
if len(clauses) == 0:
# This shouldn't happen, but if it does, there can be
# no matching results.
@@ -268,10 +283,11 @@ def entries_for_work(self, work_or_edition):
else:
clause = or_(*clauses)
- qu = _db.query(CustomListEntry).filter(
- CustomListEntry.customlist==self).filter(
- clause
- )
+ qu = (
+ _db.query(CustomListEntry)
+ .filter(CustomListEntry.customlist == self)
+ .filter(clause)
+ )
return qu
def update_size(self):
@@ -280,11 +296,11 @@ def update_size(self):
class CustomListEntry(Base):
- __tablename__ = 'customlistentries'
+ __tablename__ = "customlistentries"
id = Column(Integer, primary_key=True)
- list_id = Column(Integer, ForeignKey('customlists.id'), index=True)
- edition_id = Column(Integer, ForeignKey('editions.id'), index=True)
- work_id = Column(Integer, ForeignKey('works.id'), index=True)
+ list_id = Column(Integer, ForeignKey("customlists.id"), index=True)
+ edition_id = Column(Integer, ForeignKey("editions.id"), index=True)
+ work_id = Column(Integer, ForeignKey("works.id"), index=True)
featured = Column(Boolean, nullable=False, default=False)
annotation = Column(Unicode)
@@ -312,14 +328,15 @@ def set_work(self, metadata=None, metadata_client=None, policy=None):
new_work = None
if not metadata:
from ..metadata_layer import Metadata
+
metadata = Metadata.from_edition(edition)
# Try to guess based on metadata, if we can get a high-quality
# guess.
- potential_license_pools = metadata.guess_license_pools(
- _db, metadata_client)
+ potential_license_pools = metadata.guess_license_pools(_db, metadata_client)
for lp, quality in sorted(
- list(potential_license_pools.items()), key=lambda x: -x[1]):
+ list(potential_license_pools.items()), key=lambda x: -x[1]
+ ):
if lp.deliverable and lp.work and quality >= 0.8:
# This work has at least one deliverable LicensePool
# associated with it, so it's likely to be real
@@ -330,13 +347,21 @@ def set_work(self, metadata=None, metadata_client=None, policy=None):
if not new_work:
# Try using the less reliable, more expensive method of
# matching based on equivalent identifiers.
- equivalent_identifier_id_subquery = Identifier.recursively_equivalent_identifier_ids_query(
- self.edition.primary_identifier.id, policy=policy
+ equivalent_identifier_id_subquery = (
+ Identifier.recursively_equivalent_identifier_ids_query(
+ self.edition.primary_identifier.id, policy=policy
+ )
)
- pool_q = _db.query(LicensePool).filter(
- LicensePool.identifier_id.in_(equivalent_identifier_id_subquery)).order_by(
+ pool_q = (
+ _db.query(LicensePool)
+ .filter(
+ LicensePool.identifier_id.in_(equivalent_identifier_id_subquery)
+ )
+ .order_by(
LicensePool.licenses_available.desc(),
- LicensePool.patrons_in_hold_queue.asc())
+ LicensePool.patrons_in_hold_queue.asc(),
+ )
+ )
pools = [x for x in pool_q if x.deliverable]
for pool in pools:
if pool.deliverable and pool.work:
@@ -348,12 +373,13 @@ def set_work(self, metadata=None, metadata_client=None, policy=None):
if old_work:
logging.info(
"Changing work for list entry %r to %r (was %r)",
- self.edition, new_work, old_work
+ self.edition,
+ new_work,
+ old_work,
)
else:
logging.info(
- "Setting work for list entry %r to %r",
- self.edition, new_work
+ "Setting work for list entry %r to %r", self.edition, new_work
)
self.work = new_work
return self.work
@@ -371,7 +397,7 @@ def update(self, _db, equivalent_entries=None):
# Confirm that all the entries are from the same CustomList.
list_ids = set([e.list_id for e in equivalent_entries])
- if not len(list_ids)==1:
+ if not len(list_ids) == 1:
raise ValueError("Cannot combine entries on different CustomLists.")
# Confirm that all the entries are equivalent.
@@ -390,20 +416,17 @@ def update(self, _db, equivalent_entries=None):
works = [w for w in works if w]
if works:
- if not len(works)==1:
+ if not len(works) == 1:
# This shouldn't happen, given all the Editions are equivalent.
raise ValueError(error)
[work] = works
- self.first_appearance = min(
- [e.first_appearance for e in equivalent_entries]
- )
+ self.first_appearance = min([e.first_appearance for e in equivalent_entries])
self.most_recent_appearance = max(
[e.most_recent_appearance for e in equivalent_entries]
)
- annotations = [str(e.annotation) for e in equivalent_entries
- if e.annotation]
+ annotations = [str(e.annotation) for e in equivalent_entries if e.annotation]
if annotations:
if len(annotations) > 1:
# Just pick the longest one?
@@ -419,10 +442,12 @@ def update(self, _db, equivalent_entries=None):
if work and not best_edition:
work.calculate_presentation()
best_edition = work.presentation_edition
- if best_edition and not best_edition==self.edition:
+ if best_edition and not best_edition == self.edition:
logging.info(
"Changing edition for list entry %r to %r from %r",
- self, best_edition, self.edition
+ self,
+ best_edition,
+ self.edition,
)
self.edition = best_edition
@@ -433,9 +458,14 @@ def update(self, _db, equivalent_entries=None):
_db.delete(entry)
_db.commit
+
# TODO: This was originally designed to speed up queries against the
# materialized view that use custom list membership as a way to cut
# down on the result set. Now that we've removed the materialized
# view, is this still necessary? It might still be necessary for
# similar queries against Work.
-Index("ix_customlistentries_work_id_list_id", CustomListEntry.work_id, CustomListEntry.list_id)
+Index(
+ "ix_customlistentries_work_id_list_id",
+ CustomListEntry.work_id,
+ CustomListEntry.list_id,
+)
diff --git a/model/datasource.py b/model/datasource.py
index fd95eab1a..955a8b49a 100644
--- a/model/datasource.py
+++ b/model/datasource.py
@@ -2,41 +2,25 @@
# DataSource
-from . import (
- Base,
- get_one,
- get_one_or_create,
-)
-from .constants import (
- DataSourceConstants,
- IdentifierConstants,
-)
+from collections import defaultdict
+from urllib.parse import quote, unquote
+
+from sqlalchemy import Boolean, Column, ForeignKey, Integer, String
+from sqlalchemy.dialects.postgresql import JSON
+from sqlalchemy.ext.mutable import MutableDict
+from sqlalchemy.orm import backref, relationship
+
+from . import Base, get_one, get_one_or_create
+from .constants import DataSourceConstants, IdentifierConstants
from .hasfulltablecache import HasFullTableCache
from .licensing import LicensePoolDeliveryMechanism
-from collections import defaultdict
-from sqlalchemy import (
- Boolean,
- Column,
- ForeignKey,
- Integer,
- String,
-)
-from sqlalchemy.dialects.postgresql import JSON
-from sqlalchemy.ext.mutable import (
- MutableDict,
-)
-from sqlalchemy.orm import (
- backref,
- relationship,
-)
-from urllib.parse import quote, unquote
class DataSource(Base, HasFullTableCache, DataSourceConstants):
"""A source for information about books, and possibly the books themselves."""
- __tablename__ = 'datasources'
+ __tablename__ = "datasources"
id = Column(Integer, primary_key=True)
name = Column(String, unique=True, index=True)
offers_licenses = Column(Boolean, default=False)
@@ -45,9 +29,15 @@ class DataSource(Base, HasFullTableCache, DataSourceConstants):
# One DataSource can have one IntegrationClient.
integration_client_id = Column(
- Integer, ForeignKey('integrationclients.id'),
- unique=True, index=True, nullable=True)
- integration_client = relationship("IntegrationClient", backref=backref("data_source", uselist=False))
+ Integer,
+ ForeignKey("integrationclients.id"),
+ unique=True,
+ index=True,
+ nullable=True,
+ )
+ integration_client = relationship(
+ "IntegrationClient", backref=backref("data_source", uselist=False)
+ )
# One DataSource can generate many Editions.
editions = relationship("Edition", backref="data_source")
@@ -60,7 +50,8 @@ class DataSource(Base, HasFullTableCache, DataSourceConstants):
# One DataSource can grant access to many LicensePools.
license_pools = relationship(
- "LicensePool", backref=backref("data_source", lazy='joined'))
+ "LicensePool", backref=backref("data_source", lazy="joined")
+ )
# One DataSource can provide many Hyperlinks.
links = relationship("Hyperlink", backref="data_source")
@@ -82,8 +73,9 @@ class DataSource(Base, HasFullTableCache, DataSourceConstants):
# One DataSource can have provide many LicensePoolDeliveryMechanisms.
delivery_mechanisms = relationship(
- "LicensePoolDeliveryMechanism", backref="data_source",
- foreign_keys=lambda: [LicensePoolDeliveryMechanism.data_source_id]
+ "LicensePoolDeliveryMechanism",
+ backref="data_source",
+ foreign_keys=lambda: [LicensePoolDeliveryMechanism.data_source_id],
)
_cache = HasFullTableCache.RESET
@@ -96,8 +88,14 @@ def cache_key(self):
return self.name
@classmethod
- def lookup(cls, _db, name, autocreate=False, offers_licenses=False,
- primary_identifier_type=None):
+ def lookup(
+ cls,
+ _db,
+ name,
+ autocreate=False,
+ offers_licenses=False,
+ primary_identifier_type=None,
+ ):
# Turn a deprecated name (e.g. "3M" into the current name
# (e.g. "Bibliotheca").
name = cls.DEPRECATED_NAMES.get(name, name)
@@ -108,11 +106,13 @@ def lookup_hook():
"""
if autocreate:
data_source, is_new = get_one_or_create(
- _db, DataSource, name=name,
+ _db,
+ DataSource,
+ name=name,
create_method_kwargs=dict(
offers_licenses=offers_licenses,
- primary_identifier_type=primary_identifier_type
- )
+ primary_identifier_type=primary_identifier_type,
+ ),
)
else:
data_source = get_one(_db, DataSource, name=name)
@@ -133,7 +133,7 @@ def name_from_uri(cls, uri):
"""
if not uri.startswith(cls.URI_PREFIX):
return None
- name = uri[len(cls.URI_PREFIX):]
+ name = uri[len(cls.URI_PREFIX) :]
return unquote(name)
@classmethod
@@ -163,8 +163,11 @@ def license_sources_for(cls, _db, identifier):
type = identifier
else:
type = identifier.type
- q =_db.query(DataSource).filter(DataSource.offers_licenses==True).filter(
- DataSource.primary_identifier_type==type)
+ q = (
+ _db.query(DataSource)
+ .filter(DataSource.offers_licenses == True)
+ .filter(DataSource.primary_identifier_type == type)
+ )
return q
@classmethod
@@ -177,7 +180,7 @@ def metadata_sources_for(cls, _db, identifier):
else:
type = identifier.type
- if not hasattr(cls, 'metadata_lookups_by_identifier_type'):
+ if not hasattr(cls, "metadata_lookups_by_identifier_type"):
# This should only happen during testing.
list(DataSource.well_known_sources(_db))
@@ -186,50 +189,87 @@ def metadata_sources_for(cls, _db, identifier):
@classmethod
def well_known_sources(cls, _db):
- """Make sure all the well-known sources exist in the database.
- """
+ """Make sure all the well-known sources exist in the database."""
cls.metadata_lookups_by_identifier_type = defaultdict(list)
- for (name, offers_licenses, offers_metadata_lookup, primary_identifier_type, refresh_rate) in (
- (cls.GUTENBERG, True, False, IdentifierConstants.GUTENBERG_ID, None),
- (cls.OVERDRIVE, True, False, IdentifierConstants.OVERDRIVE_ID, 0),
- (cls.BIBLIOTHECA, True, False, IdentifierConstants.BIBLIOTHECA_ID, 60*60*6),
- (cls.ODILO, True, False, IdentifierConstants.ODILO_ID, 0),
- (cls.AXIS_360, True, False, IdentifierConstants.AXIS_360_ID, 0),
- (cls.OCLC, False, False, None, None),
- (cls.OCLC_LINKED_DATA, False, False, None, None),
- (cls.AMAZON, False, False, None, None),
- (cls.OPEN_LIBRARY, False, False, IdentifierConstants.OPEN_LIBRARY_ID, None),
- (cls.GUTENBERG_COVER_GENERATOR, False, False, IdentifierConstants.GUTENBERG_ID, None),
- (cls.GUTENBERG_EPUB_GENERATOR, False, False, IdentifierConstants.GUTENBERG_ID, None),
- (cls.WEB, True, False, IdentifierConstants.URI, None),
- (cls.VIAF, False, False, None, None),
- (cls.CONTENT_CAFE, True, True, IdentifierConstants.ISBN, None),
- (cls.MANUAL, False, False, None, None),
- (cls.NYT, False, False, IdentifierConstants.ISBN, None),
- (cls.LIBRARY_STAFF, False, False, None, None),
- (cls.METADATA_WRANGLER, False, False, None, None),
- (cls.PROJECT_GITENBERG, True, False, IdentifierConstants.GUTENBERG_ID, None),
- (cls.STANDARD_EBOOKS, True, False, IdentifierConstants.URI, None),
- (cls.UNGLUE_IT, True, False, IdentifierConstants.URI, None),
- (cls.ADOBE, False, False, None, None),
- (cls.PLYMPTON, True, False, IdentifierConstants.ISBN, None),
- (cls.ELIB, True, False, IdentifierConstants.ELIB_ID, None),
- (cls.OA_CONTENT_SERVER, True, False, None, None),
- (cls.NOVELIST, False, True, IdentifierConstants.NOVELIST_ID, None),
- (cls.PRESENTATION_EDITION, False, False, None, None),
- (cls.INTERNAL_PROCESSING, False, False, None, None),
- (cls.FEEDBOOKS, True, False, IdentifierConstants.URI, None),
- (cls.BIBBLIO, False, True, IdentifierConstants.BIBBLIO_CONTENT_ITEM_ID, None),
- (cls.ENKI, True, False, IdentifierConstants.ENKI_ID, None),
- (cls.PROQUEST, True, False, IdentifierConstants.PROQUEST_ID, None)
+ for (
+ name,
+ offers_licenses,
+ offers_metadata_lookup,
+ primary_identifier_type,
+ refresh_rate,
+ ) in (
+ (cls.GUTENBERG, True, False, IdentifierConstants.GUTENBERG_ID, None),
+ (cls.OVERDRIVE, True, False, IdentifierConstants.OVERDRIVE_ID, 0),
+ (
+ cls.BIBLIOTHECA,
+ True,
+ False,
+ IdentifierConstants.BIBLIOTHECA_ID,
+ 60 * 60 * 6,
+ ),
+ (cls.ODILO, True, False, IdentifierConstants.ODILO_ID, 0),
+ (cls.AXIS_360, True, False, IdentifierConstants.AXIS_360_ID, 0),
+ (cls.OCLC, False, False, None, None),
+ (cls.OCLC_LINKED_DATA, False, False, None, None),
+ (cls.AMAZON, False, False, None, None),
+ (cls.OPEN_LIBRARY, False, False, IdentifierConstants.OPEN_LIBRARY_ID, None),
+ (
+ cls.GUTENBERG_COVER_GENERATOR,
+ False,
+ False,
+ IdentifierConstants.GUTENBERG_ID,
+ None,
+ ),
+ (
+ cls.GUTENBERG_EPUB_GENERATOR,
+ False,
+ False,
+ IdentifierConstants.GUTENBERG_ID,
+ None,
+ ),
+ (cls.WEB, True, False, IdentifierConstants.URI, None),
+ (cls.VIAF, False, False, None, None),
+ (cls.CONTENT_CAFE, True, True, IdentifierConstants.ISBN, None),
+ (cls.MANUAL, False, False, None, None),
+ (cls.NYT, False, False, IdentifierConstants.ISBN, None),
+ (cls.LIBRARY_STAFF, False, False, None, None),
+ (cls.METADATA_WRANGLER, False, False, None, None),
+ (
+ cls.PROJECT_GITENBERG,
+ True,
+ False,
+ IdentifierConstants.GUTENBERG_ID,
+ None,
+ ),
+ (cls.STANDARD_EBOOKS, True, False, IdentifierConstants.URI, None),
+ (cls.UNGLUE_IT, True, False, IdentifierConstants.URI, None),
+ (cls.ADOBE, False, False, None, None),
+ (cls.PLYMPTON, True, False, IdentifierConstants.ISBN, None),
+ (cls.ELIB, True, False, IdentifierConstants.ELIB_ID, None),
+ (cls.OA_CONTENT_SERVER, True, False, None, None),
+ (cls.NOVELIST, False, True, IdentifierConstants.NOVELIST_ID, None),
+ (cls.PRESENTATION_EDITION, False, False, None, None),
+ (cls.INTERNAL_PROCESSING, False, False, None, None),
+ (cls.FEEDBOOKS, True, False, IdentifierConstants.URI, None),
+ (
+ cls.BIBBLIO,
+ False,
+ True,
+ IdentifierConstants.BIBBLIO_CONTENT_ITEM_ID,
+ None,
+ ),
+ (cls.ENKI, True, False, IdentifierConstants.ENKI_ID, None),
+ (cls.PROQUEST, True, False, IdentifierConstants.PROQUEST_ID, None),
):
obj = DataSource.lookup(
- _db, name, autocreate=True,
+ _db,
+ name,
+ autocreate=True,
offers_licenses=offers_licenses,
- primary_identifier_type = primary_identifier_type
+ primary_identifier_type=primary_identifier_type,
)
if offers_metadata_lookup:
diff --git a/model/edition.py b/model/edition.py
index d287e7c01..6add08798 100644
--- a/model/edition.py
+++ b/model/edition.py
@@ -2,51 +2,25 @@
# Edition
-from . import (
- Base,
- get_one,
- get_one_or_create,
- PresentationCalculationPolicy,
-)
-from .coverage import CoverageRecord
-from .constants import (
- DataSourceConstants,
- EditionConstants,
- LinkRelations,
- MediaTypes,
-)
-from .contributor import (
- Contributor,
- Contribution,
-)
-from .datasource import DataSource
-from .identifier import Identifier
-from .licensing import (
- DeliveryMechanism,
- LicensePool,
-)
-
-from collections import defaultdict
import logging
-from sqlalchemy import (
- Column,
- Date,
- Enum,
- ForeignKey,
- Index,
- Integer,
- String,
- Unicode,
-)
+from collections import defaultdict
+
+from sqlalchemy import Column, Date, Enum, ForeignKey, Index, Integer, String, Unicode
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import relationship
from sqlalchemy.orm.session import Session
-from ..util import (
- LanguageCodes,
- TitleProcessor
-)
+
+from ..util import LanguageCodes, TitleProcessor
from ..util.permanent_work_id import WorkIDCalculator
+from . import Base, PresentationCalculationPolicy, get_one, get_one_or_create
+from .constants import DataSourceConstants, EditionConstants, LinkRelations, MediaTypes
+from .contributor import Contribution, Contributor
+from .coverage import CoverageRecord
+from .datasource import DataSource
+from .identifier import Identifier
+from .licensing import DeliveryMechanism, LicensePool
+
class Edition(Base, EditionConstants):
@@ -55,10 +29,10 @@ class Edition(Base, EditionConstants):
as a "book" with a "title" it can go in here.
"""
- __tablename__ = 'editions'
+ __tablename__ = "editions"
id = Column(Integer, primary_key=True)
- data_source_id = Column(Integer, ForeignKey('datasources.id'), index=True)
+ data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True)
MAX_THUMBNAIL_HEIGHT = 300
MAX_THUMBNAIL_WIDTH = 200
@@ -71,8 +45,7 @@ class Edition(Base, EditionConstants):
# identifier--the one used by its data source to identify
# it. Through the Equivalency class, it is associated with a
# (probably huge) number of other identifiers.
- primary_identifier_id = Column(
- Integer, ForeignKey('identifiers.id'), index=True)
+ primary_identifier_id = Column(Integer, ForeignKey("identifiers.id"), index=True)
# An Edition may be the presentation edition for a single Work. If it's not
# a presentation edition for a work, work will be None.
@@ -82,9 +55,7 @@ class Edition(Base, EditionConstants):
custom_list_entries = relationship("CustomListEntry", backref="edition")
# An Edition may be the presentation edition for many LicensePools.
- is_presentation_for = relationship(
- "LicensePool", backref="presentation_edition"
- )
+ is_presentation_for = relationship("LicensePool", backref="presentation_edition")
title = Column(Unicode, index=True)
sort_title = Column(Unicode, index=True)
@@ -119,9 +90,10 @@ class Edition(Base, EditionConstants):
medium = Column(MEDIUM_ENUM, index=True)
cover_id = Column(
- Integer, ForeignKey(
- 'resources.id', use_alter=True, name='fk_editions_summary_id'),
- index=True)
+ Integer,
+ ForeignKey("resources.id", use_alter=True, name="fk_editions_summary_id"),
+ index=True,
+ )
# These two let us avoid actually loading up the cover Resource
# every time.
cover_full_url = Column(Unicode)
@@ -137,9 +109,11 @@ class Edition(Base, EditionConstants):
def __repr__(self):
id_repr = repr(self.primary_identifier)
return "Edition %s [%r] (%s/%s/%s)" % (
- self.id, id_repr, self.title,
+ self.id,
+ id_repr,
+ self.title,
", ".join([x.sort_name for x in self.contributors]),
- self.language
+ self.language,
)
@property
@@ -185,10 +159,12 @@ def author_contributors(self):
primary_author = x.contributor
elif x.role in Contributor.AUTHOR_ROLES:
other_authors.append(x.contributor)
- elif x.role.lower().startswith('author and'):
+ elif x.role.lower().startswith("author and"):
other_authors.append(x.contributor)
- elif (x.role in Contributor.AUTHOR_SUBSTITUTE_ROLES
- or x.role in Contributor.PERFORMER_ROLES):
+ elif (
+ x.role in Contributor.AUTHOR_SUBSTITUTE_ROLES
+ or x.role in Contributor.PERFORMER_ROLES
+ ):
l = acceptable_substitutes[x.role]
if x.contributor not in l:
l.append(x.contributor)
@@ -207,15 +183,14 @@ def dedupe(l):
return deduped
if primary_author:
- return dedupe([primary_author] + sorted(other_authors, key=lambda x: x.sort_name))
+ return dedupe(
+ [primary_author] + sorted(other_authors, key=lambda x: x.sort_name)
+ )
if other_authors:
return dedupe(other_authors)
- for role in (
- Contributor.AUTHOR_SUBSTITUTE_ROLES
- + Contributor.PERFORMER_ROLES
- ):
+ for role in Contributor.AUTHOR_SUBSTITUTE_ROLES + Contributor.PERFORMER_ROLES:
if role in acceptable_substitutes:
contributors = acceptable_substitutes[role]
return dedupe(sorted(contributors, key=lambda x: x.sort_name))
@@ -226,7 +201,6 @@ def dedupe(l):
# want to put them down as 'author'.
return []
-
@classmethod
def medium_from_media_type(cls, media_type):
"""Derive a value for Edition.medium from a media type.
@@ -254,9 +228,9 @@ def medium_from_media_type(cls, media_type):
return None
@classmethod
- def for_foreign_id(cls, _db, data_source,
- foreign_id_type, foreign_id,
- create_if_not_exists=True):
+ def for_foreign_id(
+ cls, _db, data_source, foreign_id_type, foreign_id, create_if_not_exists=True
+ ):
"""Find the Edition representing the given data source's view of
the work that it primarily identifies by foreign ID.
e.g. for_foreign_id(_db, DataSource.OVERDRIVE, Identifier.OVERDRIVE_ID, uuid)
@@ -273,8 +247,7 @@ def for_foreign_id(cls, _db, data_source,
if isinstance(data_source, (bytes, str)):
data_source = DataSource.lookup(_db, data_source)
- identifier, ignore = Identifier.for_foreign_id(
- _db, foreign_id_type, foreign_id)
+ identifier, ignore = Identifier.for_foreign_id(_db, foreign_id_type, foreign_id)
# Combine the two to get/create a Edition.
if create_if_not_exists:
@@ -283,9 +256,13 @@ def for_foreign_id(cls, _db, data_source,
else:
f = get_one
kwargs = dict()
- r = f(_db, Edition, data_source=data_source,
- primary_identifier=identifier,
- **kwargs)
+ r = f(
+ _db,
+ Edition,
+ data_source=data_source,
+ primary_identifier=identifier,
+ **kwargs
+ )
return r
@property
@@ -294,9 +271,14 @@ def license_pools(self):
by this Edition.
"""
_db = Session.object_session(self)
- return _db.query(LicensePool).filter(
- LicensePool.data_source==self.data_source,
- LicensePool.identifier==self.primary_identifier).all()
+ return (
+ _db.query(LicensePool)
+ .filter(
+ LicensePool.data_source == self.data_source,
+ LicensePool.identifier == self.primary_identifier,
+ )
+ .all()
+ )
def equivalent_identifiers(self, type=None, policy=None):
"""All Identifiers equivalent to this
@@ -307,13 +289,12 @@ def equivalent_identifiers(self, type=None, policy=None):
identifier_id_subquery = Identifier.recursively_equivalent_identifier_ids_query(
self.primary_identifier.id, policy=policy
)
- q = _db.query(Identifier).filter(
- Identifier.id.in_(identifier_id_subquery))
+ q = _db.query(Identifier).filter(Identifier.id.in_(identifier_id_subquery))
if type:
if isinstance(type, list):
q = q.filter(Identifier.type.in_(type))
else:
- q = q.filter(Identifier.type==type)
+ q = q.filter(Identifier.type == type)
return q.all()
def equivalent_editions(self, policy=None):
@@ -325,12 +306,12 @@ def equivalent_editions(self, policy=None):
self.primary_identifier.id, policy=policy
)
return _db.query(Edition).filter(
- Edition.primary_identifier_id.in_(identifier_id_subquery))
+ Edition.primary_identifier_id.in_(identifier_id_subquery)
+ )
@classmethod
def missing_coverage_from(
- cls, _db, edition_data_sources, coverage_data_source,
- operation=None
+ cls, _db, edition_data_sources, coverage_data_source, operation=None
):
"""Find Editions from `edition_data_source` whose primary
identifiers have no CoverageRecord from
@@ -339,7 +320,7 @@ def missing_coverage_from(
gutenberg = DataSource.lookup(_db, DataSource.GUTENBERG)
oclc_classify = DataSource.lookup(_db, DataSource.OCLC)
missing_coverage_from(_db, gutenberg, oclc_classify)
-
+
will find Editions that came from Project Gutenberg and
have never been used as input to the OCLC Classify web
service.
@@ -348,16 +329,15 @@ def missing_coverage_from(
edition_data_sources = [edition_data_sources]
edition_data_source_ids = [x.id for x in edition_data_sources]
join_clause = (
- (Edition.primary_identifier_id==CoverageRecord.identifier_id) &
- (CoverageRecord.data_source_id==coverage_data_source.id) &
- (CoverageRecord.operation==operation)
+ (Edition.primary_identifier_id == CoverageRecord.identifier_id)
+ & (CoverageRecord.data_source_id == coverage_data_source.id)
+ & (CoverageRecord.operation == operation)
)
- q = _db.query(Edition).outerjoin(
- CoverageRecord, join_clause)
+ q = _db.query(Edition).outerjoin(CoverageRecord, join_clause)
if edition_data_source_ids:
q = q.filter(Edition.data_source_id.in_(edition_data_source_ids))
- q2 = q.filter(CoverageRecord.id==None)
+ q2 = q.filter(CoverageRecord.id == None)
return q2
@classmethod
@@ -366,6 +346,7 @@ def sort_by_priority(cls, editions, license_source=None):
this LicensePool, in the order they should be used to create a
presentation Edition for the LicensePool.
"""
+
def sort_key(edition):
"""Return a numeric ordering of this edition."""
source = edition.data_source
@@ -390,9 +371,12 @@ def sort_key(edition):
return -1.5
if source.name in DataSourceConstants.PRESENTATION_EDITION_PRIORITY:
- return DataSourceConstants.PRESENTATION_EDITION_PRIORITY.index(source.name)
+ return DataSourceConstants.PRESENTATION_EDITION_PRIORITY.index(
+ source.name
+ )
else:
return -2
+
return sorted(editions, key=sort_key)
@classmethod
@@ -418,8 +402,10 @@ def set_cover(self, resource):
# versions of this representation and we need some way of
# choosing between them. Right now we just pick the first one
# that works.
- if (resource.representation.image_height
- and resource.representation.image_height <= self.MAX_THUMBNAIL_HEIGHT):
+ if (
+ resource.representation.image_height
+ and resource.representation.image_height <= self.MAX_THUMBNAIL_HEIGHT
+ ):
# This image doesn't need a thumbnail.
self.cover_thumbnail_url = resource.representation.public_url
else:
@@ -427,21 +413,25 @@ def set_cover(self, resource):
best_thumbnail = resource.representation.best_thumbnail
if best_thumbnail:
self.cover_thumbnail_url = best_thumbnail.public_url
- if (not self.cover_thumbnail_url and
- resource.representation.image_height
- and resource.representation.image_height <= self.MAX_FALLBACK_THUMBNAIL_HEIGHT):
+ if (
+ not self.cover_thumbnail_url
+ and resource.representation.image_height
+ and resource.representation.image_height
+ <= self.MAX_FALLBACK_THUMBNAIL_HEIGHT
+ ):
# The full-sized image is too large to be a thumbnail, but it's
# not huge, and there is no other thumbnail, so use it.
self.cover_thumbnail_url = resource.representation.public_url
if old_cover != self.cover or old_cover_full_url != self.cover_full_url:
logging.debug(
"Setting cover for %s/%s: full=%s thumb=%s",
- self.primary_identifier.type, self.primary_identifier.identifier,
- self.cover_full_url, self.cover_thumbnail_url
+ self.primary_identifier.type,
+ self.primary_identifier.identifier,
+ self.cover_full_url,
+ self.cover_thumbnail_url,
)
- def add_contributor(self, name, roles, aliases=None, lc=None, viaf=None,
- **kwargs):
+ def add_contributor(self, name, roles, aliases=None, lc=None, viaf=None, **kwargs):
"""Assign a contributor to this Edition."""
_db = Session.object_session(self)
if isinstance(roles, (bytes, str)):
@@ -451,8 +441,7 @@ def add_contributor(self, name, roles, aliases=None, lc=None, viaf=None,
if isinstance(name, Contributor):
contributor = name
else:
- contributor, was_new = Contributor.lookup(
- _db, name, lc, viaf, aliases)
+ contributor, was_new = Contributor.lookup(_db, name, lc, viaf, aliases)
if isinstance(contributor, list):
# Contributor was looked up/created by name,
# which returns a list.
@@ -461,8 +450,8 @@ def add_contributor(self, name, roles, aliases=None, lc=None, viaf=None,
# Then add their Contributions.
for role in roles:
contribution, was_new = get_one_or_create(
- _db, Contribution, edition=self, contributor=contributor,
- role=role)
+ _db, Contribution, edition=self, contributor=contributor, role=role
+ )
return contributor
def similarity_to(self, other_record):
@@ -506,16 +495,18 @@ def similarity_to(self, other_record):
# English, the penalty will be less if one of the
# languages is English. It's more likely that an unlabeled
# record is in English than that it's in some other language.
- if self.language == 'eng' or other_record.language == 'eng':
+ if self.language == "eng" or other_record.language == "eng":
language_factor = 0.80
else:
language_factor = 0.50
title_quotient = MetadataSimilarity.title_similarity(
- self.title, other_record.title)
+ self.title, other_record.title
+ )
author_quotient = MetadataSimilarity.author_similarity(
- self.author_contributors, other_record.author_contributors)
+ self.author_contributors, other_record.author_contributors
+ )
if author_quotient == 0:
# The two works have no authors in common. Immediate
# disqualification.
@@ -524,8 +515,7 @@ def similarity_to(self, other_record):
# We weight title more heavily because it's much more likely
# that one author wrote two different books than that two
# books with the same title have different authors.
- return language_factor * (
- (title_quotient * 0.80) + (author_quotient * 0.20))
+ return language_factor * ((title_quotient * 0.80) + (author_quotient * 0.20))
def apply_similarity_threshold(self, candidates, threshold=0.5):
"""Yield the Editions from the given list that are similar
@@ -542,7 +532,7 @@ def apply_similarity_threshold(self, candidates, threshold=0.5):
def best_cover_within_distance(self, distance, rel=None, policy=None):
_db = Session.object_session(self)
identifier_ids = [self.primary_identifier.id]
-
+
if distance > 0:
if policy is None:
new_policy = PresentationCalculationPolicy()
@@ -564,7 +554,7 @@ def best_cover_within_distance(self, distance, rel=None, policy=None):
def title_for_permanent_work_id(self):
title = self.title
if self.subtitle:
- title += (": " + self.subtitle)
+ title += ": " + self.subtitle
return title
@property
@@ -596,11 +586,18 @@ def calculate_permanent_work_id(self, debug=False):
old_id = self.permanent_work_id
self.permanent_work_id = self.calculate_permanent_work_id_for_title_and_author(
- title, author, medium)
+ title, author, medium
+ )
args = (
"Permanent work ID for %d: %s/%s -> %s/%s/%s -> %s (was %s)",
- self.id, title, author, norm_title, norm_author, medium,
- self.permanent_work_id, old_id
+ self.id,
+ title,
+ author,
+ norm_title,
+ norm_author,
+ medium,
+ self.permanent_work_id,
+ old_id,
)
if debug:
logging.debug(*args)
@@ -608,19 +605,15 @@ def calculate_permanent_work_id(self, debug=False):
logging.info(*args)
@classmethod
- def calculate_permanent_work_id_for_title_and_author(
- cls, title, author, medium):
+ def calculate_permanent_work_id_for_title_and_author(cls, title, author, medium):
w = WorkIDCalculator
norm_title = w.normalize_title(title)
norm_author = w.normalize_author(author)
- return WorkIDCalculator.permanent_id(
- norm_title, norm_author, medium)
+ return WorkIDCalculator.permanent_id(norm_title, norm_author, medium)
UNKNOWN_AUTHOR = "[Unknown]"
-
-
def calculate_presentation(self, policy=None):
"""Make sure the presentation of this Edition is up-to-date."""
_db = Session.object_session(self)
@@ -643,14 +636,16 @@ def calculate_presentation(self, policy=None):
self.sort_title = TitleProcessor.sort_title_for(self.title)
self.calculate_permanent_work_id()
CoverageRecord.add_for(
- self, data_source=self.data_source,
- operation=CoverageRecord.SET_EDITION_METADATA_OPERATION
+ self,
+ data_source=self.data_source,
+ operation=CoverageRecord.SET_EDITION_METADATA_OPERATION,
)
if policy.choose_cover:
self.choose_cover(policy=policy)
- if (self.author != old_author
+ if (
+ self.author != old_author
or self.sort_author != old_sort_author
or self.sort_title != old_sort_title
or self.permanent_work_id != old_work_id
@@ -670,9 +665,15 @@ def calculate_presentation(self, policy=None):
level = logging.debug
msg = "Presentation %s for Edition %s (by %s, pub=%s, ident=%s/%s, pwid=%s, language=%s, cover=%r)"
- args = [changed_status, self.title, self.author, self.publisher,
- self.primary_identifier.type, self.primary_identifier.identifier,
- self.permanent_work_id, self.language
+ args = [
+ changed_status,
+ self.title,
+ self.author,
+ self.publisher,
+ self.primary_identifier.type,
+ self.primary_identifier.identifier,
+ self.permanent_work_id,
+ self.language,
]
if self.cover and self.cover.representation:
args.append(self.cover.representation.public_url)
@@ -729,7 +730,7 @@ def choose_cover(self, policy=None):
logging.warn(
"Best cover for %r (%s) was never thumbnailed!",
self.primary_identifier,
- rep.public_url
+ rep.public_url,
)
self.set_cover(best_cover)
break
@@ -753,7 +754,8 @@ def choose_cover(self, policy=None):
# look for a cover.
for distance in (0, 5):
best_thumbnail, thumbnails = self.best_cover_within_distance(
- distance=distance, policy=policy,
+ distance=distance,
+ policy=policy,
rel=LinkRelations.THUMBNAIL_IMAGE,
)
if best_thumbnail:
@@ -776,8 +778,15 @@ def choose_cover(self, policy=None):
# Whether or not we succeeded in setting the cover,
# record the fact that we tried.
CoverageRecord.add_for(
- self, data_source=self.data_source,
- operation=CoverageRecord.CHOOSE_COVER_OPERATION
+ self,
+ data_source=self.data_source,
+ operation=CoverageRecord.CHOOSE_COVER_OPERATION,
)
-Index("ix_editions_data_source_id_identifier_id", Edition.data_source_id, Edition.primary_identifier_id, unique=True)
+
+Index(
+ "ix_editions_data_source_id_identifier_id",
+ Edition.data_source_id,
+ Edition.primary_identifier_id,
+ unique=True,
+)
diff --git a/model/hasfulltablecache.py b/model/hasfulltablecache.py
index 5ecf0fc09..16740ddbe 100644
--- a/model/hasfulltablecache.py
+++ b/model/hasfulltablecache.py
@@ -1,9 +1,10 @@
# encoding: utf-8
# HasFullTableCache
+import logging
+
from . import get_one
-import logging
class HasFullTableCache(object):
"""A mixin class for ORM classes that maintain an in-memory cache of
@@ -108,7 +109,8 @@ def _cache_lookup(cls, _db, cache, cache_name, cache_key, lookup_hook):
except Exception as e:
logging.error(
"Unable to merge cached object %r into database session",
- obj, exc_info=e
+ obj,
+ exc_info=e,
)
# Try to look up a fresh copy of the object.
obj, new = lookup_hook()
@@ -124,15 +126,15 @@ def _cache_lookup(cls, _db, cache, cache_name, cache_key, lookup_hook):
@classmethod
def by_id(cls, _db, id):
"""Look up an item by its unique database ID."""
+
def lookup_hook():
return get_one(_db, cls, id=id), False
+
obj, is_new = cls._cache_lookup(
- _db, cls._id_cache, '_id_cache', id, lookup_hook
+ _db, cls._id_cache, "_id_cache", id, lookup_hook
)
return obj
@classmethod
def by_cache_key(cls, _db, cache_key, lookup_hook):
- return cls._cache_lookup(
- _db, cls._cache, '_cache', cache_key, lookup_hook
- )
+ return cls._cache_lookup(_db, cls._cache, "_cache", cache_key, lookup_hook)
diff --git a/model/identifier.py b/model/identifier.py
index 3a1965a53..f7898f277 100644
--- a/model/identifier.py
+++ b/model/identifier.py
@@ -2,10 +2,11 @@
# Identifier, Equivalency
import logging
import random
-from urllib.parse import quote, unquote
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from functools import total_ordering
+from urllib.parse import quote, unquote
+
import isbnlib
from sqlalchemy import (
Boolean,
@@ -23,15 +24,15 @@
from sqlalchemy.sql import select
from sqlalchemy.sql.expression import and_, or_
+from ..util.datetime_helpers import utc_now
+from ..util.summary import SummaryEvaluator
+from . import Base, PresentationCalculationPolicy, create, get_one, get_one_or_create
from .classification import Classification, Subject
from .constants import IdentifierConstants, LinkRelations
from .coverage import CoverageRecord
from .datasource import DataSource
from .licensing import LicensePoolDeliveryMechanism, RightsStatus
from .measurement import Measurement
-from . import Base, PresentationCalculationPolicy, create, get_one, get_one_or_create
-from ..util.summary import SummaryEvaluator
-from ..util.datetime_helpers import utc_now
class IdentifierParser(metaclass=ABCMeta):
@@ -53,10 +54,9 @@ def parse(self, identifier_string):
@total_ordering
class Identifier(Base, IdentifierConstants):
- """A way of uniquely referring to a particular edition.
- """
+ """A way of uniquely referring to a particular edition."""
- __tablename__ = 'identifiers'
+ __tablename__ = "identifiers"
id = Column(Integer, primary_key=True)
type = Column(String(64), index=True)
identifier = Column(String, index=True)
@@ -64,13 +64,15 @@ class Identifier(Base, IdentifierConstants):
equivalencies = relationship(
"Equivalency",
primaryjoin=("Identifier.id==Equivalency.input_id"),
- backref="input_identifiers", cascade="all, delete-orphan"
+ backref="input_identifiers",
+ cascade="all, delete-orphan",
)
inbound_equivalencies = relationship(
"Equivalency",
primaryjoin=("Identifier.id==Equivalency.output_id"),
- backref="output_identifiers", cascade="all, delete-orphan"
+ backref="output_identifiers",
+ cascade="all, delete-orphan",
)
# One Identifier may have many associated CoverageRecords.
@@ -86,46 +88,37 @@ def __repr__(self):
# One Identifier may serve as the primary identifier for
# several Editions.
- primarily_identifies = relationship(
- "Edition", backref="primary_identifier"
- )
+ primarily_identifies = relationship("Edition", backref="primary_identifier")
# One Identifier may serve as the identifier for many
# LicensePools, through different Collections.
licensed_through = relationship(
- "LicensePool", backref="identifier", lazy='joined',
+ "LicensePool",
+ backref="identifier",
+ lazy="joined",
)
# One Identifier may have many Links.
- links = relationship(
- "Hyperlink", backref="identifier"
- )
+ links = relationship("Hyperlink", backref="identifier")
# One Identifier may be the subject of many Measurements.
- measurements = relationship(
- "Measurement", backref="identifier"
- )
+ measurements = relationship("Measurement", backref="identifier")
# One Identifier may participate in many Classifications.
- classifications = relationship(
- "Classification", backref="identifier"
- )
+ classifications = relationship("Classification", backref="identifier")
# One identifier may participate in many Annotations.
- annotations = relationship(
- "Annotation", backref="identifier"
- )
+ annotations = relationship("Annotation", backref="identifier")
# One Identifier can have have many LicensePoolDeliveryMechanisms.
delivery_mechanisms = relationship(
- "LicensePoolDeliveryMechanism", backref="identifier",
- foreign_keys=lambda: [LicensePoolDeliveryMechanism.identifier_id]
+ "LicensePoolDeliveryMechanism",
+ backref="identifier",
+ foreign_keys=lambda: [LicensePoolDeliveryMechanism.identifier_id],
)
# Type + identifier is unique.
- __table_args__ = (
- UniqueConstraint('type', 'identifier'),
- )
+ __table_args__ = (UniqueConstraint("type", "identifier"),)
@classmethod
def from_asin(cls, _db, asin, autocreate=True):
@@ -145,8 +138,7 @@ def from_asin(cls, _db, asin, autocreate=True):
return cls.for_foreign_id(_db, type, asin, autocreate)
@classmethod
- def for_foreign_id(cls, _db, foreign_identifier_type, foreign_id,
- autocreate=True):
+ def for_foreign_id(cls, _db, foreign_identifier_type, foreign_id, autocreate=True):
"""Turn a foreign ID into an Identifier."""
foreign_identifier_type, foreign_id = cls.prepare_foreign_type_and_identifier(
foreign_identifier_type, foreign_id
@@ -159,8 +151,7 @@ def for_foreign_id(cls, _db, foreign_identifier_type, foreign_id,
else:
m = get_one
- result = m(_db, cls, type=foreign_identifier_type,
- identifier=foreign_id)
+ result = m(_db, cls, type=foreign_identifier_type, identifier=foreign_id)
if isinstance(result, tuple):
return result
@@ -180,9 +171,9 @@ def prepare_foreign_type_and_identifier(cls, foreign_type, foreign_identifier):
foreign_identifier = foreign_identifier.lower()
if not cls.valid_as_foreign_identifier(foreign_type, foreign_identifier):
- raise ValueError('"%s" is not a valid %s.' % (
- foreign_identifier, foreign_type
- ))
+ raise ValueError(
+ '"%s" is not a valid %s.' % (foreign_identifier, foreign_type)
+ )
return (foreign_type, foreign_identifier)
@@ -197,16 +188,16 @@ def valid_as_foreign_identifier(cls, type, id):
currently don't enforce that). We only reject an ID out of
hand if it will cause problems with a third-party API.
"""
- forbidden_characters = ''
+ forbidden_characters = ""
if type == Identifier.BIBLIOTHECA_ID:
# IDs are joined with commas and provided as a URL path
# element. Embedded commas or slashes will confuse the
# Bibliotheca API.
- forbidden_characters = ',/'
+ forbidden_characters = ",/"
elif type == Identifier.AXIS_360_ID:
# IDs are joined with commas during a lookup. Embedded
# commas will confuse the Axis 360 API.
- forbidden_characters = ','
+ forbidden_characters = ","
if any(x in id for x in forbidden_characters):
return False
return True
@@ -222,8 +213,7 @@ def urn(self):
return self.GUTENBERG_URN_SCHEME_PREFIX + identifier_text
else:
identifier_type = quote(self.type)
- return self.URN_SCHEME_PREFIX + "%s/%s" % (
- identifier_type, identifier_text)
+ return self.URN_SCHEME_PREFIX + "%s/%s" % (identifier_type, identifier_text)
@property
def work(self):
@@ -248,19 +238,26 @@ def type_and_identifier_for_urn(cls, identifier_string):
if m:
type = Identifier.GUTENBERG_ID
identifier_string = m.groups()[0]
- elif identifier_string.startswith("http:") or identifier_string.startswith("https:"):
+ elif identifier_string.startswith("http:") or identifier_string.startswith(
+ "https:"
+ ):
type = Identifier.URI
elif identifier_string.startswith(Identifier.URN_SCHEME_PREFIX):
- identifier_string = identifier_string[len(Identifier.URN_SCHEME_PREFIX):]
- type, identifier_string = list(map(
- unquote, identifier_string.split("/", 1)))
+ identifier_string = identifier_string[len(Identifier.URN_SCHEME_PREFIX) :]
+ type, identifier_string = list(
+ map(unquote, identifier_string.split("/", 1))
+ )
elif identifier_string.lower().startswith(Identifier.ISBN_URN_SCHEME_PREFIX):
type = Identifier.ISBN
- identifier_string = identifier_string[len(Identifier.ISBN_URN_SCHEME_PREFIX):]
+ identifier_string = identifier_string[
+ len(Identifier.ISBN_URN_SCHEME_PREFIX) :
+ ]
identifier_string = unquote(identifier_string)
# Make sure this is a valid ISBN, and convert it to an ISBN-13.
- if not (isbnlib.is_isbn10(identifier_string) or
- isbnlib.is_isbn13(identifier_string)):
+ if not (
+ isbnlib.is_isbn10(identifier_string)
+ or isbnlib.is_isbn13(identifier_string)
+ ):
raise ValueError("%s is not a valid ISBN." % identifier_string)
if isbnlib.is_isbn10(identifier_string):
identifier_string = isbnlib.to_isbn13(identifier_string)
@@ -268,13 +265,12 @@ def type_and_identifier_for_urn(cls, identifier_string):
type = Identifier.URI
else:
raise ValueError(
- "Could not turn %s into a recognized identifier." %
- identifier_string)
+ "Could not turn %s into a recognized identifier." % identifier_string
+ )
return (type, identifier_string)
@classmethod
- def parse_urns(cls, _db, identifier_strings, autocreate=True,
- allowed_types=None):
+ def parse_urns(cls, _db, identifier_strings, autocreate=True, allowed_types=None):
"""Converts a batch of URNs into Identifier objects.
:param _db: A database connection
@@ -300,8 +296,11 @@ def parse_urns(cls, _db, identifier_strings, autocreate=True,
(type, identifier) = cls.prepare_foreign_type_and_identifier(
*cls.type_and_identifier_for_urn(urn)
)
- if (type and identifier and
- (allowed_types is None or type in allowed_types)):
+ if (
+ type
+ and identifier
+ and (allowed_types is None or type in allowed_types)
+ ):
identifier_details[urn] = (type, identifier)
else:
failures.append(urn)
@@ -309,14 +308,13 @@ def parse_urns(cls, _db, identifier_strings, autocreate=True,
failures.append(urn)
identifiers_by_urn = dict()
+
def find_existing_identifiers(identifier_details):
if not identifier_details:
return
and_clauses = list()
for type, identifier in identifier_details:
- and_clauses.append(
- and_(cls.type==type, cls.identifier==identifier)
- )
+ and_clauses.append(and_(cls.type == type, cls.identifier == identifier))
identifiers = _db.query(cls).filter(or_(*and_clauses)).all()
for identifier in identifiers:
@@ -327,9 +325,12 @@ def find_existing_identifiers(identifier_details):
# Remove the existing identifiers from the identifier_details list,
# regardless of whether the provided URN was accurate.
- existing_details = [(i.type, i.identifier) for i in list(identifiers_by_urn.values())]
+ existing_details = [
+ (i.type, i.identifier) for i in list(identifiers_by_urn.values())
+ ]
identifier_details = {
- k: v for k, v in list(identifier_details.items())
+ k: v
+ for k, v in list(identifier_details.items())
if v not in existing_details and k not in list(identifiers_by_urn.keys())
}
@@ -360,7 +361,9 @@ def find_existing_identifiers(identifier_details):
return identifiers_by_urn, failures
@classmethod
- def _parse_urn(cls, _db, identifier_string, identifier_type, must_support_license_pools=False):
+ def _parse_urn(
+ cls, _db, identifier_string, identifier_type, must_support_license_pools=False
+ ):
"""Parse identifier string.
:param _db: Database session
@@ -407,9 +410,13 @@ def parse_urn(cls, _db, identifier_string, must_support_license_pools=False):
:return: 2-tuple containing Identifier object and a boolean value indicating whether it's new
:rtype: Tuple[core.model.identifier.Identifier, bool]
"""
- identifier_type, identifier_string = cls.type_and_identifier_for_urn(identifier_string)
+ identifier_type, identifier_string = cls.type_and_identifier_for_urn(
+ identifier_string
+ )
- return cls._parse_urn(_db, identifier_string, identifier_type, must_support_license_pools)
+ return cls._parse_urn(
+ _db, identifier_string, identifier_type, must_support_license_pools
+ )
@classmethod
def parse(cls, _db, identifier_string, parser, must_support_license_pools=False):
@@ -433,7 +440,9 @@ def parse(cls, _db, identifier_string, parser, must_support_license_pools=False)
"""
identifier_type, identifier_string = parser.parse(identifier_string)
- return cls._parse_urn(_db, identifier_string, identifier_type, must_support_license_pools)
+ return cls._parse_urn(
+ _db, identifier_string, identifier_type, must_support_license_pools
+ )
def equivalent_to(self, data_source, identifier, strength):
"""Make one Identifier equivalent to another.
@@ -446,23 +455,24 @@ def equivalent_to(self, data_source, identifier, strength):
# Do nothing.
return None
eq, new = get_one_or_create(
- _db, Equivalency,
+ _db,
+ Equivalency,
data_source=data_source,
input=self,
output=identifier,
- on_multiple='interchangeable'
+ on_multiple="interchangeable",
)
- eq.strength=strength
+ eq.strength = strength
if new:
logging.info(
- "Identifier equivalency: %r==%r p=%.2f", self, identifier,
- strength
+ "Identifier equivalency: %r==%r p=%.2f", self, identifier, strength
)
return eq
@classmethod
def recursively_equivalent_identifier_ids_query(
- cls, identifier_id_column, policy=None):
+ cls, identifier_id_column, policy=None
+ ):
"""Get a SQL statement that will return all Identifier IDs
equivalent to a given ID at the given confidence threshold.
`identifier_id_column` can be a single Identifier ID, or a column
@@ -489,8 +499,7 @@ def _recursively_equivalent_identifier_ids_query(
)
@classmethod
- def recursively_equivalent_identifier_ids(
- cls, _db, identifier_ids, policy=None):
+ def recursively_equivalent_identifier_ids(cls, _db, identifier_ids, policy=None):
"""All Identifier IDs equivalent to the given set of Identifier
IDs at the given confidence threshold.
This uses the function defined in files/recursive_equivalents.sql.
@@ -503,9 +512,7 @@ def recursively_equivalent_identifier_ids(
how you've chosen to make the tradeoff between performance,
data quality, and sheer number of equivalent identifiers.
"""
- fn = cls._recursively_equivalent_identifier_ids_query(
- Identifier.id, policy
- )
+ fn = cls._recursively_equivalent_identifier_ids_query(Identifier.id, policy)
query = select([Identifier.id, fn], Identifier.id.in_(identifier_ids))
results = _db.execute(query)
equivalents = defaultdict(list)
@@ -517,9 +524,7 @@ def recursively_equivalent_identifier_ids(
def equivalent_identifier_ids(self, policy=None):
_db = Session.object_session(self)
- return Identifier.recursively_equivalent_identifier_ids(
- _db, [self.id], policy
- )
+ return Identifier.recursively_equivalent_identifier_ids(_db, [self.id], policy)
def licensed_through_collection(self, collection):
"""Find the LicensePool, if any, for this Identifier
@@ -530,9 +535,19 @@ def licensed_through_collection(self, collection):
if lp.collection == collection:
return lp
- def add_link(self, rel, href, data_source, media_type=None, content=None,
- content_path=None, rights_status_uri=None, rights_explanation=None,
- original_resource=None, transformation_settings=None):
+ def add_link(
+ self,
+ rel,
+ href,
+ data_source,
+ media_type=None,
+ content=None,
+ content_path=None,
+ rights_status_uri=None,
+ rights_explanation=None,
+ original_resource=None,
+ transformation_settings=None,
+ ):
"""Create a link between this Identifier and a (potentially new)
Resource.
TODO: There's some code in metadata_layer for automatically
@@ -540,6 +555,7 @@ def add_link(self, rel, href, data_source, media_type=None, content=None,
created. It might be good to move that code into here.
"""
from .resource import Hyperlink, Representation, Resource
+
_db = Session.object_session(self)
# Find or create the Resource.
@@ -549,22 +565,30 @@ def add_link(self, rel, href, data_source, media_type=None, content=None,
if rights_status_uri:
rights_status = RightsStatus.lookup(_db, rights_status_uri)
resource, new_resource = get_one_or_create(
- _db, Resource, url=href,
- create_method_kwargs=dict(data_source=data_source,
- rights_status=rights_status,
- rights_explanation=rights_explanation)
+ _db,
+ Resource,
+ url=href,
+ create_method_kwargs=dict(
+ data_source=data_source,
+ rights_status=rights_status,
+ rights_explanation=rights_explanation,
+ ),
)
# Find or create the Hyperlink.
link, new_link = get_one_or_create(
- _db, Hyperlink, rel=rel, data_source=data_source,
- identifier=self, resource=resource,
+ _db,
+ Hyperlink,
+ rel=rel,
+ data_source=data_source,
+ identifier=self,
+ resource=resource,
)
if content or content_path:
# We have content for this resource.
resource.set_fetched_content(media_type, content, content_path)
- elif (media_type and not resource.representation):
+ elif media_type and not resource.representation:
# We know the type of the resource, so make a
# Representation for it.
resource.representation, is_new = get_one_or_create(
@@ -578,24 +602,33 @@ def add_link(self, rel, href, data_source, media_type=None, content=None,
# wanted to.
return link, new_link
- def add_measurement(self, data_source, quantity_measured, value,
- weight=1, taken_at=None):
+ def add_measurement(
+ self, data_source, quantity_measured, value, weight=1, taken_at=None
+ ):
"""Associate a new Measurement with this Identifier."""
_db = Session.object_session(self)
logging.debug(
"MEASUREMENT: %s on %s/%s: %s == %s (wt=%d)",
- data_source.name, self.type, self.identifier,
- quantity_measured, value, weight)
+ data_source.name,
+ self.type,
+ self.identifier,
+ quantity_measured,
+ value,
+ weight,
+ )
now = utc_now()
taken_at = taken_at or now
# Is there an existing most recent measurement?
most_recent = get_one(
- _db, Measurement, identifier=self,
+ _db,
+ Measurement,
+ identifier=self,
data_source=data_source,
quantity_measured=quantity_measured,
- is_most_recent=True, on_multiple='interchangeable'
+ is_most_recent=True,
+ on_multiple="interchangeable",
)
if most_recent and most_recent.value == value and taken_at == now:
# The value hasn't changed since last time. Just update
@@ -606,13 +639,20 @@ def add_measurement(self, data_source, quantity_measured, value,
most_recent.is_most_recent = False
return create(
- _db, Measurement,
- identifier=self, data_source=data_source,
- quantity_measured=quantity_measured, taken_at=taken_at,
- value=value, weight=weight, is_most_recent=True)[0]
-
- def classify(self, data_source, subject_type, subject_identifier,
- subject_name=None, weight=1):
+ _db,
+ Measurement,
+ identifier=self,
+ data_source=data_source,
+ quantity_measured=quantity_measured,
+ taken_at=taken_at,
+ value=value,
+ weight=weight,
+ is_most_recent=True,
+ )[0]
+
+ def classify(
+ self, data_source, subject_type, subject_identifier, subject_name=None, weight=1
+ ):
"""Classify this Identifier under a Subject.
:param type: Classification scheme; one of the constants from Subject.
@@ -628,30 +668,40 @@ def classify(self, data_source, subject_type, subject_identifier,
# Turn the subject type and identifier into a Subject.
classifications = []
subject, is_new = Subject.lookup(
- _db, subject_type, subject_identifier, subject_name,
+ _db,
+ subject_type,
+ subject_identifier,
+ subject_name,
)
logging.debug(
"CLASSIFICATION: %s on %s/%s: %s %s/%s (wt=%d)",
- data_source.name, self.type, self.identifier,
- subject.type, subject.identifier, subject.name,
- weight
+ data_source.name,
+ self.type,
+ self.identifier,
+ subject.type,
+ subject.identifier,
+ subject.name,
+ weight,
)
# Use a Classification to connect the Identifier to the
# Subject.
try:
classification, is_new = get_one_or_create(
- _db, Classification,
+ _db,
+ Classification,
identifier=self,
subject=subject,
- data_source=data_source)
+ data_source=data_source,
+ )
except MultipleResultsFound as e:
# TODO: This is a hack.
all_classifications = _db.query(Classification).filter(
- Classification.identifier==self,
- Classification.subject==subject,
- Classification.data_source==data_source)
+ Classification.identifier == self,
+ Classification.subject == subject,
+ Classification.data_source == data_source,
+ )
all_classifications = all_classifications.all()
classification = all_classifications[0]
for i in all_classifications[1:]:
@@ -661,37 +711,45 @@ def classify(self, data_source, subject_type, subject_identifier,
return classification
@classmethod
- def resources_for_identifier_ids(self, _db, identifier_ids, rel=None,
- data_source=None):
+ def resources_for_identifier_ids(
+ self, _db, identifier_ids, rel=None, data_source=None
+ ):
from .resource import Hyperlink, Resource
- resources = _db.query(Resource).join(Resource.links).filter(
- Hyperlink.identifier_id.in_(identifier_ids))
+
+ resources = (
+ _db.query(Resource)
+ .join(Resource.links)
+ .filter(Hyperlink.identifier_id.in_(identifier_ids))
+ )
if data_source:
if isinstance(data_source, DataSource):
data_source = [data_source]
- resources = resources.filter(Hyperlink.data_source_id.in_([d.id for d in data_source]))
+ resources = resources.filter(
+ Hyperlink.data_source_id.in_([d.id for d in data_source])
+ )
if rel:
if isinstance(rel, list):
resources = resources.filter(Hyperlink.rel.in_(rel))
else:
- resources = resources.filter(Hyperlink.rel==rel)
- resources = resources.options(joinedload('representation'))
+ resources = resources.filter(Hyperlink.rel == rel)
+ resources = resources.options(joinedload("representation"))
return resources
@classmethod
def classifications_for_identifier_ids(self, _db, identifier_ids):
classifications = _db.query(Classification).filter(
- Classification.identifier_id.in_(identifier_ids))
- return classifications.options(joinedload('subject'))
+ Classification.identifier_id.in_(identifier_ids)
+ )
+ return classifications.options(joinedload("subject"))
@classmethod
def best_cover_for(cls, _db, identifier_ids, rel=None):
# Find all image resources associated with any of
# these identifiers.
from .resource import Hyperlink, Resource
+
rel = rel or Hyperlink.IMAGE
- images = cls.resources_for_identifier_ids(
- _db, identifier_ids, rel)
+ images = cls.resources_for_identifier_ids(_db, identifier_ids, rel)
images = images.join(Resource.representation)
images = images.all()
@@ -706,8 +764,9 @@ def best_cover_for(cls, _db, identifier_ids, rel=None):
return champion, images
@classmethod
- def evaluate_summary_quality(cls, _db, identifier_ids,
- privileged_data_sources=None):
+ def evaluate_summary_quality(
+ cls, _db, identifier_ids, privileged_data_sources=None
+ ):
"""Evaluate the summaries for the given group of Identifier IDs.
This is an automatic evaluation based solely on the content of
the summaries. It will be combined with human-entered ratings
@@ -732,7 +791,8 @@ def evaluate_summary_quality(cls, _db, identifier_ids,
# these records.
rels = [LinkRelations.DESCRIPTION, LinkRelations.SHORT_DESCRIPTION]
descriptions = cls.resources_for_identifier_ids(
- _db, identifier_ids, rels, privileged_data_source).all()
+ _db, identifier_ids, rels, privileged_data_source
+ ).all()
champion = None
# Add each resource's content to the evaluator's corpus.
@@ -753,14 +813,22 @@ def evaluate_summary_quality(cls, _db, identifier_ids,
if privileged_data_source and not champion:
# We could not find any descriptions from the privileged
# data source. Try relaxing that restriction.
- return cls.evaluate_summary_quality(_db, identifier_ids, privileged_data_sources[1:])
+ return cls.evaluate_summary_quality(
+ _db, identifier_ids, privileged_data_sources[1:]
+ )
return champion, descriptions
@classmethod
def missing_coverage_from(
- cls, _db, identifier_types, coverage_data_source, operation=None,
- count_as_covered=None, count_as_missing_before=None, identifiers=None,
- collection=None
+ cls,
+ _db,
+ identifier_types,
+ coverage_data_source,
+ operation=None,
+ count_as_covered=None,
+ count_as_missing_before=None,
+ identifiers=None,
+ collection=None,
):
"""Find identifiers of the given types which have no CoverageRecord
from `coverage_data_source`.
@@ -777,17 +845,16 @@ def missing_coverage_from(
if coverage_data_source:
data_source_id = coverage_data_source.id
- clause = and_(Identifier.id==CoverageRecord.identifier_id,
- CoverageRecord.data_source_id==data_source_id,
- CoverageRecord.operation==operation,
- CoverageRecord.collection_id==collection_id
+ clause = and_(
+ Identifier.id == CoverageRecord.identifier_id,
+ CoverageRecord.data_source_id == data_source_id,
+ CoverageRecord.operation == operation,
+ CoverageRecord.collection_id == collection_id,
)
qu = _db.query(Identifier).outerjoin(CoverageRecord, clause)
if identifier_types:
qu = qu.filter(Identifier.type.in_(identifier_types))
- missing = CoverageRecord.not_covered(
- count_as_covered, count_as_missing_before
- )
+ missing = CoverageRecord.not_covered(count_as_covered, count_as_missing_before)
qu = qu.filter(missing)
if identifiers:
@@ -816,8 +883,9 @@ def opds_entry(self):
resource = link.resource
if link.rel == LinkRelations.IMAGE:
if not cover_image or (
- not cover_image.representation.thumbnails and
- resource.representation.thumbnails):
+ not cover_image.representation.thumbnails
+ and resource.representation.thumbnails
+ ):
cover_image = resource
if cover_image.representation:
# This is technically redundant because
@@ -831,20 +899,23 @@ def opds_entry(self):
description = resource
if self.coverage_records:
- timestamps.extend([
- c.timestamp for c in self.coverage_records if c.timestamp
- ])
+ timestamps.extend(
+ [c.timestamp for c in self.coverage_records if c.timestamp]
+ )
if timestamps:
most_recent_update = max(timestamps)
quality = Measurement.overall_quality(self.measurements)
from ..opds import AcquisitionFeed
+
return AcquisitionFeed.minimal_opds_entry(
- identifier=self, cover=cover_image, description=description,
- quality=quality, most_recent_update=most_recent_update
+ identifier=self,
+ cover=cover_image,
+ description=description,
+ quality=quality,
+ most_recent_update=most_recent_update,
)
-
def __eq__(self, other):
"""Equality implementation for total_ordering."""
# We don't want an Identifier to be == an IdentifierData
@@ -852,7 +923,7 @@ def __eq__(self, other):
if other is None or not isinstance(other, Identifier):
return False
return (self.type, self.identifier) == (other.type, other.identifier)
-
+
def __hash__(self):
return hash((self.type, self.identifier))
@@ -868,18 +939,19 @@ class Equivalency(Base):
This assertion comes with a 'strength' which represents how confident
the data source is in the assertion.
"""
- __tablename__ = 'equivalents'
+
+ __tablename__ = "equivalents"
# 'input' is the ID that was used as input to the datasource.
# 'output' is the output
id = Column(Integer, primary_key=True)
- input_id = Column(Integer, ForeignKey('identifiers.id'), index=True)
+ input_id = Column(Integer, ForeignKey("identifiers.id"), index=True)
input = relationship("Identifier", foreign_keys=input_id)
- output_id = Column(Integer, ForeignKey('identifiers.id'), index=True)
+ output_id = Column(Integer, ForeignKey("identifiers.id"), index=True)
output = relationship("Identifier", foreign_keys=output_id)
# Who says?
- data_source_id = Column(Integer, ForeignKey('datasources.id'), index=True)
+ data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True)
# How many distinct votes went into this assertion? This will let
# us scale the change to the strength when additional votes come
@@ -901,7 +973,9 @@ def __repr__(self):
r = "[%s ->\n %s\n source=%s strength=%.2f votes=%d)]" % (
repr(self.input).decode("utf8"),
repr(self.output).decode("utf8"),
- self.data_source.name, self.strength, self.votes
+ self.data_source.name,
+ self.strength,
+ self.votes,
)
return r
@@ -912,9 +986,15 @@ def for_identifiers(self, _db, identifiers, exclude_ids=None):
return []
if isinstance(identifiers, list) and isinstance(identifiers[0], Identifier):
identifiers = [x.id for x in identifiers]
- q = _db.query(Equivalency).distinct().filter(
- or_(Equivalency.input_id.in_(identifiers),
- Equivalency.output_id.in_(identifiers))
+ q = (
+ _db.query(Equivalency)
+ .distinct()
+ .filter(
+ or_(
+ Equivalency.input_id.in_(identifiers),
+ Equivalency.output_id.in_(identifiers),
+ )
+ )
)
if exclude_ids:
q = q.filter(~Equivalency.id.in_(exclude_ids))
diff --git a/model/integrationclient.py b/model/integrationclient.py
index 29d05a84d..2b5d1c9e6 100644
--- a/model/integrationclient.py
+++ b/model/integrationclient.py
@@ -3,24 +3,14 @@
import os
import re
-from sqlalchemy import (
- Boolean,
- Column,
- DateTime,
- Integer,
- Unicode,
-)
-from sqlalchemy.orm import (
- relationship,
-)
-
-from . import (
- Base,
- get_one,
- get_one_or_create,
-)
-from ..util.string_helpers import random_string
+
+from sqlalchemy import Boolean, Column, DateTime, Integer, Unicode
+from sqlalchemy.orm import relationship
+
from ..util.datetime_helpers import utc_now
+from ..util.string_helpers import random_string
+from . import Base, get_one, get_one_or_create
+
class IntegrationClient(Base):
"""A client that has authenticated access to this application.
@@ -28,7 +18,8 @@ class IntegrationClient(Base):
Currently used to represent circulation managers that have access
to the metadata wrangler.
"""
- __tablename__ = 'integrationclients'
+
+ __tablename__ = "integrationclients"
id = Column(Integer, primary_key=True)
@@ -45,8 +36,8 @@ class IntegrationClient(Base):
created = Column(DateTime(timezone=True))
last_accessed = Column(DateTime(timezone=True))
- loans = relationship('Loan', backref='integration_client')
- holds = relationship('Hold', backref='integration_client')
+ loans = relationship("Loan", backref="integration_client")
+ holds = relationship("Hold", backref="integration_client")
def __repr__(self):
return "" % (self.url, self.id)
@@ -72,8 +63,12 @@ def register(cls, _db, url, submitted_secret=None):
"""Creates a new server with client details."""
client, is_new = cls.for_url(_db, url)
- if not is_new and (not submitted_secret or submitted_secret != client.shared_secret):
- raise ValueError('Cannot update existing IntegratedClient without valid shared_secret')
+ if not is_new and (
+ not submitted_secret or submitted_secret != client.shared_secret
+ ):
+ raise ValueError(
+ "Cannot update existing IntegratedClient without valid shared_secret"
+ )
generate_secret = (client.shared_secret is None) or submitted_secret
if generate_secret:
@@ -83,9 +78,9 @@ def register(cls, _db, url, submitted_secret=None):
@classmethod
def normalize_url(cls, url):
- url = re.sub(r'^(http://|https://)', '', url)
- url = re.sub(r'^www\.', '', url)
- if url.endswith('/'):
+ url = re.sub(r"^(http://|https://)", "", url)
+ url = re.sub(r"^www\.", "", url)
+ if url.endswith("/"):
url = url[:-1]
return str(url.lower())
diff --git a/model/library.py b/model/library.py
index 6dd28853b..911d1166b 100644
--- a/model/library.py
+++ b/model/library.py
@@ -1,44 +1,42 @@
# encoding: utf-8
# Library
-from expiringdict import ExpiringDict
-
-
-from . import (
- Base,
- get_one,
-)
-from ..config import Configuration
-from .circulationevent import CirculationEvent
-from .edition import Edition
-from ..entrypoint import EntryPoint
-from ..facets import FacetConstants
-from .hasfulltablecache import HasFullTableCache
-from .licensing import LicensePool
-from .work import Work
-
-from collections import Counter
import logging
+from collections import Counter
+
+from expiringdict import ExpiringDict
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
- func,
Integer,
Table,
Unicode,
UniqueConstraint,
+ func,
)
-from sqlalchemy.orm import relationship
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy.sql.functions import func
+from sqlalchemy.orm import relationship
from sqlalchemy.orm.session import Session
+from sqlalchemy.sql.functions import func
+
+from ..config import Configuration
+from ..entrypoint import EntryPoint
+from ..facets import FacetConstants
+from . import Base, get_one
+from .circulationevent import CirculationEvent
+from .edition import Edition
+from .hasfulltablecache import HasFullTableCache
+from .licensing import LicensePool
+from .work import Work
+
class Library(Base, HasFullTableCache):
"""A library that uses this circulation manager to authenticate
its patrons and manage access to its content.
A circulation manager may serve many libraries.
"""
- __tablename__ = 'libraries'
+
+ __tablename__ = "libraries"
id = Column(Integer, primary_key=True)
@@ -58,13 +56,13 @@ class Library(Base, HasFullTableCache):
# One, and only one, library may be the default. The default
# library is the one chosen when an incoming request does not
# designate a library.
- _is_default = Column(Boolean, index=True, default=False, name='is_default')
+ _is_default = Column(Boolean, index=True, default=False, name="is_default")
# The name of this library to use when signing short client tokens
# for consumption by the library registry. e.g. "NYNYPL" for NYPL.
# This name must be unique across the library registry.
_library_registry_short_name = Column(
- Unicode, unique=True, name='library_registry_short_name'
+ Unicode, unique=True, name="library_registry_short_name"
)
# The shared secret to use when signing short client tokens for
@@ -72,47 +70,53 @@ class Library(Base, HasFullTableCache):
library_registry_shared_secret = Column(Unicode, unique=True)
# A library may have many Patrons.
- patrons = relationship(
- 'Patron', backref='library', cascade="all, delete-orphan"
- )
+ patrons = relationship("Patron", backref="library", cascade="all, delete-orphan")
# An Library may have many admin roles.
- adminroles = relationship("AdminRole", backref="library", cascade="all, delete-orphan")
+ adminroles = relationship(
+ "AdminRole", backref="library", cascade="all, delete-orphan"
+ )
# A Library may have many CachedFeeds.
cachedfeeds = relationship(
- "CachedFeed", backref="library",
+ "CachedFeed",
+ backref="library",
cascade="all, delete-orphan",
)
# A Library may have many CachedMARCFiles.
cachedmarcfiles = relationship(
- "CachedMARCFile", backref="library",
+ "CachedMARCFile",
+ backref="library",
cascade="all, delete-orphan",
)
# A Library may have many CustomLists.
custom_lists = relationship(
- "CustomList", backref="library", lazy='joined',
+ "CustomList",
+ backref="library",
+ lazy="joined",
)
# A Library may have many ExternalIntegrations.
integrations = relationship(
- "ExternalIntegration", secondary=lambda: externalintegrations_libraries,
- backref="libraries"
+ "ExternalIntegration",
+ secondary=lambda: externalintegrations_libraries,
+ backref="libraries",
)
# Any additional configuration information is stored as
# ConfigurationSettings.
settings = relationship(
- "ConfigurationSetting", backref="library",
- lazy="joined", cascade="all, delete",
+ "ConfigurationSetting",
+ backref="library",
+ lazy="joined",
+ cascade="all, delete",
)
# A Library may have many CirculationEvents
circulation_events = relationship(
- "CirculationEvent", backref="library",
- cascade='all, delete-orphan'
+ "CirculationEvent", backref="library", cascade="all, delete-orphan"
)
_cache = HasFullTableCache.RESET
@@ -124,8 +128,9 @@ class Library(Base, HasFullTableCache):
_has_root_lane_cache = ExpiringDict(max_len=1000, max_age_seconds=3600)
def __repr__(self):
- return '' % (
- self.name, self.short_name, self.uuid, self.library_registry_short_name
+ return (
+ ''
+ % (self.name, self.short_name, self.uuid, self.library_registry_short_name)
)
def cache_key(self):
@@ -134,9 +139,11 @@ def cache_key(self):
@classmethod
def lookup(cls, _db, short_name):
"""Look up a library by short name."""
+
def _lookup():
library = get_one(_db, Library, short_name=short_name)
return library, False
+
library, is_new = cls.by_cache_key(_db, short_name, _lookup)
return library
@@ -147,8 +154,12 @@ def default(cls, _db):
# the database, they're not actually interchangeable, but
# raising an error here might make it impossible to fix the
# problem.
- defaults = _db.query(Library).filter(
- Library._is_default==True).order_by(Library.id.asc()).all()
+ defaults = (
+ _db.query(Library)
+ .filter(Library._is_default == True)
+ .order_by(Library.id.asc())
+ .all()
+ )
if len(defaults) == 1:
# This is the normal case.
return defaults[0]
@@ -163,9 +174,8 @@ def default(cls, _db):
return None
[default_library] = libraries
logging.warn(
- "No default library, setting %s as default." % (
- default_library.short_name
- )
+ "No default library, setting %s as default."
+ % (default_library.short_name)
)
else:
# There is more than one default, probably caused by a
@@ -173,9 +183,8 @@ def default(cls, _db):
# of the libraries as the default.
default_library = defaults[0]
logging.warn(
- "Multiple default libraries, setting %s as default." % (
- default_library.short_name
- )
+ "Multiple default libraries, setting %s as default."
+ % (default_library.short_name)
)
default_library.is_default = True
return default_library
@@ -190,7 +199,7 @@ def library_registry_short_name(self, value):
"""Uppercase the library registry short name on the way in."""
if value:
value = value.upper()
- if '|' in value:
+ if "|" in value:
raise ValueError(
"Library registry short name cannot contain the pipe character."
)
@@ -203,9 +212,8 @@ def setting(self, key):
:return: A ConfigurationSetting
"""
from .configuration import ConfigurationSetting
- return ConfigurationSetting.for_library(
- key, self
- )
+
+ return ConfigurationSetting.for_library(key, self)
@property
def all_collections(self):
@@ -218,7 +226,7 @@ def all_collections(self):
# The name of the per-library regular expression used to derive a patron's
# external_type from their authorization_identifier.
- EXTERNAL_TYPE_REGULAR_EXPRESSION = 'external_type_regular_expression'
+ EXTERNAL_TYPE_REGULAR_EXPRESSION = "external_type_regular_expression"
# The name of the per-library configuration policy that controls whether
# books may be put on hold.
@@ -286,12 +294,11 @@ def enabled_facets(self, group_name):
try:
value = setting.json_value
except ValueError as e:
- logging.error("Invalid list of enabled facets for %s: %s",
- group_name, setting.value)
- if value is None:
- value = list(
- FacetConstants.DEFAULT_ENABLED_FACETS.get(group_name, [])
+ logging.error(
+ "Invalid list of enabled facets for %s: %s", group_name, setting.value
)
+ if value is None:
+ value = list(FacetConstants.DEFAULT_ENABLED_FACETS.get(group_name, []))
return value
def enabled_facets_setting(self, group_name):
@@ -318,18 +325,22 @@ def has_root_lanes(self):
value = Library._has_root_lane_cache.get(self.id, None)
if value is None:
from ..lane import Lane
+
_db = Session.object_session(self)
- root_lanes = _db.query(Lane).filter(
- Lane.library==self
- ).filter(
- Lane.root_for_patron_type!=None
+ root_lanes = (
+ _db.query(Lane)
+ .filter(Lane.library == self)
+ .filter(Lane.root_for_patron_type != None)
)
- value = (root_lanes.count() > 0)
+ value = root_lanes.count() > 0
Library._has_root_lane_cache[self.id] = value
return value
def restrict_to_ready_deliverable_works(
- self, query, collection_ids=None, show_suppressed=False,
+ self,
+ query,
+ collection_ids=None,
+ show_suppressed=False,
):
"""Restrict a query to show only presentation-ready works present in
an appropriate collection which the default client can
@@ -343,10 +354,13 @@ def restrict_to_ready_deliverable_works(
suppressed LicensePools.
"""
from .collection import Collection
+
collection_ids = collection_ids or [x.id for x in self.all_collections]
return Collection.restrict_to_ready_deliverable_works(
- query, collection_ids=collection_ids,
- show_suppressed=show_suppressed, allow_holds=self.allow_holds
+ query,
+ collection_ids=collection_ids,
+ show_suppressed=show_suppressed,
+ allow_holds=self.allow_holds,
)
def estimated_holdings_by_language(self, include_open_access=True):
@@ -357,14 +371,17 @@ def estimated_holdings_by_language(self, include_open_access=True):
of titles in that language.
"""
_db = Session.object_session(self)
- qu = _db.query(
- Edition.language, func.count(Work.id).label("work_count")
- ).select_from(Work).join(Work.license_pools).join(
- Work.presentation_edition
- ).filter(Edition.language != None).group_by(Edition.language)
+ qu = (
+ _db.query(Edition.language, func.count(Work.id).label("work_count"))
+ .select_from(Work)
+ .join(Work.license_pools)
+ .join(Work.presentation_edition)
+ .filter(Edition.language != None)
+ .group_by(Edition.language)
+ )
qu = self.restrict_to_ready_deliverable_works(qu)
if not include_open_access:
- qu = qu.filter(LicensePool.open_access==False)
+ qu = qu.filter(LicensePool.open_access == False)
counter = Counter()
for language, count in qu:
counter[language] = count
@@ -399,13 +416,13 @@ def explain(self, include_secrets=False):
if self.library_registry_short_name:
lines.append(
- 'Short name (for library registry): "%s"' %
- self.library_registry_short_name
+ 'Short name (for library registry): "%s"'
+ % self.library_registry_short_name
)
- if (self.library_registry_shared_secret and include_secrets):
+ if self.library_registry_shared_secret and include_secrets:
lines.append(
- 'Shared secret (for library registry): "%s"' %
- self.library_registry_shared_secret
+ 'Shared secret (for library registry): "%s"'
+ % self.library_registry_shared_secret
)
# Find all ConfigurationSettings that are set on the library
@@ -425,9 +442,7 @@ def explain(self, include_secrets=False):
lines.append("External integrations:")
lines.append("----------------------")
for integration in integrations:
- lines.extend(
- integration.explain(self, include_secrets=include_secrets)
- )
+ lines.extend(integration.explain(self, include_secrets=include_secrets))
lines.append("")
return lines
@@ -450,15 +465,19 @@ def is_default(self, new_is_default):
else:
library._is_default = False
+
externalintegrations_libraries = Table(
- 'externalintegrations_libraries', Base.metadata,
- Column(
- 'externalintegration_id', Integer, ForeignKey('externalintegrations.id'),
- index=True, nullable=False
- ),
- Column(
- 'library_id', Integer, ForeignKey('libraries.id'),
- index=True, nullable=False
- ),
- UniqueConstraint('externalintegration_id', 'library_id'),
- )
+ "externalintegrations_libraries",
+ Base.metadata,
+ Column(
+ "externalintegration_id",
+ Integer,
+ ForeignKey("externalintegrations.id"),
+ index=True,
+ nullable=False,
+ ),
+ Column(
+ "library_id", Integer, ForeignKey("libraries.id"), index=True, nullable=False
+ ),
+ UniqueConstraint("externalintegration_id", "library_id"),
+)
diff --git a/model/licensing.py b/model/licensing.py
index 8d9f9fb02..5e03f97e8 100644
--- a/model/licensing.py
+++ b/model/licensing.py
@@ -2,6 +2,7 @@
# PolicyException LicensePool, LicensePoolDeliveryMechanism, DeliveryMechanism,
# RightsStatus
import logging
+
from sqlalchemy import (
Boolean,
Column,
@@ -18,13 +19,13 @@
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.functions import func
+from ..util.datetime_helpers import utc_now
+from . import Base, create, flush, get_one, get_one_or_create
from .circulationevent import CirculationEvent
from .complaint import Complaint
from .constants import DataSourceConstants, EditionConstants, LinkRelations, MediaTypes
from .hasfulltablecache import HasFullTableCache
from .patron import Hold, Loan, Patron
-from . import Base, create, flush, get_one, get_one_or_create
-from ..util.datetime_helpers import utc_now
class PolicyException(Exception):
@@ -38,7 +39,7 @@ class License(Base):
delivery mechanisms, which may not always be true.
"""
- __tablename__ = 'licenses'
+ __tablename__ = "licenses"
id = Column(Integer, primary_key=True)
identifier = Column(Unicode)
@@ -50,16 +51,12 @@ class License(Base):
concurrent_checkouts = Column(Integer)
# A License belongs to one LicensePool.
- license_pool_id = Column(Integer, ForeignKey('licensepools.id'), index=True)
+ license_pool_id = Column(Integer, ForeignKey("licensepools.id"), index=True)
# One License can have many Loans.
- loans = relationship(
- 'Loan', backref='license', cascade='all, delete-orphan'
- )
+ loans = relationship("Loan", backref="license", cascade="all, delete-orphan")
- __table_args__ = (
- UniqueConstraint('identifier', 'license_pool_id'),
- )
+ __table_args__ = (UniqueConstraint("identifier", "license_pool_id"),)
def loan_to(self, patron_or_client, **kwargs):
loan, is_new = self.license_pool.loan_to(patron_or_client, **kwargs)
@@ -81,8 +78,9 @@ def is_loan_limited(self):
@property
def is_expired(self):
now = utc_now()
- return ((self.expires and self.expires <= now) or
- (self.remaining_checkouts is not None and self.remaining_checkouts <= 0))
+ return (self.expires and self.expires <= now) or (
+ self.remaining_checkouts is not None and self.remaining_checkouts <= 0
+ )
class LicensePool(Base):
@@ -90,51 +88,47 @@ class LicensePool(Base):
UNLIMITED_ACCESS = -1
- __tablename__ = 'licensepools'
+ __tablename__ = "licensepools"
id = Column(Integer, primary_key=True)
# A LicensePool may be associated with a Work. (If it's not, no one
# can check it out.)
- work_id = Column(Integer, ForeignKey('works.id'), index=True)
+ work_id = Column(Integer, ForeignKey("works.id"), index=True)
# Each LicensePool is associated with one DataSource and one
# Identifier.
- data_source_id = Column(Integer, ForeignKey('datasources.id'), index=True)
- identifier_id = Column(Integer, ForeignKey('identifiers.id'), index=True)
+ data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True)
+ identifier_id = Column(Integer, ForeignKey("identifiers.id"), index=True)
# Each LicensePool belongs to one Collection.
- collection_id = Column(Integer, ForeignKey('collections.id'),
- index=True, nullable=False)
+ collection_id = Column(
+ Integer, ForeignKey("collections.id"), index=True, nullable=False
+ )
# Each LicensePool has an Edition which contains the metadata used
# to describe this book.
- presentation_edition_id = Column(Integer, ForeignKey('editions.id'), index=True)
+ presentation_edition_id = Column(Integer, ForeignKey("editions.id"), index=True)
# If the source provides information about individual licenses, the
# LicensePool may have many Licenses.
licenses = relationship(
- 'License', backref='license_pool', cascade='all, delete-orphan'
+ "License", backref="license_pool", cascade="all, delete-orphan"
)
# One LicensePool can have many Loans.
- loans = relationship(
- 'Loan', backref='license_pool', cascade='all, delete-orphan'
- )
+ loans = relationship("Loan", backref="license_pool", cascade="all, delete-orphan")
# One LicensePool can have many Holds.
- holds = relationship(
- 'Hold', backref='license_pool', cascade='all, delete-orphan'
- )
+ holds = relationship("Hold", backref="license_pool", cascade="all, delete-orphan")
# One LicensePool can have many CirculationEvents
circulation_events = relationship(
- "CirculationEvent", backref="license_pool",
- cascade='all, delete-orphan'
+ "CirculationEvent", backref="license_pool", cascade="all, delete-orphan"
)
# One LicensePool can be associated with many Complaints.
complaints = relationship(
- 'Complaint', backref='license_pool', cascade='all, delete-orphan'
+ "Complaint", backref="license_pool", cascade="all, delete-orphan"
)
# The date this LicensePool was first created in our db
@@ -158,9 +152,9 @@ class LicensePool(Base):
open_access = Column(Boolean, index=True)
last_checked = Column(DateTime(timezone=True), index=True)
licenses_owned = Column(Integer, default=0, index=True)
- licenses_available = Column(Integer,default=0, index=True)
- licenses_reserved = Column(Integer,default=0)
- patrons_in_hold_queue = Column(Integer,default=0)
+ licenses_available = Column(Integer, default=0, index=True)
+ licenses_reserved = Column(Integer, default=0)
+ patrons_in_hold_queue = Column(Integer, default=0)
# Set to True for collections imported using MirrorUploaded
self_hosted = Column(Boolean, index=True, nullable=False, default=False)
@@ -172,7 +166,7 @@ class LicensePool(Base):
# A Collection can not have more than one LicensePool for a given
# Identifier from a given DataSource.
__table_args__ = (
- UniqueConstraint('identifier_id', 'data_source_id', 'collection_id'),
+ UniqueConstraint("identifier_id", "data_source_id", "collection_id"),
)
delivery_mechanisms = relationship(
@@ -184,13 +178,19 @@ class LicensePool(Base):
def __repr__(self):
if self.identifier:
- identifier = "%s/%s" % (self.identifier.type,
- self.identifier.identifier)
+ identifier = "%s/%s" % (self.identifier.type, self.identifier.identifier)
else:
identifier = "unknown identifier"
- return "" % (
- self.id, identifier, self.licenses_owned, self.licenses_available,
- self.licenses_reserved, self.patrons_in_hold_queue
+ return (
+ ""
+ % (
+ self.id,
+ identifier,
+ self.licenses_owned,
+ self.licenses_available,
+ self.licenses_reserved,
+ self.patrons_in_hold_queue,
+ )
)
@hybrid_property
@@ -219,12 +219,21 @@ def unlimited_access(self, value):
self.licenses_available = 0
@classmethod
- def for_foreign_id(self, _db, data_source, foreign_id_type, foreign_id,
- rights_status=None, collection=None, autocreate=True):
+ def for_foreign_id(
+ self,
+ _db,
+ data_source,
+ foreign_id_type,
+ foreign_id,
+ rights_status=None,
+ collection=None,
+ autocreate=True,
+ ):
"""Find or create a LicensePool for the given foreign ID."""
from .collection import CollectionMissing
from .datasource import DataSource
from .identifier import Identifier
+
if not collection:
raise CollectionMissing()
@@ -234,27 +243,28 @@ def for_foreign_id(self, _db, data_source, foreign_id_type, foreign_id,
# The type of the foreign ID must be the primary identifier
# type for the data source.
- if (data_source.primary_identifier_type and
- foreign_id_type != data_source.primary_identifier_type
- and foreign_id_type != Identifier.DEPRECATED_NAMES.get(data_source.primary_identifier_type)
+ if (
+ data_source.primary_identifier_type
+ and foreign_id_type != data_source.primary_identifier_type
+ and foreign_id_type
+ != Identifier.DEPRECATED_NAMES.get(data_source.primary_identifier_type)
):
raise ValueError(
"License pools for data source '%s' are keyed to "
- "identifier type '%s' (not '%s', which was provided)" % (
- data_source.name, data_source.primary_identifier_type,
- foreign_id_type
+ "identifier type '%s' (not '%s', which was provided)"
+ % (
+ data_source.name,
+ data_source.primary_identifier_type,
+ foreign_id_type,
)
)
# Get the Identifier.
- identifier, ignore = Identifier.for_foreign_id(
- _db, foreign_id_type, foreign_id
- )
+ identifier, ignore = Identifier.for_foreign_id(_db, foreign_id_type, foreign_id)
- kw = dict(data_source=data_source, identifier=identifier,
- collection=collection)
+ kw = dict(data_source=data_source, identifier=identifier, collection=collection)
if rights_status:
- kw['rights_status'] = rights_status
+ kw["rights_status"] = rights_status
# Get the LicensePool that corresponds to the
# DataSource/Identifier/Collection.
@@ -282,8 +292,8 @@ def for_foreign_id(self, _db, data_source, foreign_id_type, foreign_id,
def with_no_work(cls, _db):
"""Find LicensePools that have no corresponding Work."""
from .work import Work
- return _db.query(LicensePool).outerjoin(Work).filter(
- Work.id==None).all()
+
+ return _db.query(LicensePool).outerjoin(Work).filter(Work.id == None).all()
@classmethod
def with_no_delivery_mechanisms(cls, _db):
@@ -291,21 +301,20 @@ def with_no_delivery_mechanisms(cls, _db):
:return: A query object.
"""
- return _db.query(LicensePool).outerjoin(
- LicensePool.delivery_mechanisms).filter(
- LicensePoolDeliveryMechanism.id==None
- )
+ return (
+ _db.query(LicensePool)
+ .outerjoin(LicensePool.delivery_mechanisms)
+ .filter(LicensePoolDeliveryMechanism.id == None)
+ )
@property
def deliverable(self):
- """This LicensePool can actually be delivered to patrons.
- """
- return (
- (self.open_access or self.licenses_owned > 0)
- and any(
- [dm.delivery_mechanism.default_client_can_fulfill
- for dm in self.delivery_mechanisms]
- )
+ """This LicensePool can actually be delivered to patrons."""
+ return (self.open_access or self.licenses_owned > 0) and any(
+ [
+ dm.delivery_mechanism.default_client_can_fulfill
+ for dm in self.delivery_mechanisms
+ ]
)
@classmethod
@@ -313,19 +322,19 @@ def with_complaint(cls, library, resolved=False):
"""Return query for LicensePools that have at least one Complaint."""
from .collection import Collection
from .library import Library
+
_db = Session.object_session(library)
- subquery = _db.query(
- LicensePool.id,
- func.count(LicensePool.id).label("complaint_count")
- ).select_from(LicensePool).join(
- LicensePool.collection).join(
- Collection.libraries).filter(
- Library.id==library.id
- ).join(
- LicensePool.complaints
- ).group_by(
- LicensePool.id
- )
+ subquery = (
+ _db.query(
+ LicensePool.id, func.count(LicensePool.id).label("complaint_count")
+ )
+ .select_from(LicensePool)
+ .join(LicensePool.collection)
+ .join(Collection.libraries)
+ .filter(Library.id == library.id)
+ .join(LicensePool.complaints)
+ .group_by(LicensePool.id)
+ )
if resolved == False:
subquery = subquery.filter(Complaint.resolved == None)
@@ -334,10 +343,12 @@ def with_complaint(cls, library, resolved=False):
subquery = subquery.subquery()
- return _db.query(LicensePool).\
- join(subquery, LicensePool.id == subquery.c.id).\
- order_by(subquery.c.complaint_count.desc()).\
- add_columns(subquery.c.complaint_count)
+ return (
+ _db.query(LicensePool)
+ .join(subquery, LicensePool.id == subquery.c.id)
+ .order_by(subquery.c.complaint_count.desc())
+ .add_columns(subquery.c.complaint_count)
+ )
@property
def open_access_source_priority(self):
@@ -358,7 +369,7 @@ def open_access_source_priority(self):
return priority
def better_open_access_pool_than(self, champion):
- """ Is this open-access pool generally known for better-quality
+ """Is this open-access pool generally known for better-quality
download files than the passed-in pool?
"""
# A license pool with no identifier shouldn't happen, but it
@@ -401,8 +412,10 @@ def better_open_access_pool_than(self, champion):
if challenger_priority < champion_priority:
return False
- if (self.data_source.name == DataSourceConstants.GUTENBERG
- and champion.data_source == self.data_source):
+ if (
+ self.data_source.name == DataSourceConstants.GUTENBERG
+ and champion.data_source == self.data_source
+ ):
# These two LicensePools are both from Gutenberg, and
# normally this wouldn't matter, but higher Gutenberg
# numbers beat lower Gutenberg numbers.
@@ -411,8 +424,7 @@ def better_open_access_pool_than(self, champion):
if challenger_id > champion_id:
logging.info(
- "Gutenberg %d beats Gutenberg %d",
- challenger_id, champion_id
+ "Gutenberg %d beats Gutenberg %d", challenger_id, champion_id
)
return True
return False
@@ -441,6 +453,7 @@ def set_presentation_edition(self, equivalent_editions=None):
information associated with this LicensePool actually changed.
"""
from .edition import Edition
+
_db = Session.object_session(self)
old_presentation_edition = self.presentation_edition
changed = False
@@ -460,11 +473,16 @@ def set_presentation_edition(self, equivalent_editions=None):
# than creating an identical composite.
self.presentation_edition = all_editions[0]
else:
- edition_identifier = IdentifierData(self.identifier.type, self.identifier.identifier)
- metadata = Metadata(data_source=DataSourceConstants.PRESENTATION_EDITION, primary_identifier=edition_identifier)
+ edition_identifier = IdentifierData(
+ self.identifier.type, self.identifier.identifier
+ )
+ metadata = Metadata(
+ data_source=DataSourceConstants.PRESENTATION_EDITION,
+ primary_identifier=edition_identifier,
+ )
for edition in all_editions:
- if (edition.data_source.name != DataSourceConstants.PRESENTATION_EDITION):
+ if edition.data_source.name != DataSourceConstants.PRESENTATION_EDITION:
metadata.update(Metadata.from_edition(edition))
# Note: Since this is a presentation edition it does not have a
@@ -488,16 +506,21 @@ def set_presentation_edition(self, equivalent_editions=None):
if self.work and not self.work.presentation_edition:
self.work.set_presentation_edition(self.presentation_edition)
- return (
- self.presentation_edition != old_presentation_edition
- or changed
- )
-
- def add_link(self, rel, href, data_source, media_type=None,
- content=None, content_path=None,
- rights_status_uri=None, rights_explanation=None,
- original_resource=None, transformation_settings=None,
- ):
+ return self.presentation_edition != old_presentation_edition or changed
+
+ def add_link(
+ self,
+ rel,
+ href,
+ data_source,
+ media_type=None,
+ content=None,
+ content_path=None,
+ rights_status_uri=None,
+ rights_explanation=None,
+ original_resource=None,
+ transformation_settings=None,
+ ):
"""Add a link between this LicensePool and a Resource.
:param rel: The relationship between this LicensePool and the resource
@@ -517,9 +540,17 @@ def add_link(self, rel, href, data_source, media_type=None,
resource into this resource.
"""
return self.identifier.add_link(
- rel, href, data_source, media_type, content, content_path,
- rights_status_uri, rights_explanation, original_resource,
- transformation_settings)
+ rel,
+ href,
+ data_source,
+ media_type,
+ content,
+ content_path,
+ rights_status_uri,
+ rights_explanation,
+ original_resource,
+ transformation_settings,
+ )
def needs_update(self):
"""Is it time to update the circulation info for this license pool?"""
@@ -528,7 +559,8 @@ def needs_update(self):
# This pool has never had its circulation info checked.
return True
maximum_stale_time = self.data_source.extra.get(
- 'circulation_refresh_rate_seconds')
+ "circulation_refresh_rate_seconds"
+ )
if maximum_stale_time is None:
# This pool never needs to have its circulation info checked.
return False
@@ -536,9 +568,14 @@ def needs_update(self):
return age > maximum_stale_time
def update_availability(
- self, new_licenses_owned, new_licenses_available,
- new_licenses_reserved, new_patrons_in_hold_queue,
- analytics=None, as_of=None):
+ self,
+ new_licenses_owned,
+ new_licenses_available,
+ new_licenses_reserved,
+ new_patrons_in_hold_queue,
+ analytics=None,
+ as_of=None,
+ ):
"""Update the LicensePool with new availability information.
Log the implied changes with the analytics provider.
"""
@@ -557,15 +594,31 @@ def update_availability(
old_patrons_in_hold_queue = self.patrons_in_hold_queue
for old_value, new_value, more_event, fewer_event in (
- [self.patrons_in_hold_queue, new_patrons_in_hold_queue,
- CirculationEvent.DISTRIBUTOR_HOLD_PLACE, CirculationEvent.DISTRIBUTOR_HOLD_RELEASE],
- [self.licenses_available, new_licenses_available,
- CirculationEvent.DISTRIBUTOR_CHECKIN, CirculationEvent.DISTRIBUTOR_CHECKOUT],
- [self.licenses_reserved, new_licenses_reserved,
- CirculationEvent.DISTRIBUTOR_AVAILABILITY_NOTIFY, None],
- [self.licenses_owned, new_licenses_owned,
- CirculationEvent.DISTRIBUTOR_LICENSE_ADD,
- CirculationEvent.DISTRIBUTOR_LICENSE_REMOVE]):
+ [
+ self.patrons_in_hold_queue,
+ new_patrons_in_hold_queue,
+ CirculationEvent.DISTRIBUTOR_HOLD_PLACE,
+ CirculationEvent.DISTRIBUTOR_HOLD_RELEASE,
+ ],
+ [
+ self.licenses_available,
+ new_licenses_available,
+ CirculationEvent.DISTRIBUTOR_CHECKIN,
+ CirculationEvent.DISTRIBUTOR_CHECKOUT,
+ ],
+ [
+ self.licenses_reserved,
+ new_licenses_reserved,
+ CirculationEvent.DISTRIBUTOR_AVAILABILITY_NOTIFY,
+ None,
+ ],
+ [
+ self.licenses_owned,
+ new_licenses_owned,
+ CirculationEvent.DISTRIBUTOR_LICENSE_ADD,
+ CirculationEvent.DISTRIBUTOR_LICENSE_REMOVE,
+ ],
+ ):
if new_value is None:
continue
if old_value == new_value:
@@ -610,24 +663,33 @@ def update_availability(
if changes_made:
message, args = self.circulation_changelog(
- old_licenses_owned, old_licenses_available,
- old_licenses_reserved, old_patrons_in_hold_queue
+ old_licenses_owned,
+ old_licenses_available,
+ old_licenses_reserved,
+ old_patrons_in_hold_queue,
)
logging.info(message, *args)
return changes_made
- def collect_analytics_event(self, analytics, event_name, as_of,
- old_value, new_value):
+ def collect_analytics_event(
+ self, analytics, event_name, as_of, old_value, new_value
+ ):
if not analytics:
return
for library in self.collection.libraries:
analytics.collect_event(
- library, self, event_name, as_of,
- old_value=old_value, new_value=new_value
+ library,
+ self,
+ event_name,
+ as_of,
+ old_value=old_value,
+ new_value=new_value,
)
- def update_availability_from_delta(self, event_type, event_date, delta, analytics=None):
+ def update_availability_from_delta(
+ self, event_type, event_date, delta, analytics=None
+ ):
"""Call update_availability based on a single change seen in the
distributor data, rather than a complete snapshot of
distributor information as of a certain time.
@@ -647,7 +709,11 @@ def update_availability_from_delta(self, event_type, event_date, delta, analytic
:param delta: The magnitude of the change that was seen.
"""
ignore = False
- if event_date != CirculationEvent.NO_DATE and self.last_checked and event_date < self.last_checked:
+ if (
+ event_date != CirculationEvent.NO_DATE
+ and self.last_checked
+ and event_date < self.last_checked
+ ):
# This is an old event and its effect on availability has
# already been taken into account.
ignore = True
@@ -659,24 +725,26 @@ def update_availability_from_delta(self, event_type, event_date, delta, analytic
ignore = True
if not ignore:
- (new_licenses_owned, new_licenses_available,
- new_licenses_reserved,
- new_patrons_in_hold_queue) = self._calculate_change_from_one_event(
- event_type, delta
- )
+ (
+ new_licenses_owned,
+ new_licenses_available,
+ new_licenses_reserved,
+ new_patrons_in_hold_queue,
+ ) = self._calculate_change_from_one_event(event_type, delta)
changes_made = self.update_availability(
- new_licenses_owned, new_licenses_available,
- new_licenses_reserved, new_patrons_in_hold_queue,
- analytics=analytics, as_of=event_date
+ new_licenses_owned,
+ new_licenses_available,
+ new_licenses_reserved,
+ new_patrons_in_hold_queue,
+ analytics=analytics,
+ as_of=event_date,
)
if ignore or not changes_made:
# Even if the event was ignored or didn't actually change
# availability, we want to record receipt of the event
# in the analytics.
- self.collect_analytics_event(
- analytics, event_type, event_date, 0, 0
- )
+ self.collect_analytics_event(analytics, event_type, event_date, 0, 0)
def _calculate_change_from_one_event(self, type, delta):
new_licenses_owned = self.licenses_owned
@@ -687,7 +755,7 @@ def _calculate_change_from_one_event(self, type, delta):
def deduct(value):
# It's impossible for any of these numbers to be
# negative.
- return max(value-delta, 0)
+ return max(value - delta, 0)
CE = CirculationEvent
added = False
@@ -712,7 +780,7 @@ def deduct(value):
# future as DISTRIBUTOR_AVAILABILITY_NOTIFICATION events
# are sent out.
if delta > new_patrons_in_hold_queue:
- new_licenses_available += (delta-new_patrons_in_hold_queue)
+ new_licenses_available += delta - new_patrons_in_hold_queue
elif type == CE.DISTRIBUTOR_CHECKOUT:
if new_licenses_available == 0:
# The only way to borrow books while there are no
@@ -746,30 +814,42 @@ def deduct(value):
# latter is more likely.
new_licenses_available = new_licenses_owned
- return (new_licenses_owned, new_licenses_available,
- new_licenses_reserved, new_patrons_in_hold_queue)
+ return (
+ new_licenses_owned,
+ new_licenses_available,
+ new_licenses_reserved,
+ new_patrons_in_hold_queue,
+ )
- def circulation_changelog(self, old_licenses_owned, old_licenses_available,
- old_licenses_reserved, old_patrons_in_hold_queue):
+ def circulation_changelog(
+ self,
+ old_licenses_owned,
+ old_licenses_available,
+ old_licenses_reserved,
+ old_patrons_in_hold_queue,
+ ):
"""Generate a log message describing a change to the circulation.
:return: a 2-tuple (message, args) suitable for passing into
logging.info or a similar method
"""
edition = self.presentation_edition
- message = 'CHANGED '
+ message = "CHANGED "
args = []
if self.identifier:
- identifier_template = '%s/%s'
+ identifier_template = "%s/%s"
identifier_args = [self.identifier.type, self.identifier.identifier]
else:
- identifier_template = '%s'
+ identifier_template = "%s"
identifier_args = [self.identifier]
if edition:
- message += '%s "%s" %s (' + identifier_template + ')'
- args.extend([edition.medium,
- edition.title or "[NO TITLE]",
- edition.author or "[NO AUTHOR]"]
- )
+ message += '%s "%s" %s (' + identifier_template + ")"
+ args.extend(
+ [
+ edition.medium,
+ edition.title or "[NO TITLE]",
+ edition.author or "[NO AUTHOR]",
+ ]
+ )
args.extend(identifier_args)
else:
message += identifier_template
@@ -778,7 +858,7 @@ def circulation_changelog(self, old_licenses_owned, old_licenses_available,
def _part(message, args, string, old_value, new_value):
if old_value != new_value:
args.extend([string, old_value, new_value])
- message += ' %s: %s=>%s'
+ message += " %s: %s=>%s"
return message, args
message, args = _part(
@@ -786,29 +866,35 @@ def _part(message, args, string, old_value, new_value):
)
message, args = _part(
- message, args, "AVAIL", old_licenses_available,
- self.licenses_available
+ message, args, "AVAIL", old_licenses_available, self.licenses_available
)
message, args = _part(
- message, args, "RSRV", old_licenses_reserved,
- self.licenses_reserved
+ message, args, "RSRV", old_licenses_reserved, self.licenses_reserved
)
- message, args =_part(
- message, args, "HOLD", old_patrons_in_hold_queue,
- self.patrons_in_hold_queue
+ message, args = _part(
+ message, args, "HOLD", old_patrons_in_hold_queue, self.patrons_in_hold_queue
)
return message, tuple(args)
- def loan_to(self, patron_or_client, start=None, end=None, fulfillment=None, external_identifier=None):
+ def loan_to(
+ self,
+ patron_or_client,
+ start=None,
+ end=None,
+ fulfillment=None,
+ external_identifier=None,
+ ):
_db = Session.object_session(patron_or_client)
- kwargs = dict(start=start or utc_now(),
- end=end)
+ kwargs = dict(start=start or utc_now(), end=end)
if isinstance(patron_or_client, Patron):
loan, is_new = get_one_or_create(
- _db, Loan, patron=patron_or_client, license_pool=self,
- create_method_kwargs=kwargs
+ _db,
+ Loan,
+ patron=patron_or_client,
+ license_pool=self,
+ create_method_kwargs=kwargs,
)
if is_new:
@@ -820,17 +906,31 @@ def loan_to(self, patron_or_client, start=None, end=None, fulfillment=None, exte
# An IntegrationClient can have multiple loans, so this always creates
# a new loan rather than returning an existing loan.
loan, is_new = create(
- _db, Loan, integration_client=patron_or_client, license_pool=self,
- create_method_kwargs=kwargs)
+ _db,
+ Loan,
+ integration_client=patron_or_client,
+ license_pool=self,
+ create_method_kwargs=kwargs,
+ )
if fulfillment:
loan.fulfillment = fulfillment
if external_identifier:
loan.external_identifier = external_identifier
return loan, is_new
- def on_hold_to(self, patron_or_client, start=None, end=None, position=None, external_identifier=None):
+ def on_hold_to(
+ self,
+ patron_or_client,
+ start=None,
+ end=None,
+ position=None,
+ external_identifier=None,
+ ):
_db = Session.object_session(patron_or_client)
- if isinstance(patron_or_client, Patron) and not patron_or_client.library.allow_holds:
+ if (
+ isinstance(patron_or_client, Patron)
+ and not patron_or_client.library.allow_holds
+ ):
raise PolicyException("Holds are disabled for this library.")
start = start or utc_now()
if isinstance(patron_or_client, Patron):
@@ -846,7 +946,8 @@ def on_hold_to(self, patron_or_client, start=None, end=None, position=None, exte
# An IntegrationClient can have multiple holds, so this always creates
# a new hold rather than returning an existing loan.
hold, new = create(
- _db, Hold, integration_client=patron_or_client, license_pool=self)
+ _db, Hold, integration_client=patron_or_client, license_pool=self
+ )
hold.update(start, end, position)
if external_identifier:
hold.external_identifier = external_identifier
@@ -876,17 +977,27 @@ def best_available_license(self):
if license.is_expired:
continue
- active_loan_count = len([l for l in license.loans if not l.end or l.end > now])
+ active_loan_count = len(
+ [l for l in license.loans if not l.end or l.end > now]
+ )
if active_loan_count >= license.concurrent_checkouts:
continue
if (
- not best or
- (license.is_time_limited and not best.is_time_limited) or
- (license.is_time_limited and best.is_time_limited and license.expires < best.expires) or
- (license.is_perpetual and not best.is_time_limited) or
- (license.is_loan_limited and best.is_loan_limited and license.remaining_checkouts > best.remaining_checkouts)
- ):
+ not best
+ or (license.is_time_limited and not best.is_time_limited)
+ or (
+ license.is_time_limited
+ and best.is_time_limited
+ and license.expires < best.expires
+ )
+ or (license.is_perpetual and not best.is_time_limited)
+ or (
+ license.is_loan_limited
+ and best.is_loan_limited
+ and license.remaining_checkouts > best.remaining_checkouts
+ )
+ ):
best = license
return best
@@ -896,9 +1007,7 @@ def consolidate_works(cls, _db, batch_size=10):
"""Assign a (possibly new) Work to every unassigned LicensePool."""
a = 0
lps = cls.with_no_work(_db)
- logging.info(
- "Assigning Works to %d LicensePools with no Work.", len(lps)
- )
+ logging.info("Assigning Works to %d LicensePools with no Work.", len(lps))
for unassigned in lps:
etext, new = unassigned.calculate_work()
if not etext:
@@ -912,10 +1021,8 @@ def consolidate_works(cls, _db, batch_size=10):
_db.commit()
_db.commit()
-
def calculate_work(
- self, known_edition=None, exclude_search=False,
- even_if_no_title=False
+ self, known_edition=None, exclude_search=False, even_if_no_title=False
):
"""Find or create a Work for this LicensePool.
A pool that is not open-access will always have its own
@@ -955,8 +1062,9 @@ def calculate_work(
if not presentation_edition:
# We don't have any information about the identifier
# associated with this LicensePool, so we can't create a work.
- logging.warn("NO EDITION for %s, cowardly refusing to create work.",
- self.identifier)
+ logging.warn(
+ "NO EDITION for %s, cowardly refusing to create work.", self.identifier
+ )
# If there was a work associated with this LicensePool,
# it was by mistake. Remove it.
@@ -969,10 +1077,14 @@ def calculate_work(
if not presentation_edition.title and not even_if_no_title:
if presentation_edition.work:
logging.warn(
- "Edition %r has no title but has a Work assigned. This will not stand.", presentation_edition
+ "Edition %r has no title but has a Work assigned. This will not stand.",
+ presentation_edition,
)
else:
- logging.info("Edition %r has no title and it will not get a Work.", presentation_edition)
+ logging.info(
+ "Edition %r has no title and it will not get a Work.",
+ presentation_edition,
+ )
self.work = None
self.work_id = None
return None, False
@@ -992,8 +1104,10 @@ def calculate_work(
# Work.open_access_for_permanent_work_id may result in works being
# merged.
work, is_new = Work.open_access_for_permanent_work_id(
- _db, presentation_edition.permanent_work_id,
- presentation_edition.medium, presentation_edition.language
+ _db,
+ presentation_edition.permanent_work_id,
+ presentation_edition.medium,
+ presentation_edition.language,
)
# Run a sanity check to make sure every LicensePool
@@ -1015,7 +1129,8 @@ def calculate_work(
existing_works = set([x.work for x in self.identifier.licensed_through])
if len(existing_works) > 1:
logging.warn(
- "LicensePools for %r have more than one Work between them. Removing them all and starting over.", self.identifier
+ "LicensePools for %r have more than one Work between them. Removing them all and starting over.",
+ self.identifier,
)
for lp in self.identifier.licensed_through:
lp.work = None
@@ -1054,17 +1169,14 @@ def calculate_work(
if not (self.open_access and pool.open_access):
pool.work = None
pool.calculate_work(
- exclude_search=exclude_search,
- even_if_no_title=even_if_no_title
+ exclude_search=exclude_search, even_if_no_title=even_if_no_title
)
licensepools_changed = True
else:
# There is no better choice than creating a brand new Work.
is_new = True
- logging.info(
- "Creating a new work for %r" % presentation_edition.title
- )
+ logging.info("Creating a new work for %r" % presentation_edition.title)
work = Work()
_db = Session.object_session(self)
_db.add(work)
@@ -1096,11 +1208,11 @@ def calculate_work(
# All done!
return work, is_new
-
@property
def open_access_links(self):
"""Yield all open-access Resources for this LicensePool."""
from .identifier import Identifier
+
open_access = LinkRelations.OPEN_ACCESS_DOWNLOAD
_db = Session.object_session(self)
if not self.identifier:
@@ -1143,10 +1255,13 @@ def best_open_access_resource(self):
best_priority = -1
for resource in self.open_access_links:
if not any(
- [resource.representation and
- resource.representation.media_type and
- resource.representation.media_type.startswith(x)
- for x in MediaTypes.SUPPORTED_BOOK_MEDIA_TYPES]):
+ [
+ resource.representation
+ and resource.representation.media_type
+ and resource.representation.media_type.startswith(x)
+ for x in MediaTypes.SUPPORTED_BOOK_MEDIA_TYPES
+ ]
+ ):
# This representation is not in a media type we
# support. We can't serve it, so we won't consider it.
continue
@@ -1158,10 +1273,12 @@ def best_open_access_resource(self):
best_priority = data_source_priority
continue
- if (best.data_source.name==DataSourceConstants.GUTENBERG
- and resource.data_source.name==DataSourceConstants.GUTENBERG
- and 'noimages' in best.representation.public_url
- and not 'noimages' in resource.representation.public_url):
+ if (
+ best.data_source.name == DataSourceConstants.GUTENBERG
+ and resource.data_source.name == DataSourceConstants.GUTENBERG
+ and "noimages" in best.representation.public_url
+ and not "noimages" in resource.representation.public_url
+ ):
# A Project Gutenberg-ism: an epub without 'noimages'
# in the filename is better than an epub with
# 'noimages' in the filename.
@@ -1203,7 +1320,15 @@ def set_delivery_mechanism(self, *args, **kwargs):
self.data_source, self.identifier, *args, **kwargs
)
-Index("ix_licensepools_data_source_id_identifier_id_collection_id", LicensePool.collection_id, LicensePool.data_source_id, LicensePool.identifier_id, unique=True)
+
+Index(
+ "ix_licensepools_data_source_id_identifier_id_collection_id",
+ LicensePool.collection_id,
+ LicensePool.data_source_id,
+ LicensePool.identifier_id,
+ unique=True,
+)
+
class LicensePoolDeliveryMechanism(Base):
"""A mechanism for delivering a specific book from a specific
@@ -1215,36 +1340,42 @@ class LicensePoolDeliveryMechanism(Base):
(i.e. a static link to a downloadable file) which explains exactly
where to go for delivery.
"""
- __tablename__ = 'licensepooldeliveries'
+
+ __tablename__ = "licensepooldeliveries"
id = Column(Integer, primary_key=True)
data_source_id = Column(
- Integer, ForeignKey('datasources.id'), index=True, nullable=False
+ Integer, ForeignKey("datasources.id"), index=True, nullable=False
)
identifier_id = Column(
- Integer, ForeignKey('identifiers.id'), index=True, nullable=False
+ Integer, ForeignKey("identifiers.id"), index=True, nullable=False
)
delivery_mechanism_id = Column(
- Integer, ForeignKey('deliverymechanisms.id'),
- index=True,
- nullable=False
+ Integer, ForeignKey("deliverymechanisms.id"), index=True, nullable=False
)
- resource_id = Column(Integer, ForeignKey('resources.id'), nullable=True)
+ resource_id = Column(Integer, ForeignKey("resources.id"), nullable=True)
# One LicensePoolDeliveryMechanism may fulfill many Loans.
fulfills = relationship("Loan", backref="fulfillment")
# One LicensePoolDeliveryMechanism may be associated with one RightsStatus.
- rightsstatus_id = Column(
- Integer, ForeignKey('rightsstatus.id'), index=True)
+ rightsstatus_id = Column(Integer, ForeignKey("rightsstatus.id"), index=True)
@classmethod
- def set(cls, data_source, identifier, content_type, drm_scheme, rights_uri,
- resource=None, autocommit=True):
+ def set(
+ cls,
+ data_source,
+ identifier,
+ content_type,
+ drm_scheme,
+ rights_uri,
+ resource=None,
+ autocommit=True,
+ ):
"""Register the fact that a distributor makes a title available in a
certain format.
@@ -1271,11 +1402,12 @@ def set(cls, data_source, identifier, content_type, drm_scheme, rights_uri,
)
rights_status = RightsStatus.lookup(_db, rights_uri)
lpdm, dirty = get_one_or_create(
- _db, LicensePoolDeliveryMechanism,
+ _db,
+ LicensePoolDeliveryMechanism,
identifier=identifier,
data_source=data_source,
delivery_mechanism=delivery_mechanism,
- resource=resource
+ resource=resource,
)
if not lpdm.rights_status or rights_status.uri != RightsStatus.UNKNOWN:
# We have better information available about the
@@ -1301,8 +1433,7 @@ def set(cls, data_source, identifier, content_type, drm_scheme, rights_uri,
@property
def is_open_access(self):
"""Is this an open-access delivery mechanism?"""
- return (self.rights_status
- and self.rights_status.uri in RightsStatus.OPEN_ACCESS)
+ return self.rights_status and self.rights_status.uri in RightsStatus.OPEN_ACCESS
def compatible_with(self, other):
"""Can a single loan be fulfilled with both this
@@ -1313,7 +1444,7 @@ def compatible_with(self, other):
if not isinstance(other, LicensePoolDeliveryMechanism):
return False
- if other.id==self.id:
+ if other.id == self.id:
# They two LicensePoolDeliveryMechanisms are the same object.
return True
@@ -1338,11 +1469,8 @@ def compatible_with(self, other):
# DeliveryMechanisms are the same or that one of them is a
# streaming mechanism.
open_access_rules = self.is_open_access and other.is_open_access
- return (
- other.delivery_mechanism
- and self.delivery_mechanism.compatible_with(
- other.delivery_mechanism, open_access_rules
- )
+ return other.delivery_mechanism and self.delivery_mechanism.compatible_with(
+ other.delivery_mechanism, open_access_rules
)
def delete(self):
@@ -1376,36 +1504,37 @@ def set_rights_status(self, uri):
@property
def license_pools(self):
- """Find all LicensePools for this LicensePoolDeliveryMechanism.
- """
+ """Find all LicensePools for this LicensePoolDeliveryMechanism."""
_db = Session.object_session(self)
- return _db.query(LicensePool).filter(
- LicensePool.data_source==self.data_source).filter(
- LicensePool.identifier==self.identifier)
+ return (
+ _db.query(LicensePool)
+ .filter(LicensePool.data_source == self.data_source)
+ .filter(LicensePool.identifier == self.identifier)
+ )
def __repr__(self):
- return "".format(
- str(self.data_source),
- repr(self.identifier),
- repr(self.delivery_mechanism)
+ return "".format(
+ str(self.data_source), repr(self.identifier), repr(self.delivery_mechanism)
)
__table_args__ = (
- UniqueConstraint('data_source_id', 'identifier_id',
- 'delivery_mechanism_id', 'resource_id'),
+ UniqueConstraint(
+ "data_source_id", "identifier_id", "delivery_mechanism_id", "resource_id"
+ ),
)
+
# The uniqueness constraint doesn't enforce uniqueness when one of the
# fields is null, and one of these fields -- resource_id -- is
# _usually_ null. So we also need a unique partial index to properly
# enforce the constraint.
Index(
- 'ix_licensepooldeliveries_unique_when_no_resource',
+ "ix_licensepooldeliveries_unique_when_no_resource",
LicensePoolDeliveryMechanism.data_source_id,
LicensePoolDeliveryMechanism.identifier_id,
LicensePoolDeliveryMechanism.delivery_mechanism_id,
unique=True,
- postgresql_where=(LicensePoolDeliveryMechanism.resource_id==None)
+ postgresql_where=(LicensePoolDeliveryMechanism.resource_id == None),
)
@@ -1416,6 +1545,7 @@ class DeliveryMechanism(Base, HasFullTableCache):
(e.g. "application/vnd.adobe.adept+xml" or "application/epub+zip") or an
informal name ("Kindle via Amazon").
"""
+
KINDLE_CONTENT_TYPE = "Kindle via Amazon"
NOOK_CONTENT_TYPE = "Nook via B&N"
STREAMING_TEXT_CONTENT_TYPE = "Streaming Text"
@@ -1442,20 +1572,30 @@ class DeliveryMechanism(Base, HasFullTableCache):
LIBBY_DRM = "Libby DRM"
KNOWN_DRM_TYPES = {
- ADOBE_DRM, FINDAWAY_DRM, AXISNOW_DRM, KINDLE_DRM, NOOK_DRM, STREAMING_DRM, LCP_DRM, OVERDRIVE_DRM, LIBBY_DRM
+ ADOBE_DRM,
+ FINDAWAY_DRM,
+ AXISNOW_DRM,
+ KINDLE_DRM,
+ NOOK_DRM,
+ STREAMING_DRM,
+ LCP_DRM,
+ OVERDRIVE_DRM,
+ LIBBY_DRM,
}
BEARER_TOKEN = "application/vnd.librarysimplified.bearer-token+json"
FEEDBOOKS_AUDIOBOOK_DRM = "http://www.feedbooks.com/audiobooks/access-restriction"
FEEDBOOKS_AUDIOBOOK_PROFILE = ';profile="%s"' % FEEDBOOKS_AUDIOBOOK_DRM
- STREAMING_PROFILE = ';profile="http://librarysimplified.org/terms/profiles/streaming-media"'
+ STREAMING_PROFILE = (
+ ';profile="http://librarysimplified.org/terms/profiles/streaming-media"'
+ )
MEDIA_TYPES_FOR_STREAMING = {
STREAMING_TEXT_CONTENT_TYPE: MediaTypes.TEXT_HTML_MEDIA_TYPE,
STREAMING_AUDIO_CONTENT_TYPE: MediaTypes.TEXT_HTML_MEDIA_TYPE,
}
- __tablename__ = 'deliverymechanisms'
+ __tablename__ = "deliverymechanisms"
id = Column(Integer, primary_key=True)
content_type = Column(String)
drm_scheme = Column(String)
@@ -1469,29 +1609,26 @@ class DeliveryMechanism(Base, HasFullTableCache):
#
# This is primarily used when deciding which books can be imported
# from an OPDS For Distributors collection.
- default_client_can_fulfill_lookup = set([
- # EPUB books
- (MediaTypes.EPUB_MEDIA_TYPE, NO_DRM),
- (MediaTypes.EPUB_MEDIA_TYPE, ADOBE_DRM),
-
- # PDF books
- (MediaTypes.PDF_MEDIA_TYPE, NO_DRM),
-
- # Various audiobook formats
- (None, FINDAWAY_DRM),
- (MediaTypes.AUDIOBOOK_MANIFEST_MEDIA_TYPE, NO_DRM),
-
- (MediaTypes.OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE, LIBBY_DRM),
- ])
+ default_client_can_fulfill_lookup = set(
+ [
+ # EPUB books
+ (MediaTypes.EPUB_MEDIA_TYPE, NO_DRM),
+ (MediaTypes.EPUB_MEDIA_TYPE, ADOBE_DRM),
+ # PDF books
+ (MediaTypes.PDF_MEDIA_TYPE, NO_DRM),
+ # Various audiobook formats
+ (None, FINDAWAY_DRM),
+ (MediaTypes.AUDIOBOOK_MANIFEST_MEDIA_TYPE, NO_DRM),
+ (MediaTypes.OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE, LIBBY_DRM),
+ ]
+ )
# If the default client supports a given media type with no DRM,
# we can infer that the client _also_ supports that media type via
# bearer token exchange.
for _media_type, _drm in list(default_client_can_fulfill_lookup):
if _media_type is not None and _drm == NO_DRM:
- default_client_can_fulfill_lookup.add(
- (_media_type, BEARER_TOKEN)
- )
+ default_client_can_fulfill_lookup.add((_media_type, BEARER_TOKEN))
license_pool_delivery_mechanisms = relationship(
"LicensePoolDeliveryMechanism",
@@ -1501,9 +1638,7 @@ class DeliveryMechanism(Base, HasFullTableCache):
_cache = HasFullTableCache.RESET
_id_cache = HasFullTableCache.RESET
- __table_args__ = (
- UniqueConstraint('content_type', 'drm_scheme'),
- )
+ __table_args__ = (UniqueConstraint("content_type", "drm_scheme"),)
@property
def name(self):
@@ -1523,17 +1658,15 @@ def __repr__(self):
else:
fulfillable = "not fulfillable"
- return "" % (
- self.name, fulfillable
- )
+ return "" % (self.name, fulfillable)
@classmethod
def lookup(cls, _db, content_type, drm_scheme):
def lookup_hook():
return get_one_or_create(
- _db, DeliveryMechanism, content_type=content_type,
- drm_scheme=drm_scheme
+ _db, DeliveryMechanism, content_type=content_type, drm_scheme=drm_scheme
)
+
return cls.by_cache_key(_db, (content_type, drm_scheme), lookup_hook)
@property
@@ -1542,13 +1675,14 @@ def implicit_medium(self):
available through this DeliveryMechanism?
"""
if self.content_type in (
- MediaTypes.EPUB_MEDIA_TYPE,
- MediaTypes.PDF_MEDIA_TYPE,
- "Kindle via Amazon",
- "Streaming Text"):
+ MediaTypes.EPUB_MEDIA_TYPE,
+ MediaTypes.PDF_MEDIA_TYPE,
+ "Kindle via Amazon",
+ "Streaming Text",
+ ):
return EditionConstants.BOOK_MEDIUM
elif self.content_type in (
- "Streaming Video" or self.content_type.startswith('video/')
+ "Streaming Video" or self.content_type.startswith("video/")
):
return EditionConstants.VIDEO_MEDIUM
else:
@@ -1560,8 +1694,10 @@ def is_media_type(cls, x):
if x is None:
return False
- return any(x.startswith(prefix) for prefix in
- ['vnd.', 'application', 'text', 'video', 'audio', 'image'])
+ return any(
+ x.startswith(prefix)
+ for prefix in ["vnd.", "application", "text", "video", "audio", "image"]
+ )
@property
def is_streaming(self):
@@ -1617,8 +1753,11 @@ def compatible_with(self, other, open_access_rules=False):
# For an open-access book, loans are not locked to delivery
# mechanisms, so as long as neither delivery mechanism has
# DRM, they're compatible.
- if (open_access_rules and self.drm_scheme==self.NO_DRM
- and other.drm_scheme==self.NO_DRM):
+ if (
+ open_access_rules
+ and self.drm_scheme == self.NO_DRM
+ and other.drm_scheme == self.NO_DRM
+ ):
return True
# For non-open-access books, locking a license pool to a
@@ -1626,15 +1765,16 @@ def compatible_with(self, other, open_access_rules=False):
# other non-streaming delivery mechanism.
return False
+
# The uniqueness constraint doesn't enforce uniqueness when one of the
# fields is null, and one of these fields -- drm_scheme -- is
# frequently null. So we also need a unique partial index to properly
# enforce the constraint.
Index(
- 'ix_deliverymechanisms_unique_when_no_drm',
+ "ix_deliverymechanisms_unique_when_no_drm",
DeliveryMechanism.content_type,
unique=True,
- postgresql_where=(DeliveryMechanism.drm_scheme==None)
+ postgresql_where=(DeliveryMechanism.drm_scheme == None),
)
@@ -1650,10 +1790,14 @@ class RightsStatus(Base):
IN_COPYRIGHT = "http://librarysimplified.org/terms/rights-status/in-copyright"
# Public domain in the USA.
- PUBLIC_DOMAIN_USA = "http://librarysimplified.org/terms/rights-status/public-domain-usa"
+ PUBLIC_DOMAIN_USA = (
+ "http://librarysimplified.org/terms/rights-status/public-domain-usa"
+ )
# Public domain in some unknown territory
- PUBLIC_DOMAIN_UNKNOWN = "http://librarysimplified.org/terms/rights-status/public-domain-unknown"
+ PUBLIC_DOMAIN_UNKNOWN = (
+ "http://librarysimplified.org/terms/rights-status/public-domain-unknown"
+ )
# Creative Commons Public Domain Dedication (No rights reserved)
CC0 = "https://creativecommons.org/publicdomain/zero/1.0/"
@@ -1677,7 +1821,9 @@ class RightsStatus(Base):
CC_BY_NC_ND = "https://creativecommons.org/licenses/by-nc-nd/4.0"
# Open access download but no explicit license
- GENERIC_OPEN_ACCESS = "http://librarysimplified.org/terms/rights-status/generic-open-access"
+ GENERIC_OPEN_ACCESS = (
+ "http://librarysimplified.org/terms/rights-status/generic-open-access"
+ )
# Unknown copyright status.
UNKNOWN = "http://librarysimplified.org/terms/rights-status/unknown"
@@ -1723,14 +1869,13 @@ class RightsStatus(Base):
DataSourceConstants.GUTENBERG: PUBLIC_DOMAIN_USA,
DataSourceConstants.PLYMPTON: CC_BY_NC,
# workaround for opds-imported license pools with 'content server' as data source
- DataSourceConstants.OA_CONTENT_SERVER : GENERIC_OPEN_ACCESS,
-
+ DataSourceConstants.OA_CONTENT_SERVER: GENERIC_OPEN_ACCESS,
DataSourceConstants.OVERDRIVE: IN_COPYRIGHT,
DataSourceConstants.BIBLIOTHECA: IN_COPYRIGHT,
DataSourceConstants.AXIS_360: IN_COPYRIGHT,
}
- __tablename__ = 'rightsstatus'
+ __tablename__ = "rightsstatus"
id = Column(Integer, primary_key=True)
# A URI unique to the license. This may be a URL (e.g. Creative
@@ -1741,7 +1886,9 @@ class RightsStatus(Base):
name = Column(String, index=True)
# One RightsStatus may apply to many LicensePoolDeliveryMechanisms.
- licensepooldeliverymechanisms = relationship("LicensePoolDeliveryMechanism", backref="rights_status")
+ licensepooldeliverymechanisms = relationship(
+ "LicensePoolDeliveryMechanism", backref="rights_status"
+ )
# One RightsStatus may apply to many Resources.
resources = relationship("Resource", backref="rights_status")
@@ -1753,40 +1900,38 @@ def lookup(cls, _db, uri):
name = cls.NAMES.get(uri)
create_method_kwargs = dict(name=name)
status, ignore = get_one_or_create(
- _db, RightsStatus, uri=uri,
- create_method_kwargs=create_method_kwargs
+ _db, RightsStatus, uri=uri, create_method_kwargs=create_method_kwargs
)
return status
@classmethod
def rights_uri_from_string(cls, rights):
rights = rights.lower()
- if rights == 'public domain in the usa.':
+ if rights == "public domain in the usa.":
return RightsStatus.PUBLIC_DOMAIN_USA
- elif rights == 'public domain in the united states.':
+ elif rights == "public domain in the united states.":
return RightsStatus.PUBLIC_DOMAIN_USA
- elif rights == 'pd-us':
+ elif rights == "pd-us":
return RightsStatus.PUBLIC_DOMAIN_USA
- elif rights.startswith('public domain'):
+ elif rights.startswith("public domain"):
return RightsStatus.PUBLIC_DOMAIN_UNKNOWN
- elif rights.startswith('copyrighted.'):
+ elif rights.startswith("copyrighted."):
return RightsStatus.IN_COPYRIGHT
- elif rights == 'cc0':
+ elif rights == "cc0":
return RightsStatus.CC0
- elif rights == 'cc by':
+ elif rights == "cc by":
return RightsStatus.CC_BY
- elif rights == 'cc by-sa':
+ elif rights == "cc by-sa":
return RightsStatus.CC_BY_SA
- elif rights == 'cc by-nd':
+ elif rights == "cc by-nd":
return RightsStatus.CC_BY_ND
- elif rights == 'cc by-nc':
+ elif rights == "cc by-nc":
return RightsStatus.CC_BY_NC
- elif rights == 'cc by-nc-sa':
+ elif rights == "cc by-nc-sa":
return RightsStatus.CC_BY_NC_SA
- elif rights == 'cc by-nc-nd':
+ elif rights == "cc by-nc-nd":
return RightsStatus.CC_BY_NC_ND
- elif (rights in RightsStatus.OPEN_ACCESS
- or rights == RightsStatus.IN_COPYRIGHT):
+ elif rights in RightsStatus.OPEN_ACCESS or rights == RightsStatus.IN_COPYRIGHT:
return rights
else:
return RightsStatus.UNKNOWN
diff --git a/model/listeners.py b/model/listeners.py
index 1ad93eb72..0b69b519f 100644
--- a/model/listeners.py
+++ b/model/listeners.py
@@ -1,40 +1,28 @@
# encoding: utf-8
import datetime
-from sqlalchemy import (
- event,
- text,
-)
from pdb import set_trace
+from threading import RLock
+
+from sqlalchemy import event, text
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.session import Session
-from threading import RLock
-from pdb import set_trace
-from . import (
- Base,
-)
-from .admin import (
- Admin,
- AdminRole,
-)
-from .datasource import DataSource
+
+from ..config import Configuration
+from ..util.datetime_helpers import to_utc, utc_now
+from . import Base
+from .admin import Admin, AdminRole
from .classification import Genre
from .collection import Collection
-from ..config import Configuration
-from .configuration import (
- ConfigurationSetting,
- ExternalIntegration,
-)
+from .configuration import ConfigurationSetting, ExternalIntegration
+from .datasource import DataSource
from .library import Library
-from .licensing import (
- DeliveryMechanism,
- LicensePool,
-)
+from .licensing import DeliveryMechanism, LicensePool
from .work import Work
-from ..util.datetime_helpers import to_utc, utc_now
-
site_configuration_has_changed_lock = RLock()
+
+
def site_configuration_has_changed(_db, cooldown=1):
"""Call this whenever you want to indicate that the site configuration
has changed and needs to be reloaded.
@@ -62,6 +50,7 @@ def site_configuration_has_changed(_db, cooldown=1):
finally:
site_configuration_has_changed_lock.release()
+
def _site_configuration_has_changed(_db, cooldown=1):
"""Actually changes the timestamp on the site configuration."""
now = utc_now()
@@ -79,29 +68,30 @@ def _site_configuration_has_changed(_db, cooldown=1):
# Update the timestamp.
now = utc_now()
- earlier = now-datetime.timedelta(seconds=cooldown)
+ earlier = now - datetime.timedelta(seconds=cooldown)
sql = "UPDATE timestamps SET finish=(:finish at time zone 'utc') WHERE service=:service AND collection_id IS NULL AND finish<=(:earlier at time zone 'utc');"
_db.execute(
text(sql),
- dict(service=Configuration.SITE_CONFIGURATION_CHANGED,
- finish=now, earlier=earlier)
+ dict(
+ service=Configuration.SITE_CONFIGURATION_CHANGED,
+ finish=now,
+ earlier=earlier,
+ ),
)
# Update the Configuration's record of when the configuration
# was updated. This will update our local record immediately
# without requiring a trip to the database.
- Configuration.site_configuration_last_update(
- _db, known_value=now
- )
+ Configuration.site_configuration_last_update(_db, known_value=now)
+
def directly_modified(obj):
"""Return True only if `obj` has itself been modified, as opposed to
having an object added or removed to one of its associated
collections.
"""
- return Session.object_session(obj).is_modified(
- obj, include_collections=False
- )
+ return Session.object_session(obj).is_modified(obj, include_collections=False)
+
# Most of the time, we can know whether a change to the database is
# likely to require that the application reload the portion of the
@@ -112,97 +102,107 @@ def directly_modified(obj):
# should trigger a ConfigurationSetting reload -- that needs to be
# handled on the application level -- but it should be good enough to
# catch most that slip through the cracks.
-@event.listens_for(Collection.children, 'append')
-@event.listens_for(Collection.children, 'remove')
-@event.listens_for(Collection.libraries, 'append')
-@event.listens_for(Collection.libraries, 'remove')
-@event.listens_for(ExternalIntegration.settings, 'append')
-@event.listens_for(ExternalIntegration.settings, 'remove')
-@event.listens_for(Library.integrations, 'append')
-@event.listens_for(Library.integrations, 'remove')
-@event.listens_for(Library.settings, 'append')
-@event.listens_for(Library.settings, 'remove')
+@event.listens_for(Collection.children, "append")
+@event.listens_for(Collection.children, "remove")
+@event.listens_for(Collection.libraries, "append")
+@event.listens_for(Collection.libraries, "remove")
+@event.listens_for(ExternalIntegration.settings, "append")
+@event.listens_for(ExternalIntegration.settings, "remove")
+@event.listens_for(Library.integrations, "append")
+@event.listens_for(Library.integrations, "remove")
+@event.listens_for(Library.settings, "append")
+@event.listens_for(Library.settings, "remove")
def configuration_relevant_collection_change(target, value, initiator):
site_configuration_has_changed(target)
-@event.listens_for(Library, 'after_insert')
-@event.listens_for(Library, 'after_delete')
-@event.listens_for(ExternalIntegration, 'after_insert')
-@event.listens_for(ExternalIntegration, 'after_delete')
-@event.listens_for(Collection, 'after_insert')
-@event.listens_for(Collection, 'after_delete')
-@event.listens_for(ConfigurationSetting, 'after_insert')
-@event.listens_for(ConfigurationSetting, 'after_delete')
+
+@event.listens_for(Library, "after_insert")
+@event.listens_for(Library, "after_delete")
+@event.listens_for(ExternalIntegration, "after_insert")
+@event.listens_for(ExternalIntegration, "after_delete")
+@event.listens_for(Collection, "after_insert")
+@event.listens_for(Collection, "after_delete")
+@event.listens_for(ConfigurationSetting, "after_insert")
+@event.listens_for(ConfigurationSetting, "after_delete")
def configuration_relevant_lifecycle_event(mapper, connection, target):
site_configuration_has_changed(target)
-@event.listens_for(Library, 'after_update')
-@event.listens_for(ExternalIntegration, 'after_update')
-@event.listens_for(Collection, 'after_update')
-@event.listens_for(ConfigurationSetting, 'after_update')
+
+@event.listens_for(Library, "after_update")
+@event.listens_for(ExternalIntegration, "after_update")
+@event.listens_for(Collection, "after_update")
+@event.listens_for(ConfigurationSetting, "after_update")
def configuration_relevant_update(mapper, connection, target):
if directly_modified(target):
site_configuration_has_changed(target)
-@event.listens_for(Admin, 'after_insert')
-@event.listens_for(Admin, 'after_delete')
-@event.listens_for(Admin, 'after_update')
+
+@event.listens_for(Admin, "after_insert")
+@event.listens_for(Admin, "after_delete")
+@event.listens_for(Admin, "after_update")
def refresh_admin_cache(mapper, connection, target):
# The next time someone tries to access an Admin,
# the cache will be repopulated.
Admin.reset_cache()
-@event.listens_for(AdminRole, 'after_insert')
-@event.listens_for(AdminRole, 'after_delete')
-@event.listens_for(AdminRole, 'after_update')
+
+@event.listens_for(AdminRole, "after_insert")
+@event.listens_for(AdminRole, "after_delete")
+@event.listens_for(AdminRole, "after_update")
def refresh_admin_role_cache(mapper, connection, target):
# The next time someone tries to access an AdminRole,
# the cache will be repopulated.
AdminRole.reset_cache()
-@event.listens_for(Collection, 'after_insert')
-@event.listens_for(Collection, 'after_delete')
-@event.listens_for(Collection, 'after_update')
+
+@event.listens_for(Collection, "after_insert")
+@event.listens_for(Collection, "after_delete")
+@event.listens_for(Collection, "after_update")
def refresh_collection_cache(mapper, connection, target):
# The next time someone tries to access a Collection,
# the cache will be repopulated.
Collection.reset_cache()
-@event.listens_for(ConfigurationSetting, 'after_insert')
-@event.listens_for(ConfigurationSetting, 'after_delete')
-@event.listens_for(ConfigurationSetting, 'after_update')
+
+@event.listens_for(ConfigurationSetting, "after_insert")
+@event.listens_for(ConfigurationSetting, "after_delete")
+@event.listens_for(ConfigurationSetting, "after_update")
def refresh_configuration_settings(mapper, connection, target):
# The next time someone tries to access a configuration setting,
# the cache will be repopulated.
ConfigurationSetting.reset_cache()
-@event.listens_for(DataSource, 'after_insert')
-@event.listens_for(DataSource, 'after_delete')
-@event.listens_for(DataSource, 'after_update')
+
+@event.listens_for(DataSource, "after_insert")
+@event.listens_for(DataSource, "after_delete")
+@event.listens_for(DataSource, "after_update")
def refresh_datasource_cache(mapper, connection, target):
# The next time someone tries to access a DataSource,
# the cache will be repopulated.
DataSource.reset_cache()
-@event.listens_for(DeliveryMechanism, 'after_insert')
-@event.listens_for(DeliveryMechanism, 'after_delete')
-@event.listens_for(DeliveryMechanism, 'after_update')
+
+@event.listens_for(DeliveryMechanism, "after_insert")
+@event.listens_for(DeliveryMechanism, "after_delete")
+@event.listens_for(DeliveryMechanism, "after_update")
def refresh_datasource_cache(mapper, connection, target):
# The next time someone tries to access a DeliveryMechanism,
# the cache will be repopulated.
DeliveryMechanism.reset_cache()
-@event.listens_for(ExternalIntegration, 'after_insert')
-@event.listens_for(ExternalIntegration, 'after_delete')
-@event.listens_for(ExternalIntegration, 'after_update')
+
+@event.listens_for(ExternalIntegration, "after_insert")
+@event.listens_for(ExternalIntegration, "after_delete")
+@event.listens_for(ExternalIntegration, "after_update")
def refresh_datasource_cache(mapper, connection, target):
# The next time someone tries to access an ExternalIntegration,
# the cache will be repopulated.
ExternalIntegration.reset_cache()
-@event.listens_for(Genre, 'after_insert')
-@event.listens_for(Genre, 'after_delete')
-@event.listens_for(Genre, 'after_update')
+
+@event.listens_for(Genre, "after_insert")
+@event.listens_for(Genre, "after_delete")
+@event.listens_for(Genre, "after_update")
def refresh_genre_cache(mapper, connection, target):
# The next time someone tries to access a genre,
# the cache will be repopulated.
@@ -211,21 +211,23 @@ def refresh_genre_cache(mapper, connection, target):
# site is brought up, but just in case.
Genre.reset_cache()
-@event.listens_for(Library, 'after_insert')
-@event.listens_for(Library, 'after_delete')
-@event.listens_for(Library, 'after_update')
+
+@event.listens_for(Library, "after_insert")
+@event.listens_for(Library, "after_delete")
+@event.listens_for(Library, "after_update")
def refresh_library_cache(mapper, connection, target):
# The next time someone tries to access a library,
# the cache will be repopulated.
Library.reset_cache()
+
# When a pool gets a work and a presentation edition for the first time,
# the work should be added to any custom lists associated with the pool's
# collection.
# In some cases, the work may be generated before the presentation edition.
# Then we need to add it when the work gets a presentation edition.
-@event.listens_for(LicensePool.work_id, 'set')
-@event.listens_for(Work.presentation_edition_id, 'set')
+@event.listens_for(LicensePool.work_id, "set")
+@event.listens_for(Work.presentation_edition_id, "set")
def add_work_to_customlists_for_collection(pool_or_work, value, oldvalue, initiator):
if isinstance(pool_or_work, LicensePool):
work = pool_or_work.work
@@ -234,7 +236,12 @@ def add_work_to_customlists_for_collection(pool_or_work, value, oldvalue, initia
work = pool_or_work
pools = work.license_pools
- if (not oldvalue or oldvalue is NO_VALUE) and value and work and work.presentation_edition:
+ if (
+ (not oldvalue or oldvalue is NO_VALUE)
+ and value
+ and work
+ and work.presentation_edition
+ ):
for pool in pools:
if not pool.collection:
# This shouldn't happen, but don't crash if it does --
@@ -248,18 +255,20 @@ def add_work_to_customlists_for_collection(pool_or_work, value, oldvalue, initia
# second one.
list.add_entry(work, featured=True, update_external_index=False)
+
# Certain ORM events, however they occur, indicate that a work's
# external index needs updating.
-@event.listens_for(Work.license_pools, 'append')
-@event.listens_for(Work.license_pools, 'remove')
+
+@event.listens_for(Work.license_pools, "append")
+@event.listens_for(Work.license_pools, "remove")
def licensepool_removed_from_work(target, value, initiator):
- """When a Work gains or loses a LicensePool, it needs to be reindexed.
- """
+ """When a Work gains or loses a LicensePool, it needs to be reindexed."""
if target:
target.external_index_needs_updating()
-@event.listens_for(LicensePool, 'after_delete')
+
+@event.listens_for(LicensePool, "after_delete")
def licensepool_deleted(mapper, connection, target):
"""A LicensePool is deleted only when its collection is deleted.
If this happens, we need to keep the Work's index up to date.
@@ -268,7 +277,8 @@ def licensepool_deleted(mapper, connection, target):
if work:
record = work.external_index_needs_updating()
-@event.listens_for(LicensePool.collection_id, 'set')
+
+@event.listens_for(LicensePool.collection_id, "set")
def licensepool_collection_change(target, value, oldvalue, initiator):
"""A LicensePool should never change collections, but if it is,
we need to keep the search index up to date.
@@ -280,8 +290,9 @@ def licensepool_collection_change(target, value, oldvalue, initiator):
return
work.external_index_needs_updating()
-@event.listens_for(LicensePool.open_access, 'set')
-@event.listens_for(LicensePool.self_hosted, 'set')
+
+@event.listens_for(LicensePool.open_access, "set")
+@event.listens_for(LicensePool.self_hosted, "set")
def licensepool_storage_status_change(target, value, oldvalue, initiator):
"""A Work may need to have its search document re-indexed if one of
its LicensePools changes its open-access status.
@@ -295,7 +306,8 @@ def licensepool_storage_status_change(target, value, oldvalue, initiator):
return
work.external_index_needs_updating()
-@event.listens_for(Work.last_update_time, 'set')
+
+@event.listens_for(Work.last_update_time, "set")
def last_update_time_change(target, value, oldvalue, initator):
"""A Work needs to have its search document re-indexed whenever its
last_update_time changes.
diff --git a/model/measurement.py b/model/measurement.py
index 5f571eac0..e0bad5182 100644
--- a/model/measurement.py
+++ b/model/measurement.py
@@ -2,26 +2,21 @@
# Measurement
+import bisect
+import logging
+
+from sqlalchemy import Boolean, Column, DateTime, Float, ForeignKey, Integer, Unicode
+
from . import Base
from .constants import DataSourceConstants
-import bisect
-import logging
-from sqlalchemy import (
- Boolean,
- Column,
- DateTime,
- Float,
- ForeignKey,
- Integer,
- Unicode,
-)
class Measurement(Base):
"""A measurement of some numeric quantity associated with a
Identifier.
"""
- __tablename__ = 'measurements'
+
+ __tablename__ = "measurements"
# Some common measurement types
POPULARITY = "http://librarysimplified.org/terms/rel/popularity"
@@ -33,7 +28,9 @@ class Measurement(Base):
PAGE_COUNT = "https://schema.org/numberOfPages"
AWARDS = "http://librarysimplified.org/terms/rel/awards"
- GUTENBERG_FAVORITE = "http://librarysimplified.org/terms/rel/lists/gutenberg-favorite"
+ GUTENBERG_FAVORITE = (
+ "http://librarysimplified.org/terms/rel/lists/gutenberg-favorite"
+ )
# We have a number of ways of measuring popularity: by an opaque
# number such as Amazon's Sales Rank, or by a directly measured
@@ -56,41 +53,644 @@ class Measurement(Base):
# A book may have a popularity score derived from an opaque
# measure of 'popularity' from some other source.
#
- POPULARITY : {
+ POPULARITY: {
# Overdrive provides a 'popularity' score for each book.
- DataSourceConstants.OVERDRIVE : [1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 9, 9, 10, 10, 11, 12, 13, 14, 15, 15, 16, 18, 19, 20, 21, 22, 24, 25, 26, 28, 30, 31, 33, 35, 37, 39, 41, 43, 46, 48, 51, 53, 56, 59, 63, 66, 70, 74, 78, 82, 87, 92, 97, 102, 108, 115, 121, 128, 135, 142, 150, 159, 168, 179, 190, 202, 216, 230, 245, 260, 277, 297, 319, 346, 372, 402, 436, 478, 521, 575, 632, 702, 777, 861, 965, 1100, 1248, 1428, 1665, 2020, 2560, 3535, 5805],
-
+ DataSourceConstants.OVERDRIVE: [
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 3,
+ 3,
+ 4,
+ 4,
+ 5,
+ 5,
+ 6,
+ 6,
+ 7,
+ 7,
+ 8,
+ 9,
+ 9,
+ 10,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 15,
+ 16,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 24,
+ 25,
+ 26,
+ 28,
+ 30,
+ 31,
+ 33,
+ 35,
+ 37,
+ 39,
+ 41,
+ 43,
+ 46,
+ 48,
+ 51,
+ 53,
+ 56,
+ 59,
+ 63,
+ 66,
+ 70,
+ 74,
+ 78,
+ 82,
+ 87,
+ 92,
+ 97,
+ 102,
+ 108,
+ 115,
+ 121,
+ 128,
+ 135,
+ 142,
+ 150,
+ 159,
+ 168,
+ 179,
+ 190,
+ 202,
+ 216,
+ 230,
+ 245,
+ 260,
+ 277,
+ 297,
+ 319,
+ 346,
+ 372,
+ 402,
+ 436,
+ 478,
+ 521,
+ 575,
+ 632,
+ 702,
+ 777,
+ 861,
+ 965,
+ 1100,
+ 1248,
+ 1428,
+ 1665,
+ 2020,
+ 2560,
+ 3535,
+ 5805,
+ ],
# Amazon Sales Rank - lower means more sales.
- DataSourceConstants.AMAZON : [14937330, 1974074, 1702163, 1553600, 1432635, 1327323, 1251089, 1184878, 1131998, 1075720, 1024272, 978514, 937726, 898606, 868506, 837523, 799879, 770211, 743194, 718052, 693932, 668030, 647121, 627642, 609399, 591843, 575970, 559942, 540713, 524397, 511183, 497576, 483884, 470850, 458438, 444475, 432528, 420088, 408785, 398420, 387895, 377244, 366837, 355406, 344288, 333747, 324280, 315002, 305918, 296420, 288522, 279185, 270824, 262801, 253865, 246224, 238239, 230537, 222611, 215989, 208641, 202597, 195817, 188939, 181095, 173967, 166058, 160032, 153526, 146706, 139981, 133348, 126689, 119201, 112447, 106795, 101250, 96534, 91052, 85837, 80619, 75292, 69957, 65075, 59901, 55616, 51624, 47598, 43645, 39403, 35645, 31795, 27990, 24496, 20780, 17740, 14102, 10498, 7090, 3861],
-
+ DataSourceConstants.AMAZON: [
+ 14937330,
+ 1974074,
+ 1702163,
+ 1553600,
+ 1432635,
+ 1327323,
+ 1251089,
+ 1184878,
+ 1131998,
+ 1075720,
+ 1024272,
+ 978514,
+ 937726,
+ 898606,
+ 868506,
+ 837523,
+ 799879,
+ 770211,
+ 743194,
+ 718052,
+ 693932,
+ 668030,
+ 647121,
+ 627642,
+ 609399,
+ 591843,
+ 575970,
+ 559942,
+ 540713,
+ 524397,
+ 511183,
+ 497576,
+ 483884,
+ 470850,
+ 458438,
+ 444475,
+ 432528,
+ 420088,
+ 408785,
+ 398420,
+ 387895,
+ 377244,
+ 366837,
+ 355406,
+ 344288,
+ 333747,
+ 324280,
+ 315002,
+ 305918,
+ 296420,
+ 288522,
+ 279185,
+ 270824,
+ 262801,
+ 253865,
+ 246224,
+ 238239,
+ 230537,
+ 222611,
+ 215989,
+ 208641,
+ 202597,
+ 195817,
+ 188939,
+ 181095,
+ 173967,
+ 166058,
+ 160032,
+ 153526,
+ 146706,
+ 139981,
+ 133348,
+ 126689,
+ 119201,
+ 112447,
+ 106795,
+ 101250,
+ 96534,
+ 91052,
+ 85837,
+ 80619,
+ 75292,
+ 69957,
+ 65075,
+ 59901,
+ 55616,
+ 51624,
+ 47598,
+ 43645,
+ 39403,
+ 35645,
+ 31795,
+ 27990,
+ 24496,
+ 20780,
+ 17740,
+ 14102,
+ 10498,
+ 7090,
+ 3861,
+ ],
# This is as measured by the criteria defined in
# ContentCafeSOAPClient.estimate_popularity(), in which
# popularity is the maximum of a) the largest number of books
# ordered in a single month within the last year, or b)
# one-half the largest number of books ever ordered in a
# single month.
- DataSourceConstants.CONTENT_CAFE : [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 7, 8, 9, 10, 11, 14, 18, 25, 41, 125, 387],
+ DataSourceConstants.CONTENT_CAFE: [
+ 0,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 5,
+ 5,
+ 5,
+ 5,
+ 6,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 14,
+ 18,
+ 25,
+ 41,
+ 125,
+ 387,
+ ],
},
-
# The popularity of a book may be deduced from the number of
# libraries with that book in their collections.
#
- HOLDINGS : {
- DataSourceConstants.OCLC : [1, 8, 12, 16, 20, 24, 28, 33, 37, 43, 49, 55, 62, 70, 78, 86, 94, 102, 110, 118, 126, 134, 143, 151, 160, 170, 178, 187, 196, 205, 214, 225, 233, 243, 253, 263, 275, 286, 298, 310, 321, 333, 345, 358, 370, 385, 398, 413, 427, 443, 458, 475, 492, 511, 530, 549, 567, 586, 606, 627, 647, 669, 693, 718, 741, 766, 794, 824, 852, 882, 914, 947, 980, 1018, 1056, 1098, 1142, 1188, 1235, 1288, 1347, 1410, 1477, 1545, 1625, 1714, 1812, 1923, 2039, 2164, 2304, 2479, 2671, 2925, 3220, 3565, 3949, 4476, 5230, 7125, 34811],
+ HOLDINGS: {
+ DataSourceConstants.OCLC: [
+ 1,
+ 8,
+ 12,
+ 16,
+ 20,
+ 24,
+ 28,
+ 33,
+ 37,
+ 43,
+ 49,
+ 55,
+ 62,
+ 70,
+ 78,
+ 86,
+ 94,
+ 102,
+ 110,
+ 118,
+ 126,
+ 134,
+ 143,
+ 151,
+ 160,
+ 170,
+ 178,
+ 187,
+ 196,
+ 205,
+ 214,
+ 225,
+ 233,
+ 243,
+ 253,
+ 263,
+ 275,
+ 286,
+ 298,
+ 310,
+ 321,
+ 333,
+ 345,
+ 358,
+ 370,
+ 385,
+ 398,
+ 413,
+ 427,
+ 443,
+ 458,
+ 475,
+ 492,
+ 511,
+ 530,
+ 549,
+ 567,
+ 586,
+ 606,
+ 627,
+ 647,
+ 669,
+ 693,
+ 718,
+ 741,
+ 766,
+ 794,
+ 824,
+ 852,
+ 882,
+ 914,
+ 947,
+ 980,
+ 1018,
+ 1056,
+ 1098,
+ 1142,
+ 1188,
+ 1235,
+ 1288,
+ 1347,
+ 1410,
+ 1477,
+ 1545,
+ 1625,
+ 1714,
+ 1812,
+ 1923,
+ 2039,
+ 2164,
+ 2304,
+ 2479,
+ 2671,
+ 2925,
+ 3220,
+ 3565,
+ 3949,
+ 4476,
+ 5230,
+ 7125,
+ 34811,
+ ],
},
-
# The popularity of a book may be deduced from the number of
# published editions of that book.
#
- PUBLISHED_EDITIONS : {
- DataSourceConstants.OCLC : [1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 12, 12, 12, 13, 13, 14, 14, 14, 15, 15, 16, 16, 17, 18, 19, 19, 20, 21, 22, 24, 25, 26, 28, 30, 32, 34, 36, 39, 42, 46, 50, 56, 64, 73, 87, 112, 156, 281, 2812],
+ PUBLISHED_EDITIONS: {
+ DataSourceConstants.OCLC: [
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 4,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 6,
+ 6,
+ 6,
+ 6,
+ 6,
+ 6,
+ 6,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 8,
+ 8,
+ 8,
+ 8,
+ 8,
+ 8,
+ 9,
+ 9,
+ 9,
+ 9,
+ 10,
+ 10,
+ 10,
+ 10,
+ 11,
+ 11,
+ 11,
+ 12,
+ 12,
+ 12,
+ 13,
+ 13,
+ 14,
+ 14,
+ 14,
+ 15,
+ 15,
+ 16,
+ 16,
+ 17,
+ 18,
+ 19,
+ 19,
+ 20,
+ 21,
+ 22,
+ 24,
+ 25,
+ 26,
+ 28,
+ 30,
+ 32,
+ 34,
+ 36,
+ 39,
+ 42,
+ 46,
+ 50,
+ 56,
+ 64,
+ 73,
+ 87,
+ 112,
+ 156,
+ 281,
+ 2812,
+ ],
},
-
# The popularity of a book may be deduced from the number of
# recent downloads from some site.
#
- DOWNLOADS : {
- DataSourceConstants.GUTENBERG : [0, 1, 2, 3, 4, 5, 5, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 12, 12, 12, 13, 14, 14, 15, 15, 16, 16, 17, 18, 18, 19, 19, 20, 21, 21, 22, 23, 23, 24, 25, 26, 27, 28, 28, 29, 30, 32, 33, 34, 35, 36, 37, 38, 40, 41, 43, 45, 46, 48, 50, 52, 55, 57, 60, 62, 65, 69, 72, 76, 79, 83, 87, 93, 99, 106, 114, 122, 130, 140, 152, 163, 179, 197, 220, 251, 281, 317, 367, 432, 501, 597, 658, 718, 801, 939, 1065, 1286, 1668, 2291, 4139],
+ DOWNLOADS: {
+ DataSourceConstants.GUTENBERG: [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 5,
+ 6,
+ 7,
+ 7,
+ 8,
+ 8,
+ 9,
+ 9,
+ 10,
+ 10,
+ 11,
+ 12,
+ 12,
+ 12,
+ 13,
+ 14,
+ 14,
+ 15,
+ 15,
+ 16,
+ 16,
+ 17,
+ 18,
+ 18,
+ 19,
+ 19,
+ 20,
+ 21,
+ 21,
+ 22,
+ 23,
+ 23,
+ 24,
+ 25,
+ 26,
+ 27,
+ 28,
+ 28,
+ 29,
+ 30,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 40,
+ 41,
+ 43,
+ 45,
+ 46,
+ 48,
+ 50,
+ 52,
+ 55,
+ 57,
+ 60,
+ 62,
+ 65,
+ 69,
+ 72,
+ 76,
+ 79,
+ 83,
+ 87,
+ 93,
+ 99,
+ 106,
+ 114,
+ 122,
+ 130,
+ 140,
+ 152,
+ 163,
+ 179,
+ 197,
+ 220,
+ 251,
+ 281,
+ 317,
+ 367,
+ 432,
+ 501,
+ 597,
+ 658,
+ 718,
+ 801,
+ 939,
+ 1065,
+ 1286,
+ 1668,
+ 2291,
+ 4139,
+ ],
},
}
@@ -98,8 +698,8 @@ class Measurement(Base):
# to another. Once we know the scale used by a given data source, we can
# scale its ratings to the 0..1 range and create a 'quality' rating.
RATING_SCALES = {
- DataSourceConstants.OVERDRIVE : [1, 5],
- DataSourceConstants.AMAZON : [1, 5],
+ DataSourceConstants.OVERDRIVE: [1, 5],
+ DataSourceConstants.AMAZON: [1, 5],
DataSourceConstants.UNGLUE_IT: [1, 5],
DataSourceConstants.NOVELIST: [0, 5],
DataSourceConstants.LIBRARY_STAFF: [1, 5],
@@ -108,12 +708,10 @@ class Measurement(Base):
id = Column(Integer, primary_key=True)
# A Measurement is always associated with some Identifier.
- identifier_id = Column(
- Integer, ForeignKey('identifiers.id'), index=True)
+ identifier_id = Column(Integer, ForeignKey("identifiers.id"), index=True)
# A Measurement always comes from some DataSource.
- data_source_id = Column(
- Integer, ForeignKey('datasources.id'), index=True)
+ data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True)
# The quantity being measured.
quantity_measured = Column(Unicode, index=True)
@@ -138,18 +736,22 @@ class Measurement(Base):
def __repr__(self):
return "%s(%r)=%s (norm=%.2f)" % (
- self.quantity_measured, self.identifier, self.value,
- self.normalized_value or 0)
+ self.quantity_measured,
+ self.identifier,
+ self.value,
+ self.normalized_value or 0,
+ )
@classmethod
- def overall_quality(cls, measurements, popularity_weight=0.3,
- rating_weight=0.7, default_value=0):
+ def overall_quality(
+ cls, measurements, popularity_weight=0.3, rating_weight=0.7, default_value=0
+ ):
"""Turn a bunch of measurements into an overall measure of quality."""
if popularity_weight + rating_weight != 1.0:
raise ValueError(
- "Popularity weight and rating weight must sum to 1! (%.2f + %.2f)" % (
- popularity_weight, rating_weight)
- )
+ "Popularity weight and rating weight must sum to 1! (%.2f + %.2f)"
+ % (popularity_weight, rating_weight)
+ )
popularities = []
ratings = []
qualities = []
@@ -193,7 +795,11 @@ def overall_quality(cls, measurements, popularity_weight=0.3,
final = (popularity * popularity_weight) + (rating * rating_weight)
logging.debug(
"(%.2f * %.2f) + (%.2f * %.2f) = %.2f",
- popularity, popularity_weight, rating, rating_weight, final
+ popularity,
+ popularity_weight,
+ rating,
+ rating_weight,
+ final,
)
if quality:
logging.debug("Popularity+Rating: %.2f, Quality: %.2f" % (final, quality))
@@ -210,7 +816,7 @@ def _average_normalized_value(cls, measurements):
if v is None:
continue
num_measurements += m.weight
- measurement_total += (v * m.weight)
+ measurement_total += v * m.weight
if num_measurements:
return measurement_total / num_measurements
else:
@@ -229,20 +835,20 @@ def normalized_value(self):
elif self.data_source.name == DataSourceConstants.METADATA_WRANGLER:
# Data from the metadata wrangler comes in pre-normalized.
self._normalized_value = self.value
- elif (self.quantity_measured == self.RATING
- and self.data_source.name in self.RATING_SCALES):
+ elif (
+ self.quantity_measured == self.RATING
+ and self.data_source.name in self.RATING_SCALES
+ ):
# Ratings need to be normalized from a scale that depends
# on the data source (e.g. Amazon's 1-5 stars) to a 0..1 scale.
scale_min, scale_max = self.RATING_SCALES[self.data_source.name]
- width = float(scale_max-scale_min)
- value = self.value-scale_min
+ width = float(scale_max - scale_min)
+ value = self.value - scale_min
self._normalized_value = value / width
elif self.quantity_measured in self.PERCENTILE_SCALES:
# Other measured quantities need to be normalized using
# a percentile scale determined emperically.
- by_data_source = self.PERCENTILE_SCALES[
- self.quantity_measured
- ]
+ by_data_source = self.PERCENTILE_SCALES[self.quantity_measured]
if not self.data_source.name in by_data_source:
# We don't know how to normalize measurements from
# this data source. Ignore this data.
diff --git a/model/patron.py b/model/patron.py
index 363f37441..e17ba5a7f 100644
--- a/model/patron.py
+++ b/model/patron.py
@@ -3,6 +3,9 @@
import datetime
import logging
+import uuid
+
+from psycopg2.extras import NumericRange
from sqlalchemy import (
Boolean,
Column,
@@ -15,24 +18,18 @@
Unicode,
UniqueConstraint,
)
-from psycopg2.extras import NumericRange
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship
from sqlalchemy.orm.session import Session
-import uuid
-from . import (
- Base,
- get_one_or_create,
- numericrange_to_tuple
-)
-from .credential import Credential
from ..classifier import Classifier
from ..user_profile import ProfileStorage
from ..util.datetime_helpers import utc_now
+from . import Base, get_one_or_create, numericrange_to_tuple
+from .credential import Credential
-class LoanAndHoldMixin(object):
+class LoanAndHoldMixin(object):
@property
def work(self):
"""Try to find the corresponding work for this Loan/Hold."""
@@ -53,18 +50,16 @@ def library(self):
# If this Loan/Hold belongs to a external patron, there may be no library.
return None
+
class Patron(Base):
- __tablename__ = 'patrons'
+ __tablename__ = "patrons"
id = Column(Integer, primary_key=True)
# Each patron is the patron _of_ one particular library. An
# individual human being may patronize multiple libraries, but
# they will have a different patron account at each one.
- library_id = Column(
- Integer, ForeignKey('libraries.id'), index=True,
- nullable=False
- )
+ library_id = Column(Integer, ForeignKey("libraries.id"), index=True, nullable=False)
# The patron's permanent unique identifier in an external library
# system, probably never seen by the patron.
@@ -128,8 +123,9 @@ class Patron(Base):
# Whether or not the patron wants their annotations synchronized
# across devices (which requires storing those annotations on a
# library server).
- _synchronize_annotations = Column(Boolean, default=None,
- name="synchronize_annotations")
+ _synchronize_annotations = Column(
+ Boolean, default=None, name="synchronize_annotations"
+ )
# If the circulation manager is set up to associate a patron's
# neighborhood with circulation events, and it would be
@@ -152,18 +148,23 @@ class Patron(Base):
# be an explicit decision of the ILS integration code.
cached_neighborhood = Column(Unicode, default=None, index=True)
- loans = relationship('Loan', backref='patron', cascade='delete')
- holds = relationship('Hold', backref='patron', cascade='delete')
+ loans = relationship("Loan", backref="patron", cascade="delete")
+ holds = relationship("Hold", backref="patron", cascade="delete")
- annotations = relationship('Annotation', backref='patron', order_by="desc(Annotation.timestamp)", cascade='delete')
+ annotations = relationship(
+ "Annotation",
+ backref="patron",
+ order_by="desc(Annotation.timestamp)",
+ cascade="delete",
+ )
# One Patron can have many associated Credentials.
credentials = relationship("Credential", backref="patron", cascade="delete")
__table_args__ = (
- UniqueConstraint('library_id', 'username'),
- UniqueConstraint('library_id', 'authorization_identifier'),
- UniqueConstraint('library_id', 'external_identifier'),
+ UniqueConstraint("library_id", "username"),
+ UniqueConstraint("library_id", "authorization_identifier"),
+ UniqueConstraint("library_id", "external_identifier"),
)
# A patron with borrowing privileges should have their local
@@ -182,9 +183,11 @@ def date(d):
if isinstance(d, datetime.datetime):
return d.date()
return d
- return '' % (
- self.authorization_identifier, date(self.authorization_expires),
- date(self.last_external_sync)
+
+ return "" % (
+ self.authorization_identifier,
+ date(self.authorization_expires),
+ date(self.last_external_sync),
)
def identifier_to_remote_service(self, remote_data_source, generator=None):
@@ -194,21 +197,27 @@ def identifier_to_remote_service(self, remote_data_source, generator=None):
DataSource) corresponding to the remote service.
"""
_db = Session.object_session(self)
+
def refresh(credential):
if generator and callable(generator):
identifier = generator()
else:
identifier = str(uuid.uuid1())
credential.credential = identifier
+
credential = Credential.lookup(
- _db, remote_data_source, Credential.IDENTIFIER_TO_REMOTE_SERVICE,
- self, refresh, allow_persistent_token=True
+ _db,
+ remote_data_source,
+ Credential.IDENTIFIER_TO_REMOTE_SERVICE,
+ self,
+ refresh,
+ allow_persistent_token=True,
)
return credential.credential
def works_on_loan(self):
db = Session.object_session(self)
- loans = db.query(Loan).filter(Loan.patron==self)
+ loans = db.query(Loan).filter(Loan.patron == self)
return [loan.work for loan in self.loans if loan.work]
def works_on_loan_or_on_hold(self):
@@ -248,9 +257,7 @@ def last_loan_activity_sync(self):
# We have an answer, but it may be so old that we should clear
# it out.
now = utc_now()
- expires = value + datetime.timedelta(
- seconds=self.loan_activity_max_age
- )
+ expires = value + datetime.timedelta(seconds=self.loan_activity_max_age)
if now > expires:
# The value has expired. Clear it out.
value = None
@@ -273,12 +280,10 @@ def synchronize_annotations(self, value):
if value is None:
# A patron cannot decide to go back to the state where
# they hadn't made a decision.
- raise ValueError(
- "synchronize_annotations cannot be unset once set."
- )
+ raise ValueError("synchronize_annotations cannot be unset once set.")
if value is False:
_db = Session.object_session(self)
- qu = _db.query(Annotation).filter(Annotation.patron==self)
+ qu = _db.query(Annotation).filter(Annotation.patron == self)
for annotation in qu:
_db.delete(annotation)
self._synchronize_annotations = value
@@ -302,11 +307,13 @@ def root_lane(self):
_db = Session.object_session(self)
from ..lane import Lane
- qu = _db.query(Lane).filter(
- Lane.library==self.library
- ).filter(
- Lane.root_for_patron_type.any(self.external_type)
- ).order_by(Lane.id)
+
+ qu = (
+ _db.query(Lane)
+ .filter(Lane.library == self.library)
+ .filter(Lane.root_for_patron_type.any(self.external_type))
+ .order_by(Lane.id)
+ )
lanes = qu.all()
if len(lanes) < 1:
# The most common situation -- this patron has no special
@@ -317,8 +324,7 @@ def root_lane(self):
# configuration problem, but we shouldn't make the patron
# pay the price -- just pick the first one.
logging.error(
- "Multiple root lanes found for patron type %s.",
- self.external_type
+ "Multiple root lanes found for patron type %s.", self.external_type
)
return lanes[0]
@@ -355,16 +361,14 @@ def work_is_age_appropriate(self, work_audience, work_target_age):
# are a match for the title's audience and target age.
return any(
self.age_appropriate_match(
- work_audience, work_target_age,
- audience, root.target_age
+ work_audience, work_target_age, audience, root.target_age
)
for audience in root.audiences
)
@classmethod
def age_appropriate_match(
- cls, work_audience, work_target_age,
- reader_audience, reader_age
+ cls, work_audience, work_target_age, reader_audience, reader_age
):
"""Match the audience and target age of a work with that of a reader,
and see whether they are an age-appropriate match.
@@ -393,10 +397,8 @@ def age_appropriate_match(
log = logging.getLogger("Age-appropriate match calculator")
log.debug(
- "Matching work %s/%s to reader %s/%s" % (
- work_audience, work_target_age,
- reader_audience, reader_age
- )
+ "Matching work %s/%s to reader %s/%s"
+ % (work_audience, work_target_age, reader_audience, reader_age)
)
if reader_audience not in Classifier.AUDIENCES_JUVENILE:
@@ -430,9 +432,13 @@ def ensure_tuple(x):
# A YA reader is treated as an adult (with no reading
# restrictions) if they have no associated age range, or their
# age range includes ADULT_AGE_CUTOFF.
- if (reader_audience == Classifier.AUDIENCE_YOUNG_ADULT
- and (reader_age is None
- or (isinstance(reader_age, int) and reader_age >= Classifier.ADULT_AGE_CUTOFF))):
+ if reader_audience == Classifier.AUDIENCE_YOUNG_ADULT and (
+ reader_age is None
+ or (
+ isinstance(reader_age, int)
+ and reader_age >= Classifier.ADULT_AGE_CUTOFF
+ )
+ ):
log.debug("YA reader to be treated as an adult.")
return True
@@ -445,13 +451,16 @@ def ensure_tuple(x):
# At this point we know we have a juvenile reader and a
# juvenile book.
- if (reader_audience == Classifier.AUDIENCE_YOUNG_ADULT
- and work_audience in (Classifier.AUDIENCES_YOUNG_CHILDREN)):
+ if reader_audience == Classifier.AUDIENCE_YOUNG_ADULT and work_audience in (
+ Classifier.AUDIENCES_YOUNG_CHILDREN
+ ):
log.debug("YA reader can access any children's title.")
return True
- if (reader_audience in (Classifier.AUDIENCES_YOUNG_CHILDREN)
- and work_audience == Classifier.AUDIENCE_YOUNG_ADULT):
+ if (
+ reader_audience in (Classifier.AUDIENCES_YOUNG_CHILDREN)
+ and work_audience == Classifier.AUDIENCE_YOUNG_ADULT
+ ):
log.debug("Child reader cannot access any YA title.")
return False
@@ -462,9 +471,7 @@ def ensure_tuple(x):
if work_target_age is None:
# This is a generic children's or YA book with no
# particular target age. Assume it's age appropriate.
- log.debug(
- "Juvenile book with no target age is presumed age-appropriate."
- )
+ log.debug("Juvenile book with no target age is presumed age-appropriate.")
return True
if reader_age is None:
@@ -479,43 +486,50 @@ def ensure_tuple(x):
# The audience for this book matches the patron's
# audience, but the book has a target age that is too high
# for the reader.
- log.debug(
- "Audience matches, but work's target age is too high for reader."
- )
+ log.debug("Audience matches, but work's target age is too high for reader.")
return False
log.debug("Both audience and target age match; it's age-appropriate.")
return True
-Index("ix_patron_library_id_external_identifier", Patron.library_id, Patron.external_identifier)
-Index("ix_patron_library_id_authorization_identifier", Patron.library_id, Patron.authorization_identifier)
+Index(
+ "ix_patron_library_id_external_identifier",
+ Patron.library_id,
+ Patron.external_identifier,
+)
+Index(
+ "ix_patron_library_id_authorization_identifier",
+ Patron.library_id,
+ Patron.authorization_identifier,
+)
Index("ix_patron_library_id_username", Patron.library_id, Patron.username)
+
class Loan(Base, LoanAndHoldMixin):
- __tablename__ = 'loans'
+ __tablename__ = "loans"
id = Column(Integer, primary_key=True)
- patron_id = Column(Integer, ForeignKey('patrons.id'), index=True)
- integration_client_id = Column(Integer, ForeignKey('integrationclients.id'), index=True)
+ patron_id = Column(Integer, ForeignKey("patrons.id"), index=True)
+ integration_client_id = Column(
+ Integer, ForeignKey("integrationclients.id"), index=True
+ )
# A Loan is always associated with a LicensePool.
- license_pool_id = Column(Integer, ForeignKey('licensepools.id'), index=True)
+ license_pool_id = Column(Integer, ForeignKey("licensepools.id"), index=True)
# It may also be associated with an individual License if the source
# provides information about individual licenses.
- license_id = Column(Integer, ForeignKey('licenses.id'), index=True, nullable=True)
+ license_id = Column(Integer, ForeignKey("licenses.id"), index=True, nullable=True)
- fulfillment_id = Column(Integer, ForeignKey('licensepooldeliveries.id'))
+ fulfillment_id = Column(Integer, ForeignKey("licensepooldeliveries.id"))
start = Column(DateTime(timezone=True), index=True)
end = Column(DateTime(timezone=True), index=True)
# Some distributors (e.g. Feedbooks) may have an identifier that can
# be used to check the status of a specific Loan.
external_identifier = Column(Unicode, unique=True, nullable=True)
- __table_args__ = (
- UniqueConstraint('patron_id', 'license_pool_id'),
- )
-
+ __table_args__ = (UniqueConstraint("patron_id", "license_pool_id"),)
+
def __lt__(self, other):
return self.id < other.id
@@ -529,26 +543,34 @@ def until(self, default_loan_period):
start = self.start or utc_now()
return start + default_loan_period
+
class Hold(Base, LoanAndHoldMixin):
- """A patron is in line to check out a book.
- """
- __tablename__ = 'holds'
+ """A patron is in line to check out a book."""
+
+ __tablename__ = "holds"
id = Column(Integer, primary_key=True)
- patron_id = Column(Integer, ForeignKey('patrons.id'), index=True)
- integration_client_id = Column(Integer, ForeignKey('integrationclients.id'), index=True)
- license_pool_id = Column(Integer, ForeignKey('licensepools.id'), index=True)
+ patron_id = Column(Integer, ForeignKey("patrons.id"), index=True)
+ integration_client_id = Column(
+ Integer, ForeignKey("integrationclients.id"), index=True
+ )
+ license_pool_id = Column(Integer, ForeignKey("licensepools.id"), index=True)
start = Column(DateTime(timezone=True), index=True)
end = Column(DateTime(timezone=True), index=True)
position = Column(Integer, index=True)
external_identifier = Column(Unicode, unique=True, nullable=True)
-
+
def __lt__(self, other):
return self.id < other.id
@classmethod
def _calculate_until(
- self, start, queue_position, total_licenses, default_loan_period,
- default_reservation_period):
+ self,
+ start,
+ queue_position,
+ total_licenses,
+ default_loan_period,
+ default_reservation_period,
+ ):
"""Helper method for `Hold.until` that can be tested independently.
We have to wait for the available licenses to cycle a
certain number of times before we get a turn.
@@ -576,7 +598,7 @@ def _calculate_until(
# in front of you to get a reservation notification, borrow
# the book at the last minute, and keep the book for the
# maximum allowable time.
- cycle_period = (default_reservation_period + default_loan_period)
+ cycle_period = default_reservation_period + default_loan_period
# This will happen at least once.
cycles = 1
@@ -590,11 +612,10 @@ def _calculate_until(
# they'll wait a while, get a reservation, and then keep
# the book for a while, and so on.
cycles += queue_position // total_licenses
- if (total_licenses > 1 and queue_position % total_licenses == 0):
+ if total_licenses > 1 and queue_position % total_licenses == 0:
cycles -= 1
return start + (cycle_period * cycles)
-
def until(self, default_loan_period, default_reservation_period):
"""Give or estimate the time at which the book will be available
to this patron.
@@ -622,8 +643,12 @@ def until(self, default_loan_period, default_reservation_period):
# end.
position = self.license_pool.patrons_in_hold_queue
return self._calculate_until(
- start, position, licenses_available,
- default_loan_period, default_reservation_period)
+ start,
+ position,
+ licenses_available,
+ default_loan_period,
+ default_reservation_period,
+ )
def update(self, start, end, position):
"""When the book becomes available, position will be 0 and end will be
@@ -638,9 +663,8 @@ def update(self, start, end, position):
if position is not None:
self.position = position
- __table_args__ = (
- UniqueConstraint('patron_id', 'license_pool_id'),
- )
+ __table_args__ = (UniqueConstraint("patron_id", "license_pool_id"),)
+
class Annotation(Base):
# The Web Annotation Data Model defines a basic set of motivations.
@@ -650,18 +674,18 @@ class Annotation(Base):
# We need to define some terms of our own.
LS_NAMESPACE = "http://librarysimplified.org/terms/annotation/"
- IDLING = LS_NAMESPACE + 'idling'
- BOOKMARKING = OA_NAMESPACE + 'bookmarking'
+ IDLING = LS_NAMESPACE + "idling"
+ BOOKMARKING = OA_NAMESPACE + "bookmarking"
MOTIVATIONS = [
IDLING,
BOOKMARKING,
]
- __tablename__ = 'annotations'
+ __tablename__ = "annotations"
id = Column(Integer, primary_key=True)
- patron_id = Column(Integer, ForeignKey('patrons.id'), index=True)
- identifier_id = Column(Integer, ForeignKey('identifiers.id'), index=True)
+ patron_id = Column(Integer, ForeignKey("patrons.id"), index=True)
+ identifier_id = Column(Integer, ForeignKey("identifiers.id"), index=True)
motivation = Column(Unicode, index=True)
timestamp = Column(DateTime(timezone=True), index=True)
active = Column(Boolean, default=True)
@@ -674,19 +698,16 @@ def get_one_or_create(self, _db, patron, *args, **kwargs):
annotation sync turned on.
"""
if not patron.synchronize_annotations:
- raise ValueError(
- "Patron has opted out of synchronizing annotations."
- )
+ raise ValueError("Patron has opted out of synchronizing annotations.")
- return get_one_or_create(
- _db, Annotation, patron=patron, *args, **kwargs
- )
+ return get_one_or_create(_db, Annotation, patron=patron, *args, **kwargs)
def set_inactive(self):
self.active = False
self.content = None
self.timestamp = utc_now()
+
class PatronProfileStorage(ProfileStorage):
"""Interface between a Patron object and the User Profile Management
Protocol.
@@ -713,13 +734,10 @@ def profile_document(self):
patron = self.patron
doc[self.AUTHORIZATION_IDENTIFIER] = patron.authorization_identifier
if patron.authorization_expires:
- doc[self.AUTHORIZATION_EXPIRES] = (
- patron.authorization_expires.strftime("%Y-%m-%dT%H:%M:%SZ")
+ doc[self.AUTHORIZATION_EXPIRES] = patron.authorization_expires.strftime(
+ "%Y-%m-%dT%H:%M:%SZ"
)
- settings = {
- self.SYNCHRONIZE_ANNOTATIONS :
- patron.synchronize_annotations
- }
+ settings = {self.SYNCHRONIZE_ANNOTATIONS: patron.synchronize_annotations}
doc[self.SETTINGS_KEY] = settings
return doc
diff --git a/model/resource.py b/model/resource.py
index 4efdcf53b..2c774a5bb 100644
--- a/model/resource.py
+++ b/model/resource.py
@@ -2,15 +2,19 @@
# Resource, ResourceTransformation, Hyperlink, Representation
-from io import BytesIO
import datetime
import json
import logging
-from hashlib import md5
import os
-from PIL import Image
import re
+import time
+import traceback
+from hashlib import md5
+from io import BytesIO
+from urllib.parse import quote, urlparse, urlsplit
+
import requests
+from PIL import Image
from sqlalchemy import (
Binary,
Column,
@@ -23,22 +27,14 @@
)
from sqlalchemy.dialects.postgresql import JSON
from sqlalchemy.ext.mutable import MutableDict
-from sqlalchemy.orm import (
- backref,
- relationship,
-)
+from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import or_
-import time
-import traceback
-from urllib.parse import urlparse, urlsplit, quote
-from . import (
- Base,
- get_one,
- get_one_or_create,
-)
from ..config import Configuration
+from ..util.datetime_helpers import utc_now
+from ..util.http import HTTP
+from . import Base, get_one, get_one_or_create
from .constants import (
DataSourceConstants,
IdentifierConstants,
@@ -46,19 +42,15 @@
MediaTypes,
)
from .edition import Edition
-from .licensing import (
- LicensePool,
- LicensePoolDeliveryMechanism,
-)
-from ..util.http import HTTP
-from ..util.datetime_helpers import utc_now
+from .licensing import LicensePool, LicensePoolDeliveryMechanism
+
class Resource(Base):
"""An external resource that may be mirrored locally.
E.g: a cover image, an epub, a description.
"""
- __tablename__ = 'resources'
+ __tablename__ = "resources"
# How many votes is the initial quality estimate worth?
ESTIMATED_QUALITY_WEIGHT = 5
@@ -76,51 +68,56 @@ class Resource(Base):
# Many Editions may choose this resource (as opposed to other
# resources linked to them with rel="image") as their cover image.
- cover_editions = relationship("Edition", backref="cover", foreign_keys=[Edition.cover_id])
+ cover_editions = relationship(
+ "Edition", backref="cover", foreign_keys=[Edition.cover_id]
+ )
# Many Works may use this resource (as opposed to other resources
# linked to them with rel="description") as their summary.
from .work import Work
- summary_works = relationship("Work", backref="summary", foreign_keys=[Work.summary_id])
+
+ summary_works = relationship(
+ "Work", backref="summary", foreign_keys=[Work.summary_id]
+ )
# Many LicensePools (but probably one at most) may use this
# resource in a delivery mechanism.
licensepooldeliverymechanisms = relationship(
- "LicensePoolDeliveryMechanism", backref="resource",
- foreign_keys=[LicensePoolDeliveryMechanism.resource_id]
+ "LicensePoolDeliveryMechanism",
+ backref="resource",
+ foreign_keys=[LicensePoolDeliveryMechanism.resource_id],
)
links = relationship("Hyperlink", backref="resource")
# The DataSource that is the controlling authority for this Resource.
- data_source_id = Column(Integer, ForeignKey('datasources.id'), index=True)
+ data_source_id = Column(Integer, ForeignKey("datasources.id"), index=True)
# An archived Representation of this Resource.
- representation_id = Column(
- Integer, ForeignKey('representations.id'), index=True)
+ representation_id = Column(Integer, ForeignKey("representations.id"), index=True)
# The rights status of this Resource.
- rights_status_id = Column(Integer, ForeignKey('rightsstatus.id'))
+ rights_status_id = Column(Integer, ForeignKey("rightsstatus.id"))
# An optional explanation of the rights status.
rights_explanation = Column(Unicode)
# A Resource may be transformed into many derivatives.
transformations = relationship(
- 'ResourceTransformation',
+ "ResourceTransformation",
primaryjoin="ResourceTransformation.original_id==Resource.id",
foreign_keys=id,
lazy="joined",
- backref=backref('original', uselist=False),
+ backref=backref("original", uselist=False),
uselist=True,
)
# A derivative resource may have one original.
derived_through = relationship(
- 'ResourceTransformation',
+ "ResourceTransformation",
primaryjoin="ResourceTransformation.derivative_id==Resource.id",
foreign_keys=id,
- backref=backref('derivative', uselist=False),
+ backref=backref("derivative", uselist=False),
lazy="joined",
uselist=False,
)
@@ -143,9 +140,7 @@ class Resource(Base):
quality = Column(Float, index=True)
# URL must be unique.
- __table_args__ = (
- UniqueConstraint('url'),
- )
+ __table_args__ = (UniqueConstraint("url"),)
@property
def final_url(self):
@@ -177,13 +172,12 @@ def set_fetched_content(self, media_type, content, content_path):
_db = Session.object_session(self)
if not (content or content_path):
- raise ValueError(
- "One of content and content_path must be specified.")
+ raise ValueError("One of content and content_path must be specified.")
if content and content_path:
- raise ValueError(
- "Only one of content and content_path may be specified.")
+ raise ValueError("Only one of content and content_path may be specified.")
representation, is_new = get_one_or_create(
- _db, Representation, url=self.url, media_type=media_type)
+ _db, Representation, url=self.url, media_type=media_type
+ )
self.representation = representation
representation.set_fetched_content(content, content_path)
@@ -198,7 +192,7 @@ def add_quality_votes(self, quality, weight=1):
self.votes_for_quality = self.votes_for_quality or 0
total_quality = self.voted_quality * self.votes_for_quality
- total_quality += (quality * weight)
+ total_quality += quality * weight
self.votes_for_quality += weight
self.voted_quality = total_quality / float(self.votes_for_quality)
self.update_quality()
@@ -260,8 +254,9 @@ def update_quality(self):
total_weight = estimated_weight + votes_for_quality
voted_quality = (self.voted_quality or 0) * votes_for_quality
- total_quality = (((self.estimated_quality or 0) * self.ESTIMATED_QUALITY_WEIGHT) +
- voted_quality)
+ total_quality = (
+ (self.estimated_quality or 0) * self.ESTIMATED_QUALITY_WEIGHT
+ ) + voted_quality
if voted_quality < 0 and total_quality > 0:
# If `voted_quality` is negative, the Resource has been
@@ -297,7 +292,7 @@ def best_covers_among(cls, resources):
continue
media_priority = cls.image_type_priority(rep.media_type)
if media_priority is None:
- media_priority = float('inf')
+ media_priority = float("inf")
# This method will set the quality if it hasn't been set before.
r.quality_as_thumbnail_image
@@ -328,8 +323,7 @@ def best_covers_among(cls, resources):
@property
def quality_as_thumbnail_image(self):
- """Determine this image's suitability for use as a thumbnail image.
- """
+ """Determine this image's suitability for use as a thumbnail image."""
rep = self.representation
if not rep:
return 0
@@ -341,11 +335,11 @@ def quality_as_thumbnail_image(self):
# Scale the estimated quality by the source of the image.
source_name = self.data_source.name
- if source_name==DataSourceConstants.GUTENBERG_COVER_GENERATOR:
+ if source_name == DataSourceConstants.GUTENBERG_COVER_GENERATOR:
quality = quality * 0.60
- elif source_name==DataSourceConstants.GUTENBERG:
+ elif source_name == DataSourceConstants.GUTENBERG:
quality = quality * 0.50
- elif source_name==DataSourceConstants.OPEN_LIBRARY:
+ elif source_name == DataSourceConstants.OPEN_LIBRARY:
quality = quality * 0.25
elif source_name in DataSourceConstants.COVER_IMAGE_PRIORITY:
# Covers from the data sources listed in
@@ -354,7 +348,7 @@ def quality_as_thumbnail_image(self):
# over all others, relative to their position in
# COVER_IMAGE_PRIORITY.
i = DataSourceConstants.COVER_IMAGE_PRIORITY.index(source_name)
- quality = quality * (i+2)
+ quality = quality * (i + 2)
self.set_estimated_quality(quality)
return quality
@@ -362,50 +356,56 @@ def add_derivative(self, derivative_resource, settings=None):
_db = Session.object_session(self)
transformation, ignore = get_one_or_create(
- _db, ResourceTransformation, derivative_id=derivative_resource.id)
+ _db, ResourceTransformation, derivative_id=derivative_resource.id
+ )
transformation.original_id = self.id
transformation.settings = settings or {}
return transformation
+
class ResourceTransformation(Base):
"""A record that a resource is a derivative of another resource,
and the settings that were used to transform the original into it.
"""
- __tablename__ = 'resourcetransformations'
+ __tablename__ = "resourcetransformations"
# The derivative resource. A resource can only be derived from one other resource.
derivative_id = Column(
- Integer, ForeignKey('resources.id'), index=True, primary_key=True)
+ Integer, ForeignKey("resources.id"), index=True, primary_key=True
+ )
# The original resource that was transformed into the derivative.
- original_id = Column(
- Integer, ForeignKey('resources.id'), index=True)
+ original_id = Column(Integer, ForeignKey("resources.id"), index=True)
# The settings used for the transformation.
settings = Column(MutableDict.as_mutable(JSON), default={})
+
class Hyperlink(Base, LinkRelations):
"""A link between an Identifier and a Resource."""
- __tablename__ = 'hyperlinks'
+ __tablename__ = "hyperlinks"
id = Column(Integer, primary_key=True)
# A Hyperlink is always associated with some Identifier.
identifier_id = Column(
- Integer, ForeignKey('identifiers.id'), index=True, nullable=False)
+ Integer, ForeignKey("identifiers.id"), index=True, nullable=False
+ )
# The DataSource through which this link was discovered.
data_source_id = Column(
- Integer, ForeignKey('datasources.id'), index=True, nullable=False)
+ Integer, ForeignKey("datasources.id"), index=True, nullable=False
+ )
# The link relation between the Identifier and the Resource.
rel = Column(Unicode, index=True, nullable=False)
# The Resource on the other end of the link.
resource_id = Column(
- Integer, ForeignKey('resources.id'), index=True, nullable=False)
+ Integer, ForeignKey("resources.id"), index=True, nullable=False
+ )
@classmethod
def unmirrored(cls, collection):
@@ -416,29 +416,30 @@ def unmirrored(cls, collection):
was created but not mirrored.)
"""
from .identifier import Identifier
+
_db = Session.object_session(collection)
- qu = _db.query(Hyperlink).join(
- Hyperlink.identifier
- ).join(
- Identifier.licensed_through
- ).outerjoin(
- Hyperlink.resource
- ).outerjoin(
- Resource.representation
+ qu = (
+ _db.query(Hyperlink)
+ .join(Hyperlink.identifier)
+ .join(Identifier.licensed_through)
+ .outerjoin(Hyperlink.resource)
+ .outerjoin(Resource.representation)
)
- qu = qu.filter(LicensePool.collection_id==collection.id)
+ qu = qu.filter(LicensePool.collection_id == collection.id)
qu = qu.filter(Hyperlink.rel.in_(Hyperlink.MIRRORED))
- qu = qu.filter(Hyperlink.data_source==collection.data_source)
+ qu = qu.filter(Hyperlink.data_source == collection.data_source)
qu = qu.filter(
or_(
- Representation.id==None,
- Representation.mirror_url==None,
+ Representation.id == None,
+ Representation.mirror_url == None,
)
)
# Without this ordering, the query does a table scan looking for
# items that match. With the ordering, they're all at the front.
- qu = qu.order_by(Representation.mirror_url.asc().nullsfirst(),
- Representation.id.asc().nullsfirst())
+ qu = qu.order_by(
+ Representation.mirror_url.asc().nullsfirst(),
+ Representation.id.asc().nullsfirst(),
+ )
return qu
@classmethod
@@ -465,11 +466,11 @@ def generic_uri(cls, data_source, identifier, rel, content=None):
@classmethod
def _default_filename(self, rel):
if rel == self.OPEN_ACCESS_DOWNLOAD:
- return 'content'
+ return "content"
elif rel == self.IMAGE:
- return 'cover'
+ return "cover"
elif rel == self.THUMBNAIL_IMAGE:
- return 'cover-thumbnail'
+ return "cover-thumbnail"
@property
def default_filename(self):
@@ -487,7 +488,7 @@ class Representation(Base, MediaTypes):
of.
"""
- __tablename__ = 'representations'
+ __tablename__ = "representations"
id = Column(Integer, primary_key=True)
# URL from which the representation was fetched.
@@ -531,13 +532,14 @@ class Representation(Base, MediaTypes):
# An image Representation may be a thumbnail version of another
# Representation.
- thumbnail_of_id = Column(
- Integer, ForeignKey('representations.id'), index=True)
+ thumbnail_of_id = Column(Integer, ForeignKey("representations.id"), index=True)
thumbnails = relationship(
"Representation",
- backref=backref("thumbnail_of", remote_side = [id]),
- lazy="joined", post_update=True)
+ backref=backref("thumbnail_of", remote_side=[id]),
+ lazy="joined",
+ post_update=True,
+ )
# The HTTP status code from the last fetch.
status_code = Column(Integer)
@@ -574,19 +576,20 @@ class Representation(Base, MediaTypes):
# A Representation may be a CachedMARCFile.
marc_file = relationship(
- "CachedMARCFile", backref="representation",
+ "CachedMARCFile",
+ backref="representation",
cascade="all, delete-orphan",
)
# At any given time, we will have a single representation for a
# given URL and media type.
- __table_args__ = (
- UniqueConstraint('url', 'media_type'),
- )
+ __table_args__ = (UniqueConstraint("url", "media_type"),)
# A User-Agent to use when acting like a web browser.
# BROWSER_USER_AGENT = "Mozilla/5.0 (Windows NT 6.3; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/37.0.2049.0 Safari/537.36 (Simplified)"
- BROWSER_USER_AGENT = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:37.0) Gecko/20100101 Firefox/37.0"
+ BROWSER_USER_AGENT = (
+ "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:37.0) Gecko/20100101 Firefox/37.0"
+ )
@property
def age(self):
@@ -598,7 +601,11 @@ def age(self):
def has_content(self):
if self.content and self.status_code == 200 and self.fetch_exception is None:
return True
- if self.local_content_path and os.path.exists(self.local_content_path) and self.fetch_exception is None:
+ if (
+ self.local_content_path
+ and os.path.exists(self.local_content_path)
+ and self.fetch_exception is None
+ ):
return True
return False
@@ -624,7 +631,9 @@ def is_usable(self):
a status code that's not in the 5xx series.
"""
if not self.fetch_exception and (
- self.content or self.local_path or self.status_code
+ self.content
+ or self.local_path
+ or self.status_code
and self.status_code // 100 != 5
):
return True
@@ -636,17 +645,20 @@ def is_media_type(cls, s):
if not s:
return False
s = s.lower()
- return any(s.startswith(x) for x in [
- 'application/',
- 'audio/',
- 'example/',
- 'image/',
- 'message/',
- 'model/',
- 'multipart/',
- 'text/',
- 'video/'
- ])
+ return any(
+ s.startswith(x)
+ for x in [
+ "application/",
+ "audio/",
+ "example/",
+ "image/",
+ "message/",
+ "model/",
+ "multipart/",
+ "text/",
+ "video/",
+ ]
+ )
@classmethod
def guess_url_media_type_from_path(cls, url):
@@ -674,13 +686,25 @@ def is_fresher_than(self, max_age):
if not self.is_usable:
return False
- return (max_age is None or max_age > self.age)
+ return max_age is None or max_age > self.age
@classmethod
- def get(cls, _db, url, do_get=None, extra_request_headers=None,
- accept=None, max_age=None, pause_before=0, allow_redirects=True,
- presumed_media_type=None, debug=True, response_reviewer=None,
- exception_handler=None, url_normalizer=None):
+ def get(
+ cls,
+ _db,
+ url,
+ do_get=None,
+ extra_request_headers=None,
+ accept=None,
+ max_age=None,
+ pause_before=0,
+ allow_redirects=True,
+ presumed_media_type=None,
+ debug=True,
+ response_reviewer=None,
+ exception_handler=None,
+ url_normalizer=None,
+ ):
"""Retrieve a representation from the cache if possible.
If not possible, retrieve it from the web and store it in the
cache.
@@ -754,8 +778,8 @@ def get(cls, _db, url, do_get=None, extra_request_headers=None,
a = dict(url=normalized_url)
if accept:
- a['media_type'] = accept
- representation = get_one(_db, Representation, 'interchangeable', **a)
+ a["media_type"] = accept
+ representation = get_one(_db, Representation, "interchangeable", **a)
usable_representation = fresh_representation = False
if representation:
@@ -785,16 +809,16 @@ def get(cls, _db, url, do_get=None, extra_request_headers=None,
if extra_request_headers:
headers.update(extra_request_headers)
if accept:
- headers['Accept'] = accept
+ headers["Accept"] = accept
if usable_representation:
# We have a representation but it's not fresh. We will
# be making a conditional HTTP request to see if there's
# a new version.
if representation.last_modified:
- headers['If-Modified-Since'] = representation.last_modified
+ headers["If-Modified-Since"] = representation.last_modified
if representation.etag:
- headers['If-None-Match'] = representation.etag
+ headers["If-None-Match"] = representation.etag
fetched_at = utc_now()
if pause_before:
@@ -817,7 +841,9 @@ def get(cls, _db, url, do_get=None, extra_request_headers=None,
# request, not that the HTTP request returned an error
# condition.
fetch_exception = e
- logging.error("Error making HTTP request to %s", url, exc_info=fetch_exception)
+ logging.error(
+ "Error making HTTP request to %s", url, exc_info=fetch_exception
+ )
exception_traceback = traceback.format_exc()
status_code = None
@@ -829,18 +855,17 @@ def get(cls, _db, url, do_get=None, extra_request_headers=None,
# we don't have one already, or if the URL or media type we
# actually got from the server differs from what we thought
# we had.
- if (not usable_representation
+ if (
+ not usable_representation
or media_type != representation.media_type
- or normalized_url != representation.url):
+ or normalized_url != representation.url
+ ):
representation, is_new = get_one_or_create(
- _db, Representation, url=normalized_url,
- media_type=str(media_type)
+ _db, Representation, url=normalized_url, media_type=str(media_type)
)
if fetch_exception:
- exception_handler(
- representation, fetch_exception, exception_traceback
- )
+ exception_handler(representation, fetch_exception, exception_traceback)
representation.fetched_at = fetched_at
if status_code == 304:
@@ -856,7 +881,7 @@ def get(cls, _db, url, do_get=None, extra_request_headers=None,
else:
status_code_series = None
- if status_code_series in (2,3) or status_code in (404, 410):
+ if status_code_series in (2, 3) or status_code in (404, 410):
# We have a new, good representation. Update the
# Representation object and return it as fresh.
representation.status_code = status_code
@@ -864,9 +889,10 @@ def get(cls, _db, url, do_get=None, extra_request_headers=None,
representation.media_type = media_type
for header, field in (
- ('etag', 'etag'),
- ('last-modified', 'last_modified'),
- ('location', 'location')):
+ ("etag", "etag"),
+ ("last-modified", "last_modified"),
+ ("location", "location"),
+ ):
if header in headers:
value = headers[header]
else:
@@ -880,8 +906,9 @@ def get(cls, _db, url, do_get=None, extra_request_headers=None,
# Okay, things didn't go so well.
date_string = fetched_at.strftime("%Y-%m-%d %H:%M:%S")
representation.fetch_exception = representation.fetch_exception or (
- "Most recent fetch attempt (at %s) got status code %s" % (
- date_string, status_code))
+ "Most recent fetch attempt (at %s) got status code %s"
+ % (date_string, status_code)
+ )
if usable_representation:
# If we have a usable (but stale) representation, we'd
# rather return the cached data than destroy the information.
@@ -905,9 +932,9 @@ def _best_media_type(cls, url, headers, default):
derive one from the URL extension.
"""
default = default or cls.guess_url_media_type_from_path(url)
- if not headers or not 'content-type' in headers:
+ if not headers or not "content-type" in headers:
return default
- headers_type = headers['content-type'].lower()
+ headers_type = headers["content-type"].lower()
clean = cls._clean_media_type(headers_type)
if clean in Representation.GENERIC_MEDIA_TYPES and default:
return default
@@ -926,19 +953,22 @@ def record_exception(cls, representation, exception, traceback):
representation.fetch_exception = traceback
@classmethod
- def post(cls, _db, url, data, max_age=None, response_reviewer=None,
- **kwargs):
+ def post(cls, _db, url, data, max_age=None, response_reviewer=None, **kwargs):
"""Finds or creates POST request as a Representation"""
- original_do_get = kwargs.pop('do_get', cls.simple_http_post)
+ original_do_get = kwargs.pop("do_get", cls.simple_http_post)
def do_post(url, headers, **kwargs):
- kwargs.update({'data' : data})
+ kwargs.update({"data": data})
return original_do_get(url, headers, **kwargs)
return cls.get(
- _db, url, do_get=do_post, max_age=max_age,
- response_reviewer=response_reviewer, **kwargs
+ _db,
+ url,
+ do_get=do_post,
+ max_age=max_age,
+ response_reviewer=response_reviewer,
+ **kwargs
)
@property
@@ -948,9 +978,8 @@ def mirrorable_media_type(self):
Basically, images and books.
"""
return any(
- self.media_type in x for x in
- (Representation.BOOK_MEDIA_TYPES,
- Representation.IMAGE_MEDIA_TYPES)
+ self.media_type in x
+ for x in (Representation.BOOK_MEDIA_TYPES, Representation.IMAGE_MEDIA_TYPES)
)
def update_image_size(self):
@@ -958,7 +987,7 @@ def update_image_size(self):
Clears .image_height and .image_width if the representation
is not an image.
"""
- if self.media_type and self.media_type.startswith('image/'):
+ if self.media_type and self.media_type.startswith("image/"):
image = self.as_image()
if image:
self.image_width, self.image_height = image.size
@@ -971,8 +1000,8 @@ def normalize_content_path(cls, content_path, base=None):
return None
base = base or Configuration.data_directory()
if content_path.startswith(base):
- content_path = content_path[len(base):]
- if content_path.startswith('/'):
+ content_path = content_path[len(base) :]
+ if content_path.startswith("/"):
content_path = content_path[1:]
return content_path
@@ -982,7 +1011,7 @@ def unicode_content(self):
If all attempts fail, we will return None rather than raise an exception.
"""
content = None
- for encoding in ('utf-8', 'windows-1252'):
+ for encoding in ("utf-8", "windows-1252"):
try:
content = self.content.decode(encoding)
break
@@ -1024,17 +1053,17 @@ def headers_to_string(cls, d):
@classmethod
def simple_http_get(cls, url, headers, **kwargs):
"""The most simple HTTP-based GET."""
- if not 'allow_redirects' in kwargs:
- kwargs['allow_redirects'] = True
+ if not "allow_redirects" in kwargs:
+ kwargs["allow_redirects"] = True
response = HTTP.get_with_timeout(url, headers=headers, **kwargs)
return response.status_code, response.headers, response.content
@classmethod
def simple_http_post(cls, url, headers, **kwargs):
"""The most simple HTTP-based POST."""
- data = kwargs.get('data')
- if 'data' in kwargs:
- del kwargs['data']
+ data = kwargs.get("data")
+ if "data" in kwargs:
+ del kwargs["data"]
response = HTTP.post_with_timeout(url, data, headers=headers, **kwargs)
return response.status_code, response.headers, response.content
@@ -1049,10 +1078,9 @@ def http_get_no_redirect(cls, url, headers, **kwargs):
@classmethod
def browser_http_get(cls, url, headers, **kwargs):
- """GET the representation that would be displayed to a web browser.
- """
+ """GET the representation that would be displayed to a web browser."""
headers = dict(headers)
- headers['User-Agent'] = cls.BROWSER_USER_AGENT
+ headers["User-Agent"] = cls.BROWSER_USER_AGENT
return cls.simple_http_get(url, headers, **kwargs)
@classmethod
@@ -1066,25 +1094,20 @@ def cautious_http_get(cls, url, headers, **kwargs):
gutenberg.org quickly result in IP bans. So we don't make those
requests.
"""
- do_not_access = kwargs.pop(
- 'do_not_access', cls.AVOID_WHEN_CAUTIOUS_DOMAINS
- )
+ do_not_access = kwargs.pop("do_not_access", cls.AVOID_WHEN_CAUTIOUS_DOMAINS)
check_for_redirect = kwargs.pop(
- 'check_for_redirect', cls.EXERCISE_CAUTION_DOMAINS
+ "check_for_redirect", cls.EXERCISE_CAUTION_DOMAINS
)
- do_get = kwargs.pop('do_get', cls.simple_http_get)
- head_client = kwargs.pop('cautious_head_client', requests.head)
+ do_get = kwargs.pop("do_get", cls.simple_http_get)
+ head_client = kwargs.pop("cautious_head_client", requests.head)
if cls.get_would_be_useful(
- url, headers, do_not_access, check_for_redirect,
- head_client
+ url, headers, do_not_access, check_for_redirect, head_client
):
# Go ahead and make the GET request.
return do_get(url, headers, **kwargs)
else:
- logging.info(
- "Declining to make non-useful HTTP request to %s", url
- )
+ logging.info("Declining to make non-useful HTTP request to %s", url)
# 417 Expectation Failed - "... if the server is a proxy,
# the server has unambiguous evidence that the request
# could not be met by the next-hop server."
@@ -1094,23 +1117,23 @@ def cautious_http_get(cls, url, headers, **kwargs):
# request".
return (
417,
- {"content-type" :
- "application/vnd.librarysimplified-did-not-make-request"},
- "Cautiously decided not to make a GET request to %s" % url
+ {
+ "content-type": "application/vnd.librarysimplified-did-not-make-request"
+ },
+ "Cautiously decided not to make a GET request to %s" % url,
)
# Sites known to host both free books and redirects to a domain in
# AVOID_WHEN_CAUTIOUS_DOMAINS.
- EXERCISE_CAUTION_DOMAINS = ['unglue.it']
+ EXERCISE_CAUTION_DOMAINS = ["unglue.it"]
# Sites that cause problems for us if we make automated
# HTTP requests to them while trying to find free books.
- AVOID_WHEN_CAUTIOUS_DOMAINS = ['gutenberg.org', 'books.google.com']
+ AVOID_WHEN_CAUTIOUS_DOMAINS = ["gutenberg.org", "books.google.com"]
@classmethod
def get_would_be_useful(
- cls, url, headers, do_not_access=None, check_for_redirect=None,
- head_client=None
+ cls, url, headers, do_not_access=None, check_for_redirect=None, head_client=None
):
"""Determine whether making a GET request to a given URL is likely to
have a useful result.
@@ -1131,8 +1154,7 @@ def has_domain(domain, check_against):
"""Is the given `domain` in `check_against`,
or maybe a subdomain of one of the domains in `check_against`?
"""
- return any(domain == x or domain.endswith('.' + x)
- for x in check_against)
+ return any(domain == x or domain.endswith("." + x) for x in check_against)
netloc = urlparse(url).netloc
if has_domain(netloc, do_not_access):
@@ -1155,7 +1177,7 @@ def has_domain(domain, check_against):
# Yes, it's a redirect. Does it redirect to a
# domain we don't want to access?
- location = head_response.headers.get('location', '')
+ location = head_response.headers.get("location", "")
netloc = urlparse(location).netloc
return not has_domain(netloc, do_not_access)
@@ -1168,8 +1190,7 @@ def local_path(self):
"""Return the full local path to the representation on disk."""
if not self.local_content_path:
return None
- return os.path.join(Configuration.data_directory(),
- self.local_content_path)
+ return os.path.join(Configuration.data_directory(), self.local_content_path)
@property
def clean_media_type(self):
@@ -1233,16 +1254,16 @@ def extension(self, destination_type=None):
def _clean_media_type(cls, media_type):
if not media_type:
return media_type
- if ';' in media_type:
- media_type = media_type[:media_type.index(';')].strip()
+ if ";" in media_type:
+ media_type = media_type[: media_type.index(";")].strip()
return media_type
@classmethod
def _extension(cls, media_type):
- value = cls.FILE_EXTENSIONS.get(media_type, '')
+ value = cls.FILE_EXTENSIONS.get(media_type, "")
if not value:
return value
- return '.' + value
+ return "." + value
def default_filename(self, link=None, destination_type=None):
"""Try to come up with a good filename for this representation."""
@@ -1259,12 +1280,16 @@ def default_filename(self, link=None, destination_type=None):
# This is the absolute last-ditch filename solution, and
# it's basically only used when we try to mirror the root
# URL of a domain.
- filename = 'resource'
+ filename = "resource"
default_extension = self.extension()
extension = self.extension(destination_type)
- if default_extension and default_extension != extension and filename.endswith(default_extension):
- filename = filename[:-len(default_extension)] + extension
+ if (
+ default_extension
+ and default_extension != extension
+ and filename.endswith(default_extension)
+ ):
+ filename = filename[: -len(default_extension)] + extension
elif extension and not filename.endswith(extension):
filename += extension
return filename
@@ -1292,7 +1317,7 @@ def content_fh(self):
elif self.local_path:
if not os.path.exists(self.local_path):
raise ValueError("%s does not exist." % self.local_path)
- return open(self.local_path, 'rb')
+ return open(self.local_path, "rb")
return None
def as_image(self):
@@ -1300,7 +1325,8 @@ def as_image(self):
if not self.is_image:
raise ValueError(
"Cannot load non-image representation as image: type %s."
- % self.media_type)
+ % self.media_type
+ )
if not self.content and not self.local_path:
raise ValueError("Image representation has no content.")
@@ -1315,8 +1341,14 @@ def as_image(self):
"image/jpeg": "jpeg",
}
- def scale(self, max_height, max_width,
- destination_url, destination_media_type, force=False):
+ def scale(
+ self,
+ max_height,
+ max_width,
+ destination_url,
+ destination_media_type,
+ force=False,
+ ):
"""Return a Representation that's a scaled-down version of this
Representation, creating it if necessary.
:param destination_url: The URL the scaled-down resource will
@@ -1326,7 +1358,9 @@ def scale(self, max_height, max_width,
_db = Session.object_session(self)
if not destination_media_type in self.pil_format_for_media_type:
- raise ValueError("Unsupported destination media type: %s" % destination_media_type)
+ raise ValueError(
+ "Unsupported destination media type: %s" % destination_media_type
+ )
pil_format = self.pil_format_for_media_type[destination_media_type]
@@ -1340,7 +1374,8 @@ def scale(self, max_height, max_width,
# This most likely indicates an error during the fetch
# phrase.
self.fetch_exception = "Error found while scaling: %s" % (
- self.scale_exception)
+ self.scale_exception
+ )
logging.error("Error found while scaling %r", self, exc_info=e)
if not image:
@@ -1351,16 +1386,17 @@ def scale(self, max_height, max_width,
self.image_width, self.image_height = image.size
# If the image is already a thumbnail-size bitmap, don't bother.
- if (self.clean_media_type != Representation.SVG_MEDIA_TYPE
+ if (
+ self.clean_media_type != Representation.SVG_MEDIA_TYPE
and self.image_height <= max_height
- and self.image_width <= max_width):
+ and self.image_width <= max_width
+ ):
self.thumbnails = []
return self, False
# Do we already have a representation for the given URL?
thumbnail, is_new = get_one_or_create(
- _db, Representation, url=destination_url,
- media_type=destination_media_type
+ _db, Representation, url=destination_url, media_type=destination_media_type
)
if thumbnail not in self.thumbnails:
thumbnail.thumbnail_of = self
@@ -1399,8 +1435,8 @@ def scale(self, max_height, max_width,
# Save the thumbnail image to the database under
# thumbnail.content.
output = BytesIO()
- if image.mode != 'RGB':
- image = image.convert('RGB')
+ if image.mode != "RGB":
+ image = image.convert("RGB")
try:
image.save(output, pil_format)
except Exception as e:
@@ -1408,7 +1444,9 @@ def scale(self, max_height, max_width,
self.scaled_at = None
# This most likely indicates a problem during the fetch phase,
# Set fetch_exception so we'll retry the fetch.
- self.fetch_exception = "Error found while scaling: %s" % (self.scale_exception)
+ self.fetch_exception = "Error found while scaling: %s" % (
+ self.scale_exception
+ )
return self, False
thumbnail.content = output.getvalue()
thumbnail.image_width, thumbnail.image_height = image.size
@@ -1419,9 +1457,7 @@ def scale(self, max_height, max_width,
@property
def thumbnail_size_quality_penalty(self):
- return self._thumbnail_size_quality_penalty(
- self.image_width, self.image_height
- )
+ return self._thumbnail_size_quality_penalty(self.image_width, self.image_height)
@classmethod
def _thumbnail_size_quality_penalty(cls, width, height):
@@ -1462,21 +1498,25 @@ def _thumbnail_size_quality_penalty(cls, width, height):
if aspect_ratio > ideal:
deviation = ideal / aspect_ratio
else:
- deviation = aspect_ratio/ideal
+ deviation = aspect_ratio / ideal
if deviation != 1:
quotient *= deviation
# Penalize an image for not being wide enough.
width_shortfall = (
- float(width - IdentifierConstants.IDEAL_IMAGE_WIDTH) / IdentifierConstants.IDEAL_IMAGE_WIDTH)
+ float(width - IdentifierConstants.IDEAL_IMAGE_WIDTH)
+ / IdentifierConstants.IDEAL_IMAGE_WIDTH
+ )
if width_shortfall < 0:
- quotient *= (1+width_shortfall)
+ quotient *= 1 + width_shortfall
# Penalize an image for not being tall enough.
height_shortfall = (
- float(height - IdentifierConstants.IDEAL_IMAGE_HEIGHT) / IdentifierConstants.IDEAL_IMAGE_HEIGHT)
+ float(height - IdentifierConstants.IDEAL_IMAGE_HEIGHT)
+ / IdentifierConstants.IDEAL_IMAGE_HEIGHT
+ )
if height_shortfall < 0:
- quotient *= (1+height_shortfall)
+ quotient *= 1 + height_shortfall
return quotient
@property
diff --git a/model/work.py b/model/work.py
index 83f5be978..0a93d32b8 100644
--- a/model/work.py
+++ b/model/work.py
@@ -3,6 +3,7 @@
import logging
from collections import Counter
+
from sqlalchemy import (
Boolean,
Column,
@@ -17,61 +18,40 @@
)
from sqlalchemy.dialects.postgresql import INT4RANGE
from sqlalchemy.ext.associationproxy import association_proxy
-from sqlalchemy.orm import (
- contains_eager,
- relationship,
-)
+from sqlalchemy.orm import contains_eager, relationship
from sqlalchemy.orm.session import Session
-from sqlalchemy.sql.expression import (
- and_,
- or_,
- select,
- join,
- literal_column,
- case,
-)
+from sqlalchemy.sql.expression import and_, case, join, literal_column, or_, select
from sqlalchemy.sql.functions import func
-from .constants import (
- DataSourceConstants,
-)
-from .contributor import (
- Contribution,
- Contributor,
-)
-from .coverage import (
- CoverageRecord,
- WorkCoverageRecord,
-)
-from .datasource import DataSource
-from .edition import Edition
-from .identifier import Identifier
-from .measurement import Measurement
+from ..classifier import Classifier, WorkClassifier
+from ..config import CannotLoadConfiguration
+from ..util import LanguageCodes
+from ..util.datetime_helpers import utc_now
from . import (
Base,
+ PresentationCalculationPolicy,
flush,
get_one_or_create,
numericrange_to_string,
numericrange_to_tuple,
- PresentationCalculationPolicy,
tuple_to_numericrange,
)
-from ..classifier import (
- Classifier,
- WorkClassifier,
-)
-from ..config import CannotLoadConfiguration
-from ..util import LanguageCodes
-from ..util.datetime_helpers import utc_now
+from .constants import DataSourceConstants
+from .contributor import Contribution, Contributor
+from .coverage import CoverageRecord, WorkCoverageRecord
+from .datasource import DataSource
+from .edition import Edition
+from .identifier import Identifier
+from .measurement import Measurement
class WorkGenre(Base):
"""An assignment of a genre to a work."""
- __tablename__ = 'workgenres'
+ __tablename__ = "workgenres"
id = Column(Integer, primary_key=True)
- genre_id = Column(Integer, ForeignKey('genres.id'), index=True)
- work_id = Column(Integer, ForeignKey('works.id'), index=True)
+ genre_id = Column(Integer, ForeignKey("genres.id"), index=True)
+ work_id = Column(Integer, ForeignKey("works.id"), index=True)
affinity = Column(Float, index=True, default=0)
@classmethod
@@ -81,7 +61,7 @@ def from_genre(cls, genre):
return wg
def __repr__(self):
- return "%s (%d%%)" % (self.genre.name, self.affinity*100)
+ return "%s (%d%%)" % (self.genre.name, self.affinity * 100)
class Work(Base):
@@ -115,53 +95,53 @@ class Work(Base):
default_quality_by_data_source = {
DataSourceConstants.GUTENBERG: 0,
DataSourceConstants.OVERDRIVE: 0.4,
- DataSourceConstants.BIBLIOTHECA : 0.65,
+ DataSourceConstants.BIBLIOTHECA: 0.65,
DataSourceConstants.AXIS_360: 0.65,
DataSourceConstants.STANDARD_EBOOKS: 0.8,
DataSourceConstants.UNGLUE_IT: 0.4,
DataSourceConstants.PLYMPTON: 0.5,
}
- __tablename__ = 'works'
+ __tablename__ = "works"
id = Column(Integer, primary_key=True)
# One Work may have copies scattered across many LicensePools.
- license_pools = relationship("LicensePool", backref="work", lazy='joined')
+ license_pools = relationship("LicensePool", backref="work", lazy="joined")
# A Work takes its presentation metadata from a single Edition.
# But this Edition is a composite of provider, metadata wrangler, admin interface, etc.-derived Editions.
- presentation_edition_id = Column(Integer, ForeignKey('editions.id'), index=True)
+ presentation_edition_id = Column(Integer, ForeignKey("editions.id"), index=True)
# One Work may have many associated WorkCoverageRecords.
coverage_records = relationship(
- "WorkCoverageRecord", backref="work",
- cascade="all, delete-orphan"
+ "WorkCoverageRecord", backref="work", cascade="all, delete-orphan"
)
# One Work may be associated with many CustomListEntries.
# However, a CustomListEntry may lose its Work without
# ceasing to exist.
- custom_list_entries = relationship('CustomListEntry', backref='work')
+ custom_list_entries = relationship("CustomListEntry", backref="work")
# One Work may have multiple CachedFeeds, and if a CachedFeed
# loses its Work, it ceases to exist.
cached_feeds = relationship(
- 'CachedFeed', backref='work', cascade="all, delete-orphan"
+ "CachedFeed", backref="work", cascade="all, delete-orphan"
)
# One Work may participate in many WorkGenre assignments.
- genres = association_proxy('work_genres', 'genre',
- creator=WorkGenre.from_genre)
- work_genres = relationship("WorkGenre", backref="work",
- cascade="all, delete-orphan")
+ genres = association_proxy("work_genres", "genre", creator=WorkGenre.from_genre)
+ work_genres = relationship(
+ "WorkGenre", backref="work", cascade="all, delete-orphan"
+ )
audience = Column(Unicode, index=True)
target_age = Column(INT4RANGE, index=True)
fiction = Column(Boolean, index=True)
summary_id = Column(
- Integer, ForeignKey(
- 'resources.id', use_alter=True, name='fk_works_summary_id'),
- index=True)
+ Integer,
+ ForeignKey("resources.id", use_alter=True, name="fk_works_summary_id"),
+ index=True,
+ )
# This gives us a convenient place to store a cleaned-up version of
# the content of the summary Resource.
summary_text = Column(Unicode)
@@ -169,7 +149,7 @@ class Work(Base):
# The overall suitability of this work for unsolicited
# presentation to a patron. This is a calculated value taking both
# rating and popularity into account.
- quality = Column(Numeric(4,3), index=True)
+ quality = Column(Numeric(4, 3), index=True)
# The overall rating given to this work.
rating = Column(Float, index=True)
@@ -177,9 +157,16 @@ class Work(Base):
# The overall current popularity of this work.
popularity = Column(Float, index=True)
- appeal_type = Enum(CHARACTER_APPEAL, LANGUAGE_APPEAL, SETTING_APPEAL,
- STORY_APPEAL, NOT_APPLICABLE_APPEAL, NO_APPEAL,
- UNKNOWN_APPEAL, name="appeal")
+ appeal_type = Enum(
+ CHARACTER_APPEAL,
+ LANGUAGE_APPEAL,
+ SETTING_APPEAL,
+ STORY_APPEAL,
+ NOT_APPLICABLE_APPEAL,
+ NO_APPEAL,
+ UNKNOWN_APPEAL,
+ name="appeal",
+ )
primary_appeal = Column(appeal_type, default=None, index=True)
secondary_appeal = Column(appeal_type, default=None, index=True)
@@ -198,7 +185,9 @@ class Work(Base):
presentation_ready = Column(Boolean, default=False, index=True)
# This is the last time we tried to make this work presentation ready.
- presentation_ready_attempt = Column(DateTime(timezone=True), default=None, index=True)
+ presentation_ready_attempt = Column(
+ DateTime(timezone=True), default=None, index=True
+ )
# This is the error that occured while trying to make this Work
# presentation ready. Until this is cleared, no further attempt
@@ -222,8 +211,10 @@ class Work(Base):
# These fields are potentially large and can be deferred if you
# don't need all the data in a Work.
LARGE_FIELDS = [
- 'simple_opds_entry', 'verbose_opds_entry', 'marc_record',
- 'summary_text',
+ "simple_opds_entry",
+ "verbose_opds_entry",
+ "marc_record",
+ "summary_text",
]
@property
@@ -320,22 +311,26 @@ def complaints(self):
def __repr__(self):
return '' % (
- self.id, self.title, self.author,
- ", ".join([g.name for g in self.genres]), self.language,
- len(self.license_pools)
+ self.id,
+ self.title,
+ self.author,
+ ", ".join([g.name for g in self.genres]),
+ self.language,
+ len(self.license_pools),
)
@classmethod
def missing_coverage_from(
- cls, _db, operation=None, count_as_covered=None,
- count_as_missing_before=None
+ cls, _db, operation=None, count_as_covered=None, count_as_missing_before=None
):
"""Find Works which have no WorkCoverageRecord for the given
`operation`.
"""
- clause = and_(Work.id==WorkCoverageRecord.work_id,
- WorkCoverageRecord.operation==operation)
+ clause = and_(
+ Work.id == WorkCoverageRecord.work_id,
+ WorkCoverageRecord.operation == operation,
+ )
q = _db.query(Work).outerjoin(WorkCoverageRecord, clause)
missing = WorkCoverageRecord.not_covered(
@@ -346,25 +341,26 @@ def missing_coverage_from(
@classmethod
def for_unchecked_subjects(cls, _db):
- from .classification import (
- Classification,
- Subject,
- )
+ from .classification import Classification, Subject
from .licensing import LicensePool
+
"""Find all Works whose LicensePools have an Identifier that
is classified under an unchecked Subject.
This is a good indicator that the Work needs to be
reclassified.
"""
- qu = _db.query(Work).join(Work.license_pools).join(
- LicensePool.identifier).join(
- Identifier.classifications).join(
- Classification.subject)
- return qu.filter(Subject.checked==False).order_by(Subject.id)
+ qu = (
+ _db.query(Work)
+ .join(Work.license_pools)
+ .join(LicensePool.identifier)
+ .join(Identifier.classifications)
+ .join(Classification.subject)
+ )
+ return qu.filter(Subject.checked == False).order_by(Subject.id)
@classmethod
def _potential_open_access_works_for_permanent_work_id(
- cls, _db, pwid, medium, language
+ cls, _db, pwid, medium, language
):
"""Find all Works that might be suitable for use as the
canonical open-access Work for the given `pwid`, `medium`,
@@ -375,16 +371,15 @@ def _potential_open_access_works_for_permanent_work_id(
associated with a given work.
"""
from .licensing import LicensePool
- qu = _db.query(LicensePool).join(
- LicensePool.presentation_edition).filter(
- LicensePool.open_access==True
- ).filter(
- Edition.permanent_work_id==pwid
- ).filter(
- Edition.medium==medium
- ).filter(
- Edition.language==language
- )
+
+ qu = (
+ _db.query(LicensePool)
+ .join(LicensePool.presentation_edition)
+ .filter(LicensePool.open_access == True)
+ .filter(Edition.permanent_work_id == pwid)
+ .filter(Edition.medium == medium)
+ .filter(Edition.language == language)
+ )
pools = set(qu.all())
# Build the Counter of Works that are eligible to represent
@@ -399,8 +394,9 @@ def _potential_open_access_works_for_permanent_work_id(
continue
pe = work.presentation_edition
if pe and (
- pe.language != language or pe.medium != medium
- or pe.permanent_work_id != pwid
+ pe.language != language
+ or pe.medium != medium
+ or pe.permanent_work_id != pwid
):
# This Work's presentation edition doesn't match
# this LicensePool's presentation edition.
@@ -424,7 +420,10 @@ def open_access_for_permanent_work_id(cls, _db, pwid, medium, language):
"""
is_new = False
- licensepools, licensepools_for_work = cls._potential_open_access_works_for_permanent_work_id(
+ (
+ licensepools,
+ licensepools_for_work,
+ ) = cls._potential_open_access_works_for_permanent_work_id(
_db, pwid, medium, language
)
if not licensepools:
@@ -463,7 +462,9 @@ def open_access_for_permanent_work_id(cls, _db, pwid, medium, language):
# nothing but LicensePools whose permanent
# work ID matches the permanent work ID of the
# Work we're about to merge into.
- needs_merge.make_exclusive_open_access_for_permanent_work_id(pwid, medium, language)
+ needs_merge.make_exclusive_open_access_for_permanent_work_id(
+ pwid, medium, language
+ )
needs_merge.merge_into(work)
# At this point we have one, and only one, Work for this
@@ -501,7 +502,7 @@ def make_exclusive_open_access_for_permanent_work_id(self, pwid, medium, languag
# cannot have an associated Work.
logging.warning(
"LicensePool %r has no presentation edition, setting .work to None.",
- pool
+ pool,
)
pool.work = None
else:
@@ -512,7 +513,7 @@ def make_exclusive_open_access_for_permanent_work_id(self, pwid, medium, languag
# cannot have an associated Work.
logging.warning(
"Presentation edition for LicensePool %r has no PWID, setting .work to None.",
- pool
+ pool,
)
e.work = None
pool.work = None
@@ -538,7 +539,10 @@ def pwids(self):
"""
pwids = set()
for pool in self.license_pools:
- if pool.presentation_edition and pool.presentation_edition.permanent_work_id:
+ if (
+ pool.presentation_edition
+ and pool.presentation_edition.permanent_work_id
+ ):
pwids.add(pool.presentation_edition.permanent_work_id)
return pwids
@@ -551,18 +555,20 @@ def merge_into(self, other_work):
for pool in w.license_pools:
if not pool.open_access:
raise ValueError(
-
- "Refusing to merge %r into %r because it would put an open-access LicensePool into the same work as a non-open-access LicensePool." %
- (self, other_work)
- )
+ "Refusing to merge %r into %r because it would put an open-access LicensePool into the same work as a non-open-access LicensePool."
+ % (self, other_work)
+ )
my_pwids = self.pwids
other_pwids = other_work.pwids
if not my_pwids == other_pwids:
raise ValueError(
- "Refusing to merge %r into %r because permanent work IDs don't match: %s vs. %s" % (
- self, other_work, ",".join(sorted(my_pwids)),
- ",".join(sorted(other_pwids))
+ "Refusing to merge %r into %r because permanent work IDs don't match: %s vs. %s"
+ % (
+ self,
+ other_work,
+ ",".join(sorted(my_pwids)),
+ ",".join(sorted(other_pwids)),
)
)
@@ -587,17 +593,16 @@ def set_summary(self, resource):
self.summary_text = resource.representation.unicode_content
else:
self.summary_text = ""
- WorkCoverageRecord.add_for(
- self, operation=WorkCoverageRecord.SUMMARY_OPERATION
- )
+ WorkCoverageRecord.add_for(self, operation=WorkCoverageRecord.SUMMARY_OPERATION)
@classmethod
def with_genre(cls, _db, genre):
"""Find all Works classified under the given genre."""
from .classification import Genre
+
if isinstance(genre, (bytes, str)):
genre, ignore = Genre.lookup(_db, genre)
- return _db.query(Work).join(WorkGenre).filter(WorkGenre.genre==genre)
+ return _db.query(Work).join(WorkGenre).filter(WorkGenre.genre == genre)
@classmethod
def with_no_genres(self, q):
@@ -605,7 +610,7 @@ def with_no_genres(self, q):
any genre."""
q = q.outerjoin(Work.work_genres)
q = q.options(contains_eager(Work.work_genres))
- q = q.filter(WorkGenre.genre==None)
+ q = q.filter(WorkGenre.genre == None)
return q
@classmethod
@@ -620,6 +625,7 @@ def from_identifiers(cls, _db, identifiers, base_query=None, policy=None):
about equivalencies.
"""
from .licensing import LicensePool
+
identifier_ids = [identifier.id for identifier in identifiers]
if not identifier_ids:
return None
@@ -627,31 +633,32 @@ def from_identifiers(cls, _db, identifiers, base_query=None, policy=None):
if not base_query:
# A raw base query that makes no accommodations for works that are
# suppressed or otherwise undeliverable.
- base_query = _db.query(Work).join(Work.license_pools).\
- join(LicensePool.identifier)
+ base_query = (
+ _db.query(Work).join(Work.license_pools).join(LicensePool.identifier)
+ )
if policy is None:
policy = PresentationCalculationPolicy(
- equivalent_identifier_levels=1,
- equivalent_identifier_threshold=0.999
+ equivalent_identifier_levels=1, equivalent_identifier_threshold=0.999
)
- identifier_ids_subquery = Identifier.recursively_equivalent_identifier_ids_query(
- Identifier.id, policy=policy)
- identifier_ids_subquery = identifier_ids_subquery.where(Identifier.id.in_(identifier_ids))
+ identifier_ids_subquery = (
+ Identifier.recursively_equivalent_identifier_ids_query(
+ Identifier.id, policy=policy
+ )
+ )
+ identifier_ids_subquery = identifier_ids_subquery.where(
+ Identifier.id.in_(identifier_ids)
+ )
query = base_query.filter(Identifier.id.in_(identifier_ids_subquery))
return query
@classmethod
- def reject_covers(cls, _db, works_or_identifiers,
- search_index_client=None):
+ def reject_covers(cls, _db, works_or_identifiers, search_index_client=None):
"""Suppresses the currently visible covers of a number of Works"""
from .licensing import LicensePool
- from .resource import (
- Resource,
- Hyperlink,
- )
+ from .resource import Hyperlink, Resource
works = list(set(works_or_identifiers))
if not isinstance(works[0], cls):
@@ -680,11 +687,12 @@ def reject_covers(cls, _db, works_or_identifiers,
# covers suppressed. Nothing to see here.
return
- covers = _db.query(Resource).join(Hyperlink.identifier).\
- join(Identifier.licensed_through).filter(
- Resource.url.in_(cover_urls),
- LicensePool.work_id.in_(work_ids)
- )
+ covers = (
+ _db.query(Resource)
+ .join(Hyperlink.identifier)
+ .join(Identifier.licensed_through)
+ .filter(Resource.url.in_(cover_urls), LicensePool.work_id.in_(work_ids))
+ )
editions = list()
for cover in covers:
@@ -715,9 +723,7 @@ def reject_covers(cls, _db, works_or_identifiers,
def reject_cover(self, search_index_client=None):
"""Suppresses the current cover of the Work"""
_db = Session.object_session(self)
- self.suppress_covers(
- _db, [self], search_index_client=search_index_client
- )
+ self.suppress_covers(_db, [self], search_index_client=search_index_client)
def all_editions(self, policy=None):
"""All Editions identified by an Identifier equivalent to
@@ -728,11 +734,16 @@ def all_editions(self, policy=None):
Identifiers.
"""
from .licensing import LicensePool
+
_db = Session.object_session(self)
- identifier_ids_subquery = Identifier.recursively_equivalent_identifier_ids_query(
- LicensePool.identifier_id, policy=policy
+ identifier_ids_subquery = (
+ Identifier.recursively_equivalent_identifier_ids_query(
+ LicensePool.identifier_id, policy=policy
+ )
+ )
+ identifier_ids_subquery = identifier_ids_subquery.where(
+ LicensePool.work_id == self.id
)
- identifier_ids_subquery = identifier_ids_subquery.where(LicensePool.work_id==self.id)
q = _db.query(Edition).filter(
Edition.primary_identifier_id.in_(identifier_ids_subquery)
@@ -744,10 +755,7 @@ def _direct_identifier_ids(self):
"""Return all Identifier IDs associated with one of this
Work's LicensePools.
"""
- return [
- lp.identifier.id for lp in self.license_pools
- if lp.identifier
- ]
+ return [lp.identifier.id for lp in self.license_pools if lp.identifier]
def all_identifier_ids(self, policy=None):
"""Return all Identifier IDs associated with this Work.
@@ -788,13 +796,15 @@ def age_appropriate_for_patron(self, patron):
return patron.work_is_age_appropriate(self.audience, self.target_age)
def set_presentation_edition(self, new_presentation_edition):
- """ Sets presentation edition and lets owned pools and editions know.
- Raises exception if edition to set to is None.
+ """Sets presentation edition and lets owned pools and editions know.
+ Raises exception if edition to set to is None.
"""
# only bother if something changed, or if were explicitly told to
# set (useful for setting to None)
if not new_presentation_edition:
- error_message = "Trying to set presentation_edition to None on Work [%s]" % self.id
+ error_message = (
+ "Trying to set presentation_edition to None on Work [%s]" % self.id
+ )
raise ValueError(error_message)
self.presentation_edition = new_presentation_edition
@@ -805,7 +815,7 @@ def set_presentation_edition(self, new_presentation_edition):
pool.work = self
def calculate_presentation_edition(self, policy=None):
- """ Which of this Work's Editions should be used as the default?
+ """Which of this Work's Editions should be used as the default?
First, every LicensePool associated with this work must have
its presentation edition set.
Then, we go through the pools, see which has the best presentation edition,
@@ -835,10 +845,7 @@ def calculate_presentation_edition(self, policy=None):
# make sure the pool has most up-to-date idea of its presentation edition,
# and then ask what it is.
pool_edition_changed = pool.set_presentation_edition()
- edition_metadata_changed = (
- edition_metadata_changed or
- pool_edition_changed
- )
+ edition_metadata_changed = edition_metadata_changed or pool_edition_changed
potential_presentation_edition = pool.presentation_edition
# We currently have no real way to choose between
@@ -848,14 +855,18 @@ def calculate_presentation_edition(self, policy=None):
#
# So basically we pick the first available edition and
# make it the presentation edition.
- if (not new_presentation_edition
- or (potential_presentation_edition is old_presentation_edition and old_presentation_edition)):
+ if not new_presentation_edition or (
+ potential_presentation_edition is old_presentation_edition
+ and old_presentation_edition
+ ):
# We would prefer not to change the Work's presentation
# edition unnecessarily, so if the current presentation
# edition is still an option, choose it.
new_presentation_edition = potential_presentation_edition
- if ((self.presentation_edition != new_presentation_edition) and new_presentation_edition != None):
+ if (
+ self.presentation_edition != new_presentation_edition
+ ) and new_presentation_edition != None:
# did we find a pool whose presentation edition was better than the work's?
self.set_presentation_edition(new_presentation_edition)
@@ -865,8 +876,8 @@ def calculate_presentation_edition(self, policy=None):
)
changed = (
- edition_metadata_changed or
- old_presentation_edition != self.presentation_edition
+ edition_metadata_changed
+ or old_presentation_edition != self.presentation_edition
)
return changed
@@ -883,8 +894,12 @@ def _get_default_audience(self):
return None
def calculate_presentation(
- self, policy=None, search_index_client=None, exclude_search=False,
- default_fiction=None, default_audience=None
+ self,
+ policy=None,
+ search_index_client=None,
+ exclude_search=False,
+ default_fiction=None,
+ default_audience=None,
):
"""Make a Work ready to show to patrons.
Call calculate_presentation_edition() to find the best-quality presentation edition
@@ -948,7 +963,7 @@ def calculate_presentation(
classification_changed = self.assign_genres(
all_identifier_ids,
default_fiction=default_fiction,
- default_audience=default_audience
+ default_audience=default_audience,
)
WorkCoverageRecord.add_for(
self, operation=WorkCoverageRecord.CLASSIFY_OPERATION
@@ -956,8 +971,7 @@ def calculate_presentation(
if policy.choose_summary:
self._choose_summary(
- direct_identifier_ids, all_identifier_ids,
- licensed_data_sources
+ direct_identifier_ids, all_identifier_ids, licensed_data_sources
)
if policy.calculate_quality:
@@ -968,9 +982,7 @@ def calculate_presentation(
# put some work into deciding which books to buy.
default_quality = None
for source in licensed_data_sources:
- q = self.default_quality_by_data_source.get(
- source.name, None
- )
+ q = self.default_quality_by_data_source.get(source.name, None)
if q is None:
continue
if default_quality is None or q > default_quality:
@@ -980,9 +992,7 @@ def calculate_presentation(
# if we still haven't found anything of a quality measurement,
# then at least make it an integer zero, not none.
default_quality = 0
- self.calculate_quality(
- all_identifier_ids, default_quality
- )
+ self.calculate_quality(all_identifier_ids, default_quality)
if self.summary_text:
if isinstance(self.summary_text, str):
@@ -993,11 +1003,11 @@ def calculate_presentation(
new_summary_text = self.summary_text
changed = (
- edition_changed or
- classification_changed or
- summary != self.summary or
- summary_text != new_summary_text or
- float(quality) != float(self.quality)
+ edition_changed
+ or classification_changed
+ or summary != self.summary
+ or summary_text != new_summary_text
+ or float(quality) != float(self.quality)
)
if changed:
@@ -1032,8 +1042,7 @@ def calculate_presentation(
self.set_presentation_ready_based_on_content()
def _choose_summary(
- self, direct_identifier_ids, all_identifier_ids,
- licensed_data_sources
+ self, direct_identifier_ids, all_identifier_ids, licensed_data_sources
):
"""Helper method for choosing a summary as part of presentation
calculation.
@@ -1056,9 +1065,7 @@ def _choose_summary(
they are trusted sources such as library staff.
"""
_db = Session.object_session(self)
- staff_data_source = DataSource.lookup(
- _db, DataSourceConstants.LIBRARY_STAFF
- )
+ staff_data_source = DataSource.lookup(_db, DataSourceConstants.LIBRARY_STAFF)
data_sources = [staff_data_source, licensed_data_sources]
summary = None
for id_set in (direct_identifier_ids, all_identifier_ids):
@@ -1080,7 +1087,7 @@ def detailed_representation(self):
if self.presentation_edition and self.presentation_edition.primary_identifier:
primary_identifier = self.presentation_edition.primary_identifier
else:
- primary_identifier=None
+ primary_identifier = None
l.append(" primary id=%s" % primary_identifier)
if self.fiction:
fiction = "Fiction"
@@ -1092,9 +1099,10 @@ def detailed_representation(self):
target_age = " age=" + self.target_age_string
else:
target_age = ""
- l.append(" %(fiction)s a=%(audience)s%(target_age)r" % (
- dict(fiction=fiction,
- audience=self.audience, target_age=target_age)))
+ l.append(
+ " %(fiction)s a=%(audience)s%(target_age)r"
+ % (dict(fiction=fiction, audience=self.audience, target_age=target_age))
+ )
l.append(" " + ", ".join(repr(wg) for wg in self.work_genres))
if self.cover_full_url:
@@ -1124,6 +1132,7 @@ def detailed_representation(self):
l.append(" " + r.final_url)
elif expect_downloads:
l.append(" Expected open-access downloads but found none.")
+
def _ensure(s):
if not s:
return ""
@@ -1141,15 +1150,10 @@ def _ensure(s):
return "\n".join(l)
def calculate_opds_entries(self, verbose=True):
- from ..opds import (
- AcquisitionFeed,
- Annotator,
- VerboseAnnotator,
- )
+ from ..opds import AcquisitionFeed, Annotator, VerboseAnnotator
+
_db = Session.object_session(self)
- simple = AcquisitionFeed.single_entry(
- _db, self, Annotator, force_create=True
- )
+ simple = AcquisitionFeed.single_entry(_db, self, Annotator, force_create=True)
if verbose is True:
verbose = AcquisitionFeed.single_entry(
_db, self, VerboseAnnotator, force_create=True
@@ -1159,13 +1163,12 @@ def calculate_opds_entries(self, verbose=True):
)
def calculate_marc_record(self):
- from ..marc import (
- Annotator,
- MARCExporter
- )
+ from ..marc import Annotator, MARCExporter
+
_db = Session.object_session(self)
record = MARCExporter.create_record(
- self, annotator=Annotator, force_create=True)
+ self, annotator=Annotator, force_create=True
+ )
WorkCoverageRecord.add_for(
self, operation=WorkCoverageRecord.GENERATE_MARC_OPERATION
)
@@ -1212,9 +1215,7 @@ def external_index_needs_updating(self):
This is a more efficient alternative to reindexing immediately,
since these WorkCoverageRecords are handled in large batches.
"""
- return self._reset_coverage(
- WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION
- )
+ return self._reset_coverage(WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION)
def update_external_index(self, client, add_coverage_record=True):
"""Create a WorkCoverageRecord so that this work's
@@ -1277,7 +1278,8 @@ def set_presentation_ready_based_on_content(self, search_index_client=None):
TODO: search_index_client is redundant here.
"""
- if (not self.presentation_edition
+ if (
+ not self.presentation_edition
or not self.license_pools
or not self.title
or not self.language
@@ -1298,22 +1300,29 @@ def calculate_quality(self, identifier_ids, default_quality=0):
# and quality, plus any quantity that might be mapppable to the 0..1
# range -- ratings, and measurements with an associated percentile
# score.
- quantities = set([
- Measurement.POPULARITY, Measurement.QUALITY, Measurement.RATING
- ])
+ quantities = set(
+ [Measurement.POPULARITY, Measurement.QUALITY, Measurement.RATING]
+ )
quantities = quantities.union(list(Measurement.PERCENTILE_SCALES.keys()))
- measurements = _db.query(Measurement).filter(
- Measurement.identifier_id.in_(identifier_ids)).filter(
- Measurement.is_most_recent==True).filter(
- Measurement.quantity_measured.in_(quantities)).all()
+ measurements = (
+ _db.query(Measurement)
+ .filter(Measurement.identifier_id.in_(identifier_ids))
+ .filter(Measurement.is_most_recent == True)
+ .filter(Measurement.quantity_measured.in_(quantities))
+ .all()
+ )
self.quality = Measurement.overall_quality(
- measurements, default_value=default_quality)
- WorkCoverageRecord.add_for(
- self, operation=WorkCoverageRecord.QUALITY_OPERATION
+ measurements, default_value=default_quality
)
+ WorkCoverageRecord.add_for(self, operation=WorkCoverageRecord.QUALITY_OPERATION)
- def assign_genres(self, identifier_ids, default_fiction=False, default_audience=Classifier.AUDIENCE_ADULT):
+ def assign_genres(
+ self,
+ identifier_ids,
+ default_fiction=False,
+ default_audience=Classifier.AUDIENCE_ADULT,
+ ):
"""Set classification information for this work based on the
subquery to get equivalent identifiers.
:return: A boolean explaining whether or not any data actually
@@ -1332,20 +1341,18 @@ def assign_genres(self, identifier_ids, default_fiction=False, default_audience=
for classification in classifications:
classifier.add(classification)
- (genre_weights, self.fiction, self.audience,
- target_age) = classifier.classify(default_fiction=default_fiction,
- default_audience=default_audience)
+ (genre_weights, self.fiction, self.audience, target_age) = classifier.classify(
+ default_fiction=default_fiction, default_audience=default_audience
+ )
self.target_age = tuple_to_numericrange(target_age)
- workgenres, workgenres_changed = self.assign_genres_from_weights(
- genre_weights
- )
+ workgenres, workgenres_changed = self.assign_genres_from_weights(genre_weights)
classification_changed = (
- workgenres_changed or
- old_fiction != self.fiction or
- old_audience != self.audience or
- numericrange_to_tuple(old_target_age) != target_age
+ workgenres_changed
+ or old_fiction != self.fiction
+ or old_audience != self.audience
+ or numericrange_to_tuple(old_target_age) != target_age
)
return classification_changed
@@ -1353,11 +1360,12 @@ def assign_genres(self, identifier_ids, default_fiction=False, default_audience=
def assign_genres_from_weights(self, genre_weights):
# Assign WorkGenre objects to the remainder.
from .classification import Genre
+
changed = False
_db = Session.object_session(self)
total_genre_weight = float(sum(genre_weights.values()))
workgenres = []
- current_workgenres = _db.query(WorkGenre).filter(WorkGenre.work==self)
+ current_workgenres = _db.query(WorkGenre).filter(WorkGenre.work == self)
by_genre = dict()
for wg in current_workgenres:
by_genre[wg.genre] = wg
@@ -1370,9 +1378,8 @@ def assign_genres_from_weights(self, genre_weights):
is_new = False
del by_genre[g]
else:
- wg, is_new = get_one_or_create(
- _db, WorkGenre, work=self, genre=g)
- if is_new or round(wg.affinity,2) != round(affinity, 2):
+ wg, is_new = get_one_or_create(_db, WorkGenre, work=self, genre=g)
+ if is_new or round(wg.affinity, 2) != round(affinity, 2):
changed = True
wg.affinity = affinity
workgenres.append(wg)
@@ -1388,9 +1395,7 @@ def assign_genres_from_weights(self, genre_weights):
return workgenres, changed
-
- def assign_appeals(self, character, language, setting, story,
- cutoff=0.20):
+ def assign_appeals(self, character, language, setting, story, cutoff=0.20):
"""Assign the given appeals to the corresponding database fields,
as well as calculating the primary and secondary appeal.
"""
@@ -1447,77 +1452,76 @@ def to_search_documents(cls, works, policy=None):
# interested in. The work_id, edition_id, and identifier_id columns are used
# by other subqueries to filter, and the remaining columns are used directly
# to create the json document.
- works_alias = select(
- [Work.id.label('work_id'),
- Edition.id.label('edition_id'),
- Edition.primary_identifier_id.label('identifier_id'),
- Edition.title,
- Edition.subtitle,
- Edition.series,
- Edition.series_position,
- Edition.language,
- Edition.sort_title,
- Edition.author,
- Edition.sort_author,
- Edition.medium,
- Edition.publisher,
- Edition.imprint,
- Edition.permanent_work_id,
- Work.fiction,
- Work.audience,
- Work.summary_text,
- Work.quality,
- Work.rating,
- Work.popularity,
- Work.presentation_ready,
- Work.presentation_edition_id,
- func.extract(
- "EPOCH",
- Work.last_update_time,
- ).label('last_update_time')
- ],
- Work.id.in_((w.id for w in works))
- ).select_from(
- join(
- Work, Edition,
- Work.presentation_edition_id==Edition.id
+ works_alias = (
+ select(
+ [
+ Work.id.label("work_id"),
+ Edition.id.label("edition_id"),
+ Edition.primary_identifier_id.label("identifier_id"),
+ Edition.title,
+ Edition.subtitle,
+ Edition.series,
+ Edition.series_position,
+ Edition.language,
+ Edition.sort_title,
+ Edition.author,
+ Edition.sort_author,
+ Edition.medium,
+ Edition.publisher,
+ Edition.imprint,
+ Edition.permanent_work_id,
+ Work.fiction,
+ Work.audience,
+ Work.summary_text,
+ Work.quality,
+ Work.rating,
+ Work.popularity,
+ Work.presentation_ready,
+ Work.presentation_edition_id,
+ func.extract(
+ "EPOCH",
+ Work.last_update_time,
+ ).label("last_update_time"),
+ ],
+ Work.id.in_((w.id for w in works)),
+ )
+ .select_from(
+ join(Work, Edition, Work.presentation_edition_id == Edition.id)
)
- ).alias('works_alias')
+ .alias("works_alias")
+ )
work_id_column = literal_column(
- works_alias.name + '.' + works_alias.c.work_id.name
+ works_alias.name + "." + works_alias.c.work_id.name
)
work_presentation_edition_id_column = literal_column(
- works_alias.name + '.' + works_alias.c.presentation_edition_id.name
+ works_alias.name + "." + works_alias.c.presentation_edition_id.name
)
work_quality_column = literal_column(
- works_alias.name + '.' + works_alias.c.quality.name
+ works_alias.name + "." + works_alias.c.quality.name
)
def query_to_json(query):
"""Convert the results of a query to a JSON object."""
- return select(
- [func.row_to_json(literal_column(query.name))]
- ).select_from(query)
+ return select([func.row_to_json(literal_column(query.name))]).select_from(
+ query
+ )
def query_to_json_array(query):
"""Convert the results of a query into a JSON array."""
return select(
- [func.array_to_json(
- func.array_agg(
- func.row_to_json(
- literal_column(query.name)
- )))]
+ [
+ func.array_to_json(
+ func.array_agg(func.row_to_json(literal_column(query.name)))
+ )
+ ]
).select_from(query)
# This subquery gets Collection IDs for collections
# that own more than zero licenses for this book.
- from .classification import (
- Genre,
- Subject,
- )
+ from .classification import Genre, Subject
from .customlist import CustomListEntry
from .licensing import LicensePool
@@ -1546,49 +1550,52 @@ def explicit_bool(label, t):
# True/None. Elasticsearch can't filter on null values.
return case([(t, True)], else_=False).label(label)
- licensepools = select(
- [
- LicensePool.id.label('licensepool_id'),
- LicensePool.data_source_id.label('data_source_id'),
- LicensePool.collection_id.label('collection_id'),
- LicensePool.open_access.label('open_access'),
- LicensePool.suppressed,
-
- explicit_bool(
- 'available',
- or_(
- LicensePool.unlimited_access,
- LicensePool.self_hosted,
- LicensePool.licenses_available > 0,
- )
- ),
- explicit_bool(
- 'licensed',
+ licensepools = (
+ select(
+ [
+ LicensePool.id.label("licensepool_id"),
+ LicensePool.data_source_id.label("data_source_id"),
+ LicensePool.collection_id.label("collection_id"),
+ LicensePool.open_access.label("open_access"),
+ LicensePool.suppressed,
+ explicit_bool(
+ "available",
+ or_(
+ LicensePool.unlimited_access,
+ LicensePool.self_hosted,
+ LicensePool.licenses_available > 0,
+ ),
+ ),
+ explicit_bool(
+ "licensed",
+ or_(
+ LicensePool.unlimited_access,
+ LicensePool.self_hosted,
+ LicensePool.licenses_owned > 0,
+ ),
+ ),
+ work_quality_column,
+ Edition.medium,
+ func.extract(
+ "EPOCH",
+ LicensePool.availability_time,
+ ).label("availability_time"),
+ ]
+ )
+ .where(
+ and_(
+ LicensePool.work_id == work_id_column,
+ work_presentation_edition_id_column == Edition.id,
or_(
+ LicensePool.open_access,
LicensePool.unlimited_access,
LicensePool.self_hosted,
- LicensePool.licenses_owned > 0
- )
- ),
- work_quality_column,
- Edition.medium,
- func.extract(
- "EPOCH",
- LicensePool.availability_time,
- ).label('availability_time')
- ]
- ).where(
- and_(
- LicensePool.work_id==work_id_column,
- work_presentation_edition_id_column==Edition.id,
- or_(
- LicensePool.open_access,
- LicensePool.unlimited_access,
- LicensePool.self_hosted,
- LicensePool.licenses_owned>0,
- ),
+ LicensePool.licenses_owned > 0,
+ ),
+ )
)
- ).alias("licensepools_subquery")
+ .alias("licensepools_subquery")
+ )
licensepools_json = query_to_json_array(licensepools)
# This subquery gets CustomList IDs for all lists
@@ -1601,37 +1608,49 @@ def explicit_bool(label, t):
# And we keep track of the first time the work appears on the list.
# This is used when generating a crawlable feed for the customlist,
# which is ordered by a work's first appearance on the list.
- customlists = select(
- [
- CustomListEntry.list_id.label('list_id'),
- CustomListEntry.featured.label('featured'),
- func.extract(
- "EPOCH",
- CustomListEntry.first_appearance,
- ).label('first_appearance')
- ]
- ).where(
- CustomListEntry.work_id==work_id_column
- ).alias("listentries_subquery")
+ customlists = (
+ select(
+ [
+ CustomListEntry.list_id.label("list_id"),
+ CustomListEntry.featured.label("featured"),
+ func.extract(
+ "EPOCH",
+ CustomListEntry.first_appearance,
+ ).label("first_appearance"),
+ ]
+ )
+ .where(CustomListEntry.work_id == work_id_column)
+ .alias("listentries_subquery")
+ )
customlists_json = query_to_json_array(customlists)
# This subquery gets Contributors, filtered on edition_id.
- contributors = select(
- [Contributor.sort_name,
- Contributor.display_name,
- Contributor.family_name,
- Contributor.lc,
- Contributor.viaf,
- Contribution.role,
- ]
- ).where(
- Contribution.edition_id==literal_column(works_alias.name + "." + works_alias.c.edition_id.name)
- ).select_from(
- join(
- Contributor, Contribution,
- Contributor.id==Contribution.contributor_id
+ contributors = (
+ select(
+ [
+ Contributor.sort_name,
+ Contributor.display_name,
+ Contributor.family_name,
+ Contributor.lc,
+ Contributor.viaf,
+ Contribution.role,
+ ]
+ )
+ .where(
+ Contribution.edition_id
+ == literal_column(
+ works_alias.name + "." + works_alias.c.edition_id.name
+ )
+ )
+ .select_from(
+ join(
+ Contributor,
+ Contribution,
+ Contributor.id == Contribution.contributor_id,
+ )
)
- ).alias("contributors_subquery")
+ .alias("contributors_subquery")
+ )
contributors_json = query_to_json_array(contributors)
# Use a subquery to get recursively equivalent Identifiers
@@ -1643,121 +1662,137 @@ def explicit_bool(label, t):
# and recommendations. The index is completely rebuilt once a
# day, and that's good enough.
equivalent_identifiers = Identifier.recursively_equivalent_identifier_ids_query(
- literal_column(
- works_alias.name + "." + works_alias.c.identifier_id.name
- ),
- policy=policy
+ literal_column(works_alias.name + "." + works_alias.c.identifier_id.name),
+ policy=policy,
).alias("equivalent_identifiers_subquery")
- identifiers = select(
- [
- Identifier.identifier.label('identifier'),
- Identifier.type.label('type'),
- ]
- ).where(
- Identifier.id.in_(equivalent_identifiers)
- ).alias("identifier_subquery")
+ identifiers = (
+ select(
+ [
+ Identifier.identifier.label("identifier"),
+ Identifier.type.label("type"),
+ ]
+ )
+ .where(Identifier.id.in_(equivalent_identifiers))
+ .alias("identifier_subquery")
+ )
identifiers_json = query_to_json_array(identifiers)
# Map our constants for Subject type to their URIs.
scheme_column = case(
- [(Subject.type==key, literal_column("'%s'" % val)) for key, val in list(Subject.uri_lookup.items())]
+ [
+ (Subject.type == key, literal_column("'%s'" % val))
+ for key, val in list(Subject.uri_lookup.items())
+ ]
)
# If the Subject has a name, use that, otherwise use the Subject's identifier.
# Also, 3M's classifications have slashes, e.g. "FICTION/Adventure". Make sure
# we get separated words for search.
- term_column = func.replace(case([(Subject.name != None, Subject.name)], else_=Subject.identifier), "/", " ")
+ term_column = func.replace(
+ case([(Subject.name != None, Subject.name)], else_=Subject.identifier),
+ "/",
+ " ",
+ )
# Normalize by dividing each weight by the sum of the weights for that Identifier's Classifications.
from .classification import Classification
- weight_column = func.sum(Classification.weight) / func.sum(func.sum(Classification.weight)).over()
+
+ weight_column = (
+ func.sum(Classification.weight)
+ / func.sum(func.sum(Classification.weight)).over()
+ )
# The subquery for Subjects, with those three columns. The labels will become keys in json objects.
- subjects = select(
- [scheme_column.label('scheme'),
- term_column.label('term'),
- weight_column.label('weight'),
- ],
- # Only include Subjects with terms that are useful for search.
- and_(Subject.type.in_(Subject.TYPES_FOR_SEARCH),
- term_column != None)
- ).group_by(
- scheme_column, term_column
- ).where(
- Classification.identifier_id.in_(equivalent_identifiers)
- ).select_from(
- join(Classification, Subject, Classification.subject_id==Subject.id)
- ).alias("subjects_subquery")
+ subjects = (
+ select(
+ [
+ scheme_column.label("scheme"),
+ term_column.label("term"),
+ weight_column.label("weight"),
+ ],
+ # Only include Subjects with terms that are useful for search.
+ and_(Subject.type.in_(Subject.TYPES_FOR_SEARCH), term_column != None),
+ )
+ .group_by(scheme_column, term_column)
+ .where(Classification.identifier_id.in_(equivalent_identifiers))
+ .select_from(
+ join(Classification, Subject, Classification.subject_id == Subject.id)
+ )
+ .alias("subjects_subquery")
+ )
subjects_json = query_to_json_array(subjects)
-
# Subquery for genres.
- genres = select(
- # All Genres have the same scheme - the simplified genre URI.
- [literal_column("'%s'" % Subject.SIMPLIFIED_GENRE).label('scheme'),
- Genre.name,
- Genre.id.label('term'),
- WorkGenre.affinity.label('weight'),
- ]
- ).where(
- WorkGenre.work_id==literal_column(works_alias.name + "." + works_alias.c.work_id.name)
- ).select_from(
- join(WorkGenre, Genre, WorkGenre.genre_id==Genre.id)
- ).alias("genres_subquery")
+ genres = (
+ select(
+ # All Genres have the same scheme - the simplified genre URI.
+ [
+ literal_column("'%s'" % Subject.SIMPLIFIED_GENRE).label("scheme"),
+ Genre.name,
+ Genre.id.label("term"),
+ WorkGenre.affinity.label("weight"),
+ ]
+ )
+ .where(
+ WorkGenre.work_id
+ == literal_column(works_alias.name + "." + works_alias.c.work_id.name)
+ )
+ .select_from(join(WorkGenre, Genre, WorkGenre.genre_id == Genre.id))
+ .alias("genres_subquery")
+ )
genres_json = query_to_json_array(genres)
target_age = cls.target_age_query(
literal_column(works_alias.name + "." + works_alias.c.work_id.name)
- ).alias('target_age_subquery')
+ ).alias("target_age_subquery")
target_age_json = query_to_json(target_age)
# Now, create a query that brings together everything we need for the final
# search document.
- search_data = select(
- [works_alias.c.work_id.label("_id"),
- works_alias.c.work_id.label("work_id"),
- works_alias.c.title,
- works_alias.c.sort_title,
- works_alias.c.subtitle,
- works_alias.c.series,
- works_alias.c.series_position,
- works_alias.c.language,
- works_alias.c.author,
- works_alias.c.sort_author,
- works_alias.c.medium,
- works_alias.c.publisher,
- works_alias.c.imprint,
- works_alias.c.permanent_work_id,
- works_alias.c.presentation_ready,
- works_alias.c.last_update_time,
-
- # Convert true/false to "Fiction"/"Nonfiction".
- case(
- [(works_alias.c.fiction==True, literal_column("'Fiction'"))],
- else_=literal_column("'Nonfiction'")
+ search_data = (
+ select(
+ [
+ works_alias.c.work_id.label("_id"),
+ works_alias.c.work_id.label("work_id"),
+ works_alias.c.title,
+ works_alias.c.sort_title,
+ works_alias.c.subtitle,
+ works_alias.c.series,
+ works_alias.c.series_position,
+ works_alias.c.language,
+ works_alias.c.author,
+ works_alias.c.sort_author,
+ works_alias.c.medium,
+ works_alias.c.publisher,
+ works_alias.c.imprint,
+ works_alias.c.permanent_work_id,
+ works_alias.c.presentation_ready,
+ works_alias.c.last_update_time,
+ # Convert true/false to "Fiction"/"Nonfiction".
+ case(
+ [(works_alias.c.fiction == True, literal_column("'Fiction'"))],
+ else_=literal_column("'Nonfiction'"),
).label("fiction"),
-
- # Replace "Young Adult" with "YoungAdult" and "Adults Only" with "AdultsOnly".
- func.replace(works_alias.c.audience, " ", "").label('audience'),
-
- works_alias.c.summary_text.label('summary'),
- works_alias.c.quality,
- works_alias.c.rating,
- works_alias.c.popularity,
-
- # Here are all the subqueries.
- licensepools_json.label("licensepools"),
- customlists_json.label("customlists"),
- contributors_json.label("contributors"),
- identifiers_json.label("identifiers"),
- subjects_json.label("classifications"),
- genres_json.label('genres'),
- target_age_json.label('target_age'),
- ]
- ).select_from(
- works_alias
- ).alias("search_data_subquery")
+ # Replace "Young Adult" with "YoungAdult" and "Adults Only" with "AdultsOnly".
+ func.replace(works_alias.c.audience, " ", "").label("audience"),
+ works_alias.c.summary_text.label("summary"),
+ works_alias.c.quality,
+ works_alias.c.rating,
+ works_alias.c.popularity,
+ # Here are all the subqueries.
+ licensepools_json.label("licensepools"),
+ customlists_json.label("customlists"),
+ contributors_json.label("contributors"),
+ identifiers_json.label("identifiers"),
+ subjects_json.label("classifications"),
+ genres_json.label("genres"),
+ target_age_json.label("target_age"),
+ ]
+ )
+ .select_from(works_alias)
+ .alias("search_data_subquery")
+ )
# Finally, convert everything to json.
search_json = query_to_json(search_data)
@@ -1772,25 +1807,19 @@ def target_age_query(self, foreign_work_id_field):
# it alone. Otherwise, we subtract one to make it inclusive.
upper_field = func.upper(Work.target_age)
upper = case(
- [(func.upper_inc(Work.target_age), upper_field)],
- else_=upper_field-1
- ).label('upper')
+ [(func.upper_inc(Work.target_age), upper_field)], else_=upper_field - 1
+ ).label("upper")
# If the lower limit of the target age is inclusive, we leave
# it alone. Otherwise, we add one to make it inclusive.
lower_field = func.lower(Work.target_age)
lower = case(
- [(func.lower_inc(Work.target_age), lower_field)],
- else_=lower_field+1
- ).label('lower')
+ [(func.lower_inc(Work.target_age), lower_field)], else_=lower_field + 1
+ ).label("lower")
# Subquery for target age. This has to be a subquery so it can
# become a nested object in the final json.
- target_age = select(
- [upper, lower]
- ).where(
- Work.id==foreign_work_id_field
- )
+ target_age = select([upper, lower]).where(Work.id == foreign_work_id_field)
return target_age
def to_search_document(self):
@@ -1818,58 +1847,66 @@ def mark_licensepools_as_superceded(self):
@classmethod
def restrict_to_custom_lists_from_data_source(
- cls, _db, base_query, data_source, on_list_as_of=None):
+ cls, _db, base_query, data_source, on_list_as_of=None
+ ):
"""Annotate a query that joins Work against Edition to match only
Works that are on a custom list from the given data source."""
- condition = CustomList.data_source==data_source
+ condition = CustomList.data_source == data_source
return cls._restrict_to_customlist_subquery_condition(
- _db, base_query, condition, on_list_as_of)
+ _db, base_query, condition, on_list_as_of
+ )
@classmethod
def restrict_to_custom_lists(
- cls, _db, base_query, custom_lists, on_list_as_of=None):
+ cls, _db, base_query, custom_lists, on_list_as_of=None
+ ):
"""Annotate a query that joins Work against Edition to match only
Works that are on one of the given custom lists."""
condition = CustomList.id.in_([x.id for x in custom_lists])
return cls._restrict_to_customlist_subquery_condition(
- _db, base_query, condition, on_list_as_of)
+ _db, base_query, condition, on_list_as_of
+ )
@classmethod
def _restrict_to_customlist_subquery_condition(
- cls, _db, base_query, condition, on_list_as_of=None):
+ cls, _db, base_query, condition, on_list_as_of=None
+ ):
"""Annotate a query that joins Work against Edition to match only
Works that are on a custom list from the given data source."""
# Find works that are on a list that meets the given condition.
qu = base_query.join(LicensePool.custom_list_entries).join(
- CustomListEntry.customlist)
+ CustomListEntry.customlist
+ )
if on_list_as_of:
- qu = qu.filter(
- CustomListEntry.most_recent_appearance >= on_list_as_of)
+ qu = qu.filter(CustomListEntry.most_recent_appearance >= on_list_as_of)
qu = qu.filter(condition)
return qu
def classifications_with_genre(self):
- from .classification import (
- Classification,
- Subject,
- )
+ from .classification import Classification, Subject
+
_db = Session.object_session(self)
identifier = self.presentation_edition.primary_identifier
- return _db.query(Classification) \
- .join(Subject) \
- .filter(Classification.identifier_id == identifier.id) \
- .filter(Subject.genre_id != None) \
+ return (
+ _db.query(Classification)
+ .join(Subject)
+ .filter(Classification.identifier_id == identifier.id)
+ .filter(Subject.genre_id != None)
.order_by(Classification.weight.desc())
+ )
def top_genre(self):
from .classification import Genre
+
_db = Session.object_session(self)
- genre = _db.query(Genre) \
- .join(WorkGenre) \
- .filter(WorkGenre.work_id == self.id) \
- .order_by(WorkGenre.affinity.desc()) \
+ genre = (
+ _db.query(Genre)
+ .join(WorkGenre)
+ .filter(WorkGenre.work_id == self.id)
+ .order_by(WorkGenre.affinity.desc())
.first()
+ )
return genre.name if genre else None
def delete(self, search_index=None):
@@ -1878,6 +1915,7 @@ def delete(self, search_index=None):
if search_index is None:
try:
from ..external_search import ExternalSearchIndex
+
search_index = ExternalSearchIndex(_db)
except CannotLoadConfiguration as e:
# No search index is configured. This is fine -- just skip that part.
diff --git a/monitor.py b/monitor.py
index fb92899c8..0e8797469 100644
--- a/monitor.py
+++ b/monitor.py
@@ -1,13 +1,11 @@
import datetime
import logging
import traceback
+
from sqlalchemy.orm import defer
-from sqlalchemy.sql.expression import (
- and_,
- or_,
-)
+from sqlalchemy.sql.expression import and_, or_
-from . import log # This sets the appropriate log format and level.
+from . import log # This sets the appropriate log format and level.
from .config import Configuration
from .metadata_layer import TimestampData
from .model import (
@@ -40,12 +38,12 @@ class CollectionMonitorLogger(logging.LoggerAdapter):
def __init__(self, logger, extra):
self.logger = logger
self.extra = extra
- collection = self.extra.get('collection', None)
- self.log_prefix = '[{}] '.format(collection.name) if collection else ''
+ collection = self.extra.get("collection", None)
+ self.log_prefix = "[{}] ".format(collection.name) if collection else ""
self.warn = self.warning
def process(self, msg, kwargs):
- return '{}{}'.format(self.log_prefix, msg), kwargs
+ return "{}{}".format(self.log_prefix, msg), kwargs
class Monitor(object):
@@ -68,6 +66,7 @@ class Monitor(object):
that needs to be run on every Collection of a certain type.
"""
+
# In your subclass, set this to the name of the service,
# e.g. "Overdrive Circulation Monitor". All instances of your
# subclass will give this as their service name and track their
@@ -76,7 +75,7 @@ class Monitor(object):
# Some useful relative constants for DEFAULT_START_TIME (below).
ONE_MINUTE_AGO = datetime.timedelta(seconds=60)
- ONE_YEAR_AGO = datetime.timedelta(seconds=60*60*24*365)
+ ONE_YEAR_AGO = datetime.timedelta(seconds=60 * 60 * 24 * 365)
NEVER = object()
# If there is no Timestamp for this Monitor, this time will be
@@ -97,9 +96,7 @@ def __init__(self, _db, collection=None):
self.service_name = self.SERVICE_NAME
default_start_time = cls.DEFAULT_START_TIME
if isinstance(default_start_time, datetime.timedelta):
- default_start_time = (
- utc_now() - default_start_time
- )
+ default_start_time = utc_now() - default_start_time
self.default_start_time = default_start_time
self.default_counter = cls.DEFAULT_COUNTER
@@ -112,10 +109,10 @@ def __init__(self, _db, collection=None):
@property
def log(self):
- if not hasattr(self, '_log'):
+ if not hasattr(self, "_log"):
self._log = CollectionMonitorLogger(
logging.getLogger(self.service_name),
- {'collection': self.collection},
+ {"collection": self.collection},
)
return self._log
@@ -150,7 +147,8 @@ def timestamp(self):
"""
initial_timestamp = self.initial_start_time
timestamp, new = get_one_or_create(
- self._db, Timestamp,
+ self._db,
+ Timestamp,
service=self.service_name,
service_type=Timestamp.MONITOR_TYPE,
collection=self.collection,
@@ -158,7 +156,7 @@ def timestamp(self):
start=initial_timestamp,
finish=None,
counter=self.default_counter,
- )
+ ),
)
return timestamp
@@ -197,7 +195,7 @@ def run(self):
collection=self.collection,
start=this_run_start,
finish=this_run_finish,
- exception=None
+ exception=None,
)
new_timestamp.apply(self._db)
else:
@@ -208,7 +206,7 @@ def run(self):
this_run_finish = utc_now()
self.log.exception(
"Error running %s monitor. Timestamp will not be updated.",
- self.service_name
+ self.service_name,
)
exception = traceback.format_exc()
if exception is not None:
@@ -222,7 +220,8 @@ def run(self):
duration = this_run_finish - this_run_start
self.log.info(
- "Ran %s monitor in %.2f sec.", self.service_name,
+ "Ran %s monitor in %.2f sec.",
+ self.service_name,
duration.total_seconds(),
)
@@ -257,6 +256,7 @@ class TimelineMonitor(Monitor):
the span of time covered in the most recent run, not the time it
actually took to run.
"""
+
OVERLAP = datetime.timedelta(minutes=5)
def run_once(self, progress):
@@ -352,9 +352,8 @@ def _validate_collection(cls, collection, protocol=None):
protocol = protocol or cls.PROTOCOL
if protocol and collection.protocol != protocol:
raise ValueError(
- "Collection protocol (%s) does not match Monitor protocol (%s)" % (
- collection.protocol, protocol
- )
+ "Collection protocol (%s) does not match Monitor protocol (%s)"
+ % (collection.protocol, protocol)
)
@classmethod
@@ -379,14 +378,15 @@ def all(cls, _db, collections=None, **constructor_kwargs):
into the CollectionMonitor constructor.
"""
- service_match = or_(Timestamp.service==cls.SERVICE_NAME,
- Timestamp.service==None)
+ service_match = or_(
+ Timestamp.service == cls.SERVICE_NAME, Timestamp.service == None
+ )
collections_for_protocol = Collection.by_protocol(_db, cls.PROTOCOL).outerjoin(
Timestamp,
and_(
- Timestamp.collection_id==Collection.id,
+ Timestamp.collection_id == Collection.id,
service_match,
- )
+ ),
)
if collections:
@@ -395,11 +395,13 @@ def all(cls, _db, collections=None, **constructor_kwargs):
try:
cls._validate_collection(coll, cls.PROTOCOL)
except ValueError as e:
- additional_info = 'Only the following collections are available: {!r}'.format(
- [c.name for c in collections_for_protocol]
+ additional_info = (
+ "Only the following collections are available: {!r}".format(
+ [c.name for c in collections_for_protocol]
+ )
)
e.args += (additional_info,)
- raise ValueError(str(e) + '\n' + additional_info)
+ raise ValueError(str(e) + "\n" + additional_info)
else:
collections = collections_for_protocol.order_by(
Timestamp.start.asc().nullsfirst()
@@ -466,8 +468,10 @@ def run_once(self, *ignore):
self.log.debug(
"%s monitor went from offset %s to %s in %.2f sec",
- self.service_name, offset, new_offset,
- (batch_ended_at-batch_started_at).total_seconds()
+ self.service_name,
+ offset,
+ new_offset,
+ (batch_ended_at - batch_started_at).total_seconds(),
)
achievements = "Records processed: %d." % total_processed
@@ -479,8 +483,7 @@ def run_once(self, *ignore):
# We need to do another batch. If it should raise an exception,
# we don't want to lose the progress we've already made.
timestamp.update(
- counter=new_offset, finish=batch_ended_at,
- achievements=achievements
+ counter=new_offset, finish=batch_ended_at, achievements=achievements
)
self._db.commit()
@@ -510,8 +513,12 @@ def process_items(self, items):
def fetch_batch(self, offset):
"""Retrieve one batch of work from the database."""
- q = self.item_query().filter(self.model_class.id > offset).order_by(
- self.model_class.id).limit(self.batch_size)
+ q = (
+ self.item_query()
+ .filter(self.model_class.id > offset)
+ .order_by(self.model_class.id)
+ .limit(self.batch_size)
+ )
return q
def item_query(self):
@@ -544,17 +551,19 @@ def process_item(self, item):
class IdentifierSweepMonitor(SweepMonitor):
"""A Monitor that does some work for every Identifier."""
+
MODEL_CLASS = Identifier
def scope_to_collection(self, qu, collection):
"""Only find Identifiers licensed through the given Collection."""
return qu.join(Identifier.licensed_through).filter(
- LicensePool.collection==collection
+ LicensePool.collection == collection
)
class SubjectSweepMonitor(SweepMonitor):
"""A Monitor that does some work for every Subject."""
+
MODEL_CLASS = Subject
# It's usually easy to process a Subject, so make the batch size
@@ -575,12 +584,12 @@ def item_query(self):
"""Find only Subjects that match the given filters."""
qu = self._db.query(Subject)
if self.subject_type:
- qu = qu.filter(Subject.type==self.subject_type)
+ qu = qu.filter(Subject.type == self.subject_type)
if self.filter_string:
- filter_string = '%' + self.filter_string + '%'
+ filter_string = "%" + self.filter_string + "%"
or_clause = or_(
Subject.identifier.ilike(filter_string),
- Subject.name.ilike(filter_string)
+ Subject.name.ilike(filter_string),
)
qu = qu.filter(or_clause)
return qu
@@ -592,52 +601,58 @@ def scope_to_collection(self, qu, collection):
class CustomListEntrySweepMonitor(SweepMonitor):
"""A Monitor that does something to every CustomListEntry."""
+
MODEL_CLASS = CustomListEntry
def scope_to_collection(self, qu, collection):
"""Restrict the query to only find CustomListEntries whose
Work is in the given Collection.
"""
- return qu.join(CustomListEntry.work).join(Work.license_pools).filter(
- LicensePool.collection==collection
+ return (
+ qu.join(CustomListEntry.work)
+ .join(Work.license_pools)
+ .filter(LicensePool.collection == collection)
)
class EditionSweepMonitor(SweepMonitor):
"""A Monitor that does something to every Edition."""
+
MODEL_CLASS = Edition
def scope_to_collection(self, qu, collection):
"""Restrict the query to only find Editions whose
primary Identifier is licensed to the given Collection.
"""
- return qu.join(Edition.primary_identifier).join(
- Identifier.licensed_through).filter(
- LicensePool.collection==collection
- )
+ return (
+ qu.join(Edition.primary_identifier)
+ .join(Identifier.licensed_through)
+ .filter(LicensePool.collection == collection)
+ )
class WorkSweepMonitor(SweepMonitor):
"""A Monitor that does something to every Work."""
+
MODEL_CLASS = Work
def scope_to_collection(self, qu, collection):
"""Restrict the query to only find Works found in the given
Collection.
"""
- return qu.join(Work.license_pools).filter(
- LicensePool.collection==collection
- )
+ return qu.join(Work.license_pools).filter(LicensePool.collection == collection)
class PresentationReadyWorkSweepMonitor(WorkSweepMonitor):
"""A Monitor that does something to every presentation-ready Work."""
def item_query(self):
- return super(
- PresentationReadyWorkSweepMonitor, self).item_query().filter(
- Work.presentation_ready==True
- )
+ return (
+ super(PresentationReadyWorkSweepMonitor, self)
+ .item_query()
+ .filter(Work.presentation_ready == True)
+ )
+
class NotPresentationReadyWorkSweepMonitor(WorkSweepMonitor):
"""A Monitor that does something to every Work that is not
@@ -646,17 +661,18 @@ class NotPresentationReadyWorkSweepMonitor(WorkSweepMonitor):
def item_query(self):
not_presentation_ready = or_(
- Work.presentation_ready==False,
- Work.presentation_ready==None
+ Work.presentation_ready == False, Work.presentation_ready == None
+ )
+ return (
+ super(NotPresentationReadyWorkSweepMonitor, self)
+ .item_query()
+ .filter(not_presentation_ready)
)
- return super(
- NotPresentationReadyWorkSweepMonitor, self).item_query().filter(
- not_presentation_ready
- )
# SweepMonitors that do something specific.
+
class OPDSEntryCacheMonitor(PresentationReadyWorkSweepMonitor):
"""A Monitor that recalculates the OPDS entries for every
presentation-ready Work.
@@ -665,15 +681,18 @@ class OPDSEntryCacheMonitor(PresentationReadyWorkSweepMonitor):
which only processes works that are missing a WorkCoverageRecord
with the 'generate-opds' operation.
"""
+
SERVICE_NAME = "ODPS Entry Cache Monitor"
def process_item(self, work):
work.calculate_opds_entries()
+
class PermanentWorkIDRefreshMonitor(EditionSweepMonitor):
"""A monitor that calculates or recalculates the permanent work ID for
every edition.
"""
+
SERVICE_NAME = "Permanent work ID refresh"
def process_item(self, edition):
@@ -688,14 +707,13 @@ class MakePresentationReadyMonitor(NotPresentationReadyWorkSweepMonitor):
the ensure_coverage() calls succeed, presentation of the work is
calculated and the work is marked presentation ready.
"""
+
SERVICE_NAME = "Make Works Presentation Ready"
def __init__(self, _db, coverage_providers, collection=None):
super(MakePresentationReadyMonitor, self).__init__(_db, collection)
self.coverage_providers = coverage_providers
- self.policy = PresentationCalculationPolicy(
- choose_edition=False
- )
+ self.policy = PresentationCalculationPolicy(choose_edition=False)
def run(self):
"""Before doing anything, consolidate works."""
@@ -713,9 +731,7 @@ def process_item(self, work):
except CoverageProvidersFailed as e:
exception = "Provider(s) failed: %s" % e
except Exception as e:
- self.log.error(
- "Exception processing work %r", work, exc_info=e
- )
+ self.log.error("Exception processing work %r", work, exc_info=e)
exception = str(e)
if exception:
@@ -745,9 +761,11 @@ def prepare(self, work):
covered_types = provider.input_identifier_types
if covered_types and identifier.type in covered_types:
coverage_record = provider.ensure_coverage(identifier)
- if (not isinstance(coverage_record, CoverageRecord)
+ if (
+ not isinstance(coverage_record, CoverageRecord)
or coverage_record.status != CoverageRecord.SUCCESS
- or coverage_record.exception is not None):
+ or coverage_record.exception is not None
+ ):
# This provider has failed.
failures.append(provider)
if failures:
@@ -759,6 +777,7 @@ class CoverageProvidersFailed(Exception):
"""We tried to run CoverageProviders on a Work's identifier,
but some of the providers failed.
"""
+
def __init__(self, failed_providers):
self.failed_providers = failed_providers
super(CoverageProvidersFailed, self).__init__(
@@ -769,6 +788,7 @@ def __init__(self, failed_providers):
class CustomListEntryWorkUpdateMonitor(CustomListEntrySweepMonitor):
"""Set or reset the Work associated with each custom list entry."""
+
SERVICE_NAME = "Update Works for custom list entries"
DEFAULT_BATCH_SIZE = 100
@@ -799,6 +819,7 @@ class ReaperMonitor(Monitor):
into a list called LARGE_FIELDS and the Reaper will avoid fetching
that information, improving performance.
"""
+
MODEL_CLASS = None
TIMESTAMP_FIELD = None
MAX_AGE = None
@@ -815,8 +836,7 @@ def __init__(self, *args, **kwargs):
@property
def cutoff(self):
- """Items with a timestamp earlier than this time will be reaped.
- """
+ """Items with a timestamp earlier than this time will be reaped."""
if isinstance(self.MAX_AGE, datetime.timedelta):
max_age = self.MAX_AGE
else:
@@ -829,14 +849,13 @@ def timestamp_field(self):
@property
def where_clause(self):
- """A SQLAlchemy clause that identifies the database rows to be reaped.
- """
+ """A SQLAlchemy clause that identifies the database rows to be reaped."""
return self.timestamp_field < self.cutoff
def run_once(self, *args, **kwargs):
rows_deleted = 0
qu = self.query()
- to_defer = getattr(self.MODEL_CLASS, 'LARGE_FIELDS', [])
+ to_defer = getattr(self.MODEL_CLASS, "LARGE_FIELDS", [])
for x in to_defer:
qu = qu.options(defer(x))
count = qu.count()
@@ -862,28 +881,40 @@ def delete(self, row):
def query(self):
return self._db.query(self.MODEL_CLASS).filter(self.where_clause)
+
# ReaperMonitors that do something specific.
+
class CachedFeedReaper(ReaperMonitor):
"""Removed cached feeds older than thirty days."""
+
MODEL_CLASS = CachedFeed
- TIMESTAMP_FIELD = 'timestamp'
+ TIMESTAMP_FIELD = "timestamp"
MAX_AGE = 30
+
+
ReaperMonitor.REGISTRY.append(CachedFeedReaper)
class CredentialReaper(ReaperMonitor):
"""Remove Credentials that expired more than a day ago."""
+
MODEL_CLASS = Credential
- TIMESTAMP_FIELD = 'expires'
+ TIMESTAMP_FIELD = "expires"
MAX_AGE = 1
+
+
ReaperMonitor.REGISTRY.append(CredentialReaper)
+
class PatronRecordReaper(ReaperMonitor):
"""Remove patron records that expired more than 60 days ago"""
+
MODEL_CLASS = Patron
- TIMESTAMP_FIELD = 'authorization_expires'
+ TIMESTAMP_FIELD = "authorization_expires"
MAX_AGE = 60
+
+
ReaperMonitor.REGISTRY.append(PatronRecordReaper)
@@ -893,36 +924,39 @@ class WorkReaper(ReaperMonitor):
Unlike other reapers, no timestamp is relevant. As soon as a Work
loses its last LicensePool it can be removed.
"""
+
MODEL_CLASS = Work
def __init__(self, *args, **kwargs):
from .external_search import ExternalSearchIndex
- search_index_client = kwargs.pop('search_index_client', None)
+
+ search_index_client = kwargs.pop("search_index_client", None)
super(WorkReaper, self).__init__(*args, **kwargs)
- self.search_index_client = (
- search_index_client or ExternalSearchIndex(self._db)
- )
+ self.search_index_client = search_index_client or ExternalSearchIndex(self._db)
def query(self):
- return self._db.query(Work).outerjoin(Work.license_pools).filter(
- LicensePool.id==None
+ return (
+ self._db.query(Work)
+ .outerjoin(Work.license_pools)
+ .filter(LicensePool.id == None)
)
def delete(self, work):
"""Delete work from elasticsearch and database."""
work.delete(self.search_index_client)
+
ReaperMonitor.REGISTRY.append(WorkReaper)
class CollectionReaper(ReaperMonitor):
"""Remove collections that have been marked for deletion."""
+
MODEL_CLASS = Collection
@property
def where_clause(self):
- """A SQLAlchemy clause that identifies the database rows to be reaped.
- """
+ """A SQLAlchemy clause that identifies the database rows to be reaped."""
return Collection.marked_for_deletion == True
def delete(self, collection):
@@ -934,17 +968,26 @@ def delete(self, collection):
failure.
"""
collection.delete()
+
+
ReaperMonitor.REGISTRY.append(CollectionReaper)
class MeasurementReaper(ReaperMonitor):
"""Remove measurements that are not the most recent"""
+
MODEL_CLASS = Measurement
def run(self):
- enabled = ConfigurationSetting.sitewide(self._db, Configuration.MEASUREMENT_REAPER).bool_value
+ enabled = ConfigurationSetting.sitewide(
+ self._db, Configuration.MEASUREMENT_REAPER
+ ).bool_value
if enabled is not None and not enabled:
- self.log.info("{} skipped because it is disabled in configuration.".format(self.service_name))
+ self.log.info(
+ "{} skipped because it is disabled in configuration.".format(
+ self.service_name
+ )
+ )
return
return super(ReaperMonitor, self).run()
@@ -957,6 +1000,7 @@ def run_once(self, *args, **kwargs):
self._db.commit()
return TimestampData(achievements="Items deleted: %d" % rows_deleted)
+
ReaperMonitor.REGISTRY.append(MeasurementReaper)
@@ -973,6 +1017,7 @@ class ScrubberMonitor(ReaperMonitor):
* SCRUB_FIELD - The field whose value will be set to None when a row
is scrubbed.
"""
+
def __init__(self, *args, **kwargs):
"""Set the name of the Monitor based on which field is being
scrubbed.
@@ -980,18 +1025,19 @@ def __init__(self, *args, **kwargs):
super(ScrubberMonitor, self).__init__(*args, **kwargs)
self.SERVICE_NAME = "Scrubber for %s.%s" % (
self.MODEL_CLASS.__name__,
- self.SCRUB_FIELD
+ self.SCRUB_FIELD,
)
def run_once(self, *args, **kwargs):
"""Find all rows that need to be scrubbed, and scrub them."""
rows_scrubbed = 0
cls = self.MODEL_CLASS
- update = cls.__table__.update().where(
- self.where_clause
- ).values(
- {self.SCRUB_FIELD : None}
- ).returning(cls.id)
+ update = (
+ cls.__table__.update()
+ .where(self.where_clause)
+ .values({self.SCRUB_FIELD: None})
+ .returning(cls.id)
+ )
scrubbed = self._db.execute(update).fetchall()
self._db.commit()
return TimestampData(achievements="Items scrubbed: %d" % len(scrubbed))
@@ -1002,10 +1048,7 @@ def where_clause(self):
SCRUB_FIELD. If the field is already null, there's no need to
scrub it.
"""
- return and_(
- super(ScrubberMonitor, self).where_clause,
- self.scrub_field != None
- )
+ return and_(super(ScrubberMonitor, self).where_clause, self.scrub_field != None)
@property
def scrub_field(self):
@@ -1017,10 +1060,13 @@ def scrub_field(self):
class CirculationEventLocationScrubber(ScrubberMonitor):
"""Scrub location information from old CirculationEvents."""
+
MODEL_CLASS = CirculationEvent
- TIMESTAMP_FIELD = 'start'
+ TIMESTAMP_FIELD = "start"
MAX_AGE = 365
- SCRUB_FIELD = 'location'
+ SCRUB_FIELD = "location"
+
+
ReaperMonitor.REGISTRY.append(CirculationEventLocationScrubber)
@@ -1028,8 +1074,11 @@ class PatronNeighborhoodScrubber(ScrubberMonitor):
"""Scrub cached neighborhood information from patrons who haven't been
seen in a while.
"""
+
MODEL_CLASS = Patron
- TIMESTAMP_FIELD = 'last_external_sync'
+ TIMESTAMP_FIELD = "last_external_sync"
MAX_AGE = Patron.MAX_SYNC_TIME
- SCRUB_FIELD = 'cached_neighborhood'
+ SCRUB_FIELD = "cached_neighborhood"
+
+
ReaperMonitor.REGISTRY.append(PatronNeighborhoodScrubber)
diff --git a/opds.py b/opds.py
index fc56ea6a6..1d17487ea 100644
--- a/opds.py
+++ b/opds.py
@@ -48,6 +48,7 @@ class UnfulfillableWork(Exception):
none of the delivery mechanisms could be mirrored.
"""
+
class Annotator(object):
"""The Annotator knows how to present an OPDS feed in a specific
application context.
@@ -68,8 +69,9 @@ def is_work_entry_solo(self, work):
"""
return False
- def annotate_work_entry(self, work, active_license_pool, edition,
- identifier, feed, entry, updated=None):
+ def annotate_work_entry(
+ self, work, active_license_pool, edition, identifier, feed, entry, updated=None
+ ):
"""Make any custom modifications necessary to integrate this
OPDS entry into the application's workflow.
@@ -105,14 +107,12 @@ def annotate_work_entry(self, work, active_license_pool, edition,
)
if permalink_uri:
OPDSFeed.add_link_to_entry(
- entry, rel='alternate', href=permalink_uri,
- type=permalink_type
+ entry, rel="alternate", href=permalink_uri, type=permalink_type
)
if self.is_work_entry_solo(work):
OPDSFeed.add_link_to_entry(
- entry, rel='self', href=permalink_uri,
- type=permalink_type
+ entry, rel="self", href=permalink_uri, type=permalink_type
)
if active_license_pool:
@@ -123,10 +123,9 @@ def annotate_work_entry(self, work, active_license_pool, edition,
# This component is not actually distributing the book,
# so it should not have a bibframe:distribution tag.
provider_name_attr = "{%s}ProviderName" % AtomFeed.BIBFRAME_NS
- kwargs = {provider_name_attr : data_source}
+ kwargs = {provider_name_attr: data_source}
data_source_tag = AtomFeed.makeelement(
- "{%s}distribution" % AtomFeed.BIBFRAME_NS,
- **kwargs
+ "{%s}distribution" % AtomFeed.BIBFRAME_NS, **kwargs
)
entry.extend([data_source_tag])
@@ -138,7 +137,7 @@ def annotate_work_entry(self, work, active_license_pool, edition,
today = datetime.date.today()
if isinstance(avail, datetime.datetime):
avail = avail.date()
- if avail <= today: # Avoid obviously wrong values.
+ if avail <= today: # Avoid obviously wrong values.
availability_tag = AtomFeed.makeelement("published")
# TODO: convert to local timezone.
availability_tag.text = AtomFeed._strftime(avail)
@@ -147,13 +146,10 @@ def annotate_work_entry(self, work, active_license_pool, edition,
# If this OPDS entry is being used as part of a grouped feed
# (which is up to the Annotator subclass), we need to add a
# group link.
- group_uri, group_title = self.group_uri(
- work, active_license_pool, identifier
- )
+ group_uri, group_title = self.group_uri(work, active_license_pool, identifier)
if group_uri:
OPDSFeed.add_link_to_entry(
- entry, rel=OPDSFeed.GROUP_REL, href=group_uri,
- title=str(group_title)
+ entry, rel=OPDSFeed.GROUP_REL, href=group_uri, title=str(group_title)
)
if not updated and work.last_update_time:
@@ -188,10 +184,10 @@ def group_uri(cls, work, license_pool, identifier):
def rating_tag(cls, type_uri, value):
"""Generate a schema:Rating tag for the given type and value."""
rating_tag = AtomFeed.makeelement(AtomFeed.schema_("Rating"))
- value_key = AtomFeed.schema_('ratingValue')
+ value_key = AtomFeed.schema_("ratingValue")
rating_tag.set(value_key, "%.4f" % value)
if type_uri:
- type_key = AtomFeed.schema_('additionalType')
+ type_key = AtomFeed.schema_("additionalType")
rating_tag.set(type_key, type_uri)
return rating_tag
@@ -232,14 +228,13 @@ def categories(cls, work):
fiction_term = None
if work.fiction == True:
- fiction_term = 'Fiction'
+ fiction_term = "Fiction"
elif work.fiction == False:
- fiction_term = 'Nonfiction'
+ fiction_term = "Nonfiction"
if fiction_term:
fiction_scheme = Subject.SIMPLIFIED_FICTION_STATUS
categories[fiction_scheme] = [
- dict(term=fiction_scheme + fiction_term,
- label=fiction_term)
+ dict(term=fiction_scheme + fiction_term, label=fiction_term)
]
simplified_genres = []
@@ -248,8 +243,7 @@ def categories(cls, work):
if simplified_genres:
categories[Subject.SIMPLIFIED_GENRE] = [
- dict(term=Subject.SIMPLIFIED_GENRE + quote(x),
- label=x)
+ dict(term=Subject.SIMPLIFIED_GENRE + quote(x), label=x)
for x in simplified_genres
]
@@ -259,10 +253,10 @@ def categories(cls, work):
appeals = []
categories[schema_url] = appeals
for name, value in (
- (Work.CHARACTER_APPEAL, work.appeal_character),
- (Work.LANGUAGE_APPEAL, work.appeal_language),
- (Work.SETTING_APPEAL, work.appeal_setting),
- (Work.STORY_APPEAL, work.appeal_story),
+ (Work.CHARACTER_APPEAL, work.appeal_character),
+ (Work.LANGUAGE_APPEAL, work.appeal_language),
+ (Work.SETTING_APPEAL, work.appeal_setting),
+ (Work.STORY_APPEAL, work.appeal_story),
):
if value:
appeal = dict(term=schema_url + name, label=name)
@@ -274,16 +268,15 @@ def categories(cls, work):
# http://schema.org/audience
if work.audience:
audience_uri = AtomFeed.SCHEMA_NS + "audience"
- categories[audience_uri] = [
- dict(term=work.audience, label=work.audience)
- ]
+ categories[audience_uri] = [dict(term=work.audience, label=work.audience)]
# Any book can have a target age, but the target age
# is only relevant for childrens' and YA books.
audiences_with_target_age = (
- Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_YOUNG_ADULT
+ Classifier.AUDIENCE_CHILDREN,
+ Classifier.AUDIENCE_YOUNG_ADULT,
)
- if (work.target_age and work.audience in audiences_with_target_age):
+ if work.target_age and work.audience in audiences_with_target_age:
uri = Subject.uri_lookup[Subject.AGE_RANGE]
target_age = work.target_age_string
if target_age:
@@ -355,7 +348,7 @@ def contributor_tag(cls, contribution, state):
# Okay, we're creating a tag.
properties = dict()
if marc_role:
- properties['{%s}role' % AtomFeed.OPF_NS] = marc_role
+ properties["{%s}role" % AtomFeed.OPF_NS] = marc_role
tag = tag_f(AtomFeed.name(name), **properties)
# Record the fact that we credited this person with this role,
@@ -370,9 +363,9 @@ def series(cls, series_name, series_position):
if not series_name:
return None
series_details = dict()
- series_details['name'] = series_name
+ series_details["name"] = series_name
if series_position != None:
- series_details[AtomFeed.schema_('position')] = str(series_position)
+ series_details[AtomFeed.schema_("position")] = str(series_position)
series_tag = AtomFeed.makeelement(AtomFeed.schema_("Series"), **series_details)
return series_tag
@@ -467,8 +460,9 @@ class VerboseAnnotator(Annotator):
opds_cache_field = Work.verbose_opds_entry.name
- def annotate_work_entry(self, work, active_license_pool, edition,
- identifier, feed, entry):
+ def annotate_work_entry(
+ self, work, active_license_pool, edition, identifier, feed, entry
+ ):
super(VerboseAnnotator, self).annotate_work_entry(
work, active_license_pool, edition, identifier, feed, entry
)
@@ -476,12 +470,11 @@ def annotate_work_entry(self, work, active_license_pool, edition,
@classmethod
def add_ratings(cls, work, entry):
- """Add a quality rating to the work.
- """
+ """Add a quality rating to the work."""
for type_uri, value in [
- (Measurement.QUALITY, work.quality),
- (None, work.rating),
- (Measurement.POPULARITY, work.popularity),
+ (Measurement.QUALITY, work.quality),
+ (None, work.rating),
+ (Measurement.POPULARITY, work.popularity),
]:
if value:
entry.append(cls.rating_tag(type_uri, value))
@@ -504,7 +497,8 @@ def categories(cls, work, policy=None):
by_scheme_and_term = dict()
identifier_ids = work.all_identifier_ids(policy=policy)
classifications = Identifier.classifications_for_identifier_ids(
- _db, identifier_ids)
+ _db, identifier_ids
+ )
for c in classifications:
subject = c.subject
if subject.type in Subject.uri_lookup:
@@ -515,7 +509,7 @@ def categories(cls, work, policy=None):
if not key in by_scheme_and_term:
value = dict(term=subject.identifier)
if subject.name:
- value['label'] = subject.name
+ value["label"] = subject.name
value[weight_field] = 0
by_scheme_and_term[key] = value
by_scheme_and_term[key][weight_field] += c.weight
@@ -530,8 +524,7 @@ def categories(cls, work, policy=None):
@classmethod
def authors(cls, work, edition):
"""Create a detailed tag for each author."""
- return [cls.detailed_author(author)
- for author in edition.author_contributors]
+ return [cls.detailed_author(author) for author in edition.author_contributors]
@classmethod
def detailed_author(cls, contributor):
@@ -550,7 +543,8 @@ def detailed_author(cls, contributor):
if contributor.wikipedia_name:
wikipedia_name = AtomFeed.makeelement(
- "{%s}wikipedia_name" % AtomFeed.SIMPLIFIED_NS)
+ "{%s}wikipedia_name" % AtomFeed.SIMPLIFIED_NS
+ )
wikipedia_name.text = contributor.wikipedia_name
children.append(wikipedia_name)
@@ -564,20 +558,27 @@ def detailed_author(cls, contributor):
lc_tag.text = "http://id.loc.gov/authorities/names/%s" % contributor.lc
children.append(lc_tag)
-
return AtomFeed.author(*children)
-
class AcquisitionFeed(OPDSFeed):
FACET_REL = "http://opds-spec.org/facet"
@classmethod
- def groups(cls, _db, title, url, worklist, annotator,
- pagination=None, facets=None, max_age=None,
- search_engine=None, search_debug=False,
- **response_kwargs
+ def groups(
+ cls,
+ _db,
+ title,
+ url,
+ worklist,
+ annotator,
+ pagination=None,
+ facets=None,
+ max_age=None,
+ search_engine=None,
+ search_debug=False,
+ **response_kwargs
):
"""The acquisition feed for 'featured' items from a given lane's
sublanes, organized into per-lane groups.
@@ -600,21 +601,39 @@ def groups(cls, _db, title, url, worklist, annotator,
def refresh():
return cls._generate_groups(
- _db=_db, title=title, url=url, worklist=worklist,
- annotator=annotator, pagination=pagination, facets=facets,
- search_engine=search_engine, search_debug=search_debug
+ _db=_db,
+ title=title,
+ url=url,
+ worklist=worklist,
+ annotator=annotator,
+ pagination=pagination,
+ facets=facets,
+ search_engine=search_engine,
+ search_debug=search_debug,
)
return CachedFeed.fetch(
- _db=_db, worklist=worklist, pagination=pagination,
- facets=facets, refresher_method=refresh, max_age=max_age,
+ _db=_db,
+ worklist=worklist,
+ pagination=pagination,
+ facets=facets,
+ refresher_method=refresh,
+ max_age=max_age,
**response_kwargs
)
@classmethod
def _generate_groups(
- cls, _db, title, url, worklist, annotator,
- pagination, facets, search_engine, search_debug
+ cls,
+ _db,
+ title,
+ url,
+ worklist,
+ annotator,
+ pagination,
+ facets,
+ search_engine,
+ search_debug,
):
"""Internal method called by groups() when a grouped feed
must be regenerated.
@@ -623,15 +642,19 @@ def _generate_groups(
# Try to get a set of (Work, WorkList) 2-tuples
# to make a normal grouped feed.
works_and_lanes = [
- x for x in worklist.groups(
- _db=_db, pagination=pagination, facets=facets,
- search_engine=search_engine, debug=search_debug
+ x
+ for x in worklist.groups(
+ _db=_db,
+ pagination=pagination,
+ facets=facets,
+ search_engine=search_engine,
+ debug=search_debug,
)
]
# Make a typical grouped feed.
all_works = []
for work, sublane in works_and_lanes:
- if sublane==worklist:
+ if sublane == worklist:
# We are looking at the groups feed for (e.g.)
# "Science Fiction", and we're seeing a book
# that is featured within "Science Fiction" itself
@@ -670,13 +693,13 @@ def _generate_groups(
# the data.
entrypoints = facets.selectable_entrypoints(worklist)
if entrypoints:
+
def make_link(ep):
return annotator.groups_url(
worklist, facets=facets.navigate(entrypoint=ep)
)
- cls.add_entrypoint_links(
- feed, make_link, entrypoints, facets.entrypoint
- )
+
+ cls.add_entrypoint_links(feed, make_link, entrypoints, facets.entrypoint)
# A grouped feed may have breadcrumb links.
feed.add_breadcrumb_links(worklist, facets.entrypoint)
@@ -687,10 +710,19 @@ def make_link(ep):
return feed
@classmethod
- def page(cls, _db, title, url, worklist, annotator,
- facets=None, pagination=None,
- max_age=None, search_engine=None, search_debug=False,
- **response_kwargs
+ def page(
+ cls,
+ _db,
+ title,
+ url,
+ worklist,
+ annotator,
+ facets=None,
+ pagination=None,
+ max_age=None,
+ search_engine=None,
+ search_debug=False,
+ **response_kwargs
):
"""Create a feed representing one page of works from a given lane.
@@ -706,27 +738,49 @@ def page(cls, _db, title, url, worklist, annotator,
def refresh():
return cls._generate_page(
- _db, title, url, worklist, annotator, facets, pagination,
- search_engine, search_debug
+ _db,
+ title,
+ url,
+ worklist,
+ annotator,
+ facets,
+ pagination,
+ search_engine,
+ search_debug,
)
- response_kwargs.setdefault('max_age', max_age)
+ response_kwargs.setdefault("max_age", max_age)
return CachedFeed.fetch(
- _db, worklist=worklist, pagination=pagination, facets=facets,
- refresher_method=refresh, **response_kwargs
+ _db,
+ worklist=worklist,
+ pagination=pagination,
+ facets=facets,
+ refresher_method=refresh,
+ **response_kwargs
)
@classmethod
def _generate_page(
- cls, _db, title, url, lane, annotator, facets, pagination,
- search_engine, search_debug
+ cls,
+ _db,
+ title,
+ url,
+ lane,
+ annotator,
+ facets,
+ pagination,
+ search_engine,
+ search_debug,
):
"""Internal method called by page() when a cached feed
must be regenerated.
"""
works = lane.works(
- _db, pagination=pagination, facets=facets,
- search_engine=search_engine, debug=search_debug
+ _db,
+ pagination=pagination,
+ facets=facets,
+ search_engine=search_engine,
+ debug=search_debug,
)
if not isinstance(works, list):
@@ -747,12 +801,9 @@ def _generate_page(
# A paginated feed may have multiple entry points into the
# same dataset.
def make_link(ep):
- return annotator.feed_url(
- lane, facets=facets.navigate(entrypoint=ep)
- )
- cls.add_entrypoint_links(
- feed, make_link, entrypoints, facets.entrypoint
- )
+ return annotator.feed_url(lane, facets=facets.navigate(entrypoint=ep))
+
+ cls.add_entrypoint_links(feed, make_link, entrypoints, facets.entrypoint)
# Add URLs to change faceted views of the collection.
for args in cls.facet_links(annotator, facets):
@@ -760,14 +811,26 @@ def make_link(ep):
if len(works) > 0 and pagination.has_next_page:
# There are works in this list. Add a 'next' link.
- OPDSFeed.add_link_to_feed(feed=feed.feed, rel="next", href=annotator.feed_url(lane, facets, pagination.next_page))
+ OPDSFeed.add_link_to_feed(
+ feed=feed.feed,
+ rel="next",
+ href=annotator.feed_url(lane, facets, pagination.next_page),
+ )
if pagination.offset > 0:
- OPDSFeed.add_link_to_feed(feed=feed.feed, rel="first", href=annotator.feed_url(lane, facets, pagination.first_page))
+ OPDSFeed.add_link_to_feed(
+ feed=feed.feed,
+ rel="first",
+ href=annotator.feed_url(lane, facets, pagination.first_page),
+ )
previous_page = pagination.previous_page
if previous_page:
- OPDSFeed.add_link_to_feed(feed=feed.feed, rel="previous", href=annotator.feed_url(lane, facets, previous_page))
+ OPDSFeed.add_link_to_feed(
+ feed=feed.feed,
+ rel="previous",
+ href=annotator.feed_url(lane, facets, previous_page),
+ )
if isinstance(facets, FacetsWithEntryPoint):
feed.add_breadcrumb_links(lane, facets.entrypoint)
@@ -795,11 +858,19 @@ def from_query(cls, query, _db, feed_name, url, pagination, url_fn, annotator):
feed = cls(_db, feed_name, url, page_of_works, annotator)
if pagination.total_size > 0 and pagination.has_next_page:
- OPDSFeed.add_link_to_feed(feed=feed.feed, rel="next", href=url_fn(pagination.next_page.offset))
+ OPDSFeed.add_link_to_feed(
+ feed=feed.feed, rel="next", href=url_fn(pagination.next_page.offset)
+ )
if pagination.offset > 0:
- OPDSFeed.add_link_to_feed(feed=feed.feed, rel="first", href=url_fn(pagination.first_page.offset))
+ OPDSFeed.add_link_to_feed(
+ feed=feed.feed, rel="first", href=url_fn(pagination.first_page.offset)
+ )
if pagination.previous_page:
- OPDSFeed.add_link_to_feed(feed=feed.feed, rel="previous", href=url_fn(pagination.previous_page.offset))
+ OPDSFeed.add_link_to_feed(
+ feed=feed.feed,
+ rel="previous",
+ href=url_fn(pagination.previous_page.offset),
+ )
return feed
@@ -812,8 +883,8 @@ def as_error_response(self, **kwargs):
by intermediaries as an error -- that is, treated as private
and not cached.
"""
- kwargs['max_age'] = 0
- kwargs['private'] = True
+ kwargs["max_age"] = 0
+ kwargs["private"] = True
return self.as_response(**kwargs)
@classmethod
@@ -840,15 +911,16 @@ def facet_link(cls, href, title, facet_group_name, is_active):
keyword arguments into OPDSFeed.add_link_to_feed.
"""
args = dict(href=href, title=title)
- args['rel'] = cls.FACET_REL
- args['{%s}facetGroup' % AtomFeed.OPDS_NS] = facet_group_name
+ args["rel"] = cls.FACET_REL
+ args["{%s}facetGroup" % AtomFeed.OPDS_NS] = facet_group_name
if is_active:
- args['{%s}activeFacet' % AtomFeed.OPDS_NS] = "true"
+ args["{%s}activeFacet" % AtomFeed.OPDS_NS] = "true"
return args
@classmethod
- def add_entrypoint_links(cls, feed, url_generator, entrypoints,
- selected_entrypoint, group_name='Formats'):
+ def add_entrypoint_links(
+ cls, feed, url_generator, entrypoints, selected_entrypoint, group_name="Formats"
+ ):
"""Add links to a feed forming an OPDS facet group for a set of
EntryPoints.
@@ -858,8 +930,7 @@ def add_entrypoint_links(cls, feed, url_generator, entrypoints,
:param entrypoints: A list of all EntryPoints in the facet group.
:param selected_entrypoint: The current EntryPoint, if selected.
"""
- if (len(entrypoints) == 1
- and selected_entrypoint in (None, entrypoints[0])):
+ if len(entrypoints) == 1 and selected_entrypoint in (None, entrypoints[0]):
# There is only one entry point. Unless the currently
# selected entry point is somehow different, there's no
# need to put any links at all here -- a facet group with
@@ -869,8 +940,7 @@ def add_entrypoint_links(cls, feed, url_generator, entrypoints,
is_default = True
for entrypoint in entrypoints:
link = cls._entrypoint_link(
- url_generator, entrypoint, selected_entrypoint, is_default,
- group_name
+ url_generator, entrypoint, selected_entrypoint, is_default, group_name
)
if link is not None:
cls.add_link_to_feed(feed.feed, **link)
@@ -878,8 +948,7 @@ def add_entrypoint_links(cls, feed, url_generator, entrypoints,
@classmethod
def _entrypoint_link(
- cls, url_generator, entrypoint, selected_entrypoint,
- is_default, group_name
+ cls, url_generator, entrypoint, selected_entrypoint, is_default, group_name
):
"""Create arguments for add_link_to_feed for a link that navigates
between EntryPoints.
@@ -899,7 +968,9 @@ def _entrypoint_link(
#
# In OPDS 2 this can become an additional rel value,
# removing the need for a custom attribute.
- link['{%s}facetGroupType' % AtomFeed.SIMPLIFIED_NS] = FacetConstants.ENTRY_POINT_REL
+ link[
+ "{%s}facetGroupType" % AtomFeed.SIMPLIFIED_NS
+ ] = FacetConstants.ENTRY_POINT_REL
return link
def add_breadcrumb_links(self, lane, entrypoint=None):
@@ -918,8 +989,10 @@ def add_breadcrumb_links(self, lane, entrypoint=None):
annotator = self.annotator
top_level_title = annotator.top_level_title() or "Collection Home"
self.add_link_to_feed(
- feed=xml, rel='start', href=annotator.default_lane_url(),
- title=top_level_title
+ feed=xml,
+ rel="start",
+ href=annotator.default_lane_url(),
+ title=top_level_title,
)
# Add a link to the direct parent with rel="up".
@@ -938,9 +1011,7 @@ def add_breadcrumb_links(self, lane, entrypoint=None):
if parent:
up_uri = annotator.lane_url(parent)
- self.add_link_to_feed(
- feed=xml, href=up_uri, rel="up", title=parent_title
- )
+ self.add_link_to_feed(feed=xml, href=up_uri, rel="up", title=parent_title)
self.add_breadcrumbs(lane, entrypoint=entrypoint)
# Annotate the feed with a simplified:entryPoint for the
@@ -948,9 +1019,18 @@ def add_breadcrumb_links(self, lane, entrypoint=None):
self.show_current_entrypoint(entrypoint)
@classmethod
- def search(cls, _db, title, url, lane, search_engine, query,
- pagination=None, facets=None, annotator=None,
- **response_kwargs
+ def search(
+ cls,
+ _db,
+ title,
+ url,
+ lane,
+ search_engine,
+ query,
+ pagination=None,
+ facets=None,
+ annotator=None,
+ **response_kwargs
):
"""Run a search against the given search engine and return
the results as a Flask Response.
@@ -972,24 +1052,24 @@ def search(cls, _db, title, url, lane, search_engine, query,
results = lane.search(
_db, query, search_engine, pagination=pagination, facets=facets
)
- opds_feed = AcquisitionFeed(
- _db, title, url, results, annotator=annotator
- )
+ opds_feed = AcquisitionFeed(_db, title, url, results, annotator=annotator)
AcquisitionFeed.add_link_to_feed(
- feed=opds_feed.feed, rel='start',
+ feed=opds_feed.feed,
+ rel="start",
href=annotator.default_lane_url(),
- title=annotator.top_level_title()
+ title=annotator.top_level_title(),
)
# A feed of search results may link to alternate entry points
# into those results.
entrypoints = facets.selectable_entrypoints(lane)
if entrypoints:
+
def make_link(ep):
return annotator.search_url(
- lane, query, pagination=None,
- facets=facets.navigate(entrypoint=ep)
+ lane, query, pagination=None, facets=facets.navigate(entrypoint=ep)
)
+
cls.add_entrypoint_links(
opds_feed, make_link, entrypoints, facets.entrypoint
)
@@ -997,19 +1077,30 @@ def make_link(ep):
if len(results) > 0:
# There are works in this list. Add a 'next' link.
next_url = annotator.search_url(lane, query, pagination.next_page, facets)
- AcquisitionFeed.add_link_to_feed(feed=opds_feed.feed, rel="next", href=next_url)
+ AcquisitionFeed.add_link_to_feed(
+ feed=opds_feed.feed, rel="next", href=next_url
+ )
if pagination.offset > 0:
first_url = annotator.search_url(lane, query, pagination.first_page, facets)
- AcquisitionFeed.add_link_to_feed(feed=opds_feed.feed, rel="first", href=first_url)
+ AcquisitionFeed.add_link_to_feed(
+ feed=opds_feed.feed, rel="first", href=first_url
+ )
previous_page = pagination.previous_page
if previous_page:
previous_url = annotator.search_url(lane, query, previous_page, facets)
- AcquisitionFeed.add_link_to_feed(feed=opds_feed.feed, rel="previous", href=previous_url)
+ AcquisitionFeed.add_link_to_feed(
+ feed=opds_feed.feed, rel="previous", href=previous_url
+ )
# Add "up" link.
- AcquisitionFeed.add_link_to_feed(feed=opds_feed.feed, rel="up", href=annotator.lane_url(lane), title=str(lane.display_name))
+ AcquisitionFeed.add_link_to_feed(
+ feed=opds_feed.feed,
+ rel="up",
+ href=annotator.lane_url(lane),
+ title=str(lane.display_name),
+ )
# We do not add breadcrumbs to this feed since you're not
# technically searching the this lane; you are searching the
@@ -1021,7 +1112,14 @@ def make_link(ep):
@classmethod
def single_entry(
- cls, _db, work, annotator, force_create=False, raw=False, use_cache=True, **response_kwargs
+ cls,
+ _db,
+ work,
+ annotator,
+ force_create=False,
+ raw=False,
+ use_cache=True,
+ **response_kwargs
):
"""Create a single-entry OPDS document for one specific work.
@@ -1041,22 +1139,26 @@ def single_entry(
OPDSFeed.create_entry.
"""
- feed = cls(_db, '', '', [], annotator=annotator)
+ feed = cls(_db, "", "", [], annotator=annotator)
if not isinstance(work, Edition) and not work.presentation_edition:
return None
- entry = feed.create_entry(work, even_if_no_license_pool=True,
- force_create=force_create, use_cache=use_cache)
+ entry = feed.create_entry(
+ work,
+ even_if_no_license_pool=True,
+ force_create=force_create,
+ use_cache=use_cache,
+ )
# Since this tag is going to be the root of an XML
# document it's essential that it include an up-to-date nsmap,
# even if it was generated from an old cached tag that
# had an older nsmap.
- if isinstance(entry, etree._Element) and not 'drm' in entry.nsmap:
+ if isinstance(entry, etree._Element) and not "drm" in entry.nsmap:
# This workaround (creating a brand new tag) is necessary
# because the nsmap attribute is immutable. See
# https://bugs.launchpad.net/lxml/+bug/555602
nsmap = entry.nsmap
- nsmap['drm'] = AtomFeed.DRM_NS
+ nsmap["drm"] = AtomFeed.DRM_NS
new_root = etree.Element(entry.tag, nsmap=nsmap)
new_root[:] = entry[:]
entry = new_root
@@ -1066,16 +1168,16 @@ def single_entry(
entry = str(entry)
# This is probably an error message; don't cache it
# even if it would otherwise be cached.
- response_kwargs['max_age'] = 0
- response_kwargs['private'] = True
+ response_kwargs["max_age"] = 0
+ response_kwargs["private"] = True
elif isinstance(entry, etree._Element):
entry = etree.tostring(entry, encoding="unicode")
# It's common for a single OPDS entry to be returned as the
# result of an unsafe operation, so we will default to setting
# the response as private and uncacheable.
- response_kwargs.setdefault('max_age', 0)
- response_kwargs.setdefault('private', True)
+ response_kwargs.setdefault("max_age", 0)
+ response_kwargs.setdefault("private", True)
return OPDSEntryResponse(response=entry, **response_kwargs)
@@ -1107,12 +1209,9 @@ def facet_links(cls, annotator, facets):
# system. It may be left over from an earlier version,
# or just weird junk data.
continue
- yield cls.facet_link(
- url, str(facet_title), str(group_title), selected
- )
+ yield cls.facet_link(url, str(facet_title), str(group_title), selected)
- def __init__(self, _db, title, url, works, annotator=None,
- precomposed_entries=[]):
+ def __init__(self, _db, title, url, works, annotator=None, precomposed_entries=[]):
"""Turn a list of works, messages, and precomposed entries
into a feed.
"""
@@ -1145,8 +1244,9 @@ def add_entry(self, work):
self.feed.append(entry)
return entry
- def create_entry(self, work, even_if_no_license_pool=False,
- force_create=False, use_cache=True):
+ def create_entry(
+ self, work, even_if_no_license_pool=False, force_create=False, use_cache=True
+ ):
"""Turn a work into an entry for an acquisition feed."""
identifier = None
if isinstance(work, Edition):
@@ -1178,7 +1278,7 @@ def create_entry(self, work, even_if_no_license_pool=False,
return self.error_message(
identifier,
403,
- "I've heard about this work but have no active licenses for it."
+ "I've heard about this work but have no active licenses for it.",
)
if not active_edition:
@@ -1186,13 +1286,17 @@ def create_entry(self, work, even_if_no_license_pool=False,
return self.error_message(
identifier,
403,
- "I've heard about this work but have no metadata for it."
+ "I've heard about this work but have no metadata for it.",
)
try:
return self._create_entry(
- work, active_license_pool, active_edition, identifier,
- force_create, use_cache
+ work,
+ active_license_pool,
+ active_edition,
+ identifier,
+ force_create,
+ use_cache,
)
except UnfulfillableWork as e:
logging.info(
@@ -1202,17 +1306,21 @@ def create_entry(self, work, even_if_no_license_pool=False,
return self.error_message(
identifier,
403,
- "I know about this work but can offer no way of fulfilling it."
+ "I know about this work but can offer no way of fulfilling it.",
)
except Exception as e:
- logging.error(
- "Exception generating OPDS entry for %r", work,
- exc_info = e
- )
+ logging.error("Exception generating OPDS entry for %r", work, exc_info=e)
return None
- def _create_entry(self, work, active_license_pool, edition,
- identifier, force_create=False, use_cache=True):
+ def _create_entry(
+ self,
+ work,
+ active_license_pool,
+ edition,
+ identifier,
+ force_create=False,
+ use_cache=True,
+ ):
"""Build a complete OPDS entry for the given Work.
The OPDS entry will contain bibliographic information about
@@ -1252,7 +1360,8 @@ def _create_entry(self, work, active_license_pool, edition,
# Now add the stuff specific to the selected Identifier
# and LicensePool.
self.annotator.annotate_work_entry(
- work, active_license_pool, edition, identifier, self, xml)
+ work, active_license_pool, edition, identifier, self, xml
+ )
return xml
@@ -1285,8 +1394,9 @@ def _make_entry_xml(self, work, edition):
thumbnail_urls, full_urls = self.annotator.cover_links(work)
for rel, urls in (
- (Hyperlink.IMAGE, full_urls),
- (Hyperlink.THUMBNAIL_IMAGE, thumbnail_urls)):
+ (Hyperlink.IMAGE, full_urls),
+ (Hyperlink.THUMBNAIL_IMAGE, thumbnail_urls),
+ ):
for url in urls:
# TODO: This is suboptimal. We know the media types
# associated with these URLs when they are
@@ -1306,21 +1416,16 @@ def _make_entry_xml(self, work, edition):
if isinstance(content, bytes):
content = content.decode("utf8")
- content_type = 'html'
+ content_type = "html"
kw = {}
if edition.medium:
- additional_type = Edition.medium_to_additional_type.get(
- edition.medium)
+ additional_type = Edition.medium_to_additional_type.get(edition.medium)
if not additional_type:
- logging.warn("No additionalType for medium %s",
- edition.medium)
+ logging.warn("No additionalType for medium %s", edition.medium)
additional_type_field = AtomFeed.schema_("additionalType")
kw[additional_type_field] = additional_type
- entry = AtomFeed.entry(
- AtomFeed.title(edition.title or OPDSFeed.NO_TITLE),
- **kw
- )
+ entry = AtomFeed.entry(AtomFeed.title(edition.title or OPDSFeed.NO_TITLE), **kw)
if edition.subtitle:
subtitle_tag = AtomFeed.makeelement(AtomFeed.schema_("alternativeHeadline"))
subtitle_tag.text = edition.subtitle
@@ -1330,13 +1435,16 @@ def _make_entry_xml(self, work, edition):
entry.extend(author_tags)
if edition.series:
- entry.extend([self.annotator.series(edition.series, edition.series_position)])
+ entry.extend(
+ [self.annotator.series(edition.series, edition.series_position)]
+ )
if content:
entry.extend([AtomFeed.summary(content, type=content_type)])
-
- permanent_work_id_tag = AtomFeed.makeelement("{%s}pwid" % AtomFeed.SIMPLIFIED_NS)
+ permanent_work_id_tag = AtomFeed.makeelement(
+ "{%s}pwid" % AtomFeed.SIMPLIFIED_NS
+ )
permanent_work_id_tag.text = edition.permanent_work_id
entry.append(permanent_work_id_tag)
@@ -1348,7 +1456,9 @@ def _make_entry_xml(self, work, edition):
for category in categories:
if isinstance(category, (bytes, str)):
category = dict(term=category)
- category = dict(list(map(str, (k, v))) for k, v in list(category.items()))
+ category = dict(
+ list(map(str, (k, v))) for k, v in list(category.items())
+ )
category_tag = AtomFeed.category(scheme=scheme, **category)
category_tags.append(category_tag)
entry.extend(category_tags)
@@ -1366,7 +1476,9 @@ def _make_entry_xml(self, work, edition):
entry.extend([publisher_tag])
if edition.imprint:
- imprint_tag = AtomFeed.makeelement("{%s}publisherImprint" % AtomFeed.BIB_SCHEMA_NS)
+ imprint_tag = AtomFeed.makeelement(
+ "{%s}publisherImprint" % AtomFeed.BIB_SCHEMA_NS
+ )
imprint_tag.text = edition.imprint
entry.extend([imprint_tag])
@@ -1386,27 +1498,25 @@ def _make_entry_xml(self, work, edition):
# it can't be used to extract this information. However, these
# tags are consistent with the OPDS spec.
issued = edition.issued or edition.published
- if (isinstance(issued, datetime.datetime)
- or isinstance(issued, datetime.date)):
+ if isinstance(issued, datetime.datetime) or isinstance(issued, datetime.date):
now = utc_now()
today = datetime.date.today()
issued_already = False
if isinstance(issued, datetime.datetime):
- issued_already = (issued <= now)
+ issued_already = issued <= now
elif isinstance(issued, datetime.date):
- issued_already = (issued <= today)
+ issued_already = issued <= today
if issued_already:
issued_tag = AtomFeed.makeelement("{%s}issued" % AtomFeed.DCTERMS_NS)
# Use datetime.isoformat instead of datetime.strftime because
# strftime only works on dates after 1890, and we have works
# that were issued much earlier than that.
# TODO: convert to local timezone, not that it matters much.
- issued_tag.text = issued.isoformat().split('T')[0]
+ issued_tag.text = issued.isoformat().split("T")[0]
entry.extend([issued_tag])
return entry
-
CURRENT_ENTRYPOINT_ATTRIBUTE = "{%s}entryPoint" % AtomFeed.SIMPLIFIED_NS
def show_current_entrypoint(self, entrypoint):
@@ -1462,12 +1572,9 @@ def add_breadcrumbs(self, lane, include_lane=False, entrypoint=None):
usable_parentage.append(ancestor)
annotator = self.annotator
- if (
- lane == site_root_lane or
- (
- site_root_lane is None and
- annotator.lane_url(lane) == annotator.default_lane_url()
- )
+ if lane == site_root_lane or (
+ site_root_lane is None
+ and annotator.lane_url(lane) == annotator.default_lane_url()
):
# There are no extra breadcrumbs: either we are at the
# site root, or we are at a lane that is the root for a
@@ -1475,9 +1582,7 @@ def add_breadcrumbs(self, lane, include_lane=False, entrypoint=None):
return
# Start work on a simplified:breadcrumbs tag.
- breadcrumbs = AtomFeed.makeelement(
- "{%s}breadcrumbs" % AtomFeed.SIMPLIFIED_NS
- )
+ breadcrumbs = AtomFeed.makeelement("{%s}breadcrumbs" % AtomFeed.SIMPLIFIED_NS)
# Add root link. This is either the link to the site root
# or to the root lane for some patron type.
@@ -1494,8 +1599,7 @@ def add_breadcrumbs(self, lane, include_lane=False, entrypoint=None):
if entrypoint:
breadcrumbs.append(
AtomFeed.link(
- title=entrypoint.INTERNAL_NAME,
- href=root_url + entrypoint_query
+ title=entrypoint.INTERNAL_NAME, href=root_url + entrypoint_query
)
)
@@ -1509,8 +1613,7 @@ def add_breadcrumbs(self, lane, include_lane=False, entrypoint=None):
breadcrumbs.append(
AtomFeed.link(
- title=ancestor.display_name,
- href=lane_url + entrypoint_query
+ title=ancestor.display_name, href=lane_url + entrypoint_query
)
)
@@ -1518,8 +1621,8 @@ def add_breadcrumbs(self, lane, include_lane=False, entrypoint=None):
self.feed.append(breadcrumbs)
@classmethod
- def minimal_opds_entry(cls, identifier, cover, description, quality,
- most_recent_update=None
+ def minimal_opds_entry(
+ cls, identifier, cover, description, quality, most_recent_update=None
):
elements = []
representations = []
@@ -1527,34 +1630,38 @@ def minimal_opds_entry(cls, identifier, cover, description, quality,
cover_representation = cover.representation
representations.append(cover.representation)
cover_link = AtomFeed.makeelement(
- "link", href=cover_representation.public_url,
- type=cover_representation.media_type, rel=Hyperlink.IMAGE)
+ "link",
+ href=cover_representation.public_url,
+ type=cover_representation.media_type,
+ rel=Hyperlink.IMAGE,
+ )
elements.append(cover_link)
if cover_representation.thumbnails:
thumbnail = cover_representation.thumbnails[0]
representations.append(thumbnail)
thumbnail_link = AtomFeed.makeelement(
- "link", href=thumbnail.public_url,
+ "link",
+ href=thumbnail.public_url,
type=thumbnail.media_type,
- rel=Hyperlink.THUMBNAIL_IMAGE
+ rel=Hyperlink.THUMBNAIL_IMAGE,
)
elements.append(thumbnail_link)
if description:
content = description.representation.content
if isinstance(content, bytes):
content = content.decode("utf8")
- description_e = AtomFeed.summary(content, type='html')
+ description_e = AtomFeed.summary(content, type="html")
elements.append(description_e)
representations.append(description.representation)
if quality:
- elements.append(
- Annotator.rating_tag(Measurement.QUALITY, quality))
+ elements.append(Annotator.rating_tag(Measurement.QUALITY, quality))
# The update date is the most recent date any of these
# resources were mirrored/fetched.
potential_update_dates = [
- r.mirrored_at or r.fetched_at for r in representations
+ r.mirrored_at or r.fetched_at
+ for r in representations
if r.mirrored_at or r.fetched_at
]
if most_recent_update:
@@ -1564,9 +1671,7 @@ def minimal_opds_entry(cls, identifier, cover, description, quality,
update_date = max(potential_update_dates)
elements.append(AtomFeed.updated(AtomFeed._strftime(update_date)))
entry = AtomFeed.entry(
- AtomFeed.id(identifier.urn),
- AtomFeed.title(OPDSFeed.NO_TITLE),
- *elements
+ AtomFeed.id(identifier.urn), AtomFeed.title(OPDSFeed.NO_TITLE), *elements
)
return entry
@@ -1622,7 +1727,8 @@ def indirect_acquisition(cls, indirect_types):
parent = None
for t in indirect_types:
indirect_link = AtomFeed.makeelement(
- "{%s}indirectAcquisition" % AtomFeed.OPDS_NS, type=t)
+ "{%s}indirectAcquisition" % AtomFeed.OPDS_NS, type=t
+ )
if parent is not None:
parent.extend([indirect_link])
parent = indirect_link
@@ -1654,7 +1760,7 @@ def license_tags(cls, license_pool, loan, hold):
collection.default_loan_period(obj.library or obj.integration_client)
)
if loan:
- status = 'available'
+ status = "available"
since = loan.start
until = loan.until(default_loan_period)
elif hold:
@@ -1664,30 +1770,36 @@ def license_tags(cls, license_pool, loan, hold):
)
until = hold.until(default_loan_period, default_reservation_period)
if hold.position == 0:
- status = 'ready'
+ status = "ready"
since = None
else:
- status = 'reserved'
+ status = "reserved"
since = hold.start
- elif (license_pool.open_access or license_pool.unlimited_access or license_pool.self_hosted or (
- license_pool.licenses_available > 0 and
- license_pool.licenses_owned > 0)
- ):
- status = 'available'
+ elif (
+ license_pool.open_access
+ or license_pool.unlimited_access
+ or license_pool.self_hosted
+ or (license_pool.licenses_available > 0 and license_pool.licenses_owned > 0)
+ ):
+ status = "available"
else:
- status='unavailable'
+ status = "unavailable"
kw = dict(status=status)
if since:
- kw['since'] = AtomFeed._strftime(since)
+ kw["since"] = AtomFeed._strftime(since)
if until:
- kw['until'] = AtomFeed._strftime(until)
+ kw["until"] = AtomFeed._strftime(until)
tag_name = "{%s}availability" % AtomFeed.OPDS_NS
availability_tag = AtomFeed.makeelement(tag_name, **kw)
tags.append(availability_tag)
# Open-access pools do not need to display or .
- if license_pool.open_access or license_pool.unlimited_access or license_pool.self_hosted:
+ if (
+ license_pool.open_access
+ or license_pool.unlimited_access
+ or license_pool.self_hosted
+ ):
return tags
holds_kw = dict()
@@ -1702,7 +1814,7 @@ def license_tags(cls, license_pool, loan, hold):
position = hold.position
if position > 0:
- holds_kw['position'] = str(position)
+ holds_kw["position"] = str(position)
if position > total:
# The patron's hold position appears larger than the total
# number of holds. This happens frequently because the
@@ -1716,7 +1828,7 @@ def license_tags(cls, license_pool, loan, hold):
# where we know that the total number of holds is
# *greater* than the hold position.
total = 1
- holds_kw['total'] = str(total)
+ holds_kw["total"] = str(total)
holds = AtomFeed.makeelement("{%s}holds" % AtomFeed.OPDS_NS, **holds_kw)
tags.append(holds)
@@ -1783,7 +1895,10 @@ def create_entry(self, work):
error_message = "Identifier not found in collection"
elif identifier.work != work:
error_status = 500
- error_message = 'I tried to generate an OPDS entry for the identifier "%s" using a Work not associated with that identifier.' % identifier.urn
+ error_message = (
+ 'I tried to generate an OPDS entry for the identifier "%s" using a Work not associated with that identifier.'
+ % identifier.urn
+ )
if error_status:
return self.error_message(identifier, error_status, error_message)
@@ -1793,28 +1908,35 @@ def create_entry(self, work):
else:
edition = work.presentation_edition
try:
- return self._create_entry(
- work, active_licensepool, edition, identifier
- )
+ return self._create_entry(work, active_licensepool, edition, identifier)
except UnfulfillableWork as e:
logging.info(
- "Work %r is not fulfillable, refusing to create an .",
- work
+ "Work %r is not fulfillable, refusing to create an .", work
)
return self.error_message(
identifier,
403,
- "I know about this work but can offer no way of fulfilling it."
+ "I know about this work but can offer no way of fulfilling it.",
)
+
class NavigationFacets(FeaturedFacets):
CACHED_FEED_TYPE = CachedFeed.NAVIGATION_TYPE
-class NavigationFeed(OPDSFeed):
+class NavigationFeed(OPDSFeed):
@classmethod
- def navigation(cls, _db, title, url, worklist, annotator,
- facets=None, max_age=None, **response_kwargs):
+ def navigation(
+ cls,
+ _db,
+ title,
+ url,
+ worklist,
+ annotator,
+ facets=None,
+ max_age=None,
+ **response_kwargs
+ ):
"""The navigation feed with links to a given lane's sublanes.
:param response_kwargs: Extra keyword arguments to pass into
@@ -1827,11 +1949,9 @@ def navigation(cls, _db, title, url, worklist, annotator,
facets = facets or NavigationFacets.default(worklist)
def refresh():
- return cls._generate_navigation(
- _db, title, url, worklist, annotator
- )
+ return cls._generate_navigation(_db, title, url, worklist, annotator)
- response_kwargs.setdefault('mimetype', OPDSFeed.NAVIGATION_FEED_TYPE)
+ response_kwargs.setdefault("mimetype", OPDSFeed.NAVIGATION_FEED_TYPE)
return CachedFeed.fetch(
_db,
worklist=worklist,
@@ -1843,8 +1963,7 @@ def refresh():
)
@classmethod
- def _generate_navigation(cls, _db, title, url, worklist,
- annotator):
+ def _generate_navigation(cls, _db, title, url, worklist, annotator):
feed = NavigationFeed(title, url)
@@ -1870,8 +1989,7 @@ def _generate_navigation(cls, _db, title, url, worklist,
def add_entry(self, url, title, type=OPDSFeed.NAVIGATION_FEED_TYPE):
"""Create an OPDS navigation entry for a URL."""
- entry = AtomFeed.entry(
- AtomFeed.title(title))
+ entry = AtomFeed.entry(AtomFeed.title(title))
entry.extend([AtomFeed.id(url)])
entry.extend([AtomFeed.link(rel="subsection", href=url, type=type)])
self.feed.append(entry)
@@ -1879,8 +1997,8 @@ def add_entry(self, url, title, type=OPDSFeed.NAVIGATION_FEED_TYPE):
# Mock annotators for use in unit tests.
-class TestAnnotator(Annotator):
+class TestAnnotator(Annotator):
def __init__(self):
self.lanes_by_work = defaultdict(list)
@@ -1899,10 +2017,10 @@ def feed_url(cls, lane, facets=None, pagination=None):
base = "http://%s/" % lane.url_name
else:
base = "http://%s/" % lane.display_name
- sep = '?'
+ sep = "?"
if facets:
base += sep + facets.query_string
- sep = '&'
+ sep = "&"
if pagination:
base += sep + pagination.query_string
return base
@@ -1913,10 +2031,10 @@ def search_url(cls, lane, query, pagination, facets=None):
base = "http://%s/" % lane.url_name
else:
base = "http://%s/" % lane.display_name
- sep = '?'
+ sep = "?"
if pagination:
base += sep + pagination.query_string
- sep = '&'
+ sep = "&"
if facets:
facet_query_string = facets.query_string
if facet_query_string:
@@ -1930,9 +2048,9 @@ def groups_url(cls, lane, facets=None):
else:
identifier = ""
if facets:
- facet_string = '?' + facets.query_string
+ facet_string = "?" + facets.query_string
else:
- facet_string = ''
+ facet_string = ""
return "http://groups/%s%s" % (identifier, facet_string)
@@ -1960,22 +2078,22 @@ def top_level_title(cls):
class TestAnnotatorWithGroup(TestAnnotator):
-
def group_uri(self, work, license_pool, identifier):
lanes = self.lanes_by_work.get(work, None)
if lanes:
lane_dic = lanes.pop(0)
- lane_name = lane_dic['lane'].display_name
+ lane_name = lane_dic["lane"].display_name
else:
lane_name = str(work.id)
- return ("http://group/%s" % lane_name,
- "Group Title for %s!" % lane_name)
+ return ("http://group/%s" % lane_name, "Group Title for %s!" % lane_name)
def group_uri_for_lane(self, lane):
if lane:
- return ("http://groups/%s" % lane.display_name,
- "Groups of %s" % lane.display_name)
+ return (
+ "http://groups/%s" % lane.display_name,
+ "Groups of %s" % lane.display_name,
+ )
else:
return "http://groups/", "Top-level groups"
diff --git a/opds2_import.py b/opds2_import.py
index 739c06699..0c5fc5bdf 100644
--- a/opds2_import.py
+++ b/opds2_import.py
@@ -376,9 +376,7 @@ def _extract_image_links(self, publication, feed_self_url):
:rtype: List[LinkData]
"""
self._logger.debug(
- "Started extracting image links from {0}".format(
- encode(publication.images)
- )
+ "Started extracting image links from {0}".format(encode(publication.images))
)
if not publication.images:
@@ -397,7 +395,8 @@ def _extract_image_links(self, publication, feed_self_url):
sorted_raw_image_links = list(
reversed(
sorted(
- publication.images.links, key=lambda link: (link.width or 0, link.height or 0)
+ publication.images.links,
+ key=lambda link: (link.width or 0, link.height or 0),
)
)
)
@@ -862,7 +861,9 @@ def _record_coverage_failure(
return failure
- def _record_publication_unrecognizable_identifier(self, publication: opds2_ast.OPDS2Publication) -> None:
+ def _record_publication_unrecognizable_identifier(
+ self, publication: opds2_ast.OPDS2Publication
+ ) -> None:
"""Record a publication's unrecognizable identifier, i.e. identifier that has an unknown format
and could not be parsed by CM.
@@ -875,7 +876,9 @@ def _record_publication_unrecognizable_identifier(self, publication: opds2_ast.O
if original_identifier is None:
self._logger.warning(f"Publication '{title}' does not have an identifier.")
else:
- self._logger.warning(f"Publication # {original_identifier} ('{title}') has an unrecognizable identifier.")
+ self._logger.warning(
+ f"Publication # {original_identifier} ('{title}') has an unrecognizable identifier."
+ )
def extract_next_links(self, feed):
"""Extracts "next" links from the feed.
@@ -937,7 +940,9 @@ def extract_feed_data(self, feed, feed_url=None):
for publication in self._get_publications(feed):
recognized_identifier = self._extract_identifier(publication)
- if not recognized_identifier or not self._is_identifier_allowed(recognized_identifier):
+ if not recognized_identifier or not self._is_identifier_allowed(
+ recognized_identifier
+ ):
self._record_publication_unrecognizable_identifier(publication)
continue
@@ -959,7 +964,9 @@ def extract_feed_data(self, feed, feed_url=None):
if publication:
recognized_identifier = self._extract_identifier(publication)
- if not recognized_identifier or not self._is_identifier_allowed(recognized_identifier):
+ if not recognized_identifier or not self._is_identifier_allowed(
+ recognized_identifier
+ ):
self._record_publication_unrecognizable_identifier(publication)
else:
self._record_coverage_failure(
diff --git a/opds_import.py b/opds_import.py
index eb9b9c0eb..237c48900 100644
--- a/opds_import.py
+++ b/opds_import.py
@@ -2,15 +2,16 @@
import logging
import traceback
from io import BytesIO
+from urllib.parse import quote, urljoin, urlparse
import dateutil
import feedparser
from flask_babel import lazy_gettext as _
from lxml import etree
-from urllib.parse import urljoin, urlparse, quote
from sqlalchemy.orm import aliased
from sqlalchemy.orm.session import Session
+from .classifier import Classifier
from .config import CannotLoadConfiguration, IntegrationException
from .coverage import CoverageFailure
from .metadata_layer import (
@@ -44,12 +45,11 @@
from .model.configuration import ExternalIntegrationLink
from .monitor import CollectionMonitor
from .selftest import HasSelfTests, SelfTestResult
-from .classifier import Classifier
+from .util.datetime_helpers import datetime_utc, to_utc, utc_now
from .util.http import HTTP, BadResponseException
from .util.opds_writer import OPDSFeed, OPDSMessage
from .util.string_helpers import base64
from .util.xmlparser import XMLParser
-from .util.datetime_helpers import datetime_utc, utc_now, to_utc
def parse_identifier(db, identifier):
@@ -75,6 +75,7 @@ def parse_identifier(db, identifier):
class AccessNotAuthenticated(Exception):
"""No authentication is configured for this service"""
+
pass
@@ -85,27 +86,23 @@ class SimplifiedOPDSLookup(object):
@classmethod
def check_content_type(cls, response):
- content_type = response.headers.get('content-type')
+ content_type = response.headers.get("content-type")
if content_type != OPDSFeed.ACQUISITION_FEED_TYPE:
raise BadResponseException.from_response(
- response.url,
- "Wrong media type: %s" % content_type,
- response
+ response.url, "Wrong media type: %s" % content_type, response
)
@classmethod
- def from_protocol(cls, _db, protocol,
- goal=ExternalIntegration.LICENSE_GOAL, library=None
+ def from_protocol(
+ cls, _db, protocol, goal=ExternalIntegration.LICENSE_GOAL, library=None
):
- integration = ExternalIntegration.lookup(
- _db, protocol, goal, library=library
- )
+ integration = ExternalIntegration.lookup(_db, protocol, goal, library=library)
if not integration or not integration.url:
return None
return cls(integration.url)
def __init__(self, base_url):
- if not base_url.endswith('/'):
+ if not base_url.endswith("/"):
base_url += "/"
self.base_url = base_url
@@ -115,9 +112,9 @@ def lookup_endpoint(self):
def _get(self, url, **kwargs):
"""Make an HTTP request. This method is overridden in the mock class."""
- kwargs['timeout'] = kwargs.get('timeout', 300)
- kwargs['allowed_response_codes'] = kwargs.get('allowed_response_codes', [])
- kwargs['allowed_response_codes'] += ['2xx', '3xx']
+ kwargs["timeout"] = kwargs.get("timeout", 300)
+ kwargs["allowed_response_codes"] = kwargs.get("allowed_response_codes", [])
+ kwargs["allowed_response_codes"] += ["2xx", "3xx"]
return HTTP.get_with_timeout(url, **kwargs)
def urn_args(self, identifiers):
@@ -138,47 +135,50 @@ class MetadataWranglerOPDSLookup(SimplifiedOPDSLookup, HasSelfTests):
CARDINALITY = 1
SETTINGS = [
- { "key": ExternalIntegration.URL,
- "label": _("URL"),
- "default": "http://metadata.librarysimplified.org/",
- "required": True,
- "format": "url",
+ {
+ "key": ExternalIntegration.URL,
+ "label": _("URL"),
+ "default": "http://metadata.librarysimplified.org/",
+ "required": True,
+ "format": "url",
},
]
SITEWIDE = True
- ADD_ENDPOINT = 'add'
- ADD_WITH_METADATA_ENDPOINT = 'add_with_metadata'
- METADATA_NEEDED_ENDPOINT = 'metadata_needed'
- REMOVE_ENDPOINT = 'remove'
- UPDATES_ENDPOINT = 'updates'
- CANONICALIZE_ENDPOINT = 'canonical-author-name'
+ ADD_ENDPOINT = "add"
+ ADD_WITH_METADATA_ENDPOINT = "add_with_metadata"
+ METADATA_NEEDED_ENDPOINT = "metadata_needed"
+ REMOVE_ENDPOINT = "remove"
+ UPDATES_ENDPOINT = "updates"
+ CANONICALIZE_ENDPOINT = "canonical-author-name"
@classmethod
def from_config(cls, _db, collection=None):
integration = ExternalIntegration.lookup(
- _db, ExternalIntegration.METADATA_WRANGLER,
- ExternalIntegration.METADATA_GOAL
+ _db,
+ ExternalIntegration.METADATA_WRANGLER,
+ ExternalIntegration.METADATA_GOAL,
)
if not integration:
raise CannotLoadConfiguration(
- "No ExternalIntegration found for the Metadata Wrangler.")
+ "No ExternalIntegration found for the Metadata Wrangler."
+ )
if not integration.url:
raise CannotLoadConfiguration("Metadata Wrangler improperly configured.")
return cls(
- integration.url, shared_secret=integration.password,
- collection=collection
+ integration.url, shared_secret=integration.password, collection=collection
)
@classmethod
def external_integration(cls, _db):
return ExternalIntegration.lookup(
- _db, ExternalIntegration.METADATA_WRANGLER,
- ExternalIntegration.METADATA_GOAL
+ _db,
+ ExternalIntegration.METADATA_WRANGLER,
+ ExternalIntegration.METADATA_GOAL,
)
def _run_self_tests(self, _db, lookup_class=None):
@@ -225,14 +225,12 @@ def _run_collection_self_tests(self):
# Check various endpoints that yield OPDS feeds.
one_day_ago = utc_now() - datetime.timedelta(days=1)
for title, m, args in (
- (
- "Metadata updates in last 24 hours",
- self.updates, [one_day_ago]
- ),
+ ("Metadata updates in last 24 hours", self.updates, [one_day_ago]),
(
"Titles where we could (but haven't) provide information to the metadata wrangler",
- self.metadata_needed, []
- )
+ self.metadata_needed,
+ [],
+ ),
):
yield self._feed_self_test(title, m, *args)
@@ -246,7 +244,7 @@ def _feed_self_test(self, name, method, *args):
# If the server returns a 500 error we don't want to raise an
# exception -- we want to record it as part of the test
# result.
- kwargs = dict(allowed_response_codes=['%sxx' % f for f in range(1,6)])
+ kwargs = dict(allowed_response_codes=["%sxx" % f for f in range(1, 6)])
response = method(*args, **kwargs)
self._annotate_feed_response(result, response)
@@ -272,23 +270,21 @@ def _annotate_feed_response(cls, result, response):
lines = []
lines.append("Request URL: %s" % response.url)
lines.append(
- "Request authorization: %s" %
- response.request.headers.get('Authorization')
+ "Request authorization: %s" % response.request.headers.get("Authorization")
)
lines.append("Status code: %d" % response.status_code)
result.success = response.status_code == 200
if result.success:
feed = feedparser.parse(response.content)
- total_results = feed['feed'].get('opensearch_totalresults')
+ total_results = feed["feed"].get("opensearch_totalresults")
if total_results is not None:
lines.append(
- "Total identifiers registered with this collection: %s" % (
- total_results
- )
+ "Total identifiers registered with this collection: %s"
+ % (total_results)
)
- lines.append("Entries on this page: %d" % len(feed['entries']))
- for i in feed['entries']:
- lines.append(" " + i['title'])
+ lines.append("Entries on this page: %d" % len(feed["entries"]))
+ for i in feed["entries"]:
+ lines.append(" " + i["title"])
result.result = lines
def __init__(self, url, shared_secret=None, collection=None):
@@ -303,40 +299,40 @@ def authenticated(self):
@property
def authorization(self):
if self.authenticated:
- token = 'Bearer ' + base64.b64encode(self.shared_secret)
- return { 'Authorization' : token }
+ token = "Bearer " + base64.b64encode(self.shared_secret)
+ return {"Authorization": token}
return None
@property
def lookup_endpoint(self):
if not (self.authenticated and self.collection):
return self.LOOKUP_ENDPOINT
- return self.collection.metadata_identifier + '/' + self.LOOKUP_ENDPOINT
+ return self.collection.metadata_identifier + "/" + self.LOOKUP_ENDPOINT
def _get(self, url, **kwargs):
if self.authenticated:
- headers = kwargs.get('headers', {})
+ headers = kwargs.get("headers", {})
headers.update(self.authorization)
- kwargs['headers'] = headers
+ kwargs["headers"] = headers
return super(MetadataWranglerOPDSLookup, self)._get(url, **kwargs)
def _post(self, url, data="", **kwargs):
"""Make an HTTP request. This method is overridden in the mock class."""
if self.authenticated:
- headers = kwargs.get('headers', {})
+ headers = kwargs.get("headers", {})
headers.update(self.authorization)
- kwargs['headers'] = headers
- kwargs['timeout'] = kwargs.get('timeout', 120)
- kwargs['allowed_response_codes'] = kwargs.get('allowed_response_codes', [])
- kwargs['allowed_response_codes'] += ['2xx', '3xx']
+ kwargs["headers"] = headers
+ kwargs["timeout"] = kwargs.get("timeout", 120)
+ kwargs["allowed_response_codes"] = kwargs.get("allowed_response_codes", [])
+ kwargs["allowed_response_codes"] += ["2xx", "3xx"]
return HTTP.post_with_timeout(url, data, **kwargs)
def add_args(self, url, arg_string):
- joiner = '?'
+ joiner = "?"
if joiner in url:
# This URL already has an argument (namely: data_source), so
# append the new arguments.
- joiner = '&'
+ joiner = "&"
return url + joiner + arg_string
def get_collection_url(self, endpoint):
@@ -345,15 +341,19 @@ def get_collection_url(self, endpoint):
if not self.collection:
raise ValueError("No Collection provided.")
- data_source = ''
+ data_source = ""
if self.collection.protocol == ExternalIntegration.OPDS_IMPORT:
# Open access OPDS_IMPORT collections need to send a DataSource to
# allow OPDS lookups on the Metadata Wrangler.
- data_source = '?data_source=' + quote(self.collection.data_source.name)
+ data_source = "?data_source=" + quote(self.collection.data_source.name)
- return (self.base_url
+ return (
+ self.base_url
+ self.collection.metadata_identifier
- + '/' + endpoint + data_source)
+ + "/"
+ + endpoint
+ + data_source
+ )
def add(self, identifiers):
"""Add items to an authenticated Metadata Wrangler Collection"""
@@ -372,9 +372,7 @@ def metadata_needed(self, **kwargs):
"""Get a feed of items that need additional metadata to be processed
by the Metadata Wrangler.
"""
- metadata_needed_url = self.get_collection_url(
- self.METADATA_NEEDED_ENDPOINT
- )
+ metadata_needed_url = self.get_collection_url(self.METADATA_NEEDED_ENDPOINT)
return self._get(metadata_needed_url, **kwargs)
def remove(self, identifiers):
@@ -394,8 +392,8 @@ def updates(self, last_update_time, **kwargs):
"""
url = self.get_collection_url(self.UPDATES_ENDPOINT)
if last_update_time:
- formatted_time = last_update_time.strftime('%Y-%m-%dT%H:%M:%SZ')
- url = self.add_args(url, ('last_update_time='+formatted_time))
+ formatted_time = last_update_time.strftime("%Y-%m-%dT%H:%M:%SZ")
+ url = self.add_args(url, ("last_update_time=" + formatted_time))
logging.info("Metadata Wrangler Collection Updates URL: %s", url)
return self._get(url, **kwargs)
@@ -408,9 +406,7 @@ def canonicalize_author_name(self, identifier, working_display_name):
(i.e. the name format human being used as opposed to the name
that goes into library records).
"""
- args = "display_name=%s" % (
- quote(working_display_name.encode("utf8"))
- )
+ args = "display_name=%s" % (quote(working_display_name.encode("utf8")))
if identifier:
args += "&urn=%s" % quote(identifier.urn)
url = self.base_url + self.CANONICALIZE_ENDPOINT + "?" + args
@@ -419,7 +415,6 @@ def canonicalize_author_name(self, identifier, working_display_name):
class MockSimplifiedOPDSLookup(SimplifiedOPDSLookup):
-
def __init__(self, *args, **kwargs):
self.requests = []
self.responses = []
@@ -427,45 +422,50 @@ def __init__(self, *args, **kwargs):
def queue_response(self, status_code, headers={}, content=None):
from .testing import MockRequestsResponse
- self.responses.insert(
- 0, MockRequestsResponse(status_code, headers, content)
- )
+
+ self.responses.insert(0, MockRequestsResponse(status_code, headers, content))
def _get(self, url, *args, **kwargs):
self.requests.append((url, args, kwargs))
response = self.responses.pop()
return HTTP._process_response(
- url, response, kwargs.get('allowed_response_codes'),
- kwargs.get('disallowed_response_codes')
+ url,
+ response,
+ kwargs.get("allowed_response_codes"),
+ kwargs.get("disallowed_response_codes"),
)
-class MockMetadataWranglerOPDSLookup(MockSimplifiedOPDSLookup, MetadataWranglerOPDSLookup):
-
+class MockMetadataWranglerOPDSLookup(
+ MockSimplifiedOPDSLookup, MetadataWranglerOPDSLookup
+):
def _post(self, url, *args, **kwargs):
self.requests.append((url, args, kwargs))
response = self.responses.pop()
return HTTP._process_response(
- url, response, kwargs.get('allowed_response_codes'),
- kwargs.get('disallowed_response_codes')
+ url,
+ response,
+ kwargs.get("allowed_response_codes"),
+ kwargs.get("disallowed_response_codes"),
)
class OPDSXMLParser(XMLParser):
- NAMESPACES = { "simplified": "http://librarysimplified.org/terms/",
- "app" : "http://www.w3.org/2007/app",
- "dcterms" : "http://purl.org/dc/terms/",
- "dc" : "http://purl.org/dc/elements/1.1/",
- "opds": "http://opds-spec.org/2010/catalog",
- "schema" : "http://schema.org/",
- "atom" : "http://www.w3.org/2005/Atom",
- "drm": "http://librarysimplified.org/terms/drm",
+ NAMESPACES = {
+ "simplified": "http://librarysimplified.org/terms/",
+ "app": "http://www.w3.org/2007/app",
+ "dcterms": "http://purl.org/dc/terms/",
+ "dc": "http://purl.org/dc/elements/1.1/",
+ "opds": "http://opds-spec.org/2010/catalog",
+ "schema": "http://schema.org/",
+ "atom": "http://www.w3.org/2005/Atom",
+ "drm": "http://librarysimplified.org/terms/drm",
}
class OPDSImporter(object):
- """ Imports editions and license pools from an OPDS feed.
+ """Imports editions and license pools from an OPDS feed.
Creates Edition, LicensePool and Work rows in the database, if those
don't already exist.
@@ -475,12 +475,13 @@ class OPDSImporter(object):
"""
COULD_NOT_CREATE_LICENSE_POOL = (
- "No existing license pool for this identifier and no way of creating one.")
+ "No existing license pool for this identifier and no way of creating one."
+ )
NAME = ExternalIntegration.OPDS_IMPORT
DESCRIPTION = _("Import books from a publicly-accessible OPDS feed.")
- NO_DEFAULT_AUDIENCE = ''
+ NO_DEFAULT_AUDIENCE = ""
# These settings are used by all OPDS-derived import methods.
BASE_SETTINGS = [
@@ -488,7 +489,7 @@ class OPDSImporter(object):
"key": Collection.EXTERNAL_ACCOUNT_ID_KEY,
"label": _("URL"),
"required": True,
- "format": "url"
+ "format": "url",
},
{
"key": Collection.DATA_SOURCE_NAME_SETTING,
@@ -504,21 +505,14 @@ class OPDSImporter(object):
),
"type": "select",
"format": "narrow",
- "options": [
- {
- "key": NO_DEFAULT_AUDIENCE,
- "label": _("No default audience")
- }
- ] + [
- {
- "key": audience,
- "label": audience
- }
+ "options": [{"key": NO_DEFAULT_AUDIENCE, "label": _("No default audience")}]
+ + [
+ {"key": audience, "label": audience}
for audience in sorted(Classifier.AUDIENCES)
],
"default": NO_DEFAULT_AUDIENCE,
"required": False,
- "readOnly": True
+ "readOnly": True,
},
]
@@ -528,24 +522,32 @@ class OPDSImporter(object):
{
"key": ExternalIntegration.USERNAME,
"label": _("Username"),
- "description": _("If HTTP Basic authentication is required to access the OPDS feed (it usually isn't), enter the username here."),
+ "description": _(
+ "If HTTP Basic authentication is required to access the OPDS feed (it usually isn't), enter the username here."
+ ),
},
{
"key": ExternalIntegration.PASSWORD,
"label": _("Password"),
- "description": _("If HTTP Basic authentication is required to access the OPDS feed (it usually isn't), enter the password here."),
+ "description": _(
+ "If HTTP Basic authentication is required to access the OPDS feed (it usually isn't), enter the password here."
+ ),
},
{
"key": ExternalIntegration.CUSTOM_ACCEPT_HEADER,
"label": _("Custom accept header"),
"required": False,
- "description": _("Some servers expect an accept header to decide which file to send. You can use */* if the server doesn't expect anything."),
- "default": ','.join([
- OPDSFeed.ACQUISITION_FEED_TYPE,
- "application/atom+xml;q=0.9",
- "application/xml;q=0.8",
- "*/*;q=0.1",
- ])
+ "description": _(
+ "Some servers expect an accept header to decide which file to send. You can use */* if the server doesn't expect anything."
+ ),
+ "default": ",".join(
+ [
+ OPDSFeed.ACQUISITION_FEED_TYPE,
+ "application/atom+xml;q=0.9",
+ "application/xml;q=0.8",
+ "*/*;q=0.1",
+ ]
+ ),
},
{
"key": ExternalIntegration.PRIMARY_IDENTIFIER_SOURCE,
@@ -554,14 +556,11 @@ class OPDSImporter(object):
"description": _("Which book identifier to use as ID."),
"type": "select",
"options": [
- {
- "key": "",
- "label": _("(Default) Use ")
- },
- {
- "key": ExternalIntegration.DCTERMS_IDENTIFIER,
- "label": _("Use first, if not exist use ")
- },
+ {"key": "", "label": _("(Default) Use ")},
+ {
+ "key": ExternalIntegration.DCTERMS_IDENTIFIER,
+ "label": _("Use first, if not exist use "),
+ },
],
},
]
@@ -576,10 +575,17 @@ class OPDSImporter(object):
# when they show up in tags.
SUCCESS_STATUS_CODES = None
- def __init__(self, _db, collection, data_source_name=None,
- identifier_mapping=None, http_get=None,
- metadata_client=None, content_modifier=None,
- map_from_collection=None, mirrors=None,
+ def __init__(
+ self,
+ _db,
+ collection,
+ data_source_name=None,
+ identifier_mapping=None,
+ http_get=None,
+ metadata_client=None,
+ content_modifier=None,
+ map_from_collection=None,
+ mirrors=None,
):
""":param collection: LicensePools created by this OPDS import
will be associated with the given Collection. If this is None,
@@ -629,16 +635,27 @@ def __init__(self, _db, collection, data_source_name=None,
self.data_source_name = data_source_name
self.identifier_mapping = identifier_mapping
try:
- self.metadata_client = metadata_client or MetadataWranglerOPDSLookup.from_config(_db, collection=collection)
+ self.metadata_client = (
+ metadata_client
+ or MetadataWranglerOPDSLookup.from_config(_db, collection=collection)
+ )
except CannotLoadConfiguration:
# The Metadata Wrangler isn't configured, but we can import without it.
- self.log.warn("Metadata Wrangler integration couldn't be loaded, importing without it.")
+ self.log.warn(
+ "Metadata Wrangler integration couldn't be loaded, importing without it."
+ )
self.metadata_client = None
# Check to see if a mirror for each purpose was passed in.
# If not, then attempt to create one.
- covers_mirror = mirrors.get(ExternalIntegrationLink.COVERS, None) if mirrors else None
- books_mirror = mirrors.get(ExternalIntegrationLink.OPEN_ACCESS_BOOKS, None) if mirrors else None
+ covers_mirror = (
+ mirrors.get(ExternalIntegrationLink.COVERS, None) if mirrors else None
+ )
+ books_mirror = (
+ mirrors.get(ExternalIntegrationLink.OPEN_ACCESS_BOOKS, None)
+ if mirrors
+ else None
+ )
self.primary_identifier_source = None
if collection:
if not covers_mirror:
@@ -680,10 +697,12 @@ def data_source(self):
"""Look up or create a DataSource object representing the
source of this OPDS feed.
"""
- offers_licenses = (self.collection is not None)
+ offers_licenses = self.collection is not None
return DataSource.lookup(
- self._db, self.data_source_name, autocreate=True,
- offers_licenses = offers_licenses
+ self._db,
+ self.data_source_name,
+ autocreate=True,
+ offers_licenses=offers_licenses,
)
def assert_importable_content(self, feed, feed_url, max_get_attempts=5):
@@ -709,15 +728,16 @@ def assert_importable_content(self, feed, feed_url, max_get_attempts=5):
return success
get_attempts += 1
if get_attempts >= max_get_attempts:
- error = "Was unable to GET supposedly open-access content such as %s (tried %s times)" % (
- url, get_attempts
+ error = (
+ "Was unable to GET supposedly open-access content such as %s (tried %s times)"
+ % (url, get_attempts)
)
explanation = "This might be an OPDS For Distributors feed, or it might require different authentication credentials."
raise IntegrationException(error, explanation)
raise IntegrationException(
"No open-access links were found in the OPDS feed.",
- "This might be an OPDS for Distributors feed."
+ "This might be an OPDS for Distributors feed.",
)
@classmethod
@@ -750,7 +770,9 @@ def _is_open_access_link(self, url, type):
return "Found a book-like thing at %s" % url
self.log.error(
"Supposedly open-access link %s didn't give us a book. Status=%s, body length=%s",
- url, status, len(body)
+ url,
+ status,
+ len(body),
)
return False
@@ -797,7 +819,12 @@ def import_from_feed(self, feed, feed_url=None):
self.log.error("Error importing an OPDS item", exc_info=e)
identifier, ignore = Identifier.parse_urn(self._db, key)
data_source = self.data_source
- failure = CoverageFailure(identifier, traceback.format_exc(), data_source=data_source, transient=False)
+ failure = CoverageFailure(
+ identifier,
+ traceback.format_exc(),
+ data_source=data_source,
+ transient=False,
+ )
failures[key] = failure
# clean up any edition might have created
if key in imported_editions:
@@ -814,17 +841,25 @@ def import_from_feed(self, feed, feed_url=None):
except Exception as e:
identifier, ignore = Identifier.parse_urn(self._db, key)
data_source = self.data_source
- failure = CoverageFailure(identifier, traceback.format_exc(), data_source=data_source, transient=False)
+ failure = CoverageFailure(
+ identifier,
+ traceback.format_exc(),
+ data_source=data_source,
+ transient=False,
+ )
failures[key] = failure
- return list(imported_editions.values()), list(pools.values()), list(works.values()), failures
+ return (
+ list(imported_editions.values()),
+ list(pools.values()),
+ list(works.values()),
+ failures,
+ )
- def import_edition_from_metadata(
- self, metadata
- ):
- """ For the passed-in Metadata object, see if can find or create an Edition
- in the database. Also create a LicensePool if the Metadata has
- CirculationData in it.
+ def import_edition_from_metadata(self, metadata):
+ """For the passed-in Metadata object, see if can find or create an Edition
+ in the database. Also create a LicensePool if the Metadata has
+ CirculationData in it.
"""
# Locate or create an Edition for this book.
edition, is_new_edition = metadata.edition(self._db)
@@ -841,8 +876,10 @@ def import_edition_from_metadata(
http_get=self.http_get,
)
metadata.apply(
- edition=edition, collection=self.collection,
- metadata_client=self.metadata_client, replace=policy
+ edition=edition,
+ collection=self.collection,
+ metadata_client=self.metadata_client,
+ replace=policy,
)
return edition
@@ -862,8 +899,10 @@ def update_work_for_edition(self, edition):
# different data source or a different collection, that's fine
# too.
pool = get_one(
- self._db, LicensePool, identifier=edition.primary_identifier,
- on_multiple='interchangeable'
+ self._db,
+ LicensePool,
+ identifier=edition.primary_identifier,
+ on_multiple="interchangeable",
)
if pool:
@@ -889,12 +928,11 @@ def extract_next_links(self, feed):
parsed = feedparser.parse(feed)
else:
parsed = feed
- feed = parsed['feed']
+ feed = parsed["feed"]
next_links = []
- if feed and 'links' in feed:
+ if feed and "links" in feed:
next_links = [
- link['href'] for link in feed['links']
- if link['rel'] == 'next'
+ link["href"] for link in feed["links"] if link["rel"] == "next"
]
return next_links
@@ -905,7 +943,7 @@ def extract_last_update_dates(self, feed):
parsed_feed = feed
dates = [
self.last_update_date_for_feedparser_entry(entry)
- for entry in parsed_feed['entries']
+ for entry in parsed_feed["entries"]
]
return [x for x in dates if x and x[1]]
@@ -928,14 +966,16 @@ def build_identifier_mapping(self, external_urns):
external_identifiers = list(identifiers_by_urn.values())
internal_identifier = aliased(Identifier)
- qu = self._db.query(Identifier, internal_identifier)\
- .join(Identifier.inbound_equivalencies)\
- .join(internal_identifier, Equivalency.input)\
- .join(internal_identifier.licensed_through)\
+ qu = (
+ self._db.query(Identifier, internal_identifier)
+ .join(Identifier.inbound_equivalencies)
+ .join(internal_identifier, Equivalency.input)
+ .join(internal_identifier.licensed_through)
.filter(
Identifier.id.in_([x.id for x in external_identifiers]),
- LicensePool.collection==self.collection
+ LicensePool.collection == self.collection,
)
+ )
for external_identifier, internal_identifier in qu:
mapping[external_identifier] = internal_identifier
@@ -947,7 +987,9 @@ def extract_feed_data(self, feed, feed_url=None):
with associated messages and next_links.
"""
data_source = self.data_source
- fp_metadata, fp_failures = self.extract_data_from_feedparser(feed=feed, data_source=data_source)
+ fp_metadata, fp_failures = self.extract_data_from_feedparser(
+ feed=feed, data_source=data_source
+ )
# gets: medium, measurements, links, contributors, etc.
xml_data_meta, xml_failures = self.extract_metadata_from_elementtree(
feed, data_source=data_source, feed_url=feed_url, do_get=self.http_get
@@ -955,7 +997,9 @@ def extract_feed_data(self, feed, feed_url=None):
if self.map_from_collection:
# Build the identifier_mapping based on the Collection.
- self.build_identifier_mapping(list(fp_metadata.keys()) + list(fp_failures.keys()))
+ self.build_identifier_mapping(
+ list(fp_metadata.keys()) + list(fp_failures.keys())
+ )
# translate the id in failures to identifier.urn
identified_failures = {}
@@ -975,25 +1019,28 @@ def extract_feed_data(self, feed, feed_url=None):
# first value from the dcterms identifier, that came from the metadata as an
# IdentifierData object and it must be validated as a foreign_id before be used
# as and external_identifier.
- dcterms_ids = xml_data_dict.get('dcterms_identifiers', [])
+ dcterms_ids = xml_data_dict.get("dcterms_identifiers", [])
if len(dcterms_ids) > 0:
external_identifier, ignore = Identifier.for_foreign_id(
- self._db, dcterms_ids[0].type, dcterms_ids[0].identifier
+ self._db, dcterms_ids[0].type, dcterms_ids[0].identifier
)
# the external identifier will be add later, so it must be removed at this point
new_identifiers = dcterms_ids[1:]
# Id must be in the identifiers with lower weight.
id_type, id_identifier = Identifier.type_and_identifier_for_urn(id)
id_weight = 1
- new_identifiers.append(IdentifierData(id_type, id_identifier, id_weight))
- xml_data_dict['identifiers'] = new_identifiers
+ new_identifiers.append(
+ IdentifierData(id_type, id_identifier, id_weight)
+ )
+ xml_data_dict["identifiers"] = new_identifiers
if external_identifier is None:
external_identifier, ignore = Identifier.parse_urn(self._db, id)
if self.identifier_mapping:
internal_identifier = self.identifier_mapping.get(
- external_identifier, external_identifier)
+ external_identifier, external_identifier
+ )
else:
internal_identifier = external_identifier
@@ -1002,16 +1049,15 @@ def extract_feed_data(self, feed, feed_url=None):
continue
identifier_obj = IdentifierData(
- type=internal_identifier.type,
- identifier=internal_identifier.identifier
+ type=internal_identifier.type, identifier=internal_identifier.identifier
)
# form the Metadata object
combined_meta = self.combine(m_data_dict, xml_data_dict)
- if combined_meta.get('data_source') is None:
- combined_meta['data_source'] = self.data_source_name
+ if combined_meta.get("data_source") is None:
+ combined_meta["data_source"] = self.data_source_name
- combined_meta['primary_identifier'] = identifier_obj
+ combined_meta["primary_identifier"] = identifier_obj
metadata[internal_identifier.urn] = Metadata(**combined_meta)
@@ -1020,8 +1066,8 @@ def extract_feed_data(self, feed, feed_url=None):
# would result.
c_data_dict = None
if self.collection:
- c_circulation_dict = m_data_dict.get('circulation')
- xml_circulation_dict = xml_data_dict.get('circulation', {})
+ c_circulation_dict = m_data_dict.get("circulation")
+ xml_circulation_dict = xml_data_dict.get("circulation", {})
c_data_dict = self.combine(c_circulation_dict, xml_circulation_dict)
# Unless there's something useful in c_data_dict, we're
@@ -1032,13 +1078,13 @@ def extract_feed_data(self, feed, feed_url=None):
if c_data_dict:
circ_links_dict = {}
# extract just the links to pass to CirculationData constructor
- if 'links' in xml_data_dict:
- circ_links_dict['links'] = xml_data_dict['links']
+ if "links" in xml_data_dict:
+ circ_links_dict["links"] = xml_data_dict["links"]
combined_circ = self.combine(c_data_dict, circ_links_dict)
- if combined_circ.get('data_source') is None:
- combined_circ['data_source'] = self.data_source_name
+ if combined_circ.get("data_source") is None:
+ combined_circ["data_source"] = self.data_source_name
- combined_circ['primary_identifier'] = identifier_obj
+ combined_circ["primary_identifier"] = identifier_obj
circulation = CirculationData(**combined_circ)
self._add_format_data(circulation)
@@ -1072,7 +1118,8 @@ def handle_failure(self, urn, failure):
# The identifier found in the OPDS feed is different from
# the identifier we want to export.
internal_identifier = self.identifier_mapping.get(
- external_identifier, external_identifier)
+ external_identifier, external_identifier
+ )
else:
internal_identifier = external_identifier
if isinstance(failure, Identifier):
@@ -1135,8 +1182,10 @@ def extract_data_from_feedparser(self, feed, data_source):
feedparser_parsed = feedparser.parse(feed)
values = {}
failures = {}
- for entry in feedparser_parsed['entries']:
- identifier, detail, failure = self.data_detail_for_feedparser_entry(entry=entry, data_source=data_source)
+ for entry in feedparser_parsed["entries"]:
+ identifier, detail, failure = self.data_detail_for_feedparser_entry(
+ entry=entry, data_source=data_source
+ )
if identifier:
if failure:
@@ -1147,11 +1196,16 @@ def extract_data_from_feedparser(self, feed, data_source):
else:
# That's bad. Can't make an item-specific error message, but write to
# log that something very wrong happened.
- logging.error("Tried to parse an element without a valid identifier. feed=%s" % feed)
+ logging.error(
+ "Tried to parse an element without a valid identifier. feed=%s"
+ % feed
+ )
return values, failures
@classmethod
- def extract_metadata_from_elementtree(cls, feed, data_source, feed_url=None, do_get=None):
+ def extract_metadata_from_elementtree(
+ cls, feed, data_source, feed_url=None, do_get=None
+ ):
"""Parse the OPDS as XML and extract all author and subject
information, as well as ratings and medium.
@@ -1178,16 +1232,14 @@ def extract_metadata_from_elementtree(cls, feed, data_source, feed_url=None, do_
# for this, so if anyone actually uses that we'll get around
# to checking it.
if not feed_url:
- links = [child.attrib for child in root.getroot() if 'link' in child.tag]
- self_links = [link['href'] for link in links if link.get('rel') == 'self']
+ links = [child.attrib for child in root.getroot() if "link" in child.tag]
+ self_links = [link["href"] for link in links if link.get("rel") == "self"]
if self_links:
feed_url = self_links[0]
# First, turn Simplified tags into CoverageFailure
# objects.
- for failure in cls.coveragefailures_from_messages(
- data_source, parser, root
- ):
+ for failure in cls.coveragefailures_from_messages(data_source, parser, root):
if isinstance(failure, Identifier):
# The Simplified tag does not actually
# represent a failure -- it was turned into an
@@ -1198,7 +1250,7 @@ def extract_metadata_from_elementtree(cls, feed, data_source, feed_url=None, do_
failures[urn] = failure
# Then turn Atom tags into Metadata objects.
- for entry in parser._xpath(root, '/atom:feed/atom:entry'):
+ for entry in parser._xpath(root, "/atom:feed/atom:entry"):
identifier, detail, failure = cls.detail_for_elementtree_entry(
parser, entry, data_source, feed_url, do_get=do_get
)
@@ -1217,8 +1269,8 @@ def _datetime(cls, entry, key):
return datetime_utc(*value[:6])
def last_update_date_for_feedparser_entry(self, entry):
- identifier = entry.get('id')
- updated = self._datetime(entry, 'updated_parsed')
+ identifier = entry.get("id")
+ updated = self._datetime(entry, "updated_parsed")
return (identifier, updated)
@classmethod
@@ -1228,7 +1280,7 @@ def data_detail_for_feedparser_entry(cls, entry, data_source):
:return: A 3-tuple (identifier, kwargs for Metadata constructor, failure)
"""
- identifier = entry.get('id')
+ identifier = entry.get("id")
if not identifier:
return None, None, None
@@ -1241,8 +1293,7 @@ def data_detail_for_feedparser_entry(cls, entry, data_source):
_db = Session.object_session(data_source)
identifier_obj, ignore = Identifier.parse_urn(_db, identifier)
failure = CoverageFailure(
- identifier_obj, traceback.format_exc(), data_source,
- transient=True
+ identifier_obj, traceback.format_exc(), data_source, transient=True
)
return identifier, None, failure
@@ -1252,10 +1303,10 @@ def _data_detail_for_feedparser_entry(cls, entry, metadata_data_source):
entry. This method can be overridden in tests to check that callers handle things
properly when it throws an exception.
"""
- title = entry.get('title', None)
+ title = entry.get("title", None)
if title == OPDSFeed.NO_TITLE:
title = None
- subtitle = entry.get('schema_alternativeheadline', None)
+ subtitle = entry.get("schema_alternativeheadline", None)
# Generally speaking, a data source will provide either
# metadata (e.g. the Simplified metadata wrangler) or both
@@ -1274,10 +1325,10 @@ def _data_detail_for_feedparser_entry(cls, entry, metadata_data_source):
# tag to keep track of which data
# source provides the circulation data.
circulation_data_source = metadata_data_source
- circulation_data_source_tag = entry.get('bibframe_distribution')
+ circulation_data_source_tag = entry.get("bibframe_distribution")
if circulation_data_source_tag:
circulation_data_source_name = circulation_data_source_tag.get(
- 'bibframe:providername'
+ "bibframe:providername"
)
if circulation_data_source_name:
_db = Session.object_session(metadata_data_source)
@@ -1285,47 +1336,46 @@ def _data_detail_for_feedparser_entry(cls, entry, metadata_data_source):
# that's what the is there
# to say.
circulation_data_source = DataSource.lookup(
- _db, circulation_data_source_name, autocreate=True,
- offers_licenses=True
+ _db,
+ circulation_data_source_name,
+ autocreate=True,
+ offers_licenses=True,
)
if not circulation_data_source:
raise ValueError(
- "Unrecognized circulation data source: %s" % (
- circulation_data_source_name
- )
+ "Unrecognized circulation data source: %s"
+ % (circulation_data_source_name)
)
- last_opds_update = cls._datetime(entry, 'updated_parsed')
+ last_opds_update = cls._datetime(entry, "updated_parsed")
- publisher = entry.get('publisher', None)
+ publisher = entry.get("publisher", None)
if not publisher:
- publisher = entry.get('dcterms_publisher', None)
+ publisher = entry.get("dcterms_publisher", None)
- language = entry.get('language', None)
+ language = entry.get("language", None)
if not language:
- language = entry.get('dcterms_language', None)
+ language = entry.get("dcterms_language", None)
links = []
def summary_to_linkdata(detail):
if not detail:
return None
- if not 'value' in detail or not detail['value']:
+ if not "value" in detail or not detail["value"]:
return None
- content = detail['value']
- media_type = detail.get('type', 'text/plain')
+ content = detail["value"]
+ media_type = detail.get("type", "text/plain")
return cls.make_link_data(
- rel=Hyperlink.DESCRIPTION,
- media_type=media_type,
- content=content
+ rel=Hyperlink.DESCRIPTION, media_type=media_type, content=content
)
- summary_detail = entry.get('summary_detail', None)
+ summary_detail = entry.get("summary_detail", None)
link = summary_to_linkdata(summary_detail)
if link:
links.append(link)
- for content_detail in entry.get('content', []):
+ for content_detail in entry.get("content", []):
link = summary_to_linkdata(content_detail)
if link:
links.append(link)
@@ -1350,7 +1400,7 @@ def summary_to_linkdata(detail):
links=list(links),
default_rights_uri=rights_uri,
)
- kwargs_meta['circulation'] = kwargs_circ
+ kwargs_meta["circulation"] = kwargs_circ
return kwargs_meta
@classmethod
@@ -1366,7 +1416,7 @@ def rights_uri_from_feedparser_entry(cls, entry):
:return: A rights URI.
"""
- rights = entry.get('rights', "")
+ rights = entry.get("rights", "")
return cls.rights_uri(rights)
@classmethod
@@ -1375,7 +1425,7 @@ def rights_uri_from_entry_tag(cls, entry):
:return: A rights URI.
"""
- rights = cls.PARSER_CLASS._xpath1(entry, 'rights')
+ rights = cls.PARSER_CLASS._xpath1(entry, "rights")
if rights:
return cls.rights_uri(rights)
@@ -1384,19 +1434,19 @@ def extract_messages(cls, parser, feed_tag):
"""Extract tags from an OPDS feed and convert
them into OPDSMessage objects.
"""
- path = '/atom:feed/simplified:message'
+ path = "/atom:feed/simplified:message"
for message_tag in parser._xpath(feed_tag, path):
# First thing to do is determine which Identifier we're
# talking about.
- identifier_tag = parser._xpath1(message_tag, 'atom:id')
+ identifier_tag = parser._xpath1(message_tag, "atom:id")
if identifier_tag is None:
urn = None
else:
urn = identifier_tag.text
# What status code is associated with the message?
- status_code_tag = parser._xpath1(message_tag, 'simplified:status_code')
+ status_code_tag = parser._xpath1(message_tag, "simplified:status_code")
if status_code_tag is None:
status_code = None
else:
@@ -1406,9 +1456,9 @@ def extract_messages(cls, parser, feed_tag):
status_code = None
# What is the human-readable message?
- description_tag = parser._xpath1(message_tag, 'schema:description')
+ description_tag = parser._xpath1(message_tag, "schema:description")
if description_tag is None:
- description = ''
+ description = ""
else:
description = description_tag.text
@@ -1445,8 +1495,7 @@ def coveragefailure_from_message(cls, data_source, message):
# Identifier so we can't turn it into a CoverageFailure.
return None
- if (cls.SUCCESS_STATUS_CODES
- and message.status_code in cls.SUCCESS_STATUS_CODES):
+ if cls.SUCCESS_STATUS_CODES and message.status_code in cls.SUCCESS_STATUS_CODES:
# This message is telling us that nothing went wrong. It
# should be treated as a success.
return identifier
@@ -1465,17 +1514,15 @@ def coveragefailure_from_message(cls, data_source, message):
elif description:
exception = description
else:
- exception = 'No detail provided.'
+ exception = "No detail provided."
# All these CoverageFailures are transient because ATM we can
# only assume that the server will eventually have the data.
- return CoverageFailure(
- identifier, exception, data_source, transient=True
- )
+ return CoverageFailure(identifier, exception, data_source, transient=True)
@classmethod
def detail_for_elementtree_entry(
- cls, parser, entry_tag, data_source, feed_url=None, do_get=None
+ cls, parser, entry_tag, data_source, feed_url=None, do_get=None
):
"""Turn an tag into a dictionary of metadata that can be
@@ -1483,7 +1530,7 @@ def detail_for_elementtree_entry(
:return: A 2-tuple (identifier, kwargs)
"""
- identifier = parser._xpath1(entry_tag, 'atom:id')
+ identifier = parser._xpath1(entry_tag, "atom:id")
if identifier is None or not identifier.text:
# This tag doesn't identify a book so we
# can't derive any information from it.
@@ -1500,13 +1547,14 @@ def detail_for_elementtree_entry(
_db = Session.object_session(data_source)
identifier_obj, ignore = Identifier.parse_urn(_db, identifier)
failure = CoverageFailure(
- identifier_obj, traceback.format_exc(), data_source,
- transient=True
+ identifier_obj, traceback.format_exc(), data_source, transient=True
)
return identifier, None, failure
@classmethod
- def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get=None):
+ def _detail_for_elementtree_entry(
+ cls, parser, entry_tag, feed_url=None, do_get=None
+ ):
"""Helper method that extracts metadata and circulation data from an elementtree
entry. This method can be overridden in tests to check that callers handle things
properly when it throws an exception.
@@ -1520,43 +1568,45 @@ def _detail_for_elementtree_entry(cls, parser, entry_tag, feed_url=None, do_get=
v = cls.extract_identifier(id_tag)
if v:
alternate_identifiers.append(v)
- data['dcterms_identifiers'] = alternate_identifiers
+ data["dcterms_identifiers"] = alternate_identifiers
# If exist another identifer, add here
- data['identifiers'] = data['dcterms_identifiers']
+ data["identifiers"] = data["dcterms_identifiers"]
- data['contributors'] = []
- for author_tag in parser._xpath(entry_tag, 'atom:author'):
+ data["contributors"] = []
+ for author_tag in parser._xpath(entry_tag, "atom:author"):
contributor = cls.extract_contributor(parser, author_tag)
if contributor is not None:
- data['contributors'].append(contributor)
+ data["contributors"].append(contributor)
- data['subjects'] = [
+ data["subjects"] = [
cls.extract_subject(parser, category_tag)
- for category_tag in parser._xpath(entry_tag, 'atom:category')
+ for category_tag in parser._xpath(entry_tag, "atom:category")
]
ratings = []
- for rating_tag in parser._xpath(entry_tag, 'schema:Rating'):
+ for rating_tag in parser._xpath(entry_tag, "schema:Rating"):
v = cls.extract_measurement(rating_tag)
if v:
ratings.append(v)
- data['measurements'] = ratings
+ data["measurements"] = ratings
rights_uri = cls.rights_uri_from_entry_tag(entry_tag)
- data['links'] = cls.consolidate_links([
- cls.extract_link(link_tag, feed_url, rights_uri)
- for link_tag in parser._xpath(entry_tag, 'atom:link')
- ])
+ data["links"] = cls.consolidate_links(
+ [
+ cls.extract_link(link_tag, feed_url, rights_uri)
+ for link_tag in parser._xpath(entry_tag, "atom:link")
+ ]
+ )
- derived_medium = cls.get_medium_from_links(data['links'])
- data['medium'] = cls.extract_medium(entry_tag, derived_medium)
+ derived_medium = cls.get_medium_from_links(data["links"])
+ data["medium"] = cls.extract_medium(entry_tag, derived_medium)
- series_tag = parser._xpath(entry_tag, 'schema:Series')
+ series_tag = parser._xpath(entry_tag, "schema:Series")
if series_tag:
- data['series'], data['series_position'] = cls.extract_series(series_tag[0])
+ data["series"], data["series_position"] = cls.extract_series(series_tag[0])
- issued_tag = parser._xpath(entry_tag, 'dcterms:issued')
+ issued_tag = parser._xpath(entry_tag, "dcterms:issued")
if issued_tag:
date_string = issued_tag[0].text
# By default, the date for strings that only have a year will
@@ -1575,9 +1625,10 @@ def get_medium_from_links(cls, links):
"""Get medium if derivable from information in an acquisition link."""
derived = None
for link in links:
- if (not link.rel
+ if (
+ not link.rel
or not link.media_type
- or not link.rel.startswith('http://opds-spec.org/acquisition/')
+ or not link.rel.startswith("http://opds-spec.org/acquisition/")
):
continue
derived = Edition.medium_from_media_type(link.media_type)
@@ -1589,7 +1640,9 @@ def get_medium_from_links(cls, links):
def extract_identifier(cls, identifier_tag):
"""Turn a tag into an IdentifierData object."""
try:
- type, identifier = Identifier.type_and_identifier_for_urn(identifier_tag.text.lower())
+ type, identifier = Identifier.type_and_identifier_for_urn(
+ identifier_tag.text.lower()
+ )
return IdentifierData(type, identifier)
except ValueError:
return None
@@ -1606,13 +1659,11 @@ def extract_medium(cls, entry_tag, default=Edition.BOOK_MEDIUM):
return default
medium = None
- additional_type = entry_tag.get('{http://schema.org/}additionalType')
+ additional_type = entry_tag.get("{http://schema.org/}additionalType")
if additional_type:
- medium = Edition.additional_type_to_medium.get(
- additional_type, None
- )
+ medium = Edition.additional_type_to_medium.get(additional_type, None)
if not medium:
- format_tag = entry_tag.find('{http://purl.org/dc/terms/}format')
+ format_tag = entry_tag.find("{http://purl.org/dc/terms/}format")
if format_tag is not None:
media_type = format_tag.text
medium = Edition.medium_from_media_type(media_type)
@@ -1622,8 +1673,8 @@ def extract_medium(cls, entry_tag, default=Edition.BOOK_MEDIUM):
def extract_contributor(cls, parser, author_tag):
"""Turn an tag into a ContributorData object."""
subtag = parser.text_of_optional_subtag
- sort_name = subtag(author_tag, 'simplified:sort_name')
- display_name = subtag(author_tag, 'atom:name')
+ sort_name = subtag(author_tag, "simplified:sort_name")
+ display_name = subtag(author_tag, "atom:name")
family_name = subtag(author_tag, "simplified:family_name")
wikipedia_name = subtag(author_tag, "simplified:wikipedia_name")
# TODO: we need a way of conveying roles. I believe Bibframe
@@ -1636,13 +1687,16 @@ def extract_contributor(cls, parser, author_tag):
viaf = None
if sort_name or display_name or viaf:
return ContributorData(
- sort_name=sort_name, display_name=display_name,
+ sort_name=sort_name,
+ display_name=display_name,
family_name=family_name,
wikipedia_name=wikipedia_name,
- roles=None
+ roles=None,
)
- logging.info("Refusing to create ContributorData for contributor with no sort name, display name, or VIAF.")
+ logging.info(
+ "Refusing to create ContributorData for contributor with no sort name, display name, or VIAF."
+ )
return None
@classmethod
@@ -1652,7 +1706,7 @@ def extract_subject(cls, parser, category_tag):
# Retrieve the type of this subject - FAST, Dewey Decimal,
# etc.
- scheme = attr.get('scheme')
+ scheme = attr.get("scheme")
subject_type = Subject.by_uri.get(scheme)
if not subject_type:
# We can't represent this subject because we don't
@@ -1661,22 +1715,17 @@ def extract_subject(cls, parser, category_tag):
# Retrieve the term (e.g. "827") and human-readable name
# (e.g. "English Satire & Humor") for this subject.
- term = attr.get('term')
- name = attr.get('label')
+ term = attr.get("term")
+ name = attr.get("label")
default_weight = 1
- weight = attr.get('{http://schema.org/}ratingValue', default_weight)
+ weight = attr.get("{http://schema.org/}ratingValue", default_weight)
try:
weight = int(weight)
except ValueError as e:
weight = default_weight
- return SubjectData(
- type=subject_type,
- identifier=term,
- name=name,
- weight=weight
- )
+ return SubjectData(type=subject_type, identifier=term, name=name, weight=weight)
@classmethod
def extract_link(cls, link_tag, feed_url=None, entry_rights_uri=None):
@@ -1691,14 +1740,14 @@ def extract_link(cls, link_tag, feed_url=None, entry_rights_uri=None):
if made available on these terms.
"""
attr = link_tag.attrib
- rel = attr.get('rel')
- media_type = attr.get('type')
- href = attr.get('href')
+ rel = attr.get("rel")
+ media_type = attr.get("type")
+ href = attr.get("href")
if not href or not rel:
# The link exists but has no destination, or no specified
# relationship to the entry.
return None
- rights = attr.get('{%s}rights' % OPDSXMLParser.NAMESPACES["dcterms"])
+ rights = attr.get("{%s}rights" % OPDSXMLParser.NAMESPACES["dcterms"])
if rights:
# Rights associated with the link override rights
# associated with the entry.
@@ -1711,14 +1760,19 @@ def extract_link(cls, link_tag, feed_url=None, entry_rights_uri=None):
return cls.make_link_data(rel, href, media_type, rights_uri)
@classmethod
- def make_link_data(cls, rel, href=None, media_type=None, rights_uri=None,
- content=None):
+ def make_link_data(
+ cls, rel, href=None, media_type=None, rights_uri=None, content=None
+ ):
"""Hook method for creating a LinkData object.
Intended to be overridden in subclasses.
"""
- return LinkData(rel=rel, href=href, media_type=media_type,
- rights_uri=rights_uri, content=content
+ return LinkData(
+ rel=rel,
+ href=href,
+ media_type=media_type,
+ rights_uri=rights_uri,
+ content=content,
)
@classmethod
@@ -1752,23 +1806,26 @@ def consolidate_links(cls, links):
next_link_already_handled = False
continue
- if i == len(links)-1:
+ if i == len(links) - 1:
# This is the last link. Since there is no next link
# there's nothing to do here.
continue
# Peek at the next link.
- next_link = links[i+1]
-
+ next_link = links[i + 1]
- if (link.rel == Hyperlink.THUMBNAIL_IMAGE
- and next_link.rel == Hyperlink.IMAGE):
+ if (
+ link.rel == Hyperlink.THUMBNAIL_IMAGE
+ and next_link.rel == Hyperlink.IMAGE
+ ):
# This link is a thumbnail and the next link is
# (presumably) the corresponding image.
thumbnail_link = link
image_link = next_link
- elif (link.rel == Hyperlink.IMAGE
- and next_link.rel == Hyperlink.THUMBNAIL_IMAGE):
+ elif (
+ link.rel == Hyperlink.IMAGE
+ and next_link.rel == Hyperlink.THUMBNAIL_IMAGE
+ ):
thumbnail_link = next_link
image_link = link
else:
@@ -1784,10 +1841,10 @@ def consolidate_links(cls, links):
@classmethod
def extract_measurement(cls, rating_tag):
- type = rating_tag.get('{http://schema.org/}additionalType')
- value = rating_tag.get('{http://schema.org/}ratingValue')
+ type = rating_tag.get("{http://schema.org/}additionalType")
+ value = rating_tag.get("{http://schema.org/}ratingValue")
if not value:
- value = rating_tag.attrib.get('{http://schema.org}ratingValue')
+ value = rating_tag.attrib.get("{http://schema.org}ratingValue")
if not type:
type = Measurement.RATING
try:
@@ -1802,8 +1859,8 @@ def extract_measurement(cls, rating_tag):
@classmethod
def extract_series(cls, series_tag):
attr = series_tag.attrib
- series_name = attr.get('{http://schema.org/}name', None)
- series_position = attr.get('{http://schema.org/}position', None)
+ series_name = attr.get("{http://schema.org/}name", None)
+ series_position = attr.get("{http://schema.org/}position", None)
return series_name, series_position
@@ -1822,8 +1879,9 @@ class OPDSImportMonitor(CollectionMonitor, HasSelfTests):
# specialize OPDS import should override this.
PROTOCOL = ExternalIntegration.OPDS_IMPORT
- def __init__(self, _db, collection, import_class,
- force_reimport=False, **import_class_kwargs):
+ def __init__(
+ self, _db, collection, import_class, force_reimport=False, **import_class_kwargs
+ ):
if not collection:
raise ValueError(
"OPDSImportMonitor can only be run in the context of a Collection."
@@ -1831,9 +1889,8 @@ def __init__(self, _db, collection, import_class,
if collection.protocol != self.PROTOCOL:
raise ValueError(
- "Collection %s is configured for protocol %s, not %s." % (
- collection.name, collection.protocol, self.PROTOCOL
- )
+ "Collection %s is configured for protocol %s, not %s."
+ % (collection.name, collection.protocol, self.PROTOCOL)
)
data_source = self.data_source(collection)
@@ -1849,21 +1906,18 @@ def __init__(self, _db, collection, import_class,
self.password = collection.external_integration.password
self.custom_accept_header = collection.external_integration.custom_accept_header
- self.importer = import_class(
- _db, collection=collection,
- **import_class_kwargs
- )
+ self.importer = import_class(_db, collection=collection, **import_class_kwargs)
super(OPDSImportMonitor, self).__init__(_db, collection)
def external_integration(self, _db):
- return get_one(_db, ExternalIntegration,
- id=self.external_integration_id)
+ return get_one(_db, ExternalIntegration, id=self.external_integration_id)
def _run_self_tests(self, _db):
"""Retrieve the first page of the OPDS feed"""
first_page = self.run_test(
"Retrieve the first page of the OPDS feed (%s)" % self.feed_url,
- self.follow_one_link, self.feed_url
+ self.follow_one_link,
+ self.feed_url,
)
yield first_page
if not first_page.result:
@@ -1877,7 +1931,8 @@ def _run_self_tests(self, _db):
yield self.run_test(
"Checking for importable content",
self.importer.assert_importable_content,
- content, self.feed_url
+ content,
+ self.feed_url,
)
def _get(self, url, headers):
@@ -1886,28 +1941,31 @@ def _get(self, url, headers):
Long timeout, raise error on anything but 2xx or 3xx.
"""
headers = self._update_headers(headers)
- kwargs = dict(timeout=120, allowed_response_codes=['2xx', '3xx'])
+ kwargs = dict(timeout=120, allowed_response_codes=["2xx", "3xx"])
response = HTTP.get_with_timeout(url, headers=headers, **kwargs)
return response.status_code, response.headers, response.content
def _get_accept_header(self):
- return ','.join([
- OPDSFeed.ACQUISITION_FEED_TYPE,
- "application/atom+xml;q=0.9",
- "application/xml;q=0.8",
- "*/*;q=0.1",
- ])
+ return ",".join(
+ [
+ OPDSFeed.ACQUISITION_FEED_TYPE,
+ "application/atom+xml;q=0.9",
+ "application/xml;q=0.8",
+ "*/*;q=0.1",
+ ]
+ )
def _update_headers(self, headers):
headers = dict(headers) if headers else {}
- if self.username and self.password and not 'Authorization' in headers:
- headers['Authorization'] = "Basic %s" % base64.b64encode("%s:%s" % (self.username,
- self.password))
+ if self.username and self.password and not "Authorization" in headers:
+ headers["Authorization"] = "Basic %s" % base64.b64encode(
+ "%s:%s" % (self.username, self.password)
+ )
if self.custom_accept_header:
- headers['Accept'] = self.custom_accept_header
- elif not 'Accept' in headers:
- headers['Accept'] = self._get_accept_header()
+ headers["Accept"] = self.custom_accept_header
+ elif not "Accept" in headers:
+ headers["Accept"] = self._get_accept_header()
return headers
@@ -1959,9 +2017,7 @@ def feed_contains_new_data(self, feed):
# Maybe this is new, maybe not, but we can't associate
# the information with an Identifier, so we can't do
# anything about it.
- self.log.info(
- "Ignoring %s because unable to turn into an Identifier."
- )
+ self.log.info("Ignoring %s because unable to turn into an Identifier.")
continue
if self.identifier_needs_import(identifier, remote_updated):
@@ -1980,16 +2036,16 @@ def identifier_needs_import(self, identifier, last_updated_remote):
return False
record = CoverageRecord.lookup(
- identifier, self.importer.data_source,
- operation=CoverageRecord.IMPORT_OPERATION
+ identifier,
+ self.importer.data_source,
+ operation=CoverageRecord.IMPORT_OPERATION,
)
if not record:
# We have no record of importing this Identifier. Import
# it now.
self.log.info(
- "Counting %s as new because it has no CoverageRecord.",
- identifier
+ "Counting %s as new because it has no CoverageRecord.", identifier
)
return True
@@ -1999,7 +2055,8 @@ def identifier_needs_import(self, identifier, last_updated_remote):
if record.status == CoverageRecord.TRANSIENT_FAILURE:
self.log.info(
"Counting %s as new because previous attempt resulted in transient failure: %s",
- identifier, record.exception
+ identifier,
+ record.exception,
)
return True
@@ -2016,7 +2073,7 @@ def identifier_needs_import(self, identifier, last_updated_remote):
# has been updated. Import it again to be safe.
self.log.info(
"Counting %s as new because remote has no information about when it was updated.",
- identifier
+ identifier,
)
return True
@@ -2024,21 +2081,22 @@ def identifier_needs_import(self, identifier, last_updated_remote):
# This book has been updated.
self.log.info(
"Counting %s as new because its coverage date is %s and remote has %s.",
- identifier, record.timestamp, last_updated_remote
+ identifier,
+ record.timestamp,
+ last_updated_remote,
)
return True
def _verify_media_type(self, url, status_code, headers, feed):
# Make sure we got an OPDS feed, and not an error page that was
# sent with a 200 status code.
- media_type = headers.get('content-type')
+ media_type = headers.get("content-type")
if not media_type or not any(
x in media_type for x in (OPDSFeed.ATOM_LIKE_TYPES)
):
message = "Expected Atom feed, got %s" % media_type
raise BadResponseException(
- url, message=message, debug_message=feed,
- status_code=status_code
+ url, message=message, debug_message=feed, status_code=status_code
)
def follow_one_link(self, url, do_get=None):
@@ -2074,16 +2132,16 @@ def import_one_feed(self, feed):
# Because we are importing into a Collection, we will immediately
# mark a book as presentation-ready if possible.
imported_editions, pools, works, failures = self.importer.import_from_feed(
- feed,
- feed_url=self.opds_url(self.collection)
+ feed, feed_url=self.opds_url(self.collection)
)
# Create CoverageRecords for the successful imports.
for edition in imported_editions:
record = CoverageRecord.add_for(
- edition, self.importer.data_source,
+ edition,
+ self.importer.data_source,
CoverageRecord.IMPORT_OPERATION,
- status=CoverageRecord.SUCCESS
+ status=CoverageRecord.SUCCESS,
)
# Create CoverageRecords for the failures.
@@ -2141,7 +2199,8 @@ def run_once(self, progress_ignore):
self._db.commit()
achievements = "Items imported: %d. Failures: %d." % (
- total_imported, total_failures
+ total_imported,
+ total_failures,
)
return TimestampData(achievements=achievements)
diff --git a/opensearch.py b/opensearch.py
index f9bcfb595..658bb6f0c 100644
--- a/opensearch.py
+++ b/opensearch.py
@@ -20,23 +20,23 @@ def search_info(cls, lane):
description = "Search %s" % lane.search_target.display_name
else:
description = "Search"
- d['description'] = description
- d['tags'] = " ".join(tags)
+ d["description"] = description
+ d["tags"] = " ".join(tags)
return d
@classmethod
def url_template(self, base_url):
"""Turn a base URL into an OpenSearch URL template."""
- if '?' in base_url:
- query = '&'
+ if "?" in base_url:
+ query = "&"
else:
- query = '?'
+ query = "?"
return base_url + query + "q={searchTerms}"
@classmethod
def for_lane(cls, lane, base_url):
info = cls.search_info(lane)
- info['url_template'] = cls.url_template(base_url)
+ info["url_template"] = cls.url_template(base_url)
info = cls.escape_entities(info)
return cls.TEMPLATE % info
diff --git a/overdrive.py b/overdrive.py
index 4fc9f7c09..c70ff5779 100644
--- a/overdrive.py
+++ b/overdrive.py
@@ -1,25 +1,28 @@
import datetime
-import isbnlib
-import os
import json
import logging
-from urllib.parse import urlsplit, quote, urlunsplit
+import os
import sys
-from sqlalchemy.orm.exc import (
- NoResultFound,
-)
+from urllib.parse import quote, urlsplit, urlunsplit
+
+import isbnlib
+from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.session import Session
from .classifier import Classifier
-from .config import (
- temp_config,
- CannotLoadConfiguration,
- Configuration,
+from .config import CannotLoadConfiguration, Configuration, temp_config
+from .coverage import BibliographicCoverageProvider
+from .metadata_layer import (
+ CirculationData,
+ ContributorData,
+ FormatData,
+ IdentifierData,
+ LinkData,
+ MeasurementData,
+ Metadata,
+ SubjectData,
)
-
from .model import (
- get_one,
- get_one_or_create,
Classification,
Collection,
ConfigurationSetting,
@@ -36,33 +39,15 @@
MediaTypes,
Representation,
Subject,
+ get_one,
+ get_one_or_create,
)
-
-from .metadata_layer import (
- CirculationData,
- ContributorData,
- FormatData,
- IdentifierData,
- Metadata,
- MeasurementData,
- LinkData,
- SubjectData,
-)
-
-from .coverage import (
- BibliographicCoverageProvider,
-)
-
-from .testing import DatabaseTest
-
-from .util.http import (
- HTTP,
- BadResponseException,
-)
+from .testing import DatabaseTest, MockRequestsResponse
+from .util.datetime_helpers import strptime_utc, to_utc, utc_now
+from .util.http import HTTP, BadResponseException
from .util.string_helpers import base64
from .util.worker_pools import RLock
-from .util.datetime_helpers import strptime_utc, to_utc, utc_now
-from .testing import MockRequestsResponse
+
class OverdriveAPI(object):
@@ -77,14 +62,14 @@ class OverdriveAPI(object):
PRODUCTION_SERVERS = "production"
TESTING_SERVERS = "testing"
HOSTS = {
- PRODUCTION_SERVERS : dict(
+ PRODUCTION_SERVERS: dict(
host="https://api.overdrive.com",
patron_host="https://patron.api.overdrive.com",
),
- TESTING_SERVERS : dict(
+ TESTING_SERVERS: dict(
host="https://integration.api.overdrive.com",
patron_host="https://integration-patron.api.overdrive.com",
- )
+ ),
}
# Production and testing setups use the same URLs for Client
@@ -92,8 +77,8 @@ class OverdriveAPI(object):
# system as for other hostnames to give a consistent look to the
# templates.
for host in list(HOSTS.values()):
- host['oauth_patron_host'] = "https://oauth-patron.overdrive.com"
- host['oauth_host'] = "https://oauth.overdrive.com"
+ host["oauth_patron_host"] = "https://oauth-patron.overdrive.com"
+ host["oauth_host"] = "https://oauth.overdrive.com"
# Each of these endpoint URLs has a slot to plug in one of the
# appropriate servers. This will be filled in either by a call to
@@ -104,16 +89,24 @@ class OverdriveAPI(object):
PATRON_TOKEN_ENDPOINT = "%(oauth_patron_host)s/patrontoken"
LIBRARY_ENDPOINT = "%(host)s/v1/libraries/%(library_id)s"
- ADVANTAGE_LIBRARY_ENDPOINT = "%(host)s/v1/libraries/%(parent_library_id)s/advantageAccounts/%(library_id)s"
- ALL_PRODUCTS_ENDPOINT = "%(host)s/v1/collections/%(collection_token)s/products?sort=%(sort)s"
- METADATA_ENDPOINT = "%(host)s/v1/collections/%(collection_token)s/products/%(item_id)s/metadata"
+ ADVANTAGE_LIBRARY_ENDPOINT = (
+ "%(host)s/v1/libraries/%(parent_library_id)s/advantageAccounts/%(library_id)s"
+ )
+ ALL_PRODUCTS_ENDPOINT = (
+ "%(host)s/v1/collections/%(collection_token)s/products?sort=%(sort)s"
+ )
+ METADATA_ENDPOINT = (
+ "%(host)s/v1/collections/%(collection_token)s/products/%(item_id)s/metadata"
+ )
EVENTS_ENDPOINT = "%(host)s/v1/collections/%(collection_token)s/products?lastUpdateTime=%(lastupdatetime)s&sort=%(sort)s&limit=%(limit)s"
AVAILABILITY_ENDPOINT = "%(host)s/v2/collections/%(collection_token)s/products/%(product_id)s/availability"
PATRON_INFORMATION_ENDPOINT = "%(patron_host)s/v1/patrons/me"
CHECKOUTS_ENDPOINT = "%(patron_host)s/v1/patrons/me/checkouts"
CHECKOUT_ENDPOINT = "%(patron_host)s/v1/patrons/me/checkouts/%(overdrive_id)s"
- FORMATS_ENDPOINT = "%(patron_host)s/v1/patrons/me/checkouts/%(overdrive_id)s/formats"
+ FORMATS_ENDPOINT = (
+ "%(patron_host)s/v1/patrons/me/checkouts/%(overdrive_id)s/formats"
+ )
HOLDS_ENDPOINT = "%(patron_host)s/v1/patrons/me/holds"
HOLD_ENDPOINT = "%(patron_host)s/v1/patrons/me/holds/%(product_id)s"
ME_ENDPOINT = "%(patron_host)s/v1/patrons/me"
@@ -126,12 +119,13 @@ class OverdriveAPI(object):
EVENT_DELAY = datetime.timedelta(minutes=120)
# The formats we care about.
- FORMATS = "ebook-epub-open,ebook-epub-adobe,ebook-pdf-adobe,ebook-pdf-open,audiobook-overdrive".split(",")
+ FORMATS = "ebook-epub-open,ebook-epub-adobe,ebook-pdf-adobe,ebook-pdf-open,audiobook-overdrive".split(
+ ","
+ )
# The formats that can be read by the default Library Simplified reader.
DEFAULT_READABLE_FORMATS = set(
- ["ebook-epub-open", "ebook-epub-adobe", "ebook-pdf-open",
- "audiobook-overdrive"]
+ ["ebook-epub-open", "ebook-epub-adobe", "ebook-pdf-open", "audiobook-overdrive"]
)
# The formats that indicate the book has been fulfilled on an
@@ -155,8 +149,8 @@ class OverdriveAPI(object):
def __init__(self, _db, collection):
if collection.protocol != ExternalIntegration.OVERDRIVE:
raise ValueError(
- "Collection protocol is %s, but passed into OverdriveAPI!" %
- collection.protocol
+ "Collection protocol is %s, but passed into OverdriveAPI!"
+ % collection.protocol
)
self._db = _db
self.library_id = collection.external_account_id
@@ -176,17 +170,18 @@ def __init__(self, _db, collection):
self.client_key = integration.username
self.client_secret = integration.password
self.website_id = integration.setting(self.WEBSITE_ID).value
- if (not self.client_key or not self.client_secret or not self.website_id
- or not self.library_id):
- raise CannotLoadConfiguration(
- "Overdrive configuration is incomplete."
- )
+ if (
+ not self.client_key
+ or not self.client_secret
+ or not self.website_id
+ or not self.library_id
+ ):
+ raise CannotLoadConfiguration("Overdrive configuration is incomplete.")
# Figure out which hostnames we'll be using when constructing
# endpoint URLs.
server_nickname = (
- integration.setting(self.SERVER_NICKNAME).value
- or self.PRODUCTION_SERVERS
+ integration.setting(self.SERVER_NICKNAME).value or self.PRODUCTION_SERVERS
)
if server_nickname not in self.HOSTS:
server_nickname = self.PRODUCTION_SERVERS
@@ -198,7 +193,7 @@ def __init__(self, _db, collection):
# Use utf8 instead of unicode encoding
settings = [self.client_key, self.client_secret, self.website_id]
self.client_key, self.client_secret, self.website_id = (
- setting.encode('utf8') for setting in settings
+ setting.encode("utf8") for setting in settings
)
# This is set by an access to .token, or by a call to
@@ -216,7 +211,7 @@ def endpoint(self, url, **kwargs):
The server hostname will be interpolated automatically; you
don't have to pass it in.
"""
- if not '%(' in url:
+ if not "%(" in url:
# Nothing to interpolate.
return url
kwargs.update(self.hosts)
@@ -238,14 +233,14 @@ def collection_token(self):
if not self._collection_token:
self.check_creds()
library = self.get_library()
- error = library.get('errorCode')
+ error = library.get("errorCode")
if error:
- message = library.get('message')
+ message = library.get("message")
raise CannotLoadConfiguration(
"Overdrive credentials are valid but could not fetch library: %s"
% message
)
- self._collection_token = library['collectionToken']
+ self._collection_token = library["collectionToken"]
return self._collection_token
@property
@@ -257,8 +252,7 @@ def source(self):
return DataSource.lookup(self._db, DataSource.OVERDRIVE)
def ils_name(self, library):
- """Determine the ILS name to use for the given Library.
- """
+ """Determine the ILS name to use for the given Library."""
return self.ils_name_setting(
self._db, self.collection, library
).value_or_default(self.ILS_NAME_DEFAULT)
@@ -308,8 +302,12 @@ def credential_object(self, refresh):
the Overdrive API.
"""
return Credential.lookup(
- self._db, DataSource.OVERDRIVE, None, None, refresh,
- collection=self.collection
+ self._db,
+ DataSource.OVERDRIVE,
+ None,
+ None,
+ refresh,
+ collection=self.collection,
)
def refresh_creds(self, credential):
@@ -317,7 +315,7 @@ def refresh_creds(self, credential):
response = self.token_post(
self.TOKEN_ENDPOINT,
dict(grant_type="client_credentials"),
- allowed_response_codes=[200]
+ allowed_response_codes=[200],
)
data = response.json()
self._update_credential(credential, data)
@@ -334,7 +332,7 @@ def get(self, url, extra_headers, exception_on_401=False):
raise BadResponseException.from_response(
url,
"Something's wrong with the Overdrive OAuth Bearer Token!",
- (status_code, headers, content)
+ (status_code, headers, content),
)
else:
# Refresh the token and try again.
@@ -351,15 +349,14 @@ def token_authorization_header(self):
def token_post(self, url, payload, headers={}, **kwargs):
"""Make an HTTP POST request for purposes of getting an OAuth token."""
headers = dict(headers)
- headers['Authorization'] = self.token_authorization_header
+ headers["Authorization"] = self.token_authorization_header
return self._do_post(url, payload, headers, **kwargs)
def _update_credential(self, credential, overdrive_data):
"""Copy Overdrive OAuth data into a Credential object."""
- credential.credential = overdrive_data['access_token']
- expires_in = (overdrive_data['expires_in'] * 0.9)
- credential.expires = utc_now() + datetime.timedelta(
- seconds=expires_in)
+ credential.credential = overdrive_data["access_token"]
+ expires_in = overdrive_data["expires_in"] * 0.9
+ credential.expires = utc_now() + datetime.timedelta(seconds=expires_in)
@property
def _library_endpoint(self):
@@ -374,7 +371,7 @@ def _library_endpoint(self):
args = dict(library_id=self.library_id)
if self.parent_library_id:
# This is an Overdrive advantage account.
- args['parent_library_id'] = self.parent_library_id
+ args["parent_library_id"] = self.parent_library_id
endpoint = self.ADVANTAGE_LIBRARY_ENDPOINT
else:
endpoint = self.LIBRARY_ENDPOINT
@@ -387,7 +384,9 @@ def get_library(self):
url = self._library_endpoint
with self.lock:
representation, cached = Representation.get(
- self._db, url, self.get,
+ self._db,
+ url,
+ self.get,
exception_handler=Representation.reraise_exception,
)
return json.loads(representation.content)
@@ -398,23 +397,23 @@ def get_advantage_accounts(self):
:yield: A sequence of OverdriveAdvantageAccount objects.
"""
library = self.get_library()
- links = library.get('links', {})
- advantage = links.get('advantageAccounts')
+ links = library.get("links", {})
+ advantage = links.get("advantageAccounts")
if not advantage:
return []
if advantage:
# This library has Overdrive Advantage accounts, or at
# least a link where some may be found.
- advantage_url = advantage.get('href')
+ advantage_url = advantage.get("href")
if not advantage_url:
return
representation, cached = Representation.get(
- self._db, advantage_url, self.get,
+ self._db,
+ advantage_url,
+ self.get,
exception_handler=Representation.reraise_exception,
)
- return OverdriveAdvantageAccount.from_representation(
- representation.content
- )
+ return OverdriveAdvantageAccount.from_representation(representation.content)
def all_ids(self):
"""Get IDs for every book in the system, with the most recently added
@@ -422,9 +421,7 @@ def all_ids(self):
"""
next_link = self._all_products_link
while next_link:
- page_inventory, next_link = self._get_book_list_page(
- next_link, 'next'
- )
+ page_inventory, next_link = self._get_book_list_page(next_link, "next")
for i in page_inventory:
yield i
@@ -434,12 +431,11 @@ def _all_products_link(self):
url = self.endpoint(
self.ALL_PRODUCTS_ENDPOINT,
collection_token=self.collection_token,
- sort="dateAdded:desc"
+ sort="dateAdded:desc",
)
return self.make_link_safe(url)
- def _get_book_list_page(self, link, rel_to_follow='next',
- extractor_class=None):
+ def _get_book_list_page(self, link, rel_to_follow="next", extractor_class=None):
"""Process a page of inventory whose circulation we need to check.
Returns a 2-tuple: (availability_info, next_link).
@@ -459,7 +455,7 @@ def _get_book_list_page(self, link, rel_to_follow='next',
# Prepare to get availability information for all the books on
# this page.
- availability_queue = (extractor_class.availability_link_list(content))
+ availability_queue = extractor_class.availability_link_list(content)
return availability_queue, next_link
def recently_changed_ids(self, start, cutoff):
@@ -469,11 +465,8 @@ def recently_changed_ids(self, start, cutoff):
# `cutoff` is not supported by Overdrive, so we ignore it. All
# we can do is get events between the start time and now.
- last_update_time = start-self.EVENT_DELAY
- self.log.info(
- "Asking for circulation changes since %s",
- last_update_time
- )
+ last_update_time = start - self.EVENT_DELAY
+ self.log.info("Asking for circulation changes since %s", last_update_time)
last_update = last_update_time.strftime(self.TIME_FORMAT)
next_link = self.endpoint(
@@ -481,7 +474,7 @@ def recently_changed_ids(self, start, cutoff):
lastupdatetime=last_update,
sort="popularity:desc",
limit=self.PAGE_SIZE_LIMIT,
- collection_token=self.collection_token
+ collection_token=self.collection_token,
)
next_link = self.make_link_safe(next_link)
while next_link:
@@ -494,12 +487,11 @@ def recently_changed_ids(self, start, cutoff):
yield i
def metadata_lookup(self, identifier):
- """Look up metadata for an Overdrive identifier.
- """
+ """Look up metadata for an Overdrive identifier."""
url = self.endpoint(
self.METADATA_ENDPOINT,
collection_token=self.collection_token,
- item_id=identifier.identifier
+ item_id=identifier.identifier,
)
status_code, headers, content = self.get(url, {})
if isinstance(content, (bytes, str)):
@@ -510,14 +502,13 @@ def metadata_lookup_obj(self, identifier):
url = self.endpoint(
self.METADATA_ENDPOINT,
collection_token=self.collection_token,
- item_id=identifier
+ item_id=identifier,
)
status_code, headers, content = self.get(url, {})
if isinstance(content, (bytes, str)):
content = json.loads(content)
return OverdriveRepresentationExtractor.book_info_to_metadata(content)
-
@classmethod
def make_link_safe(self, url):
"""Turn a server-provided link into a link the server will accept!
@@ -531,11 +522,10 @@ def make_link_safe(self, url):
parts = list(urlsplit(url))
parts[2] = quote(parts[2])
endings = ("/availability", "/availability/")
- if (parts[2].startswith("/v1/collections/")
- and any(parts[2].endswith(x) for x in endings)):
- parts[2] = parts[2].replace(
- "/v1/collections/", "/v2/collections/", 1
- )
+ if parts[2].startswith("/v1/collections/") and any(
+ parts[2].endswith(x) for x in endings
+ ):
+ parts[2] = parts[2].replace("/v1/collections/", "/v2/collections/", 1)
query_string = parts[3]
query_string = query_string.replace("+", "%2B")
query_string = query_string.replace(":", "%3A")
@@ -547,9 +537,7 @@ def make_link_safe(self, url):
def _do_get(self, url, headers):
"""This method is overridden in MockOverdriveAPI."""
url = self.endpoint(url)
- return Representation.simple_http_get(
- url, headers
- )
+ return Representation.simple_http_get(url, headers)
def _do_post(self, url, payload, headers, **kwargs):
"""This method is overridden in MockOverdriveAPI."""
@@ -558,30 +546,33 @@ def _do_post(self, url, payload, headers, **kwargs):
class MockOverdriveAPI(OverdriveAPI):
-
@classmethod
- def mock_collection(self, _db, library=None,
- name="Test Overdrive Collection",
- client_key="a", client_secret="b",
- library_id="c", website_id="d",
- ils_name="e",
- ):
+ def mock_collection(
+ self,
+ _db,
+ library=None,
+ name="Test Overdrive Collection",
+ client_key="a",
+ client_secret="b",
+ library_id="c",
+ website_id="d",
+ ils_name="e",
+ ):
"""Create a mock Overdrive collection for use in tests."""
if library is None:
library = DatabaseTest.make_default_library(_db)
collection, ignore = get_one_or_create(
- _db, Collection,
- name=name,
- create_method_kwargs=dict(
- external_account_id=library_id
- )
- )
+ _db,
+ Collection,
+ name=name,
+ create_method_kwargs=dict(external_account_id=library_id),
+ )
integration = collection.create_external_integration(
protocol=ExternalIntegration.OVERDRIVE
)
integration.username = client_key
integration.password = client_secret
- integration.set_setting('website_id', website_id)
+ integration.set_setting("website_id", website_id)
library.collections.append(collection)
OverdriveAPI.ils_name_setting(_db, collection, library).value = ils_name
return collection
@@ -594,18 +585,14 @@ def __init__(self, _db, collection, *args, **kwargs):
# Almost all tests will try to request the access token, so
# set the response that will be returned if an attempt is
# made.
- self.access_token_response = self.mock_access_token_response(
- "bearer token"
- )
+ self.access_token_response = self.mock_access_token_response("bearer token")
super(MockOverdriveAPI, self).__init__(_db, collection, *args, **kwargs)
def queue_collection_token(self):
# Many tests immediately try to access the
# collection token. This is a helper method to make it easy to
# queue up the response.
- self.queue_response(
- 200, content=self.mock_collection_token("collection token")
- )
+ self.queue_response(200, content=self.mock_collection_token("collection token"))
def token_post(self, url, payload, headers={}, **kwargs):
"""Mock the request for an OAuth token.
@@ -630,9 +617,7 @@ def mock_collection_token(self, token):
return json.dumps(dict(collectionToken=token))
def queue_response(self, status_code, headers={}, content=None):
- self.responses.insert(
- 0, MockRequestsResponse(status_code, headers, content)
- )
+ self.responses.insert(0, MockRequestsResponse(status_code, headers, content))
def _do_get(self, url, *args, **kwargs):
"""Simulate Representation.simple_http_get."""
@@ -647,8 +632,10 @@ def _make_request(self, url, *args, **kwargs):
response = self.responses.pop()
self.requests.append((url, args, kwargs))
return HTTP._process_response(
- url, response, kwargs.get('allowed_response_codes'),
- kwargs.get('disallowed_response_codes')
+ url,
+ response,
+ kwargs.get("allowed_response_codes"),
+ kwargs.get("disallowed_response_codes"),
)
@@ -669,98 +656,88 @@ def __init__(self, api):
@classmethod
def availability_link_list(cls, book_list):
- """:return: A list of dictionaries with keys `id`, `title`, `availability_link`.
- """
+ """:return: A list of dictionaries with keys `id`, `title`, `availability_link`."""
l = []
- if not 'products' in book_list:
+ if not "products" in book_list:
return []
- products = book_list['products']
+ products = book_list["products"]
for product in products:
- if not 'id' in product:
+ if not "id" in product:
cls.log.warn("No ID found in %r", product)
continue
- book_id = product['id']
+ book_id = product["id"]
data = dict(
id=book_id,
- title=product.get('title'),
+ title=product.get("title"),
author_name=None,
- date_added=product.get('dateAdded')
+ date_added=product.get("dateAdded"),
)
- if 'primaryCreator' in product:
- creator = product['primaryCreator']
- if creator.get('role') == 'Author':
- data['author_name'] = creator.get('name')
- links = product.get('links', [])
- if 'availability' in links:
- link = links['availability']['href']
- data['availability_link'] = OverdriveAPI.make_link_safe(link)
+ if "primaryCreator" in product:
+ creator = product["primaryCreator"]
+ if creator.get("role") == "Author":
+ data["author_name"] = creator.get("name")
+ links = product.get("links", [])
+ if "availability" in links:
+ link = links["availability"]["href"]
+ data["availability_link"] = OverdriveAPI.make_link_safe(link)
else:
logging.getLogger("Overdrive API").warn(
- "No availability link for %s", book_id)
+ "No availability link for %s", book_id
+ )
l.append(data)
return l
@classmethod
def link(self, page, rel):
- if 'links' in page and rel in page['links']:
- raw_link = page['links'][rel]['href']
+ if "links" in page and rel in page["links"]:
+ raw_link = page["links"][rel]["href"]
link = OverdriveAPI.make_link_safe(raw_link)
else:
link = None
return link
format_data_for_overdrive_format = {
-
- "ebook-pdf-adobe" : (
- Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM
- ),
- "ebook-pdf-open" : (
- Representation.PDF_MEDIA_TYPE, DeliveryMechanism.NO_DRM
- ),
- "ebook-epub-adobe" : (
- Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM
- ),
- "ebook-epub-open" : (
- Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM
- ),
- "audiobook-mp3" : (
- "application/x-od-media", DeliveryMechanism.OVERDRIVE_DRM
- ),
- "music-mp3" : (
- "application/x-od-media", DeliveryMechanism.OVERDRIVE_DRM
+ "ebook-pdf-adobe": (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM),
+ "ebook-pdf-open": (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.NO_DRM),
+ "ebook-epub-adobe": (
+ Representation.EPUB_MEDIA_TYPE,
+ DeliveryMechanism.ADOBE_DRM,
),
- "ebook-overdrive" : [
+ "ebook-epub-open": (Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM),
+ "audiobook-mp3": ("application/x-od-media", DeliveryMechanism.OVERDRIVE_DRM),
+ "music-mp3": ("application/x-od-media", DeliveryMechanism.OVERDRIVE_DRM),
+ "ebook-overdrive": [
(
MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE,
- DeliveryMechanism.LIBBY_DRM
+ DeliveryMechanism.LIBBY_DRM,
),
(
DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
- DeliveryMechanism.STREAMING_DRM
+ DeliveryMechanism.STREAMING_DRM,
),
],
- "audiobook-overdrive" : [
+ "audiobook-overdrive": [
(
MediaTypes.OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE,
DeliveryMechanism.LIBBY_DRM,
),
(
DeliveryMechanism.STREAMING_AUDIO_CONTENT_TYPE,
- DeliveryMechanism.STREAMING_DRM
+ DeliveryMechanism.STREAMING_DRM,
),
],
- 'video-streaming' : (
+ "video-streaming": (
DeliveryMechanism.STREAMING_VIDEO_CONTENT_TYPE,
- DeliveryMechanism.STREAMING_DRM
+ DeliveryMechanism.STREAMING_DRM,
),
- "ebook-kindle" : (
+ "ebook-kindle": (
DeliveryMechanism.KINDLE_CONTENT_TYPE,
- DeliveryMechanism.KINDLE_DRM
+ DeliveryMechanism.KINDLE_DRM,
),
- "periodicals-nook" : (
+ "periodicals-nook": (
DeliveryMechanism.NOOK_CONTENT_TYPE,
- DeliveryMechanism.NOOK_DRM
+ DeliveryMechanism.NOOK_DRM,
),
}
@@ -784,46 +761,46 @@ def internal_formats(cls, overdrive_format):
ignorable_overdrive_formats = set([])
overdrive_role_to_simplified_role = {
- "actor" : Contributor.ACTOR_ROLE,
- "artist" : Contributor.ARTIST_ROLE,
- "book producer" : Contributor.PRODUCER_ROLE,
- "associated name" : Contributor.ASSOCIATED_ROLE,
- "author" : Contributor.AUTHOR_ROLE,
- "author of introduction" : Contributor.INTRODUCTION_ROLE,
- "author of foreword" : Contributor.FOREWORD_ROLE,
- "author of afterword" : Contributor.AFTERWORD_ROLE,
- "contributor" : Contributor.CONTRIBUTOR_ROLE,
- "colophon" : Contributor.COLOPHON_ROLE,
- "adapter" : Contributor.ADAPTER_ROLE,
- "etc." : Contributor.UNKNOWN_ROLE,
- "cast member" : Contributor.ACTOR_ROLE,
- "collaborator" : Contributor.COLLABORATOR_ROLE,
- "compiler" : Contributor.COMPILER_ROLE,
- "composer" : Contributor.COMPOSER_ROLE,
- "copyright holder" : Contributor.COPYRIGHT_HOLDER_ROLE,
- "director" : Contributor.DIRECTOR_ROLE,
- "editor" : Contributor.EDITOR_ROLE,
- "engineer" : Contributor.ENGINEER_ROLE,
- "executive producer" : Contributor.EXECUTIVE_PRODUCER_ROLE,
- "illustrator" : Contributor.ILLUSTRATOR_ROLE,
- "musician" : Contributor.MUSICIAN_ROLE,
- "narrator" : Contributor.NARRATOR_ROLE,
- "other" : Contributor.UNKNOWN_ROLE,
- "performer" : Contributor.PERFORMER_ROLE,
- "producer" : Contributor.PRODUCER_ROLE,
- "translator" : Contributor.TRANSLATOR_ROLE,
- "photographer" : Contributor.PHOTOGRAPHER_ROLE,
- "lyricist" : Contributor.LYRICIST_ROLE,
- "transcriber" : Contributor.TRANSCRIBER_ROLE,
- "designer" : Contributor.DESIGNER_ROLE,
+ "actor": Contributor.ACTOR_ROLE,
+ "artist": Contributor.ARTIST_ROLE,
+ "book producer": Contributor.PRODUCER_ROLE,
+ "associated name": Contributor.ASSOCIATED_ROLE,
+ "author": Contributor.AUTHOR_ROLE,
+ "author of introduction": Contributor.INTRODUCTION_ROLE,
+ "author of foreword": Contributor.FOREWORD_ROLE,
+ "author of afterword": Contributor.AFTERWORD_ROLE,
+ "contributor": Contributor.CONTRIBUTOR_ROLE,
+ "colophon": Contributor.COLOPHON_ROLE,
+ "adapter": Contributor.ADAPTER_ROLE,
+ "etc.": Contributor.UNKNOWN_ROLE,
+ "cast member": Contributor.ACTOR_ROLE,
+ "collaborator": Contributor.COLLABORATOR_ROLE,
+ "compiler": Contributor.COMPILER_ROLE,
+ "composer": Contributor.COMPOSER_ROLE,
+ "copyright holder": Contributor.COPYRIGHT_HOLDER_ROLE,
+ "director": Contributor.DIRECTOR_ROLE,
+ "editor": Contributor.EDITOR_ROLE,
+ "engineer": Contributor.ENGINEER_ROLE,
+ "executive producer": Contributor.EXECUTIVE_PRODUCER_ROLE,
+ "illustrator": Contributor.ILLUSTRATOR_ROLE,
+ "musician": Contributor.MUSICIAN_ROLE,
+ "narrator": Contributor.NARRATOR_ROLE,
+ "other": Contributor.UNKNOWN_ROLE,
+ "performer": Contributor.PERFORMER_ROLE,
+ "producer": Contributor.PRODUCER_ROLE,
+ "translator": Contributor.TRANSLATOR_ROLE,
+ "photographer": Contributor.PHOTOGRAPHER_ROLE,
+ "lyricist": Contributor.LYRICIST_ROLE,
+ "transcriber": Contributor.TRANSCRIBER_ROLE,
+ "designer": Contributor.DESIGNER_ROLE,
}
overdrive_medium_to_simplified_medium = {
- "eBook" : Edition.BOOK_MEDIUM,
- "Video" : Edition.VIDEO_MEDIUM,
- "Audiobook" : Edition.AUDIO_MEDIUM,
- "Music" : Edition.MUSIC_MEDIUM,
- "Periodicals" : Edition.PERIODICAL_MEDIUM,
+ "eBook": Edition.BOOK_MEDIUM,
+ "Video": Edition.VIDEO_MEDIUM,
+ "Audiobook": Edition.AUDIO_MEDIUM,
+ "Music": Edition.MUSIC_MEDIUM,
+ "Periodicals": Edition.PERIODICAL_MEDIUM,
}
DATE_FORMAT = "%Y-%m-%d"
@@ -832,19 +809,18 @@ def internal_formats(cls, overdrive_format):
def parse_roles(cls, id, rolestring):
rolestring = rolestring.lower()
roles = [x.strip() for x in rolestring.split(",")]
- if ' and ' in roles[-1]:
+ if " and " in roles[-1]:
roles = roles[:-1] + [x.strip() for x in roles[-1].split(" and ")]
processed = []
for x in roles:
if x not in cls.overdrive_role_to_simplified_role:
- cls.log.error(
- "Could not process role %s for %s", x, id)
+ cls.log.error("Could not process role %s for %s", x, id)
else:
processed.append(cls.overdrive_role_to_simplified_role[x])
return processed
def book_info_to_circulation(self, book):
- """ Note: The json data passed into this method is from a different file/stream
+ """Note: The json data passed into this method is from a different file/stream
from the json data that goes into the book_info_to_metadata() method.
"""
# In Overdrive, 'reserved' books show up as books on
@@ -859,14 +835,12 @@ def book_info_to_circulation(self, book):
# circulation code sticks the known book ID into `book` ahead
# of time. That's a code smell indicating that this system
# needs to be refactored.
- if 'reserveId' in book and not 'id' in book:
- book['id'] = book['reserveId']
- if not 'id' in book:
+ if "reserveId" in book and not "id" in book:
+ book["id"] = book["reserveId"]
+ if not "id" in book:
return None
- overdrive_id = book['id']
- primary_identifier = IdentifierData(
- Identifier.OVERDRIVE_ID, overdrive_id
- )
+ overdrive_id = book["id"]
+ primary_identifier = IdentifierData(Identifier.OVERDRIVE_ID, overdrive_id)
# TODO: We might be able to use this information to avoid the
# need for explicit configuration of Advantage collections, or
# at least to keep Advantage collections more up-to-date than
@@ -884,35 +858,35 @@ def book_info_to_circulation(self, book):
# similarly, though those can abruptly become unavailable, so
# UNLIMITED_ACCESS is probably not appropriate.
- error_code = book.get('errorCode')
+ error_code = book.get("errorCode")
# TODO: It's not clear what other error codes there might be.
# The current behavior will respond to errors other than
# NotFound by leaving the book alone, but this might not be
# the right behavior.
- if error_code == 'NotFound':
+ if error_code == "NotFound":
licenses_owned = 0
licenses_available = 0
patrons_in_hold_queue = 0
- elif book.get('isOwnedByCollections') is not False:
+ elif book.get("isOwnedByCollections") is not False:
# We own this book.
- for account in book.get('accounts', []):
+ for account in book.get("accounts", []):
# Only keep track of copies owned by the collection
# we're tracking.
- if account.get('id') != self.library_id:
+ if account.get("id") != self.library_id:
continue
- if 'copiesOwned' in account:
+ if "copiesOwned" in account:
if licenses_owned is None:
licenses_owned = 0
- licenses_owned += int(account['copiesOwned'])
- if 'copiesAvailable' in account:
+ licenses_owned += int(account["copiesOwned"])
+ if "copiesAvailable" in account:
if licenses_available is None:
licenses_available = 0
- licenses_available += int(account['copiesAvailable'])
- if 'numberOfHolds' in book:
+ licenses_available += int(account["copiesAvailable"])
+ if "numberOfHolds" in book:
if patrons_in_hold_queue is None:
patrons_in_hold_queue = 0
- patrons_in_hold_queue += book['numberOfHolds']
+ patrons_in_hold_queue += book["numberOfHolds"]
return CirculationData(
data_source=DataSource.OVERDRIVE,
primary_identifier=primary_identifier,
@@ -924,154 +898,163 @@ def book_info_to_circulation(self, book):
@classmethod
def image_link_to_linkdata(cls, link, rel):
- if not link or not 'href' in link:
+ if not link or not "href" in link:
return None
- href = link['href']
- if '00000000-0000-0000-0000' in href:
+ href = link["href"]
+ if "00000000-0000-0000-0000" in href:
# This is a stand-in cover for preorders. It's better not
# to have a cover at all -- we might be able to get one
# later, or from another source.
return None
href = OverdriveAPI.make_link_safe(href)
- media_type = link.get('type', None)
+ media_type = link.get("type", None)
return LinkData(rel=rel, href=href, media_type=media_type)
-
@classmethod
- def book_info_to_metadata(cls, book, include_bibliographic=True, include_formats=True):
+ def book_info_to_metadata(
+ cls, book, include_bibliographic=True, include_formats=True
+ ):
"""Turn Overdrive's JSON representation of a book into a Metadata
object.
Note: The json data passed into this method is from a different file/stream
from the json data that goes into the book_info_to_circulation() method.
"""
- if not 'id' in book:
+ if not "id" in book:
return None
- overdrive_id = book['id']
- primary_identifier = IdentifierData(
- Identifier.OVERDRIVE_ID, overdrive_id
- )
+ overdrive_id = book["id"]
+ primary_identifier = IdentifierData(Identifier.OVERDRIVE_ID, overdrive_id)
# If we trust classification data, we'll give it this weight.
# Otherwise we'll probably give it a fraction of this weight.
trusted_weight = Classification.TRUSTED_DISTRIBUTOR_WEIGHT
if include_bibliographic:
- title = book.get('title', None)
- sort_title = book.get('sortTitle')
- subtitle = book.get('subtitle', None)
- series = book.get('series', None)
- publisher = book.get('publisher', None)
- imprint = book.get('imprint', None)
-
- if 'publishDate' in book:
- published = strptime_utc(
- book['publishDate'][:10], cls.DATE_FORMAT)
+ title = book.get("title", None)
+ sort_title = book.get("sortTitle")
+ subtitle = book.get("subtitle", None)
+ series = book.get("series", None)
+ publisher = book.get("publisher", None)
+ imprint = book.get("imprint", None)
+
+ if "publishDate" in book:
+ published = strptime_utc(book["publishDate"][:10], cls.DATE_FORMAT)
else:
published = None
- languages = [l['code'] for l in book.get('languages', [])]
- if 'eng' in languages or not languages:
- language = 'eng'
+ languages = [l["code"] for l in book.get("languages", [])]
+ if "eng" in languages or not languages:
+ language = "eng"
else:
language = sorted(languages)[0]
contributors = []
- for creator in book.get('creators', []):
- sort_name = creator['fileAs']
- display_name = creator['name']
- role = creator['role']
- roles = cls.parse_roles(overdrive_id, role) or [Contributor.UNKNOWN_ROLE]
+ for creator in book.get("creators", []):
+ sort_name = creator["fileAs"]
+ display_name = creator["name"]
+ role = creator["role"]
+ roles = cls.parse_roles(overdrive_id, role) or [
+ Contributor.UNKNOWN_ROLE
+ ]
contributor = ContributorData(
- sort_name=sort_name, display_name=display_name,
- roles=roles, biography = creator.get('bioText', None)
+ sort_name=sort_name,
+ display_name=display_name,
+ roles=roles,
+ biography=creator.get("bioText", None),
)
contributors.append(contributor)
subjects = []
- for sub in book.get('subjects', []):
+ for sub in book.get("subjects", []):
subject = SubjectData(
- type=Subject.OVERDRIVE, identifier=sub['value'],
+ type=Subject.OVERDRIVE,
+ identifier=sub["value"],
weight=trusted_weight,
)
subjects.append(subject)
- for sub in book.get('keywords', []):
+ for sub in book.get("keywords", []):
subject = SubjectData(
- type=Subject.TAG, identifier=sub['value'],
+ type=Subject.TAG,
+ identifier=sub["value"],
# We don't use TRUSTED_DISTRIBUTOR_WEIGHT because
# we don't know where the tags come from --
# probably Overdrive users -- and they're
# frequently wrong.
- weight=1
+ weight=1,
)
subjects.append(subject)
extra = dict()
- if 'grade_levels' in book:
+ if "grade_levels" in book:
# n.b. Grade levels are measurements of reading level, not
# age appropriateness. We can use them as a measure of age
# appropriateness in a pinch, but we weight them less
# heavily than TRUSTED_DISTRIBUTOR_WEIGHT.
- for i in book['grade_levels']:
+ for i in book["grade_levels"]:
subject = SubjectData(
type=Subject.GRADE_LEVEL,
- identifier=i['value'],
- weight=trusted_weight / 10
+ identifier=i["value"],
+ weight=trusted_weight / 10,
)
subjects.append(subject)
- overdrive_medium = book.get('mediaType', None)
- if overdrive_medium and overdrive_medium not in cls.overdrive_medium_to_simplified_medium:
+ overdrive_medium = book.get("mediaType", None)
+ if (
+ overdrive_medium
+ and overdrive_medium not in cls.overdrive_medium_to_simplified_medium
+ ):
cls.log.error(
- "Could not process medium %s for %s", overdrive_medium, overdrive_id)
+ "Could not process medium %s for %s", overdrive_medium, overdrive_id
+ )
medium = cls.overdrive_medium_to_simplified_medium.get(
overdrive_medium, Edition.BOOK_MEDIUM
)
measurements = []
- if 'awards' in book:
- extra['awards'] = book.get('awards', [])
- num_awards = len(extra['awards'])
+ if "awards" in book:
+ extra["awards"] = book.get("awards", [])
+ num_awards = len(extra["awards"])
measurements.append(
- MeasurementData(
- Measurement.AWARDS, str(num_awards)
- )
+ MeasurementData(Measurement.AWARDS, str(num_awards))
)
for name, subject_type in (
- ('ATOS', Subject.ATOS_SCORE),
- ('lexileScore', Subject.LEXILE_SCORE),
- ('interestLevel', Subject.INTEREST_LEVEL)
+ ("ATOS", Subject.ATOS_SCORE),
+ ("lexileScore", Subject.LEXILE_SCORE),
+ ("interestLevel", Subject.INTEREST_LEVEL),
):
if not name in book:
continue
identifier = str(book[name])
subjects.append(
- SubjectData(type=subject_type, identifier=identifier,
- weight=trusted_weight
- )
+ SubjectData(
+ type=subject_type, identifier=identifier, weight=trusted_weight
+ )
)
- for grade_level_info in book.get('gradeLevels', []):
- grade_level = grade_level_info.get('value')
+ for grade_level_info in book.get("gradeLevels", []):
+ grade_level = grade_level_info.get("value")
subjects.append(
- SubjectData(type=Subject.GRADE_LEVEL, identifier=grade_level,
- weight=trusted_weight)
+ SubjectData(
+ type=Subject.GRADE_LEVEL,
+ identifier=grade_level,
+ weight=trusted_weight,
+ )
)
identifiers = []
links = []
- for format in book.get('formats', []):
- for new_id in format.get('identifiers', []):
- t = new_id['type']
- v = new_id['value']
+ for format in book.get("formats", []):
+ for new_id in format.get("identifiers", []):
+ t = new_id["type"]
+ v = new_id["value"]
orig_v = v
type_key = None
- if t == 'ASIN':
+ if t == "ASIN":
type_key = Identifier.ASIN
- elif t == 'ISBN':
+ elif t == "ISBN":
type_key = Identifier.ISBN
if len(v) == 10:
v = isbnlib.to_isbn13(v)
@@ -1082,47 +1065,43 @@ def book_info_to_metadata(cls, book, include_bibliographic=True, include_formats
# books appear to have the same ISBN. ISBNs
# which fail check digit checks or are invalid
# also can occur. Log them for review.
- cls.log.info(
- "Bad ISBN value provided: %s", orig_v
- )
+ cls.log.info("Bad ISBN value provided: %s", orig_v)
continue
- elif t == 'DOI':
+ elif t == "DOI":
type_key = Identifier.DOI
- elif t == 'UPC':
+ elif t == "UPC":
type_key = Identifier.UPC
- elif t == 'PublisherCatalogNumber':
+ elif t == "PublisherCatalogNumber":
continue
if type_key and v:
- identifiers.append(
- IdentifierData(type_key, v, 1)
- )
+ identifiers.append(IdentifierData(type_key, v, 1))
# Samples become links.
- if 'samples' in format:
- overdrive_name = format['id']
+ if "samples" in format:
+ overdrive_name = format["id"]
internal_names = list(cls.internal_formats(overdrive_name))
if not internal_names:
# Useless to us.
continue
for content_type, drm_scheme in internal_names:
if Representation.is_media_type(content_type):
- for sample_info in format['samples']:
- href = sample_info['url']
+ for sample_info in format["samples"]:
+ href = sample_info["url"]
links.append(
LinkData(
rel=Hyperlink.SAMPLE,
href=href,
- media_type=content_type
+ media_type=content_type,
)
)
# A cover and its thumbnail become a single LinkData.
- if 'images' in book:
- images = book['images']
+ if "images" in book:
+ images = book["images"]
image_data = cls.image_link_to_linkdata(
- images.get('cover'), Hyperlink.IMAGE
+ images.get("cover"), Hyperlink.IMAGE
)
- for name in ['cover300Wide', 'cover150Wide', 'thumbnail']:
+ for name in ["cover300Wide", "cover150Wide", "thumbnail"]:
# Try to get a thumbnail that's as close as possible
# to the size we use.
image = images.get(name)
@@ -1130,9 +1109,7 @@ def book_info_to_metadata(cls, book, include_bibliographic=True, include_formats
image, Hyperlink.THUMBNAIL_IMAGE
)
if not image_data:
- image_data = cls.image_link_to_linkdata(
- image, Hyperlink.IMAGE
- )
+ image_data = cls.image_link_to_linkdata(image, Hyperlink.IMAGE)
if thumbnail_data:
break
@@ -1142,8 +1119,8 @@ def book_info_to_metadata(cls, book, include_bibliographic=True, include_formats
links.append(image_data)
# Descriptions become links.
- short = book.get('shortDescription')
- full = book.get('fullDescription')
+ short = book.get("shortDescription")
+ full = book.get("fullDescription")
if full:
links.append(
LinkData(
@@ -1163,19 +1140,18 @@ def book_info_to_metadata(cls, book, include_bibliographic=True, include_formats
)
# Add measurements: rating and popularity
- if book.get('starRating') is not None and book['starRating'] > 0:
+ if book.get("starRating") is not None and book["starRating"] > 0:
measurements.append(
MeasurementData(
- quantity_measured=Measurement.RATING,
- value=book['starRating']
+ quantity_measured=Measurement.RATING, value=book["starRating"]
)
)
- if book.get('popularity'):
+ if book.get("popularity"):
measurements.append(
MeasurementData(
quantity_measured=Measurement.POPULARITY,
- value=book['popularity']
+ value=book["popularity"],
)
)
@@ -1205,8 +1181,8 @@ def book_info_to_metadata(cls, book, include_bibliographic=True, include_formats
if include_formats:
formats = []
- for format in book.get('formats', []):
- format_id = format['id']
+ for format in book.get("formats", []):
+ format_id = format["id"]
internal_formats = list(cls.internal_formats(format_id))
if internal_formats:
for content_type, drm_scheme in internal_formats:
@@ -1214,7 +1190,8 @@ def book_info_to_metadata(cls, book, include_bibliographic=True, include_formats
elif format_id not in cls.ignorable_overdrive_formats:
cls.log.error(
"Could not process Overdrive format %s for %s",
- format_id, overdrive_id
+ format_id,
+ overdrive_id,
)
# Also make a CirculationData so we can write the formats,
@@ -1230,8 +1207,7 @@ def book_info_to_metadata(cls, book, include_bibliographic=True, include_formats
class OverdriveAdvantageAccount(object):
- """Holder and parser for data associated with Overdrive Advantage.
- """
+ """Holder and parser for data associated with Overdrive Advantage."""
def __init__(self, parent_library_id, library_id, name):
"""Constructor.
@@ -1255,15 +1231,14 @@ def from_representation(cls, content):
:yield: A sequence of OverdriveAdvantageAccount objects.
"""
data = json.loads(content)
- parent_id = str(data.get('id'))
- accounts = data.get('advantageAccounts', {})
+ parent_id = str(data.get("id"))
+ accounts = data.get("advantageAccounts", {})
for account in accounts:
- name = account['name']
- products_link = account['links']['products']['href']
- library_id = str(account.get('id'))
- name = account.get('name')
- yield cls(parent_library_id=parent_id, library_id=library_id,
- name=name)
+ name = account["name"]
+ products_link = account["links"]["products"]["href"]
+ library_id = str(account.get("id"))
+ name = account.get("name")
+ yield cls(parent_library_id=parent_id, library_id=library_id, name=name)
def to_collection(self, _db):
"""Find or create a Collection object for this Overdrive Advantage
@@ -1274,9 +1249,11 @@ def to_collection(self, _db):
"""
# First find the parent Collection.
try:
- parent = Collection.by_protocol(_db, ExternalIntegration.OVERDRIVE).filter(
- Collection.external_account_id==self.parent_library_id
- ).one()
+ parent = (
+ Collection.by_protocol(_db, ExternalIntegration.OVERDRIVE)
+ .filter(Collection.external_account_id == self.parent_library_id)
+ .one()
+ )
except NoResultFound as e:
# Without the parent's credentials we can't access the child.
raise ValueError(
@@ -1284,9 +1261,11 @@ def to_collection(self, _db):
)
name = parent.name + " / " + self.name
child, is_new = get_one_or_create(
- _db, Collection, parent_id=parent.id,
+ _db,
+ Collection,
+ parent_id=parent.id,
external_account_id=self.library_id,
- create_method_kwargs=dict(name=name)
+ create_method_kwargs=dict(name=name),
)
if is_new:
# Make sure the child has its protocol set appropriately.
@@ -1337,17 +1316,15 @@ def __init__(self, collection, api_class=OverdriveAPI, **kwargs):
def process_item(self, identifier):
info = self.api.metadata_lookup(identifier)
error = None
- if info.get('errorCode') == 'NotFound':
+ if info.get("errorCode") == "NotFound":
error = "ID not recognized by Overdrive: %s" % identifier.identifier
- elif info.get('errorCode') == 'InvalidGuid':
+ elif info.get("errorCode") == "InvalidGuid":
error = "Invalid Overdrive ID: %s" % identifier.identifier
if error:
return self.failure(identifier, error, transient=False)
- metadata = OverdriveRepresentationExtractor.book_info_to_metadata(
- info
- )
+ metadata = OverdriveRepresentationExtractor.book_info_to_metadata(info)
if not metadata:
e = "Could not extract metadata from Overdrive data: %r" % info
diff --git a/problem_details.py b/problem_details.py
index 02b63fdff..bfef5d0b1 100644
--- a/problem_details.py
+++ b/problem_details.py
@@ -1,65 +1,66 @@
-from .util.problem_detail import ProblemDetail as pd
-from .util.http import INTEGRATION_ERROR
from flask_babel import lazy_gettext as _
+from .util.http import INTEGRATION_ERROR
+from .util.problem_detail import ProblemDetail as pd
+
# Generic problem detail documents that recapitulate HTTP errors.
# call detailed() to add more specific information.
INVALID_INPUT = pd(
- "http://librarysimplified.org/terms/problem/invalid-input",
- 400,
- _("Invalid input."),
- _("You provided invalid or unrecognized input."),
+ "http://librarysimplified.org/terms/problem/invalid-input",
+ 400,
+ _("Invalid input."),
+ _("You provided invalid or unrecognized input."),
)
INVALID_CREDENTIALS = pd(
- "http://librarysimplified.org/terms/problem/credentials-invalid",
- 401,
- _("Invalid credentials"),
- _("Valid credentials are required."),
+ "http://librarysimplified.org/terms/problem/credentials-invalid",
+ 401,
+ _("Invalid credentials"),
+ _("Valid credentials are required."),
)
METHOD_NOT_ALLOWED = pd(
- "http://librarysimplified.org/terms/problem/method-not-allowed",
- 405,
- _("Method not allowed"),
- _("The HTTP method you used is not allowed on this resource."),
+ "http://librarysimplified.org/terms/problem/method-not-allowed",
+ 405,
+ _("Method not allowed"),
+ _("The HTTP method you used is not allowed on this resource."),
)
UNSUPPORTED_MEDIA_TYPE = pd(
- "http://librarysimplified.org/terms/problem/unsupported-media-type",
- 415,
- _("Unsupported media type"),
- _("You submitted an unsupported media type."),
+ "http://librarysimplified.org/terms/problem/unsupported-media-type",
+ 415,
+ _("Unsupported media type"),
+ _("You submitted an unsupported media type."),
)
PAYLOAD_TOO_LARGE = pd(
- "http://librarysimplified.org/terms/problem/unsupported-media-type",
- 413,
- _("Payload too large"),
- _("You submitted a document that was too large."),
+ "http://librarysimplified.org/terms/problem/unsupported-media-type",
+ 413,
+ _("Payload too large"),
+ _("You submitted a document that was too large."),
)
INTERNAL_SERVER_ERROR = pd(
- "http://librarysimplified.org/terms/problem/internal-server-error",
- 500,
- _("Internal server error."),
- _("Internal server error"),
+ "http://librarysimplified.org/terms/problem/internal-server-error",
+ 500,
+ _("Internal server error."),
+ _("Internal server error"),
)
# Problem detail documents that are specific to the Library Simplified
# domain.
INVALID_URN = pd(
- "http://librarysimplified.org/terms/problem/could-not-parse-urn",
- 400,
- _("Invalid URN"),
- _("Could not parse identifier."),
+ "http://librarysimplified.org/terms/problem/could-not-parse-urn",
+ 400,
+ _("Invalid URN"),
+ _("Could not parse identifier."),
)
UNRECOGNIZED_DATA_SOURCE = pd(
- "http://librarysimplified.org/terms/problem/unrecognized-data-source",
- 400,
- _("Unrecognized data source."),
- _("I don't know anything about that data source."),
+ "http://librarysimplified.org/terms/problem/unrecognized-data-source",
+ 400,
+ _("Unrecognized data source."),
+ _("I don't know anything about that data source."),
)
diff --git a/python_expression_dsl/ast.py b/python_expression_dsl/ast.py
index a71060c4f..c97826d9a 100644
--- a/python_expression_dsl/ast.py
+++ b/python_expression_dsl/ast.py
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
from enum import Enum
+
class Visitor(metaclass=ABCMeta):
"""Interface for visitors walking through abstract syntax trees (AST)."""
diff --git a/python_expression_dsl/evaluator.py b/python_expression_dsl/evaluator.py
index 549facf7a..d3bdb2e04 100644
--- a/python_expression_dsl/evaluator.py
+++ b/python_expression_dsl/evaluator.py
@@ -1,4 +1,5 @@
import operator
+import types
from copy import copy, deepcopy
from multipledispatch import dispatch
@@ -20,7 +21,7 @@
Visitor,
)
from .parser import DSLParser
-import types
+
class DSLEvaluationError(BaseError):
"""Raised when evaluation of a DSL expression fails."""
diff --git a/s3.py b/s3.py
index a34eeb8d9..74fc4a6e0 100644
--- a/s3.py
+++ b/s3.py
@@ -1,27 +1,26 @@
import functools
import logging
from contextlib import contextmanager
-from urllib.parse import quote, urlsplit, unquote_plus
+from enum import Enum
+from urllib.parse import quote, unquote_plus, urlsplit
import boto3
import botocore
from botocore.config import Config
-from botocore.exceptions import (
- BotoCoreError,
- ClientError,
-)
-from enum import Enum
+from botocore.exceptions import BotoCoreError, ClientError
from flask_babel import lazy_gettext as _
+
from .mirror import MirrorUploader
from .model import ExternalIntegration
from .model.configuration import (
- ConfigurationOption,
+ ConfigurationAttributeType,
ConfigurationGrouping,
ConfigurationMetadata,
- ConfigurationAttributeType
+ ConfigurationOption,
)
-class MultipartS3Upload():
+
+class MultipartS3Upload:
def __init__(self, uploader, representation, mirror_to):
self.uploader = uploader
self.representation = representation
@@ -80,7 +79,7 @@ def _get_available_regions():
"""
session = boto3.session.Session()
- return session.get_available_regions(service_name='s3')
+ return session.get_available_regions(service_name="s3")
def _get_available_region_options():
@@ -98,142 +97,138 @@ def _get_available_region_options():
class S3AddressingStyle(Enum):
"""Enumeration of different addressing styles supported by boto"""
- VIRTUAL = 'virtual'
- PATH = 'path'
- AUTO = 'auto'
+ VIRTUAL = "virtual"
+ PATH = "path"
+ AUTO = "auto"
class S3UploaderConfiguration(ConfigurationGrouping):
- S3_REGION = 's3_region'
- S3_DEFAULT_REGION = 'us-east-1'
+ S3_REGION = "s3_region"
+ S3_DEFAULT_REGION = "us-east-1"
- S3_ADDRESSING_STYLE = 's3_addressing_style'
+ S3_ADDRESSING_STYLE = "s3_addressing_style"
S3_DEFAULT_ADDRESSING_STYLE = S3AddressingStyle.VIRTUAL.value
- S3_PRESIGNED_URL_EXPIRATION = 's3_presigned_url_expiration'
+ S3_PRESIGNED_URL_EXPIRATION = "s3_presigned_url_expiration"
S3_DEFAULT_PRESIGNED_URL_EXPIRATION = 3600
- BOOK_COVERS_BUCKET_KEY = 'book_covers_bucket'
- OA_CONTENT_BUCKET_KEY = 'open_access_content_bucket'
- PROTECTED_CONTENT_BUCKET_KEY = 'protected_content_bucket'
- MARC_BUCKET_KEY = 'marc_bucket'
+ BOOK_COVERS_BUCKET_KEY = "book_covers_bucket"
+ OA_CONTENT_BUCKET_KEY = "open_access_content_bucket"
+ PROTECTED_CONTENT_BUCKET_KEY = "protected_content_bucket"
+ MARC_BUCKET_KEY = "marc_bucket"
- URL_TEMPLATE_KEY = 'bucket_name_transform'
- URL_TEMPLATE_HTTP = 'http'
- URL_TEMPLATE_HTTPS = 'https'
- URL_TEMPLATE_DEFAULT = 'identity'
+ URL_TEMPLATE_KEY = "bucket_name_transform"
+ URL_TEMPLATE_HTTP = "http"
+ URL_TEMPLATE_HTTPS = "https"
+ URL_TEMPLATE_DEFAULT = "identity"
URL_TEMPLATES_BY_TEMPLATE = {
- URL_TEMPLATE_HTTP: 'http://%(bucket)s/%(key)s',
- URL_TEMPLATE_HTTPS: 'https://%(bucket)s/%(key)s',
- URL_TEMPLATE_DEFAULT: 'https://%(bucket)s.s3.%(region)s/%(key)s'
+ URL_TEMPLATE_HTTP: "http://%(bucket)s/%(key)s",
+ URL_TEMPLATE_HTTPS: "https://%(bucket)s/%(key)s",
+ URL_TEMPLATE_DEFAULT: "https://%(bucket)s.s3.%(region)s/%(key)s",
}
access_key = ConfigurationMetadata(
key=ExternalIntegration.USERNAME,
- label=_('Access Key'),
- description='',
+ label=_("Access Key"),
+ description="",
type=ConfigurationAttributeType.TEXT,
- required=False
+ required=False,
)
secret_key = ConfigurationMetadata(
key=ExternalIntegration.PASSWORD,
- label=_('Secret Key'),
+ label=_("Secret Key"),
description=_(
- 'If the Access Key and Secret Key are not given here credentials '
- 'will be used as outlined in the '
+ "If the Access Key and Secret Key are not given here credentials "
+ "will be used as outlined in the "
'Boto3 documenation. '
- 'If Access Key is given, Secrent Key must also be given.'
+ "If Access Key is given, Secrent Key must also be given."
),
type=ConfigurationAttributeType.TEXT,
- required=False
+ required=False,
)
book_covers_bucket = ConfigurationMetadata(
key=BOOK_COVERS_BUCKET_KEY,
- label=_('Book Covers Bucket'),
+ label=_("Book Covers Bucket"),
description=_(
- 'All book cover images encountered will be mirrored to this S3 bucket. '
- 'Large images will be scaled down, and the scaled-down copies will also be uploaded to this bucket. '
- '
The bucket must already exist—it will not be created automatically.
'
+ "All book cover images encountered will be mirrored to this S3 bucket. "
+ "Large images will be scaled down, and the scaled-down copies will also be uploaded to this bucket. "
+ "
The bucket must already exist—it will not be created automatically.
"
),
type=ConfigurationAttributeType.TEXT,
- required=False
+ required=False,
)
open_access_content_bucket = ConfigurationMetadata(
key=OA_CONTENT_BUCKET_KEY,
- label=_('Open Access Content Bucket'),
+ label=_("Open Access Content Bucket"),
description=_(
- 'All open-access books encountered will be uploaded to this S3 bucket. '
- '
The bucket must already exist—it will not be created automatically.
'
+ "All open-access books encountered will be uploaded to this S3 bucket. "
+ "
The bucket must already exist—it will not be created automatically.
"
),
type=ConfigurationAttributeType.TEXT,
- required=False
+ required=False,
)
protected_access_content_bucket = ConfigurationMetadata(
key=PROTECTED_CONTENT_BUCKET_KEY,
- label=_('Protected Access Content Bucket'),
+ label=_("Protected Access Content Bucket"),
description=_(
- 'Self-hosted books will be uploaded to this S3 bucket. '
- '
The bucket must already exist—it will not be created automatically.
'
+ "Self-hosted books will be uploaded to this S3 bucket. "
+ "
The bucket must already exist—it will not be created automatically.
"
),
type=ConfigurationAttributeType.TEXT,
- required=False
+ required=False,
)
marc_file_bucket = ConfigurationMetadata(
key=MARC_BUCKET_KEY,
- label=_('MARC File Bucket'),
+ label=_("MARC File Bucket"),
description=_(
- 'All generated MARC files will be uploaded to this S3 bucket. '
- '
The bucket must already exist—it will not be created automatically.
'
+ "All generated MARC files will be uploaded to this S3 bucket. "
+ "
The bucket must already exist—it will not be created automatically.
"
),
type=ConfigurationAttributeType.TEXT,
- required=False
+ required=False,
)
s3_region = ConfigurationMetadata(
key=S3_REGION,
- label=_('S3 region'),
- description=_(
- 'S3 region which will be used for storing the content.'
- ),
+ label=_("S3 region"),
+ description=_("S3 region which will be used for storing the content."),
type=ConfigurationAttributeType.SELECT,
required=False,
default=S3_DEFAULT_REGION,
- options=_get_available_region_options()
+ options=_get_available_region_options(),
)
s3_addressing_style = ConfigurationMetadata(
key=S3_ADDRESSING_STYLE,
- label=_('S3 addressing style'),
+ label=_("S3 addressing style"),
description=_(
- 'Buckets created after September 30, 2020, will support only virtual hosted-style requests. '
- 'Path-style requests will continue to be supported for buckets created on or before this date. '
- 'For more information, '
+ "Buckets created after September 30, 2020, will support only virtual hosted-style requests. "
+ "Path-style requests will continue to be supported for buckets created on or before this date. "
+ "For more information, "
'see '
- 'Amazon S3 Path Deprecation Plan - The Rest of the Story.'
+ "Amazon S3 Path Deprecation Plan - The Rest of the Story."
),
type=ConfigurationAttributeType.SELECT,
required=False,
default=S3_DEFAULT_REGION,
options=[
- ConfigurationOption(S3AddressingStyle.VIRTUAL.value, _('Virtual')),
- ConfigurationOption(S3AddressingStyle.PATH.value, _('Path')),
- ConfigurationOption(S3AddressingStyle.AUTO.value, _('Auto'))
- ]
+ ConfigurationOption(S3AddressingStyle.VIRTUAL.value, _("Virtual")),
+ ConfigurationOption(S3AddressingStyle.PATH.value, _("Path")),
+ ConfigurationOption(S3AddressingStyle.AUTO.value, _("Auto")),
+ ],
)
s3_presigned_url_expiration = ConfigurationMetadata(
key=S3_PRESIGNED_URL_EXPIRATION,
- label=_('S3 presigned URL expiration'),
- description=_(
- 'Time in seconds for the presigned URL to remain valid'
- ),
+ label=_("S3 presigned URL expiration"),
+ description=_("Time in seconds for the presigned URL to remain valid"),
type=ConfigurationAttributeType.NUMBER,
required=False,
default=S3_DEFAULT_PRESIGNED_URL_EXPIRATION,
@@ -241,25 +236,27 @@ class S3UploaderConfiguration(ConfigurationGrouping):
url_template = ConfigurationMetadata(
key=URL_TEMPLATE_KEY,
- label=_('URL format'),
+ label=_("URL format"),
description=_(
- 'A file mirrored to S3 is available at http://{bucket}.s3.{region}.amazonaws.com/{filename}. '
- 'If you\'ve set up your DNS so that http://[bucket]/ or https://[bucket]/ points to the appropriate '
- 'S3 bucket, you can configure this S3 integration to shorten the URLs. '
- '
If you haven\'t set up your S3 buckets, don\'t change this from the default -- '
- 'you\'ll get URLs that don\'t work.
'
+ "A file mirrored to S3 is available at http://{bucket}.s3.{region}.amazonaws.com/{filename}. "
+ "If you've set up your DNS so that http://[bucket]/ or https://[bucket]/ points to the appropriate "
+ "S3 bucket, you can configure this S3 integration to shorten the URLs. "
+ "
If you haven't set up your S3 buckets, don't change this from the default -- "
+ "you'll get URLs that don't work.
"
),
type=ConfigurationAttributeType.SELECT,
required=False,
default=URL_TEMPLATE_DEFAULT,
options=[
ConfigurationOption(
- URL_TEMPLATE_DEFAULT, _('S3 Default: https://{bucket}.s3.{region}.amazonaws.com/{file}')),
- ConfigurationOption(
- URL_TEMPLATE_HTTPS, _('HTTPS: https://{bucket}/{file}')),
+ URL_TEMPLATE_DEFAULT,
+ _("S3 Default: https://{bucket}.s3.{region}.amazonaws.com/{file}"),
+ ),
ConfigurationOption(
- URL_TEMPLATE_HTTP, _('HTTP: http://{bucket}/{file}'))
- ]
+ URL_TEMPLATE_HTTPS, _("HTTPS: https://{bucket}/{file}")
+ ),
+ ConfigurationOption(URL_TEMPLATE_HTTP, _("HTTP: http://{bucket}/{file}")),
+ ],
)
@@ -267,7 +264,7 @@ class S3Uploader(MirrorUploader):
NAME = ExternalIntegration.S3
# AWS S3 host
- S3_HOST = 'amazonaws.com'
+ S3_HOST = "amazonaws.com"
SETTINGS = S3UploaderConfiguration.to_settings()
@@ -292,24 +289,24 @@ def __init__(self, integration, client_class=None, host=S3_HOST):
client_class = boto3.client
self._s3_region = integration.setting(
- S3UploaderConfiguration.S3_REGION).value_or_default(
- S3UploaderConfiguration.S3_DEFAULT_REGION)
+ S3UploaderConfiguration.S3_REGION
+ ).value_or_default(S3UploaderConfiguration.S3_DEFAULT_REGION)
self._s3_addressing_style = integration.setting(
- S3UploaderConfiguration.S3_ADDRESSING_STYLE).value_or_default(
- S3UploaderConfiguration.S3_DEFAULT_ADDRESSING_STYLE)
+ S3UploaderConfiguration.S3_ADDRESSING_STYLE
+ ).value_or_default(S3UploaderConfiguration.S3_DEFAULT_ADDRESSING_STYLE)
self._s3_presigned_url_expiration = integration.setting(
- S3UploaderConfiguration.S3_PRESIGNED_URL_EXPIRATION).value_or_default(
- S3UploaderConfiguration.S3_DEFAULT_PRESIGNED_URL_EXPIRATION)
+ S3UploaderConfiguration.S3_PRESIGNED_URL_EXPIRATION
+ ).value_or_default(S3UploaderConfiguration.S3_DEFAULT_PRESIGNED_URL_EXPIRATION)
if callable(client_class):
# Pass None into boto3 if we get an empty string.
- access_key = integration.username if integration.username != '' else None
- secret_key = integration.password if integration.password != '' else None
+ access_key = integration.username if integration.username != "" else None
+ secret_key = integration.password if integration.password != "" else None
config = Config(
signature_version=botocore.UNSIGNED,
- s3={'addressing_style': self._s3_addressing_style}
+ s3={"addressing_style": self._s3_addressing_style},
)
# NOTE: Unfortunately, boto ignores credentials (aws_access_key_id, aws_secret_access_key)
# when using botocore.UNSIGNED signature version and doesn't authenticate the client in this case.
@@ -317,14 +314,14 @@ def __init__(self, integration, client_class=None, host=S3_HOST):
# - the first client WITHOUT authentication which is used for generating unsigned URLs
# - the second client WITH authentication used for working with S3: uploading files, etc.
self._s3_link_client = client_class(
- 's3',
+ "s3",
region_name=self._s3_region,
aws_access_key_id=None,
aws_secret_access_key=None,
- config=config
+ config=config,
)
self.client = client_class(
- 's3',
+ "s3",
region_name=self._s3_region,
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
@@ -333,15 +330,15 @@ def __init__(self, integration, client_class=None, host=S3_HOST):
self.client = client_class
self.url_transform = integration.setting(
- S3UploaderConfiguration.URL_TEMPLATE_KEY).value_or_default(
- S3UploaderConfiguration.URL_TEMPLATE_DEFAULT)
+ S3UploaderConfiguration.URL_TEMPLATE_KEY
+ ).value_or_default(S3UploaderConfiguration.URL_TEMPLATE_DEFAULT)
# Transfer information about bucket names from the
# ExternalIntegration to the S3Uploader object, so we don't
# have to keep the ExternalIntegration around.
self.buckets = dict()
for setting in integration.settings:
- if setting.key.endswith('_bucket'):
+ if setting.key.endswith("_bucket"):
self.buckets[setting.key] = setting.value
def _generate_s3_url(self, bucket, path):
@@ -360,21 +357,16 @@ def _generate_s3_url(self, bucket, path):
# However, boto3 doesn't allow us to pass the key as an empty string.
# As a workaround we set it to a dummy string and later remove it from the generated URL
if not path:
- key = 'dummy'
+ key = "dummy"
url = self._s3_link_client.generate_presigned_url(
- 'get_object',
- ExpiresIn=0,
- Params={
- 'Bucket': bucket,
- 'Key': key
- }
+ "get_object", ExpiresIn=0, Params={"Bucket": bucket, "Key": key}
)
# If the path was an empty string we need to strip out trailing dummy string ending up with a URL
# pointing at the root directory of the bucket
if not path:
- url = url.replace('/' + key, '/')
+ url = url.replace("/" + key, "/")
return url
@@ -396,12 +388,9 @@ def sign_url(self, url, expiration=None):
bucket, key = self.split_url(url)
url = self.client.generate_presigned_url(
- 'get_object',
+ "get_object",
ExpiresIn=int(expiration),
- Params={
- 'Bucket': bucket,
- 'Key': key
- }
+ Params={"Bucket": bucket, "Key": key},
)
return url
@@ -412,7 +401,7 @@ def get_bucket(self, bucket_key):
def url(self, bucket, path):
"""The URL to a resource on S3 identified by bucket and path."""
- custom_url = bucket.startswith('http://') or bucket.startswith('https://')
+ custom_url = bucket.startswith("http://") or bucket.startswith("https://")
if isinstance(path, list):
# This is a list of key components that need to be quoted
@@ -420,14 +409,14 @@ def url(self, bucket, path):
path = self.key_join(path, encode=custom_url)
if isinstance(path, bytes):
path = path.decode("utf-8")
- if path.startswith('/'):
+ if path.startswith("/"):
path = path[1:]
if custom_url:
url = bucket
- if not url.endswith('/'):
- url += '/'
+ if not url.endswith("/"):
+ url += "/"
return url + path
else:
@@ -448,15 +437,15 @@ def cover_image_root(self, bucket, data_source, scaled_size=None):
data_source_name = data_source.name
parts.append(data_source_name)
url = self.url(bucket, parts)
- if not url.endswith('/'):
- url += '/'
+ if not url.endswith("/"):
+ url += "/"
return url
def content_root(self, bucket):
"""The root URL to the S3 location of hosted content of
the given type.
"""
- return self.url(bucket, '/')
+ return self.url(bucket, "/")
def marc_file_root(self, bucket, library):
url = self.url(bucket, [library.short_name])
@@ -475,7 +464,7 @@ def key_join(self, key, encode=True):
:return: A string that can be used as an S3 key.
"""
if isinstance(key, str):
- parts = key.split('/')
+ parts = key.split("/")
else:
parts = key
new_parts = []
@@ -487,18 +476,26 @@ def key_join(self, key, encode=True):
part = quote(str(part))
new_parts.append(part)
- return '/'.join(new_parts)
+ return "/".join(new_parts)
- def book_url(self, identifier, extension='.epub', open_access=True,
- data_source=None, title=None):
+ def book_url(
+ self,
+ identifier,
+ extension=".epub",
+ open_access=True,
+ data_source=None,
+ title=None,
+ ):
"""The path to the hosted EPUB file for the given identifier."""
bucket = self.get_bucket(
- S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY if open_access
- else S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY)
+ S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY
+ if open_access
+ else S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY
+ )
root = self.content_root(bucket)
- if not extension.startswith('.'):
- extension = '.' + extension
+ if not extension.startswith("."):
+ extension = "." + extension
parts = []
if data_source:
@@ -514,8 +511,7 @@ def book_url(self, identifier, extension='.epub', open_access=True,
parts.append(filename + extension)
return root + self.key_join(parts)
- def cover_image_url(self, data_source, identifier, filename,
- scaled_size=None):
+ def cover_image_url(self, data_source, identifier, filename, scaled_size=None):
"""The path to the hosted cover image for the given identifier."""
bucket = self.get_bucket(S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY)
root = self.cover_image_root(bucket, data_source, scaled_size)
@@ -549,7 +545,7 @@ def split_url(self, url, unquote=True):
scheme, netloc, path, query, fragment = urlsplit(url)
if self.is_self_url(url):
- host_parts = netloc.split('.')
+ host_parts = netloc.split(".")
host_parts_count = len(host_parts)
# 1. Path-style requests
@@ -561,15 +557,16 @@ def split_url(self, url, unquote=True):
# 2.2. Endpoints with s3-region: https://{bucket}.s3-{region}.amazonaws.com/{path}
# 2.3. Endpoints with s3.region: https://{bucket}.s3.{region}.amazonaws.com/{path}
- if host_parts_count == 3 or \
- (host_parts_count == 4 and host_parts[0] == 's3'):
- if path.startswith('/'):
+ if host_parts_count == 3 or (
+ host_parts_count == 4 and host_parts[0] == "s3"
+ ):
+ if path.startswith("/"):
path = path[1:]
- bucket, filename = path.split('/', 1)
+ bucket, filename = path.split("/", 1)
else:
bucket = host_parts[0]
- if path.startswith('/'):
+ if path.startswith("/"):
path = path[1:]
filename = path
@@ -626,7 +623,7 @@ def mirror_one(self, representation, mirror_to, collection=None):
Fileobj=fh,
Bucket=bucket,
Key=remote_filename,
- ExtraArgs=dict(ContentType=media_type)
+ ExtraArgs=dict(ContentType=media_type),
)
# Since upload_fileobj completed without a problem, we
@@ -640,8 +637,7 @@ def mirror_one(self, representation, mirror_to, collection=None):
if representation.url != mirror_url:
source = representation.url
if source:
- logging.info("MIRRORED %s => %s",
- source, representation.mirror_url)
+ logging.info("MIRRORED %s => %s", source, representation.mirror_url)
else:
logging.info("MIRRORED %s", representation.mirror_url)
except (BotoCoreError, ClientError) as e:
@@ -651,14 +647,14 @@ def mirror_one(self, representation, mirror_to, collection=None):
# the best thing to do is treat this as a transient
# error and try again later. There's no scenario where
# giving up is the right move.
- logging.error(
- "Error uploading %s: %r", mirror_to, e, exc_info=e
- )
+ logging.error("Error uploading %s: %r", mirror_to, e, exc_info=e)
finally:
fh.close()
@contextmanager
- def multipart_upload(self, representation, mirror_to, upload_class=MultipartS3Upload):
+ def multipart_upload(
+ self, representation, mirror_to, upload_class=MultipartS3Upload
+ ):
upload = upload_class(self, representation, mirror_to)
try:
yield upload
@@ -675,23 +671,23 @@ def multipart_upload(self, representation, mirror_to, upload_class=MultipartS3Up
class MinIOUploaderConfiguration(ConfigurationGrouping):
- ENDPOINT_URL = 'ENDPOINT_URL'
+ ENDPOINT_URL = "ENDPOINT_URL"
endpoint_url = ConfigurationMetadata(
key=ENDPOINT_URL,
- label=_('Endpoint URL'),
- description=_(
- 'MinIO\'s endpoint URL'
- ),
+ label=_("Endpoint URL"),
+ description=_("MinIO's endpoint URL"),
type=ConfigurationAttributeType.TEXT,
- required=True
+ required=True,
)
class MinIOUploader(S3Uploader):
NAME = ExternalIntegration.MINIO
- SETTINGS = S3Uploader.SETTINGS + [MinIOUploaderConfiguration.endpoint_url.to_settings()]
+ SETTINGS = S3Uploader.SETTINGS + [
+ MinIOUploaderConfiguration.endpoint_url.to_settings()
+ ]
def __init__(self, integration, client_class=None):
"""Instantiate an S3Uploader from an ExternalIntegration.
@@ -702,7 +698,8 @@ def __init__(self, integration, client_class=None):
instead of boto3.client.
"""
endpoint_url = integration.setting(
- MinIOUploaderConfiguration.ENDPOINT_URL).value
+ MinIOUploaderConfiguration.ENDPOINT_URL
+ ).value
_, host, _, _, _ = urlsplit(endpoint_url)
@@ -726,10 +723,10 @@ class MockS3Uploader(S3Uploader):
"""A dummy uploader for use in tests."""
buckets = {
- S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'test-cover-bucket',
- S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'test-content-bucket',
- S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: 'test-content-bucket',
- S3UploaderConfiguration.MARC_BUCKET_KEY: 'test-marc-bucket',
+ S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "test-cover-bucket",
+ S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "test-content-bucket",
+ S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: "test-content-bucket",
+ S3UploaderConfiguration.MARC_BUCKET_KEY: "test-marc-bucket",
}
def __init__(self, fail=False, *args, **kwargs):
@@ -741,24 +738,24 @@ def __init__(self, fail=False, *args, **kwargs):
self._s3_addressing_style = S3UploaderConfiguration.S3_DEFAULT_ADDRESSING_STYLE
config = Config(
signature_version=botocore.UNSIGNED,
- s3={'addressing_style': self._s3_addressing_style}
+ s3={"addressing_style": self._s3_addressing_style},
)
self._s3_link_client = boto3.client(
- 's3',
+ "s3",
region_name=self._s3_region,
aws_access_key_id=None,
aws_secret_access_key=None,
- config=config
+ config=config,
)
self.client = boto3.client(
- 's3',
+ "s3",
region_name=self._s3_region,
aws_access_key_id=None,
aws_secret_access_key=None,
)
def mirror_one(self, representation, **kwargs):
- mirror_to = kwargs['mirror_to']
+ mirror_to = kwargs["mirror_to"]
self.uploaded.append(representation)
self.destinations.append(mirror_to)
self.content.append(representation.content)
@@ -795,8 +792,15 @@ class MockS3Client(object):
boto3 client.
"""
- def __init__(self, service, region_name, aws_access_key_id, aws_secret_access_key, config=None):
- assert service == 's3'
+ def __init__(
+ self,
+ service,
+ region_name,
+ aws_access_key_id,
+ aws_secret_access_key,
+ config=None,
+ ):
+ assert service == "s3"
self.region_name = region_name
self.access_key = aws_access_key_id
self.secret_key = aws_secret_access_key
@@ -832,9 +836,6 @@ def abort_multipart_upload(self, **kwargs):
return None
def generate_presigned_url(
- self,
- ClientMethod,
- Params=None,
- ExpiresIn=3600,
- HttpMethod=None):
+ self, ClientMethod, Params=None, ExpiresIn=3600, HttpMethod=None
+ ):
return None
diff --git a/scripts.py b/scripts.py
index b14bce5c3..4f432b71c 100644
--- a/scripts.py
+++ b/scripts.py
@@ -8,44 +8,27 @@
import traceback
import unicodedata
import uuid
-from pdb import set_trace
from collections import defaultdict
from enum import Enum
-from sqlalchemy import (
- exists,
- and_,
- text,
-)
+from pdb import set_trace
+
+from sqlalchemy import and_, exists, text
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm import Session
-from sqlalchemy.orm.exc import (
- NoResultFound,
- MultipleResultsFound,
-)
+from sqlalchemy.orm.exc import MultipleResultsFound, NoResultFound
-from .config import Configuration, CannotLoadConfiguration
-from .coverage import (
- CollectionCoverageProviderJob,
- CoverageProviderProgress,
-)
-from .external_search import (
- ExternalSearchIndex,
- Filter,
- SearchIndexCoverageProvider,
-)
+from .config import CannotLoadConfiguration, Configuration
+from .coverage import CollectionCoverageProviderJob, CoverageProviderProgress
+from .external_search import ExternalSearchIndex, Filter, SearchIndexCoverageProvider
from .lane import Lane
from .metadata_layer import (
LinkData,
- ReplacementPolicy,
MetaToModelUtility,
+ ReplacementPolicy,
TimestampData,
)
from .mirror import MirrorUploader
from .model import (
- create,
- get_one,
- get_one_or_create,
- production_session,
BaseCoverageRecord,
CachedFeed,
Collection,
@@ -69,30 +52,22 @@
Timestamp,
Work,
WorkCoverageRecord,
+ create,
+ get_one,
+ get_one_or_create,
+ production_session,
site_configuration_has_changed,
)
from .model.configuration import ExternalIntegrationLink
-from .monitor import (
- CollectionMonitor,
- ReaperMonitor,
-)
-from .opds_import import (
- OPDSImportMonitor,
- OPDSImporter,
-)
+from .monitor import CollectionMonitor, ReaperMonitor
+from .opds_import import OPDSImporter, OPDSImportMonitor
from .util import fast_query_count
-from .util.personal_names import (
- contributor_name_match_ratio,
- display_name_to_sort_name
-)
-from .util.worker_pools import (
- DatabasePool,
-)
from .util.datetime_helpers import strptime_utc, to_utc, utc_now
+from .util.personal_names import contributor_name_match_ratio, display_name_to_sort_name
+from .util.worker_pools import DatabasePool
class Script(object):
-
@property
def _db(self):
if not hasattr(self, "_session"):
@@ -106,11 +81,11 @@ def script_name(self):
This is either the .name of the Script object or the name of
the class.
"""
- return getattr(self, 'name', self.__class__.__name__)
+ return getattr(self, "name", self.__class__.__name__)
@property
def log(self):
- if not hasattr(self, '_log'):
+ if not hasattr(self, "_log"):
self._log = logging.getLogger(self.script_name)
return self._log
@@ -132,8 +107,8 @@ def parse_time(cls, time_string):
"""Try to pass the given string as a time."""
if not time_string:
return None
- for format in ('%Y-%m-%d', '%m/%d/%Y', '%Y%m%d'):
- for hours in ('', ' %H:%M:%S'):
+ for format in ("%Y-%m-%d", "%m/%d/%Y", "%Y%m%d"):
+ for hours in ("", " %H:%M:%S"):
full_format = format + hours
try:
parsed = strptime_utc(time_string, full_format)
@@ -162,10 +137,7 @@ def run(self):
timestamp_data = None
self.update_timestamp(timestamp_data, start_time, None)
except Exception as e:
- logging.error(
- "Fatal exception while running script: %s", e,
- exc_info=e
- )
+ logging.error("Fatal exception while running script: %s", e, exc_info=e)
stack_trace = traceback.format_exc()
self.update_timestamp(None, start_time, stack_trace)
raise
@@ -214,13 +186,16 @@ def update_timestamp(self, timestamp_data, start, exception):
if timestamp_data is None:
timestamp_data = TimestampData()
timestamp_data.finalize(
- self.script_name, Timestamp.SCRIPT_TYPE, self.timestamp_collection,
- start=start, exception=exception
+ self.script_name,
+ Timestamp.SCRIPT_TYPE,
+ self.timestamp_collection,
+ start=start,
+ exception=exception,
)
timestamp_data.apply(self._db)
-class RunMonitorScript(Script):
+class RunMonitorScript(Script):
def __init__(self, monitor, _db=None, **kwargs):
super(RunMonitorScript, self).__init__(_db)
if issubclass(monitor, CollectionMonitor):
@@ -288,7 +263,10 @@ def do_run(self):
monitor.exception = e
self.log.error(
"Error running monitor %s for collection %s: %s",
- self.name, collection_name, e, exc_info=e
+ self.name,
+ collection_name,
+ e,
+ exc_info=e,
)
@@ -303,6 +281,7 @@ def monitors(self, **kwargs):
class RunCoverageProvidersScript(Script):
"""Alternate between multiple coverage providers."""
+
def __init__(self, providers, _db=None):
super(RunCoverageProvidersScript, self).__init__(_db=_db)
self.providers = []
@@ -314,7 +293,7 @@ def __init__(self, providers, _db=None):
def do_run(self):
providers = list(self.providers)
if not providers:
- self.log.info('No CoverageProviders to run.')
+ self.log.info("No CoverageProviders to run.")
progress = []
while providers:
@@ -328,7 +307,8 @@ def do_run(self):
except Exception as e:
self.log.error(
"Error in %r, moving on to next CoverageProvider.",
- provider, exc_info=e
+ provider,
+ exc_info=e,
)
self.log.debug("Completed %s", provider.service_name)
@@ -340,6 +320,7 @@ class RunCollectionCoverageProviderScript(RunCoverageProvidersScript):
"""Run the same CoverageProvider code for all Collections that
get their licenses from the appropriate place.
"""
+
def __init__(self, provider_class, _db=None, providers=None, **kwargs):
_db = _db or self._db
providers = providers or list()
@@ -356,9 +337,7 @@ class RunThreadedCollectionCoverageProviderScript(Script):
DEFAULT_WORKER_SIZE = 5
- def __init__(self, provider_class, worker_size=None, _db=None,
- **provider_kwargs
- ):
+ def __init__(self, provider_class, worker_size=None, _db=None, **provider_kwargs):
super(RunThreadedCollectionCoverageProviderScript, self).__init__(_db)
self.worker_size = worker_size or self.DEFAULT_WORKER_SIZE
@@ -389,9 +368,7 @@ def run(self, pool=None):
with (
pool or DatabasePool(self.worker_size, self.session_factory)
) as job_queue:
- query_size, batch_size = self.get_query_and_batch_sizes(
- provider
- )
+ query_size, batch_size = self.get_query_and_batch_sizes(provider)
# Without a commit, the query to count which items need
# coverage hangs in the database, blocking the threads.
self._db.commit()
@@ -402,13 +379,13 @@ def run(self, pool=None):
# value as its complets. It woudl be better if all the
# jobs could share a single 'progress' object.
while offset < query_size:
- progress = CoverageProviderProgress(
- start=utc_now()
- )
+ progress = CoverageProviderProgress(start=utc_now())
progress.offset = offset
job = CollectionCoverageProviderJob(
- collection, self.provider_class, progress,
- **self.provider_kwargs
+ collection,
+ self.provider_class,
+ progress,
+ **self.provider_kwargs,
)
job_queue.put(job)
offset += batch_size
@@ -450,15 +427,18 @@ class IdentifierInputScript(InputScript):
DATABASE_ID = "Database ID"
@classmethod
- def parse_command_line(cls, _db=None, cmd_args=None, stdin=sys.stdin,
- *args, **kwargs):
+ def parse_command_line(
+ cls, _db=None, cmd_args=None, stdin=sys.stdin, *args, **kwargs
+ ):
parser = cls.arg_parser()
parsed = parser.parse_args(cmd_args)
stdin = cls.read_stdin_lines(stdin)
return cls.look_up_identifiers(_db, parsed, stdin, *args, **kwargs)
@classmethod
- def look_up_identifiers(cls, _db, parsed, stdin_identifier_strings, *args, **kwargs):
+ def look_up_identifiers(
+ cls, _db, parsed, stdin_identifier_strings, *args, **kwargs
+ ):
"""Turn identifiers as specified on the command line into
real database Identifier objects.
"""
@@ -469,12 +449,14 @@ def look_up_identifiers(cls, _db, parsed, stdin_identifier_strings, *args, **kwa
# We can also call parse_identifier_list.
identifier_strings = parsed.identifier_strings
if stdin_identifier_strings:
- identifier_strings = (
- identifier_strings + stdin_identifier_strings
- )
+ identifier_strings = identifier_strings + stdin_identifier_strings
parsed.identifiers = cls.parse_identifier_list(
- _db, parsed.identifier_type, data_source,
- identifier_strings, *args, **kwargs
+ _db,
+ parsed.identifier_type,
+ data_source,
+ identifier_strings,
+ *args,
+ **kwargs,
)
else:
# The script can call parse_identifier_list later if it
@@ -486,23 +468,24 @@ def look_up_identifiers(cls, _db, parsed, stdin_identifier_strings, *args, **kwa
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--identifier-type',
- help='Process identifiers of this type. If IDENTIFIER is not specified, all identifiers of this type will be processed. To name identifiers by their database ID, use --identifier-type="Database ID"'
+ "--identifier-type",
+ help='Process identifiers of this type. If IDENTIFIER is not specified, all identifiers of this type will be processed. To name identifiers by their database ID, use --identifier-type="Database ID"',
)
parser.add_argument(
- '--identifier-data-source',
- help='Process only identifiers which have a LicensePool associated with this DataSource'
+ "--identifier-data-source",
+ help="Process only identifiers which have a LicensePool associated with this DataSource",
)
parser.add_argument(
- 'identifier_strings',
- help='A specific identifier to process.',
- metavar='IDENTIFIER', nargs='*'
+ "identifier_strings",
+ help="A specific identifier to process.",
+ metavar="IDENTIFIER",
+ nargs="*",
)
return parser
@classmethod
def parse_identifier_list(
- cls, _db, identifier_type, data_source, arguments, autocreate=False
+ cls, _db, identifier_type, data_source, arguments, autocreate=False
):
"""Turn a list of identifiers into a list of Identifier objects.
@@ -520,16 +503,21 @@ def parse_identifier_list(
identifiers = []
if not identifier_type:
- raise ValueError("No identifier type specified! Use '--identifier-type=\"Database ID\"' to name identifiers by database ID.")
+ raise ValueError(
+ "No identifier type specified! Use '--identifier-type=\"Database ID\"' to name identifiers by database ID."
+ )
if len(arguments) == 0:
if data_source:
- identifiers = _db.query(Identifier).\
- join(Identifier.licensed_through).\
- filter(
- Identifier.type==identifier_type,
- LicensePool.data_source==data_source
- ).all()
+ identifiers = (
+ _db.query(Identifier)
+ .join(Identifier.licensed_through)
+ .filter(
+ Identifier.type == identifier_type,
+ LicensePool.data_source == data_source,
+ )
+ .all()
+ )
return identifiers
for arg in arguments:
@@ -546,9 +534,7 @@ def parse_identifier_list(
_db, identifier_type, arg, autocreate=autocreate
)
if not identifier:
- logging.warn(
- "Could not load identifier %s/%s", identifier_type, arg
- )
+ logging.warn("Could not load identifier %s/%s", identifier_type, arg)
if identifier:
identifiers.append(identifier)
return identifiers
@@ -556,23 +542,24 @@ def parse_identifier_list(
class LibraryInputScript(InputScript):
"""A script that operates on one or more Libraries."""
+
@classmethod
- def parse_command_line(cls, _db=None, cmd_args=None,
- *args, **kwargs):
+ def parse_command_line(cls, _db=None, cmd_args=None, *args, **kwargs):
parser = cls.arg_parser(_db)
parsed = parser.parse_args(cmd_args)
return cls.look_up_libraries(_db, parsed, *args, **kwargs)
@classmethod
- def arg_parser(cls, _db, multiple_libraries = True):
+ def arg_parser(cls, _db, multiple_libraries=True):
parser = argparse.ArgumentParser()
library_names = sorted(l.short_name for l in _db.query(Library))
library_names = '"' + '", "'.join(library_names) + '"'
parser.add_argument(
- 'libraries',
- help='Name of a specific library to process. Libraries on this system: %s' % library_names,
- metavar='SHORT_NAME',
- nargs='*' if multiple_libraries else 1
+ "libraries",
+ help="Name of a specific library to process. Libraries on this system: %s"
+ % library_names,
+ metavar="SHORT_NAME",
+ nargs="*" if multiple_libraries else 1,
)
return parser
@@ -613,7 +600,7 @@ def parse_library_list(cls, _db, arguments):
continue
for field in (Library.short_name, Library.name):
try:
- library = _db.query(Library).filter(field==arg).one()
+ library = _db.query(Library).filter(field == arg).one()
except NoResultFound:
continue
except MultipleResultsFound:
@@ -622,9 +609,7 @@ def parse_library_list(cls, _db, arguments):
libraries.append(library)
break
else:
- logging.warn(
- "Could not find library %s", arg
- )
+ logging.warn("Could not find library %s", arg)
return libraries
def do_run(self, *args, **kwargs):
@@ -643,22 +628,26 @@ class PatronInputScript(LibraryInputScript):
"""A script that operates on one or more Patrons."""
@classmethod
- def parse_command_line(cls, _db=None, cmd_args=None, stdin=sys.stdin,
- *args, **kwargs):
+ def parse_command_line(
+ cls, _db=None, cmd_args=None, stdin=sys.stdin, *args, **kwargs
+ ):
parser = cls.arg_parser(_db)
parsed = parser.parse_args(cmd_args)
if stdin:
stdin = cls.read_stdin_lines(stdin)
- parsed = super(PatronInputScript, cls).look_up_libraries(_db, parsed, *args, **kwargs)
+ parsed = super(PatronInputScript, cls).look_up_libraries(
+ _db, parsed, *args, **kwargs
+ )
return cls.look_up_patrons(_db, parsed, stdin, *args, **kwargs)
@classmethod
def arg_parser(cls, _db):
parser = super(PatronInputScript, cls).arg_parser(_db, multiple_libraries=False)
parser.add_argument(
- 'identifiers',
- help='A specific patron identifier to process.',
- metavar='IDENTIFIER', nargs='+'
+ "identifiers",
+ help="A specific patron identifier to process.",
+ metavar="IDENTIFIER",
+ nargs="+",
)
return parser
@@ -671,9 +660,7 @@ def look_up_patrons(cls, _db, parsed, stdin_patron_strings, *args, **kwargs):
patron_strings = parsed.identifiers
library = parsed.libraries[0]
if stdin_patron_strings:
- patron_strings = (
- patron_strings + stdin_patron_strings
- )
+ patron_strings = patron_strings + stdin_patron_strings
parsed.patrons = cls.parse_patron_list(
_db, library, patron_strings, *args, **kwargs
)
@@ -697,13 +684,18 @@ def parse_patron_list(cls, _db, library, arguments):
for arg in arguments:
if not arg:
continue
- for field in (Patron.authorization_identifier, Patron.username,
- Patron.external_identifier):
+ for field in (
+ Patron.authorization_identifier,
+ Patron.username,
+ Patron.external_identifier,
+ ):
try:
- patron = _db.query(Patron)\
- .filter(field==arg)\
- .filter(Patron.library_id==library.id)\
+ patron = (
+ _db.query(Patron)
+ .filter(field == arg)
+ .filter(Patron.library_id == library.id)
.one()
+ )
except NoResultFound:
continue
except MultipleResultsFound:
@@ -712,9 +704,7 @@ def parse_patron_list(cls, _db, library, arguments):
patrons.append(patron)
break
else:
- logging.warn(
- "Could not find patron %s", arg
- )
+ logging.warn("Could not find patron %s", arg)
return patrons
def do_run(self, *args, **kwargs):
@@ -734,6 +724,7 @@ class LaneSweeperScript(LibraryInputScript):
def process_library(self, library):
from .lane import WorkList
+
top_level = WorkList.top_level_for_library(self._db, library)
queue = [top_level]
while queue:
@@ -754,11 +745,12 @@ def should_process_lane(self, lane):
def process_lane(self, lane):
pass
+
class CustomListSweeperScript(LibraryInputScript):
"""Do something to each custom list in a library."""
def process_library(self, library):
- lists = self._db.query(CustomList).filter(CustomList.library_id==library.id)
+ lists = self._db.query(CustomList).filter(CustomList.library_id == library.id)
for l in lists:
self.process_custom_list(l)
self._db.commit()
@@ -778,13 +770,10 @@ class SubjectInputScript(Script):
@classmethod
def arg_parser(cls):
parser = argparse.ArgumentParser()
+ parser.add_argument("--subject-type", help="Process subjects of this type")
parser.add_argument(
- '--subject-type',
- help='Process subjects of this type'
- )
- parser.add_argument(
- '--subject-filter',
- help='Process subjects whose names or identifiers match this substring'
+ "--subject-filter",
+ help="Process subjects whose names or identifiers match this substring",
)
return parser
@@ -796,14 +785,13 @@ class RunCoverageProviderScript(IdentifierInputScript):
def arg_parser(cls):
parser = IdentifierInputScript.arg_parser()
parser.add_argument(
- '--cutoff-time',
- help='Update existing coverage records if they were originally created after this time.'
+ "--cutoff-time",
+ help="Update existing coverage records if they were originally created after this time.",
)
return parser
@classmethod
- def parse_command_line(cls, _db, cmd_args=None, stdin=sys.stdin,
- *args, **kwargs):
+ def parse_command_line(cls, _db, cmd_args=None, stdin=sys.stdin, *args, **kwargs):
parser = cls.arg_parser()
parsed = parser.parse_args(cmd_args)
stdin = cls.read_stdin_lines(stdin)
@@ -812,7 +800,9 @@ def parse_command_line(cls, _db, cmd_args=None, stdin=sys.stdin,
parsed.cutoff_time = cls.parse_time(parsed.cutoff_time)
return parsed
- def __init__(self, provider, _db=None, cmd_args=None, *provider_args, **provider_kwargs):
+ def __init__(
+ self, provider, _db=None, cmd_args=None, *provider_args, **provider_kwargs
+ ):
super(RunCoverageProviderScript, self).__init__(_db)
parsed_args = self.parse_command_line(self._db, cmd_args)
@@ -833,14 +823,11 @@ def __init__(self, provider, _db=None, cmd_args=None, *provider_args, **provider
kwargs.update(provider_kwargs)
provider = provider(
- self._db, *provider_args,
- cutoff_time=parsed_args.cutoff_time,
- **kwargs
+ self._db, *provider_args, cutoff_time=parsed_args.cutoff_time, **kwargs
)
self.provider = provider
self.name = self.provider.service_name
-
def extract_additional_command_line_arguments(self):
"""A hook method for subclasses.
@@ -851,31 +838,32 @@ def extract_additional_command_line_arguments(self):
(as opposed to WorkCoverageProvider).
"""
return {
- "input_identifiers" : self.identifiers,
+ "input_identifiers": self.identifiers,
}
-
def do_run(self):
if self.identifiers:
self.provider.run_on_specific_identifiers(self.identifiers)
else:
self.provider.run()
+
class ShowLibrariesScript(Script):
"""Show information about the libraries on a server."""
name = "List the libraries on this server."
+
@classmethod
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--short-name',
- help='Only display information for the library with the given short name',
+ "--short-name",
+ help="Only display information for the library with the given short name",
)
parser.add_argument(
- '--show-secrets',
- help='Print out secrets associated with the library.',
- action='store_true'
+ "--show-secrets",
+ help="Print out secrets associated with the library.",
+ action="store_true",
)
return parser
@@ -883,43 +871,33 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
_db = _db or self._db
args = self.parse_command_line(_db, cmd_args=cmd_args)
if args.short_name:
- library = get_one(
- _db, Library, short_name=args.short_name
- )
+ library = get_one(_db, Library, short_name=args.short_name)
libraries = [library]
else:
libraries = _db.query(Library).order_by(Library.name).all()
if not libraries:
output.write("No libraries found.\n")
for library in libraries:
- output.write(
- "\n".join(
- library.explain(
- include_secrets=args.show_secrets
- )
- )
- )
+ output.write("\n".join(library.explain(include_secrets=args.show_secrets)))
output.write("\n")
class ConfigurationSettingScript(Script):
-
@classmethod
def _parse_setting(self, setting):
"""Parse a command-line setting option into a key-value pair."""
- if not '=' in setting:
+ if not "=" in setting:
raise ValueError(
- 'Incorrect format for setting: "%s". Should be "key=value"'
- % setting
+ 'Incorrect format for setting: "%s". Should be "key=value"' % setting
)
- return setting.split('=', 1)
+ return setting.split("=", 1)
@classmethod
def add_setting_argument(self, parser, help):
"""Modify an ArgumentParser to indicate that the script takes
command-line settings.
"""
- parser.add_argument('--setting', help=help, action="append")
+ parser.add_argument("--setting", help=help, action="append")
def apply_settings(self, settings, obj):
"""Treat `settings` as a list of command-line argument settings,
@@ -939,27 +917,27 @@ def __init__(self, _db=None, config=Configuration):
self.config = config
super(ConfigureSiteScript, self).__init__(_db=_db)
-
@classmethod
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--show-secrets',
+ "--show-secrets",
help="Include secrets when displaying site settings.",
action="store_true",
- default=False
+ default=False,
)
cls.add_setting_argument(
parser,
- 'Set a site-wide setting, such as default_nongrouped_feed_max_age. Format: --setting="default_nongrouped_feed_max_age=1200"'
+ 'Set a site-wide setting, such as default_nongrouped_feed_max_age. Format: --setting="default_nongrouped_feed_max_age=1200"',
)
parser.add_argument(
- '--force',
+ "--force",
help="Set a site-wide setting even if the key isn't a known setting.",
- dest='force', action='store_true'
+ dest="force",
+ action="store_true",
)
return parser
@@ -970,7 +948,9 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
if args.setting:
for setting in args.setting:
key, value = self._parse_setting(setting)
- if not args.force and not key in [s.get("key") for s in self.config.SITEWIDE_SETTINGS]:
+ if not args.force and not key in [
+ s.get("key") for s in self.config.SITEWIDE_SETTINGS
+ ]:
raise ValueError(
"'%s' is not a known site-wide setting. Use --force to set it anyway."
% key
@@ -978,27 +958,29 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
else:
ConfigurationSetting.sitewide(_db, key).value = value
output.write(
- "\n".join(ConfigurationSetting.explain(
- _db, include_secrets=args.show_secrets
- ))
+ "\n".join(
+ ConfigurationSetting.explain(_db, include_secrets=args.show_secrets)
+ )
)
site_configuration_has_changed(_db)
_db.commit()
+
class ConfigureLibraryScript(ConfigurationSettingScript):
"""Create a library or change its settings."""
+
name = "Change a library's settings"
@classmethod
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--name',
- help='Official name of the library',
+ "--name",
+ help="Official name of the library",
)
parser.add_argument(
- '--short-name',
- help='Short name of the library',
+ "--short-name",
+ help="Short name of the library",
)
cls.add_setting_argument(
parser,
@@ -1010,9 +992,7 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
_db = _db or self._db
args = self.parse_command_line(_db, cmd_args=cmd_args)
if not args.short_name:
- raise ValueError(
- "You must identify the library by its short name."
- )
+ raise ValueError("You must identify the library by its short name.")
# Are we talking about an existing library?
libraries = _db.query(Library).all()
@@ -1025,10 +1005,12 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
else:
# No existing library. Make one.
library, ignore = get_one_or_create(
- _db, Library, create_method_kwargs=dict(
+ _db,
+ Library,
+ create_method_kwargs=dict(
uuid=str(uuid.uuid4()),
short_name=args.short_name,
- )
+ ),
)
if args.name:
@@ -1047,17 +1029,18 @@ class ShowCollectionsScript(Script):
"""Show information about the collections on a server."""
name = "List the collections on this server."
+
@classmethod
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--name',
- help='Only display information for the collection with the given name',
+ "--name",
+ help="Only display information for the collection with the given name",
)
parser.add_argument(
- '--show-secrets',
- help='Display secret values such as passwords.',
- action='store_true'
+ "--show-secrets",
+ help="Display secret values such as passwords.",
+ action="store_true",
)
return parser
@@ -1070,9 +1053,7 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
if collection:
collections = [collection]
else:
- output.write(
- "Could not locate collection by name: %s" % name
- )
+ output.write("Could not locate collection by name: %s" % name)
collections = []
else:
collections = _db.query(Collection).order_by(Collection.name).all()
@@ -1080,9 +1061,7 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
output.write("No collections found.\n")
for collection in collections:
output.write(
- "\n".join(
- collection.explain(include_secrets=args.show_secrets)
- )
+ "\n".join(collection.explain(include_secrets=args.show_secrets))
)
output.write("\n")
@@ -1091,17 +1070,18 @@ class ShowIntegrationsScript(Script):
"""Show information about the external integrations on a server."""
name = "List the external integrations on this server."
+
@classmethod
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--name',
- help='Only display information for the integration with the given name or ID',
+ "--name",
+ help="Only display information for the integration with the given name or ID",
)
parser.add_argument(
- '--show-secrets',
- help='Display secret values such as passwords.',
- action='store_true'
+ "--show-secrets",
+ help="Display secret values such as passwords.",
+ action="store_true",
)
return parser
@@ -1116,26 +1096,26 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
if integration:
integrations = [integration]
else:
- output.write(
- "Could not locate integration by name or ID: %s\n" % args
- )
+ output.write("Could not locate integration by name or ID: %s\n" % args)
integrations = []
else:
- integrations = _db.query(ExternalIntegration).order_by(
- ExternalIntegration.name, ExternalIntegration.id).all()
+ integrations = (
+ _db.query(ExternalIntegration)
+ .order_by(ExternalIntegration.name, ExternalIntegration.id)
+ .all()
+ )
if not integrations:
output.write("No integrations found.\n")
for integration in integrations:
output.write(
- "\n".join(
- integration.explain(include_secrets=args.show_secrets)
- )
+ "\n".join(integration.explain(include_secrets=args.show_secrets))
)
output.write("\n")
class ConfigureCollectionScript(ConfigurationSettingScript):
"""Create a collection or change its settings."""
+
name = "Change a collection's settings"
@classmethod
@@ -1146,31 +1126,26 @@ def parse_command_line(cls, _db=None, cmd_args=None):
@classmethod
def arg_parser(cls, _db):
parser = argparse.ArgumentParser()
+ parser.add_argument("--name", help="Name of the collection", required=True)
parser.add_argument(
- '--name',
- help='Name of the collection',
- required=True
+ "--protocol",
+ help='Protocol to use to get the licenses. Possible values: "%s"'
+ % ('", "'.join(ExternalIntegration.LICENSE_PROTOCOLS)),
)
parser.add_argument(
- '--protocol',
- help='Protocol to use to get the licenses. Possible values: "%s"' % (
- '", "'.join(ExternalIntegration.LICENSE_PROTOCOLS)
- )
- )
- parser.add_argument(
- '--external-account-id',
+ "--external-account-id",
help='The ID of this collection according to the license source. Sometimes called a "library ID".',
)
parser.add_argument(
- '--url',
- help='Run the acquisition protocol against this URL.',
+ "--url",
+ help="Run the acquisition protocol against this URL.",
)
parser.add_argument(
- '--username',
+ "--username",
help='Use this username to authenticate with the license protocol. Sometimes called a "key".',
)
parser.add_argument(
- '--password',
+ "--password",
help='Use this password to authenticate with the license protocol. Sometimes called a "secret".',
)
cls.add_setting_argument(
@@ -1180,8 +1155,9 @@ def arg_parser(cls, _db):
library_names = cls._library_names(_db)
if library_names:
parser.add_argument(
- '--library',
- help='Associate this collection with the given library. Possible libraries: %s' % library_names,
+ "--library",
+ help="Associate this collection with the given library. Possible libraries: %s"
+ % library_names,
action="append",
)
@@ -1190,8 +1166,8 @@ def arg_parser(cls, _db):
@classmethod
def _library_names(self, _db):
"""Return a string that lists known library names."""
- library_names = [x.short_name for x in _db.query(
- Library).order_by(Library.short_name)
+ library_names = [
+ x.short_name for x in _db.query(Library).order_by(Library.short_name)
]
if library_names:
return '"' + '", "'.join(library_names) + '"'
@@ -1205,7 +1181,7 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
protocol = None
name = args.name
protocol = args.protocol
- collection = get_one(_db, Collection, Collection.name==name)
+ collection = get_one(_db, Collection, Collection.name == name)
if not collection:
if protocol:
collection, is_new = Collection.by_name_and_protocol(
@@ -1215,7 +1191,8 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
# We didn't find a Collection, and we don't have a protocol,
# so we can't create a new Collection.
raise ValueError(
- 'No collection called "%s". You can create it, but you must specify a protocol.' % name
+ 'No collection called "%s". You can create it, but you must specify a protocol.'
+ % name
)
integration = collection.external_integration
if protocol:
@@ -1231,7 +1208,7 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
integration.password = args.password
self.apply_settings(args.setting, integration)
- if hasattr(args, 'library'):
+ if hasattr(args, "library"):
for name in args.library:
library = get_one(_db, Library, short_name=name)
if not library:
@@ -1251,6 +1228,7 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
class ConfigureIntegrationScript(ConfigurationSettingScript):
"""Create a integration or change its settings."""
+
name = "Create a site-wide integration or change an integration's settings"
@classmethod
@@ -1262,22 +1240,24 @@ def parse_command_line(cls, _db=None, cmd_args=None):
def arg_parser(cls, _db):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--name',
- help='Name of the integration',
+ "--name",
+ help="Name of the integration",
)
parser.add_argument(
- '--id',
- help='ID of the integration, if it has no name',
+ "--id",
+ help="ID of the integration, if it has no name",
)
parser.add_argument(
- '--protocol', help='Protocol used by the integration.',
+ "--protocol",
+ help="Protocol used by the integration.",
)
parser.add_argument(
- '--goal', help='Goal of the integration',
+ "--goal",
+ help="Goal of the integration",
)
cls.add_setting_argument(
parser,
- 'Set a configuration value on the integration. Format: --setting="key=value"'
+ 'Set a configuration value on the integration. Format: --setting="key=value"',
)
return parser
@@ -1291,7 +1271,7 @@ def _integration(self, _db, id, name, protocol, goal):
integration = None
if id:
integration = get_one(
- _db, ExternalIntegration, ExternalIntegration.id==id
+ _db, ExternalIntegration, ExternalIntegration.id == id
)
if not integration:
raise ValueError("No integration with ID %s." % id)
@@ -1299,7 +1279,8 @@ def _integration(self, _db, id, name, protocol, goal):
integration = get_one(_db, ExternalIntegration, name=name)
if not integration and not (protocol and goal):
raise ValueError(
- 'No integration with name "%s". To create it, you must also provide protocol and goal.' % name
+ 'No integration with name "%s". To create it, you must also provide protocol and goal.'
+ % name
)
if not integration and (protocol and goal):
integration, is_new = get_one_or_create(
@@ -1332,12 +1313,13 @@ class ShowLanesScript(Script):
"""Show information about the lanes on a server."""
name = "List the lanes on this server."
+
@classmethod
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--id',
- help='Only display information for the lane with the given ID',
+ "--id",
+ help="Only display information for the lane with the given ID",
)
return parser
@@ -1350,24 +1332,20 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
if lane:
lanes = [lane]
else:
- output.write(
- "Could not locate lane with id: %s" % id
- )
+ output.write("Could not locate lane with id: %s" % id)
lanes = []
else:
lanes = _db.query(Lane).order_by(Lane.id).all()
if not lanes:
output.write("No lanes found.\n")
for lane in lanes:
- output.write(
- "\n".join(
- lane.explain()
- )
- )
+ output.write("\n".join(lane.explain()))
output.write("\n\n")
+
class ConfigureLaneScript(ConfigurationSettingScript):
"""Create a lane or change its settings."""
+
name = "Change a lane's settings"
@classmethod
@@ -1379,32 +1357,33 @@ def parse_command_line(cls, _db=None, cmd_args=None):
def arg_parser(cls, _db):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--id',
- help='ID of the lane, if editing an existing lane.',
+ "--id",
+ help="ID of the lane, if editing an existing lane.",
)
parser.add_argument(
- '--library-short-name',
- help='Short name of the library for this lane. Possible values: %s' % cls._library_names(_db),
+ "--library-short-name",
+ help="Short name of the library for this lane. Possible values: %s"
+ % cls._library_names(_db),
)
parser.add_argument(
- '--parent-id',
+ "--parent-id",
help="The ID of this lane's parent lane",
)
parser.add_argument(
- '--priority',
+ "--priority",
help="The lane's priority",
)
parser.add_argument(
- '--display-name',
- help='The lane name that will be displayed to patrons.',
+ "--display-name",
+ help="The lane name that will be displayed to patrons.",
)
return parser
@classmethod
def _library_names(self, _db):
"""Return a string that lists known library names."""
- library_names = [x.short_name for x in _db.query(
- Library).order_by(Library.short_name)
+ library_names = [
+ x.short_name for x in _db.query(Library).order_by(Library.short_name)
]
if library_names:
return '"' + '", "'.join(library_names) + '"'
@@ -1421,7 +1400,7 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
if args.library_short_name:
library = get_one(_db, Library, short_name=args.library_short_name)
if not library:
- raise ValueError("No such library: \"%s\"." % args.library_short_name)
+ raise ValueError('No such library: "%s".' % args.library_short_name)
lane, is_new = create(_db, Lane, library=library)
else:
raise ValueError("Library short name is required to create a new lane.")
@@ -1438,6 +1417,7 @@ def do_run(self, _db=None, cmd_args=None, output=sys.stdout):
output.write("\n".join(lane.explain()))
output.write("\n")
+
class AddClassificationScript(IdentifierInputScript):
name = "Add a classification to an identifier"
@@ -1445,42 +1425,39 @@ class AddClassificationScript(IdentifierInputScript):
def arg_parser(cls):
parser = IdentifierInputScript.arg_parser()
parser.add_argument(
- '--subject-type',
- help='The type of the subject to add to each identifier.',
- required=True
+ "--subject-type",
+ help="The type of the subject to add to each identifier.",
+ required=True,
)
parser.add_argument(
- '--subject-identifier',
- help='The identifier of the subject to add to each identifier.'
+ "--subject-identifier",
+ help="The identifier of the subject to add to each identifier.",
)
parser.add_argument(
- '--subject-name',
- help='The name of the subject to add to each identifier.'
+ "--subject-name", help="The name of the subject to add to each identifier."
)
parser.add_argument(
- '--data-source',
- help='The data source to use when classifying.',
- default=DataSource.MANUAL
+ "--data-source",
+ help="The data source to use when classifying.",
+ default=DataSource.MANUAL,
)
parser.add_argument(
- '--weight',
- help='The weight to use when classifying.',
+ "--weight",
+ help="The weight to use when classifying.",
type=int,
- default=1000
+ default=1000,
)
parser.add_argument(
- '--create-subject',
+ "--create-subject",
help="Add the subject to the database if it doesn't already exist",
- action='store_const',
- const=True
+ action="store_const",
+ const=True,
)
return parser
def __init__(self, _db=None, cmd_args=None, stdin=sys.stdin):
super(AddClassificationScript, self).__init__(_db=_db)
- args = self.parse_command_line(
- self._db, cmd_args=cmd_args, stdin=stdin
- )
+ args = self.parse_command_line(self._db, cmd_args=cmd_args, stdin=stdin)
self.identifier_type = args.identifier_type
self.identifiers = args.identifiers
subject_type = args.subject_type
@@ -1493,8 +1470,11 @@ def __init__(self, _db=None, cmd_args=None, stdin=sys.stdin):
self.data_source = DataSource.lookup(self._db, args.data_source)
self.weight = args.weight
self.subject, ignore = Subject.lookup(
- self._db, subject_type, subject_identifier, subject_name,
- autocreate=args.create_subject
+ self._db,
+ subject_type,
+ subject_identifier,
+ subject_name,
+ autocreate=args.create_subject,
)
def do_run(self):
@@ -1513,9 +1493,11 @@ def do_run(self):
if self.subject:
for identifier in self.identifiers:
identifier.classify(
- self.data_source, self.subject.type,
- self.subject.identifier, self.subject.name,
- self.weight
+ self.data_source,
+ self.subject.type,
+ self.subject.identifier,
+ self.subject.name,
+ self.weight,
)
work = identifier.work
if work:
@@ -1528,26 +1510,26 @@ class WorkProcessingScript(IdentifierInputScript):
name = "Work processing script"
- def __init__(self, force=False, batch_size=10, _db=None,
- cmd_args=None, stdin=sys.stdin
+ def __init__(
+ self, force=False, batch_size=10, _db=None, cmd_args=None, stdin=sys.stdin
):
super(WorkProcessingScript, self).__init__(_db=_db)
- args = self.parse_command_line(
- self._db, cmd_args=cmd_args, stdin=stdin
- )
+ args = self.parse_command_line(self._db, cmd_args=cmd_args, stdin=stdin)
self.identifier_type = args.identifier_type
self.data_source = args.identifier_data_source
self.identifiers = self.parse_identifier_list(
- self._db, self.identifier_type, self.data_source,
- args.identifier_strings
+ self._db, self.identifier_type, self.data_source, args.identifier_strings
)
self.batch_size = batch_size
self.query = self.make_query(
- self._db, self.identifier_type, self.identifiers, self.data_source,
- log=self.log
+ self._db,
+ self.identifier_type,
+ self.identifiers,
+ self.data_source,
+ log=self.log,
)
self.force = force
@@ -1555,37 +1537,27 @@ def __init__(self, force=False, batch_size=10, _db=None,
def make_query(cls, _db, identifier_type, identifiers, data_source, log=None):
query = _db.query(Work)
if identifiers or identifier_type:
- query = query.join(Work.license_pools).join(
- LicensePool.identifier
- )
+ query = query.join(Work.license_pools).join(LicensePool.identifier)
if identifiers:
if log:
- log.info(
- 'Restricted to %d specific identifiers.' % len(identifiers)
- )
+ log.info("Restricted to %d specific identifiers." % len(identifiers))
query = query.filter(
LicensePool.identifier_id.in_([x.id for x in identifiers])
)
elif data_source:
if log:
- log.info(
- 'Restricted to identifiers from DataSource "%s".', data_source
- )
+ log.info('Restricted to identifiers from DataSource "%s".', data_source)
source = DataSource.lookup(_db, data_source)
- query = query.filter(LicensePool.data_source==source)
+ query = query.filter(LicensePool.data_source == source)
if identifier_type:
if log:
- log.info(
- 'Restricted to identifier type "%s".' % identifier_type
- )
- query = query.filter(Identifier.type==identifier_type)
+ log.info('Restricted to identifier type "%s".' % identifier_type)
+ query = query.filter(Identifier.type == identifier_type)
if log:
- log.info(
- "Processing %d works.", query.count()
- )
+ log.info("Processing %d works.", query.count())
return query.order_by(Work.id)
def do_run(self):
@@ -1619,7 +1591,7 @@ def make_query(self, _db, identifier_type, identifiers, data_source, log=None):
# We actually process LicensePools, not Works.
qu = _db.query(LicensePool).join(LicensePool.identifier)
if identifier_type:
- qu = qu.filter(Identifier.type==identifier_type)
+ qu = qu.filter(Identifier.type == identifier_type)
if identifiers:
qu = qu.filter(
Identifier.identifier.in_([x.identifier for x in identifiers])
@@ -1634,12 +1606,12 @@ def process_work(self, work):
def do_run(self):
super(WorkConsolidationScript, self).do_run()
- qu = self._db.query(Work).outerjoin(Work.license_pools).filter(
- LicensePool.id==None
- )
- self.log.info(
- "Deleting %d Works that have no LicensePools." % qu.count()
+ qu = (
+ self._db.query(Work)
+ .outerjoin(Work.license_pools)
+ .filter(LicensePool.id == None)
)
+ self.log.info("Deleting %d Works that have no LicensePools." % qu.count())
for i in qu:
self._db.delete(i)
self._db.commit()
@@ -1658,9 +1630,9 @@ def process_work(self, work):
class WorkClassificationScript(WorkPresentationScript):
- """Recalculate the classification--and nothing else--for Work objects.
- """
- name = "Recalculate the classification for works that need it."""
+ """Recalculate the classification--and nothing else--for Work objects."""
+
+ name = "Recalculate the classification for works that need it." ""
policy = PresentationCalculationPolicy(
choose_edition=False,
@@ -1683,7 +1655,7 @@ class ReclassifyWorksForUncheckedSubjectsScript(WorkClassificationScript):
Subjects because the rules for processing them changed.
"""
- name = "Reclassify works that use unchecked subjects."""
+ name = "Reclassify works that use unchecked subjects." ""
policy = WorkClassificationScript.policy
@@ -1723,22 +1695,26 @@ class CustomListManagementScript(Script):
MembershipManager.
"""
- def __init__(self, manager_class,
- data_source_name, list_identifier, list_name,
- primary_language, description,
- **manager_kwargs
- ):
+ def __init__(
+ self,
+ manager_class,
+ data_source_name,
+ list_identifier,
+ list_name,
+ primary_language,
+ description,
+ **manager_kwargs,
+ ):
data_source = DataSource.lookup(self._db, data_source_name)
self.custom_list, is_new = get_one_or_create(
- self._db, CustomList,
+ self._db,
+ CustomList,
data_source_id=data_source.id,
foreign_identifier=list_identifier,
)
self.custom_list.primary_language = primary_language
self.custom_list.description = description
- self.membership_manager = manager_class(
- self.custom_list, **manager_kwargs
- )
+ self.membership_manager = manager_class(self.custom_list, **manager_kwargs)
def run(self):
self.membership_manager.update()
@@ -1746,9 +1722,9 @@ def run(self):
class CollectionType(Enum):
- OPEN_ACCESS = 'OPEN_ACCESS'
- PROTECTED_ACCESS = 'PROTECTED_ACCESS'
- LCP = 'LCP'
+ OPEN_ACCESS = "OPEN_ACCESS"
+ PROTECTED_ACCESS = "PROTECTED_ACCESS"
+ LCP = "LCP"
def __str__(self):
return self.name
@@ -1780,30 +1756,32 @@ def look_up_collections(cls, _db, parsed, *args, **kwargs):
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '--collection',
- help='Collection to use',
- dest='collection_names',
- metavar='NAME', action='append', default=[]
+ "--collection",
+ help="Collection to use",
+ dest="collection_names",
+ metavar="NAME",
+ action="append",
+ default=[],
)
parser.add_argument(
- '--collection-type',
- help='Collection type. Valid values are: OPEN_ACCESS (default), PROTECTED_ACCESS.',
+ "--collection-type",
+ help="Collection type. Valid values are: OPEN_ACCESS (default), PROTECTED_ACCESS.",
type=CollectionType,
choices=list(CollectionType),
- default=CollectionType.OPEN_ACCESS
+ default=CollectionType.OPEN_ACCESS,
)
return parser
class CollectionArgumentsScript(CollectionInputScript):
-
@classmethod
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- 'collection_names',
- help='One or more collection names.',
- metavar='COLLECTION', nargs='*'
+ "collection_names",
+ help="One or more collection names.",
+ metavar="COLLECTION",
+ nargs="*",
)
return parser
@@ -1832,8 +1810,8 @@ def __init__(self, monitor_class, _db=None, cmd_args=None, **kwargs):
self.monitor_class = monitor_class
self.name = self.monitor_class.SERVICE_NAME
parsed = vars(self.parse_command_line(self._db, cmd_args=cmd_args))
- parsed.pop('collection_names', None)
- self.collections = parsed.pop('collections', None)
+ parsed.pop("collection_names", None)
+ self.collections = parsed.pop("collections", None)
self.kwargs.update(parsed)
def monitors(self, **kwargs):
@@ -1849,8 +1827,15 @@ class OPDSImportScript(CollectionInputScript):
MONITOR_CLASS = OPDSImportMonitor
PROTOCOL = ExternalIntegration.OPDS_IMPORT
- def __init__(self, _db=None, importer_class=None, monitor_class=None,
- protocol=None, *args, **kwargs):
+ def __init__(
+ self,
+ _db=None,
+ importer_class=None,
+ monitor_class=None,
+ protocol=None,
+ *args,
+ **kwargs,
+ ):
super(OPDSImportScript, self).__init__(_db, *args, **kwargs)
self.importer_class = importer_class or self.IMPORTER_CLASS
self.monitor_class = monitor_class or self.MONITOR_CLASS
@@ -1861,22 +1846,28 @@ def __init__(self, _db=None, importer_class=None, monitor_class=None,
def arg_parser(cls):
parser = CollectionInputScript.arg_parser()
parser.add_argument(
- '--force',
- help='Import the feed from scratch, even if it seems like it was already imported.',
- dest='force', action='store_true'
+ "--force",
+ help="Import the feed from scratch, even if it seems like it was already imported.",
+ dest="force",
+ action="store_true",
)
return parser
def do_run(self, cmd_args=None):
parsed = self.parse_command_line(self._db, cmd_args=cmd_args)
- collections = parsed.collections or Collection.by_protocol(self._db, self.protocol)
+ collections = parsed.collections or Collection.by_protocol(
+ self._db, self.protocol
+ )
for collection in collections:
self.run_monitor(collection, force=parsed.force)
def run_monitor(self, collection, force=None):
monitor = self.monitor_class(
- self._db, collection, import_class=self.importer_class,
- force_reimport=force, **self.importer_kwargs
+ self._db,
+ collection,
+ import_class=self.importer_class,
+ force_reimport=force,
+ **self.importer_kwargs,
)
monitor.run()
@@ -1898,10 +1889,14 @@ def do_run(self, cmd_args=None):
collections = self._db.query(Collection).all()
# But only process collections that have an associated MirrorUploader.
- for collection, policy in self.collections_with_uploader(collections, collection_type):
+ for collection, policy in self.collections_with_uploader(
+ collections, collection_type
+ ):
self.process_collection(collection, policy)
- def collections_with_uploader(self, collections, collection_type=CollectionType.OPEN_ACCESS):
+ def collections_with_uploader(
+ self, collections, collection_type=CollectionType.OPEN_ACCESS
+ ):
"""Filter out collections that have no MirrorUploader.
:yield: 2-tuples (Collection, ReplacementPolicy). The
@@ -1912,25 +1907,21 @@ def collections_with_uploader(self, collections, collection_type=CollectionType.
covers = MirrorUploader.for_collection(
collection, ExternalIntegrationLink.COVERS
)
- books_mirror_type = \
- ExternalIntegrationLink.OPEN_ACCESS_BOOKS \
- if collection_type == CollectionType.OPEN_ACCESS \
+ books_mirror_type = (
+ ExternalIntegrationLink.OPEN_ACCESS_BOOKS
+ if collection_type == CollectionType.OPEN_ACCESS
else ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS
- books = MirrorUploader.for_collection(
- collection,
- books_mirror_type
)
+ books = MirrorUploader.for_collection(collection, books_mirror_type)
if covers or books:
mirrors = {
ExternalIntegrationLink.COVERS: covers,
- books_mirror_type: books
+ books_mirror_type: books,
}
policy = self.replacement_policy(mirrors)
yield collection, policy
else:
- self.log.info(
- "Skipping %r as it has no MirrorUploader.", collection
- )
+ self.log.info("Skipping %r as it has no MirrorUploader.", collection)
@classmethod
def replacement_policy(cls, mirrors):
@@ -1938,7 +1929,8 @@ def replacement_policy(cls, mirrors):
given mirrors.
"""
return ReplacementPolicy(
- mirrors=mirrors, link_content=True,
+ mirrors=mirrors,
+ link_content=True,
even_if_not_apparently_updated=True,
http_get=Representation.cautious_http_get,
)
@@ -1979,9 +1971,9 @@ def derive_rights_status(cls, license_pool, resource):
# this particular resource, but if every
# LicensePoolDeliveryMechanism has the same rights
# status, we can assume it's that one.
- statuses = list(set([
- x.rights_status for x in license_pool.delivery_mechanisms
- ]))
+ statuses = list(
+ set([x.rights_status for x in license_pool.delivery_mechanisms])
+ )
if len(statuses) == 1:
[rights_status] = statuses
if rights_status:
@@ -1994,9 +1986,12 @@ def process_item(self, collection, link_obj, policy):
"""
identifier = link_obj.identifier
license_pool, ignore = LicensePool.for_foreign_id(
- self._db, collection.data_source,
- identifier.type, identifier.identifier,
- collection=collection, autocreate=False
+ self._db,
+ collection.data_source,
+ identifier.type,
+ identifier.identifier,
+ collection=collection,
+ autocreate=False,
)
if not license_pool:
# This shouldn't happen.
@@ -2010,7 +2005,8 @@ def process_item(self, collection, link_obj, policy):
rights_status = self.derive_rights_status(license_pool, resource)
if not rights_status:
self.log.warn(
- "Could not unambiguously determine rights status for %r, skipping.", link_obj
+ "Could not unambiguously determine rights status for %r, skipping.",
+ link_obj,
)
return
else:
@@ -2021,15 +2017,16 @@ def process_item(self, collection, link_obj, policy):
# Mock up a LinkData that MetaToModelUtility can use to
# mirror this link (or decide not to mirror it).
linkdata = LinkData(
- rel=link_obj.rel,
- href=resource.url,
- rights_uri=rights_status
+ rel=link_obj.rel, href=resource.url, rights_uri=rights_status
)
# Mirror the link (or not).
self.MIRROR_UTILITY.mirror_link(
- model_object=license_pool, data_source=collection.data_source,
- link=linkdata, link_obj=link_obj, policy=policy
+ model_object=license_pool,
+ data_source=collection.data_source,
+ link=linkdata,
+ link_obj=link_obj,
+ policy=policy,
)
@@ -2051,10 +2048,10 @@ class DatabaseMigrationScript(Script):
MIGRATION_WITH_COUNTER = re.compile("\d{8}-(\d+)-(.)+\.(py|sql)")
# There are some SQL commands that can't be run inside a transaction.
- TRANSACTIONLESS_COMMANDS = ['alter type']
+ TRANSACTIONLESS_COMMANDS = ["alter type"]
- TRANSACTION_PER_STATEMENT = 'SIMPLYE_MIGRATION_TRANSACTION_PER_STATEMENT'
- DO_NOT_EXECUTE = 'SIMPLYE_MIGRATION_DO_NOT_EXECUTE'
+ TRANSACTION_PER_STATEMENT = "SIMPLYE_MIGRATION_TRANSACTION_PER_STATEMENT"
+ DO_NOT_EXECUTE = "SIMPLYE_MIGRATION_DO_NOT_EXECUTE"
class TimestampInfo(object):
"""Act like a ORM Timestamp object, but with no database connection."""
@@ -2075,8 +2072,8 @@ def find(cls, script, service):
# 2.3.0 - 'timestamp' field renamed to 'finish'
exception = None
for sql in (
- "SELECT finish, counter FROM timestamps WHERE service=:service LIMIT 1;",
- "SELECT timestamp, counter FROM timestamps WHERE service=:service LIMIT 1;",
+ "SELECT finish, counter FROM timestamps WHERE service=:service LIMIT 1;",
+ "SELECT timestamp, counter FROM timestamps WHERE service=:service LIMIT 1;",
):
_db = script._db
try:
@@ -2092,7 +2089,9 @@ def find(cls, script, service):
# The database connection is now tainted; we must
# create a new one.
logging.error(
- "Got a database error obtaining the timestamp for %s. Hopefully the timestamps table itself must be migrated and this is all according to plan.", service, exc_info=e
+ "Got a database error obtaining the timestamp for %s. Hopefully the timestamps table itself must be migrated and this is all according to plan.",
+ service,
+ exc_info=e,
)
_db.close()
script._session = production_session(initialize_data=False)
@@ -2133,8 +2132,7 @@ def save(self, _db):
self.update(_db, self.finish, self.counter)
def update(self, _db, finish, counter, migration_name=None):
- """Saves a TimestampInfo object to the database.
- """
+ """Saves a TimestampInfo object to the database."""
# Reset values locally.
self.finish = to_utc(finish)
self.counter = counter
@@ -2144,14 +2142,16 @@ def update(self, _db, finish, counter, migration_name=None):
" where service=:service"
)
values = dict(
- finish=self.finish, counter=self.counter,
+ finish=self.finish,
+ counter=self.counter,
service=self.service,
)
_db.execute(text(sql), values)
_db.flush()
message = "%s Timestamp stamped at %s" % (
- self.service, self.finish.strftime('%Y-%m-%d')
+ self.service,
+ self.finish.strftime("%Y-%m-%d"),
)
if migration_name:
message += " for %s" % migration_name
@@ -2161,20 +2161,30 @@ def update(self, _db, finish, counter, migration_name=None):
def arg_parser(cls):
parser = argparse.ArgumentParser()
parser.add_argument(
- '-d', '--last-run-date',
- help=('A date string representing the last migration file '
- 'run against your database, formatted as YYYY-MM-DD')
+ "-d",
+ "--last-run-date",
+ help=(
+ "A date string representing the last migration file "
+ "run against your database, formatted as YYYY-MM-DD"
+ ),
)
parser.add_argument(
- '-c', '--last-run-counter', type=int,
- help=('An optional digit representing the counter of the last '
- 'migration run against your database. Only necessary if '
- 'multiple migrations were created on the same date.')
+ "-c",
+ "--last-run-counter",
+ type=int,
+ help=(
+ "An optional digit representing the counter of the last "
+ "migration run against your database. Only necessary if "
+ "multiple migrations were created on the same date."
+ ),
)
parser.add_argument(
- '--python-only', action='store_true',
- help=('Only run python migrations since the given timestamp or the'
- 'most recent python timestamp')
+ "--python-only",
+ action="store_true",
+ help=(
+ "Only run python migrations since the given timestamp or the"
+ "most recent python timestamp"
+ ),
)
return parser
@@ -2205,7 +2215,7 @@ def compare_migrations(first):
migrations are sorted by counter (asc).
"""
key = []
- if first.endswith('.py'):
+ if first.endswith(".py"):
key.append(1)
else:
key.append(-1)
@@ -2234,8 +2244,8 @@ def directories_by_priority(self):
and its container server, organized in priority order (core first)
"""
current_dir = os.path.split(os.path.abspath(__file__))[0]
- core = os.path.join(current_dir, 'migration')
- server = os.path.join(os.path.split(current_dir)[0], 'migration')
+ core = os.path.join(current_dir, "migration")
+ server = os.path.join(os.path.split(current_dir)[0], "migration")
# Core is listed first, since core makes changes to the core database
# schema. Server migrations generally fix bugs or otherwise update
@@ -2298,9 +2308,7 @@ def run(self, test_db=None, test=False, cmd_args=None):
last_run_date = parsed.last_run_date
last_run_counter = parsed.last_run_counter
if last_run_date:
- timestamp = self.TimestampInfo(
- self.name, last_run_date, last_run_counter
- )
+ timestamp = self.TimestampInfo(self.name, last_run_date, last_run_counter)
# Save the timestamp at this point. This will set back the clock
# in the case that the input last_run_date/counter is before the
# existing Timestamp.finish / Timestamp.counter.
@@ -2316,7 +2324,7 @@ def run(self, test_db=None, test=False, cmd_args=None):
if not timestamp or not self.overall_timestamp:
# There's no timestamp in the database! Raise an error.
print("")
- print (
+ print(
"NO TIMESTAMP FOUND. Either initialize your untouched database "
"with the script `core/bin/initialize_database` OR run this "
"script with a timestamp that indicates the last migration run "
@@ -2333,9 +2341,7 @@ def run(self, test_db=None, test=False, cmd_args=None):
print("%d new migrations found." % len(new_migrations))
for migration in new_migrations:
print(" - %s" % migration)
- self.run_migrations(
- new_migrations, migrations_by_dir, timestamp
- )
+ self.run_migrations(new_migrations, migrations_by_dir, timestamp)
self._db.commit()
else:
print("No new migrations found. Your database is up-to-date.")
@@ -2349,9 +2355,9 @@ def fetch_migration_files(self):
migrations = list()
migrations_by_dir = defaultdict(list)
- extensions = ['.py']
+ extensions = [".py"]
if not self.python_only:
- extensions.insert(0, '.sql')
+ extensions.insert(0, ".sql")
for directory in self.directories_by_priority:
# In the case of tests, the container server migration directory
@@ -2369,10 +2375,11 @@ def get_new_migrations(self, timestamp, migrations):
"""Return a list of migration filenames, representing migrations
created since the timestamp
"""
- last_run = timestamp.finish.strftime('%Y%m%d')
+ last_run = timestamp.finish.strftime("%Y%m%d")
migrations = self.sort_migrations(migrations)
- new_migrations = [migration for migration in migrations
- if int(migration[:8]) >= int(last_run)]
+ new_migrations = [
+ migration for migration in migrations if int(migration[:8]) >= int(last_run)
+ ]
# Multiple migrations run on the same day have an additional digit
# after the date and a dash, eg:
@@ -2402,17 +2409,17 @@ def _is_matching_migration(self, migration_file, timestamp):
is_match = False
is_after_timestamp = False
- timestamp_str = timestamp.finish.strftime('%Y%m%d')
+ timestamp_str = timestamp.finish.strftime("%Y%m%d")
counter = timestamp.counter
- if migration_file[:8]>=timestamp_str:
- if migration_file[:8]>timestamp_str:
+ if migration_file[:8] >= timestamp_str:
+ if migration_file[:8] > timestamp_str:
is_after_timestamp = True
elif counter:
count = self.MIGRATION_WITH_COUNTER.search(migration_file)
if count:
migration_num = int(count.groups()[0])
- if migration_num==counter:
+ if migration_num == counter:
is_match = True
if migration_num > counter:
is_after_timestamp = True
@@ -2448,7 +2455,7 @@ def raise_error(migration_path, message, code=1):
raise_error(
full_migration_path,
"Migration raised error code '%d'" % se.code,
- code=se.code
+ code=se.code,
)
# Sometimes a migration isn't relevant and it
@@ -2467,29 +2474,35 @@ def _run_migration(self, migration_path, timestamp):
migration_filename = os.path.split(migration_path)[1]
ok_to_execute = True
- if migration_path.endswith('.sql'):
+ if migration_path.endswith(".sql"):
with open(migration_path) as clause:
sql = clause.read()
- transactionless = any([c for c in self.TRANSACTIONLESS_COMMANDS if c in sql.lower()])
- one_tx_per_statement = bool(self.TRANSACTION_PER_STATEMENT.lower() in sql.lower())
+ transactionless = any(
+ [c for c in self.TRANSACTIONLESS_COMMANDS if c in sql.lower()]
+ )
+ one_tx_per_statement = bool(
+ self.TRANSACTION_PER_STATEMENT.lower() in sql.lower()
+ )
ok_to_execute = not bool(self.DO_NOT_EXECUTE.lower() in sql.lower())
if ok_to_execute:
if transactionless:
new_session = self._run_migration_without_transaction(sql)
elif one_tx_per_statement:
- commands = self._extract_statements_from_sql_file(migration_path)
+ commands = self._extract_statements_from_sql_file(
+ migration_path
+ )
for command in commands:
self._db.execute(f"BEGIN;{command}COMMIT;")
else:
# By wrapping the action in a transation, we can avoid
# rolling over errors and losing data in files
# with multiple interrelated SQL actions.
- sql = 'BEGIN;\n%s\nCOMMIT;' % sql
+ sql = "BEGIN;\n%s\nCOMMIT;" % sql
self._db.execute(sql)
- if migration_path.endswith('.py'):
+ if migration_path.endswith(".py"):
module_name = migration_filename[:-3]
subprocess.call(migration_path)
@@ -2508,20 +2521,20 @@ def _extract_statements_from_sql_file(self, filepath):
sql_file_lines = f.readlines()
sql_commands = []
- current_command = ''
+ current_command = ""
for line in sql_file_lines:
- if line.strip().startswith('--'):
+ if line.strip().startswith("--"):
continue
else:
- if current_command == '':
+ if current_command == "":
current_command = line.strip()
else:
- current_command = current_command + ' ' + line.strip()
+ current_command = current_command + " " + line.strip()
- if current_command.endswith(';'):
+ if current_command.endswith(";"):
sql_commands.append(current_command)
- current_command = ''
+ current_command = ""
return sql_commands
@@ -2538,15 +2551,18 @@ def _run_migration_without_transaction(self, sql_statement):
# In the case of 'ALTER TYPE' (at least), running commands
# simultaneously raises psycopg2.InternalError ending with 'cannot be
# executed from a fuction or multi-command string'
- sql_commands = [command.strip()+';'
- for command in sql_statement.split(';')
- if command.strip()]
+ sql_commands = [
+ command.strip() + ";"
+ for command in sql_statement.split(";")
+ if command.strip()
+ ]
# Run each command in the sql statement right up against the
# database: no transactions, no guardrails.
for command in sql_commands:
- connection.execution_options(isolation_level='AUTOCOMMIT')\
- .execute(text(command))
+ connection.execution_options(isolation_level="AUTOCOMMIT").execute(
+ text(command)
+ )
# Update the script's Session to a new one that has the changed schema
# and other important info.
@@ -2567,11 +2583,13 @@ def update_timestamps(self, migration_file):
if match:
counter = int(match.groups()[0])
- if migration_file.endswith('py') and self.python_timestamp:
+ if migration_file.endswith("py") and self.python_timestamp:
# This is a python migration. Update the python timestamp.
self.python_timestamp.update(
- self._db, finish=last_run_date,
- counter=counter, migration_name=migration_file
+ self._db,
+ finish=last_run_date,
+ counter=counter,
+ migration_name=migration_file,
)
# Nothing to update
@@ -2585,19 +2603,23 @@ def update_timestamps(self, migration_file):
return
# The dates of the scrips are the same so compare the counters
- if finish_timestamp==last_run_date:
+ if finish_timestamp == last_run_date:
# The current script has no counter, so it's the same script that ran
# or an earlier script that ran
if counter is None:
return
# The previous script has a higher counter
- if (self.overall_timestamp.counter is not None and
- self.overall_timestamp.counter > counter):
+ if (
+ self.overall_timestamp.counter is not None
+ and self.overall_timestamp.counter > counter
+ ):
return
self.overall_timestamp.update(
- self._db, finish=last_run_date,
- counter=counter, migration_name=migration_file
+ self._db,
+ finish=last_run_date,
+ counter=counter,
+ migration_name=migration_file,
)
@@ -2611,8 +2633,10 @@ class DatabaseMigrationInitializationScript(DatabaseMigrationScript):
def arg_parser(cls):
parser = super(DatabaseMigrationInitializationScript, cls).arg_parser()
parser.add_argument(
- '-f', '--force', action='store_true',
- help="Force reset the initialization, ignoring any existing timestamps."
+ "-f",
+ "--force",
+ action="store_true",
+ help="Force reset the initialization, ignoring any existing timestamps.",
)
return parser
@@ -2623,7 +2647,8 @@ def run(self, cmd_args=None):
if last_run_counter and not last_run_date:
raise ValueError(
- "Timestamp.counter must be reset alongside Timestamp.finish")
+ "Timestamp.counter must be reset alongside Timestamp.finish"
+ )
existing_timestamp = get_one(self._db, Timestamp, service=self.name)
if existing_timestamp and existing_timestamp.finish:
@@ -2632,21 +2657,28 @@ def run(self, cmd_args=None):
if parsed.force:
self.log.warn(
"Overwriting existing %s timestamp: %r",
- self.name, existing_timestamp)
+ self.name,
+ existing_timestamp,
+ )
else:
raise RuntimeError(
- "%s timestamp already exists: %r. Use --force to update." %
- (self.name, existing_timestamp))
+ "%s timestamp already exists: %r. Use --force to update."
+ % (self.name, existing_timestamp)
+ )
# Initialize the required timestamps with the Space Jam release date.
- init_timestamp = self.parse_time('1996-11-15')
+ init_timestamp = self.parse_time("1996-11-15")
overall_timestamp = existing_timestamp or Timestamp.stamp(
- _db=self._db, service=self.SERVICE_NAME,
- service_type=Timestamp.SCRIPT_TYPE, finish=init_timestamp
+ _db=self._db,
+ service=self.SERVICE_NAME,
+ service_type=Timestamp.SCRIPT_TYPE,
+ finish=init_timestamp,
)
python_timestamp = Timestamp.stamp(
- _db=self._db, service=self.PY_TIMESTAMP_SERVICE_NAME,
- service_type=Timestamp.SCRIPT_TYPE, finish=init_timestamp
+ _db=self._db,
+ service=self.PY_TIMESTAMP_SERVICE_NAME,
+ service_type=Timestamp.SCRIPT_TYPE,
+ finish=init_timestamp,
)
if last_run_date:
@@ -2658,8 +2690,8 @@ def run(self, cmd_args=None):
return
migrations = self.sort_migrations(self.fetch_migration_files()[0])
- py_migrations = [m for m in migrations if m.endswith('.py')]
- sql_migrations = [m for m in migrations if m.endswith('.sql')]
+ py_migrations = [m for m in migrations if m.endswith(".py")]
+ sql_migrations = [m for m in migrations if m.endswith(".sql")]
most_recent_sql_migration = sql_migrations[-1]
most_recent_python_migration = py_migrations[-1]
@@ -2670,7 +2702,7 @@ def run(self, cmd_args=None):
class CheckContributorNamesInDB(IdentifierInputScript):
- """ Checks that contributor sort_names are display_names in
+ """Checks that contributor sort_names are display_names in
"last name, comma, other names" format.
Read contributors edition by edition, so that can, if necessary,
@@ -2685,8 +2717,7 @@ class CheckContributorNamesInDB(IdentifierInputScript):
"""
COMPLAINT_SOURCE = "CheckContributorNamesInDB"
- COMPLAINT_TYPE = "http://librarysimplified.org/terms/problem/wrong-author";
-
+ COMPLAINT_TYPE = "http://librarysimplified.org/terms/problem/wrong-author"
def __init__(self, _db=None, cmd_args=None, stdin=sys.stdin):
super(CheckContributorNamesInDB, self).__init__(_db=_db)
@@ -2695,7 +2726,6 @@ def __init__(self, _db=None, cmd_args=None, stdin=sys.stdin):
_db=self._db, cmd_args=cmd_args, stdin=stdin
)
-
@classmethod
def make_query(self, _db, identifier_type, identifiers, log=None):
query = _db.query(Edition)
@@ -2707,31 +2737,27 @@ def make_query(self, _db, identifier_type, identifiers, log=None):
if identifiers:
if log:
- log.info(
- 'Restricted to %d specific identifiers.' % len(identifiers)
- )
+ log.info("Restricted to %d specific identifiers." % len(identifiers))
query = query.filter(
Edition.primary_identifier_id.in_([x.id for x in identifiers])
)
if identifier_type:
if log:
- log.info(
- 'Restricted to identifier type "%s".' % identifier_type
- )
- query = query.filter(Identifier.type==identifier_type)
+ log.info('Restricted to identifier type "%s".' % identifier_type)
+ query = query.filter(Identifier.type == identifier_type)
if log:
- log.info(
- "Processing %d editions.", query.count()
- )
+ log.info("Processing %d editions.", query.count())
return query.order_by(Edition.id)
-
def do_run(self, batch_size=10):
self.query = self.make_query(
- self._db, self.parsed_args.identifier_type, self.parsed_args.identifiers, self.log
+ self._db,
+ self.parsed_args.identifier_type,
+ self.parsed_args.identifiers,
+ self.log,
)
editions = True
@@ -2746,13 +2772,14 @@ def do_run(self, batch_size=10):
for edition in editions:
if edition.contributions:
for contribution in edition.contributions:
- self.process_contribution_local(self._db, contribution, self.log)
+ self.process_contribution_local(
+ self._db, contribution, self.log
+ )
offset += batch_size
self._db.commit()
self._db.commit()
-
def process_contribution_local(self, _db, contribution, log=None):
if not contribution or not contribution.edition:
return
@@ -2762,10 +2789,18 @@ def process_contribution_local(self, _db, contribution, log=None):
identifier = contribution.edition.primary_identifier
if contributor.sort_name and contributor.display_name:
- computed_sort_name_local_new = unicodedata.normalize("NFKD", str(display_name_to_sort_name(contributor.display_name)))
+ computed_sort_name_local_new = unicodedata.normalize(
+ "NFKD", str(display_name_to_sort_name(contributor.display_name))
+ )
# Did HumanName parser produce a differet result from the plain comma replacement?
- if (contributor.sort_name.strip().lower() != computed_sort_name_local_new.strip().lower()):
- error_message_detail = "Contributor[id=%s].sort_name is oddly different from computed_sort_name, human intervention required." % contributor.id
+ if (
+ contributor.sort_name.strip().lower()
+ != computed_sort_name_local_new.strip().lower()
+ ):
+ error_message_detail = (
+ "Contributor[id=%s].sort_name is oddly different from computed_sort_name, human intervention required."
+ % contributor.id
+ )
# computed names don't match. by how much? if it's a matter of a comma or a misplaced
# suffix, we can fix without asking for human intervention. if the names are very different,
@@ -2777,28 +2812,50 @@ def process_contribution_local(self, _db, contribution, log=None):
# it probably means that a human metadata professional had added an explanation/expansion to the
# sort_name, s.a. "Bob A. Jones" --> "Bob A. (Allan) Jones", and we'd rather not replace this data
# with the "Jones, Bob A." that the auto-algorigthm would generate.
- length_difference = len(contributor.sort_name.strip()) - len(computed_sort_name_local_new.strip())
+ length_difference = len(contributor.sort_name.strip()) - len(
+ computed_sort_name_local_new.strip()
+ )
if abs(length_difference) > 3:
- return self.process_local_mismatch(_db=_db, contribution=contribution,
- computed_sort_name=computed_sort_name_local_new, error_message_detail=error_message_detail, log=log)
+ return self.process_local_mismatch(
+ _db=_db,
+ contribution=contribution,
+ computed_sort_name=computed_sort_name_local_new,
+ error_message_detail=error_message_detail,
+ log=log,
+ )
- match_ratio = contributor_name_match_ratio(contributor.sort_name, computed_sort_name_local_new, normalize_names=False)
+ match_ratio = contributor_name_match_ratio(
+ contributor.sort_name,
+ computed_sort_name_local_new,
+ normalize_names=False,
+ )
- if (match_ratio < 40):
+ if match_ratio < 40:
# ask a human. this kind of score can happen when the sort_name is a transliteration of the display_name,
# and is non-trivial to fix.
- self.process_local_mismatch(_db=_db, contribution=contribution,
- computed_sort_name=computed_sort_name_local_new, error_message_detail=error_message_detail, log=log)
+ self.process_local_mismatch(
+ _db=_db,
+ contribution=contribution,
+ computed_sort_name=computed_sort_name_local_new,
+ error_message_detail=error_message_detail,
+ log=log,
+ )
else:
# we can fix it!
- output = "%s|\t%s|\t%s|\t%s|\tlocal_fix" % (contributor.id, contributor.sort_name, contributor.display_name, computed_sort_name_local_new)
+ output = "%s|\t%s|\t%s|\t%s|\tlocal_fix" % (
+ contributor.id,
+ contributor.sort_name,
+ contributor.display_name,
+ computed_sort_name_local_new,
+ )
print(output.encode("utf8"))
- self.set_contributor_sort_name(computed_sort_name_local_new, contribution)
-
+ self.set_contributor_sort_name(
+ computed_sort_name_local_new, contribution
+ )
@classmethod
def set_contributor_sort_name(cls, sort_name, contribution):
- """ Sets the contributor.sort_name and associated edition.author_name to the passed-in value. """
+ """Sets the contributor.sort_name and associated edition.author_name to the passed-in value."""
contribution.contributor.sort_name = sort_name
# also change edition.sort_author, if the author was primary
@@ -2808,23 +2865,32 @@ def set_contributor_sort_name(cls, sort_name, contribution):
# If this author appears as Primary Author anywhere on the edition, then change edition.sort_author.
edition_contributions = contribution.edition.contributions
for edition_contribution in edition_contributions:
- if ((edition_contribution.role == Contributor.PRIMARY_AUTHOR_ROLE) and
- (edition_contribution.contributor.display_name == contribution.contributor.display_name)):
+ if (edition_contribution.role == Contributor.PRIMARY_AUTHOR_ROLE) and (
+ edition_contribution.contributor.display_name
+ == contribution.contributor.display_name
+ ):
contribution.edition.sort_author = sort_name
-
- def process_local_mismatch(self, _db, contribution, computed_sort_name, error_message_detail, log=None):
+ def process_local_mismatch(
+ self, _db, contribution, computed_sort_name, error_message_detail, log=None
+ ):
"""
Determines if a problem is to be investigated further or recorded as a Complaint,
to be solved by a human. In this class, it's always a complaint. In the overridden
method in the child class in metadata_wrangler code, we sometimes go do a web query.
"""
- self.register_problem(source=self.COMPLAINT_SOURCE, contribution=contribution,
- computed_sort_name=computed_sort_name, error_message_detail=error_message_detail, log=log)
-
+ self.register_problem(
+ source=self.COMPLAINT_SOURCE,
+ contribution=contribution,
+ computed_sort_name=computed_sort_name,
+ error_message_detail=error_message_detail,
+ log=log,
+ )
@classmethod
- def register_problem(cls, source, contribution, computed_sort_name, error_message_detail, log=None):
+ def register_problem(
+ cls, source, contribution, computed_sort_name, error_message_detail, log=None
+ ):
"""
Make a Complaint in the database, so a human can take a look at this Contributor's name
and resolve whatever the complex issue that got us here.
@@ -2834,8 +2900,16 @@ def register_problem(cls, source, contribution, computed_sort_name, error_messag
pools = contribution.edition.is_presentation_for
try:
- complaint, is_new = Complaint.register(pools[0], cls.COMPLAINT_TYPE, source, error_message_detail)
- output = "%s|\t%s|\t%s|\t%s|\tcomplain|\t%s" % (contributor.id, contributor.sort_name, contributor.display_name, computed_sort_name, source)
+ complaint, is_new = Complaint.register(
+ pools[0], cls.COMPLAINT_TYPE, source, error_message_detail
+ )
+ output = "%s|\t%s|\t%s|\t%s|\tcomplain|\t%s" % (
+ contributor.id,
+ contributor.sort_name,
+ contributor.display_name,
+ computed_sort_name,
+ source,
+ )
print(output.encode("utf8"))
except ValueError as e:
# log and move on, don't stop run
@@ -2845,10 +2919,6 @@ def register_problem(cls, source, contribution, computed_sort_name, error_messag
return success
-
-
-
-
class Explain(IdentifierInputScript):
"""Explain everything known about a given work."""
@@ -2873,20 +2943,28 @@ def do_run(self, cmd_args=None, stdin=sys.stdin, stdout=sys.stdout):
def write(self, s):
"""Write a string to self.stdout."""
- if not s.endswith('\n'):
- s += '\n'
+ if not s.endswith("\n"):
+ s += "\n"
self.stdout.write(s)
def explain(self, _db, edition, presentation_calculation_policy=None):
- if edition.medium not in ('Book', 'Audio'):
+ if edition.medium not in ("Book", "Audio"):
# we haven't yet decided what to display for you
return
# Tell about the Edition record.
- output = "%s (%s, %s) according to %s" % (edition.title, edition.author, edition.medium, edition.data_source.name)
+ output = "%s (%s, %s) according to %s" % (
+ edition.title,
+ edition.author,
+ edition.medium,
+ edition.data_source.name,
+ )
self.write(output)
self.write(" Permanent work ID: %s" % edition.permanent_work_id)
- self.write(" Metadata URL: %s " % (self.METADATA_URL_TEMPLATE % edition.primary_identifier.urn))
+ self.write(
+ " Metadata URL: %s "
+ % (self.METADATA_URL_TEMPLATE % edition.primary_identifier.urn)
+ )
seen = set()
self.explain_identifier(edition.primary_identifier, True, seen, 1, 0)
@@ -2913,19 +2991,21 @@ def explain(self, _db, edition, presentation_calculation_policy=None):
# Note: Can change DB state.
if work and presentation_calculation_policy is not None:
- print("!!! About to calculate presentation!")
- work.calculate_presentation(policy=presentation_calculation_policy)
- print("!!! All done!")
- print()
- print("After recalculating presentation:")
- self.explain_work(work)
-
+ print("!!! About to calculate presentation!")
+ work.calculate_presentation(policy=presentation_calculation_policy)
+ print("!!! All done!")
+ print()
+ print("After recalculating presentation:")
+ self.explain_work(work)
def explain_contribution(self, contribution):
contributor_id = contribution.contributor.id
contributor_sort_name = contribution.contributor.sort_name
contributor_display_name = contribution.contributor.display_name
- self.write(" Contributor[%s]: contributor_sort_name=%s, contributor_display_name=%s, " % (contributor_id, contributor_sort_name, contributor_display_name))
+ self.write(
+ " Contributor[%s]: contributor_sort_name=%s, contributor_display_name=%s, "
+ % (contributor_id, contributor_sort_name, contributor_display_name)
+ )
def explain_identifier(self, identifier, primary, seen, strength, level):
indent = " " * level
@@ -2935,11 +3015,15 @@ def explain_identifier(self, identifier, primary, seen, strength, level):
ident = "Identifier"
if primary:
strength = 1
- self.write("%s %s: %s/%s (q=%s)" % (indent, ident, identifier.type, identifier.identifier, strength))
+ self.write(
+ "%s %s: %s/%s (q=%s)"
+ % (indent, ident, identifier.type, identifier.identifier, strength)
+ )
_db = Session.object_session(identifier)
classifications = Identifier.classifications_for_identifier_ids(
- _db, [identifier.id])
+ _db, [identifier.id]
+ )
for classification in classifications:
subject = classification.subject
genre = subject.genre
@@ -2947,18 +3031,19 @@ def explain_identifier(self, identifier, primary, seen, strength, level):
genre = genre.name
else:
genre = "(!genre)"
- #print("%s %s says: %s/%s %s w=%s" % (
+ # print("%s %s says: %s/%s %s w=%s" % (
# indent, classification.data_source.name,
# subject.identifier, subject.name, genre, classification.weight
- #))
+ # ))
seen.add(identifier)
for equivalency in identifier.equivalencies:
if equivalency.id in seen:
continue
seen.add(equivalency.id)
output = equivalency.output
- self.explain_identifier(output, False, seen,
- equivalency.strength, level+1)
+ self.explain_identifier(
+ output, False, seen, equivalency.strength, level + 1
+ )
if primary:
crs = identifier.coverage_records
if crs:
@@ -2988,14 +3073,23 @@ def explain_license_pool(self, pool):
self.write(" %s %s/%s" % (fulfillable, dm.content_type, dm.drm_scheme))
else:
self.write(" No delivery mechanisms.")
- self.write(" %s owned, %d available, %d holds, %d reserves" % (
- pool.licenses_owned, pool.licenses_available, pool.patrons_in_hold_queue, pool.licenses_reserved
- ))
+ self.write(
+ " %s owned, %d available, %d holds, %d reserves"
+ % (
+ pool.licenses_owned,
+ pool.licenses_available,
+ pool.patrons_in_hold_queue,
+ pool.licenses_reserved,
+ )
+ )
def explain_work(self, work):
self.write("Work info:")
if work.presentation_edition:
- self.write(" Identifier of presentation edition: %r" % work.presentation_edition.primary_identifier)
+ self.write(
+ " Identifier of presentation edition: %r"
+ % work.presentation_edition.primary_identifier
+ )
else:
self.write(" No presentation edition.")
self.write(" Fiction: %s" % work.fiction)
@@ -3012,7 +3106,7 @@ def explain_work(self, work):
if pool.collection:
collection = pool.collection.name
else:
- collection = '!collection'
+ collection = "!collection"
self.write(" %s: %r %s" % (active, pool.identifier, collection))
wcrs = sorted(work.coverage_records, key=lambda x: x.timestamp)
if wcrs:
@@ -3022,8 +3116,7 @@ def explain_work(self, work):
def explain_coverage_record(self, cr):
self._explain_coverage_record(
- cr.timestamp, cr.data_source, cr.operation, cr.status,
- cr.exception
+ cr.timestamp, cr.data_source, cr.operation, cr.status, cr.exception
)
def explain_work_coverage_record(self, cr):
@@ -3031,25 +3124,25 @@ def explain_work_coverage_record(self, cr):
cr.timestamp, None, cr.operation, cr.status, cr.exception
)
- def _explain_coverage_record(self, timestamp, data_source, operation,
- status, exception):
+ def _explain_coverage_record(
+ self, timestamp, data_source, operation, status, exception
+ ):
timestamp = timestamp.strftime(self.TIME_FORMAT)
if data_source:
- data_source = data_source.name + ' | '
+ data_source = data_source.name + " | "
else:
- data_source = ''
+ data_source = ""
if operation:
- operation = operation + ' | '
+ operation = operation + " | "
else:
- operation = ''
+ operation = ""
if exception:
- exception = ' | ' + exception
+ exception = " | " + exception
else:
- exception = ''
- self.write(" %s | %s%s%s%s" % (
- timestamp, data_source, operation, status,
- exception
- ))
+ exception = ""
+ self.write(
+ " %s | %s%s%s%s" % (timestamp, data_source, operation, status, exception)
+ )
class WhereAreMyBooksScript(CollectionInputScript):
@@ -3058,6 +3151,7 @@ class WhereAreMyBooksScript(CollectionInputScript):
This is a common problem on a new installation or when a new collection
is being configured.
"""
+
def __init__(self, _db=None, output=None, search=None):
_db = _db or self._db
super(WhereAreMyBooksScript, self).__init__(_db)
@@ -3065,7 +3159,9 @@ def __init__(self, _db=None, output=None, search=None):
try:
self.search = search or ExternalSearchIndex(_db)
except CannotLoadConfiguration:
- self.out("Here's your problem: the search integration is missing or misconfigured.")
+ self.out(
+ "Here's your problem: the search integration is missing or misconfigured."
+ )
raise
def out(self, s, *args):
@@ -3114,7 +3210,8 @@ def delete_cached_feeds(self):
)
page_feeds_count = page_feeds.count()
self.out(
- "%d feeds in cachedfeeds table, not counting grouped feeds.", page_feeds_count
+ "%d feeds in cachedfeeds table, not counting grouped feeds.",
+ page_feeds_count,
)
if page_feeds_count:
self.out(" Deleting them all.")
@@ -3124,12 +3221,14 @@ def delete_cached_feeds(self):
def explain_collection(self, collection):
self.out('Examining collection "%s"', collection.name)
- base = self._db.query(Work).join(LicensePool).filter(
- LicensePool.collection==collection
+ base = (
+ self._db.query(Work)
+ .join(LicensePool)
+ .filter(LicensePool.collection == collection)
)
- ready = base.filter(Work.presentation_ready==True)
- unready = base.filter(Work.presentation_ready==False)
+ ready = base.filter(Work.presentation_ready == True)
+ unready = base.filter(Work.presentation_ready == False)
ready_count = ready.count()
unready_count = unready.count()
@@ -3140,22 +3239,23 @@ def explain_collection(self, collection):
LPDM = LicensePoolDeliveryMechanism
no_delivery_mechanisms = base.filter(
~exists().where(
- and_(LicensePool.data_source_id==LPDM.data_source_id,
- LicensePool.identifier_id==LPDM.identifier_id)
+ and_(
+ LicensePool.data_source_id == LPDM.data_source_id,
+ LicensePool.identifier_id == LPDM.identifier_id,
+ )
)
).count()
if no_delivery_mechanisms > 0:
self.out(
" %d works are missing delivery mechanisms and won't show up.",
- no_delivery_mechanisms
+ no_delivery_mechanisms,
)
# Check if the license pools are suppressed.
- suppressed = base.filter(LicensePool.suppressed==True).count()
+ suppressed = base.filter(LicensePool.suppressed == True).count()
if suppressed > 0:
self.out(
- " %d works have suppressed LicensePools and won't show up.",
- suppressed
+ " %d works have suppressed LicensePools and won't show up.", suppressed
)
# Check if the pools have available licenses.
@@ -3165,14 +3265,13 @@ def explain_collection(self, collection):
if not_owned > 0:
self.out(
" %d non-open-access works have no owned licenses and won't show up.",
- not_owned
+ not_owned,
)
filter = Filter(collections=[collection])
count = self.search.count_works(filter)
self.out(
- " %d works in the search index, expected around %d.",
- count, ready_count
+ " %d works in the search index, expected around %d.", count, ready_count
)
@@ -3182,6 +3281,7 @@ class ListCollectionMetadataIdentifiersScript(CollectionInputScript):
This script is helpful for accounting for and tracking collections on
the metadata wrangler.
"""
+
def __init__(self, _db=None, output=None):
_db = _db or self._db
super(ListCollectionMetadataIdentifiersScript, self).__init__(_db)
@@ -3200,19 +3300,18 @@ def do_run(self, collections=None):
if collection_ids:
collections = collections.filter(Collection.id.in_(collection_ids))
- self.output.write('COLLECTIONS\n')
- self.output.write('='*50+'\n')
+ self.output.write("COLLECTIONS\n")
+ self.output.write("=" * 50 + "\n")
+
def add_line(id, name, protocol, metadata_identifier):
- line = '(%s) %s/%s => %s\n' % (
- id, name, protocol, metadata_identifier
- )
+ line = "(%s) %s/%s => %s\n" % (id, name, protocol, metadata_identifier)
self.output.write(line)
count = 0
for collection in collections:
if not count:
# Add a format line.
- add_line('id', 'name', 'protocol', 'metadata_identifier')
+ add_line("id", "name", "protocol", "metadata_identifier")
count += 1
add_line(
@@ -3222,11 +3321,10 @@ def add_line(id, name, protocol, metadata_identifier):
collection.metadata_identifier,
)
- self.output.write('\n%d collections found.\n' % count)
+ self.output.write("\n%d collections found.\n" % count)
class UpdateLaneSizeScript(LaneSweeperScript):
-
def should_process_lane(self, lane):
"""We don't want to process generic WorkLists -- there's nowhere
to store the data.
@@ -3248,25 +3346,24 @@ class RemovesSearchCoverage(object):
"""Mix-in class for a script that might remove all coverage records
for the search engine.
"""
+
def remove_search_coverage_records(self):
"""Delete all search coverage records from the database.
:return: The number of records deleted.
"""
wcr = WorkCoverageRecord
- clause = wcr.operation==wcr.UPDATE_SEARCH_INDEX_OPERATION
+ clause = wcr.operation == wcr.UPDATE_SEARCH_INDEX_OPERATION
count = self._db.query(wcr).filter(clause).count()
self._db.execute(wcr.__table__.delete().where(clause))
return count
-class RebuildSearchIndexScript(
- RunWorkCoverageProviderScript, RemovesSearchCoverage
-):
+class RebuildSearchIndexScript(RunWorkCoverageProviderScript, RemovesSearchCoverage):
"""Completely delete the search index and recreate it."""
def __init__(self, *args, **kwargs):
- search = kwargs.get('search_index_client', None)
+ search = kwargs.get("search_index_client", None)
self.search = search or ExternalSearchIndex(self._db)
super(RebuildSearchIndexScript, self).__init__(
SearchIndexCoverageProvider, *args, **kwargs
@@ -3292,17 +3389,17 @@ class SearchIndexCoverageRemover(TimestampScript, RemovesSearchCoverage):
This guarantees the SearchIndexCoverageProvider will add
fresh coverage for every Work the next time it runs.
"""
+
def do_run(self):
count = self.remove_search_coverage_records()
return TimestampData(
- achievements="Coverage records deleted: %(deleted)d" % dict(
- deleted=count
- )
+ achievements="Coverage records deleted: %(deleted)d" % dict(deleted=count)
)
class MockStdin(object):
"""Mock a list of identifiers passed in on standard input."""
+
def __init__(self, *lines):
self.lines = lines
diff --git a/selftest.py b/selftest.py
index 74b26c705..96d306479 100644
--- a/selftest.py
+++ b/selftest.py
@@ -1,12 +1,13 @@
"""Define the interfaces used by ExternalIntegration self-tests.
"""
-from .util.http import IntegrationException
import json
import logging
import traceback
-from .util.opds_writer import AtomFeed
from .util.datetime_helpers import utc_now
+from .util.http import IntegrationException
+from .util.opds_writer import AtomFeed
+
class SelfTestResult(object):
"""The result of running a single self-test.
@@ -45,31 +46,34 @@ def to_dict(self):
# Time formatting method
f = AtomFeed._strftime
if self.exception:
- exception = { "class": self.exception.__class__.__name__,
- "message": str(self.exception),
- "debug_message" : self.debug_message }
+ exception = {
+ "class": self.exception.__class__.__name__,
+ "message": str(self.exception),
+ "debug_message": self.debug_message,
+ }
else:
exception = None
value = dict(
- name=self.name, success=self.success,
+ name=self.name,
+ success=self.success,
duration=self.duration,
exception=exception,
)
if self.start:
- value['start'] = f(self.start)
+ value["start"] = f(self.start)
if self.end:
- value['end'] = f(self.end)
+ value["end"] = f(self.end)
if self.collection:
- value['collection'] = self.collection.name
+ value["collection"] = self.collection.name
# String results will be displayed in a fixed-width font.
# Lists of strings will be hidden behind an expandable toggle.
# Other return values have no defined method of display.
if isinstance(self.result, str) or isinstance(self.result, list):
- value['result'] = self.result
+ value["result"] = self.result
else:
- value['result'] = None
+ value["result"] = None
return value
def __repr__(self):
@@ -77,7 +81,7 @@ def __repr__(self):
if isinstance(self.exception, IntegrationException):
exception = " exception=%r debug=%r" % (
str(self.exception),
- self.debug_message
+ self.debug_message,
)
else:
exception = " exception=%r" % self.exception
@@ -88,8 +92,12 @@ def __repr__(self):
else:
collection = ""
return "" % (
- self.name, collection, self.duration, self.success,
- exception, self.result
+ self.name,
+ collection,
+ self.duration,
+ self.success,
+ exception,
+ self.result,
)
@property
@@ -97,14 +105,14 @@ def duration(self):
"""How long the test took to run."""
if not self.start or not self.end:
return 0
- return (self.end-self.start).total_seconds()
+ return (self.end - self.start).total_seconds()
@property
def debug_message(self):
"""The debug message associated with the Exception, if any."""
if not self.exception:
return None
- return getattr(self.exception, 'debug_message', None)
+ return getattr(self.exception, "debug_message", None)
class HasSelfTests(object):
@@ -114,7 +122,7 @@ class HasSelfTests(object):
# Self-test results are stored in a ConfigurationSetting with this name,
# associated with the appropriate ExternalIntegration.
- SELF_TEST_RESULTS_SETTING = 'self_test_results'
+ SELF_TEST_RESULTS_SETTING = "self_test_results"
@classmethod
def run_self_tests(cls, _db, constructor_method=None, *args, **kwargs):
@@ -169,7 +177,6 @@ def run_self_tests(cls, _db, constructor_method=None, *args, **kwargs):
)
results.append(failure)
-
end = utc_now()
# Format the results in a useful way.
@@ -177,8 +184,8 @@ def run_self_tests(cls, _db, constructor_method=None, *args, **kwargs):
value = dict(
start=AtomFeed._strftime(start),
end=AtomFeed._strftime(end),
- duration = (end-start).total_seconds(),
- results = [x.to_dict for x in results]
+ duration=(end - start).total_seconds(),
+ results=[x.to_dict for x in results],
)
# Store the formatted results in the database, if we can find
# a place to store them.
@@ -193,9 +200,7 @@ def run_self_tests(cls, _db, constructor_method=None, *args, **kwargs):
integration = instance.external_integration(_db)
if integration:
- integration.setting(
- cls.SELF_TEST_RESULTS_SETTING
- ).value = json.dumps(value)
+ integration.setting(cls.SELF_TEST_RESULTS_SETTING).value = json.dumps(value)
return value, results
@@ -210,12 +215,16 @@ def prior_test_results(cls, _db, constructor_method=None, *args, **kwargs):
instance = constructor_method(*args, **kwargs)
from .external_search import ExternalSearchIndex
+
if isinstance(instance, ExternalSearchIndex):
integration = instance.search_integration(_db)
else:
integration = instance.external_integration(_db)
if integration:
- return integration.setting(cls.SELF_TEST_RESULTS_SETTING).json_value or "No results yet"
+ return (
+ integration.setting(cls.SELF_TEST_RESULTS_SETTING).json_value
+ or "No results yet"
+ )
def external_integration(self, _db):
"""Locate the ExternalIntegration associated with this object.
diff --git a/testing.py b/testing.py
index 35166019f..9c9ab3182 100644
--- a/testing.py
+++ b/testing.py
@@ -1,38 +1,45 @@
-from datetime import timedelta
+import inspect
import json
import logging
import os
import shutil
-import time
import tempfile
+import time
import uuid
-from psycopg2.errors import UndefinedTable
-import pytest
-from sqlalchemy.orm.session import Session
-from sqlalchemy.exc import ProgrammingError
+from datetime import timedelta
from pdb import set_trace
+
import mock
-import inspect
+import pytest
+from psycopg2.errors import UndefinedTable
+from sqlalchemy.exc import ProgrammingError
+from sqlalchemy.orm.session import Session
+from . import external_search
+from .classifier import Classifier
from .config import Configuration
-from .lane import (
- Lane,
+from .coverage import (
+ BibliographicCoverageProvider,
+ CollectionCoverageProvider,
+ CoverageFailure,
+ IdentifierCoverageProvider,
+ WorkCoverageProvider,
)
-from .model.constants import MediaTypes
-from .model import (
- Base,
- PresentationCalculationPolicy,
- SessionManager,
- get_one_or_create,
- create,
+from .external_search import (
+ ExternalSearchIndex,
+ MockExternalSearchIndex,
+ SearchIndexCoverageProvider,
)
+from .lane import Lane
+from .log import LogConfiguration
from .model import (
- CoverageRecord,
+ Base,
Classification,
Collection,
Complaint,
ConfigurationSetting,
Contributor,
+ CoverageRecord,
Credential,
CustomList,
DataSource,
@@ -49,36 +56,27 @@
LicensePool,
LicensePoolDeliveryMechanism,
Patron,
+ PresentationCalculationPolicy,
Representation,
Resource,
RightsStatus,
+ SessionManager,
Subject,
Work,
WorkCoverageRecord,
+ create,
+ get_one_or_create,
)
from .model.configuration import ExternalIntegrationLink
-from .classifier import Classifier
-from .coverage import (
- BibliographicCoverageProvider,
- CollectionCoverageProvider,
- IdentifierCoverageProvider,
- CoverageFailure,
- WorkCoverageProvider,
-)
-
-from .external_search import (
- MockExternalSearchIndex,
- ExternalSearchIndex,
- SearchIndexCoverageProvider,
-)
-from .log import LogConfiguration
-from . import external_search
+from .model.constants import MediaTypes
from .util.datetime_helpers import datetime_utc, utc_now
+
class LogCaptureHandler(logging.Handler):
"""A `logging.Handler` context manager that captures the messages
of emitted log records in the context of the specified `logger`.
"""
+
_level_names = logging._levelToName.values()
@staticmethod
@@ -114,15 +112,17 @@ def emit(self, record):
self._records[level].append(record.getMessage())
def reset(self):
- """Empty the message accumulators.
- """
+ """Empty the message accumulators."""
self._records = {level: [] for level in self.LEVEL_NAMES}
def __getitem__(self, item):
if item in self.LEVEL_NAMES:
return self._records[item]
else:
- message = "'%s' object has no attribute '%s'" % (self.__class__.__name__, item)
+ message = "'%s' object has no attribute '%s'" % (
+ self.__class__.__name__,
+ item,
+ )
raise AttributeError(message)
def __getattr__(self, item):
@@ -166,7 +166,10 @@ def teardown_class(cls):
shutil.rmtree(cls.tmp_data_dir)
else:
- logging.warn("Cowardly refusing to remove 'temporary' directory %s" % cls.tmp_data_dir)
+ logging.warn(
+ "Cowardly refusing to remove 'temporary' directory %s"
+ % cls.tmp_data_dir
+ )
Configuration.instance[Configuration.DATA_DIRECTORY] = cls.old_data_dir
@@ -177,7 +180,10 @@ def search_mock(self, request):
if elasticsearch_mark is not None:
self.search_mock = None
else:
- self.search_mock = mock.patch(external_search.__name__ + ".ExternalSearchIndex", MockExternalSearchIndex)
+ self.search_mock = mock.patch(
+ external_search.__name__ + ".ExternalSearchIndex",
+ MockExternalSearchIndex,
+ )
self.search_mock.start()
yield
if self.search_mock:
@@ -193,7 +199,10 @@ def setup_method(self):
self.time_counter = datetime_utc(2014, 1, 1)
self.isbns = [
- "9780674368279", "0636920028468", "9781936460236", "9780316075978"
+ "9780674368279",
+ "0636920028468",
+ "9781936460236",
+ "9780316075978",
]
def teardown_method(self):
@@ -218,23 +227,23 @@ def teardown_method(self):
# Also roll back any record of those changes in the
# Configuration instance.
for key in [
- Configuration.SITE_CONFIGURATION_LAST_UPDATE,
- Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE
+ Configuration.SITE_CONFIGURATION_LAST_UPDATE,
+ Configuration.LAST_CHECKED_FOR_SITE_CONFIGURATION_UPDATE,
]:
if key in Configuration.instance:
- del(Configuration.instance[key])
+ del Configuration.instance[key]
def time_eq(self, a, b):
"Assert that two times are *approximately* the same -- within 2 seconds."
if a < b:
- delta = b-a
+ delta = b - a
else:
- delta = a-b
+ delta = a - b
total_seconds = delta.total_seconds()
- assert (total_seconds < 2), ("Delta was too large: %.2f seconds." % total_seconds)
+ assert total_seconds < 2, "Delta was too large: %.2f seconds." % total_seconds
def shortDescription(self):
- return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2.
+ return None # Stop nosetests displaying docstrings instead of class names when verbosity level >= 2.
@property
def _id(self):
@@ -263,8 +272,7 @@ def _patron(self, external_identifier=None, library=None):
external_identifier = external_identifier or self._str
library = library or self._default_library
return get_one_or_create(
- self._db, Patron, external_identifier=external_identifier,
- library=library
+ self._db, Patron, external_identifier=external_identifier, library=library
)[0]
def _contributor(self, sort_name=None, name=None, **kw_args):
@@ -279,25 +287,24 @@ def _identifier(self, identifier_type=Identifier.GUTENBERG_ID, foreign_id=None):
return Identifier.for_foreign_id(self._db, identifier_type, id)[0]
def _edition(
- self,
- data_source_name=DataSource.GUTENBERG,
- identifier_type=Identifier.GUTENBERG_ID,
- with_license_pool=False,
- with_open_access_download=False,
- title=None,
- language="eng",
- authors=None,
- identifier_id=None,
- series=None,
- collection=None,
- publication_date=None,
- self_hosted=False,
- unlimited_access=False
+ self,
+ data_source_name=DataSource.GUTENBERG,
+ identifier_type=Identifier.GUTENBERG_ID,
+ with_license_pool=False,
+ with_open_access_download=False,
+ title=None,
+ language="eng",
+ authors=None,
+ identifier_id=None,
+ series=None,
+ collection=None,
+ publication_date=None,
+ self_hosted=False,
+ unlimited_access=False,
):
id = identifier_id or self._str
source = DataSource.lookup(self._db, data_source_name)
- wr = Edition.for_foreign_id(
- self._db, source, identifier_type, id)[0]
+ wr = Edition.for_foreign_id(self._db, source, identifier_type, id)[0]
if not title:
title = self._str
wr.title = str(title)
@@ -312,10 +319,12 @@ def _edition(
authors = [authors]
if authors:
primary_author_name = str(authors[0])
- contributor = wr.add_contributor(primary_author_name, Contributor.PRIMARY_AUTHOR_ROLE)
+ contributor = wr.add_contributor(
+ primary_author_name, Contributor.PRIMARY_AUTHOR_ROLE
+ )
# add_contributor assumes authors[0] is a sort_name,
# but it may be a display name. If so, set that field as well.
- if not contributor.display_name and ',' not in primary_author_name:
+ if not contributor.display_name and "," not in primary_author_name:
contributor.display_name = primary_author_name
wr.author = primary_author_name
@@ -326,11 +335,12 @@ def _edition(
if with_license_pool or with_open_access_download:
pool = self._licensepool(
- wr, data_source_name=data_source_name,
+ wr,
+ data_source_name=data_source_name,
with_open_access_download=with_open_access_download,
collection=collection,
self_hosted=self_hosted,
- unlimited_access=unlimited_access
+ unlimited_access=unlimited_access,
)
pool.set_presentation_edition()
@@ -338,22 +348,22 @@ def _edition(
return wr
def _work(
- self,
- title=None,
- authors=None,
- genre=None,
- language=None,
- audience=None,
- fiction=True,
- with_license_pool=False,
- with_open_access_download=False,
- quality=0.5,
- series=None,
- presentation_edition=None,
- collection=None,
- data_source_name=None,
- self_hosted=False,
- unlimited_access=False
+ self,
+ title=None,
+ authors=None,
+ genre=None,
+ language=None,
+ audience=None,
+ fiction=True,
+ with_license_pool=False,
+ with_open_access_download=False,
+ quality=0.5,
+ series=None,
+ presentation_edition=None,
+ collection=None,
+ data_source_name=None,
+ self_hosted=False,
+ unlimited_access=False,
):
"""Create a Work.
@@ -381,7 +391,8 @@ def _work(
if not presentation_edition:
new_edition = True
presentation_edition = self._edition(
- title=title, language=language,
+ title=title,
+ language=language,
authors=authors,
with_license_pool=with_license_pool,
with_open_access_download=with_open_access_download,
@@ -389,7 +400,7 @@ def _work(
series=series,
collection=collection,
self_hosted=self_hosted,
- unlimited_access=unlimited_access
+ unlimited_access=unlimited_access,
)
if with_license_pool:
presentation_edition, pool = presentation_edition
@@ -406,10 +417,13 @@ def _work(
else:
pools = presentation_edition.license_pools
work, ignore = get_one_or_create(
- self._db, Work, create_method_kwargs=dict(
- audience=audience,
- fiction=fiction,
- quality=quality), id=self._id)
+ self._db,
+ Work,
+ create_method_kwargs=dict(
+ audience=audience, fiction=fiction, quality=quality
+ ),
+ id=self._id,
+ )
if genre:
if not isinstance(genre, Genre):
genre, ignore = Genre.lookup(self._db, genre, autocreate=True)
@@ -434,21 +448,29 @@ def _work(
return work
- def _lane(self, display_name=None, library=None,
- parent=None, genres=None, languages=None,
- fiction=None, inherit_parent_restrictions=True
+ def _lane(
+ self,
+ display_name=None,
+ library=None,
+ parent=None,
+ genres=None,
+ languages=None,
+ fiction=None,
+ inherit_parent_restrictions=True,
):
display_name = display_name or self._str
library = library or self._default_library
lane, is_new = create(
- self._db, Lane,
+ self._db,
+ Lane,
library=library,
- parent=parent, display_name=display_name,
+ parent=parent,
+ display_name=display_name,
fiction=fiction,
- inherit_parent_restrictions=inherit_parent_restrictions
+ inherit_parent_restrictions=inherit_parent_restrictions,
)
if is_new and parent:
- lane.priority = len(parent.sublanes)-1
+ lane.priority = len(parent.sublanes) - 1
if genres:
if not isinstance(genres, list):
genres = [genres]
@@ -480,68 +502,77 @@ def _add_generic_delivery_mechanism(self, license_pool):
content_type = Representation.EPUB_MEDIA_TYPE
drm_scheme = DeliveryMechanism.NO_DRM
return LicensePoolDeliveryMechanism.set(
- data_source, identifier, content_type, drm_scheme,
- RightsStatus.IN_COPYRIGHT
+ data_source, identifier, content_type, drm_scheme, RightsStatus.IN_COPYRIGHT
)
- def _coverage_record(self, edition, coverage_source, operation=None,
- status=CoverageRecord.SUCCESS, collection=None, exception=None,
+ def _coverage_record(
+ self,
+ edition,
+ coverage_source,
+ operation=None,
+ status=CoverageRecord.SUCCESS,
+ collection=None,
+ exception=None,
):
if isinstance(edition, Identifier):
identifier = edition
else:
identifier = edition.primary_identifier
record, ignore = get_one_or_create(
- self._db, CoverageRecord,
+ self._db,
+ CoverageRecord,
identifier=identifier,
data_source=coverage_source,
operation=operation,
collection=collection,
- create_method_kwargs = dict(
+ create_method_kwargs=dict(
timestamp=utc_now(),
status=status,
exception=exception,
- )
+ ),
)
return record
- def _work_coverage_record(self, work, operation=None,
- status=CoverageRecord.SUCCESS):
+ def _work_coverage_record(
+ self, work, operation=None, status=CoverageRecord.SUCCESS
+ ):
record, ignore = get_one_or_create(
- self._db, WorkCoverageRecord,
+ self._db,
+ WorkCoverageRecord,
work=work,
operation=operation,
- create_method_kwargs = dict(
+ create_method_kwargs=dict(
timestamp=utc_now(),
status=status,
- )
+ ),
)
return record
def _licensepool(
- self,
- edition,
- open_access=True,
- data_source_name=DataSource.GUTENBERG,
- with_open_access_download=False,
- set_edition_as_presentation=False,
- collection=None,
- self_hosted=False,
- unlimited_access=False
+ self,
+ edition,
+ open_access=True,
+ data_source_name=DataSource.GUTENBERG,
+ with_open_access_download=False,
+ set_edition_as_presentation=False,
+ collection=None,
+ self_hosted=False,
+ unlimited_access=False,
):
source = DataSource.lookup(self._db, data_source_name)
if not edition:
edition = self._edition(data_source_name)
collection = collection or self._default_collection
pool, ignore = get_one_or_create(
- self._db, LicensePool,
+ self._db,
+ LicensePool,
create_method_kwargs=dict(open_access=open_access),
identifier=edition.primary_identifier,
data_source=source,
collection=collection,
availability_time=utc_now(),
self_hosted=self_hosted,
- unlimited_access=unlimited_access
+ unlimited_access=unlimited_access,
)
if set_edition_as_presentation:
@@ -552,8 +583,7 @@ def _licensepool(
url = "http://foo.com/" + self._str
media_type = MediaTypes.EPUB_MEDIA_TYPE
link, new = pool.identifier.add_link(
- Hyperlink.OPEN_ACCESS_DOWNLOAD, url,
- source, media_type
+ Hyperlink.OPEN_ACCESS_DOWNLOAD, url, source, media_type
)
# Add a DeliveryMechanism for this download
@@ -565,7 +595,8 @@ def _licensepool(
)
representation, is_new = self._representation(
- url, media_type, "Dummy content", mirrored=True)
+ url, media_type, "Dummy content", mirrored=True
+ )
link.resource.representation = representation
else:
# Add a DeliveryMechanism for this licensepool
@@ -573,7 +604,7 @@ def _licensepool(
MediaTypes.EPUB_MEDIA_TYPE,
DeliveryMechanism.ADOBE_DRM,
RightsStatus.UNKNOWN,
- None
+ None,
)
if not unlimited_access:
@@ -581,25 +612,35 @@ def _licensepool(
return pool
- def _license(self, pool, identifier=None, checkout_url=None, status_url=None,
- expires=None, remaining_checkouts=None, concurrent_checkouts=None):
+ def _license(
+ self,
+ pool,
+ identifier=None,
+ checkout_url=None,
+ status_url=None,
+ expires=None,
+ remaining_checkouts=None,
+ concurrent_checkouts=None,
+ ):
identifier = identifier or self._str
checkout_url = checkout_url or self._str
status_url = status_url or self._str
license, ignore = get_one_or_create(
- self._db, License, identifier=identifier, license_pool=pool,
+ self._db,
+ License,
+ identifier=identifier,
+ license_pool=pool,
checkout_url=checkout_url,
- status_url=status_url, expires=expires,
+ status_url=status_url,
+ expires=expires,
remaining_checkouts=remaining_checkouts,
concurrent_checkouts=concurrent_checkouts,
)
return license
- def _representation(self, url=None, media_type=None, content=None,
- mirrored=False):
+ def _representation(self, url=None, media_type=None, content=None, mirrored=False):
url = url or "http://foo.com/" + self._str
- repr, is_new = get_one_or_create(
- self._db, Representation, url=url)
+ repr, is_new = get_one_or_create(self._db, Representation, url=url)
repr.media_type = media_type
if media_type and content:
if isinstance(content, str):
@@ -611,24 +652,28 @@ def _representation(self, url=None, media_type=None, content=None,
repr.mirrored_at = utc_now()
return repr, is_new
- def _customlist(self, foreign_identifier=None,
- name=None,
- data_source_name=DataSource.NYT, num_entries=1,
- entries_exist_as_works=True
+ def _customlist(
+ self,
+ foreign_identifier=None,
+ name=None,
+ data_source_name=DataSource.NYT,
+ num_entries=1,
+ entries_exist_as_works=True,
):
data_source = DataSource.lookup(self._db, data_source_name)
foreign_identifier = foreign_identifier or self._str
now = utc_now()
customlist, ignore = get_one_or_create(
- self._db, CustomList,
+ self._db,
+ CustomList,
create_method_kwargs=dict(
created=now,
updated=now,
name=name or self._str,
description=self._str,
- ),
+ ),
data_source=data_source,
- foreign_identifier=foreign_identifier
+ foreign_identifier=foreign_identifier,
)
editions = []
@@ -637,26 +682,21 @@ def _customlist(self, foreign_identifier=None,
work = self._work(with_open_access_download=True)
edition = work.presentation_edition
else:
- edition = self._edition(
- data_source_name, title="Item %s" % i)
- edition.permanent_work_id="Permanent work ID %s" % self._str
- customlist.add_entry(
- edition, "Annotation %s" % i, first_appearance=now)
+ edition = self._edition(data_source_name, title="Item %s" % i)
+ edition.permanent_work_id = "Permanent work ID %s" % self._str
+ customlist.add_entry(edition, "Annotation %s" % i, first_appearance=now)
editions.append(edition)
return customlist, editions
def _complaint(self, license_pool, type, source, detail, resolved=None):
complaint, is_new = Complaint.register(
- license_pool,
- type,
- source,
- detail,
- resolved
+ license_pool, type, source, detail, resolved
)
return complaint
- def _credential(self, data_source_name=DataSource.GUTENBERG,
- type=None, patron=None):
+ def _credential(
+ self, data_source_name=DataSource.GUTENBERG, type=None, patron=None
+ ):
data_source = DataSource.lookup(self._db, data_source_name)
type = type or self._str
patron = patron or self._patron()
@@ -665,8 +705,8 @@ def _credential(self, data_source_name=DataSource.GUTENBERG,
)
return credential
- def _external_integration(self, protocol, goal=None, settings=None,
- libraries=None, **kwargs
+ def _external_integration(
+ self, protocol, goal=None, settings=None, libraries=None, **kwargs
):
integration = None
if not libraries:
@@ -690,7 +730,8 @@ def _external_integration(self, protocol, goal=None, settings=None,
# Otherwise, create a brand new integration specifically
# for the library.
integration = ExternalIntegration(
- protocol=protocol, goal=goal,
+ protocol=protocol,
+ goal=goal,
)
integration.libraries.extend(libraries)
self._db.add(integration)
@@ -704,28 +745,38 @@ def _external_integration(self, protocol, goal=None, settings=None,
return integration
- def _external_integration_link(self, integration=None, library=None,
- other_integration=None, purpose="covers_mirror"):
+ def _external_integration_link(
+ self,
+ integration=None,
+ library=None,
+ other_integration=None,
+ purpose="covers_mirror",
+ ):
integration = integration or self._external_integration("some protocol")
- other_integration = other_integration or self._external_integration("some other protocol")
+ other_integration = other_integration or self._external_integration(
+ "some other protocol"
+ )
library_id = library.id if library else None
external_integration_link, ignore = get_one_or_create(
- self._db, ExternalIntegrationLink,
+ self._db,
+ ExternalIntegrationLink,
library_id=library_id,
external_integration_id=integration.id,
other_integration_id=other_integration.id,
- purpose=purpose
+ purpose=purpose,
)
return external_integration_link
def _delegated_patron_identifier(
- self, library_uri=None, patron_identifier=None,
- identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID,
- identifier=None
+ self,
+ library_uri=None,
+ patron_identifier=None,
+ identifier_type=DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID,
+ identifier=None,
):
"""Create a sample DelegatedPatronIdentifier"""
library_uri = library_uri or self._url
@@ -735,16 +786,17 @@ def _delegated_patron_identifier(
else:
if not identifier:
identifier = self._str
+
def make_id():
return identifier
+
patron, is_new = DelegatedPatronIdentifier.get_one_or_create(
- self._db, library_uri, patron_identifier, identifier_type,
- make_id
+ self._db, library_uri, patron_identifier, identifier_type, make_id
)
return patron
def _sample_ecosystem(self):
- """ Creates an ecosystem of some sample work, pool, edition, and author
+ """Creates an ecosystem of some sample work, pool, edition, and author
objects that all know each other.
"""
# make some authors
@@ -753,21 +805,36 @@ def _sample_ecosystem(self):
[alice], ignore = Contributor.lookup(self._db, "Adder, Alice")
alice.family_name, alice.display_name = alice.default_names()
- edition_std_ebooks, pool_std_ebooks = self._edition(DataSource.STANDARD_EBOOKS, Identifier.URI,
- with_license_pool=True, with_open_access_download=True, authors=[])
+ edition_std_ebooks, pool_std_ebooks = self._edition(
+ DataSource.STANDARD_EBOOKS,
+ Identifier.URI,
+ with_license_pool=True,
+ with_open_access_download=True,
+ authors=[],
+ )
edition_std_ebooks.title = "The Standard Ebooks Title"
edition_std_ebooks.subtitle = "The Standard Ebooks Subtitle"
edition_std_ebooks.add_contributor(alice, Contributor.AUTHOR_ROLE)
- edition_git, pool_git = self._edition(DataSource.PROJECT_GITENBERG, Identifier.GUTENBERG_ID,
- with_license_pool=True, with_open_access_download=True, authors=[])
+ edition_git, pool_git = self._edition(
+ DataSource.PROJECT_GITENBERG,
+ Identifier.GUTENBERG_ID,
+ with_license_pool=True,
+ with_open_access_download=True,
+ authors=[],
+ )
edition_git.title = "The GItenberg Title"
edition_git.subtitle = "The GItenberg Subtitle"
edition_git.add_contributor(bob, Contributor.AUTHOR_ROLE)
edition_git.add_contributor(alice, Contributor.AUTHOR_ROLE)
- edition_gut, pool_gut = self._edition(DataSource.GUTENBERG, Identifier.GUTENBERG_ID,
- with_license_pool=True, with_open_access_download=True, authors=[])
+ edition_gut, pool_gut = self._edition(
+ DataSource.GUTENBERG,
+ Identifier.GUTENBERG_ID,
+ with_license_pool=True,
+ with_open_access_download=True,
+ authors=[],
+ )
edition_gut.title = "The GUtenberg Title"
edition_gut.subtitle = "The GUtenberg Subtitle"
edition_gut.add_contributor(bob, Contributor.AUTHOR_ROLE)
@@ -779,9 +846,17 @@ def _sample_ecosystem(self):
work.calculate_presentation()
- return (work, pool_std_ebooks, pool_git, pool_gut,
- edition_std_ebooks, edition_git, edition_gut, alice, bob)
-
+ return (
+ work,
+ pool_std_ebooks,
+ pool_git,
+ pool_gut,
+ edition_std_ebooks,
+ edition_git,
+ edition_gut,
+ alice,
+ bob,
+ )
def print_database_instance(self):
"""
@@ -800,15 +875,16 @@ def test_name(self):
self.print_database_instance() # TODO: remove before prod
[code...]
"""
- if not 'TESTING' in os.environ:
+ if not "TESTING" in os.environ:
# we are on production, abort, abort!
- logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production.")
+ logging.warn(
+ "Forgot to remove call to testing.py:DatabaseTest.print_database_instance() before pushing to production."
+ )
return
DatabaseTest.print_database_class(self._db)
return
-
@classmethod
def print_database_class(cls, db_connection):
"""
@@ -835,9 +911,11 @@ def print_database_class(cls, db_connection):
TODO: remove before prod
"""
- if not 'TESTING' in os.environ:
+ if not "TESTING" in os.environ:
# we are on production, abort, abort!
- logging.warn("Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production.")
+ logging.warn(
+ "Forgot to remove call to testing.py:DatabaseTest.print_database_class() before pushing to production."
+ )
return
works = db_connection.query(Work).all()
@@ -847,13 +925,13 @@ def print_database_class(cls, db_connection):
data_sources = db_connection.query(DataSource).all()
representations = db_connection.query(Representation).all()
- if (not works):
+ if not works:
print("NO Work found")
for wCount, work in enumerate(works):
# pipe character at end of line helps see whitespace issues
print("Work[%s]=%s|" % (wCount, work))
- if (not work.license_pools):
+ if not work.license_pools:
print(" NO Work.LicensePool found")
for lpCount, license_pool in enumerate(work.license_pools):
print(" Work.LicensePool[%s]=%s|" % (lpCount, license_pool))
@@ -861,51 +939,59 @@ def print_database_class(cls, db_connection):
print(" Work.presentation_edition=%s|" % work.presentation_edition)
print("__________________________________________________________________\n")
- if (not identifiers):
+ if not identifiers:
print("NO Identifier found")
for iCount, identifier in enumerate(identifiers):
print("Identifier[%s]=%s|" % (iCount, identifier))
print(" Identifier.licensed_through=%s|" % identifier.licensed_through)
print("__________________________________________________________________\n")
- if (not license_pools):
+ if not license_pools:
print("NO LicensePool found")
for index, license_pool in enumerate(license_pools):
print("LicensePool[%s]=%s|" % (index, license_pool))
print(" LicensePool.work_id=%s|" % license_pool.work_id)
print(" LicensePool.data_source_id=%s|" % license_pool.data_source_id)
print(" LicensePool.identifier_id=%s|" % license_pool.identifier_id)
- print(" LicensePool.presentation_edition_id=%s|" % license_pool.presentation_edition_id)
+ print(
+ " LicensePool.presentation_edition_id=%s|"
+ % license_pool.presentation_edition_id
+ )
print(" LicensePool.superceded=%s|" % license_pool.superceded)
print(" LicensePool.suppressed=%s|" % license_pool.suppressed)
print("__________________________________________________________________\n")
- if (not editions):
+ if not editions:
print("NO Edition found")
for index, edition in enumerate(editions):
# pipe character at end of line helps see whitespace issues
print("Edition[%s]=%s|" % (index, edition))
- print(" Edition.primary_identifier_id=%s|" % edition.primary_identifier_id)
+ print(
+ " Edition.primary_identifier_id=%s|" % edition.primary_identifier_id
+ )
print(" Edition.permanent_work_id=%s|" % edition.permanent_work_id)
- if (edition.data_source):
+ if edition.data_source:
print(" Edition.data_source.id=%s|" % edition.data_source.id)
print(" Edition.data_source.name=%s|" % edition.data_source.name)
else:
print(" No Edition.data_source.")
- if (edition.license_pool):
+ if edition.license_pool:
print(" Edition.license_pool.id=%s|" % edition.license_pool.id)
else:
print(" No Edition.license_pool.")
print(" Edition.title=%s|" % edition.title)
print(" Edition.author=%s|" % edition.author)
- if (not edition.author_contributors):
+ if not edition.author_contributors:
print(" NO Edition.author_contributor found")
for acCount, author_contributor in enumerate(edition.author_contributors):
- print(" Edition.author_contributor[%s]=%s|" % (acCount, author_contributor))
+ print(
+ " Edition.author_contributor[%s]=%s|"
+ % (acCount, author_contributor)
+ )
print("__________________________________________________________________\n")
- if (not data_sources):
+ if not data_sources:
print("NO DataSource found")
for index, data_source in enumerate(data_sources):
print("DataSource[%s]=%s|" % (index, data_source))
@@ -917,35 +1003,48 @@ def print_database_class(cls, db_connection):
print(" DataSource.links=%s|" % data_source.links)
print("__________________________________________________________________\n")
- if (not representations):
+ if not representations:
print("NO Representation found")
for index, representation in enumerate(representations):
print("Representation[%s]=%s|" % (index, representation))
print(" Representation.id=%s|" % representation.id)
print(" Representation.url=%s|" % representation.url)
print(" Representation.mirror_url=%s|" % representation.mirror_url)
- print(" Representation.fetch_exception=%s|" % representation.fetch_exception)
- print(" Representation.mirror_exception=%s|" % representation.mirror_exception)
+ print(
+ " Representation.fetch_exception=%s|"
+ % representation.fetch_exception
+ )
+ print(
+ " Representation.mirror_exception=%s|"
+ % representation.mirror_exception
+ )
return
-
def _library(self, name=None, short_name=None):
- name=name or self._str
+ name = name or self._str
short_name = short_name or self._str
library, ignore = get_one_or_create(
- self._db, Library, name=name, short_name=short_name,
+ self._db,
+ Library,
+ name=name,
+ short_name=short_name,
create_method_kwargs=dict(uuid=str(uuid.uuid4())),
)
return library
- def _collection(self, name=None, protocol=ExternalIntegration.OPDS_IMPORT,
- external_account_id=None, url=None, username=None,
- password=None, data_source_name=None):
+ def _collection(
+ self,
+ name=None,
+ protocol=ExternalIntegration.OPDS_IMPORT,
+ external_account_id=None,
+ url=None,
+ username=None,
+ password=None,
+ data_source_name=None,
+ ):
name = name or self._str
- collection, ignore = get_one_or_create(
- self._db, Collection, name=name
- )
+ collection, ignore = get_one_or_create(self._db, Collection, name=name)
collection.external_account_id = external_account_id
integration = collection.create_external_integration(protocol)
integration.goal = ExternalIntegration.LICENSE_GOAL
@@ -964,7 +1063,7 @@ def _default_library(self):
By default, the `_default_collection` will be associated with
the default library.
"""
- if not hasattr(self, '_default__library'):
+ if not hasattr(self, "_default__library"):
self._default__library = self.make_default_library(self._db)
return self._default__library
@@ -978,7 +1077,7 @@ def _default_collection(self):
self._default_collection instead of calling self.collection()
saves time.
"""
- if not hasattr(self, '_default__collection'):
+ if not hasattr(self, "_default__collection"):
self._default__collection = self._default_library.collections[0]
return self._default__collection
@@ -990,10 +1089,13 @@ def make_default_library(cls, _db):
within a DatabaseTest subclass.
"""
library, ignore = get_one_or_create(
- _db, Library, create_method_kwargs=dict(
+ _db,
+ Library,
+ create_method_kwargs=dict(
uuid=str(uuid.uuid4()),
name="default",
- ), short_name="default"
+ ),
+ short_name="default",
)
collection, ignore = get_one_or_create(
_db, Collection, name="Default Collection"
@@ -1013,19 +1115,23 @@ def _integration_client(self, url=None, shared_secret=None):
url = url or self._url
secret = shared_secret or "secret"
return get_one_or_create(
- self._db, IntegrationClient, shared_secret=secret,
- create_method_kwargs=dict(url=url)
+ self._db,
+ IntegrationClient,
+ shared_secret=secret,
+ create_method_kwargs=dict(url=url),
)[0]
def _subject(self, type, identifier):
- return get_one_or_create(
- self._db, Subject, type=type, identifier=identifier
- )[0]
+ return get_one_or_create(self._db, Subject, type=type, identifier=identifier)[0]
def _classification(self, identifier, subject, data_source, weight=1):
return get_one_or_create(
- self._db, Classification, identifier=identifier, subject=subject,
- data_source=data_source, weight=weight
+ self._db,
+ Classification,
+ identifier=identifier,
+ subject=subject,
+ data_source=data_source,
+ weight=weight,
)[0]
def sample_cover_path(self, name):
@@ -1039,8 +1145,7 @@ def sample_cover_representation(self, name):
"""A Representation of the sample cover with the given filename."""
sample_cover_path = self.sample_cover_path(name)
return self._representation(
- media_type="image/png",
- content=open(sample_cover_path, 'rb').read()
+ media_type="image/png", content=open(sample_cover_path, "rb").read()
)[0]
@@ -1068,7 +1173,9 @@ class ExternalSearchTest(DatabaseTest):
to ensure that it works well overall, with a realistic index.
"""
- SIMPLIFIED_TEST_ELASTICSEARCH = os.environ.get('SIMPLIFIED_TEST_ELASTICSEARCH', 'http://localhost:9200')
+ SIMPLIFIED_TEST_ELASTICSEARCH = os.environ.get(
+ "SIMPLIFIED_TEST_ELASTICSEARCH", "http://localhost:9200"
+ )
def setup_method(self):
@@ -1083,9 +1190,9 @@ def setup_method(self):
goal=ExternalIntegration.SEARCH_GOAL,
url=self.SIMPLIFIED_TEST_ELASTICSEARCH,
settings={
- ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY : 'test_index',
- ExternalSearchIndex.TEST_SEARCH_TERM_KEY : 'test_search_term',
- }
+ ExternalSearchIndex.WORKS_INDEX_PREFIX_KEY: "test_index",
+ ExternalSearchIndex.TEST_SEARCH_TERM_KEY: "test_search_term",
+ },
)
try:
@@ -1094,7 +1201,7 @@ def setup_method(self):
self.search = None
logging.error(
"Unable to set up elasticsearch index, search tests will be skipped.",
- exc_info=e
+ exc_info=e,
)
def setup_index(self, new_index):
@@ -1106,9 +1213,7 @@ def teardown_method(self):
if self.search:
# Delete the works_index, which is almost always created.
if self.search.works_index:
- self.search.indices.delete(
- self.search.works_index, ignore=[404]
- )
+ self.search.indices.delete(self.search.works_index, ignore=[404])
# Delete any other indexes created over the course of the test.
for index in self.indexes:
self.search.indices.delete(index, ignore=[404])
@@ -1120,8 +1225,7 @@ def default_work(self, *args, **kwargs):
in the default collection.
"""
work = self._work(
- *args, with_license_pool=True,
- collection=self._default_collection, **kwargs
+ *args, with_license_pool=True, collection=self._default_collection, **kwargs
)
work.set_presentation_ready()
return work
@@ -1178,16 +1282,21 @@ def _assert_works(self, description, expect, actual, should_be_ordered=True):
expect_compare = set(expect_compare)
actual_compare = set(actual_compare)
- assert expect_compare == actual_compare, \
- "%r did not find %d works\n (%s/%s).\nInstead found %d\n (%s/%s)" % (
- description,
- len(expect), ", ".join(map(str, expect_ids)),
- ", ".join(expect_titles),
- len(actual), ", ".join(map(str, actual_ids)),
- ", ".join(actual_titles)
- )
+ assert (
+ expect_compare == actual_compare
+ ), "%r did not find %d works\n (%s/%s).\nInstead found %d\n (%s/%s)" % (
+ description,
+ len(expect),
+ ", ".join(map(str, expect_ids)),
+ ", ".join(expect_titles),
+ len(actual),
+ ", ".join(map(str, actual_ids)),
+ ", ".join(actual_titles),
+ )
- def _expect_results(self, expect, query_string=None, filter=None, pagination=None, **kwargs):
+ def _expect_results(
+ self, expect, query_string=None, filter=None, pagination=None, **kwargs
+ ):
"""Helper function to call query_works() and verify that it
returns certain work IDs.
@@ -1199,15 +1308,13 @@ def _expect_results(self, expect, query_string=None, filter=None, pagination=Non
"""
if isinstance(expect, Work):
expect = [expect]
- should_be_ordered = kwargs.pop('ordered', True)
+ should_be_ordered = kwargs.pop("ordered", True)
hits = self.search.query_works(
query_string, filter, pagination, debug=True, **kwargs
)
query_args = (query_string, filter, pagination)
- self._compare_hits(
- expect, hits, query_args, should_be_ordered, **kwargs
- )
+ self._compare_hits(expect, hits, query_args, should_be_ordered, **kwargs)
def _expect_results_multi(self, expect, queries, **kwargs):
"""Helper function to call query_works_multi() and verify that it
@@ -1223,22 +1330,16 @@ def _expect_results_multi(self, expect, queries, **kwargs):
then those exact results must come up, but their order is
not what's being tested.
"""
- should_be_ordered = kwargs.pop('ordered', True)
- resultset = list(
- self.search.query_works_multi(
- queries, debug=True, **kwargs
- )
- )
+ should_be_ordered = kwargs.pop("ordered", True)
+ resultset = list(self.search.query_works_multi(queries, debug=True, **kwargs))
for i, expect_one_query in enumerate(expect):
hits = resultset[i]
query_args = queries[i]
self._compare_hits(
- expect_one_query, hits, query_args,
- should_be_ordered, **kwargs
+ expect_one_query, hits, query_args, should_be_ordered, **kwargs
)
- def _compare_hits(self, expect, hits, query_args,
- should_be_ordered=True, **kwargs):
+ def _compare_hits(self, expect, hits, query_args, should_be_ordered=True, **kwargs):
query_string, filter, pagination = query_args
results = [x.work_id for x in hits]
actual = self._db.query(Work).filter(Work.id.in_(results)).all()
@@ -1249,8 +1350,7 @@ def _compare_hits(self, expect, hits, query_args,
for w in actual:
works_by_id[w.id] = w
actual = [
- works_by_id[result] for result in results
- if result in works_by_id
+ works_by_id[result] for result in results if result in works_by_id
]
query_args = (query_string, filter, pagination)
@@ -1267,6 +1367,7 @@ def _compare_hits(self, expect, hits, query_args,
class MockCoverageProvider(object):
"""Mixin class for mock CoverageProviders that defines common constants."""
+
SERVICE_NAME = "Generic mock CoverageProvider"
# Whenever a CoverageRecord is created, the data_source of that
@@ -1282,8 +1383,7 @@ class MockCoverageProvider(object):
PROTOCOL = ExternalIntegration.OPDS_IMPORT
-class InstrumentedCoverageProvider(MockCoverageProvider,
- IdentifierCoverageProvider):
+class InstrumentedCoverageProvider(MockCoverageProvider, IdentifierCoverageProvider):
"""A CoverageProvider that keeps track of every item it tried
to cover.
"""
@@ -1297,11 +1397,11 @@ def process_item(self, item):
return item
-class InstrumentedWorkCoverageProvider(MockCoverageProvider,
- WorkCoverageProvider):
+class InstrumentedWorkCoverageProvider(MockCoverageProvider, WorkCoverageProvider):
"""A WorkCoverageProvider that keeps track of every item it tried
to cover.
"""
+
def __init__(self, _db, *args, **kwargs):
super(InstrumentedWorkCoverageProvider, self).__init__(_db, *args, **kwargs)
self.attempts = []
@@ -1310,9 +1410,12 @@ def process_item(self, item):
self.attempts.append(item)
return item
-class AlwaysSuccessfulCollectionCoverageProvider(MockCoverageProvider,
- CollectionCoverageProvider):
+
+class AlwaysSuccessfulCollectionCoverageProvider(
+ MockCoverageProvider, CollectionCoverageProvider
+):
"""A CollectionCoverageProvider that does nothing and always succeeds."""
+
SERVICE_NAME = "Always successful (collection)"
def process_item(self, item):
@@ -1321,15 +1424,19 @@ def process_item(self, item):
class AlwaysSuccessfulCoverageProvider(InstrumentedCoverageProvider):
"""A CoverageProvider that does nothing and always succeeds."""
+
SERVICE_NAME = "Always successful"
+
class AlwaysSuccessfulWorkCoverageProvider(InstrumentedWorkCoverageProvider):
"""A WorkCoverageProvider that does nothing and always succeeds."""
+
SERVICE_NAME = "Always successful (works)"
class AlwaysSuccessfulBibliographicCoverageProvider(
- MockCoverageProvider, BibliographicCoverageProvider):
+ MockCoverageProvider, BibliographicCoverageProvider
+):
"""A BibliographicCoverageProvider that does nothing and is always
successful.
@@ -1337,6 +1444,7 @@ class AlwaysSuccessfulBibliographicCoverageProvider(
LicensePool in place beforehand. Otherwise the process will fail
during handle_success().
"""
+
SERVICE_NAME = "Always successful (bibliographic)"
def process_item(self, identifier):
@@ -1345,26 +1453,29 @@ def process_item(self, identifier):
class NeverSuccessfulCoverageProvider(InstrumentedCoverageProvider):
"""A CoverageProvider that does nothing and always fails."""
+
SERVICE_NAME = "Never successful"
def __init__(self, *args, **kwargs):
- super(NeverSuccessfulCoverageProvider, self).__init__(
- *args, **kwargs
- )
- self.transient = kwargs.get('transient') or False
+ super(NeverSuccessfulCoverageProvider, self).__init__(*args, **kwargs)
+ self.transient = kwargs.get("transient") or False
def process_item(self, item):
self.attempts.append(item)
return self.failure(item, "What did you expect?", self.transient)
+
class NeverSuccessfulWorkCoverageProvider(InstrumentedWorkCoverageProvider):
SERVICE_NAME = "Never successful (works)"
+
def process_item(self, item):
self.attempts.append(item)
return self.failure(item, "What did you expect?", False)
+
class NeverSuccessfulBibliographicCoverageProvider(
- MockCoverageProvider, BibliographicCoverageProvider):
+ MockCoverageProvider, BibliographicCoverageProvider
+):
"""Simulates a BibliographicCoverageProvider that's never successful."""
SERVICE_NAME = "Never successful (bibliographic)"
@@ -1375,40 +1486,48 @@ def process_item(self, identifier):
class BrokenCoverageProvider(InstrumentedCoverageProvider):
SERVICE_NAME = "Broken"
+
def process_item(self, item):
raise Exception("I'm too broken to even return a CoverageFailure.")
class BrokenBibliographicCoverageProvider(
- BrokenCoverageProvider, BibliographicCoverageProvider):
+ BrokenCoverageProvider, BibliographicCoverageProvider
+):
SERVICE_NAME = "Broken (bibliographic)"
class TransientFailureCoverageProvider(InstrumentedCoverageProvider):
SERVICE_NAME = "Never successful (transient)"
+
def process_item(self, item):
self.attempts.append(item)
return self.failure(item, "Oops!", True)
+
class TransientFailureWorkCoverageProvider(InstrumentedWorkCoverageProvider):
SERVICE_NAME = "Never successful (transient, works)"
+
def process_item(self, item):
self.attempts.append(item)
return self.failure(item, "Oops!", True)
+
class TaskIgnoringCoverageProvider(InstrumentedCoverageProvider):
"""A coverage provider that ignores all work given to it."""
+
SERVICE_NAME = "I ignore all work."
+
def process_batch(self, batch):
return []
-class DummyCanonicalizeLookupResponse(object):
+class DummyCanonicalizeLookupResponse(object):
@classmethod
def success(cls, result):
r = cls()
r.status_code = 200
- r.headers = { "Content-Type" : "text/plain" }
+ r.headers = {"Content-Type": "text/plain"}
r.content = result
return r
@@ -1418,26 +1537,26 @@ def failure(cls):
r.status_code = 404
return r
-class DummyMetadataClient(object):
+class DummyMetadataClient(object):
def __init__(self):
self.lookups = {}
def canonicalize_author_name(self, primary_identifier, display_author):
if display_author in self.lookups:
- return DummyCanonicalizeLookupResponse.success(
- self.lookups[display_author])
+ return DummyCanonicalizeLookupResponse.success(self.lookups[display_author])
else:
return DummyCanonicalizeLookupResponse.failure()
-class DummyHTTPClient(object):
+class DummyHTTPClient(object):
def __init__(self):
self.responses = []
self.requests = []
- def queue_response(self, response_code, media_type="text/html",
- other_headers=None, content=''):
+ def queue_response(
+ self, response_code, media_type="text/html", other_headers=None, content=""
+ ):
"""Queue a response of the type produced by
Representation.simple_http_get.
"""
@@ -1454,13 +1573,12 @@ def queue_response(self, response_code, media_type="text/html",
self.responses.append((response_code, headers, content))
def queue_requests_response(
- self, response_code, media_type="text/html",
- other_headers=None, content=''
+ self, response_code, media_type="text/html", other_headers=None, content=""
):
"""Queue a response of the type produced by HTTP.get_with_timeout."""
headers = dict(other_headers or {})
if media_type:
- headers['Content-Type'] = media_type
+ headers["Content-Type"] = media_type
response = MockRequestsResponse(response_code, headers, content)
self.responses.append(response)
@@ -1477,6 +1595,7 @@ class MockRequestsRequest(object):
"""A mock object that simulates an HTTP request from the
`requests` library.
"""
+
def __init__(self, url, method="GET", headers=None):
self.url = url
self.method = method
@@ -1487,9 +1606,8 @@ class MockRequestsResponse(object):
"""A mock object that simulates an HTTP response from the
`requests` library.
"""
- def __init__(
- self, status_code, headers={}, content=None, url=None, request=None
- ):
+
+ def __init__(self, status_code, headers={}, content=None, url=None, request=None):
self.status_code = status_code
self.headers = headers
# We want to enforce that the mocked content is a bytestring
@@ -1528,7 +1646,7 @@ def raise_for_status(self):
@pytest.fixture(autouse=True, scope="session")
def session_fixture():
# This will make sure we always connect to the test database.
- os.environ['TESTING'] = 'true'
+ os.environ["TESTING"] = "true"
# Ensure that the log configuration starts in a known state.
LogConfiguration.initialize(None, testing=True)
@@ -1540,8 +1658,8 @@ def session_fixture():
yield
- if 'TESTING' in os.environ:
- del os.environ['TESTING']
+ if "TESTING" in os.environ:
+ del os.environ["TESTING"]
def pytest_configure(config):
@@ -1549,6 +1667,4 @@ def pytest_configure(config):
config.addinivalue_line(
"markers", "elasticsearch: mark test as requiring elasticsearch"
)
- config.addinivalue_line(
- "markers", "minio: mark test as requiring minio"
- )
+ config.addinivalue_line("markers", "minio: mark test as requiring minio")
diff --git a/tests/classifiers/test_age.py b/tests/classifiers/test_age.py
index 62a854e3e..0dfb9bb04 100644
--- a/tests/classifiers/test_age.py
+++ b/tests/classifiers/test_age.py
@@ -1,17 +1,13 @@
-from ...classifier import (
- Classifier,
- AgeOrGradeClassifier,
- LCSHClassifier as LCSH,
-)
-
+from ...classifier import AgeOrGradeClassifier, Classifier
+from ...classifier import LCSHClassifier as LCSH
from ...classifier.age import (
+ AgeClassifier,
GradeLevelClassifier,
InterestLevelClassifier,
- AgeClassifier,
)
-class TestTargetAge(object):
+class TestTargetAge(object):
def test_range_tuple_swaps_mismatched_ages(self):
"""If for whatever reason a Classifier decides that something is from
ages 6 to 5, the Classifier.range_tuple() method will automatically
@@ -21,115 +17,125 @@ def test_range_tuple_swaps_mismatched_ages(self):
but if it does happen, range_tuple() will stop it from causing
downstream problems.
"""
- range1 = Classifier.range_tuple(5,6)
- range2 = Classifier.range_tuple(6,5)
+ range1 = Classifier.range_tuple(5, 6)
+ range2 = Classifier.range_tuple(6, 5)
assert range2 == range1
assert 5 == range2[0]
assert 6 == range2[1]
# If one of the target ages is None, it's left alone.
- r = Classifier.range_tuple(None,6)
+ r = Classifier.range_tuple(None, 6)
assert None == r[0]
assert 6 == r[1]
- r = Classifier.range_tuple(18,None)
+ r = Classifier.range_tuple(18, None)
assert 18 == r[0]
assert None == r[1]
-
def test_age_from_grade_classifier(self):
def f(t):
return GradeLevelClassifier.target_age(t, None)
- assert (
- Classifier.range_tuple(5,6) ==
- GradeLevelClassifier.target_age(None, "grades 0-1"))
- assert (4,7) == f("pk - 2")
- assert (5,7) == f("grades k-2")
- assert (6,6) == f("first grade")
- assert (6,6) == f("1st grade")
- assert (6,6) == f("grade 1")
- assert (7,7) == f("second grade")
- assert (7,7) == f("2nd grade")
- assert (8,8) == f("third grade")
- assert (9,9) == f("fourth grade")
- assert (10,10) == f("fifth grade")
- assert (11,11) == f("sixth grade")
- assert (12,12) == f("7th grade")
- assert (13,13) == f("grade 8")
- assert (14,14) == f("9th grade")
- assert (15,17) == f("grades 10-12")
- assert (6,6) == f("grades 00-01")
- assert (8,12) == f("grades 03-07")
- assert (8,12) == f("3-07")
- assert (8,10) == f("5 - 3")
- assert (17,17) == f("12th grade")
+
+ assert Classifier.range_tuple(5, 6) == GradeLevelClassifier.target_age(
+ None, "grades 0-1"
+ )
+ assert (4, 7) == f("pk - 2")
+ assert (5, 7) == f("grades k-2")
+ assert (6, 6) == f("first grade")
+ assert (6, 6) == f("1st grade")
+ assert (6, 6) == f("grade 1")
+ assert (7, 7) == f("second grade")
+ assert (7, 7) == f("2nd grade")
+ assert (8, 8) == f("third grade")
+ assert (9, 9) == f("fourth grade")
+ assert (10, 10) == f("fifth grade")
+ assert (11, 11) == f("sixth grade")
+ assert (12, 12) == f("7th grade")
+ assert (13, 13) == f("grade 8")
+ assert (14, 14) == f("9th grade")
+ assert (15, 17) == f("grades 10-12")
+ assert (6, 6) == f("grades 00-01")
+ assert (8, 12) == f("grades 03-07")
+ assert (8, 12) == f("3-07")
+ assert (8, 10) == f("5 - 3")
+ assert (17, 17) == f("12th grade")
# target_age() will assume that a number it sees is talking
# about a grade level, unless require_explicit_grade_marker is
# True.
- assert (14,17) == f("Children's Audio - 9-12")
- assert (7,9) == GradeLevelClassifier.target_age("2-4", None, False)
- assert (None,None) == GradeLevelClassifier.target_age("2-4", None, True)
- assert (None,None) == GradeLevelClassifier.target_age(
- "Children's Audio - 9-12", None, True)
-
- assert (None,None) == GradeLevelClassifier.target_age("grade 50", None)
- assert (None,None) == GradeLevelClassifier.target_age("road grades -- history", None)
- assert (None,None) == GradeLevelClassifier.target_age(None, None)
+ assert (14, 17) == f("Children's Audio - 9-12")
+ assert (7, 9) == GradeLevelClassifier.target_age("2-4", None, False)
+ assert (None, None) == GradeLevelClassifier.target_age("2-4", None, True)
+ assert (None, None) == GradeLevelClassifier.target_age(
+ "Children's Audio - 9-12", None, True
+ )
+
+ assert (None, None) == GradeLevelClassifier.target_age("grade 50", None)
+ assert (None, None) == GradeLevelClassifier.target_age(
+ "road grades -- history", None
+ )
+ assert (None, None) == GradeLevelClassifier.target_age(None, None)
def test_age_from_age_classifier(self):
def f(t):
return AgeClassifier.target_age(t, None)
- assert (9,12) == f("Ages 9-12")
- assert (9,13) == f("9 and up")
- assert (9,13) == f("9 and up.")
- assert (9,13) == f("9+")
- assert (9,13) == f("9+.")
- assert (None,None) == f("900-901")
- assert (9,12) == f("9-12")
- assert (9,9) == f("9 years")
- assert (9,12) == f("9 - 12 years")
- assert (12,14) == f("12 - 14")
- assert (12,14) == f("14 - 12")
- assert (0,3) == f("0-3")
- assert (5,8) == f("05 - 08")
- assert (None,None) == f("K-3")
+
+ assert (9, 12) == f("Ages 9-12")
+ assert (9, 13) == f("9 and up")
+ assert (9, 13) == f("9 and up.")
+ assert (9, 13) == f("9+")
+ assert (9, 13) == f("9+.")
+ assert (None, None) == f("900-901")
+ assert (9, 12) == f("9-12")
+ assert (9, 9) == f("9 years")
+ assert (9, 12) == f("9 - 12 years")
+ assert (12, 14) == f("12 - 14")
+ assert (12, 14) == f("14 - 12")
+ assert (0, 3) == f("0-3")
+ assert (5, 8) == f("05 - 08")
+ assert (None, None) == f("K-3")
assert (18, 18) == f("Age 18+")
# This could be improved but I've never actually seen a
# classification like this.
assert (16, 16) == f("up to age 16")
- assert (None,None) == AgeClassifier.target_age("K-3", None, True)
- assert (None,None) == AgeClassifier.target_age("9-12", None, True)
- assert (9,13) == AgeClassifier.target_age("9 and up", None, True)
- assert (7,9) == AgeClassifier.target_age("7 years and up.", None, True)
+ assert (None, None) == AgeClassifier.target_age("K-3", None, True)
+ assert (None, None) == AgeClassifier.target_age("9-12", None, True)
+ assert (9, 13) == AgeClassifier.target_age("9 and up", None, True)
+ assert (7, 9) == AgeClassifier.target_age("7 years and up.", None, True)
def test_age_from_keyword_classifier(self):
def f(t):
return LCSH.target_age(t, None)
- assert (5,5) == f("Interest age: from c 5 years")
- assert (9,12) == f("Children's Books / 9-12 Years")
- assert (9,12) == f("Ages 9-12")
- assert (9,12) == f("Age 9-12")
- assert (9,12) == f("Children's Books/Ages 9-12 Fiction")
- assert (4,8) == f("Children's Books / 4-8 Years")
- assert (0,2) == f("For children c 0-2 years")
- assert (12,14) == f("Children: Young Adult (Gr. 7-9)")
- assert (8,10) == f("Grades 3-5 (Common Core History: The Alexandria Plan)")
- assert (9,11) == f("Children: Grades 4-6")
-
- assert (0,3) == f("Baby-3 Years")
-
- assert (None,None) == f("Children's Audio - 9-12") # Doesn't specify grade or years
- assert (None,None) == f("Children's 9-12 - Literature - Classics / Contemporary")
- assert (None,None) == f("Third-graders")
- assert (None,None) == f("First graders")
- assert (None,None) == f("Fifth grade (Education)--Curricula")
+
+ assert (5, 5) == f("Interest age: from c 5 years")
+ assert (9, 12) == f("Children's Books / 9-12 Years")
+ assert (9, 12) == f("Ages 9-12")
+ assert (9, 12) == f("Age 9-12")
+ assert (9, 12) == f("Children's Books/Ages 9-12 Fiction")
+ assert (4, 8) == f("Children's Books / 4-8 Years")
+ assert (0, 2) == f("For children c 0-2 years")
+ assert (12, 14) == f("Children: Young Adult (Gr. 7-9)")
+ assert (8, 10) == f("Grades 3-5 (Common Core History: The Alexandria Plan)")
+ assert (9, 11) == f("Children: Grades 4-6")
+
+ assert (0, 3) == f("Baby-3 Years")
+
+ assert (None, None) == f(
+ "Children's Audio - 9-12"
+ ) # Doesn't specify grade or years
+ assert (None, None) == f(
+ "Children's 9-12 - Literature - Classics / Contemporary"
+ )
+ assert (None, None) == f("Third-graders")
+ assert (None, None) == f("First graders")
+ assert (None, None) == f("Fifth grade (Education)--Curricula")
def test_audience_from_age_classifier(self):
def f(t):
return AgeClassifier.audience(t, None)
+
assert Classifier.AUDIENCE_CHILDREN == f("Age 5")
assert Classifier.AUDIENCE_ADULT == f("Age 18+")
assert None == f("Ages Of Man")
@@ -142,23 +148,24 @@ def f(t):
def test_audience_from_age_or_grade_classifier(self):
def f(t):
return AgeOrGradeClassifier.audience(t, None)
- assert Classifier.AUDIENCE_CHILDREN == f(
- "Children's - Kindergarten, Age 5-6")
+
+ assert Classifier.AUDIENCE_CHILDREN == f("Children's - Kindergarten, Age 5-6")
def test_age_from_age_or_grade_classifier(self):
def f(t):
t = AgeOrGradeClassifier.scrub_identifier(t)
return AgeOrGradeClassifier.target_age(t, None)
- assert (5,6) == f("Children's - Kindergarten, Age 5-6")
- assert (5,5) == f("Children's - Kindergarten")
- assert (9,12) == f("Ages 9-12")
+ assert (5, 6) == f("Children's - Kindergarten, Age 5-6")
+ assert (5, 5) == f("Children's - Kindergarten")
+ assert (9, 12) == f("Ages 9-12")
-class TestInterestLevelClassifier(object):
+class TestInterestLevelClassifier(object):
def test_audience(self):
def f(t):
return InterestLevelClassifier.audience(t, None)
+
assert Classifier.AUDIENCE_CHILDREN == f("lg")
assert Classifier.AUDIENCE_CHILDREN == f("mg")
assert Classifier.AUDIENCE_CHILDREN == f("mg+")
@@ -167,7 +174,8 @@ def f(t):
def test_target_age(self):
def f(t):
return InterestLevelClassifier.target_age(t, None)
- assert (5,8) == f("lg")
- assert (9,13) == f("mg")
- assert (9,13) == f("mg+")
- assert (14,17) == f("ug")
+
+ assert (5, 8) == f("lg")
+ assert (9, 13) == f("mg")
+ assert (9, 13) == f("mg+")
+ assert (14, 17) == f("ug")
diff --git a/tests/classifiers/test_bic.py b/tests/classifiers/test_bic.py
index e2128b3dc..f2bcec0ac 100644
--- a/tests/classifiers/test_bic.py
+++ b/tests/classifiers/test_bic.py
@@ -2,8 +2,8 @@
from ...classifier import *
from ...classifier.bic import BICClassifier as BIC
-class TestBIC(object):
+class TestBIC(object):
def test_is_fiction(self):
def fic(bic):
return BIC.is_fiction(BIC.scrub_identifier(bic), None)
@@ -16,6 +16,7 @@ def fic(bic):
def test_audience(self):
young_adult = Classifier.AUDIENCE_YOUNG_ADULT
adult = Classifier.AUDIENCE_ADULT
+
def aud(bic):
return BIC.audience(BIC.scrub_identifier(bic), None)
@@ -25,19 +26,12 @@ def aud(bic):
def test_genre(self):
def gen(bic):
return BIC.genre(BIC.scrub_identifier(bic), None)
- assert (classifier.Art_Design ==
- gen("A"))
- assert (classifier.Art_Design ==
- gen("AB"))
- assert (classifier.Music ==
- gen("AV"))
- assert (classifier.Fantasy ==
- gen("FM"))
- assert (classifier.Economics ==
- gen("KC"))
- assert (classifier.Short_Stories ==
- gen("FYB"))
- assert (classifier.Music ==
- gen("YNC"))
- assert (classifier.European_History ==
- gen("HBJD"))
+
+ assert classifier.Art_Design == gen("A")
+ assert classifier.Art_Design == gen("AB")
+ assert classifier.Music == gen("AV")
+ assert classifier.Fantasy == gen("FM")
+ assert classifier.Economics == gen("KC")
+ assert classifier.Short_Stories == gen("FYB")
+ assert classifier.Music == gen("YNC")
+ assert classifier.European_History == gen("HBJD")
diff --git a/tests/classifiers/test_bisac.py b/tests/classifiers/test_bisac.py
index 132a0d313..f292fca9f 100644
--- a/tests/classifiers/test_bisac.py
+++ b/tests/classifiers/test_bisac.py
@@ -1,13 +1,11 @@
import re
import pytest
-from ...classifier import (
- BISACClassifier,
- Classifier,
-)
+
+from ...classifier import BISACClassifier, Classifier
from ...classifier.bisac import (
- MatchingRule,
RE,
+ MatchingRule,
anything,
fiction,
juvenile,
@@ -17,8 +15,8 @@
ya,
)
-class TestMatchingRule(object):
+class TestMatchingRule(object):
def test_registered_object_returned_on_match(self):
o = object()
rule = MatchingRule(o, "Fiction")
@@ -27,18 +25,16 @@ def test_registered_object_returned_on_match(self):
# You can't create a MatchingRule that returns None on
# match, since that's the value returned on non-match.
- pytest.raises(
- ValueError, MatchingRule, None, "Fiction"
- )
+ pytest.raises(ValueError, MatchingRule, None, "Fiction")
def test_string_match(self):
- rule = MatchingRule(True, 'Fiction')
+ rule = MatchingRule(True, "Fiction")
assert True == rule.match("fiction", "westerns")
assert None == rule.match("nonfiction", "westerns")
assert None == rule.match("all books", "fiction")
def test_regular_expression_match(self):
- rule = MatchingRule(True, RE('F.*O'))
+ rule = MatchingRule(True, RE("F.*O"))
assert True == rule.match("food")
assert True == rule.match("flapjacks and oatmeal")
assert None == rule.match("good", "food")
@@ -84,8 +80,10 @@ def test_fiction_match(self):
def test_anything_match(self):
# 'anything' can go up front.
- rule = MatchingRule(True, anything, 'Penguins')
- assert True == rule.match("juvenile fiction", "science fiction", "antarctica", "animals", "penguins")
+ rule = MatchingRule(True, anything, "Penguins")
+ assert True == rule.match(
+ "juvenile fiction", "science fiction", "antarctica", "animals", "penguins"
+ )
assert True == rule.match("fiction", "penguins")
assert True == rule.match("nonfiction", "penguins")
assert True == rule.match("penguins")
@@ -93,21 +91,25 @@ def test_anything_match(self):
# 'anything' can go in the middle, even after another special
# match rule.
- rule = MatchingRule(True, fiction, anything, 'Penguins')
- assert True == rule.match("juvenile fiction", "science fiction", "antarctica", "animals", "penguins")
+ rule = MatchingRule(True, fiction, anything, "Penguins")
+ assert True == rule.match(
+ "juvenile fiction", "science fiction", "antarctica", "animals", "penguins"
+ )
assert True == rule.match("fiction", "penguins")
assert None == rule.match("fiction", "geese")
# It's redundant, but 'anything' can go last.
- rule = MatchingRule(True, anything, 'Penguins', anything)
- assert True == rule.match("juvenile fiction", "science fiction", "antarctica", "animals", "penguins")
+ rule = MatchingRule(True, anything, "Penguins", anything)
+ assert True == rule.match(
+ "juvenile fiction", "science fiction", "antarctica", "animals", "penguins"
+ )
assert True == rule.match("fiction", "penguins", "more penguins")
assert True == rule.match("penguins")
assert None == rule.match("geese")
def test_something_match(self):
# 'something' can go anywhere.
- rule = MatchingRule(True, something, 'Penguins', something, something)
+ rule = MatchingRule(True, something, "Penguins", something, something)
assert True == rule.match("juvenile fiction", "penguins", "are", "great")
assert True == rule.match("penguins", "penguins", "i said", "penguins")
@@ -125,10 +127,14 @@ def __init__(self, identifier, name):
class TestBISACClassifier(object):
-
def _subject(self, identifier, name):
subject = MockSubject(identifier, name)
- subject.genre, subject.audience, subject.target_age, subject.fiction = BISACClassifier.classify(subject)
+ (
+ subject.genre,
+ subject.audience,
+ subject.target_age,
+ subject.fiction,
+ ) = BISACClassifier.classify(subject)
return subject
def genre_is(self, name, expect):
@@ -149,15 +155,11 @@ def test_every_rule_fires(self):
subjects.append(self._subject(identifier, name))
for i in BISACClassifier.FICTION:
if i.caught == []:
- raise Exception(
- "Fiction rule %s didn't catch anything!" % i.ruleset
- )
+ raise Exception("Fiction rule %s didn't catch anything!" % i.ruleset)
for i in BISACClassifier.GENRE:
if i.caught == []:
- raise Exception(
- "Genre rule %s didn't catch anything!" % i.ruleset
- )
+ raise Exception("Genre rule %s didn't catch anything!" % i.ruleset)
need_fiction = []
need_audience = []
@@ -170,8 +172,9 @@ def test_every_rule_fires(self):
# We determined fiction/nonfiction status for every BISAC
# subject except for humor, drama, and poetry.
for subject in need_fiction:
- assert any(subject.name.lower().startswith(x)
- for x in ['humor', 'drama', 'poetry'])
+ assert any(
+ subject.name.lower().startswith(x) for x in ["humor", "drama", "poetry"]
+ )
# We determined the target audience for every BISAC subject.
assert [] == need_audience
@@ -202,16 +205,24 @@ def test_genre_spot_checks(self):
genre_is("Fiction / African American / Urban", "Urban Fiction")
genre_is("Fiction / Urban", None)
genre_is("History / Native American", "United States History")
- genre_is("History / Modern / 17th Century", "Renaissance & Early Modern History")
+ genre_is(
+ "History / Modern / 17th Century", "Renaissance & Early Modern History"
+ )
genre_is("Biography & Autobiography / Composers & Musicians", "Music"),
- genre_is("Biography & Autobiography / Entertainment & Performing Arts", "Entertainment"),
+ genre_is(
+ "Biography & Autobiography / Entertainment & Performing Arts",
+ "Entertainment",
+ ),
genre_is("Fiction / Christian", "Religious Fiction"),
genre_is("Juvenile Nonfiction / Science & Nature / Fossils", "Nature")
genre_is("Juvenile Nonfiction / Science & Nature / Physics", "Science")
genre_is("Juvenile Nonfiction / Science & Nature / General", "Science")
genre_is("Juvenile Fiction / Social Issues / General", "Life Strategies")
genre_is("Juvenile Nonfiction / Social Issues / Pregnancy", "Life Strategies")
- genre_is("Juvenile Nonfiction / Religious / Christian / Social Issues", "Christianity")
+ genre_is(
+ "Juvenile Nonfiction / Religious / Christian / Social Issues",
+ "Christianity",
+ )
genre_is("Young Adult Fiction / Zombies", "Horror")
genre_is("Young Adult Fiction / Superheroes", "Suspense/Thriller")
@@ -224,9 +235,7 @@ def test_genre_spot_checks(self):
# Grandfathered in from an older test to validate that the new
# BISAC algorithm gives the same results as the old one.
genre_is("JUVENILE FICTION / Dystopian", "Dystopian SF")
- genre_is("JUVENILE FICTION / Stories in Verse (see also Poetry)",
- "Poetry")
-
+ genre_is("JUVENILE FICTION / Stories in Verse (see also Poetry)", "Poetry")
def test_deprecated_bisac_terms(self):
"""These BISAC terms have been deprecated. We classify them
@@ -237,7 +246,7 @@ def test_deprecated_bisac_terms(self):
self.genre_is("Technology / Fire", "Technology")
self.genre_is(
"Young Adult Nonfiction / Social Situations / Junior Prom",
- "Life Strategies"
+ "Life Strategies",
)
def test_non_bisac_classified_as_keywords(self):
@@ -281,9 +290,7 @@ def fiction_is(name, expect):
fiction_is("YOUNG ADULT FICTION / Lifestyles / Country Life", True)
fiction_is("HISTORY / General", False)
-
def test_audience_spot_checks(self):
-
def audience_is(name, expect):
subject = self._subject("", name)
assert expect == subject.audience
@@ -305,23 +312,18 @@ def audience_is(name, expect):
audience_is("YOUNG ADULT FICTION / Action & Adventure / General", ya)
def test_target_age_spot_checks(self):
-
def target_age_is(name, expect):
subject = self._subject("", name)
assert expect == subject.target_age
# These are the only BISAC classifications with implied target
# ages.
- for check in ('Fiction', 'Nonfiction'):
- target_age_is("Juvenile %s / Readers / Beginner" % check,
- (0,4))
- target_age_is("Juvenile %s / Readers / Intermediate" % check,
- (5,7))
- target_age_is("Juvenile %s / Readers / Chapter Books" % check,
- (8,13))
+ for check in ("Fiction", "Nonfiction"):
+ target_age_is("Juvenile %s / Readers / Beginner" % check, (0, 4))
+ target_age_is("Juvenile %s / Readers / Intermediate" % check, (5, 7))
+ target_age_is("Juvenile %s / Readers / Chapter Books" % check, (8, 13))
target_age_is(
- "Juvenile %s / Religious / Christian / Early Readers" % check,
- (5,7)
+ "Juvenile %s / Religious / Christian / Early Readers" % check, (5, 7)
)
# In all other cases, the classifier will fall back to the
@@ -329,8 +331,7 @@ def target_age_is(name, expect):
target_age_is("Fiction / Science Fiction / Erotica", (18, None))
target_age_is("Fiction / Science Fiction", (18, None))
target_age_is("Juvenile Fiction / Science Fiction", (None, None))
- target_age_is("Young Adult Fiction / Science Fiction / General",
- (14, 17))
+ target_age_is("Young Adult Fiction / Science Fiction / General", (14, 17))
def test_feedbooks_bisac(self):
"""Feedbooks uses a system based on BISAC but with different
@@ -356,25 +357,36 @@ def test_scrub_identifier(self):
# the canonical name is also returned. This will override
# any other name associated with the subject for classification
# purposes.
- assert (("FIC015000", "Fiction / Horror") ==
- BISACClassifier.scrub_identifier("FBFIC015000"))
+ assert ("FIC015000", "Fiction / Horror") == BISACClassifier.scrub_identifier(
+ "FBFIC015000"
+ )
def test_scrub_name(self):
"""Sometimes a data provider sends BISAC names that contain extra or
nonstandard characters. We store the data as it was provided to us,
but when it's time to classify things, we normalize it.
"""
+
def scrubbed(before, after):
assert after == BISACClassifier.scrub_name(before)
- scrubbed("ART/Collections Catalogs Exhibitions/",
- ["art", "collections, catalogs, exhibitions"])
- scrubbed("ARCHITECTURE|History|Contemporary|",
- ["architecture", "history", "contemporary"])
- scrubbed("BIOGRAPHY & AUTOBIOGRAPHY / Editors, Journalists, Publishers",
- ["biography & autobiography", "editors, journalists, publishers"])
- scrubbed("EDUCATION/Teaching Methods & Materials/Arts & Humanities */",
- ["education", "teaching methods & materials",
- "arts & humanities"])
- scrubbed("JUVENILE FICTION / Family / General (see also headings under Social Issues)",
- ["juvenile fiction", "family", "general"])
+ scrubbed(
+ "ART/Collections Catalogs Exhibitions/",
+ ["art", "collections, catalogs, exhibitions"],
+ )
+ scrubbed(
+ "ARCHITECTURE|History|Contemporary|",
+ ["architecture", "history", "contemporary"],
+ )
+ scrubbed(
+ "BIOGRAPHY & AUTOBIOGRAPHY / Editors, Journalists, Publishers",
+ ["biography & autobiography", "editors, journalists, publishers"],
+ )
+ scrubbed(
+ "EDUCATION/Teaching Methods & Materials/Arts & Humanities */",
+ ["education", "teaching methods & materials", "arts & humanities"],
+ )
+ scrubbed(
+ "JUVENILE FICTION / Family / General (see also headings under Social Issues)",
+ ["juvenile fiction", "family", "general"],
+ )
diff --git a/tests/classifiers/test_classifier.py b/tests/classifiers/test_classifier.py
index bbf003ebc..002672b10 100644
--- a/tests/classifiers/test_classifier.py
+++ b/tests/classifiers/test_classifier.py
@@ -1,46 +1,37 @@
"""Test logic surrounding classification schemes."""
-from ...testing import DatabaseTest
from collections import Counter
+
from psycopg2.extras import NumericRange
-from ...model import (
- Genre,
- DataSource,
- Subject,
- Classification,
-)
from ... import classifier
from ...classifier import (
- Classifier,
- Lowercased,
- WorkClassifier,
- Lowercased,
- fiction_genres,
- nonfiction_genres,
- GenreData,
- FreeformAudienceClassifier,
- )
-
+ Classifier,
+ FreeformAudienceClassifier,
+ GenreData,
+ Lowercased,
+ WorkClassifier,
+ fiction_genres,
+ nonfiction_genres,
+)
from ...classifier.age import (
AgeClassifier,
GradeLevelClassifier,
InterestLevelClassifier,
)
from ...classifier.ddc import DeweyDecimalClassifier as DDC
-from ...classifier.keyword import (
- LCSHClassifier as LCSH,
- FASTClassifier as FAST,
-)
+from ...classifier.keyword import FASTClassifier as FAST
+from ...classifier.keyword import LCSHClassifier as LCSH
from ...classifier.lcc import LCCClassifier as LCC
from ...classifier.simplified import SimplifiedGenreClassifier
+from ...model import Classification, DataSource, Genre, Subject
+from ...testing import DatabaseTest
genres = dict()
GenreData.populate(globals(), genres, fiction_genres, nonfiction_genres)
class TestLowercased(object):
-
def test_constructor(self):
l = Lowercased("A string")
@@ -64,7 +55,6 @@ def test_constructor(self):
class TestGenreData(object):
-
def test_fiction_default(self):
# In general, genres are restricted to either fiction or
# nonfiction.
@@ -73,21 +63,20 @@ def test_fiction_default(self):
class TestClassifier(object):
-
def test_default_target_age_for_audience(self):
- assert (
- (None, None) ==
- Classifier.default_target_age_for_audience(Classifier.AUDIENCE_CHILDREN))
- assert (
- (14, 17) ==
- Classifier.default_target_age_for_audience(Classifier.AUDIENCE_YOUNG_ADULT))
- assert (
- (18, None) ==
- Classifier.default_target_age_for_audience(Classifier.AUDIENCE_ADULT))
- assert (
- (18, None) ==
- Classifier.default_target_age_for_audience(Classifier.AUDIENCE_ADULTS_ONLY))
+ assert (None, None) == Classifier.default_target_age_for_audience(
+ Classifier.AUDIENCE_CHILDREN
+ )
+ assert (14, 17) == Classifier.default_target_age_for_audience(
+ Classifier.AUDIENCE_YOUNG_ADULT
+ )
+ assert (18, None) == Classifier.default_target_age_for_audience(
+ Classifier.AUDIENCE_ADULT
+ )
+ assert (18, None) == Classifier.default_target_age_for_audience(
+ Classifier.AUDIENCE_ADULTS_ONLY
+ )
def test_default_audience_for_target_age(self):
def aud(low, high, expect):
@@ -119,6 +108,7 @@ def aud(low, high, expect):
def test_and_up(self):
"""Test the code that determines what "x and up" actually means."""
+
def u(young, keyword):
return Classifier.and_up(young, keyword)
@@ -132,7 +122,6 @@ def u(young, keyword):
assert 17 == u(14, "14+.")
assert 18 == u(18, "18+")
-
def test_scrub_identifier_can_override_name(self):
"""Test the ability of scrub_identifier to override the name
of the subject for classification purposes.
@@ -140,12 +129,14 @@ def test_scrub_identifier_can_override_name(self):
This is used e.g. in the BISACClassifier to ensure that a known BISAC
code is always mapped to its canonical name.
"""
+
class SetsNameForOneIdentifier(Classifier):
"A Classifier that insists on a certain name for one specific identifier"
+
@classmethod
def scrub_identifier(self, identifier):
- if identifier == 'A':
- return ('A', 'Use this name!')
+ if identifier == "A":
+ return ("A", "Use this name!")
else:
return identifier
@@ -172,7 +163,6 @@ def test_scrub_name(self):
class TestClassifierLookup(object):
-
def test_lookup(self):
assert DDC == Classifier.lookup(Classifier.DDC)
assert LCC == Classifier.lookup(Classifier.LCC)
@@ -181,15 +171,14 @@ def test_lookup(self):
assert GradeLevelClassifier == Classifier.lookup(Classifier.GRADE_LEVEL)
assert AgeClassifier == Classifier.lookup(Classifier.AGE_RANGE)
assert InterestLevelClassifier == Classifier.lookup(Classifier.INTEREST_LEVEL)
- assert None == Classifier.lookup('no-such-key')
+ assert None == Classifier.lookup("no-such-key")
-class TestNestedSubgenres(object):
+class TestNestedSubgenres(object):
def test_parents(self):
- assert ([classifier.Romance] ==
- list(classifier.Romantic_Suspense.parents))
+ assert [classifier.Romance] == list(classifier.Romantic_Suspense.parents)
- #eq_([classifier.Crime_Thrillers_Mystery, classifier.Mystery],
+ # eq_([classifier.Crime_Thrillers_Mystery, classifier.Mystery],
# list(classifier.Police_Procedurals.parents))
def test_self_and_subgenres(self):
@@ -198,13 +187,19 @@ def test_self_and_subgenres(self):
# - Historical Fantasy
# - Urban Fantasy
assert (
- set([classifier.Fantasy, classifier.Epic_Fantasy,
- classifier.Historical_Fantasy, classifier.Urban_Fantasy,
- ]) ==
- set(list(classifier.Fantasy.self_and_subgenres)))
+ set(
+ [
+ classifier.Fantasy,
+ classifier.Epic_Fantasy,
+ classifier.Historical_Fantasy,
+ classifier.Urban_Fantasy,
+ ]
+ )
+ == set(list(classifier.Fantasy.self_and_subgenres))
+ )
-class TestConsolidateWeights(object):
+class TestConsolidateWeights(object):
def test_consolidate(self):
# Asian History is a subcategory of the top-level category History.
weights = dict()
@@ -258,37 +253,44 @@ def test_consolidate_fails_when_threshold_not_met(self):
assert 100 == w2[classifier.History]
assert 1 == w2[classifier.Middle_East_History]
+
class TestFreeformAudienceClassifier(DatabaseTest):
def test_audience(self):
def audience(aud):
# The second param, `name`, is not used in the audience method
return FreeformAudienceClassifier.audience(aud, None)
- for val in ['children', 'pre-adolescent', 'beginning reader']:
+ for val in ["children", "pre-adolescent", "beginning reader"]:
assert Classifier.AUDIENCE_CHILDREN == audience(val)
- for val in ['young adult', 'ya', 'teenagers', 'adolescent', 'early adolescents']:
+ for val in [
+ "young adult",
+ "ya",
+ "teenagers",
+ "adolescent",
+ "early adolescents",
+ ]:
assert Classifier.AUDIENCE_YOUNG_ADULT == audience(val)
- assert audience('adult') == Classifier.AUDIENCE_ADULT
- assert audience('adults only') == Classifier.AUDIENCE_ADULTS_ONLY
- assert audience('all ages') == Classifier.AUDIENCE_ALL_AGES
- assert audience('research') == Classifier.AUDIENCE_RESEARCH
+ assert audience("adult") == Classifier.AUDIENCE_ADULT
+ assert audience("adults only") == Classifier.AUDIENCE_ADULTS_ONLY
+ assert audience("all ages") == Classifier.AUDIENCE_ALL_AGES
+ assert audience("research") == Classifier.AUDIENCE_RESEARCH
- assert audience('books for all ages') == None
+ assert audience("books for all ages") == None
def test_target_age(self):
def target_age(age):
return FreeformAudienceClassifier.target_age(age, None)
- assert target_age('beginning reader') == (5, 8)
- assert target_age('pre-adolescent') == (9, 12)
- assert target_age('all ages') == (Classifier.ALL_AGES_AGE_CUTOFF, None)
+ assert target_age("beginning reader") == (5, 8)
+ assert target_age("pre-adolescent") == (9, 12)
+ assert target_age("all ages") == (Classifier.ALL_AGES_AGE_CUTOFF, None)
- assert target_age('babies') == (None, None)
+ assert target_age("babies") == (None, None)
-class TestWorkClassifier(DatabaseTest):
+class TestWorkClassifier(DatabaseTest):
def setup_method(self):
super(TestWorkClassifier, self).setup_method()
self.work = self._work(with_license_pool=True)
@@ -375,28 +377,28 @@ def test_no_children_or_ya_signal_from_distributor_implies_book_is_for_adults(se
# to 500.
i = self.identifier
source = DataSource.lookup(self._db, DataSource.OVERDRIVE)
- for subject in ('Nonfiction', 'Science Fiction', 'History'):
+ for subject in ("Nonfiction", "Science Fiction", "History"):
c = i.classify(source, Subject.OVERDRIVE, subject, weight=1000)
self.classifier.add(c)
# There's a little bit of evidence that it's a children's book,
# but not enough to outweight the distributor's silence.
- c2 = self.identifier.classify(
- source, Subject.TAG, "Children's books", weight=1
- )
+ c2 = self.identifier.classify(source, Subject.TAG, "Children's books", weight=1)
self.classifier.add(c2)
self.classifier.prepare_to_classify()
# Overdrive classifications are regarded as 50 times more reliable
# than their actual weight, as per Classification.scaled_weight
assert 50000 == self.classifier.audience_weights[Classifier.AUDIENCE_ADULT]
- def test_adults_only_indication_from_distributor_has_no_implication_for_audience(self):
+ def test_adults_only_indication_from_distributor_has_no_implication_for_audience(
+ self,
+ ):
# Create some classifications that end up in
# direct_from_license_source, one of which implies the book is
# for adults only.
i = self.identifier
source = DataSource.lookup(self._db, DataSource.OVERDRIVE)
- for subject in ('Erotic Literature', 'Science Fiction', 'History'):
+ for subject in ("Erotic Literature", "Science Fiction", "History"):
c = i.classify(source, Subject.OVERDRIVE, subject, weight=1)
self.classifier.add(c)
@@ -420,12 +422,16 @@ def test_no_signal_from_distributor_has_no_implication_for_audience(self):
# distributor.
assert {} == self.classifier.audience_weights
- def test_children_or_ya_signal_from_distributor_has_no_immediate_implication_for_audience(self):
+ def test_children_or_ya_signal_from_distributor_has_no_immediate_implication_for_audience(
+ self,
+ ):
# This work has a classification direct from the distributor
# that implies the book is for children, so no conclusions are
# drawn in the prepare_to_classify() step.
source = DataSource.lookup(self._db, DataSource.OVERDRIVE)
- c = self.identifier.classify(source, Subject.OVERDRIVE, "Picture Books", weight=1000)
+ c = self.identifier.classify(
+ source, Subject.OVERDRIVE, "Picture Books", weight=1000
+ )
self.classifier.prepare_to_classify()
assert {} == self.classifier.audience_weights
@@ -444,10 +450,7 @@ def test_juvenile_classification_is_split_between_children_and_ya(self):
# (This classification has no bearing on audience and its
# weight will be ignored.)
- c2 = i.classify(
- source, Subject.TAG, "Pets",
- weight=1000
- )
+ c2 = i.classify(source, Subject.TAG, "Pets", weight=1000)
self.classifier.add(c2)
self.classifier.prepare_to_classify
genres, fiction, audience, target_age = self.classifier.classify()
@@ -481,9 +484,9 @@ def test_childrens_book_when_evidence_is_overwhelming(self):
# The evidence that this is a children's book is strong but
# not overwhelming.
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 10,
- Classifier.AUDIENCE_ADULTS_ONLY : 1,
- Classifier.AUDIENCE_CHILDREN : 22,
+ Classifier.AUDIENCE_ADULT: 10,
+ Classifier.AUDIENCE_ADULTS_ONLY: 1,
+ Classifier.AUDIENCE_CHILDREN: 22,
}
assert Classifier.AUDIENCE_ADULT == self.classifier.audience()
@@ -505,73 +508,70 @@ def test_ya_book_when_childrens_and_ya_combined_beat_adult(self):
# but it's more accurate than 'adult' and less likely to be
# a costly mistake than 'children'.
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 9,
- Classifier.AUDIENCE_ADULTS_ONLY : 0,
- Classifier.AUDIENCE_CHILDREN : 10,
- Classifier.AUDIENCE_YOUNG_ADULT : 9,
+ Classifier.AUDIENCE_ADULT: 9,
+ Classifier.AUDIENCE_ADULTS_ONLY: 0,
+ Classifier.AUDIENCE_CHILDREN: 10,
+ Classifier.AUDIENCE_YOUNG_ADULT: 9,
}
assert Classifier.AUDIENCE_YOUNG_ADULT == self.classifier.audience()
def test_genre_may_restrict_audience(self):
# The audience info says this is a YA book.
- self.classifier.audience_weights = {
- Classifier.AUDIENCE_YOUNG_ADULT : 1000
- }
+ self.classifier.audience_weights = {Classifier.AUDIENCE_YOUNG_ADULT: 1000}
# Without any genre information, it's classified as YA.
assert Classifier.AUDIENCE_YOUNG_ADULT == self.classifier.audience()
# But if it's Erotica, it is always classified as Adults Only.
- genres = { classifier.Erotica : 50,
- classifier.Science_Fiction: 50}
+ genres = {classifier.Erotica: 50, classifier.Science_Fiction: 50}
assert Classifier.AUDIENCE_ADULTS_ONLY == self.classifier.audience(genres)
-
+
def test_all_ages_audience(self):
# If the All Ages weight is more than the total adult weight and
# the total juvenile weight, then assign all ages as the audience.
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 50,
- Classifier.AUDIENCE_ADULTS_ONLY : 30,
- Classifier.AUDIENCE_ALL_AGES : 100,
- Classifier.AUDIENCE_CHILDREN : 30,
- Classifier.AUDIENCE_YOUNG_ADULT : 40,
+ Classifier.AUDIENCE_ADULT: 50,
+ Classifier.AUDIENCE_ADULTS_ONLY: 30,
+ Classifier.AUDIENCE_ALL_AGES: 100,
+ Classifier.AUDIENCE_CHILDREN: 30,
+ Classifier.AUDIENCE_YOUNG_ADULT: 40,
}
assert Classifier.AUDIENCE_ALL_AGES == self.classifier.audience()
# This works even if 'Children' looks much better than 'Adult'.
# 'All Ages' looks even better than that, so it wins.
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 1,
- Classifier.AUDIENCE_ADULTS_ONLY : 0,
- Classifier.AUDIENCE_ALL_AGES : 1000,
- Classifier.AUDIENCE_CHILDREN : 30,
- Classifier.AUDIENCE_YOUNG_ADULT : 29,
+ Classifier.AUDIENCE_ADULT: 1,
+ Classifier.AUDIENCE_ADULTS_ONLY: 0,
+ Classifier.AUDIENCE_ALL_AGES: 1000,
+ Classifier.AUDIENCE_CHILDREN: 30,
+ Classifier.AUDIENCE_YOUNG_ADULT: 29,
}
assert Classifier.AUDIENCE_ALL_AGES == self.classifier.audience()
# If the All Ages weight is smaller than the total adult weight,
# the audience is adults.
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 70,
- Classifier.AUDIENCE_ADULTS_ONLY : 10,
- Classifier.AUDIENCE_ALL_AGES : 79,
- Classifier.AUDIENCE_CHILDREN : 30,
- Classifier.AUDIENCE_YOUNG_ADULT : 40,
+ Classifier.AUDIENCE_ADULT: 70,
+ Classifier.AUDIENCE_ADULTS_ONLY: 10,
+ Classifier.AUDIENCE_ALL_AGES: 79,
+ Classifier.AUDIENCE_CHILDREN: 30,
+ Classifier.AUDIENCE_YOUNG_ADULT: 40,
}
assert Classifier.AUDIENCE_ADULT == self.classifier.audience()
-
+
def test_research_audience(self):
# If the research weight is larger than the total adult weight +
# all ages weight and larger than the total juvenile weight +
# all ages weight, then assign research as the audience
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 50,
- Classifier.AUDIENCE_ADULTS_ONLY : 30,
- Classifier.AUDIENCE_ALL_AGES : 10,
- Classifier.AUDIENCE_CHILDREN : 30,
- Classifier.AUDIENCE_YOUNG_ADULT : 150,
- Classifier.AUDIENCE_RESEARCH : 200,
+ Classifier.AUDIENCE_ADULT: 50,
+ Classifier.AUDIENCE_ADULTS_ONLY: 30,
+ Classifier.AUDIENCE_ALL_AGES: 10,
+ Classifier.AUDIENCE_CHILDREN: 30,
+ Classifier.AUDIENCE_YOUNG_ADULT: 150,
+ Classifier.AUDIENCE_RESEARCH: 200,
}
assert Classifier.AUDIENCE_RESEARCH == self.classifier.audience()
@@ -579,16 +579,15 @@ def test_research_audience(self):
# and all ages weight or total juvenile weight and all ages weight,
# then we get those audience values instead.
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 80,
- Classifier.AUDIENCE_ADULTS_ONLY : 10,
- Classifier.AUDIENCE_ALL_AGES : 20,
- Classifier.AUDIENCE_CHILDREN : 35,
- Classifier.AUDIENCE_YOUNG_ADULT : 40,
- Classifier.AUDIENCE_RESEARCH : 100,
+ Classifier.AUDIENCE_ADULT: 80,
+ Classifier.AUDIENCE_ADULTS_ONLY: 10,
+ Classifier.AUDIENCE_ALL_AGES: 20,
+ Classifier.AUDIENCE_CHILDREN: 35,
+ Classifier.AUDIENCE_YOUNG_ADULT: 40,
+ Classifier.AUDIENCE_RESEARCH: 100,
}
assert Classifier.AUDIENCE_ADULT == self.classifier.audience()
-
def test_format_classification_from_license_source_is_used(self):
# This book will be classified as a comic book, because
# the "comic books" classification comes from its license source.
@@ -614,11 +613,11 @@ def test_childrens_book_when_no_evidence_for_adult_book(self):
# buckets, so minimal evidence in the 'children' bucket is
# sufficient to be confident.
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 0,
- Classifier.AUDIENCE_ADULTS_ONLY : 0,
- Classifier.AUDIENCE_CHILDREN : 1,
- Classifier.AUDIENCE_RESEARCH : 0,
- Classifier.AUDIENCE_ALL_AGES : 0,
+ Classifier.AUDIENCE_ADULT: 0,
+ Classifier.AUDIENCE_ADULTS_ONLY: 0,
+ Classifier.AUDIENCE_CHILDREN: 1,
+ Classifier.AUDIENCE_RESEARCH: 0,
+ Classifier.AUDIENCE_ALL_AGES: 0,
}
assert Classifier.AUDIENCE_CHILDREN == self.classifier.audience()
@@ -627,11 +626,11 @@ def test_adults_only_threshold(self):
# majority, but it's high enough that we classify this work as
# 'adults only' to be safe.
self.classifier.audience_weights = {
- Classifier.AUDIENCE_ADULT : 4,
- Classifier.AUDIENCE_ADULTS_ONLY : 2,
- Classifier.AUDIENCE_CHILDREN : 4,
- Classifier.AUDIENCE_RESEARCH : 0,
- Classifier.AUDIENCE_ALL_AGES : 0,
+ Classifier.AUDIENCE_ADULT: 4,
+ Classifier.AUDIENCE_ADULTS_ONLY: 2,
+ Classifier.AUDIENCE_CHILDREN: 4,
+ Classifier.AUDIENCE_RESEARCH: 0,
+ Classifier.AUDIENCE_ALL_AGES: 0,
}
assert Classifier.AUDIENCE_ADULTS_ONLY == self.classifier.audience()
@@ -659,9 +658,7 @@ def test_target_age_weight_scaling(self):
# We have a louder but less reliable signal that this is a
# book for eleven-year-olds.
oclc = DataSource.lookup(self._db, DataSource.OCLC)
- c2 = self.identifier.classify(
- oclc, Subject.TAG, "Grade 6", weight=3
- )
+ c2 = self.identifier.classify(oclc, Subject.TAG, "Grade 6", weight=3)
self.classifier.add(c2)
# Both signals make it into the dataset, but they are weighted
@@ -675,9 +672,7 @@ def test_target_age_weight_scaling(self):
# And this affects the target age we choose.
a = self.classifier.target_age(Classifier.AUDIENCE_CHILDREN)
- assert (
- (5,8) ==
- self.classifier.target_age(Classifier.AUDIENCE_CHILDREN))
+ assert (5, 8) == self.classifier.target_age(Classifier.AUDIENCE_CHILDREN)
def test_target_age_errs_towards_wider_span(self):
i = self._identifier()
@@ -686,8 +681,9 @@ def test_target_age_errs_towards_wider_span(self):
c2 = i.classify(source, Subject.AGE_RANGE, "6-7", weight=1)
overdrive_edition, lp = self._edition(
- data_source_name=source.name, with_license_pool=True,
- identifier_id=i.identifier
+ data_source_name=source.name,
+ with_license_pool=True,
+ identifier_id=i.identifier,
)
self.classifier.work = self._work(presentation_edition=overdrive_edition)
for classification in i.classifications:
@@ -695,7 +691,7 @@ def test_target_age_errs_towards_wider_span(self):
genres, fiction, audience, target_age = self.classifier.classify()
assert Classifier.AUDIENCE_CHILDREN == audience
- assert (6,9) == target_age
+ assert (6, 9) == target_age
def test_fiction_status_restricts_genre(self):
# Classify a book to imply that it's 50% science fiction and
@@ -753,8 +749,7 @@ def test_overdrive_juvenile_implicit_target_age(self):
# target age range of 9-12.
i = self.identifier
source = DataSource.lookup(self._db, DataSource.OVERDRIVE)
- c = i.classify(source, Subject.OVERDRIVE, "Juvenile Fiction",
- weight=1)
+ c = i.classify(source, Subject.OVERDRIVE, "Juvenile Fiction", weight=1)
self.classifier.add(c)
self.classifier.prepare_to_classify()
assert [9] == list(self.classifier.target_age_lower_weights.keys())
@@ -810,29 +805,43 @@ def test_classify_sets_minimum_age_high_if_minimum_lower_than_maximum(self):
assert 10 == target_age[1]
def test_classify_uses_default_fiction_status(self):
- genres, fiction, audience, target_age = self.classifier.classify(default_fiction=True)
+ genres, fiction, audience, target_age = self.classifier.classify(
+ default_fiction=True
+ )
assert True == fiction
- genres, fiction, audience, target_age = self.classifier.classify(default_fiction=False)
+ genres, fiction, audience, target_age = self.classifier.classify(
+ default_fiction=False
+ )
assert False == fiction
- genres, fiction, audience, target_age = self.classifier.classify(default_fiction=None)
+ genres, fiction, audience, target_age = self.classifier.classify(
+ default_fiction=None
+ )
assert None == fiction
# The default isn't used if there's any information about the fiction status.
self.classifier.fiction_weights[False] = 1
- genres, fiction, audience, target_age = self.classifier.classify(default_fiction=None)
+ genres, fiction, audience, target_age = self.classifier.classify(
+ default_fiction=None
+ )
assert False == fiction
def test_classify_uses_default_audience(self):
genres, fiction, audience, target_age = self.classifier.classify()
assert None == audience
- genres, fiction, audience, target_age = self.classifier.classify(default_audience=Classifier.AUDIENCE_ADULT)
+ genres, fiction, audience, target_age = self.classifier.classify(
+ default_audience=Classifier.AUDIENCE_ADULT
+ )
assert Classifier.AUDIENCE_ADULT == audience
- genres, fiction, audience, target_age = self.classifier.classify(default_audience=Classifier.AUDIENCE_CHILDREN)
+ genres, fiction, audience, target_age = self.classifier.classify(
+ default_audience=Classifier.AUDIENCE_CHILDREN
+ )
assert Classifier.AUDIENCE_CHILDREN == audience
# The default isn't used if there's any information about the audience.
self.classifier.audience_weights[Classifier.AUDIENCE_ADULT] = 1
- genres, fiction, audience, target_age = self.classifier.classify(default_audience=None)
+ genres, fiction, audience, target_age = self.classifier.classify(
+ default_audience=None
+ )
assert Classifier.AUDIENCE_ADULT == audience
def test_classify(self):
@@ -840,7 +849,9 @@ def test_classify(self):
# do an overall test to verify that classify() returns a 4-tuple
# (genres, fiction, audience, target_age)
- self.work.presentation_edition.title = "Science Fiction: A Comprehensive History"
+ self.work.presentation_edition.title = (
+ "Science Fiction: A Comprehensive History"
+ )
i = self.identifier
source = DataSource.lookup(self._db, DataSource.OVERDRIVE)
c1 = i.classify(source, Subject.OVERDRIVE, "History", weight=10)
@@ -859,7 +870,7 @@ def test_classify(self):
assert "History" == list(genres.keys())[0].name
assert False == fiction
assert Classifier.AUDIENCE_YOUNG_ADULT == audience
- assert (12,17) == target_age
+ assert (12, 17) == target_age
def test_top_tier_values(self):
c = Counter()
@@ -868,9 +879,9 @@ def test_top_tier_values(self):
c = Counter(["a"])
assert set(["a"]) == WorkClassifier.top_tier_values(c)
- c = Counter([1,1,1,2,2,3,4,4,4])
- assert set([1,4]) == WorkClassifier.top_tier_values(c)
- c = Counter([1,1,1,2])
+ c = Counter([1, 1, 1, 2, 2, 3, 4, 4, 4])
+ assert set([1, 4]) == WorkClassifier.top_tier_values(c)
+ c = Counter([1, 1, 1, 2])
assert set([1]) == WorkClassifier.top_tier_values(c)
def test_duplicate_classification_ignored(self):
@@ -908,11 +919,14 @@ def test_staff_genre_overrides_others(self):
source = DataSource.lookup(self._db, DataSource.AXIS_360)
staff_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF)
classification1 = self._classification(
- identifier=self.identifier, subject=subject1,
- data_source=source, weight=10)
+ identifier=self.identifier, subject=subject1, data_source=source, weight=10
+ )
classification2 = self._classification(
- identifier=self.identifier, subject=subject2,
- data_source=staff_source, weight=1)
+ identifier=self.identifier,
+ subject=subject2,
+ data_source=staff_source,
+ weight=1,
+ )
self.classifier.add(classification1)
self.classifier.add(classification2)
(genre_weights, fiction, audience, target_age) = self.classifier.classify()
@@ -925,15 +939,17 @@ def test_staff_none_genre_overrides_others(self):
subject1 = self._subject(type="type1", identifier="subject1")
subject1.genre = genre1
subject2 = self._subject(
- type=Subject.SIMPLIFIED_GENRE,
- identifier=SimplifiedGenreClassifier.NONE
+ type=Subject.SIMPLIFIED_GENRE, identifier=SimplifiedGenreClassifier.NONE
)
classification1 = self._classification(
- identifier=self.identifier, subject=subject1,
- data_source=source, weight=10)
+ identifier=self.identifier, subject=subject1, data_source=source, weight=10
+ )
classification2 = self._classification(
- identifier=self.identifier, subject=subject2,
- data_source=staff_source, weight=1)
+ identifier=self.identifier,
+ subject=subject2,
+ data_source=staff_source,
+ weight=1,
+ )
self.classifier.add(classification1)
self.classifier.add(classification2)
(genre_weights, fiction, audience, target_age) = self.classifier.classify()
@@ -947,18 +963,20 @@ def test_staff_fiction_overrides_others(self):
subject2 = self._subject(type="type2", identifier="Psychology")
subject2.fiction = False
subject3 = self._subject(
- type=Subject.SIMPLIFIED_FICTION_STATUS,
- identifier="Fiction"
+ type=Subject.SIMPLIFIED_FICTION_STATUS, identifier="Fiction"
)
classification1 = self._classification(
- identifier=self.identifier, subject=subject1,
- data_source=source, weight=10)
+ identifier=self.identifier, subject=subject1, data_source=source, weight=10
+ )
classification2 = self._classification(
- identifier=self.identifier, subject=subject2,
- data_source=source, weight=10)
+ identifier=self.identifier, subject=subject2, data_source=source, weight=10
+ )
classification3 = self._classification(
- identifier=self.identifier, subject=subject3,
- data_source=staff_source, weight=1)
+ identifier=self.identifier,
+ subject=subject3,
+ data_source=staff_source,
+ weight=1,
+ )
self.classifier.add(classification1)
self.classifier.add(classification2)
self.classifier.add(classification3)
@@ -973,19 +991,25 @@ def test_staff_audience_overrides_others(self):
subject1.audience = "Adult"
subject2 = self._subject(type="type2", identifier="subject2")
subject2.audience = "Adult"
- subject3 = self._subject(
- type=Subject.FREEFORM_AUDIENCE,
- identifier="Children"
- )
+ subject3 = self._subject(type=Subject.FREEFORM_AUDIENCE, identifier="Children")
classification1 = self._classification(
- identifier=pool.identifier, subject=subject1,
- data_source=license_source, weight=10)
+ identifier=pool.identifier,
+ subject=subject1,
+ data_source=license_source,
+ weight=10,
+ )
classification2 = self._classification(
- identifier=pool.identifier, subject=subject2,
- data_source=license_source, weight=10)
+ identifier=pool.identifier,
+ subject=subject2,
+ data_source=license_source,
+ weight=10,
+ )
classification3 = self._classification(
- identifier=pool.identifier, subject=subject3,
- data_source=staff_source, weight=1)
+ identifier=pool.identifier,
+ subject=subject3,
+ data_source=staff_source,
+ weight=1,
+ )
self.classifier.add(classification1)
self.classifier.add(classification2)
self.classifier.add(classification3)
@@ -1001,19 +1025,19 @@ def test_staff_target_age_overrides_others(self):
subject2 = self._subject(type="type2", identifier="subject2")
subject2.target_age = NumericRange(6, 8, "[)")
subject2.weight_as_indicator_of_target_age = 1
- subject3 = self._subject(
- type=Subject.AGE_RANGE,
- identifier="10-13"
- )
+ subject3 = self._subject(type=Subject.AGE_RANGE, identifier="10-13")
classification1 = self._classification(
- identifier=self.identifier, subject=subject1,
- data_source=source, weight=10)
+ identifier=self.identifier, subject=subject1, data_source=source, weight=10
+ )
classification2 = self._classification(
- identifier=self.identifier, subject=subject2,
- data_source=source, weight=10)
+ identifier=self.identifier, subject=subject2, data_source=source, weight=10
+ )
classification3 = self._classification(
- identifier=self.identifier, subject=subject3,
- data_source=staff_source, weight=1)
+ identifier=self.identifier,
+ subject=subject3,
+ data_source=staff_source,
+ weight=1,
+ )
self.classifier.add(classification1)
self.classifier.add(classification2)
self.classifier.add(classification3)
@@ -1022,14 +1046,14 @@ def test_staff_target_age_overrides_others(self):
def test_not_inclusive_target_age(self):
staff_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF)
- subject = self._subject(
- type=Subject.AGE_RANGE,
- identifier="10-12"
- )
+ subject = self._subject(type=Subject.AGE_RANGE, identifier="10-12")
subject.target_age = NumericRange(9, 13, "()")
classification = self._classification(
- identifier=self.identifier, subject=subject,
- data_source=staff_source, weight=1)
+ identifier=self.identifier,
+ subject=subject,
+ data_source=staff_source,
+ weight=1,
+ )
self.classifier.add(classification)
(genre_weights, fiction, audience, target_age) = self.classifier.classify()
assert (10, 12) == target_age
diff --git a/tests/classifiers/test_ddc.py b/tests/classifiers/test_ddc.py
index 32727a72f..9c74ca357 100644
--- a/tests/classifiers/test_ddc.py
+++ b/tests/classifiers/test_ddc.py
@@ -1,9 +1,9 @@
-from ...classifier.ddc import DeweyDecimalClassifier as DDC
-from ...classifier import *
from ... import classifier
+from ...classifier import *
+from ...classifier.ddc import DeweyDecimalClassifier as DDC
-class TestDewey(object):
+class TestDewey(object):
def test_name_for(self):
assert "General statistics of Europe" == DDC.name_for("314")
assert "Biography" == DDC.name_for("B")
@@ -30,16 +30,13 @@ def aud(identifier):
assert None == aud("FIC")
assert None == aud("Fic")
-
# We could derive audience=Adult from the lack of a
# distinguishing "J" or "E" here, but we've seen this go
# wrong, and it's not terribly important overall, so we don't.
assert None == aud("B")
assert None == aud("400")
-
def test_is_fiction(self):
-
def fic(identifier):
return DDC.is_fiction(*DDC.scrub_identifier(identifier))
diff --git a/tests/classifiers/test_keyword.py b/tests/classifiers/test_keyword.py
index e8bf26896..944c9c38f 100644
--- a/tests/classifiers/test_keyword.py
+++ b/tests/classifiers/test_keyword.py
@@ -1,13 +1,11 @@
from ... import classifier
from ...classifier import *
-from ...classifier.keyword import (
- KeywordBasedClassifier as Keyword,
- LCSHClassifier as LCSH,
- FASTClassifier as FAST,
-)
+from ...classifier.keyword import FASTClassifier as FAST
+from ...classifier.keyword import KeywordBasedClassifier as Keyword
+from ...classifier.keyword import LCSHClassifier as LCSH
-class TestLCSH(object):
+class TestLCSH(object):
def test_is_fiction(self):
def fic(lcsh):
return LCSH.is_fiction(None, LCSH.scrub_name(lcsh))
@@ -25,9 +23,9 @@ def fic(lcsh):
assert None == fic("Kentucky")
assert None == fic("Social life and customs")
-
def test_audience(self):
child = Classifier.AUDIENCE_CHILDREN
+
def aud(lcsh):
return LCSH.audience(None, LCSH.scrub_name(lcsh))
@@ -39,6 +37,7 @@ def aud(lcsh):
assert None == aud("Runaway children")
assert None == aud("Humor")
+
class TestKeyword(object):
def genre(self, keyword):
scrub = Keyword.scrub_identifier(keyword)
@@ -54,11 +53,11 @@ def test_higher_tier_wins(self):
assert classifier.Romance == self.genre("Regency romances")
def test_audience(self):
- assert (Classifier.AUDIENCE_YOUNG_ADULT ==
- Keyword.audience(None, "Teens / Fiction"))
+ assert Classifier.AUDIENCE_YOUNG_ADULT == Keyword.audience(
+ None, "Teens / Fiction"
+ )
- assert (Classifier.AUDIENCE_YOUNG_ADULT ==
- Keyword.audience(None, "teen books"))
+ assert Classifier.AUDIENCE_YOUNG_ADULT == Keyword.audience(None, "teen books")
def test_subgenre_wins_over_genre(self):
# Asian_History wins over History, even though they both
@@ -75,15 +74,18 @@ def test_children_audience_implies_no_genre(self):
assert None == self.genre("Children's Books")
def test_young_adult_wins_over_children(self):
- assert (Classifier.AUDIENCE_YOUNG_ADULT ==
- Keyword.audience(None, "children's books - young adult fiction"))
+ assert Classifier.AUDIENCE_YOUNG_ADULT == Keyword.audience(
+ None, "children's books - young adult fiction"
+ )
def test_juvenile_romance_means_young_adult(self):
- assert (Classifier.AUDIENCE_YOUNG_ADULT ==
- Keyword.audience(None, "juvenile fiction / love & romance"))
+ assert Classifier.AUDIENCE_YOUNG_ADULT == Keyword.audience(
+ None, "juvenile fiction / love & romance"
+ )
- assert (Classifier.AUDIENCE_YOUNG_ADULT ==
- Keyword.audience(None, "teenage romance"))
+ assert Classifier.AUDIENCE_YOUNG_ADULT == Keyword.audience(
+ None, "teenage romance"
+ )
def test_audience_match(self):
(audience, match) = Keyword.audience_match("teen books")
@@ -108,40 +110,34 @@ def test_improvements(self):
since the original work.
"""
# was Literary Fiction
- assert (classifier.Science_Fiction ==
- Keyword.genre(None, "Science Fiction - General"))
+ assert classifier.Science_Fiction == Keyword.genre(
+ None, "Science Fiction - General"
+ )
# Was General Fiction (!)
- assert (classifier.Science_Fiction ==
- Keyword.genre(None, "Science Fiction"))
+ assert classifier.Science_Fiction == Keyword.genre(None, "Science Fiction")
- assert (classifier.Science_Fiction ==
- Keyword.genre(None, "Speculative Fiction"))
+ assert classifier.Science_Fiction == Keyword.genre(None, "Speculative Fiction")
- assert (classifier.Social_Sciences ==
- Keyword.genre(None, "Social Sciences"))
+ assert classifier.Social_Sciences == Keyword.genre(None, "Social Sciences")
- assert (classifier.Social_Sciences ==
- Keyword.genre(None, "Social Science"))
+ assert classifier.Social_Sciences == Keyword.genre(None, "Social Science")
- assert (classifier.Social_Sciences ==
- Keyword.genre(None, "Human Science"))
+ assert classifier.Social_Sciences == Keyword.genre(None, "Human Science")
# was genreless
- assert (classifier.Short_Stories ==
- Keyword.genre(None, "Short Stories"))
+ assert classifier.Short_Stories == Keyword.genre(None, "Short Stories")
# was Military History
- assert (classifier.Military_SF ==
- Keyword.genre(None, "Interstellar Warfare"))
+ assert classifier.Military_SF == Keyword.genre(None, "Interstellar Warfare")
# was Fantasy
- assert (classifier.Games ==
- Keyword.genre(None, "Games / Role Playing & Fantasy"))
+ assert classifier.Games == Keyword.genre(None, "Games / Role Playing & Fantasy")
# This isn't perfect but it covers most cases.
- assert (classifier.Media_Tie_in_SF ==
- Keyword.genre(None, "TV, Movie, Video game adaptations"))
+ assert classifier.Media_Tie_in_SF == Keyword.genre(
+ None, "TV, Movie, Video game adaptations"
+ )
# Previously only 'nonfiction' was recognized.
assert False == Keyword.is_fiction(None, "Non-Fiction")
diff --git a/tests/classifiers/test_lcc.py b/tests/classifiers/test_lcc.py
index 3e1b1a83f..4cb3261a3 100644
--- a/tests/classifiers/test_lcc.py
+++ b/tests/classifiers/test_lcc.py
@@ -2,8 +2,8 @@
from ...classifier import *
from ...classifier.lcc import LCCClassifier as LCC
-class TestLCC(object):
+class TestLCC(object):
def test_name_for(self):
child = Classifier.AUDIENCE_CHILDREN
@@ -13,7 +13,7 @@ def test_name_for(self):
assert "English literature" == LCC.name_for("PR")
assert "Fiction and juvenile belles lettres" == LCC.name_for("PZ")
assert "HISTORY OF THE AMERICAS" == LCC.name_for("E")
- assert 'Literature (General)' == LCC.name_for("PN")
+ assert "Literature (General)" == LCC.name_for("PN")
assert None == LCC.name_for("no-such-key")
def test_audience(self):
@@ -36,10 +36,10 @@ def aud(identifier):
assert None == aud("PA")
assert None == aud("J821.8 CARRIKK")
-
def test_is_fiction(self):
def fic(lcc):
return LCC.is_fiction(LCC.scrub_identifier(lcc), None)
+
assert False == fic("A")
assert False == fic("AB")
assert False == fic("PA")
diff --git a/tests/classifiers/test_overdrive.py b/tests/classifiers/test_overdrive.py
index 7f2641767..7b676d086 100644
--- a/tests/classifiers/test_overdrive.py
+++ b/tests/classifiers/test_overdrive.py
@@ -2,30 +2,30 @@
from ...classifier import *
from ...classifier.overdrive import OverdriveClassifier as Overdrive
-class TestOverdriveClassifier(object):
+class TestOverdriveClassifier(object):
def test_lookup(self):
assert Overdrive == Classifier.lookup(Classifier.OVERDRIVE)
def test_scrub_identifier(self):
scrub = Overdrive.scrub_identifier
- assert ("Foreign Language Study" ==
- scrub("Foreign Language Study - Italian"))
- assert ("Foreign Language Study" ==
- scrub("Foreign Language Study - Klingon"))
+ assert "Foreign Language Study" == scrub("Foreign Language Study - Italian")
+ assert "Foreign Language Study" == scrub("Foreign Language Study - Klingon")
assert "Foreign Affairs" == scrub("Foreign Affairs")
def test_target_age(self):
def a(x, y):
- return Overdrive.target_age(x,y)
- assert (0,4) == a("Picture Book Nonfiction", None)
- assert (5,8) == a("Beginning Reader", None)
- assert (12,17) == a("Young Adult Fiction", None)
- assert (None,None) == a("Fiction", None)
+ return Overdrive.target_age(x, y)
+
+ assert (0, 4) == a("Picture Book Nonfiction", None)
+ assert (5, 8) == a("Beginning Reader", None)
+ assert (12, 17) == a("Young Adult Fiction", None)
+ assert (None, None) == a("Fiction", None)
def test_audience(self):
def a(identifier):
return Overdrive.audience(identifier, None)
+
assert Classifier.AUDIENCE_CHILDREN == a("Picture Books")
assert Classifier.AUDIENCE_CHILDREN == a("Beginning Reader")
assert Classifier.AUDIENCE_CHILDREN == a("Children's Video")
@@ -92,6 +92,7 @@ def f(identifier):
def test_genre(self):
"""Check the fiction status and genre of every known Overdrive
subject."""
+
def g(x, fiction=None):
genre = Overdrive.genre(x, None, fiction=fiction)
if genre:
@@ -165,7 +166,7 @@ def g(x, fiction=None):
assert "Social Sciences" == g("Gender Studies")
assert "Social Sciences" == g("Genealogy")
- assert None == g("Geography") # This is all over the place.
+ assert None == g("Geography") # This is all over the place.
assert "Reference & Study Aids" == g("Grammar & Language Usage")
assert "Health & Diet" == g("Health & Fitness")
assert "Historical Fiction" == g("Historical Fiction")
@@ -175,14 +176,14 @@ def g(x, fiction=None):
assert None == g("Human Rights")
assert "Humorous Fiction" == g("Humor (Fiction)")
assert "Humorous Nonfiction" == g("Humor (Nonfiction)")
- assert None == g("Inspirational") # Mix of Christian nonfiction and fiction
+ assert None == g("Inspirational") # Mix of Christian nonfiction and fiction
assert None == g("Journalism")
assert "Judaism" == g("Judaica")
assert None == g("Juvenile Fiction")
assert None == g("Juvenile Literature")
assert None == g("Juvenile Nonfiction")
assert "Literary Criticism" == g("Language Arts")
- assert None == g("Latin") # A language, not a genre
+ assert None == g("Latin") # A language, not a genre
assert "Law" == g("Law")
assert "Short Stories" == g("Literary Anthologies")
assert "Literary Criticism" == g("Literary Criticism")
@@ -193,7 +194,7 @@ def g(x, fiction=None):
assert "Social Sciences" == g("Media Studies")
assert "Medical" == g("Medical")
assert "Military History" == g("Military")
- assert None == g("Multi-Cultural") # All over the place
+ assert None == g("Multi-Cultural") # All over the place
assert "Music" == g("Music")
assert "Mystery" == g("Mystery")
assert "Folklore" == g("Mythology")
diff --git a/tests/classifiers/test_simplified.py b/tests/classifiers/test_simplified.py
index 87c93e070..b9e2dd50f 100644
--- a/tests/classifiers/test_simplified.py
+++ b/tests/classifiers/test_simplified.py
@@ -1,8 +1,8 @@
from ... import classifier
from ...classifier import *
-class TestSimplifiedGenreClassifier(object):
+class TestSimplifiedGenreClassifier(object):
def test_scrub_identifier(self):
"""The URI for a Library Simplified genre is treated the same as
the genre itself.
diff --git a/tests/conftest.py b/tests/conftest.py
index f34090a52..ca859b038 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,10 +1,9 @@
-
# This is kind of janky, but we import the session fixture
# into these tests here. Plugins need absolute import paths
# and we don't have a package structure that gives us a reliable
# import path, so we construct one.
# todo: reorg core file structure so we have a reliable package name
-from os.path import abspath, dirname, basename
+from os.path import abspath, basename, dirname
# Pull in the session_fixture defined in core/testing.py
# which does the database setup and initialization
diff --git a/tests/lcp/test_credential.py b/tests/lcp/test_credential.py
index 2b68f6bea..a8c1a9701 100644
--- a/tests/lcp/test_credential.py
+++ b/tests/lcp/test_credential.py
@@ -2,10 +2,10 @@
from mock import patch
from parameterized import parameterized
-from ...testing import DatabaseTest
from ...lcp.credential import LCPCredentialFactory, LCPCredentialType
from ...lcp.exceptions import LCPError
from ...model import Credential, DataSource
+from ...testing import DatabaseTest
class TestCredentialFactory(DatabaseTest):
@@ -14,27 +14,33 @@ def setup_method(self):
self._factory = LCPCredentialFactory()
self._patron = self._patron()
- self._data_source = DataSource.lookup(self._db, DataSource.INTERNAL_PROCESSING, autocreate=True)
+ self._data_source = DataSource.lookup(
+ self._db, DataSource.INTERNAL_PROCESSING, autocreate=True
+ )
- @parameterized.expand([
- (
- 'get_patron_id',
+ @parameterized.expand(
+ [
+ (
+ "get_patron_id",
LCPCredentialType.PATRON_ID.value,
- 'get_patron_id',
- '52a190d1-cd69-4794-9d7a-1ec50392697f'
- ),
- (
- 'get_patron_passphrase',
+ "get_patron_id",
+ "52a190d1-cd69-4794-9d7a-1ec50392697f",
+ ),
+ (
+ "get_patron_passphrase",
LCPCredentialType.LCP_PASSPHRASE.value,
- 'get_patron_passphrase',
- '52a190d1-cd69-4794-9d7a-1ec50392697f'
- )
- ])
+ "get_patron_passphrase",
+ "52a190d1-cd69-4794-9d7a-1ec50392697f",
+ ),
+ ]
+ )
def test_getter(self, _, credential_type, method_name, expected_result):
# Arrange
credential = Credential(credential=expected_result)
- with patch.object(Credential, 'persistent_token_create') as persistent_token_create_mock:
+ with patch.object(
+ Credential, "persistent_token_create"
+ ) as persistent_token_create_mock:
persistent_token_create_mock.return_value = (credential, True)
method = getattr(self._factory, method_name)
@@ -45,7 +51,8 @@ def test_getter(self, _, credential_type, method_name, expected_result):
# Assert
assert result == expected_result
persistent_token_create_mock.assert_called_once_with(
- self._db, self._data_source, credential_type, self._patron, None)
+ self._db, self._data_source, credential_type, self._patron, None
+ )
def test_get_hashed_passphrase_raises_exception_when_there_is_no_passphrase(self):
# Act, assert
@@ -54,7 +61,7 @@ def test_get_hashed_passphrase_raises_exception_when_there_is_no_passphrase(self
def test_get_hashed_passphrase_returns_existing_hashed_passphrase(self):
# Arrange
- expected_result = '12345'
+ expected_result = "12345"
# Act
self._factory.set_hashed_passphrase(self._db, self._patron, expected_result)
diff --git a/tests/models/test_admin.py b/tests/models/test_admin.py
index 1cb0f1c9a..06299ca37 100644
--- a/tests/models/test_admin.py
+++ b/tests/models/test_admin.py
@@ -1,11 +1,10 @@
# encoding: utf-8
import pytest
-from ...testing import DatabaseTest
+
from ...model import create
-from ...model.admin import (
- Admin,
- AdminRole,
-)
+from ...model.admin import Admin, AdminRole
+from ...testing import DatabaseTest
+
class TestAdmin(DatabaseTest):
def setup_method(self):
@@ -15,7 +14,7 @@ def setup_method(self):
def test_password_hashed(self):
pytest.raises(NotImplementedError, lambda: self.admin.password)
- assert self.admin.password_hashed.startswith('$2a$')
+ assert self.admin.password_hashed.startswith("$2a$")
def test_with_password(self):
self._db.delete(self.admin)
@@ -124,8 +123,8 @@ def test_can_see_collection(self):
c2.libraries += [self._default_library]
# The admin has no roles yet.
- assert False == self.admin.can_see_collection(c1);
- assert False == self.admin.can_see_collection(c2);
+ assert False == self.admin.can_see_collection(c1)
+ assert False == self.admin.can_see_collection(c2)
self.admin.add_role(AdminRole.SYSTEM_ADMIN)
assert True == self.admin.can_see_collection(c1)
@@ -133,24 +132,24 @@ def test_can_see_collection(self):
self.admin.remove_role(AdminRole.SYSTEM_ADMIN)
self.admin.add_role(AdminRole.SITEWIDE_LIBRARY_MANAGER)
- assert False == self.admin.can_see_collection(c1);
- assert True == self.admin.can_see_collection(c2);
+ assert False == self.admin.can_see_collection(c1)
+ assert True == self.admin.can_see_collection(c2)
self.admin.remove_role(AdminRole.SITEWIDE_LIBRARY_MANAGER)
self.admin.add_role(AdminRole.SITEWIDE_LIBRARIAN)
- assert False == self.admin.can_see_collection(c1);
- assert True == self.admin.can_see_collection(c2);
+ assert False == self.admin.can_see_collection(c1)
+ assert True == self.admin.can_see_collection(c2)
self.admin.remove_role(AdminRole.SITEWIDE_LIBRARIAN)
self.admin.add_role(AdminRole.LIBRARY_MANAGER, self._default_library)
- assert False == self.admin.can_see_collection(c1);
- assert True == self.admin.can_see_collection(c2);
+ assert False == self.admin.can_see_collection(c1)
+ assert True == self.admin.can_see_collection(c2)
self.admin.remove_role(AdminRole.LIBRARY_MANAGER, self._default_library)
self.admin.add_role(AdminRole.LIBRARIAN, self._default_library)
- assert False == self.admin.can_see_collection(c1);
- assert True == self.admin.can_see_collection(c2);
+ assert False == self.admin.can_see_collection(c1)
+ assert True == self.admin.can_see_collection(c2)
self.admin.remove_role(AdminRole.LIBRARIAN, self._default_library)
- assert False == self.admin.can_see_collection(c1);
- assert False == self.admin.can_see_collection(c2);
+ assert False == self.admin.can_see_collection(c1)
+ assert False == self.admin.can_see_collection(c2)
diff --git a/tests/models/test_appeal.py b/tests/models/test_appeal.py
index 8d2ca0130..7babf224a 100644
--- a/tests/models/test_appeal.py
+++ b/tests/models/test_appeal.py
@@ -1,13 +1,8 @@
-from ...testing import (
- DatabaseTest,
-)
+from ...model import Work
+from ...testing import DatabaseTest
-from ...model import (
- Work,
-)
class TestAppealAssignment(DatabaseTest):
-
def test_assign_appeals(self):
work = self._work()
work.assign_appeals(0.50, 0.25, 0.20, 0.05)
diff --git a/tests/models/test_cachedfeed.py b/tests/models/test_cachedfeed.py
index cfd86fc29..d6d25f102 100644
--- a/tests/models/test_cachedfeed.py
+++ b/tests/models/test_cachedfeed.py
@@ -1,23 +1,20 @@
# encoding: utf-8
-import pytest
import datetime
-from ...testing import DatabaseTest
+
+import pytest
+
from ...classifier import Classifier
-from ...lane import (
- Facets,
- Pagination,
- Lane,
- WorkList,
-)
+from ...lane import Facets, Lane, Pagination, WorkList
from ...model.cachedfeed import CachedFeed
from ...model.configuration import ConfigurationSetting
from ...opds import AcquisitionFeed
+from ...testing import DatabaseTest
+from ...util.datetime_helpers import utc_now
from ...util.flask_util import OPDSFeedResponse
from ...util.opds_writer import OPDSFeed
-from ...util.datetime_helpers import utc_now
-class MockFeedGenerator(object):
+class MockFeedGenerator(object):
def __init__(self):
self.calls = []
@@ -27,7 +24,6 @@ def __call__(self):
class TestCachedFeed(DatabaseTest):
-
def test_fetch(self):
# Verify that CachedFeed.fetch looks in the database for a
# matching CachedFeed
@@ -46,6 +42,7 @@ class Mock(CachedFeed):
def _prepare_keys(cls, *args):
cls._prepare_keys_called_with = args
return cls._keys
+
# _prepare_keys always returns this named tuple. Manipulate its
# members to test different bits of fetch().
_keys = CachedFeed.CachedFeedKeys(
@@ -54,14 +51,15 @@ def _prepare_keys(cls, *args):
work=work,
lane_id=lane.id,
unique_key="unique key",
- facets_key='facets',
- pagination_key='pagination',
+ facets_key="facets",
+ pagination_key="pagination",
)
@classmethod
def max_cache_age(cls, *args):
cls.max_cache_age_called_with = args
return cls.MAX_CACHE_AGE
+
# max_cache_age always returns whatever value is stored here.
MAX_CACHE_AGE = 42
@@ -69,8 +67,10 @@ def max_cache_age(cls, *args):
def _should_refresh(cls, *args):
cls._should_refresh_called_with = args
return cls.SHOULD_REFRESH
+
# _should_refresh always returns whatever value is stored here.
SHOULD_REFRESH = True
+
m = Mock.fetch
def clear_helpers():
@@ -78,6 +78,7 @@ def clear_helpers():
Mock._prepare_keys_called_with = None
Mock.max_cache_age_called_with = None
Mock._should_refresh_called_with = None
+
clear_helpers()
# Define the hook function that is called whenever
@@ -90,8 +91,7 @@ def clear_helpers():
pagination = object()
max_age = object()
result1 = m(
- self._db, worklist, facets, pagination, refresher, max_age,
- raw=True
+ self._db, worklist, facets, pagination, refresher, max_age, raw=True
)
now = utc_now()
assert isinstance(result1, CachedFeed)
@@ -121,21 +121,25 @@ def clear_helpers():
# We called _prepare_keys with all the necessary information
# to create a named tuple.
assert (
- (self._db, worklist, facets, pagination) ==
- Mock._prepare_keys_called_with)
+ self._db,
+ worklist,
+ facets,
+ pagination,
+ ) == Mock._prepare_keys_called_with
# We then called max_cache_age on the WorkList, the page
# type, and the max_age object passed in to fetch().
assert (
- (worklist, "mock type", facets, max_age) ==
- Mock.max_cache_age_called_with)
+ worklist,
+ "mock type",
+ facets,
+ max_age,
+ ) == Mock.max_cache_age_called_with
# Then we called _should_refresh with the feed retrieved from
# the database (which was None), and the return value of
# max_cache_age.
- assert (
- (None, 42) ==
- Mock._should_refresh_called_with)
+ assert (None, 42) == Mock._should_refresh_called_with
# Since _should_refresh is hard-coded to return True, we then
# called refresher() to generate a feed and created a new
@@ -147,8 +151,7 @@ def clear_helpers():
# refresher() will be called again.
clear_helpers()
result2 = m(
- self._db, worklist, facets, pagination, refresher, max_age,
- raw=True
+ self._db, worklist, facets, pagination, refresher, max_age, raw=True
)
# The CachedFeed from before was reused.
@@ -162,16 +165,13 @@ def clear_helpers():
# Since there was a matching CachedFeed in the database
# already, that CachedFeed was passed into _should_refresh --
# previously this value was None.
- assert (
- (result1, 42) ==
- Mock._should_refresh_called_with)
+ assert (result1, 42) == Mock._should_refresh_called_with
# Now try the scenario where the feed does not need to be refreshed.
clear_helpers()
Mock.SHOULD_REFRESH = False
result3 = m(
- self._db, worklist, facets, pagination, refresher, max_age,
- raw=True
+ self._db, worklist, facets, pagination, refresher, max_age, raw=True
)
# Not only do we have the same CachedFeed as before, but its
@@ -184,17 +184,12 @@ def clear_helpers():
# cached feed before forging ahead.
Mock.MAX_CACHE_AGE = 0
clear_helpers()
- m(
- self._db, worklist, facets, pagination, refresher, max_age,
- raw=True
- )
+ m(self._db, worklist, facets, pagination, refresher, max_age, raw=True)
# A matching CachedFeed exists in the database, but we didn't
# even look for it, because we knew we'd be looking it up
# again after feed generation.
- assert (
- (None, 0) ==
- Mock._should_refresh_called_with)
+ assert (None, 0) == Mock._should_refresh_called_with
def test_no_race_conditions(self):
# Why do we look up a CachedFeed again after feed generation?
@@ -226,10 +221,8 @@ def simultaneous_refresher():
# refresher is running.
def other_thread_refresher():
return "Another thread made a feed."
- m(
- self._db, wl, facets, pagination, other_thread_refresher, 0,
- raw=True
- )
+
+ m(self._db, wl, facets, pagination, other_thread_refresher, 0, raw=True)
return "Then this thread made a feed."
@@ -237,8 +230,7 @@ def other_thread_refresher():
# CachedFeed.fetch() _again_, which will call
# other_thread_refresher().
result = m(
- self._db, wl, facets, pagination, simultaneous_refresher, 0,
- raw=True
+ self._db, wl, facets, pagination, simultaneous_refresher, 0, raw=True
)
# We ended up with a single CachedFeed containing the
@@ -254,16 +246,20 @@ def other_thread_refresher():
now = utc_now()
tomorrow = now + datetime.timedelta(days=1)
yesterday = now - datetime.timedelta(days=1)
+
def tomorrow_vs_now():
result.content = "Someone in the background set tomorrow's content."
result.timestamp = tomorrow
return "Today's content can't compete."
+
tomorrow_result = m(
self._db, wl, facets, pagination, tomorrow_vs_now, 0, raw=True
)
assert tomorrow_result == result
- assert ("Someone in the background set tomorrow's content." ==
- tomorrow_result.content)
+ assert (
+ "Someone in the background set tomorrow's content."
+ == tomorrow_result.content
+ )
assert tomorrow_result.timestamp == tomorrow
# Here, the other thread sets .timestamp to a date in the past, and
@@ -272,9 +268,8 @@ def yesterday_vs_now():
result.content = "Someone in the background set yesterday's content."
result.timestamp = yesterday
return "Today's content is fresher."
- now_result = m(
- self._db, wl, facets, pagination, yesterday_vs_now, 0, raw=True
- )
+
+ now_result = m(self._db, wl, facets, pagination, yesterday_vs_now, 0, raw=True)
# We got the same CachedFeed we've been getting this whole
# time, but the outdated data set by the 'background thread'
@@ -297,9 +292,15 @@ def timestamp_cleared_in_background():
result.timestamp = None
return "Non-weird content."
+
result2 = m(
- self._db, wl, facets, pagination, timestamp_cleared_in_background,
- 0, raw=True
+ self._db,
+ wl,
+ facets,
+ pagination,
+ timestamp_cleared_in_background,
+ 0,
+ raw=True,
)
now = utc_now()
@@ -319,9 +320,9 @@ def content_cleared_in_background():
result2.timestamp = tomorrow
return "Non-weird content."
+
result3 = m(
- self._db, wl, facets, pagination, content_cleared_in_background, 0,
- raw=True
+ self._db, wl, facets, pagination, content_cleared_in_background, 0, raw=True
)
now = utc_now()
@@ -349,10 +350,9 @@ def test_response_format(self):
def refresh():
return "Here's a feed."
- private=object()
+ private = object()
r = CachedFeed.fetch(
- self._db, wl, facets, pagination, refresh, max_age=102,
- private=private
+ self._db, wl, facets, pagination, refresh, max_age=102, private=private
)
assert isinstance(r, OPDSFeedResponse)
assert 200 == r.status_code
@@ -370,8 +370,7 @@ def refresh():
# Try it again as a cache hit.
r = CachedFeed.fetch(
- self._db, wl, facets, pagination, refresh, max_age=102,
- private=private
+ self._db, wl, facets, pagination, refresh, max_age=102, private=private
)
assert isinstance(r, OPDSFeedResponse)
assert 200 == r.status_code
@@ -383,8 +382,13 @@ def refresh():
# applies to the _database_ cache. The client is told to cache
# the feed for the default period.
r = CachedFeed.fetch(
- self._db, wl, facets, pagination, refresh,
- max_age=CachedFeed.CACHE_FOREVER, private=private
+ self._db,
+ wl,
+ facets,
+ pagination,
+ refresh,
+ max_age=CachedFeed.CACHE_FOREVER,
+ private=private,
)
assert isinstance(r, OPDSFeedResponse)
assert OPDSFeed.DEFAULT_MAX_AGE == r.max_age
@@ -393,12 +397,11 @@ def refresh():
# has root lanes, `private` is always set to True, even if we
# asked for the opposite.
from unittest.mock import PropertyMock, patch
+
from ...model import Library
+
Library._has_root_lane_cache[self._default_library.id] = True
- r = CachedFeed.fetch(
- self._db, wl, facets, pagination, refresh,
- private=False
- )
+ r = CachedFeed.fetch(self._db, wl, facets, pagination, refresh, private=False)
assert isinstance(r, OPDSFeedResponse)
assert True == r.private
@@ -442,6 +445,7 @@ def test_max_cache_age(self):
# Otherwise, the faceting object gets a chance to weigh in.
class MockFacets(object):
max_cache_age = 22
+
facets = MockFacets()
assert 22 == m(None, "feed type", facets=facets)
@@ -484,6 +488,7 @@ def test__prepare_keys(self):
# First, prepare some mock classes.
class MockCachedFeed(CachedFeed):
feed_type_called_with = None
+
@classmethod
def feed_type(cls, worklist, facets):
cls.feed_type_called_with = (worklist, facets)
@@ -512,8 +517,8 @@ class MockPagination(object):
assert None == keys.work
assert lane.id == keys.lane_id
assert None == keys.unique_key
- assert '' == keys.facets_key
- assert '' == keys.pagination_key
+ assert "" == keys.facets_key
+ assert "" == keys.pagination_key
# When pagination and/or facets are available, facets_key and
# pagination_key are set appropriately.
@@ -530,8 +535,10 @@ class MockPagination(object):
# but keys.unique_key is set to worklist.unique_key.
worklist = WorkList()
worklist.initialize(
- library=self._default_library, display_name="wl",
- languages=["eng", "spa"], audiences=[Classifier.AUDIENCE_CHILDREN]
+ library=self._default_library,
+ display_name="wl",
+ languages=["eng", "spa"],
+ audiences=[Classifier.AUDIENCE_CHILDREN],
)
keys = m(self._db, worklist, None, None)
@@ -541,8 +548,8 @@ class MockPagination(object):
assert None == keys.lane_id
assert "wl-eng,spa-Children" == keys.unique_key
assert keys.unique_key == worklist.unique_key
- assert '' == keys.facets_key
- assert '' == keys.pagination_key
+ assert "" == keys.facets_key
+ assert "" == keys.pagination_key
# When a WorkList is associated with a specific .work,
# that information is included as keys.work.
@@ -565,14 +572,10 @@ def __init__(self, timestamp):
now = utc_now()
# This feed was generated five minutes ago.
- five_minutes_old = MockCachedFeed(
- now - datetime.timedelta(minutes=5)
- )
+ five_minutes_old = MockCachedFeed(now - datetime.timedelta(minutes=5))
# This feed was generated a thousand years ago.
- ancient = MockCachedFeed(
- now - datetime.timedelta(days=1000*365)
- )
+ ancient = MockCachedFeed(now - datetime.timedelta(days=1000 * 365))
# If we intend to cache forever, then even a thousand-year-old
# feed shouldn't be refreshed.
@@ -589,13 +592,12 @@ def __init__(self, timestamp):
assert True == m(five_minutes_old, 0)
assert True == m(five_minutes_old, 1)
-
# Realistic end-to-end tests.
def test_lifecycle_with_lane(self):
facets = Facets.default(self._default_library)
pagination = Pagination.default()
- lane = self._lane("My Lane", languages=['eng','chi'])
+ lane = self._lane("My Lane", languages=["eng", "chi"])
# Fetch a cached feed from the database. It comes out updated.
refresher = MockFeedGenerator()
@@ -645,7 +647,5 @@ def test_lifecycle_with_worklist(self):
assert "This is feed #2" == feed.content
# The special constant CACHE_FOREVER means it's always cached.
- feed = CachedFeed.fetch(
- *args, max_age=CachedFeed.CACHE_FOREVER, raw=True
- )
+ feed = CachedFeed.fetch(*args, max_age=CachedFeed.CACHE_FOREVER, raw=True)
assert "This is feed #2" == feed.content
diff --git a/tests/models/test_circulationevent.py b/tests/models/test_circulationevent.py
index 762f1a646..3c8014821 100644
--- a/tests/models/test_circulationevent.py
+++ b/tests/models/test_circulationevent.py
@@ -1,35 +1,29 @@
# encoding: utf-8
-import pytest
import datetime
+
+import pytest
from sqlalchemy.exc import IntegrityError
-from ...testing import DatabaseTest
-from ...model import (
- create,
- get_one_or_create
-)
+
+from ...model import create, get_one_or_create
from ...model.circulationevent import CirculationEvent
from ...model.datasource import DataSource
from ...model.identifier import Identifier
from ...model.licensing import LicensePool
-from ...util.datetime_helpers import (
- datetime_utc,
- strptime_utc,
- to_utc,
- utc_now,
-)
+from ...testing import DatabaseTest
+from ...util.datetime_helpers import datetime_utc, strptime_utc, to_utc, utc_now
-class TestCirculationEvent(DatabaseTest):
+class TestCirculationEvent(DatabaseTest):
def _event_data(self, **kwargs):
for k, default in (
- ("source", DataSource.OVERDRIVE),
- ("id_type", Identifier.OVERDRIVE_ID),
- ("start", utc_now()),
- ("type", CirculationEvent.DISTRIBUTOR_LICENSE_ADD),
+ ("source", DataSource.OVERDRIVE),
+ ("id_type", Identifier.OVERDRIVE_ID),
+ ("start", utc_now()),
+ ("type", CirculationEvent.DISTRIBUTOR_LICENSE_ADD),
):
kwargs.setdefault(k, default)
- if 'old_value' in kwargs and 'new_value' in kwargs:
- kwargs['delta'] = kwargs['new_value'] - kwargs['old_value']
+ if "old_value" in kwargs and "new_value" in kwargs:
+ kwargs["delta"] = kwargs["new_value"] - kwargs["old_value"]
return kwargs
def _get_datetime(self, data, key):
@@ -52,13 +46,13 @@ def from_dict(self, data):
_db = self._db
# Identify the source of the event.
- source_name = data['source']
+ source_name = data["source"]
source = DataSource.lookup(_db, source_name)
# Identify which LicensePool the event is talking about.
- foreign_id = data['id']
+ foreign_id = data["id"]
identifier_type = source.primary_identifier_type
- collection = data['collection']
+ collection = data["collection"]
license_pool, was_new = LicensePool.for_foreign_id(
_db, source, identifier_type, foreign_id, collection=collection
@@ -66,20 +60,21 @@ def from_dict(self, data):
# Finally, gather some information about the event itself.
type = data.get("type")
- start = self._get_datetime(data, 'start')
- end = self._get_datetime(data, 'end')
- old_value = self._get_int(data, 'old_value')
- new_value = self._get_int(data, 'new_value')
- delta = self._get_int(data, 'delta')
+ start = self._get_datetime(data, "start")
+ end = self._get_datetime(data, "end")
+ old_value = self._get_int(data, "old_value")
+ new_value = self._get_int(data, "new_value")
+ delta = self._get_int(data, "delta")
event, was_new = get_one_or_create(
- _db, CirculationEvent, license_pool=license_pool,
- type=type, start=start,
+ _db,
+ CirculationEvent,
+ license_pool=license_pool,
+ type=type,
+ start=start,
create_method_kwargs=dict(
- old_value=old_value,
- new_value=new_value,
- delta=delta,
- end=end)
- )
+ old_value=old_value, new_value=new_value, delta=delta, end=end
+ ),
+ )
return event, was_new
def test_new_title(self):
@@ -127,9 +122,15 @@ def test_log(self):
m = CirculationEvent.log
event, is_new = m(
- self._db, license_pool=pool, event_name=event_name,
- library=library, old_value=old_value, new_value=new_value,
- start=start, end=end, location=location
+ self._db,
+ license_pool=pool,
+ event_name=event_name,
+ library=library,
+ old_value=old_value,
+ new_value=new_value,
+ start=start,
+ end=end,
+ location=location,
)
assert True == is_new
assert pool == event.license_pool
@@ -143,13 +144,16 @@ def test_log(self):
# library, event name, and start date, that event is returned
# unchanged.
event, is_new = m(
- self._db, license_pool=pool, event_name=event_name,
- library=library, start=start,
-
+ self._db,
+ license_pool=pool,
+ event_name=event_name,
+ library=library,
+ start=start,
# These values will be ignored.
- old_value=500, new_value=200,
+ old_value=500,
+ new_value=200,
end=utc_now(),
- location="another location"
+ location="another location",
)
assert False == is_new
assert pool == event.license_pool
@@ -163,9 +167,14 @@ def test_log(self):
# is the most common case, so basically a new event will be
# created each time you call log().
event, is_new = m(
- self._db, license_pool=pool, event_name=event_name,
- library=library, old_value=old_value, new_value=new_value,
- end=end, location=location
+ self._db,
+ license_pool=pool,
+ event_name=event_name,
+ library=library,
+ old_value=old_value,
+ new_value=new_value,
+ end=end,
+ location=location,
)
assert (utc_now() - event.start).total_seconds() < 2
assert True == is_new
@@ -181,7 +190,8 @@ def test_uniqueness_constraints_no_library(self):
pool = self._licensepool(edition=None)
now = utc_now()
kwargs = dict(
- license_pool=pool, type=CirculationEvent.DISTRIBUTOR_TITLE_ADD,
+ license_pool=pool,
+ type=CirculationEvent.DISTRIBUTOR_TITLE_ADD,
)
event = create(self._db, CirculationEvent, start=now, **kwargs)
@@ -193,8 +203,7 @@ def test_uniqueness_constraints_no_library(self):
# Reuse the timestamp and you get an IntegrityError which ruins the
# entire transaction.
pytest.raises(
- IntegrityError, create, self._db, CirculationEvent, start=now,
- **kwargs
+ IntegrityError, create, self._db, CirculationEvent, start=now, **kwargs
)
self._db.rollback()
@@ -218,7 +227,6 @@ def test_uniqueness_constraints_with_library(self):
# Reuse the timestamp and you get an IntegrityError which ruins the
# entire transaction.
pytest.raises(
- IntegrityError, create, self._db, CirculationEvent, start=now,
- **kwargs
+ IntegrityError, create, self._db, CirculationEvent, start=now, **kwargs
)
self._db.rollback()
diff --git a/tests/models/test_classification.py b/tests/models/test_classification.py
index 34178823f..0e5abb845 100644
--- a/tests/models/test_classification.py
+++ b/tests/models/test_classification.py
@@ -2,20 +2,14 @@
import pytest
from psycopg2.extras import NumericRange
from sqlalchemy.exc import IntegrityError
-from ...testing import DatabaseTest
+
from ...classifier import Classifier
-from ...model import (
- create,
- get_one,
- get_one_or_create,
-)
-from ...model.classification import (
- Subject,
- Genre,
-)
+from ...model import create, get_one, get_one_or_create
+from ...model.classification import Genre, Subject
+from ...testing import DatabaseTest
-class TestSubject(DatabaseTest):
+class TestSubject(DatabaseTest):
def test_lookup_errors(self):
"""Subject.lookup will complain if you don't give it
enough information to find a Subject.
@@ -25,15 +19,16 @@ def test_lookup_errors(self):
assert "Cannot look up Subject with no type." in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
Subject.lookup(self._db, Subject.TAG, None, None)
- assert "Cannot look up Subject when neither identifier nor name is provided." in str(excinfo.value)
+ assert (
+ "Cannot look up Subject when neither identifier nor name is provided."
+ in str(excinfo.value)
+ )
def test_lookup_autocreate(self):
# By default, Subject.lookup creates a Subject that doesn't exist.
identifier = self._str
name = self._str
- subject, was_new = Subject.lookup(
- self._db, Subject.TAG, identifier, name
- )
+ subject, was_new = Subject.lookup(self._db, Subject.TAG, identifier, name)
assert True == was_new
assert identifier == subject.identifier
assert name == subject.name
@@ -64,11 +59,13 @@ def test_lookup_by_name(self):
def test_assign_to_genre_can_remove_genre(self):
# Here's a Subject that identifies children's books.
- subject, was_new = Subject.lookup(self._db, Subject.TAG, "Children's books", None)
+ subject, was_new = Subject.lookup(
+ self._db, Subject.TAG, "Children's books", None
+ )
# The genre and audience data for this Subject is totally wrong.
subject.audience = Classifier.AUDIENCE_ADULT
- subject.target_age = NumericRange(1,10)
+ subject.target_age = NumericRange(1, 10)
subject.fiction = False
sf, ignore = Genre.lookup(self._db, "Science Fiction")
subject.genre = sf
@@ -76,12 +73,12 @@ def test_assign_to_genre_can_remove_genre(self):
# But calling assign_to_genre() will fix it.
subject.assign_to_genre()
assert Classifier.AUDIENCE_CHILDREN == subject.audience
- assert NumericRange(None, None, '[]') == subject.target_age
+ assert NumericRange(None, None, "[]") == subject.target_age
assert None == subject.genre
assert None == subject.fiction
-class TestGenre(DatabaseTest):
+class TestGenre(DatabaseTest):
def test_full_table_cache(self):
"""We use Genre as a convenient way of testing
HasFullTableCache.populate_cache, which requires a real
@@ -124,8 +121,8 @@ def test_by_id(self):
def test_by_cache_key_miss_triggers_create_function(self):
_db = self._db
- class Factory(object):
+ class Factory(object):
def __init__(self):
self.called = False
@@ -155,8 +152,7 @@ def test_by_cache_key_miss_when_cache_is_reset_populates_cache(self):
# Call Genreby_cache_key...
drama, is_new = Genre.by_cache_key(
- self._db, "Drama",
- lambda: get_one_or_create(self._db, Genre, name="Drama")
+ self._db, "Drama", lambda: get_one_or_create(self._db, Genre, name="Drama")
)
assert "Drama" == drama.name
assert False == is_new
@@ -171,11 +167,10 @@ def test_by_cache_key_hit_returns_cached_object(self):
# function will be called and raise an exception.
def exploding_create_hook():
raise Exception("Kaboom")
+
drama, ignore = get_one_or_create(self._db, Genre, name="Drama")
- Genre._cache = { "Drama": drama }
- drama2, is_new = Genre.by_cache_key(
- self._db, "Drama", exploding_create_hook
- )
+ Genre._cache = {"Drama": drama}
+ drama2, is_new = Genre.by_cache_key(self._db, "Drama", exploding_create_hook)
# The object was already in the cache, so we just looked it up.
# No exception.
@@ -196,9 +191,7 @@ def test_default_fiction(self):
assert False == nonfiction.default_fiction
# Create a previously unknown genre.
- genre, ignore = Genre.lookup(
- self._db, "Some Weird Genre", autocreate=True
- )
+ genre, ignore = Genre.lookup(self._db, "Some Weird Genre", autocreate=True)
# We don't know its default fiction status.
assert None == genre.default_fiction
diff --git a/tests/models/test_collection.py b/tests/models/test_collection.py
index 5590bbc95..6d902f059 100644
--- a/tests/models/test_collection.py
+++ b/tests/models/test_collection.py
@@ -1,44 +1,38 @@
# encoding: utf-8
-import pytest
-from mock import create_autospec, MagicMock
import datetime
import json
-from ...testing import DatabaseTest
+import pytest
+from mock import MagicMock, create_autospec
+
from ...config import Configuration
-from ...model import (
- create,
- get_one_or_create,
-)
-from ...model.coverage import (
- CoverageRecord,
- WorkCoverageRecord,
-)
+from ...model import create, get_one_or_create
from ...model.circulationevent import CirculationEvent
-from ...model.collection import Collection, HasExternalIntegrationPerCollection, CollectionConfigurationStorage
+from ...model.collection import (
+ Collection,
+ CollectionConfigurationStorage,
+ HasExternalIntegrationPerCollection,
+)
from ...model.complaint import Complaint
from ...model.configuration import (
ConfigurationSetting,
ExternalIntegration,
- ExternalIntegrationLink)
+ ExternalIntegrationLink,
+)
+from ...model.coverage import CoverageRecord, WorkCoverageRecord
from ...model.customlist import CustomList
from ...model.datasource import DataSource
from ...model.edition import Edition
from ...model.hasfulltablecache import HasFullTableCache
from ...model.identifier import Identifier
-from ...model.licensing import (
- Hold,
- Loan,
- License,
- LicensePool,
-)
+from ...model.licensing import Hold, License, LicensePool, Loan
from ...model.work import Work
-from ...util.string_helpers import base64
+from ...testing import DatabaseTest
from ...util.datetime_helpers import utc_now
+from ...util.string_helpers import base64
class TestCollection(DatabaseTest):
-
def setup_method(self):
super(TestCollection, self).setup_method()
self.collection = self._collection(
@@ -74,8 +68,12 @@ def test_by_name_and_protocol(self):
# You'll get an exception if you look up an existing name
# but the protocol doesn't match.
with pytest.raises(ValueError) as excinfo:
- Collection.by_name_and_protocol(self._db, name, ExternalIntegration.BIBLIOTHECA)
- assert 'Collection "A name" does not use protocol "Bibliotheca".' in str(excinfo.value)
+ Collection.by_name_and_protocol(
+ self._db, name, ExternalIntegration.BIBLIOTHECA
+ )
+ assert 'Collection "A name" does not use protocol "Bibliotheca".' in str(
+ excinfo.value
+ )
def test_by_protocol(self):
"""Verify the ability to find all collections that implement
@@ -86,17 +84,17 @@ def test_by_protocol(self):
c1 = self._collection(self._str, protocol=overdrive)
c1.parent = self.collection
c2 = self._collection(self._str, protocol=bibliotheca)
- assert (set([self.collection, c1]) ==
- set(Collection.by_protocol(self._db, overdrive).all()))
- assert (([c2]) ==
- Collection.by_protocol(self._db, bibliotheca).all())
- assert (set([self.collection, c1, c2]) ==
- set(Collection.by_protocol(self._db, None).all()))
+ assert set([self.collection, c1]) == set(
+ Collection.by_protocol(self._db, overdrive).all()
+ )
+ assert ([c2]) == Collection.by_protocol(self._db, bibliotheca).all()
+ assert set([self.collection, c1, c2]) == set(
+ Collection.by_protocol(self._db, None).all()
+ )
# A collection marked for deletion is filtered out.
c1.marked_for_deletion = True
- assert ([self.collection] ==
- Collection.by_protocol(self._db, overdrive).all())
+ assert [self.collection] == Collection.by_protocol(self._db, overdrive).all()
def test_by_datasource(self):
"""Collections can be found by their associated DataSource"""
@@ -104,13 +102,13 @@ def test_by_datasource(self):
c2 = self._collection(data_source_name=DataSource.OVERDRIVE)
# Using the DataSource name
- assert (set([c1]) ==
- set(Collection.by_datasource(self._db, DataSource.GUTENBERG).all()))
+ assert set([c1]) == set(
+ Collection.by_datasource(self._db, DataSource.GUTENBERG).all()
+ )
# Using the DataSource itself
overdrive = DataSource.lookup(self._db, DataSource.OVERDRIVE)
- assert (set([c2]) ==
- set(Collection.by_datasource(self._db, overdrive).all()))
+ assert set([c2]) == set(Collection.by_datasource(self._db, overdrive).all())
# A collection marked for deletion is filtered out.
c2.marked_for_deletion = True
@@ -131,12 +129,10 @@ def test_parents(self):
def test_create_external_integration(self):
# A newly created Collection has no associated ExternalIntegration.
- collection, ignore = get_one_or_create(
- self._db, Collection, name=self._str
- )
+ collection, ignore = get_one_or_create(self._db, Collection, name=self._str)
assert None == collection.external_integration_id
with pytest.raises(ValueError) as excinfo:
- getattr(collection, 'external_integration')
+ getattr(collection, "external_integration")
assert "No known external integration for collection" in str(excinfo.value)
# We can create one with create_external_integration().
@@ -153,9 +149,11 @@ def test_create_external_integration(self):
# If we try to initialize an ExternalIntegration with a different
# protocol, we get an error.
with pytest.raises(ValueError) as excinfo:
- collection.create_external_integration(protocol = "blah")
- assert "Located ExternalIntegration, but its protocol (Overdrive) does not match desired protocol (blah)." \
+ collection.create_external_integration(protocol="blah")
+ assert (
+ "Located ExternalIntegration, but its protocol (Overdrive) does not match desired protocol (blah)."
in str(excinfo.value)
+ )
def test_unique_account_id(self):
@@ -205,11 +203,16 @@ def test_change_protocol(self):
# We can't change the child's protocol to a value that contradicts
# the parent's protocol.
child.protocol = overdrive
+
def set_child_protocol():
child.protocol = bibliotheca
+
with pytest.raises(ValueError) as excinfo:
set_child_protocol()
- assert "Proposed new protocol (Bibliotheca) contradicts parent collection's protocol (Overdrive)." in str(excinfo.value)
+ assert (
+ "Proposed new protocol (Bibliotheca) contradicts parent collection's protocol (Overdrive)."
+ in str(excinfo.value)
+ )
# If we change the parent's protocol, the children are
# automatically updated.
@@ -251,19 +254,22 @@ def test_default_loan_period(self):
# The default when no value is set.
assert (
- Collection.STANDARD_DEFAULT_LOAN_PERIOD ==
- self.collection.default_loan_period(library, ebook))
+ Collection.STANDARD_DEFAULT_LOAN_PERIOD
+ == self.collection.default_loan_period(library, ebook)
+ )
assert (
- Collection.STANDARD_DEFAULT_LOAN_PERIOD ==
- self.collection.default_loan_period(library, audio))
+ Collection.STANDARD_DEFAULT_LOAN_PERIOD
+ == self.collection.default_loan_period(library, audio)
+ )
# Set a value, and it's used.
self.collection.default_loan_period_setting(library, ebook).value = 604
assert 604 == self.collection.default_loan_period(library)
assert (
- Collection.STANDARD_DEFAULT_LOAN_PERIOD ==
- self.collection.default_loan_period(library, audio))
+ Collection.STANDARD_DEFAULT_LOAN_PERIOD
+ == self.collection.default_loan_period(library, audio)
+ )
self.collection.default_loan_period_setting(library, audio).value = 606
assert 606 == self.collection.default_loan_period(library, audio)
@@ -275,19 +281,22 @@ def test_default_loan_period(self):
# The default when no value is set.
assert (
- Collection.STANDARD_DEFAULT_LOAN_PERIOD ==
- self.collection.default_loan_period(client, ebook))
+ Collection.STANDARD_DEFAULT_LOAN_PERIOD
+ == self.collection.default_loan_period(client, ebook)
+ )
assert (
- Collection.STANDARD_DEFAULT_LOAN_PERIOD ==
- self.collection.default_loan_period(client, audio))
+ Collection.STANDARD_DEFAULT_LOAN_PERIOD
+ == self.collection.default_loan_period(client, audio)
+ )
# Set a value, and it's used.
self.collection.default_loan_period_setting(client, ebook).value = 347
assert 347 == self.collection.default_loan_period(client)
assert (
- Collection.STANDARD_DEFAULT_LOAN_PERIOD ==
- self.collection.default_loan_period(client, audio))
+ Collection.STANDARD_DEFAULT_LOAN_PERIOD
+ == self.collection.default_loan_period(client, audio)
+ )
self.collection.default_loan_period_setting(client, audio).value = 349
assert 349 == self.collection.default_loan_period(client, audio)
@@ -301,8 +310,9 @@ def test_default_reservation_period(self):
library = self._default_library
# The default when no value is set.
assert (
- Collection.STANDARD_DEFAULT_RESERVATION_PERIOD ==
- self.collection.default_reservation_period)
+ Collection.STANDARD_DEFAULT_RESERVATION_PERIOD
+ == self.collection.default_reservation_period
+ )
# Set a value, and it's used.
self.collection.default_reservation_period = 601
@@ -342,7 +352,7 @@ def test_explain(self):
about a Collection.
"""
library = self._default_library
- library.name="The only library"
+ library.name = "The only library"
library.short_name = "only one"
library.collections.append(self.collection)
@@ -353,15 +363,15 @@ def test_explain(self):
setting = self.collection.external_integration.set_setting("setting", "value")
data = self.collection.explain()
- assert (['Name: "test collection"',
- 'Protocol: "Overdrive"',
- 'Used by library: "only one"',
- 'External account ID: "id"',
- 'Setting "setting": "value"',
- 'Setting "url": "url"',
- 'Setting "username": "username"',
- ] ==
- data)
+ assert [
+ 'Name: "test collection"',
+ 'Protocol: "Overdrive"',
+ 'Used by library: "only one"',
+ 'External account ID: "id"',
+ 'Setting "setting": "value"',
+ 'Setting "url": "url"',
+ 'Setting "username": "username"',
+ ] == data
with_password = self.collection.explain(include_secrets=True)
assert 'Setting "password": "password"' in with_password
@@ -371,51 +381,53 @@ def test_explain(self):
child = Collection(
name="Child", parent=self.collection, external_account_id="id2"
)
- child.create_external_integration(
- protocol=ExternalIntegration.OVERDRIVE
- )
+ child.create_external_integration(protocol=ExternalIntegration.OVERDRIVE)
data = child.explain()
- assert (['Name: "Child"',
- 'Parent: test collection',
- 'Protocol: "Overdrive"',
- 'External account ID: "id2"'] ==
- data)
+ assert [
+ 'Name: "Child"',
+ "Parent: test collection",
+ 'Protocol: "Overdrive"',
+ 'External account ID: "id2"',
+ ] == data
def test_metadata_identifier(self):
# If the collection doesn't have its unique identifier, an error
# is raised.
- pytest.raises(ValueError, getattr, self.collection, 'metadata_identifier')
+ pytest.raises(ValueError, getattr, self.collection, "metadata_identifier")
def build_expected(protocol, unique_id):
encode = base64.urlsafe_b64encode
- encoded = [
- encode(value)
- for value in [protocol, unique_id]
- ]
- joined = ':'.join(encoded)
+ encoded = [encode(value) for value in [protocol, unique_id]]
+ joined = ":".join(encoded)
return encode(joined)
# With a unique identifier, we get back the expected identifier.
- self.collection.external_account_id = 'id'
- expected = build_expected(ExternalIntegration.OVERDRIVE, 'id')
+ self.collection.external_account_id = "id"
+ expected = build_expected(ExternalIntegration.OVERDRIVE, "id")
assert expected == self.collection.metadata_identifier
# If there's a parent, its unique id is incorporated into the result.
child = self._collection(
- name="Child", protocol=ExternalIntegration.OPDS_IMPORT,
- external_account_id=self._url
+ name="Child",
+ protocol=ExternalIntegration.OPDS_IMPORT,
+ external_account_id=self._url,
)
child.parent = self.collection
- expected = build_expected(ExternalIntegration.OPDS_IMPORT, 'id+%s' % child.external_account_id)
+ expected = build_expected(
+ ExternalIntegration.OPDS_IMPORT, "id+%s" % child.external_account_id
+ )
assert expected == child.metadata_identifier
# If it's an OPDS_IMPORT collection with a url external_account_id,
# closing '/' marks are removed.
opds = self._collection(
- name='OPDS', protocol=ExternalIntegration.OPDS_IMPORT,
- external_account_id=(self._url+'/')
+ name="OPDS",
+ protocol=ExternalIntegration.OPDS_IMPORT,
+ external_account_id=(self._url + "/"),
+ )
+ expected = build_expected(
+ ExternalIntegration.OPDS_IMPORT, opds.external_account_id[:-1]
)
- expected = build_expected(ExternalIntegration.OPDS_IMPORT, opds.external_account_id[:-1])
assert expected == opds.metadata_identifier
def test_from_metadata_identifier(self):
@@ -425,24 +437,29 @@ def test_from_metadata_identifier(self):
# A ValueError results if we try to look up using an invalid
# identifier.
with pytest.raises(ValueError) as excinfo:
- Collection.from_metadata_identifier(self._db, "not a real identifier", data_source = data_source)
- assert "Metadata identifier 'not a real identifier' is invalid: Incorrect padding" in str(excinfo.value)
+ Collection.from_metadata_identifier(
+ self._db, "not a real identifier", data_source=data_source
+ )
+ assert (
+ "Metadata identifier 'not a real identifier' is invalid: Incorrect padding"
+ in str(excinfo.value)
+ )
# Of if we pass in the empty string.
with pytest.raises(ValueError) as excinfo:
- Collection.from_metadata_identifier(self._db, "", data_source = data_source)
+ Collection.from_metadata_identifier(self._db, "", data_source=data_source)
assert "No metadata identifier provided" in str(excinfo.value)
# No new data source was created.
def new_data_source():
return DataSource.lookup(self._db, data_source)
+
assert None == new_data_source()
# If a mirrored collection doesn't exist, it is created.
- self.collection.external_account_id = 'id'
+ self.collection.external_account_id = "id"
mirror_collection, is_new = Collection.from_metadata_identifier(
- self._db, self.collection.metadata_identifier,
- data_source=data_source
+ self._db, self.collection.metadata_identifier, data_source=data_source
)
assert True == is_new
assert self.collection.metadata_identifier == mirror_collection.name
@@ -458,8 +475,7 @@ def new_data_source():
# If the mirrored collection already exists, it is returned.
collection = self._collection(external_account_id=self._url)
mirror_collection = create(
- self._db, Collection,
- name=collection.metadata_identifier
+ self._db, Collection, name=collection.metadata_identifier
)[0]
mirror_collection.create_external_integration(collection.protocol)
@@ -513,33 +529,31 @@ def test_unresolved_catalog(self):
# A 'resolved' identifier that doesn't have a work yet.
# (This isn't supposed to happen, but jic.)
source = DataSource.lookup(self._db, DataSource.GUTENBERG)
- operation = 'test-thyself'
+ operation = "test-thyself"
resolved_id = self._identifier()
self._coverage_record(
- resolved_id, source, operation=operation,
- status=CoverageRecord.SUCCESS
+ resolved_id, source, operation=operation, status=CoverageRecord.SUCCESS
)
# An unresolved identifier--we tried to resolve it, but
# it all fell apart.
unresolved_id = self._identifier()
self._coverage_record(
- unresolved_id, source, operation=operation,
- status=CoverageRecord.TRANSIENT_FAILURE
+ unresolved_id,
+ source,
+ operation=operation,
+ status=CoverageRecord.TRANSIENT_FAILURE,
)
# An identifier with a Work already.
id_with_work = self._work().presentation_edition.primary_identifier
-
- self.collection.catalog_identifiers([
- pure_id, resolved_id, unresolved_id, id_with_work
- ])
-
- result = self.collection.unresolved_catalog(
- self._db, source.name, operation
+ self.collection.catalog_identifiers(
+ [pure_id, resolved_id, unresolved_id, id_with_work]
)
+ result = self.collection.unresolved_catalog(self._db, source.name, operation)
+
# Only the failing identifier is in the query.
assert [unresolved_id] == result.all()
@@ -556,13 +570,17 @@ def test_disassociate_library(self):
integration = collection.external_integration
setting1 = integration.set_setting("integration setting", "value2")
setting2 = ConfigurationSetting.for_library_and_externalintegration(
- self._db, "default_library+integration setting",
- self._default_library, integration,
+ self._db,
+ "default_library+integration setting",
+ self._default_library,
+ integration,
)
setting2.value = "value2"
setting3 = ConfigurationSetting.for_library_and_externalintegration(
- self._db, "other_library+integration setting",
- other_library, integration,
+ self._db,
+ "other_library+integration setting",
+ other_library,
+ integration,
)
setting3.value = "value3"
@@ -617,12 +635,8 @@ def test_licensepools_with_works_updated_since(self):
# This Work is catalogued in another catalog and will never show up.
collection2 = self._collection()
- in_other_catalog = self._work(
- with_license_pool=True, collection=collection2
- )
- collection2.catalog_identifier(
- in_other_catalog.license_pools[0].identifier
- )
+ in_other_catalog = self._work(with_license_pool=True, collection=collection2)
+ collection2.catalog_identifier(in_other_catalog.license_pools[0].identifier)
# When no timestamp is passed, all LicensePeols in the catalog
# are returned, in order of the WorkCoverageRecord
@@ -634,12 +648,12 @@ def test_licensepools_with_works_updated_since(self):
# When a timestamp is passed, only LicensePools whose works
# have been updated since then will be returned.
[w1_coverage_record] = [
- c for c in w1.coverage_records
+ c
+ for c in w1.coverage_records
if c.operation == WorkCoverageRecord.GENERATE_OPDS_OPERATION
]
w1_coverage_record.timestamp = utc_now()
- assert (
- [w1] == [x.work for x in m(self._db, timestamp)])
+ assert [w1] == [x.work for x in m(self._db, timestamp)]
def test_isbns_updated_since(self):
i1 = self._identifier(identifier_type=Identifier.ISBN, foreign_id=self._isbn)
@@ -719,8 +733,8 @@ def test_custom_lists(self):
identifier = pool.identifier
staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF)
staff_edition, ignore = Edition.for_foreign_id(
- self._db, staff_data_source,
- identifier.type, identifier.identifier)
+ self._db, staff_data_source, identifier.type, identifier.identifier
+ )
staff_edition.title = self._str
work.calculate_presentation()
@@ -743,17 +757,20 @@ def test_restrict_to_ready_deliverable_works(self):
"""
# Create two audiobooks and one ebook.
overdrive_audiobook = self._work(
- data_source_name=DataSource.OVERDRIVE, with_license_pool=True,
- title="Overdrive Audiobook"
+ data_source_name=DataSource.OVERDRIVE,
+ with_license_pool=True,
+ title="Overdrive Audiobook",
)
overdrive_audiobook.presentation_edition.medium = Edition.AUDIO_MEDIUM
overdrive_ebook = self._work(
- data_source_name=DataSource.OVERDRIVE, with_license_pool=True,
+ data_source_name=DataSource.OVERDRIVE,
+ with_license_pool=True,
title="Overdrive Ebook",
)
feedbooks_audiobook = self._work(
- data_source_name=DataSource.FEEDBOOKS, with_license_pool=True,
- title="Feedbooks Audiobook"
+ data_source_name=DataSource.FEEDBOOKS,
+ with_license_pool=True,
+ title="Feedbooks Audiobook",
)
feedbooks_audiobook.presentation_edition.medium = Edition.AUDIO_MEDIUM
@@ -762,13 +779,13 @@ def test_restrict_to_ready_deliverable_works(self):
data_source_name=DataSource.LCP,
title="Self-hosted LCP book",
with_license_pool=True,
- self_hosted=True
+ self_hosted=True,
)
unlimited_access_book = self._work(
data_source_name=DataSource.LCP,
title="Self-hosted LCP book",
with_license_pool=True,
- unlimited_access=True
+ unlimited_access=True,
)
def expect(qu, works):
@@ -776,39 +793,48 @@ def expect(qu, works):
restrict_to_ready_deliverable_works(), then verify that
the query returns the works expected by `works`.
"""
- restricted_query = Collection.restrict_to_ready_deliverable_works(
- qu
- )
+ restricted_query = Collection.restrict_to_ready_deliverable_works(qu)
expect_ids = [x.id for x in works]
actual_ids = [x.id for x in restricted_query]
assert set(expect_ids) == set(actual_ids)
+
# Here's the setting which controls which data sources should
# have their audiobooks excluded.
setting = ConfigurationSetting.sitewide(
self._db, Configuration.EXCLUDED_AUDIO_DATA_SOURCES
)
- qu = self._db.query(Work).join(Work.license_pools).join(
- Work.presentation_edition
+ qu = (
+ self._db.query(Work)
+ .join(Work.license_pools)
+ .join(Work.presentation_edition)
)
# When its value is set to the empty list, every work shows
# up.
setting.value = json.dumps([])
expect(
- qu, [overdrive_ebook, overdrive_audiobook, feedbooks_audiobook, self_hosted_lcp_book, unlimited_access_book]
+ qu,
+ [
+ overdrive_ebook,
+ overdrive_audiobook,
+ feedbooks_audiobook,
+ self_hosted_lcp_book,
+ unlimited_access_book,
+ ],
)
# Putting a data source in the list excludes its audiobooks, but
# not its ebooks.
setting.value = json.dumps([DataSource.OVERDRIVE])
expect(
qu,
- [overdrive_ebook, feedbooks_audiobook, self_hosted_lcp_book, unlimited_access_book]
- )
- setting.value = json.dumps(
- [DataSource.OVERDRIVE, DataSource.FEEDBOOKS]
- )
- expect(
- qu, [overdrive_ebook, self_hosted_lcp_book, unlimited_access_book]
+ [
+ overdrive_ebook,
+ feedbooks_audiobook,
+ self_hosted_lcp_book,
+ unlimited_access_book,
+ ],
)
+ setting.value = json.dumps([DataSource.OVERDRIVE, DataSource.FEEDBOOKS])
+ expect(qu, [overdrive_ebook, self_hosted_lcp_book, unlimited_access_book])
def test_delete(self):
"""Verify that Collection.delete will only operate on collections
@@ -826,20 +852,30 @@ def test_delete(self):
integration = collection.external_integration
setting1 = integration.set_setting("integration setting", "value2")
setting2 = ConfigurationSetting.for_library_and_externalintegration(
- self._db, "library+integration setting",
- self._default_library, integration,
+ self._db,
+ "library+integration setting",
+ self._default_library,
+ integration,
)
setting2.value = "value2"
# Also it has links to another independent ExternalIntegration (S3 storage in this case).
s3_storage = self._external_integration(
- ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL, libraries=[self._default_library]
+ ExternalIntegration.S3,
+ ExternalIntegration.STORAGE_GOAL,
+ libraries=[self._default_library],
)
link1 = self._external_integration_link(
- integration, self._default_library, s3_storage, ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS
+ integration,
+ self._default_library,
+ s3_storage,
+ ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS,
)
link2 = self._external_integration_link(
- integration, self._default_library, s3_storage, ExternalIntegrationLink.COVERS
+ integration,
+ self._default_library,
+ s3_storage,
+ ExternalIntegrationLink.COVERS,
)
integration.links.append(link1)
@@ -859,8 +895,7 @@ def test_delete(self):
# And a Complaint.
complaint, is_new = Complaint.register(
- pool, list(Complaint.VALID_TYPES)[0],
- source=None, detail=None
+ pool, list(Complaint.VALID_TYPES)[0], source=None, detail=None
)
# And a CirculationEvent.
@@ -880,15 +915,20 @@ def test_delete(self):
# Works are removed from the search index.
class MockExternalSearchIndex(object):
removed = []
+
def remove_work(self, work):
self.removed.append(work)
+
index = MockExternalSearchIndex()
# delete() will not work on a collection that's not marked for
# deletion.
with pytest.raises(Exception) as excinfo:
collection.delete()
- assert "Cannot delete %s: it is not marked for deletion." % collection.name in str(excinfo.value)
+ assert (
+ "Cannot delete %s: it is not marked for deletion." % collection.name
+ in str(excinfo.value)
+ )
# Delete the collection.
collection.marked_for_deletion = True
@@ -964,22 +1004,26 @@ def test_only_name_is_required(self):
"""Test that only name is a required field on
the Collection class.
"""
- collection = create(
- self._db, Collection, name='banana'
- )[0]
+ collection = create(self._db, Collection, name="banana")[0]
assert True == isinstance(collection, Collection)
class TestCollectionConfigurationStorage(DatabaseTest):
def test_load(self):
# Arrange
- lcp_collection = self._collection('Test Collection', DataSource.LCP)
+ lcp_collection = self._collection("Test Collection", DataSource.LCP)
external_integration = lcp_collection.external_integration
- external_integration_association = create_autospec(spec=HasExternalIntegrationPerCollection)
- external_integration_association.collection_external_integration = MagicMock(return_value=external_integration)
- storage = CollectionConfigurationStorage(external_integration_association, lcp_collection)
- setting_name = 'Test'
- expected_result = 'Test'
+ external_integration_association = create_autospec(
+ spec=HasExternalIntegrationPerCollection
+ )
+ external_integration_association.collection_external_integration = MagicMock(
+ return_value=external_integration
+ )
+ storage = CollectionConfigurationStorage(
+ external_integration_association, lcp_collection
+ )
+ setting_name = "Test"
+ expected_result = "Test"
# Act
storage.save(self._db, setting_name, expected_result)
diff --git a/tests/models/test_complaint.py b/tests/models/test_complaint.py
index 2f7ae8ad7..dd74fa121 100644
--- a/tests/models/test_complaint.py
+++ b/tests/models/test_complaint.py
@@ -1,44 +1,37 @@
# encoding: utf-8
import pytest
-from ...testing import DatabaseTest
from ...model.complaint import Complaint
+from ...testing import DatabaseTest
from ...util.datetime_helpers import utc_now
-class TestComplaint(DatabaseTest):
+class TestComplaint(DatabaseTest):
def setup_method(self):
super(TestComplaint, self).setup_method()
self.edition, self.pool = self._edition(with_license_pool=True)
self.type = "http://librarysimplified.org/terms/problem/wrong-genre"
def test_for_license_pool(self):
- work_complaint, is_new = Complaint.register(
- self.pool, self.type, "yes", "okay"
- )
+ work_complaint, is_new = Complaint.register(self.pool, self.type, "yes", "okay")
- lp_type = self.type.replace('wrong-genre', 'cannot-render')
- lp_complaint, is_new = Complaint.register(
- self.pool, lp_type, "yes", "okay")
+ lp_type = self.type.replace("wrong-genre", "cannot-render")
+ lp_complaint, is_new = Complaint.register(self.pool, lp_type, "yes", "okay")
assert False == work_complaint.for_license_pool
assert True == lp_complaint.for_license_pool
def test_success(self):
- complaint, is_new = Complaint.register(
- self.pool, self.type, "foo", "bar"
- )
+ complaint, is_new = Complaint.register(self.pool, self.type, "foo", "bar")
assert True == is_new
assert self.type == complaint.type
assert "foo" == complaint.source
assert "bar" == complaint.detail
- assert abs(utc_now() -complaint.timestamp).seconds < 3
+ assert abs(utc_now() - complaint.timestamp).seconds < 3
# A second complaint from the same source is folded into the
# original complaint.
- complaint2, is_new = Complaint.register(
- self.pool, self.type, "foo", "baz"
- )
+ complaint2, is_new = Complaint.register(self.pool, self.type, "foo", "baz")
assert False == is_new
assert complaint.id == complaint2.id
assert "baz" == complaint.detail
@@ -46,18 +39,14 @@ def test_success(self):
assert 1 == len(self.pool.complaints)
def test_success_no_source(self):
- complaint, is_new = Complaint.register(
- self.pool, self.type, None, None
- )
+ complaint, is_new = Complaint.register(self.pool, self.type, None, None)
assert True == is_new
assert self.type == complaint.type
assert None == complaint.source
# A second identical complaint from no source is treated as a
# separate complaint.
- complaint2, is_new = Complaint.register(
- self.pool, self.type, None, None
- )
+ complaint2, is_new = Complaint.register(self.pool, self.type, None, None)
assert True == is_new
assert None == complaint.source
assert complaint2.id != complaint.id
@@ -65,15 +54,11 @@ def test_success_no_source(self):
assert 2 == len(self.pool.complaints)
def test_failure_no_licensepool(self):
- pytest.raises(
- ValueError, Complaint.register, self.pool, type, None, None
- )
+ pytest.raises(ValueError, Complaint.register, self.pool, type, None, None)
def test_unrecognized_type(self):
type = "http://librarysimplified.org/terms/problem/no-such-error"
- pytest.raises(
- ValueError, Complaint.register, self.pool, type, None, None
- )
+ pytest.raises(ValueError, Complaint.register, self.pool, type, None, None)
def test_register_resolved(self):
complaint, is_new = Complaint.register(
@@ -83,22 +68,18 @@ def test_register_resolved(self):
assert self.type == complaint.type
assert "foo" == complaint.source
assert "bar" == complaint.detail
- assert abs(utc_now() -complaint.timestamp).seconds < 3
- assert abs(utc_now() -complaint.resolved).seconds < 3
+ assert abs(utc_now() - complaint.timestamp).seconds < 3
+ assert abs(utc_now() - complaint.resolved).seconds < 3
# A second complaint from the same source is not folded into the same complaint.
- complaint2, is_new = Complaint.register(
- self.pool, self.type, "foo", "baz"
- )
+ complaint2, is_new = Complaint.register(self.pool, self.type, "foo", "baz")
assert True == is_new
assert complaint2.id != complaint.id
assert "baz" == complaint2.detail
assert 2 == len(self.pool.complaints)
def test_resolve(self):
- complaint, is_new = Complaint.register(
- self.pool, self.type, "foo", "bar"
- )
+ complaint, is_new = Complaint.register(self.pool, self.type, "foo", "bar")
complaint.resolve()
assert complaint.resolved != None
assert abs(utc_now() - complaint.resolved).seconds < 3
diff --git a/tests/models/test_configuration.py b/tests/models/test_configuration.py
index 0ba78f07e..30ce7972f 100644
--- a/tests/models/test_configuration.py
+++ b/tests/models/test_configuration.py
@@ -28,23 +28,20 @@
class TestConfigurationSetting(DatabaseTest):
-
def test_is_secret(self):
"""Some configuration settings are considered secrets,
and some are not.
"""
m = ConfigurationSetting._is_secret
- assert True == m('secret')
- assert True == m('password')
- assert True == m('its_a_secret_to_everybody')
- assert True == m('the_password')
- assert True == m('password_for_the_account')
- assert False == m('public_information')
-
- assert (True ==
- ConfigurationSetting.sitewide(self._db, "secret_key").is_secret)
- assert (False ==
- ConfigurationSetting.sitewide(self._db, "public_key").is_secret)
+ assert True == m("secret")
+ assert True == m("password")
+ assert True == m("its_a_secret_to_everybody")
+ assert True == m("the_password")
+ assert True == m("password_for_the_account")
+ assert False == m("public_information")
+
+ assert True == ConfigurationSetting.sitewide(self._db, "secret_key").is_secret
+ assert False == ConfigurationSetting.sitewide(self._db, "public_key").is_secret
def test_value_or_default(self):
integration, ignore = create(
@@ -82,8 +79,10 @@ def test_value_inheritance(self):
# Here's an integration, let's say the SIP2 authentication mechanism
sip, ignore = create(
- self._db, ExternalIntegration,
- goal=ExternalIntegration.PATRON_AUTH_GOAL, protocol="SIP2"
+ self._db,
+ ExternalIntegration,
+ goal=ExternalIntegration.PATRON_AUTH_GOAL,
+ protocol="SIP2",
)
# It happens to a ConfigurationSetting for the same key used
@@ -119,8 +118,10 @@ def test_value_inheritance(self):
# prefix. This is set on the combination of a library and a
# SIP2 integration.
key = "patron_identifier_prefix"
- library_patron_prefix_conf = ConfigurationSetting.for_library_and_externalintegration(
- self._db, key, library, sip
+ library_patron_prefix_conf = (
+ ConfigurationSetting.for_library_and_externalintegration(
+ self._db, key, library, sip
+ )
)
assert None == library_patron_prefix_conf.value
@@ -165,9 +166,12 @@ def test_duplicate(self):
assert setting.id == setting2.id
pytest.raises(
IntegrityError,
- create, self._db, ConfigurationSetting,
+ create,
+ self._db,
+ ConfigurationSetting,
key=key,
- library=library, external_integration=integration
+ library=library,
+ external_integration=integration,
)
def test_relationships(self):
@@ -241,15 +245,16 @@ def test_no_orphan_delete_cascade(self):
# That was a weird thing to do, but the ConfigurationSettings
# are still in the database.
for cs in for_library, for_integration:
- assert (
- cs == get_one(self._db, ConfigurationSetting, id=cs.id))
-
- @parameterized.expand([
- ('no value', None, None),
- ('stringable value', 1, '1'),
- ('string value', 'snowman', 'snowman'),
- ('bytes value', '☃'.encode("utf8"), '☃'),
- ])
+ assert cs == get_one(self._db, ConfigurationSetting, id=cs.id)
+
+ @parameterized.expand(
+ [
+ ("no value", None, None),
+ ("stringable value", 1, "1"),
+ ("string value", "snowman", "snowman"),
+ ("bytes value", "☃".encode("utf8"), "☃"),
+ ]
+ )
def test_setter(self, _, set_to, expect):
# Values are converted into Unicode strings on the way in to
# the 'value' setter.
@@ -261,11 +266,11 @@ def test_stored_bytes_value(self):
bytes_setting = ConfigurationSetting.sitewide(self._db, "bytes_setting")
assert bytes_setting.value is None
- bytes_setting.value = '1234 ☃'.encode('utf8')
- assert '1234 ☃' == bytes_setting.value
+ bytes_setting.value = "1234 ☃".encode("utf8")
+ assert "1234 ☃" == bytes_setting.value
with pytest.raises(UnicodeDecodeError):
- bytes_setting.value = b'\x80'
+ bytes_setting.value = b"\x80"
def test_int_value(self):
number = ConfigurationSetting.sitewide(self._db, "number")
@@ -292,7 +297,7 @@ def test_json_value(self):
assert None == jsondata.int_value
jsondata.value = "[1,2]"
- assert [1,2] == jsondata.json_value
+ assert [1, 2] == jsondata.json_value
jsondata.value = "tra la la"
pytest.raises(ValueError, lambda: jsondata.json_value)
@@ -307,8 +312,7 @@ def test_excluded_audio_data_sources(self):
# the return value of the method is AUDIO_EXCLUSIONS -- whatever
# the default is for the current version of the circulation manager.
assert None == setting.value
- assert (ConfigurationSetting.EXCLUDED_AUDIO_DATA_SOURCES_DEFAULT ==
- m(self._db))
+ assert ConfigurationSetting.EXCLUDED_AUDIO_DATA_SOURCES_DEFAULT == m(self._db)
# When an explicit value for the ConfigurationSetting, is set, that
# value is interpreted as JSON and returned.
setting.value = "[]"
@@ -330,15 +334,14 @@ def test_explain(self):
nonsecret_setting='2'"""
assert expect == "\n".join(actual)
- without_secrets = "\n".join(ConfigurationSetting.explain(
- self._db, include_secrets=False
- ))
- assert 'a_secret' not in without_secrets
- assert 'nonsecret_setting' in without_secrets
+ without_secrets = "\n".join(
+ ConfigurationSetting.explain(self._db, include_secrets=False)
+ )
+ assert "a_secret" not in without_secrets
+ assert "nonsecret_setting" in without_secrets
class TestUniquenessConstraints(DatabaseTest):
-
def test_duplicate_sitewide_setting(self):
# You can't create two sitewide settings with the same key.
c1 = ConfigurationSetting(key="key", value="value1")
@@ -381,14 +384,18 @@ def test_duplicate_library_integration_setting(self):
# different ways for the same key.
integration = self._external_integration(self._str)
c1 = ConfigurationSetting(
- key="key", value="value1", library=self._default_library,
- external_integration=integration
+ key="key",
+ value="value1",
+ library=self._default_library,
+ external_integration=integration,
)
self._db.add(c1)
self._db.flush()
c2 = ConfigurationSetting(
- key="key", value="value1", library=self._default_library,
- external_integration=integration
+ key="key",
+ value="value1",
+ library=self._default_library,
+ external_integration=integration,
)
self._db.add(c2)
pytest.raises(IntegrityError, self._db.flush)
@@ -400,29 +407,39 @@ def test_collection_mirror_settings(self):
assert settings[0]["key"] == ExternalIntegrationLink.COVERS_KEY
assert settings[0]["label"] == "Covers Mirror"
- assert (settings[0]["options"][0]['key'] ==
- ExternalIntegrationLink.NO_MIRROR_INTEGRATION)
- assert (settings[0]["options"][0]['label'] ==
- _("None - Do not mirror cover images"))
+ assert (
+ settings[0]["options"][0]["key"]
+ == ExternalIntegrationLink.NO_MIRROR_INTEGRATION
+ )
+ assert settings[0]["options"][0]["label"] == _(
+ "None - Do not mirror cover images"
+ )
assert settings[1]["key"] == ExternalIntegrationLink.OPEN_ACCESS_BOOKS_KEY
assert settings[1]["label"] == "Open Access Books Mirror"
- assert (settings[1]["options"][0]['key'] ==
- ExternalIntegrationLink.NO_MIRROR_INTEGRATION)
- assert (settings[1]["options"][0]['label'] ==
- _("None - Do not mirror free books"))
+ assert (
+ settings[1]["options"][0]["key"]
+ == ExternalIntegrationLink.NO_MIRROR_INTEGRATION
+ )
+ assert settings[1]["options"][0]["label"] == _(
+ "None - Do not mirror free books"
+ )
assert settings[2]["key"] == ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS_KEY
assert settings[2]["label"] == "Protected Access Books Mirror"
- assert (settings[2]["options"][0]['key'] ==
- ExternalIntegrationLink.NO_MIRROR_INTEGRATION)
- assert (settings[2]["options"][0]['label'] ==
- _("None - Do not mirror self-hosted, commercially licensed books"))
-
+ assert (
+ settings[2]["options"][0]["key"]
+ == ExternalIntegrationLink.NO_MIRROR_INTEGRATION
+ )
+ assert settings[2]["options"][0]["label"] == _(
+ "None - Do not mirror self-hosted, commercially licensed books"
+ )
+
def test_relationships(self):
# Create a collection with two storage external integrations.
collection = self._collection(
- name="Collection", protocol=ExternalIntegration.OVERDRIVE,
+ name="Collection",
+ protocol=ExternalIntegration.OVERDRIVE,
)
storage1 = self._external_integration(
@@ -433,7 +450,8 @@ def test_relationships(self):
name="integration2",
protocol=ExternalIntegration.S3,
goal=ExternalIntegration.STORAGE_GOAL,
- username="username", password="password",
+ username="username",
+ password="password",
)
# Two external integration links need to be created to associate
@@ -441,15 +459,18 @@ def test_relationships(self):
# external integrations.
s1_external_integration_link = self._external_integration_link(
integration=collection.external_integration,
- other_integration=storage1, purpose="covers_mirror"
+ other_integration=storage1,
+ purpose="covers_mirror",
)
s2_external_integration_link = self._external_integration_link(
integration=collection.external_integration,
- other_integration=storage2, purpose="books_mirror"
+ other_integration=storage2,
+ purpose="books_mirror",
)
- qu = self._db.query(ExternalIntegrationLink
- ).order_by(ExternalIntegrationLink.other_integration_id)
+ qu = self._db.query(ExternalIntegrationLink).order_by(
+ ExternalIntegrationLink.other_integration_id
+ )
external_integration_links = qu.all()
assert len(external_integration_links) == 2
@@ -469,7 +490,6 @@ def test_relationships(self):
class TestExternalIntegration(DatabaseTest):
-
def setup_method(self):
super(TestExternalIntegration, self).setup_method()
self.external_integration, ignore = create(
@@ -493,8 +513,9 @@ def test_for_library_and_goal(self):
# also starts returning it.
self.external_integration.libraries.append(self._default_library)
assert [self.external_integration] == qu.all()
- assert (self.external_integration ==
- get_one(self._db, self._default_library, goal))
+ assert self.external_integration == get_one(
+ self._db, self._default_library, goal
+ )
# Create another, similar ExternalIntegration. By itself, this
# has no effect.
@@ -502,8 +523,9 @@ def test_for_library_and_goal(self):
self._db, ExternalIntegration, goal=goal, protocol=self._str
)
assert [self.external_integration] == qu.all()
- assert (self.external_integration ==
- get_one(self._db, self._default_library, goal))
+ assert self.external_integration == get_one(
+ self._db, self._default_library, goal
+ )
# Associate that ExternalIntegration with the library, and
# the query starts picking it up, and one_for_library_and_goal
@@ -512,18 +534,23 @@ def test_for_library_and_goal(self):
assert set([self.external_integration, integration2]) == set(qu.all())
with pytest.raises(CannotLoadConfiguration) as excinfo:
get_one(self._db, self._default_library, goal)
- assert "Library {} defines multiple integrations with goal {}".format(self._default_library.name, goal) \
- in str(excinfo.value)
-
+ assert "Library {} defines multiple integrations with goal {}".format(
+ self._default_library.name, goal
+ ) in str(excinfo.value)
+
def test_for_collection_and_purpose(self):
wrong_purpose = "isbn"
collection = self._collection()
with pytest.raises(CannotLoadConfiguration) as excinfo:
- ExternalIntegration.for_collection_and_purpose(self._db, collection, wrong_purpose)
- assert "No storage integration for collection '%s' and purpose '%s' is configured" \
- % (collection.name, wrong_purpose) \
+ ExternalIntegration.for_collection_and_purpose(
+ self._db, collection, wrong_purpose
+ )
+ assert (
+ "No storage integration for collection '%s' and purpose '%s' is configured"
+ % (collection.name, wrong_purpose)
in str(excinfo.value)
+ )
external_integration = self._external_integration("some protocol")
collection.external_integration_id = external_integration.id
@@ -615,9 +642,7 @@ def test_set_key_value_pair(self):
assert setting2 == self.external_integration.setting("website_id")
def test_explain(self):
- integration = self._external_integration(
- "protocol", "goal"
- )
+ integration = self._external_integration("protocol", "goal")
integration.name = "The Integration"
integration.url = "http://url/"
integration.username = "someuser"
@@ -642,14 +667,17 @@ def test_explain(self):
# If we decline to pass in a library, we get information about how
# each library in the system configures this integration.
- expect = """ID: %s
+ expect = (
+ """ID: %s
Name: The Integration
Protocol/Goal: protocol/goal
library-specific='value1' (applies only to First Library)
library-specific='value2' (applies only to Second Library)
somesetting='somevalue'
url='http://url/'
-username='someuser'""" % integration.id
+username='someuser'"""
+ % integration.id
+ )
actual = integration.explain()
assert expect == "\n".join(actual)
@@ -675,18 +703,28 @@ def test_custom_accept_header(self):
def test_delete(self):
"""Ensure that ExternalIntegration.delete clears all orphan ExternalIntegrationLinks."""
integration1 = self._external_integration(
- ExternalIntegration.MANUAL, ExternalIntegration.LICENSE_GOAL, libraries=[self._default_library]
+ ExternalIntegration.MANUAL,
+ ExternalIntegration.LICENSE_GOAL,
+ libraries=[self._default_library],
)
integration2 = self._external_integration(
- ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL, libraries=[self._default_library]
+ ExternalIntegration.S3,
+ ExternalIntegration.STORAGE_GOAL,
+ libraries=[self._default_library],
)
# Set up a a link associating integration2 with integration1.
link1 = self._external_integration_link(
- integration1, self._default_library, integration2, ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS
+ integration1,
+ self._default_library,
+ integration2,
+ ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS,
)
link2 = self._external_integration_link(
- integration1, self._default_library, integration2, ExternalIntegrationLink.COVERS
+ integration1,
+ self._default_library,
+ integration2,
+ ExternalIntegrationLink.COVERS,
)
# Delete integration1.
@@ -703,80 +741,80 @@ def test_delete(self):
assert integration2 in external_integrations
-SETTING1_KEY = 'setting1'
-SETTING1_LABEL = 'Setting 1\'s label'
-SETTING1_DESCRIPTION = 'Setting 1\'s description'
+SETTING1_KEY = "setting1"
+SETTING1_LABEL = "Setting 1's label"
+SETTING1_DESCRIPTION = "Setting 1's description"
SETTING1_TYPE = ConfigurationAttributeType.TEXT
SETTING1_REQUIRED = False
-SETTING1_DEFAULT = '12345'
-SETTING1_CATEGORY = 'Settings'
+SETTING1_DEFAULT = "12345"
+SETTING1_CATEGORY = "Settings"
-SETTING2_KEY = 'setting2'
-SETTING2_LABEL = 'Setting 2\'s label'
-SETTING2_DESCRIPTION = 'Setting 2\'s description'
+SETTING2_KEY = "setting2"
+SETTING2_LABEL = "Setting 2's label"
+SETTING2_DESCRIPTION = "Setting 2's description"
SETTING2_TYPE = ConfigurationAttributeType.SELECT
SETTING2_REQUIRED = False
-SETTING2_DEFAULT = 'value1'
+SETTING2_DEFAULT = "value1"
SETTING2_OPTIONS = [
- ConfigurationOption('key1', 'value1'),
- ConfigurationOption('key2', 'value2'),
- ConfigurationOption('key3', 'value3')
+ ConfigurationOption("key1", "value1"),
+ ConfigurationOption("key2", "value2"),
+ ConfigurationOption("key3", "value3"),
]
-SETTING2_CATEGORY = 'Settings'
+SETTING2_CATEGORY = "Settings"
class TestConfiguration(ConfigurationGrouping):
setting1 = ConfigurationMetadata(
- key='setting1',
+ key="setting1",
label=SETTING1_LABEL,
description=SETTING1_DESCRIPTION,
type=SETTING1_TYPE,
required=SETTING1_REQUIRED,
default=SETTING1_DEFAULT,
- category=SETTING1_CATEGORY
+ category=SETTING1_CATEGORY,
)
setting2 = ConfigurationMetadata(
- key='setting2',
+ key="setting2",
label=SETTING2_LABEL,
description=SETTING2_DESCRIPTION,
type=SETTING2_TYPE,
required=SETTING2_REQUIRED,
default=SETTING2_DEFAULT,
options=SETTING2_OPTIONS,
- category=SETTING2_CATEGORY
+ category=SETTING2_CATEGORY,
)
class ConfigurationWithBooleanProperty(ConfigurationGrouping):
boolean_setting = ConfigurationMetadata(
- key='boolean_setting',
- label='Boolean Setting',
- description='Boolean Setting',
+ key="boolean_setting",
+ label="Boolean Setting",
+ description="Boolean Setting",
type=ConfigurationAttributeType.SELECT,
required=True,
- default='true',
+ default="true",
options=[
- ConfigurationOption('true', 'True'),
- ConfigurationOption('false', 'False')
- ]
+ ConfigurationOption("true", "True"),
+ ConfigurationOption("false", "False"),
+ ],
)
class TestConfiguration2(ConfigurationGrouping):
setting1 = ConfigurationMetadata(
- key='setting1',
+ key="setting1",
label=SETTING1_LABEL,
description=SETTING1_DESCRIPTION,
type=SETTING1_TYPE,
required=SETTING1_REQUIRED,
default=SETTING1_DEFAULT,
category=SETTING1_CATEGORY,
- index=1
+ index=1,
)
setting2 = ConfigurationMetadata(
- key='setting2',
+ key="setting2",
label=SETTING2_LABEL,
description=SETTING2_DESCRIPTION,
type=SETTING2_TYPE,
@@ -784,18 +822,15 @@ class TestConfiguration2(ConfigurationGrouping):
default=SETTING2_DEFAULT,
options=SETTING2_OPTIONS,
category=SETTING2_CATEGORY,
- index=0
+ index=0,
)
class TestConfigurationOption(object):
def test_to_settings(self):
# Arrange
- option = ConfigurationOption('key1', 'value1')
- expected_result = {
- 'key': 'key1',
- 'label': 'value1'
- }
+ option = ConfigurationOption("key1", "value1")
+ expected_result = {"key": "key1", "label": "value1"}
# Act
result = option.to_settings()
@@ -806,11 +841,12 @@ def test_to_settings(self):
def test_from_enum(self):
# Arrange
class TestEnum(Enum):
- LABEL1 = 'KEY1'
- LABEL2 = 'KEY2'
+ LABEL1 = "KEY1"
+ LABEL2 = "KEY2"
+
expected_result = [
- ConfigurationOption('KEY1', 'LABEL1'),
- ConfigurationOption('KEY2', 'LABEL2')
+ ConfigurationOption("KEY1", "LABEL1"),
+ ConfigurationOption("KEY2", "LABEL2"),
]
# Act
@@ -821,10 +857,9 @@ class TestEnum(Enum):
class TestConfigurationGrouping(object):
- @parameterized.expand([
- ('setting1', 'setting1', 12345),
- ('setting2', 'setting2', '12345')
- ])
+ @parameterized.expand(
+ [("setting1", "setting1", 12345), ("setting2", "setting2", "12345")]
+ )
def test_getters(self, _, setting_name, expected_value):
# Arrange
configuration_storage = create_autospec(spec=ConfigurationStorage)
@@ -839,10 +874,9 @@ def test_getters(self, _, setting_name, expected_value):
assert setting_value == expected_value
configuration_storage.load.assert_called_once_with(db, setting_name)
- @parameterized.expand([
- ('setting1', 'setting1', 12345),
- ('setting2', 'setting2', '12345')
- ])
+ @parameterized.expand(
+ [("setting1", "setting1", 12345), ("setting2", "setting2", "12345")]
+ )
def test_setters(self, _, setting_name, expected_value):
# Arrange
configuration_storage = create_autospec(spec=ConfigurationStorage)
@@ -854,7 +888,9 @@ def test_setters(self, _, setting_name, expected_value):
setattr(configuration, setting_name, expected_value)
# Assert
- configuration_storage.save.assert_called_once_with(db, setting_name, expected_value)
+ configuration_storage.save.assert_called_once_with(
+ db, setting_name, expected_value
+ )
def test_to_settings_considers_default_indices(self):
# Act
@@ -865,7 +901,10 @@ def test_to_settings_considers_default_indices(self):
assert settings[0][ConfigurationAttribute.KEY.value] == SETTING1_KEY
assert settings[0][ConfigurationAttribute.LABEL.value] == SETTING1_LABEL
- assert settings[0][ConfigurationAttribute.DESCRIPTION.value] == SETTING1_DESCRIPTION
+ assert (
+ settings[0][ConfigurationAttribute.DESCRIPTION.value]
+ == SETTING1_DESCRIPTION
+ )
assert settings[0][ConfigurationAttribute.TYPE.value] == None
assert settings[0][ConfigurationAttribute.REQUIRED.value] == SETTING1_REQUIRED
assert settings[0][ConfigurationAttribute.DEFAULT.value] == SETTING1_DEFAULT
@@ -873,11 +912,16 @@ def test_to_settings_considers_default_indices(self):
assert settings[1][ConfigurationAttribute.KEY.value] == SETTING2_KEY
assert settings[1][ConfigurationAttribute.LABEL.value] == SETTING2_LABEL
- assert settings[1][ConfigurationAttribute.DESCRIPTION.value] == SETTING2_DESCRIPTION
+ assert (
+ settings[1][ConfigurationAttribute.DESCRIPTION.value]
+ == SETTING2_DESCRIPTION
+ )
assert settings[1][ConfigurationAttribute.TYPE.value] == SETTING2_TYPE.value
assert settings[1][ConfigurationAttribute.REQUIRED.value] == SETTING2_REQUIRED
assert settings[1][ConfigurationAttribute.DEFAULT.value] == SETTING2_DEFAULT
- assert settings[1][ConfigurationAttribute.OPTIONS.value] == [option.to_settings() for option in SETTING2_OPTIONS]
+ assert settings[1][ConfigurationAttribute.OPTIONS.value] == [
+ option.to_settings() for option in SETTING2_OPTIONS
+ ]
assert settings[1][ConfigurationAttribute.CATEGORY.value] == SETTING2_CATEGORY
def test_to_settings_considers_explicit_indices(self):
@@ -889,16 +933,24 @@ def test_to_settings_considers_explicit_indices(self):
assert settings[0][ConfigurationAttribute.KEY.value] == SETTING2_KEY
assert settings[0][ConfigurationAttribute.LABEL.value] == SETTING2_LABEL
- assert settings[0][ConfigurationAttribute.DESCRIPTION.value] == SETTING2_DESCRIPTION
+ assert (
+ settings[0][ConfigurationAttribute.DESCRIPTION.value]
+ == SETTING2_DESCRIPTION
+ )
assert settings[0][ConfigurationAttribute.TYPE.value] == SETTING2_TYPE.value
assert settings[0][ConfigurationAttribute.REQUIRED.value] == SETTING2_REQUIRED
assert settings[0][ConfigurationAttribute.DEFAULT.value] == SETTING2_DEFAULT
- assert settings[0][ConfigurationAttribute.OPTIONS.value] == [option.to_settings() for option in SETTING2_OPTIONS]
+ assert settings[0][ConfigurationAttribute.OPTIONS.value] == [
+ option.to_settings() for option in SETTING2_OPTIONS
+ ]
assert settings[0][ConfigurationAttribute.CATEGORY.value] == SETTING2_CATEGORY
assert settings[1][ConfigurationAttribute.KEY.value] == SETTING1_KEY
assert settings[1][ConfigurationAttribute.LABEL.value] == SETTING1_LABEL
- assert settings[1][ConfigurationAttribute.DESCRIPTION.value] == SETTING1_DESCRIPTION
+ assert (
+ settings[1][ConfigurationAttribute.DESCRIPTION.value]
+ == SETTING1_DESCRIPTION
+ )
assert settings[1][ConfigurationAttribute.TYPE.value] == None
assert settings[1][ConfigurationAttribute.REQUIRED.value] == SETTING1_REQUIRED
assert settings[1][ConfigurationAttribute.DEFAULT.value] == SETTING1_DEFAULT
@@ -906,15 +958,19 @@ def test_to_settings_considers_explicit_indices(self):
class TestBooleanConfigurationMetadata(DatabaseTest):
- @parameterized.expand([
- ('true', 'true', True),
- ('t', 't', True),
- ('yes', 'yes', True),
- ('y', 'y', True),
- (1, 1, False),
- ('false', 'false', False),
- ])
- def test_configuration_metadata_correctly_recognize_bool_values(self, _, value, expected_result):
+ @parameterized.expand(
+ [
+ ("true", "true", True),
+ ("t", "t", True),
+ ("yes", "yes", True),
+ ("y", "y", True),
+ (1, 1, False),
+ ("false", "false", False),
+ ]
+ )
+ def test_configuration_metadata_correctly_recognize_bool_values(
+ self, _, value, expected_result
+ ):
"""Ensure that ConfigurationMetadata.to_bool correctly translates different values into boolean (True/False).
:param _: Name of the test case
@@ -927,14 +983,18 @@ def test_configuration_metadata_correctly_recognize_bool_values(self, _, value,
:type expected_result: bool
"""
# Arrange
- external_integration = self._external_integration('test')
+ external_integration = self._external_integration("test")
external_integration_association = create_autospec(spec=HasExternalIntegration)
- external_integration_association.external_integration = MagicMock(return_value=external_integration)
+ external_integration_association.external_integration = MagicMock(
+ return_value=external_integration
+ )
configuration_storage = ConfigurationStorage(external_integration_association)
- configuration = ConfigurationWithBooleanProperty(configuration_storage, self._db)
+ configuration = ConfigurationWithBooleanProperty(
+ configuration_storage, self._db
+ )
# We set a new value using ConfigurationMetadata.__set__
configuration.boolean_setting = value
diff --git a/tests/models/test_contributor.py b/tests/models/test_contributor.py
index ff8f0ff14..678a5d83f 100644
--- a/tests/models/test_contributor.py
+++ b/tests/models/test_contributor.py
@@ -1,20 +1,20 @@
# encoding: utf-8
-from ...testing import DatabaseTest
from ...model import get_one_or_create
from ...model.contributor import Contributor
from ...model.datasource import DataSource
from ...model.edition import Edition
from ...model.identifier import Identifier
+from ...testing import DatabaseTest
-class TestContributor(DatabaseTest):
+class TestContributor(DatabaseTest):
def test_marc_code_for_every_role_constant(self):
"""We have determined the MARC Role Code for every role
that's important enough we gave it a constant in the Contributor
class.
"""
for constant, value in list(Contributor.__dict__.items()):
- if not constant.endswith('_ROLE'):
+ if not constant.endswith("_ROLE"):
# Not a constant.
continue
assert value in Contributor.MARC_ROLE_CODES
@@ -49,9 +49,7 @@ def test_lookup_by_viaf_interchangeable(self):
bob2.lc = "foo"
self._db.commit()
assert bob1 != bob2
- [some_bob], new = Contributor.lookup(
- self._db, sort_name="Bob", lc="foo"
- )
+ [some_bob], new = Contributor.lookup(self._db, sort_name="Bob", lc="foo")
assert False == new
assert some_bob in (bob1, bob2)
@@ -82,10 +80,10 @@ def test_merge(self):
# Here's Bob.
[bob], ignore = Contributor.lookup(self._db, sort_name="Jones, Bob")
- bob.extra['foo'] = 'bar'
- bob.aliases = ['Bobby']
- bob.viaf = 'viaf'
- bob.lc = 'lc'
+ bob.extra["foo"] = "bar"
+ bob.aliases = ["Bobby"]
+ bob.viaf = "viaf"
+ bob.lc = "lc"
bob.display_name = "Bob Jones"
bob.family_name = "Bobb"
bob.wikipedia_name = "Bob_(Person)"
@@ -94,11 +92,13 @@ def test_merge(self):
data_source = DataSource.lookup(self._db, DataSource.GUTENBERG)
roberts_book, ignore = Edition.for_foreign_id(
- self._db, data_source, Identifier.GUTENBERG_ID, "1")
+ self._db, data_source, Identifier.GUTENBERG_ID, "1"
+ )
roberts_book.add_contributor(robert, Contributor.AUTHOR_ROLE)
bobs_book, ignore = Edition.for_foreign_id(
- self._db, data_source, Identifier.GUTENBERG_ID, "10")
+ self._db, data_source, Identifier.GUTENBERG_ID, "10"
+ )
bobs_book.add_contributor(bob, Contributor.AUTHOR_ROLE)
# In a shocking turn of events, it transpires that "Bob" and
@@ -108,11 +108,11 @@ def test_merge(self):
# 'Bob' is now listed as an alias for Robert, as is Bob's
# alias.
- assert ['Jones, Bob', 'Bobby'] == robert.aliases
+ assert ["Jones, Bob", "Bobby"] == robert.aliases
# The extra information associated with Bob is now associated
# with Robert.
- assert 'bar' == robert.extra['foo']
+ assert "bar" == robert.extra["foo"]
assert "viaf" == robert.viaf
assert "lc" == robert.lc
@@ -123,8 +123,9 @@ def test_merge(self):
# The standalone 'Bob' record has been removed from the database.
assert (
- [] ==
- self._db.query(Contributor).filter(Contributor.sort_name=="Bob").all())
+ []
+ == self._db.query(Contributor).filter(Contributor.sort_name == "Bob").all()
+ )
# Bob's book is now associated with 'Robert', not the standalone
# 'Bob' record.
@@ -136,10 +137,7 @@ def test_merge(self):
bob.merge_into(robert)
assert "Jones, Bob" == robert.sort_name
-
-
- def _names(self, in_name, out_family, out_display,
- default_display_name=None):
+ def _names(self, in_name, out_family, out_display, default_display_name=None):
f, d = Contributor._default_names(in_name, default_display_name)
assert f == out_family
assert d == out_display
@@ -147,30 +145,35 @@ def _names(self, in_name, out_family, out_display,
def test_default_names(self):
# Pass in a default display name and it will always be used.
- self._names("Jones, Bob", "Jones", "Sally Smith",
- default_display_name="Sally Smith")
+ self._names(
+ "Jones, Bob", "Jones", "Sally Smith", default_display_name="Sally Smith"
+ )
# Corporate names are untouched and get no family name.
self._names("Bob's Books.", None, "Bob's Books.")
self._names("Bob's Books, Inc.", None, "Bob's Books, Inc.")
self._names("Little, Brown & Co.", None, "Little, Brown & Co.")
- self._names("Philadelphia Broad Street Church (Philadelphia, Pa.)",
- None, "Philadelphia Broad Street Church")
+ self._names(
+ "Philadelphia Broad Street Church (Philadelphia, Pa.)",
+ None,
+ "Philadelphia Broad Street Church",
+ )
# Dates and other gibberish after a name is removed.
self._names("Twain, Mark, 1855-1910", "Twain", "Mark Twain")
self._names("Twain, Mark, ???-1910", "Twain", "Mark Twain")
self._names("Twain, Mark, circ. 1900", "Twain", "Mark Twain")
self._names("Twain, Mark, !@#!@", "Twain", "Mark Twain")
- self._names(
- "Coolbrith, Ina D. 1842?-1928", "Coolbrith", "Ina D. Coolbrith")
+ self._names("Coolbrith, Ina D. 1842?-1928", "Coolbrith", "Ina D. Coolbrith")
self._names("Caesar, Julius, 1st cent.", "Caesar", "Julius Caesar")
self._names("Arrian, 2nd cent.", "Arrian", "Arrian")
self._names("Hafiz, 14th cent.", "Hafiz", "Hafiz")
self._names("Hormel, Bob 1950?-", "Hormel", "Bob Hormel")
- self._names("Holland, Henry 1583-1650? Monumenta sepulchraria Sancti Pauli",
- "Holland", "Henry Holland")
-
+ self._names(
+ "Holland, Henry 1583-1650? Monumenta sepulchraria Sancti Pauli",
+ "Holland",
+ "Henry Holland",
+ )
# Suffixes stay on the end, except for "Mrs.", which goes
# to the front.
@@ -183,7 +186,6 @@ def test_default_names(self):
self._names("Twain, Mark", "Twain", "Mark Twain")
self._names("Geering, R. G.", "Geering", "R. G. Geering")
-
def test_sort_name(self):
bob, new = get_one_or_create(self._db, Contributor, sort_name=None)
assert None == bob.sort_name
diff --git a/tests/models/test_coverage.py b/tests/models/test_coverage.py
index a7b79fbe5..3f8095100 100644
--- a/tests/models/test_coverage.py
+++ b/tests/models/test_coverage.py
@@ -1,7 +1,6 @@
# encoding: utf-8
import datetime
-from ...testing import DatabaseTest
from ...model.coverage import (
BaseCoverageRecord,
CoverageRecord,
@@ -10,45 +9,37 @@
)
from ...model.datasource import DataSource
from ...model.identifier import Identifier
+from ...testing import DatabaseTest
from ...util.datetime_helpers import datetime_utc, utc_now
-class TestTimestamp(DatabaseTest):
+class TestTimestamp(DatabaseTest):
def test_lookup(self):
c1 = self._default_collection
c2 = self._collection()
# Create a timestamp.
- timestamp = Timestamp.stamp(
- self._db, "service", Timestamp.SCRIPT_TYPE, c1
- )
+ timestamp = Timestamp.stamp(self._db, "service", Timestamp.SCRIPT_TYPE, c1)
# Look it up.
- assert (
- timestamp ==
- Timestamp.lookup(self._db, "service", Timestamp.SCRIPT_TYPE, c1))
+ assert timestamp == Timestamp.lookup(
+ self._db, "service", Timestamp.SCRIPT_TYPE, c1
+ )
# There are a number of ways to _fail_ to look up this timestamp.
- assert (
- None ==
- Timestamp.lookup(
- self._db, "other service", Timestamp.SCRIPT_TYPE, c1
- ))
- assert (
- None ==
- Timestamp.lookup(self._db, "service", Timestamp.MONITOR_TYPE, c1))
- assert (
- None ==
- Timestamp.lookup(self._db, "service", Timestamp.SCRIPT_TYPE, c2))
+ assert None == Timestamp.lookup(
+ self._db, "other service", Timestamp.SCRIPT_TYPE, c1
+ )
+ assert None == Timestamp.lookup(self._db, "service", Timestamp.MONITOR_TYPE, c1)
+ assert None == Timestamp.lookup(self._db, "service", Timestamp.SCRIPT_TYPE, c2)
# value() works the same way as lookup() but returns the actual
# timestamp.finish value.
- assert (timestamp.finish ==
- Timestamp.value(self._db, "service", Timestamp.SCRIPT_TYPE, c1))
- assert (
- None ==
- Timestamp.value(self._db, "service", Timestamp.SCRIPT_TYPE, c2))
+ assert timestamp.finish == Timestamp.value(
+ self._db, "service", Timestamp.SCRIPT_TYPE, c1
+ )
+ assert None == Timestamp.value(self._db, "service", Timestamp.SCRIPT_TYPE, c2)
def test_stamp(self):
service = "service"
@@ -69,8 +60,7 @@ def test_stamp(self):
# Calling stamp() again will update the Timestamp.
stamp2 = Timestamp.stamp(
- self._db, service, type, achievements="yay",
- counter=100, exception="boo"
+ self._db, service, type, achievements="yay", counter=100, exception="boo"
)
assert stamp == stamp2
now = utc_now()
@@ -79,9 +69,9 @@ def test_stamp(self):
assert service == stamp.service
assert type == stamp.service_type
assert None == stamp.collection
- assert 'yay' == stamp.achievements
+ assert "yay" == stamp.achievements
assert 100 == stamp.counter
- assert 'boo' == stamp.exception
+ assert "boo" == stamp.exception
# Passing in a different collection will create a new Timestamp.
stamp3 = Timestamp.stamp(
@@ -93,9 +83,12 @@ def test_stamp(self):
# Passing in CLEAR_VALUE for start, end, or exception will
# clear an existing Timestamp.
stamp4 = Timestamp.stamp(
- self._db, service, type,
- start=Timestamp.CLEAR_VALUE, finish=Timestamp.CLEAR_VALUE,
- exception=Timestamp.CLEAR_VALUE
+ self._db,
+ service,
+ type,
+ start=Timestamp.CLEAR_VALUE,
+ finish=Timestamp.CLEAR_VALUE,
+ exception=Timestamp.CLEAR_VALUE,
)
assert stamp4 == stamp
assert None == stamp4.start
@@ -131,8 +124,12 @@ def test_update(self):
def to_data(self):
stamp = Timestamp.stamp(
- self._db, "service", Timestamp.SCRIPT_TYPE,
- collection=self._default_collection, counter=10, achivements="a"
+ self._db,
+ "service",
+ Timestamp.SCRIPT_TYPE,
+ collection=self._default_collection,
+ counter=10,
+ achivements="a",
)
data = stamp.to_data()
assert isinstance(data, TimestampData)
@@ -150,7 +147,6 @@ def to_data(self):
class TestBaseCoverageRecord(DatabaseTest):
-
def test_not_covered(self):
source = DataSource.lookup(self._db, DataSource.OCLC)
@@ -162,9 +158,7 @@ def test_not_covered(self):
success = self._identifier()
success_record = self._coverage_record(success, source)
- success_record.timestamp = (
- utc_now() - datetime.timedelta(seconds=3600)
- )
+ success_record.timestamp = utc_now() - datetime.timedelta(seconds=3600)
assert CoverageRecord.SUCCESS == success_record.status
transient = self._identifier()
@@ -175,7 +169,7 @@ def test_not_covered(self):
persistent = self._identifier()
persistent_record = self._coverage_record(
- persistent, source, status = BaseCoverageRecord.PERSISTENT_FAILURE
+ persistent, source, status=BaseCoverageRecord.PERSISTENT_FAILURE
)
assert CoverageRecord.PERSISTENT_FAILURE == persistent_record.status
@@ -196,15 +190,17 @@ def check_not_covered(expect, **kwargs):
# to count as 'coverage'.
check_not_covered(
[no_coverage],
- count_as_covered=[CoverageRecord.PERSISTENT_FAILURE,
- CoverageRecord.TRANSIENT_FAILURE,
- CoverageRecord.SUCCESS]
+ count_as_covered=[
+ CoverageRecord.PERSISTENT_FAILURE,
+ CoverageRecord.TRANSIENT_FAILURE,
+ CoverageRecord.SUCCESS,
+ ],
)
# Here, only success counts as 'coverage'.
check_not_covered(
[no_coverage, transient, persistent],
- count_as_covered=CoverageRecord.SUCCESS
+ count_as_covered=CoverageRecord.SUCCESS,
)
# We can also say that coverage doesn't count if it was achieved before
@@ -212,45 +208,45 @@ def check_not_covered(expect, **kwargs):
# of the 'success' record means that record still counts as covered.
check_not_covered(
[no_coverage, transient],
- count_as_not_covered_if_covered_before=success_record.timestamp
+ count_as_not_covered_if_covered_before=success_record.timestamp,
)
# But if we pass in a time one second later, the 'success'
# record no longer counts as covered.
- one_second_after = (
- success_record.timestamp + datetime.timedelta(seconds=1)
- )
+ one_second_after = success_record.timestamp + datetime.timedelta(seconds=1)
check_not_covered(
[success, no_coverage, transient],
- count_as_not_covered_if_covered_before=one_second_after
+ count_as_not_covered_if_covered_before=one_second_after,
)
-class TestCoverageRecord(DatabaseTest):
+class TestCoverageRecord(DatabaseTest):
def test_lookup(self):
source = DataSource.lookup(self._db, DataSource.OCLC)
edition = self._edition()
- operation = 'foo'
+ operation = "foo"
collection = self._default_collection
- record = self._coverage_record(edition, source, operation,
- collection=collection)
-
+ record = self._coverage_record(
+ edition, source, operation, collection=collection
+ )
# To find the CoverageRecord, edition, source, operation,
# and collection must all match.
- result = CoverageRecord.lookup(edition, source, operation,
- collection=collection)
+ result = CoverageRecord.lookup(
+ edition, source, operation, collection=collection
+ )
assert record == result
# You can substitute the Edition's primary identifier for the
# Edition iteslf.
lookup = CoverageRecord.lookup(
- edition.primary_identifier, source, operation,
- collection=self._default_collection
+ edition.primary_identifier,
+ source,
+ operation,
+ collection=self._default_collection,
)
assert lookup == record
-
# Omit the collection, and you find nothing.
result = CoverageRecord.lookup(edition, source, operation)
assert None == result
@@ -259,29 +255,29 @@ def test_lookup(self):
result = CoverageRecord.lookup(edition, source, collection=collection)
assert None == result
- result = CoverageRecord.lookup(edition, source, "other operation",
- collection=collection)
+ result = CoverageRecord.lookup(
+ edition, source, "other operation", collection=collection
+ )
assert None == result
# Same for data source.
other_source = DataSource.lookup(self._db, DataSource.OVERDRIVE)
- result = CoverageRecord.lookup(edition, other_source, operation,
- collection=collection)
+ result = CoverageRecord.lookup(
+ edition, other_source, operation, collection=collection
+ )
assert None == result
def test_add_for(self):
source = DataSource.lookup(self._db, DataSource.OCLC)
edition = self._edition()
- operation = 'foo'
+ operation = "foo"
record, is_new = CoverageRecord.add_for(edition, source, operation)
assert True == is_new
# If we call add_for again we get the same record back, but we
# can modify the timestamp.
a_week_ago = utc_now() - datetime.timedelta(days=7)
- record2, is_new = CoverageRecord.add_for(
- edition, source, operation, a_week_ago
- )
+ record2, is_new = CoverageRecord.add_for(edition, source, operation, a_week_ago)
assert record == record2
assert False == is_new
assert a_week_ago == record2.timestamp
@@ -300,15 +296,14 @@ def test_add_for(self):
# We can change the status.
record5, is_new = CoverageRecord.add_for(
- edition, source, operation,
- status=CoverageRecord.PERSISTENT_FAILURE
+ edition, source, operation, status=CoverageRecord.PERSISTENT_FAILURE
)
assert record5 == record
assert CoverageRecord.PERSISTENT_FAILURE == record.status
def test_bulk_add(self):
source = DataSource.lookup(self._db, DataSource.GUTENBERG)
- operation = 'testing'
+ operation = "testing"
# An untouched identifier.
i1 = self._identifier()
@@ -316,9 +311,11 @@ def test_bulk_add(self):
# An identifier that already has failing coverage.
covered = self._identifier()
existing = self._coverage_record(
- covered, source, operation=operation,
+ covered,
+ source,
+ operation=operation,
status=CoverageRecord.TRANSIENT_FAILURE,
- exception='Uh oh'
+ exception="Uh oh",
)
original_timestamp = existing.timestamp
@@ -339,7 +336,7 @@ def test_bulk_add(self):
assert [existing] == covered.coverage_records
assert CoverageRecord.TRANSIENT_FAILURE == existing.status
assert original_timestamp == existing.timestamp
- assert 'Uh oh' == existing.exception
+ assert "Uh oh" == existing.exception
# Newly untouched identifier.
i2 = self._identifier()
@@ -370,7 +367,7 @@ def test_bulk_add(self):
def test_bulk_add_with_collection(self):
source = DataSource.lookup(self._db, DataSource.GUTENBERG)
- operation = 'testing'
+ operation = "testing"
c1 = self._collection()
c2 = self._collection()
@@ -381,15 +378,17 @@ def test_bulk_add_with_collection(self):
# An identifier with coverage for a different collection.
covered = self._identifier()
existing = self._coverage_record(
- covered, source, operation=operation,
- status=CoverageRecord.TRANSIENT_FAILURE, collection=c1,
- exception='Danger, Will Robinson'
+ covered,
+ source,
+ operation=operation,
+ status=CoverageRecord.TRANSIENT_FAILURE,
+ collection=c1,
+ exception="Danger, Will Robinson",
)
original_timestamp = existing.timestamp
resulting_records, ignored_identifiers = CoverageRecord.bulk_add(
- [i1, covered], source, operation=operation, collection=c1,
- force=True
+ [i1, covered], source, operation=operation, collection=c1, force=True
)
assert 2 == len(resulting_records)
@@ -411,8 +410,12 @@ def test_bulk_add_with_collection(self):
# Bulk add for a different collection.
resulting_records, ignored_identifiers = CoverageRecord.bulk_add(
- [covered], source, operation=operation, collection=c2,
- status=CoverageRecord.TRANSIENT_FAILURE, exception='Oh no',
+ [covered],
+ source,
+ operation=operation,
+ collection=c2,
+ status=CoverageRecord.TRANSIENT_FAILURE,
+ exception="Oh no",
)
# A new record has been added to the identifier.
@@ -422,13 +425,13 @@ def test_bulk_add_with_collection(self):
assert CoverageRecord.TRANSIENT_FAILURE == new_record.status
assert source == new_record.data_source
assert operation == new_record.operation
- assert 'Oh no' == new_record.exception
+ assert "Oh no" == new_record.exception
-class TestWorkCoverageRecord(DatabaseTest):
+class TestWorkCoverageRecord(DatabaseTest):
def test_lookup(self):
work = self._work()
- operation = 'foo'
+ operation = "foo"
lookup = WorkCoverageRecord.lookup(work, operation)
assert None == lookup
@@ -442,16 +445,14 @@ def test_lookup(self):
def test_add_for(self):
work = self._work()
- operation = 'foo'
+ operation = "foo"
record, is_new = WorkCoverageRecord.add_for(work, operation)
assert True == is_new
# If we call add_for again we get the same record back, but we
# can modify the timestamp.
a_week_ago = utc_now() - datetime.timedelta(days=7)
- record2, is_new = WorkCoverageRecord.add_for(
- work, operation, a_week_ago
- )
+ record2, is_new = WorkCoverageRecord.add_for(work, operation, a_week_ago)
assert record == record2
assert False == is_new
assert a_week_ago == record2.timestamp
@@ -485,25 +486,24 @@ def test_bulk_add(self):
# for an irrelevant operation.
not_already_covered = self._work()
irrelevant_record, ignore = WorkCoverageRecord.add_for(
- not_already_covered, irrelevant_operation,
- status=WorkCoverageRecord.SUCCESS
+ not_already_covered, irrelevant_operation, status=WorkCoverageRecord.SUCCESS
)
# This Work will have its existing, relevant CoverageRecord
# updated.
already_covered = self._work()
previously_failed, ignore = WorkCoverageRecord.add_for(
- already_covered, operation,
+ already_covered,
+ operation,
status=WorkCoverageRecord.TRANSIENT_FAILURE,
)
- previously_failed.exception="Some exception"
+ previously_failed.exception = "Some exception"
# This work will not have a record created for it, because
# we're not passing it in to the method.
not_affected = self._work()
WorkCoverageRecord.add_for(
- not_affected, irrelevant_operation,
- status=WorkCoverageRecord.SUCCESS
+ not_affected, irrelevant_operation, status=WorkCoverageRecord.SUCCESS
)
# This work will not have its existing record updated, because
@@ -519,12 +519,14 @@ def test_bulk_add(self):
new_status = WorkCoverageRecord.REGISTERED
WorkCoverageRecord.bulk_add(
[not_already_covered, already_covered],
- operation, new_timestamp, status=new_status
+ operation,
+ new_timestamp,
+ status=new_status,
)
self._db.commit()
+
def relevant_records(work):
- return [x for x in work.coverage_records
- if x.operation == operation]
+ return [x for x in work.coverage_records if x.operation == operation]
# No coverage records were added or modified for works not
# passed in to the method.
diff --git a/tests/models/test_credential.py b/tests/models/test_credential.py
index 5824979aa..d9ee707f3 100644
--- a/tests/models/test_credential.py
+++ b/tests/models/test_credential.py
@@ -1,19 +1,20 @@
# encoding: utf-8
-import pytest
import datetime
+
+import pytest
from sqlalchemy.exc import IntegrityError
-from ...testing import DatabaseTest
from ...model.credential import (
Credential,
DelegatedPatronIdentifier,
DRMDeviceIdentifier,
)
from ...model.datasource import DataSource
+from ...testing import DatabaseTest
from ...util.datetime_helpers import utc_now
-class TestCredentials(DatabaseTest):
+class TestCredentials(DatabaseTest):
def test_temporary_token(self):
# Create a temporary token good for one hour.
@@ -23,44 +24,53 @@ def test_temporary_token(self):
now = utc_now()
expect_expires = now + duration
token, is_new = Credential.temporary_token_create(
- self._db, data_source, "some random type", patron, duration)
+ self._db, data_source, "some random type", patron, duration
+ )
assert data_source == token.data_source
assert "some random type" == token.type
assert patron == token.patron
- expires_difference = abs((token.expires-expect_expires).seconds)
+ expires_difference = abs((token.expires - expect_expires).seconds)
assert expires_difference < 2
# Now try to look up the credential based solely on the UUID.
new_token = Credential.lookup_by_token(
- self._db, data_source, token.type, token.credential)
+ self._db, data_source, token.type, token.credential
+ )
assert new_token == token
# When we call lookup_and_expire_temporary_token, the token is automatically
# expired and we cannot use it anymore.
new_token = Credential.lookup_and_expire_temporary_token(
- self._db, data_source, token.type, token.credential)
+ self._db, data_source, token.type, token.credential
+ )
assert new_token == token
assert new_token.expires < now
new_token = Credential.lookup_by_token(
- self._db, data_source, token.type, token.credential)
+ self._db, data_source, token.type, token.credential
+ )
assert None == new_token
new_token = Credential.lookup_and_expire_temporary_token(
- self._db, data_source, token.type, token.credential)
+ self._db, data_source, token.type, token.credential
+ )
assert None == new_token
# A token with no expiration date is treated as expired...
token.expires = None
self._db.commit()
no_expiration_token = Credential.lookup_by_token(
- self._db, data_source, token.type, token.credential)
+ self._db, data_source, token.type, token.credential
+ )
assert None == no_expiration_token
# ...unless we specifically say we're looking for a persistent token.
no_expiration_token = Credential.lookup_by_token(
- self._db, data_source, token.type, token.credential,
- allow_persistent_token=True
+ self._db,
+ data_source,
+ token.type,
+ token.credential,
+ allow_persistent_token=True,
)
assert token == no_expiration_token
@@ -73,8 +83,12 @@ def test_specify_value_of_temporary_token(self):
duration = datetime.timedelta(hours=1)
data_source = DataSource.lookup(self._db, DataSource.ADOBE)
token, is_new = Credential.temporary_token_create(
- self._db, data_source, "some random type", patron, duration,
- "Some random value"
+ self._db,
+ data_source,
+ "some random type",
+ patron,
+ duration,
+ "Some random value",
)
assert "Some random value" == token.credential
@@ -83,13 +97,15 @@ def test_temporary_token_overwrites_old_token(self):
data_source = DataSource.lookup(self._db, DataSource.ADOBE)
patron = self._patron()
old_token, is_new = Credential.temporary_token_create(
- self._db, data_source, "some random type", patron, duration)
+ self._db, data_source, "some random type", patron, duration
+ )
assert True == is_new
old_credential = old_token.credential
# Creating a second temporary token overwrites the first.
token, is_new = Credential.temporary_token_create(
- self._db, data_source, "some random type", patron, duration)
+ self._db, data_source, "some random type", patron, duration
+ )
assert False == is_new
assert token.id == old_token.id
assert old_credential != token.credential
@@ -108,8 +124,11 @@ def test_persistent_token(self):
# Now try to look up the credential based solely on the UUID.
new_token = Credential.lookup_by_token(
- self._db, data_source, token.type, token.credential,
- allow_persistent_token=True
+ self._db,
+ data_source,
+ token.type,
+ token.credential,
+ allow_persistent_token=True,
)
assert new_token == token
credential = new_token.credential
@@ -118,8 +137,11 @@ def test_persistent_token(self):
# Credential object with the same .credential -- it doesn't
# expire.
again_token = Credential.lookup_by_token(
- self._db, data_source, token.type, token.credential,
- allow_persistent_token=True
+ self._db,
+ data_source,
+ token.type,
+ token.credential,
+ allow_persistent_token=True,
)
assert again_token == new_token
assert again_token.credential == credential
@@ -127,7 +149,8 @@ def test_persistent_token(self):
def test_cannot_look_up_nonexistent_token(self):
data_source = DataSource.lookup(self._db, DataSource.ADOBE)
new_token = Credential.lookup_by_token(
- self._db, data_source, "no such type", "no such credential")
+ self._db, data_source, "no such type", "no such credential"
+ )
assert None == new_token
def test_empty_token(self):
@@ -144,7 +167,14 @@ def test_empty_token(self):
# and the refresher method is not called.
def refresher(self):
raise Exception("Refresher method was called")
- args = self._db, data_source, token.type, None, refresher,
+
+ args = (
+ self._db,
+ data_source,
+ token.type,
+ None,
+ refresher,
+ )
again_token = Credential.lookup(
*args, allow_persistent_token=True, allow_empty_token=True
)
@@ -153,7 +183,9 @@ def refresher(self):
# If allow_empty_token is False, the refresher method is
# created.
with pytest.raises(Exception) as excinfo:
- Credential.lookup(*args, allow_persistent_token = True, allow_empty_token = False)
+ Credential.lookup(
+ *args, allow_persistent_token=True, allow_empty_token=False
+ )
assert "Refresher method was called" in str(excinfo.value)
def test_force_refresher_method(self):
@@ -186,7 +218,7 @@ def refresher(self):
# This call should run the refresher method.
with pytest.raises(Exception) as excinfo:
- Credential.lookup(*args, allow_persistent_token = True, force_refresh=True)
+ Credential.lookup(*args, allow_persistent_token=True, force_refresh=True)
assert "Refresher method was called" in str(excinfo.value)
def test_collection_token(self):
@@ -199,29 +231,47 @@ def test_collection_token(self):
type = "super secret"
# Create our credentials
- credential1 = Credential.lookup(self._db, data_source, type, patron, None, collection=collection1)
- credential2 = Credential.lookup(self._db, data_source, type, patron, None, collection=collection2)
- credential1.credential = 'test1'
- credential2.credential = 'test2'
+ credential1 = Credential.lookup(
+ self._db, data_source, type, patron, None, collection=collection1
+ )
+ credential2 = Credential.lookup(
+ self._db, data_source, type, patron, None, collection=collection2
+ )
+ credential1.credential = "test1"
+ credential2.credential = "test2"
# Make sure the text matches what we expect
- assert 'test1' == Credential.lookup(self._db, data_source, type, patron, None, collection=collection1).credential
- assert 'test2' == Credential.lookup(self._db, data_source, type, patron, None, collection=collection2).credential
+ assert (
+ "test1"
+ == Credential.lookup(
+ self._db, data_source, type, patron, None, collection=collection1
+ ).credential
+ )
+ assert (
+ "test2"
+ == Credential.lookup(
+ self._db, data_source, type, patron, None, collection=collection2
+ ).credential
+ )
# Make sure we don't get anything if we don't pass a collection
- assert None == Credential.lookup(self._db, data_source, type, patron, None).credential
+ assert (
+ None
+ == Credential.lookup(self._db, data_source, type, patron, None).credential
+ )
-class TestDelegatedPatronIdentifier(DatabaseTest):
+class TestDelegatedPatronIdentifier(DatabaseTest):
def test_get_one_or_create(self):
library_uri = self._url
patron_identifier = self._str
identifier_type = DelegatedPatronIdentifier.ADOBE_ACCOUNT_ID
+
def make_id():
return "id1"
+
identifier, is_new = DelegatedPatronIdentifier.get_one_or_create(
- self._db, library_uri, patron_identifier, identifier_type,
- make_id
+ self._db, library_uri, patron_identifier, identifier_type, make_id
)
assert True == is_new
assert library_uri == identifier.library_uri
@@ -233,6 +283,7 @@ def make_id():
# that raises an exception if called.
def explode():
raise Exception("I should never be called.")
+
identifier2, is_new = DelegatedPatronIdentifier.get_one_or_create(
self._db, library_uri, patron_identifier, identifier_type, explode
)
@@ -244,11 +295,10 @@ def explode():
class TestUniquenessConstraints(DatabaseTest):
-
def setup_method(self):
super(TestUniquenessConstraints, self).setup_method()
self.data_source = DataSource.lookup(self._db, DataSource.OVERDRIVE)
- self.type = 'a credential type'
+ self.type = "a credential type"
self.patron = self._patron()
self.col1 = self._default_collection
self.col2 = self._collection()
@@ -256,15 +306,11 @@ def setup_method(self):
def test_duplicate_sitewide_credential(self):
# You can't create two credentials with the same data source,
# type, and token value.
- token = 'a token'
+ token = "a token"
- c1 = Credential(
- data_source=self.data_source, type=self.type, credential=token
- )
+ c1 = Credential(data_source=self.data_source, type=self.type, credential=token)
self._db.flush()
- c2 = Credential(
- data_source=self.data_source, type=self.type, credential=token
- )
+ c2 = Credential(data_source=self.data_source, type=self.type, credential=token)
pytest.raises(IntegrityError, self._db.flush)
def test_duplicate_patron_credential(self):
@@ -287,17 +333,23 @@ def test_duplicate_patron_collection_credential(self):
# collections are different.
c1 = Credential(
- data_source=self.data_source, type=self.type, patron=self.patron,
- collection=self.col1
+ data_source=self.data_source,
+ type=self.type,
+ patron=self.patron,
+ collection=self.col1,
)
c2 = Credential(
- data_source=self.data_source, type=self.type, patron=self.patron,
- collection=self.col2
+ data_source=self.data_source,
+ type=self.type,
+ patron=self.patron,
+ collection=self.col2,
)
self._db.flush()
c3 = Credential(
- data_source=self.data_source, type=self.type, patron=self.patron,
- collection=self.col1
+ data_source=self.data_source,
+ type=self.type,
+ patron=self.patron,
+ collection=self.col1,
)
pytest.raises(IntegrityError, self._db.flush)
@@ -315,13 +367,13 @@ def test_duplicate_collection_credential(self):
class TestDRMDeviceIdentifier(DatabaseTest):
-
def setup_method(self):
super(TestDRMDeviceIdentifier, self).setup_method()
self.data_source = DataSource.lookup(self._db, DataSource.ADOBE)
self.patron = self._patron()
self.credential, ignore = Credential.persistent_token_create(
- self._db, self.data_source, "Some Credential", self.patron)
+ self._db, self.data_source, "Some Credential", self.patron
+ )
def test_devices_for_credential(self):
device_id_1, new = self.credential.register_drm_device_identifier("foo")
@@ -335,7 +387,9 @@ def test_devices_for_credential(self):
device_id_3, new = self.credential.register_drm_device_identifier("bar")
- assert set([device_id_1, device_id_3]) == set(self.credential.drm_device_identifiers)
+ assert set([device_id_1, device_id_3]) == set(
+ self.credential.drm_device_identifiers
+ )
def test_deregister(self):
device, new = self.credential.register_drm_device_identifier("foo")
diff --git a/tests/models/test_customlist.py b/tests/models/test_customlist.py
index df983b25d..f3df94802 100644
--- a/tests/models/test_customlist.py
+++ b/tests/models/test_customlist.py
@@ -1,48 +1,50 @@
# encoding: utf-8
-import pytest
from pdb import set_trace
-from ...testing import DatabaseTest
+import pytest
+
from ...model import get_one_or_create
from ...model.coverage import WorkCoverageRecord
-from ...model.customlist import (
- CustomList,
- CustomListEntry,
-)
+from ...model.customlist import CustomList, CustomListEntry
from ...model.datasource import DataSource
+from ...testing import DatabaseTest
from ...util.datetime_helpers import utc_now
-class TestCustomList(DatabaseTest):
+class TestCustomList(DatabaseTest):
def test_find(self):
source = DataSource.lookup(self._db, DataSource.NYT)
# When there's no CustomList to find, nothing is returned.
- result = CustomList.find(self._db, 'my-list', source)
+ result = CustomList.find(self._db, "my-list", source)
assert None == result
custom_list = self._customlist(
- foreign_identifier='a-list', name='My List', num_entries=0
+ foreign_identifier="a-list", name="My List", num_entries=0
)[0]
# A CustomList can be found by its foreign_identifier.
- result = CustomList.find(self._db, 'a-list', source)
+ result = CustomList.find(self._db, "a-list", source)
assert custom_list == result
# Or its name.
- result = CustomList.find(self._db, 'My List', source.name)
+ result = CustomList.find(self._db, "My List", source.name)
assert custom_list == result
# The list can also be found by name without a data source.
- result = CustomList.find(self._db, 'My List')
+ result = CustomList.find(self._db, "My List")
assert custom_list == result
# By default, we only find lists with no associated Library.
# If we look for a list from a library, there isn't one.
- result = CustomList.find(self._db, 'My List', source, library=self._default_library)
+ result = CustomList.find(
+ self._db, "My List", source, library=self._default_library
+ )
assert None == result
# If we add the Library to the list, it's returned.
custom_list.library = self._default_library
- result = CustomList.find(self._db, 'My List', source, library=self._default_library)
+ result = CustomList.find(
+ self._db, "My List", source, library=self._default_library
+ )
assert custom_list == result
def assert_reindexing_scheduled(self, work):
@@ -50,8 +52,9 @@ def assert_reindexing_scheduled(self, work):
indicates that it needs to have its search index updated.
"""
[needs_reindex] = work.coverage_records
- assert (WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION ==
- needs_reindex.operation)
+ assert (
+ WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION == needs_reindex.operation
+ )
assert WorkCoverageRecord.REGISTERED == needs_reindex.status
def test_add_entry(self):
@@ -124,7 +127,9 @@ def test_add_entry(self):
# If the entry already exists, the most_recent_appearance can be
# updated by passing in a later first_appearance.
later = utc_now()
- new_timed_entry = custom_list.add_entry(timed_edition, first_appearance=later)[0]
+ new_timed_entry = custom_list.add_entry(timed_edition, first_appearance=later)[
+ 0
+ ]
assert timed_entry == new_timed_entry
assert now == new_timed_entry.first_appearance
assert later == new_timed_entry.most_recent_appearance
@@ -242,7 +247,8 @@ def test_add_entry_work_equivalent_identifier(self):
w2 = self._work()
w1.presentation_edition.primary_identifier.equivalent_to(
w1.presentation_edition.data_source,
- w2.presentation_edition.primary_identifier, 1
+ w2.presentation_edition.primary_identifier,
+ 1,
)
custom_list, ignore = self._customlist(num_entries=0)
@@ -273,7 +279,9 @@ def test_remove_entry(self):
first.work.coverage_records = []
custom_list.remove_entry(first)
assert 2 == len(custom_list.entries)
- assert set([second, third]) == set([entry.edition for entry in custom_list.entries])
+ assert set([second, third]) == set(
+ [entry.edition for entry in custom_list.entries]
+ )
# And CustomList.updated and size are changed.
assert True == (custom_list.updated > now)
assert 2 == custom_list.size
@@ -317,7 +325,7 @@ def test_remove_entry(self):
def test_entries_for_work(self):
custom_list, editions = self._customlist(num_entries=2)
edition = editions[0]
- [entry] = [e for e in custom_list.entries if e.edition==edition]
+ [entry] = [e for e in custom_list.entries if e.edition == edition]
# The entry is returned when you search by Edition.
assert [entry] == list(custom_list.entries_for_work(edition))
@@ -338,12 +346,11 @@ def test_entries_for_work(self):
not_yet_equivalent = self._edition()
other_entry = custom_list.add_entry(not_yet_equivalent)[0]
edition.primary_identifier.equivalent_to(
- not_yet_equivalent.data_source,
- not_yet_equivalent.primary_identifier, 1
+ not_yet_equivalent.data_source, not_yet_equivalent.primary_identifier, 1
+ )
+ assert set([entry, other_entry]) == set(
+ custom_list.entries_for_work(not_yet_equivalent)
)
- assert (
- set([entry, other_entry]) ==
- set(custom_list.entries_for_work(not_yet_equivalent)))
def test_update_size(self):
list, ignore = self._customlist(num_entries=4)
@@ -354,7 +361,6 @@ def test_update_size(self):
class TestCustomListEntry(DatabaseTest):
-
def test_set_work(self):
# Start with a custom list with no entries
@@ -364,8 +370,10 @@ def test_set_work(self):
edition = self._edition()
entry, ignore = get_one_or_create(
- self._db, CustomListEntry,
- list_id=list.id, edition_id=edition.id,
+ self._db,
+ CustomListEntry,
+ list_id=list.id,
+ edition_id=edition.id,
)
assert edition == entry.edition
@@ -412,8 +420,7 @@ def test_update(self):
other_custom_list = self._customlist()[0]
[external_entry] = other_custom_list.entries
pytest.raises(
- ValueError, entry.update, self._db,
- equivalent_entries=[external_entry]
+ ValueError, entry.update, self._db, equivalent_entries=[external_entry]
)
# So is attempting to update an entry with other entries that
@@ -422,8 +429,7 @@ def test_update(self):
external_work_edition = external_work.presentation_edition
external_work_entry = custom_list.add_entry(external_work_edition)[0]
pytest.raises(
- ValueError, entry.update, self._db,
- equivalent_entries=[external_work_entry]
+ ValueError, entry.update, self._db, equivalent_entries=[external_work_entry]
)
# Okay, but with an actual equivalent entry...
@@ -448,17 +454,23 @@ def test_update(self):
assert entry.edition == work.presentation_edition
assert entry.work == equivalent.work
# The equivalent entry has been deleted.
- assert ([] == self._db.query(CustomListEntry).\
- filter(CustomListEntry.id==equivalent_entry.id).all())
+ assert (
+ []
+ == self._db.query(CustomListEntry)
+ .filter(CustomListEntry.id == equivalent_entry.id)
+ .all()
+ )
# The entry with the longest annotation wins the annotation awards.
long_annotation = "Wow books are so great especially when they're annotated."
longwinded = self._edition()
longwinded_entry = custom_list.add_entry(
- longwinded, annotation=long_annotation)[0]
+ longwinded, annotation=long_annotation
+ )[0]
identifier.equivalent_to(
- longwinded.data_source, longwinded.primary_identifier, 1)
+ longwinded.data_source, longwinded.primary_identifier, 1
+ )
entry.update(self._db, equivalent_entries=[longwinded_entry])
assert long_annotation == entry.annotation
assert longwinded_entry.most_recent_appearance == entry.most_recent_appearance
diff --git a/tests/models/test_datasource.py b/tests/models/test_datasource.py
index 8ccaf2dd5..92498210d 100644
--- a/tests/models/test_datasource.py
+++ b/tests/models/test_datasource.py
@@ -1,13 +1,14 @@
# encoding: utf-8
import pytest
from sqlalchemy.orm.exc import NoResultFound
-from ...testing import DatabaseTest
+
from ...model.datasource import DataSource
from ...model.hasfulltablecache import HasFullTableCache
from ...model.identifier import Identifier
+from ...testing import DatabaseTest
-class TestDataSource(DatabaseTest):
+class TestDataSource(DatabaseTest):
def test_lookup(self):
key = DataSource.GUTENBERG
@@ -49,8 +50,7 @@ def test_lookup_by_deprecated_name(self):
assert DataSource.BIBLIOTHECA != "3M"
def test_lookup_returns_none_for_nonexistent_source(self):
- assert None == DataSource.lookup(
- self._db, "No such data source " + self._str)
+ assert None == DataSource.lookup(self._db, "No such data source " + self._str)
def test_lookup_with_autocreate(self):
name = "Brand new data source " + self._str
@@ -79,11 +79,11 @@ def test_license_source_for(self):
assert DataSource.OVERDRIVE == source.name
def test_license_source_for_string(self):
- source = DataSource.license_source_for(
- self._db, Identifier.THREEM_ID)
+ source = DataSource.license_source_for(self._db, Identifier.THREEM_ID)
assert DataSource.THREEM == source.name
def test_license_source_fails_if_identifier_type_does_not_provide_licenses(self):
identifier = self._identifier(DataSource.MANUAL)
pytest.raises(
- NoResultFound, DataSource.license_source_for, self._db, identifier)
+ NoResultFound, DataSource.license_source_for, self._db, identifier
+ )
diff --git a/tests/models/test_edition.py b/tests/models/test_edition.py
index 3be50cd19..8fc11bc9a 100644
--- a/tests/models/test_edition.py
+++ b/tests/models/test_edition.py
@@ -1,25 +1,20 @@
# encoding: utf-8
import datetime
-from ...testing import DatabaseTest
-from ...model import (
- get_one_or_create,
- PresentationCalculationPolicy,
-)
+
+from ...model import PresentationCalculationPolicy, get_one_or_create
from ...model.constants import MediaTypes
-from ...model.coverage import CoverageRecord
from ...model.contributor import Contributor
+from ...model.coverage import CoverageRecord
from ...model.datasource import DataSource
from ...model.edition import Edition
from ...model.identifier import Identifier
from ...model.licensing import DeliveryMechanism
-from ...model.resource import (
- Hyperlink,
- Representation,
-)
+from ...model.resource import Hyperlink, Representation
+from ...testing import DatabaseTest
from ...util.datetime_helpers import utc_now
-class TestEdition(DatabaseTest):
+class TestEdition(DatabaseTest):
def test_medium_from_media_type(self):
# Verify that we can guess a value for Edition.medium from a
# media type.
@@ -54,9 +49,7 @@ def test_author_contributors(self):
id = self._str
type = Identifier.GUTENBERG_ID
- edition, was_new = Edition.for_foreign_id(
- self._db, data_source, type, id
- )
+ edition, was_new = Edition.for_foreign_id(self._db, data_source, type, id)
# We've listed the same person as primary author and author.
[alice], ignore = Contributor.lookup(self._db, "Adder, Alice")
@@ -81,8 +74,7 @@ def test_for_foreign_id(self):
id = "549"
type = Identifier.GUTENBERG_ID
- record, was_new = Edition.for_foreign_id(
- self._db, data_source, type, id)
+ record, was_new = Edition.for_foreign_id(self._db, data_source, type, id)
assert data_source == record.data_source
identifier = record.primary_identifier
assert id == identifier.identifier
@@ -93,7 +85,8 @@ def test_for_foreign_id(self):
# We can get the same work record by providing only the name
# of the data source.
record, was_new = Edition.for_foreign_id(
- self._db, DataSource.GUTENBERG, type, id)
+ self._db, DataSource.GUTENBERG, type, id
+ )
assert data_source == record.data_source
assert identifier == record.primary_identifier
assert False == was_new
@@ -105,10 +98,12 @@ def test_missing_coverage_from(self):
# Here are two Gutenberg records.
g1, ignore = Edition.for_foreign_id(
- self._db, gutenberg, Identifier.GUTENBERG_ID, "1")
+ self._db, gutenberg, Identifier.GUTENBERG_ID, "1"
+ )
g2, ignore = Edition.for_foreign_id(
- self._db, gutenberg, Identifier.GUTENBERG_ID, "2")
+ self._db, gutenberg, Identifier.GUTENBERG_ID, "2"
+ )
# One of them has coverage from OCLC Classify
c1 = self._coverage_record(g1, oclc)
@@ -118,13 +113,15 @@ def test_missing_coverage_from(self):
# Here's a web record, just sitting there.
w, ignore = Edition.for_foreign_id(
- self._db, web, Identifier.URI, "http://www.foo.com/")
+ self._db, web, Identifier.URI, "http://www.foo.com/"
+ )
# missing_coverage_from picks up the Gutenberg record with no
# coverage from OCLC. It doesn't pick up the other
# Gutenberg record, and it doesn't pick up the web record.
[in_gutenberg_but_not_in_oclc] = Edition.missing_coverage_from(
- self._db, gutenberg, oclc).all()
+ self._db, gutenberg, oclc
+ ).all()
assert g2 == in_gutenberg_but_not_in_oclc
@@ -132,27 +129,36 @@ def test_missing_coverage_from(self):
# record that has coverage for that operation, but not the one
# that has generic OCLC coverage.
[has_generic_coverage_only] = Edition.missing_coverage_from(
- self._db, gutenberg, oclc, "some operation").all()
+ self._db, gutenberg, oclc, "some operation"
+ ).all()
assert g1 == has_generic_coverage_only
# We don't put web sites into OCLC, so this will pick up the
# web record (but not the Gutenberg record).
[in_web_but_not_in_oclc] = Edition.missing_coverage_from(
- self._db, web, oclc).all()
+ self._db, web, oclc
+ ).all()
assert w == in_web_but_not_in_oclc
# We don't use the web as a source of coverage, so this will
# return both Gutenberg records (but not the web record).
- assert [g1.id, g2.id] == sorted([x.id for x in Edition.missing_coverage_from(
- self._db, gutenberg, web)])
+ assert [g1.id, g2.id] == sorted(
+ [x.id for x in Edition.missing_coverage_from(self._db, gutenberg, web)]
+ )
def test_sort_by_priority(self):
# Make editions created by the license source, the metadata
# wrangler, and library staff.
- admin = self._edition(data_source_name=DataSource.LIBRARY_STAFF, with_license_pool=False)
- od = self._edition(data_source_name=DataSource.OVERDRIVE, with_license_pool=False)
- mw = self._edition(data_source_name=DataSource.METADATA_WRANGLER, with_license_pool=False)
+ admin = self._edition(
+ data_source_name=DataSource.LIBRARY_STAFF, with_license_pool=False
+ )
+ od = self._edition(
+ data_source_name=DataSource.OVERDRIVE, with_license_pool=False
+ )
+ mw = self._edition(
+ data_source_name=DataSource.METADATA_WRANGLER, with_license_pool=False
+ )
# Create an invalid edition with no data source. (This shouldn't
# happen.)
@@ -184,15 +190,15 @@ def test_equivalent_identifiers(self):
identifier.equivalent_to(data_source, edition.primary_identifier, 0.6)
- policy = PresentationCalculationPolicy(
- equivalent_identifier_threshold=0.5
+ policy = PresentationCalculationPolicy(equivalent_identifier_threshold=0.5)
+ assert set([identifier, edition.primary_identifier]) == set(
+ edition.equivalent_identifiers(policy=policy)
)
- assert (set([identifier, edition.primary_identifier]) ==
- set(edition.equivalent_identifiers(policy=policy)))
policy.equivalent_identifier_threshold = 0.7
- assert (set([edition.primary_identifier]) ==
- set(edition.equivalent_identifiers(policy=policy)))
+ assert set([edition.primary_identifier]) == set(
+ edition.equivalent_identifiers(policy=policy)
+ )
def test_recursive_edition_equivalence(self):
@@ -202,7 +208,8 @@ def test_recursive_edition_equivalence(self):
identifier_type=Identifier.GUTENBERG_ID,
identifier_id="1",
with_open_access_download=True,
- title="Original Gutenberg text")
+ title="Original Gutenberg text",
+ )
# Here's a Edition for an Open Library text.
open_library, open_library_pool = self._edition(
@@ -210,7 +217,8 @@ def test_recursive_edition_equivalence(self):
identifier_type=Identifier.OPEN_LIBRARY_ID,
identifier_id="W1111",
with_open_access_download=True,
- title="Open Library record")
+ title="Open Library record",
+ )
# We've learned from OCLC Classify that the Gutenberg text is
# equivalent to a certain OCLC Number. We've learned from OCLC
@@ -220,24 +228,27 @@ def test_recursive_edition_equivalence(self):
oclc_linked_data = DataSource.lookup(self._db, DataSource.OCLC_LINKED_DATA)
oclc_number, ignore = Identifier.for_foreign_id(
- self._db, Identifier.OCLC_NUMBER, "22")
- gutenberg.primary_identifier.equivalent_to(
- oclc_classify, oclc_number, 1)
- open_library.primary_identifier.equivalent_to(
- oclc_linked_data, oclc_number, 1)
+ self._db, Identifier.OCLC_NUMBER, "22"
+ )
+ gutenberg.primary_identifier.equivalent_to(oclc_classify, oclc_number, 1)
+ open_library.primary_identifier.equivalent_to(oclc_linked_data, oclc_number, 1)
# Here's a Edition for a Recovering the Classics cover.
web_source = DataSource.lookup(self._db, DataSource.WEB)
recovering, ignore = Edition.for_foreign_id(
- self._db, web_source, Identifier.URI,
- "http://recoveringtheclassics.com/pride-and-prejudice.jpg")
+ self._db,
+ web_source,
+ Identifier.URI,
+ "http://recoveringtheclassics.com/pride-and-prejudice.jpg",
+ )
recovering.title = "Recovering the Classics cover"
# We've manually associated that Edition's URI directly
# with the Project Gutenberg text.
manual = DataSource.lookup(self._db, DataSource.MANUAL)
gutenberg.primary_identifier.equivalent_to(
- manual, recovering.primary_identifier, 1)
+ manual, recovering.primary_identifier, 1
+ )
# Finally, here's a completely unrelated Edition, which
# will not be showing up.
@@ -246,7 +257,8 @@ def test_recursive_edition_equivalence(self):
identifier_type=Identifier.GUTENBERG_ID,
identifier_id="2",
with_open_access_download=True,
- title="Unrelated Gutenberg record.")
+ title="Unrelated Gutenberg record.",
+ )
# When we call equivalent_editions on the Project Gutenberg
# Edition, we get three Editions: the Gutenberg record
@@ -300,7 +312,7 @@ def test_calculate_presentation_author(self):
assert "Bob Bitshifter" == wr.author
assert "Bitshifter, Bob" == wr.sort_author
- bob.display_name="Bob A. Bitshifter"
+ bob.display_name = "Bob A. Bitshifter"
wr.calculate_presentation()
assert "Bob A. Bitshifter" == wr.author
assert "Bitshifter, Bob" == wr.sort_author
@@ -317,8 +329,9 @@ def test_set_summary(self):
overdrive = DataSource.lookup(self._db, DataSource.OVERDRIVE)
# Set the work's summmary.
- l1, new = pool.add_link(Hyperlink.DESCRIPTION, None, overdrive, "text/plain",
- "F")
+ l1, new = pool.add_link(
+ Hyperlink.DESCRIPTION, None, overdrive, "text/plain", "F"
+ )
work.set_summary(l1.resource)
assert l1.resource == work.summary
@@ -336,14 +349,20 @@ def test_calculate_evaluate_summary_quality_with_privileged_data_sources(self):
overdrive = DataSource.lookup(self._db, DataSource.OVERDRIVE)
# There's a perfunctory description from Overdrive.
- l1, new = pool.add_link(Hyperlink.SHORT_DESCRIPTION, None, overdrive, "text/plain",
- "F")
+ l1, new = pool.add_link(
+ Hyperlink.SHORT_DESCRIPTION, None, overdrive, "text/plain", "F"
+ )
overdrive_resource = l1.resource
# There's a much better description from OCLC Linked Data.
- l2, new = pool.add_link(Hyperlink.DESCRIPTION, None, oclc, "text/plain",
- """Nothing about working with his former high school crush, Stephanie Stephens, is ideal. Still, if Aaron Caruthers intends to save his grandmother's bakery, he must. Good thing he has a lot of ideas he can't wait to implement. He never imagines Stephanie would have her own ideas for the business. Or that they would clash with his!""")
+ l2, new = pool.add_link(
+ Hyperlink.DESCRIPTION,
+ None,
+ oclc,
+ "text/plain",
+ """Nothing about working with his former high school crush, Stephanie Stephens, is ideal. Still, if Aaron Caruthers intends to save his grandmother's bakery, he must. Good thing he has a lot of ideas he can't wait to implement. He never imagines Stephanie would have her own ideas for the business. Or that they would clash with his!""",
+ )
oclc_resource = l2.resource
# In a head-to-head evaluation, the OCLC Linked Data description wins.
@@ -356,7 +375,8 @@ def test_calculate_evaluate_summary_quality_with_privileged_data_sources(self):
# But if we say that Overdrive is the privileged data source, it wins
# automatically. The other resource isn't even considered.
champ2, resources2 = Identifier.evaluate_summary_quality(
- self._db, ids, [overdrive])
+ self._db, ids, [overdrive]
+ )
assert overdrive_resource == champ2
assert [overdrive_resource] == resources2
@@ -366,14 +386,16 @@ def test_calculate_evaluate_summary_quality_with_privileged_data_sources(self):
# wins.
threem = DataSource.lookup(self._db, DataSource.THREEM)
champ3, resources3 = Identifier.evaluate_summary_quality(
- self._db, ids, [threem])
+ self._db, ids, [threem]
+ )
assert set([overdrive_resource, oclc_resource]) == set(resources3)
assert oclc_resource == champ3
# If there are two privileged data sources and there's no
# description from the first, the second is used.
champ4, resources4 = Identifier.evaluate_summary_quality(
- self._db, ids, [threem, overdrive])
+ self._db, ids, [threem, overdrive]
+ )
assert [overdrive_resource] == resources4
assert overdrive_resource == champ4
@@ -381,17 +403,22 @@ def test_calculate_evaluate_summary_quality_with_privileged_data_sources(self):
# This is not a silly example. The librarian may choose to set the description
# to an empty string in the admin inteface, to override a bad overdrive/etc. description.
staff = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF)
- l3, new = pool.add_link(Hyperlink.SHORT_DESCRIPTION, None, staff, "text/plain", "")
+ l3, new = pool.add_link(
+ Hyperlink.SHORT_DESCRIPTION, None, staff, "text/plain", ""
+ )
staff_resource = l3.resource
champ5, resources5 = Identifier.evaluate_summary_quality(
- self._db, ids, [staff, overdrive])
+ self._db, ids, [staff, overdrive]
+ )
assert [staff_resource] == resources5
assert staff_resource == champ5
def test_calculate_presentation_cover(self):
# Here's a cover image with a thumbnail.
- representation, ignore = get_one_or_create(self._db, Representation, url="http://cover")
+ representation, ignore = get_one_or_create(
+ self._db, Representation, url="http://cover"
+ )
representation.media_type = Representation.JPEG_MEDIA_TYPE
representation.mirrored_at = utc_now()
representation.mirror_url = "http://mirror/cover"
@@ -403,7 +430,9 @@ def test_calculate_presentation_cover(self):
# Verify that a cover for the edition's primary identifier is used.
e, pool = self._edition(with_license_pool=True)
- link, ignore = e.primary_identifier.add_link(Hyperlink.IMAGE, "http://cover", e.data_source)
+ link, ignore = e.primary_identifier.add_link(
+ Hyperlink.IMAGE, "http://cover", e.data_source
+ )
link.resource.representation = representation
e.calculate_presentation()
assert "http://mirror/cover" == e.cover_full_url
@@ -414,10 +443,12 @@ def test_calculate_presentation_cover(self):
e, pool = self._edition(with_license_pool=True)
oclc_classify = DataSource.lookup(self._db, DataSource.OCLC)
oclc_number, ignore = Identifier.for_foreign_id(
- self._db, Identifier.OCLC_NUMBER, "22")
- e.primary_identifier.equivalent_to(
- oclc_classify, oclc_number, 1)
- link, ignore = oclc_number.add_link(Hyperlink.IMAGE, "http://cover", oclc_classify)
+ self._db, Identifier.OCLC_NUMBER, "22"
+ )
+ e.primary_identifier.equivalent_to(oclc_classify, oclc_number, 1)
+ link, ignore = oclc_number.add_link(
+ Hyperlink.IMAGE, "http://cover", oclc_classify
+ )
link.resource.representation = representation
e.calculate_presentation()
assert "http://mirror/cover" == e.cover_full_url
@@ -425,13 +456,19 @@ def test_calculate_presentation_cover(self):
# Verify that a nearby cover takes precedence over a
# faraway cover.
- link, ignore = e.primary_identifier.add_link(Hyperlink.IMAGE, "http://nearby-cover", e.data_source)
- nearby, ignore = get_one_or_create(self._db, Representation, url=link.resource.url)
+ link, ignore = e.primary_identifier.add_link(
+ Hyperlink.IMAGE, "http://nearby-cover", e.data_source
+ )
+ nearby, ignore = get_one_or_create(
+ self._db, Representation, url=link.resource.url
+ )
nearby.media_type = Representation.JPEG_MEDIA_TYPE
nearby.mirrored_at = utc_now()
nearby.mirror_url = "http://mirror/nearby-cover"
link.resource.representation = nearby
- nearby_thumb, ignore = get_one_or_create(self._db, Representation, url="http://nearby-thumb")
+ nearby_thumb, ignore = get_one_or_create(
+ self._db, Representation, url="http://nearby-thumb"
+ )
nearby_thumb.media_type = Representation.JPEG_MEDIA_TYPE
nearby_thumb.mirrored_at = utc_now()
nearby_thumb.mirror_url = "http://mirror/nearby-thumb"
@@ -443,13 +480,14 @@ def test_calculate_presentation_cover(self):
# Verify that a thumbnail is used even if there's
# no full-sized cover.
e, pool = self._edition(with_license_pool=True)
- link, ignore = e.primary_identifier.add_link(Hyperlink.THUMBNAIL_IMAGE, "http://thumb", e.data_source)
+ link, ignore = e.primary_identifier.add_link(
+ Hyperlink.THUMBNAIL_IMAGE, "http://thumb", e.data_source
+ )
link.resource.representation = thumb
e.calculate_presentation()
assert None == e.cover_full_url
assert "http://mirror/thumb" == e.cover_thumbnail_url
-
def test_calculate_presentation_registers_coverage_records(self):
edition = self._edition()
identifier = edition.primary_identifier
@@ -465,9 +503,11 @@ def test_calculate_presentation_registers_coverage_records(self):
# One for setting the Edition metadata and one for choosing
# the Edition's cover.
- expect = set([
- CoverageRecord.SET_EDITION_METADATA_OPERATION,
- CoverageRecord.CHOOSE_COVER_OPERATION]
+ expect = set(
+ [
+ CoverageRecord.SET_EDITION_METADATA_OPERATION,
+ CoverageRecord.CHOOSE_COVER_OPERATION,
+ ]
)
assert expect == set([x.operation for x in records])
@@ -475,9 +515,9 @@ def test_calculate_presentation_registers_coverage_records(self):
# Edition, not just the Identifier, because each
# CoverageRecord's DataSource is set to this Edition's
# DataSource.
- assert (
- [edition.data_source, edition.data_source] ==
- [x.data_source for x in records])
+ assert [edition.data_source, edition.data_source] == [
+ x.data_source for x in records
+ ]
def test_no_permanent_work_id_for_edition_without_title_or_medium(self):
# An edition with no title or medium is not assigned a permanent work
@@ -485,11 +525,11 @@ def test_no_permanent_work_id_for_edition_without_title_or_medium(self):
edition = self._edition()
assert None == edition.permanent_work_id
- edition.title = ''
+ edition.title = ""
edition.calculate_permanent_work_id()
assert None == edition.permanent_work_id
- edition.title = 'something'
+ edition.title = "something"
edition.calculate_permanent_work_id()
assert None != edition.permanent_work_id
@@ -503,12 +543,16 @@ def test_choose_cover_can_choose_full_image_and_thumbnail_separately(self):
# This edition has a full-sized image and a thumbnail image,
# but there is no evidence that they are the _same_ image.
main_image, ignore = edition.primary_identifier.add_link(
- Hyperlink.IMAGE, "http://main/",
- edition.data_source, Representation.PNG_MEDIA_TYPE
+ Hyperlink.IMAGE,
+ "http://main/",
+ edition.data_source,
+ Representation.PNG_MEDIA_TYPE,
)
thumbnail_image, ignore = edition.primary_identifier.add_link(
- Hyperlink.THUMBNAIL_IMAGE, "http://thumbnail/",
- edition.data_source, Representation.PNG_MEDIA_TYPE
+ Hyperlink.THUMBNAIL_IMAGE,
+ "http://thumbnail/",
+ edition.data_source,
+ Representation.PNG_MEDIA_TYPE,
)
# Nonetheless, Edition.choose_cover() will assign the
@@ -522,10 +566,14 @@ def test_choose_cover_can_choose_full_image_and_thumbnail_separately(self):
# associated with the identifier is a thumbnail _of_ the
# full-sized image...
thumbnail_2, ignore = edition.primary_identifier.add_link(
- Hyperlink.THUMBNAIL_IMAGE, "http://thumbnail2/",
- edition.data_source, Representation.PNG_MEDIA_TYPE
+ Hyperlink.THUMBNAIL_IMAGE,
+ "http://thumbnail2/",
+ edition.data_source,
+ Representation.PNG_MEDIA_TYPE,
+ )
+ thumbnail_2.resource.representation.thumbnail_of = (
+ main_image.resource.representation
)
- thumbnail_2.resource.representation.thumbnail_of = main_image.resource.representation
edition.choose_cover()
# ...That thumbnail will be chosen in preference to the
diff --git a/tests/models/test_hasfulltablecache.py b/tests/models/test_hasfulltablecache.py
index 7cb5b5abd..20a5c7ecd 100644
--- a/tests/models/test_hasfulltablecache.py
+++ b/tests/models/test_hasfulltablecache.py
@@ -1,6 +1,7 @@
# encoding: utf-8
-from ...testing import DatabaseTest
from ...model.hasfulltablecache import HasFullTableCache
+from ...testing import DatabaseTest
+
class MockHasTableCache(HasFullTableCache):
@@ -21,8 +22,8 @@ def id(self):
def cache_key(self):
return self.KEY
-class TestHasFullTableCache(DatabaseTest):
+class TestHasFullTableCache(DatabaseTest):
def setup_method(self):
super(TestHasFullTableCache, self).setup_method()
self.mock_class = MockHasTableCache
diff --git a/tests/models/test_identifier.py b/tests/models/test_identifier.py
index d17b60540..3ab7b2c64 100644
--- a/tests/models/test_identifier.py
+++ b/tests/models/test_identifier.py
@@ -1,7 +1,8 @@
# encoding: utf-8
-import pytest
import datetime
+
import feedparser
+import pytest
from lxml import etree
from mock import PropertyMock, create_autospec
from parameterized import parameterized
@@ -15,21 +16,22 @@
from ...util.datetime_helpers import utc_now
from ...util.opds_writer import AtomFeed
+
class TestIdentifier(DatabaseTest):
def test_for_foreign_id(self):
identifier_type = Identifier.ISBN
isbn = "3293000061"
# Getting the data automatically creates a database record.
- identifier, was_new = Identifier.for_foreign_id(
- self._db, identifier_type, isbn)
+ identifier, was_new = Identifier.for_foreign_id(self._db, identifier_type, isbn)
assert Identifier.ISBN == identifier.type
assert isbn == identifier.identifier
assert True == was_new
# If we get it again we get the same data, but it's no longer new.
identifier2, was_new = Identifier.for_foreign_id(
- self._db, identifier_type, isbn)
+ self._db, identifier_type, isbn
+ )
assert identifier == identifier2
assert False == was_new
@@ -37,9 +39,7 @@ def test_for_foreign_id(self):
assert None == Identifier.for_foreign_id(self._db, None, None)
def test_for_foreign_id_by_deprecated_type(self):
- threem_id, is_new = Identifier.for_foreign_id(
- self._db, "3M ID", self._str
- )
+ threem_id, is_new = Identifier.for_foreign_id(self._db, "3M ID", self._str)
assert Identifier.BIBLIOTHECA_ID == threem_id.type
assert Identifier.BIBLIOTHECA_ID != "3M ID"
@@ -65,15 +65,16 @@ def test_for_foreign_id_without_autocreate(self):
# We don't want to auto-create a database record, so we set
# autocreate=False
identifier, was_new = Identifier.for_foreign_id(
- self._db, identifier_type, isbn, autocreate=False)
+ self._db, identifier_type, isbn, autocreate=False
+ )
assert None == identifier
assert False == was_new
def test_from_asin(self):
- isbn10 = '1449358063'
- isbn13 = '9781449358068'
- asin = 'B0088IYM3C'
- isbn13_with_dashes = '978-144-935-8068'
+ isbn10 = "1449358063"
+ isbn13 = "9781449358068"
+ asin = "B0088IYM3C"
+ isbn13_with_dashes = "978-144-935-8068"
i_isbn10, new1 = Identifier.from_asin(self._db, isbn10)
i_isbn13, new2 = Identifier.from_asin(self._db, isbn13)
@@ -96,18 +97,22 @@ def test_from_asin(self):
def test_urn(self):
# ISBN identifiers use the ISBN URN scheme.
identifier, ignore = Identifier.for_foreign_id(
- self._db, Identifier.ISBN, "9781449358068")
+ self._db, Identifier.ISBN, "9781449358068"
+ )
assert "urn:isbn:9781449358068" == identifier.urn
# URI identifiers don't need a URN scheme.
identifier, ignore = Identifier.for_foreign_id(
- self._db, Identifier.URI, "http://example.com/")
+ self._db, Identifier.URI, "http://example.com/"
+ )
assert identifier.identifier == identifier.urn
# Gutenberg identifiers use Gutenberg's URL-based sceheme
identifier = self._identifier(Identifier.GUTENBERG_ID)
- assert (Identifier.GUTENBERG_URN_SCHEME_PREFIX + identifier.identifier ==
- identifier.urn)
+ assert (
+ Identifier.GUTENBERG_URN_SCHEME_PREFIX + identifier.identifier
+ == identifier.urn
+ )
# All other identifiers use our custom URN scheme.
identifier = self._identifier(Identifier.OVERDRIVE_ID)
@@ -130,7 +135,7 @@ def test_parse_urns(self):
# Only the existing identifier is included in the results.
assert 1 == len(identifiers_by_urn)
- assert {identifier.urn : identifier} == identifiers_by_urn
+ assert {identifier.urn: identifier} == identifiers_by_urn
# By default, new identifiers are created, too.
results = Identifier.parse_urns(self._db, urns)
@@ -168,9 +173,7 @@ def test_parse_urns(self):
assert isbn_urn in failure
success, failure = Identifier.parse_urns(
- self._db, urns, allowed_types=[
- Identifier.OVERDRIVE_ID, Identifier.ISBN
- ]
+ self._db, urns, allowed_types=[Identifier.OVERDRIVE_ID, Identifier.ISBN]
)
assert new_urn in success
assert isbn_urn in success
@@ -178,9 +181,7 @@ def test_parse_urns(self):
# If the allowed_types is empty, no URNs can be looked up
# -- this is most likely the caller's mistake.
- success, failure = Identifier.parse_urns(
- self._db, urns, allowed_types=[]
- )
+ success, failure = Identifier.parse_urns(self._db, urns, allowed_types=[])
assert new_urn in failure
assert isbn_urn in failure
@@ -195,7 +196,8 @@ def test_parse_urn(self):
# We can parse urn:isbn URNs into ISBN identifiers. ISBN-10s are
# converted to ISBN-13s.
identifier, ignore = Identifier.for_foreign_id(
- self._db, Identifier.ISBN, "9781449358068")
+ self._db, Identifier.ISBN, "9781449358068"
+ )
isbn_urn = "urn:isbn:1449358063"
isbn_identifier, ignore = Identifier.parse_urn(self._db, isbn_urn)
assert Identifier.ISBN == isbn_identifier.type
@@ -207,21 +209,23 @@ def test_parse_urn(self):
# We can parse ordinary http: or https: URLs into URI
# identifiers.
- http_identifier, ignore = Identifier.parse_urn(
- self._db, "http://example.com")
+ http_identifier, ignore = Identifier.parse_urn(self._db, "http://example.com")
assert Identifier.URI == http_identifier.type
assert "http://example.com" == http_identifier.identifier
- https_identifier, ignore = Identifier.parse_urn(
- self._db, "https://example.com")
+ https_identifier, ignore = Identifier.parse_urn(self._db, "https://example.com")
assert Identifier.URI == https_identifier.type
assert "https://example.com" == https_identifier.identifier
# We can parse UUIDs.
uuid_identifier, ignore = Identifier.parse_urn(
- self._db, "urn:uuid:04377e87-ab69-41c8-a2a4-812d55dc0952")
+ self._db, "urn:uuid:04377e87-ab69-41c8-a2a4-812d55dc0952"
+ )
assert Identifier.URI == uuid_identifier.type
- assert "urn:uuid:04377e87-ab69-41c8-a2a4-812d55dc0952" == uuid_identifier.identifier
+ assert (
+ "urn:uuid:04377e87-ab69-41c8-a2a4-812d55dc0952"
+ == uuid_identifier.identifier
+ )
# A URN we can't handle raises an exception.
ftp_urn = "ftp://example.com"
@@ -240,8 +244,11 @@ def parse_urn_must_support_license_pools(self):
isbn_urn = "urn:isbn:1449358063"
pytest.raises(
Identifier.UnresolvableIdentifierException,
- Identifier.parse_urn, self._db, isbn_urn,
- must_support_license_pools=True)
+ Identifier.parse_urn,
+ self._db,
+ isbn_urn,
+ must_support_license_pools=True,
+ )
def test_recursively_equivalent_identifier_ids(self):
identifier = self._identifier()
@@ -266,44 +273,44 @@ def test_recursively_equivalent_identifier_ids(self):
# With a low threshold and enough levels, we find all the identifiers.
high_levels_low_threshold = PresentationCalculationPolicy(
- equivalent_identifier_levels=5,
- equivalent_identifier_threshold=0.1
+ equivalent_identifier_levels=5, equivalent_identifier_threshold=0.1
)
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [identifier.id], policy=high_levels_low_threshold
)
- assert (set([identifier.id,
- strong_equivalent.id,
- weak_equivalent.id,
- level_2_equivalent.id,
- level_3_equivalent.id,
- level_4_equivalent.id]) ==
- set(equivs[identifier.id]))
+ assert (
+ set(
+ [
+ identifier.id,
+ strong_equivalent.id,
+ weak_equivalent.id,
+ level_2_equivalent.id,
+ level_3_equivalent.id,
+ level_4_equivalent.id,
+ ]
+ )
+ == set(equivs[identifier.id])
+ )
# If we only look at one level, we don't find the level 2, 3, or 4 identifiers.
one_level = PresentationCalculationPolicy(
- equivalent_identifier_levels=1,
- equivalent_identifier_threshold=0.1
+ equivalent_identifier_levels=1, equivalent_identifier_threshold=0.1
)
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [identifier.id], policy=one_level
)
- assert (set([identifier.id,
- strong_equivalent.id,
- weak_equivalent.id]) ==
- set(equivs[identifier.id]))
+ assert set([identifier.id, strong_equivalent.id, weak_equivalent.id]) == set(
+ equivs[identifier.id]
+ )
# If we raise the threshold, we don't find the weak identifier.
one_level_high_threshold = PresentationCalculationPolicy(
- equivalent_identifier_levels=1,
- equivalent_identifier_threshold=0.4
- )
+ equivalent_identifier_levels=1, equivalent_identifier_threshold=0.4
+ )
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [identifier.id], policy=one_level_high_threshold
)
- assert (set([identifier.id,
- strong_equivalent.id]) ==
- set(equivs[identifier.id]))
+ assert set([identifier.id, strong_equivalent.id]) == set(equivs[identifier.id])
# For deeper levels, the strength is the product of the strengths
# of all the equivalencies in between the two identifiers.
@@ -315,60 +322,73 @@ def test_recursively_equivalent_identifier_ids(self):
# With a threshold of 0.5, level 2 and all subsequent levels are too weak.
high_levels_high_threshold = PresentationCalculationPolicy(
- equivalent_identifier_levels=5,
- equivalent_identifier_threshold=0.5
- )
+ equivalent_identifier_levels=5, equivalent_identifier_threshold=0.5
+ )
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [identifier.id], policy=high_levels_high_threshold
)
- assert (set([identifier.id,
- strong_equivalent.id]) ==
- set(equivs[identifier.id]))
+ assert set([identifier.id, strong_equivalent.id]) == set(equivs[identifier.id])
# With a threshold of 0.25, level 2 is strong enough, but level
# 4 is too weak.
high_levels_lower_threshold = PresentationCalculationPolicy(
- equivalent_identifier_levels=5,
- equivalent_identifier_threshold=0.25
- )
+ equivalent_identifier_levels=5, equivalent_identifier_threshold=0.25
+ )
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [identifier.id], policy=high_levels_lower_threshold
)
- assert (set([identifier.id,
- strong_equivalent.id,
- level_2_equivalent.id,
- level_3_equivalent.id]) ==
- set(equivs[identifier.id]))
+ assert (
+ set(
+ [
+ identifier.id,
+ strong_equivalent.id,
+ level_2_equivalent.id,
+ level_3_equivalent.id,
+ ]
+ )
+ == set(equivs[identifier.id])
+ )
# It also works if we start from other identifiers.
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [strong_equivalent.id], policy=high_levels_low_threshold
)
- assert (set([identifier.id,
- strong_equivalent.id,
- weak_equivalent.id,
- level_2_equivalent.id,
- level_3_equivalent.id,
- level_4_equivalent.id]) ==
- set(equivs[strong_equivalent.id]))
+ assert (
+ set(
+ [
+ identifier.id,
+ strong_equivalent.id,
+ weak_equivalent.id,
+ level_2_equivalent.id,
+ level_3_equivalent.id,
+ level_4_equivalent.id,
+ ]
+ )
+ == set(equivs[strong_equivalent.id])
+ )
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [level_4_equivalent.id], policy=high_levels_low_threshold
)
- assert (set([identifier.id,
- strong_equivalent.id,
- level_2_equivalent.id,
- level_3_equivalent.id,
- level_4_equivalent.id]) ==
- set(equivs[level_4_equivalent.id]))
+ assert (
+ set(
+ [
+ identifier.id,
+ strong_equivalent.id,
+ level_2_equivalent.id,
+ level_3_equivalent.id,
+ level_4_equivalent.id,
+ ]
+ )
+ == set(equivs[level_4_equivalent.id])
+ )
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [level_4_equivalent.id], policy=high_levels_high_threshold
)
- assert (set([level_2_equivalent.id,
- level_3_equivalent.id,
- level_4_equivalent.id]) ==
- set(equivs[level_4_equivalent.id]))
+ assert set(
+ [level_2_equivalent.id, level_3_equivalent.id, level_4_equivalent.id]
+ ) == set(equivs[level_4_equivalent.id])
# A chain of very strong equivalents can keep a high strength
# even at deep levels. This wouldn't work if we changed the strength
@@ -381,34 +401,28 @@ def test_recursively_equivalent_identifier_ids(self):
l3.equivalent_to(data_source, l2, 1)
l4.equivalent_to(data_source, l3, 0.9)
high_levels_fairly_high_threshold = PresentationCalculationPolicy(
- equivalent_identifier_levels=5,
- equivalent_identifier_threshold=0.89
+ equivalent_identifier_levels=5, equivalent_identifier_threshold=0.89
)
equivs = Identifier.recursively_equivalent_identifier_ids(
- self._db, [another_identifier.id],
- high_levels_fairly_high_threshold
+ self._db, [another_identifier.id], high_levels_fairly_high_threshold
+ )
+ assert set([another_identifier.id, l2.id, l3.id, l4.id]) == set(
+ equivs[another_identifier.id]
)
- assert (set([another_identifier.id,
- l2.id,
- l3.id,
- l4.id]) ==
- set(equivs[another_identifier.id]))
# We can look for multiple identifiers at once.
two_levels_high_threshold = PresentationCalculationPolicy(
- equivalent_identifier_levels=2,
- equivalent_identifier_threshold=0.8
+ equivalent_identifier_levels=2, equivalent_identifier_threshold=0.8
)
equivs = Identifier.recursively_equivalent_identifier_ids(
- self._db, [identifier.id, level_3_equivalent.id],
- policy=two_levels_high_threshold
+ self._db,
+ [identifier.id, level_3_equivalent.id],
+ policy=two_levels_high_threshold,
+ )
+ assert set([identifier.id, strong_equivalent.id]) == set(equivs[identifier.id])
+ assert set([level_2_equivalent.id, level_3_equivalent.id]) == set(
+ equivs[level_3_equivalent.id]
)
- assert (set([identifier.id,
- strong_equivalent.id]) ==
- set(equivs[identifier.id]))
- assert (set([level_2_equivalent.id,
- level_3_equivalent.id]) ==
- set(equivs[level_3_equivalent.id]))
# By setting a cutoff, you can say to look deep in the tree,
# but stop looking as soon as you have a certain number of
@@ -416,19 +430,19 @@ def test_recursively_equivalent_identifier_ids(self):
with_cutoff = PresentationCalculationPolicy(
equivalent_identifier_levels=5,
equivalent_identifier_threshold=0.1,
- equivalent_identifier_cutoff=1,
- )
+ equivalent_identifier_cutoff=1,
+ )
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [identifier.id], policy=with_cutoff
)
-
+
# The cutoff was set to 1, but we always go at least one level
# deep, and that gives us three equivalent identifiers. We
# don't artificially trim it back down to 1.
assert 3 == len(equivs[identifier.id])
# Increase the cutoff, and we get more identifiers.
- with_cutoff.equivalent_identifier_cutoff=5
+ with_cutoff.equivalent_identifier_cutoff = 5
equivs = Identifier.recursively_equivalent_identifier_ids(
self._db, [identifier.id], policy=with_cutoff
)
@@ -440,16 +454,22 @@ def test_recursively_equivalent_identifier_ids(self):
query = Identifier.recursively_equivalent_identifier_ids_query(
Identifier.id, policy=high_levels_low_threshold
)
- query = query.where(Identifier.id==identifier.id)
+ query = query.where(Identifier.id == identifier.id)
results = self._db.execute(query)
equivalent_ids = [r[0] for r in results]
- assert (set([identifier.id,
- strong_equivalent.id,
- weak_equivalent.id,
- level_2_equivalent.id,
- level_3_equivalent.id,
- level_4_equivalent.id]) ==
- set(equivalent_ids))
+ assert (
+ set(
+ [
+ identifier.id,
+ strong_equivalent.id,
+ weak_equivalent.id,
+ level_2_equivalent.id,
+ level_3_equivalent.id,
+ level_4_equivalent.id,
+ ]
+ )
+ == set(equivalent_ids)
+ )
query = Identifier.recursively_equivalent_identifier_ids_query(
Identifier.id, policy=two_levels_high_threshold
@@ -457,11 +477,17 @@ def test_recursively_equivalent_identifier_ids(self):
query = query.where(Identifier.id.in_([identifier.id, level_3_equivalent.id]))
results = self._db.execute(query)
equivalent_ids = [r[0] for r in results]
- assert (set([identifier.id,
- strong_equivalent.id,
- level_2_equivalent.id,
- level_3_equivalent.id]) ==
- set(equivalent_ids))
+ assert (
+ set(
+ [
+ identifier.id,
+ strong_equivalent.id,
+ level_2_equivalent.id,
+ level_3_equivalent.id,
+ ]
+ )
+ == set(equivalent_ids)
+ )
def test_licensed_through_collection(self):
c1 = self._default_collection
@@ -485,10 +511,12 @@ def test_missing_coverage_from(self):
# Here are two Gutenberg records.
g1, ignore = Edition.for_foreign_id(
- self._db, gutenberg, Identifier.GUTENBERG_ID, "1")
+ self._db, gutenberg, Identifier.GUTENBERG_ID, "1"
+ )
g2, ignore = Edition.for_foreign_id(
- self._db, gutenberg, Identifier.GUTENBERG_ID, "2")
+ self._db, gutenberg, Identifier.GUTENBERG_ID, "2"
+ )
# One of them has coverage from OCLC Classify
c1 = self._coverage_record(g1, oclc)
@@ -498,7 +526,8 @@ def test_missing_coverage_from(self):
# Here's a web record, just sitting there.
w, ignore = Edition.for_foreign_id(
- self._db, web, Identifier.URI, "http://www.foo.com/")
+ self._db, web, Identifier.URI, "http://www.foo.com/"
+ )
# If we run missing_coverage_from we pick up the Gutenberg
# record with no generic OCLC coverage. It doesn't pick up the
@@ -506,7 +535,8 @@ def test_missing_coverage_from(self):
# and it doesn't pick up the OCLC coverage for a specific
# operation.
[in_gutenberg_but_not_in_oclc] = Identifier.missing_coverage_from(
- self._db, [Identifier.GUTENBERG_ID], oclc).all()
+ self._db, [Identifier.GUTENBERG_ID], oclc
+ ).all()
assert g2.primary_identifier == in_gutenberg_but_not_in_oclc
@@ -515,20 +545,27 @@ def test_missing_coverage_from(self):
# that has generic OCLC coverage.
[has_generic_coverage_only] = Identifier.missing_coverage_from(
- self._db, [Identifier.GUTENBERG_ID], oclc, "some operation").all()
+ self._db, [Identifier.GUTENBERG_ID], oclc, "some operation"
+ ).all()
assert g1.primary_identifier == has_generic_coverage_only
# We don't put web sites into OCLC, so this will pick up the
# web record (but not the Gutenberg record).
[in_web_but_not_in_oclc] = Identifier.missing_coverage_from(
- self._db, [Identifier.URI], oclc).all()
+ self._db, [Identifier.URI], oclc
+ ).all()
assert w.primary_identifier == in_web_but_not_in_oclc
# We don't use the web as a source of coverage, so this will
# return both Gutenberg records (but not the web record).
assert [g1.primary_identifier.id, g2.primary_identifier.id] == sorted(
- [x.id for x in Identifier.missing_coverage_from(
- self._db, [Identifier.GUTENBERG_ID], web)])
+ [
+ x.id
+ for x in Identifier.missing_coverage_from(
+ self._db, [Identifier.GUTENBERG_ID], web
+ )
+ ]
+ )
def test_missing_coverage_from_with_collection(self):
gutenberg = DataSource.lookup(self._db, DataSource.GUTENBERG)
@@ -538,24 +575,24 @@ def test_missing_coverage_from_with_collection(self):
self._coverage_record(identifier, gutenberg, collection=collection1)
# The Identifier has coverage in collection 1.
- assert ([] ==
- Identifier.missing_coverage_from(
+ assert (
+ []
+ == Identifier.missing_coverage_from(
self._db, [identifier.type], gutenberg, collection=collection1
- ).all())
+ ).all()
+ )
# It is missing coverage in collection 2.
- assert (
- [identifier] == Identifier.missing_coverage_from(
- self._db, [identifier.type], gutenberg, collection=collection2
- ).all())
+ assert [identifier] == Identifier.missing_coverage_from(
+ self._db, [identifier.type], gutenberg, collection=collection2
+ ).all()
# If no collection is specified, we look for a CoverageRecord
# that also has no collection specified, and the Identifier is
# not treated as covered.
- assert ([identifier] ==
- Identifier.missing_coverage_from(
- self._db, [identifier.type], gutenberg
- ).all())
+ assert [identifier] == Identifier.missing_coverage_from(
+ self._db, [identifier.type], gutenberg
+ ).all()
def test_missing_coverage_from_with_cutoff_date(self):
gutenberg = DataSource.lookup(self._db, DataSource.GUTENBERG)
@@ -564,7 +601,8 @@ def test_missing_coverage_from_with_cutoff_date(self):
# Here's an Edition with a coverage record from OCLC classify.
gutenberg, ignore = Edition.for_foreign_id(
- self._db, gutenberg, Identifier.GUTENBERG_ID, "1")
+ self._db, gutenberg, Identifier.GUTENBERG_ID, "1"
+ )
identifier = gutenberg.primary_identifier
oclc = DataSource.lookup(self._db, DataSource.OCLC)
coverage = self._coverage_record(gutenberg, oclc)
@@ -575,43 +613,50 @@ def test_missing_coverage_from_with_cutoff_date(self):
# If we ask for Identifiers that are missing coverage records
# as of that time, we see nothing.
assert (
- [] ==
- Identifier.missing_coverage_from(
- self._db, [identifier.type], oclc,
- count_as_missing_before=timestamp
- ).all())
+ []
+ == Identifier.missing_coverage_from(
+ self._db, [identifier.type], oclc, count_as_missing_before=timestamp
+ ).all()
+ )
# But if we give a time one second later, the Identifier is
# missing coverage.
- assert (
- [identifier] ==
- Identifier.missing_coverage_from(
- self._db, [identifier.type], oclc,
- count_as_missing_before=timestamp+datetime.timedelta(seconds=1)
- ).all())
+ assert [identifier] == Identifier.missing_coverage_from(
+ self._db,
+ [identifier.type],
+ oclc,
+ count_as_missing_before=timestamp + datetime.timedelta(seconds=1),
+ ).all()
def test_opds_entry(self):
identifier = self._identifier()
source = DataSource.lookup(self._db, DataSource.CONTENT_CAFE)
summary = identifier.add_link(
- Hyperlink.DESCRIPTION, 'http://description', source,
- media_type=Representation.TEXT_PLAIN, content='a book'
+ Hyperlink.DESCRIPTION,
+ "http://description",
+ source,
+ media_type=Representation.TEXT_PLAIN,
+ content="a book",
)[0]
cover = identifier.add_link(
- Hyperlink.IMAGE, 'http://cover', source,
- media_type=Representation.JPEG_MEDIA_TYPE
+ Hyperlink.IMAGE,
+ "http://cover",
+ source,
+ media_type=Representation.JPEG_MEDIA_TYPE,
)[0]
def get_entry_dict(entry):
- return feedparser.parse(etree.tostring(entry, encoding="unicode")).entries[0]
+ return feedparser.parse(etree.tostring(entry, encoding="unicode")).entries[
+ 0
+ ]
# The entry includes the urn, description, and cover link.
entry = get_entry_dict(identifier.opds_entry())
assert identifier.urn == entry.id
- assert 'a book' == entry.summary
+ assert "a book" == entry.summary
[cover_link] = entry.links
- assert 'http://cover' == cover_link.href
+ assert "http://cover" == cover_link.href
# The 'updated' time is set to the latest timestamp associated
# with the Identifier.
@@ -648,8 +693,10 @@ def get_entry_dict(entry):
# or a representation.
even_later = now + datetime.timedelta(minutes=120)
thumbnail = identifier.add_link(
- Hyperlink.THUMBNAIL_IMAGE, 'http://thumb', source,
- media_type=Representation.JPEG_MEDIA_TYPE
+ Hyperlink.THUMBNAIL_IMAGE,
+ "http://thumb",
+ source,
+ media_type=Representation.JPEG_MEDIA_TYPE,
)[0]
thumb_rep = thumbnail.resource.representation
cover_rep = cover.resource.representation
@@ -660,27 +707,27 @@ def get_entry_dict(entry):
entry = get_entry_dict(identifier.opds_entry())
# The thumbnail has been added to the links.
assert 2 == len(entry.links)
- assert any(filter(lambda l: l.href=='http://thumb', entry.links))
+ assert any(filter(lambda l: l.href == "http://thumb", entry.links))
# And the updated time has been changed accordingly.
expected = thumbnail.resource.representation.mirrored_at
assert AtomFeed._strftime(even_later) == entry.updated
- @parameterized.expand([
- ('ascii_type_ascii_identifier_no_title', 'a', 'a', None),
- ('ascii_type_non_ascii_identifier_no_title', 'a', 'ą', None),
- ('non_ascii_type_ascii_identifier_no_title', 'ą', 'a', None),
- ('non_ascii_type_non_ascii_identifier_no_title', 'ą', 'ą', None),
-
- ('ascii_type_ascii_identifier_ascii_title', 'a', 'a', 'a'),
- ('ascii_type_non_ascii_identifier_ascii_title', 'a', 'ą', 'a'),
- ('non_ascii_type_ascii_identifier_ascii_title', 'ą', 'a', 'a'),
- ('non_ascii_type_non_ascii_identifier_ascii_title', 'ą', 'ą', 'a'),
-
- ('ascii_type_ascii_identifier_non_ascii_title', 'a', 'a', 'ą'),
- ('ascii_type_non_ascii_identifier_non_ascii_title', 'a', 'ą', 'ą'),
- ('non_ascii_type_ascii_identifier_non_ascii_title', 'ą', 'a', 'ą'),
- ('non_ascii_type_non_ascii_identifier_non_ascii_title', 'ą', 'ą', 'ą'),
- ])
+ @parameterized.expand(
+ [
+ ("ascii_type_ascii_identifier_no_title", "a", "a", None),
+ ("ascii_type_non_ascii_identifier_no_title", "a", "ą", None),
+ ("non_ascii_type_ascii_identifier_no_title", "ą", "a", None),
+ ("non_ascii_type_non_ascii_identifier_no_title", "ą", "ą", None),
+ ("ascii_type_ascii_identifier_ascii_title", "a", "a", "a"),
+ ("ascii_type_non_ascii_identifier_ascii_title", "a", "ą", "a"),
+ ("non_ascii_type_ascii_identifier_ascii_title", "ą", "a", "a"),
+ ("non_ascii_type_non_ascii_identifier_ascii_title", "ą", "ą", "a"),
+ ("ascii_type_ascii_identifier_non_ascii_title", "a", "a", "ą"),
+ ("ascii_type_non_ascii_identifier_non_ascii_title", "a", "ą", "ą"),
+ ("non_ascii_type_ascii_identifier_non_ascii_title", "ą", "a", "ą"),
+ ("non_ascii_type_non_ascii_identifier_non_ascii_title", "ą", "ą", "ą"),
+ ]
+ )
def test_repr(self, _, identifier_type, identifier, title):
"""Test that Identifier.__repr__ correctly works with both ASCII and non-ASCII symbols.
diff --git a/tests/models/test_integrationclient.py b/tests/models/test_integrationclient.py
index e8b419040..6f5f00e01 100644
--- a/tests/models/test_integrationclient.py
+++ b/tests/models/test_integrationclient.py
@@ -1,13 +1,14 @@
# encoding: utf-8
import datetime
+
import pytest
-from ...testing import DatabaseTest
from ...model.integrationclient import IntegrationClient
+from ...testing import DatabaseTest
from ...util.datetime_helpers import utc_now
-class TestIntegrationClient(DatabaseTest):
+class TestIntegrationClient(DatabaseTest):
def setup_method(self):
super(TestIntegrationClient, self).setup_method()
self.client = self._integration_client()
@@ -51,7 +52,9 @@ def test_register(self):
# It raises an error if the url is already registered and the
# submitted shared_secret is inaccurate.
pytest.raises(ValueError, IntegrationClient.register, self._db, client.url)
- pytest.raises(ValueError, IntegrationClient.register, self._db, client.url, 'wrong')
+ pytest.raises(
+ ValueError, IntegrationClient.register, self._db, client.url, "wrong"
+ )
def test_authenticate(self):
@@ -63,20 +66,20 @@ def test_authenticate(self):
def test_normalize_url(self):
# http/https protocol is removed.
- url = 'https://fake.com'
- assert 'fake.com' == IntegrationClient.normalize_url(url)
+ url = "https://fake.com"
+ assert "fake.com" == IntegrationClient.normalize_url(url)
- url = 'http://really-fake.com'
- assert 'really-fake.com' == IntegrationClient.normalize_url(url)
+ url = "http://really-fake.com"
+ assert "really-fake.com" == IntegrationClient.normalize_url(url)
# www is removed if it exists, along with any trailing /
- url = 'https://www.also-fake.net/'
- assert 'also-fake.net' == IntegrationClient.normalize_url(url)
+ url = "https://www.also-fake.net/"
+ assert "also-fake.net" == IntegrationClient.normalize_url(url)
# Subdomains and paths are retained.
- url = 'https://www.super.fake.org/wow/'
- assert 'super.fake.org/wow' == IntegrationClient.normalize_url(url)
+ url = "https://www.super.fake.org/wow/"
+ assert "super.fake.org/wow" == IntegrationClient.normalize_url(url)
# URL is lowercased.
- url = 'http://OMG.soVeryFake.gov'
- assert 'omg.soveryfake.gov' == IntegrationClient.normalize_url(url)
+ url = "http://OMG.soVeryFake.gov"
+ assert "omg.soveryfake.gov" == IntegrationClient.normalize_url(url)
diff --git a/tests/models/test_library.py b/tests/models/test_library.py
index d8df32c54..d1a39aec2 100644
--- a/tests/models/test_library.py
+++ b/tests/models/test_library.py
@@ -1,12 +1,13 @@
# encoding: utf-8
import pytest
-from ...testing import DatabaseTest
+
from ...model.configuration import ConfigurationSetting
from ...model.hasfulltablecache import HasFullTableCache
from ...model.library import Library
+from ...testing import DatabaseTest
-class TestLibrary(DatabaseTest):
+class TestLibrary(DatabaseTest):
def test_library_registry_short_name(self):
library = self._default_library
@@ -17,6 +18,7 @@ def test_library_registry_short_name(self):
# Short name cannot contain a pipe character.
def set_to_pipe():
library.library_registry_short_name = "foo|bar"
+
pytest.raises(ValueError, set_to_pipe)
# You can set the short name to None. This isn't
@@ -68,8 +70,10 @@ def test_default(self):
assert False == l2.is_default
with pytest.raises(ValueError) as excinfo:
l1.is_default = False
- assert "You cannot stop a library from being the default library; you must designate a different library as the default." \
- in str(excinfo.value)
+ assert (
+ "You cannot stop a library from being the default library; you must designate a different library as the default."
+ in str(excinfo.value)
+ )
def test_has_root_lanes(self):
# A library has root lanes if any of its lanes are the root for any
@@ -86,7 +90,7 @@ def test_has_root_lanes(self):
# (This is because there's a listener that resets
# Library._has_default_lane_cache whenever lane configuration
# changes.)
- lane.root_for_patron_type = ["1","2"]
+ lane.root_for_patron_type = ["1", "2"]
self._db.flush()
assert True == library.has_root_lanes
@@ -101,8 +105,7 @@ def test_all_collections(self):
self._default_collection.parent_id = parent.id
assert [self._default_collection] == library.collections
- assert (set([self._default_collection, parent]) ==
- set(library.all_collections))
+ assert set([self._default_collection, parent]) == set(library.all_collections)
def test_estimated_holdings_by_language(self):
library = self._default_library
@@ -125,15 +128,13 @@ def test_estimated_holdings_by_language(self):
assert dict(eng=1, tgl=1) == estimate
# If we disqualify open-access works, it only counts the Tagalog.
- estimate = library.estimated_holdings_by_language(
- include_open_access=False)
+ estimate = library.estimated_holdings_by_language(include_open_access=False)
assert dict(tgl=1) == estimate
# If we remove the default collection from the default library,
# it loses all its works.
self._default_library.collections = []
- estimate = library.estimated_holdings_by_language(
- include_open_access=False)
+ estimate = library.estimated_holdings_by_language(include_open_access=False)
assert dict() == estimate
def test_explain(self):
@@ -147,9 +148,7 @@ def test_explain(self):
library.library_registry_short_name = "SHORT"
library.library_registry_shared_secret = "secret"
- integration = self._external_integration(
- "protocol", "goal"
- )
+ integration = self._external_integration("protocol", "goal")
integration.url = "http://url/"
integration.username = "someuser"
integration.password = "somepass"
@@ -167,7 +166,8 @@ def test_explain(self):
library.integrations.append(integration)
- expect = """Library UUID: "uuid"
+ expect = (
+ """Library UUID: "uuid"
Name: "The Library"
Short name: "Short"
Short name (for library registry): "SHORT"
@@ -180,7 +180,9 @@ def test_explain(self):
somesetting='somevalue'
url='http://url/'
username='someuser'
-""" % integration.id
+"""
+ % integration.id
+ )
actual = library.explain()
assert expect == "\n".join(actual)
diff --git a/tests/models/test_licensing.py b/tests/models/test_licensing.py
index c035c67f6..06ebef581 100644
--- a/tests/models/test_licensing.py
+++ b/tests/models/test_licensing.py
@@ -1,8 +1,9 @@
# encoding: utf-8
+import datetime
+
+import pytest
from mock import MagicMock, PropertyMock
from parameterized import parameterized
-import pytest
-import datetime
from sqlalchemy.exc import IntegrityError
from ...mock_analytics_provider import MockAnalyticsProvider
@@ -34,14 +35,20 @@ class TestDeliveryMechanism(DatabaseTest):
def setup_method(self):
super(TestDeliveryMechanism, self).setup_method()
self.epub_no_drm, ignore = DeliveryMechanism.lookup(
- self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM)
+ self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM
+ )
self.epub_adobe_drm, ignore = DeliveryMechanism.lookup(
- self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM)
+ self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM
+ )
self.overdrive_streaming_text, ignore = DeliveryMechanism.lookup(
- self._db, DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE, DeliveryMechanism.OVERDRIVE_DRM)
+ self._db,
+ DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
+ DeliveryMechanism.OVERDRIVE_DRM,
+ )
self.audiobook_drm_scheme, ignore = DeliveryMechanism.lookup(
- self._db, Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE,
- DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_DRM
+ self._db,
+ Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE,
+ DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_DRM,
)
def test_implicit_medium(self):
@@ -52,8 +59,12 @@ def test_implicit_medium(self):
def test_is_media_type(self):
assert False == DeliveryMechanism.is_media_type(None)
assert True == DeliveryMechanism.is_media_type(Representation.EPUB_MEDIA_TYPE)
- assert False == DeliveryMechanism.is_media_type(DeliveryMechanism.KINDLE_CONTENT_TYPE)
- assert False == DeliveryMechanism.is_media_type(DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE)
+ assert False == DeliveryMechanism.is_media_type(
+ DeliveryMechanism.KINDLE_CONTENT_TYPE
+ )
+ assert False == DeliveryMechanism.is_media_type(
+ DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE
+ )
def test_is_streaming(self):
assert False == self.epub_no_drm.is_streaming
@@ -66,12 +77,22 @@ def test_drm_scheme_media_type(self):
assert None == self.overdrive_streaming_text.drm_scheme_media_type
def test_content_type_media_type(self):
- assert Representation.EPUB_MEDIA_TYPE == self.epub_no_drm.content_type_media_type
- assert Representation.EPUB_MEDIA_TYPE == self.epub_adobe_drm.content_type_media_type
- assert (Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE ==
- self.overdrive_streaming_text.content_type_media_type)
- assert (Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE + DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_PROFILE ==
- self.audiobook_drm_scheme.content_type_media_type)
+ assert (
+ Representation.EPUB_MEDIA_TYPE == self.epub_no_drm.content_type_media_type
+ )
+ assert (
+ Representation.EPUB_MEDIA_TYPE
+ == self.epub_adobe_drm.content_type_media_type
+ )
+ assert (
+ Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE
+ == self.overdrive_streaming_text.content_type_media_type
+ )
+ assert (
+ Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE
+ + DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_PROFILE
+ == self.audiobook_drm_scheme.content_type_media_type
+ )
def test_default_fulfillable(self):
# Try some well-known media type/DRM combinations known to be
@@ -95,15 +116,13 @@ def test_default_fulfillable(self):
# It's possible to create new DeliveryMechanisms at runtime,
# but their .default_client_can_fulfill will be False.
mechanism, is_new = DeliveryMechanism.lookup(
- self._db, MediaTypes.EPUB_MEDIA_TYPE,
- DeliveryMechanism.ADOBE_DRM
+ self._db, MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM
)
assert False == is_new
assert True == mechanism.default_client_can_fulfill
mechanism, is_new = DeliveryMechanism.lookup(
- self._db, MediaTypes.PDF_MEDIA_TYPE,
- DeliveryMechanism.STREAMING_DRM
+ self._db, MediaTypes.PDF_MEDIA_TYPE, DeliveryMechanism.STREAMING_DRM
)
assert True == is_new
assert False == mechanism.default_client_can_fulfill
@@ -121,28 +140,25 @@ def test_compatible_with(self):
mutually compatible and which are mutually exclusive.
"""
epub_adobe, ignore = DeliveryMechanism.lookup(
- self._db, MediaTypes.EPUB_MEDIA_TYPE,
- DeliveryMechanism.ADOBE_DRM
+ self._db, MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM
)
pdf_adobe, ignore = DeliveryMechanism.lookup(
- self._db, MediaTypes.PDF_MEDIA_TYPE,
- DeliveryMechanism.ADOBE_DRM
+ self._db, MediaTypes.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM
)
epub_no_drm, ignore = DeliveryMechanism.lookup(
- self._db, MediaTypes.EPUB_MEDIA_TYPE,
- DeliveryMechanism.NO_DRM
+ self._db, MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM
)
pdf_no_drm, ignore = DeliveryMechanism.lookup(
- self._db, MediaTypes.PDF_MEDIA_TYPE,
- DeliveryMechanism.NO_DRM
+ self._db, MediaTypes.PDF_MEDIA_TYPE, DeliveryMechanism.NO_DRM
)
streaming, ignore = DeliveryMechanism.lookup(
- self._db, DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
- DeliveryMechanism.STREAMING_DRM
+ self._db,
+ DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
+ DeliveryMechanism.STREAMING_DRM,
)
# A non-streaming DeliveryMechanism is compatible only with
@@ -185,7 +201,6 @@ def test_uniqueness_constraint(self):
class TestRightsStatus(DatabaseTest):
-
def test_lookup(self):
status = RightsStatus.lookup(self._db, RightsStatus.IN_COPYRIGHT)
assert RightsStatus.IN_COPYRIGHT == status.uri
@@ -212,7 +227,6 @@ def test_unique_uri_constraint(self):
class TestLicense(DatabaseTest):
-
def setup_method(self):
super(TestLicense, self).setup_method()
self.pool = self._licensepool(None)
@@ -222,28 +236,37 @@ def setup_method(self):
yesterday = now - datetime.timedelta(days=1)
self.perpetual = self._license(
- self.pool, expires=None, remaining_checkouts=None,
- concurrent_checkouts=1)
+ self.pool, expires=None, remaining_checkouts=None, concurrent_checkouts=1
+ )
self.time_limited = self._license(
- self.pool, expires=next_year, remaining_checkouts=None,
- concurrent_checkouts=1)
+ self.pool,
+ expires=next_year,
+ remaining_checkouts=None,
+ concurrent_checkouts=1,
+ )
self.loan_limited = self._license(
- self.pool, expires=None, remaining_checkouts=4,
- concurrent_checkouts=2)
+ self.pool, expires=None, remaining_checkouts=4, concurrent_checkouts=2
+ )
self.time_and_loan_limited = self._license(
- self.pool, expires=next_year + datetime.timedelta(days=1),
- remaining_checkouts=52, concurrent_checkouts=1)
+ self.pool,
+ expires=next_year + datetime.timedelta(days=1),
+ remaining_checkouts=52,
+ concurrent_checkouts=1,
+ )
self.expired_time_limited = self._license(
- self.pool, expires=yesterday, remaining_checkouts=None,
- concurrent_checkouts=1)
+ self.pool,
+ expires=yesterday,
+ remaining_checkouts=None,
+ concurrent_checkouts=1,
+ )
self.expired_loan_limited = self._license(
- self.pool, expires=None, remaining_checkouts=0,
- concurrent_checkouts=1)
+ self.pool, expires=None, remaining_checkouts=0, concurrent_checkouts=1
+ )
def test_loan_to(self):
# Verify that loaning a license also loans its pool.
@@ -297,11 +320,14 @@ def test_license_types(self):
def test_best_available_license(self):
next_week = utc_now() + datetime.timedelta(days=7)
time_limited_2 = self._license(
- self.pool, expires=next_week, remaining_checkouts=None,
- concurrent_checkouts=1)
+ self.pool,
+ expires=next_week,
+ remaining_checkouts=None,
+ concurrent_checkouts=1,
+ )
loan_limited_2 = self._license(
- self.pool, expires=None, remaining_checkouts=2,
- concurrent_checkouts=1)
+ self.pool, expires=None, remaining_checkouts=2, concurrent_checkouts=1
+ )
# First, we use the time-limited license that's expiring first.
assert time_limited_2 == self.pool.best_available_license()
@@ -337,14 +363,16 @@ def test_best_available_license(self):
class TestLicensePool(DatabaseTest):
-
def test_for_foreign_id(self):
"""Verify we can get a LicensePool for a data source, an
appropriate work identifier, and a Collection."""
now = utc_now()
pool, was_new = LicensePool.for_foreign_id(
- self._db, DataSource.GUTENBERG, Identifier.GUTENBERG_ID, "541",
- collection=self._collection()
+ self._db,
+ DataSource.GUTENBERG,
+ Identifier.GUTENBERG_ID,
+ "541",
+ collection=self._collection(),
)
assert (pool.availability_time - now).total_seconds() < 2
assert True == was_new
@@ -363,8 +391,11 @@ def test_for_foreign_id_fails_when_no_collection_provided(self):
pytest.raises(
CollectionMissing,
LicensePool.for_foreign_id,
- self._db, DataSource.GUTENBERG, Identifier.GUTENBERG_ID, "541",
- collection=None
+ self._db,
+ DataSource.GUTENBERG,
+ Identifier.GUTENBERG_ID,
+ "541",
+ collection=None,
)
def test_with_no_delivery_mechanisms(self):
@@ -389,10 +420,16 @@ def test_no_license_pool_for_non_primary_identifier(self):
collection = self._collection()
with pytest.raises(ValueError) as excinfo:
LicensePool.for_foreign_id(
- self._db, DataSource.OVERDRIVE, Identifier.ISBN, "{1-2-3}",
- collection = collection)
- assert "License pools for data source 'Overdrive' are keyed to identifier type 'Overdrive ID' (not 'ISBN', which was provided)" \
+ self._db,
+ DataSource.OVERDRIVE,
+ Identifier.ISBN,
+ "{1-2-3}",
+ collection=collection,
+ )
+ assert (
+ "License pools for data source 'Overdrive' are keyed to identifier type 'Overdrive ID' (not 'ISBN', which was provided)"
in str(excinfo.value)
+ )
def test_licensepools_for_same_identifier_have_same_presentation_edition(self):
"""Two LicensePools for the same Identifier will get the same
@@ -400,12 +437,16 @@ def test_licensepools_for_same_identifier_have_same_presentation_edition(self):
"""
identifier = self._identifier()
edition1, pool1 = self._edition(
- with_license_pool=True, data_source_name=DataSource.GUTENBERG,
- identifier_type=identifier.type, identifier_id=identifier.identifier
+ with_license_pool=True,
+ data_source_name=DataSource.GUTENBERG,
+ identifier_type=identifier.type,
+ identifier_id=identifier.identifier,
)
edition2, pool2 = self._edition(
- with_license_pool=True, data_source_name=DataSource.UNGLUE_IT,
- identifier_type=identifier.type, identifier_id=identifier.identifier
+ with_license_pool=True,
+ data_source_name=DataSource.UNGLUE_IT,
+ identifier_type=identifier.type,
+ identifier_id=identifier.identifier,
)
pool1.set_presentation_edition()
pool2.set_presentation_edition()
@@ -423,7 +464,7 @@ def test_collection_datasource_identifier_must_be_unique(self):
LicensePool,
data_source=data_source,
identifier=identifier,
- collection=collection
+ collection=collection,
)
pytest.raises(
@@ -433,18 +474,24 @@ def test_collection_datasource_identifier_must_be_unique(self):
LicensePool,
data_source=data_source,
identifier=identifier,
- collection=collection
+ collection=collection,
)
def test_with_no_work(self):
p1, ignore = LicensePool.for_foreign_id(
- self._db, DataSource.GUTENBERG, Identifier.GUTENBERG_ID, "1",
- collection=self._default_collection
+ self._db,
+ DataSource.GUTENBERG,
+ Identifier.GUTENBERG_ID,
+ "1",
+ collection=self._default_collection,
)
p2, ignore = LicensePool.for_foreign_id(
- self._db, DataSource.OVERDRIVE, Identifier.OVERDRIVE_ID, "2",
- collection=self._default_collection
+ self._db,
+ DataSource.OVERDRIVE,
+ Identifier.OVERDRIVE_ID,
+ "2",
+ collection=self._default_collection,
)
work = self._work(title="Foo")
@@ -520,7 +567,6 @@ def test_update_availability_does_nothing_if_given_no_data(self):
assert 30 == pool.licenses_reserved
assert 40 == pool.patrons_in_hold_queue
-
def test_open_access_links(self):
edition, pool = self._edition(with_open_access_download=True)
source = DataSource.lookup(self._db, DataSource.GUTENBERG)
@@ -532,8 +578,7 @@ def test_open_access_links(self):
url = self._url
media_type = MediaTypes.EPUB_MEDIA_TYPE
link2, new = pool.identifier.add_link(
- Hyperlink.OPEN_ACCESS_DOWNLOAD, url,
- source, media_type
+ Hyperlink.OPEN_ACCESS_DOWNLOAD, url, source, media_type
)
oa2 = link2.resource
@@ -550,20 +595,28 @@ def test_open_access_links(self):
def test_better_open_access_pool_than(self):
gutenberg_1 = self._licensepool(
- None, open_access=True, data_source_name=DataSource.GUTENBERG,
+ None,
+ open_access=True,
+ data_source_name=DataSource.GUTENBERG,
with_open_access_download=True,
)
gutenberg_2 = self._licensepool(
- None, open_access=True, data_source_name=DataSource.GUTENBERG,
+ None,
+ open_access=True,
+ data_source_name=DataSource.GUTENBERG,
with_open_access_download=True,
)
- assert int(gutenberg_1.identifier.identifier) < int(gutenberg_2.identifier.identifier)
+ assert int(gutenberg_1.identifier.identifier) < int(
+ gutenberg_2.identifier.identifier
+ )
standard_ebooks = self._licensepool(
- None, open_access=True, data_source_name=DataSource.STANDARD_EBOOKS,
- with_open_access_download=True
+ None,
+ open_access=True,
+ data_source_name=DataSource.STANDARD_EBOOKS,
+ with_open_access_download=True,
)
# Make sure Feedbooks data source exists -- it's not created
@@ -572,8 +625,10 @@ def test_better_open_access_pool_than(self):
self._db, DataSource.FEEDBOOKS, autocreate=True
)
feedbooks = self._licensepool(
- None, open_access=True, data_source_name=DataSource.FEEDBOOKS,
- with_open_access_download=True
+ None,
+ open_access=True,
+ data_source_name=DataSource.FEEDBOOKS,
+ with_open_access_download=True,
)
overdrive = self._licensepool(
@@ -585,7 +640,7 @@ def test_better_open_access_pool_than(self):
)
suppressed.suppressed = True
- def better(x,y):
+ def better(x, y):
return x.better_open_access_pool_than(y)
# We would rather have nothing at all than a suppressed
@@ -612,7 +667,8 @@ def better(x,y):
# open-access download resource, it will only be considered if
# there is no other alternative.
no_resource = self._licensepool(
- None, open_access=True,
+ None,
+ open_access=True,
data_source_name=DataSource.STANDARD_EBOOKS,
with_open_access_download=False,
)
@@ -631,66 +687,65 @@ def test_with_complaint(self):
"fiction work with complaint",
language="eng",
fiction=True,
- with_open_access_download=True)
+ with_open_access_download=True,
+ )
lp1 = work1.license_pools[0]
lp1_complaint1 = self._complaint(
- lp1,
- type1,
- "lp1 complaint1 source",
- "lp1 complaint1 detail")
+ lp1, type1, "lp1 complaint1 source", "lp1 complaint1 detail"
+ )
lp1_complaint2 = self._complaint(
- lp1,
- type1,
- "lp1 complaint2 source",
- "lp1 complaint2 detail")
+ lp1, type1, "lp1 complaint2 source", "lp1 complaint2 detail"
+ )
lp1_complaint3 = self._complaint(
- lp1,
- type2,
- "work1 complaint3 source",
- "work1 complaint3 detail")
+ lp1, type2, "work1 complaint3 source", "work1 complaint3 detail"
+ )
lp1_resolved_complaint = self._complaint(
lp1,
type3,
"work3 resolved complaint source",
"work3 resolved complaint detail",
- utc_now())
+ utc_now(),
+ )
work2 = self._work(
"nonfiction work with complaint",
language="eng",
fiction=False,
- with_open_access_download=True)
+ with_open_access_download=True,
+ )
lp2 = work2.license_pools[0]
lp2_complaint1 = self._complaint(
- lp2,
- type2,
- "work2 complaint1 source",
- "work2 complaint1 detail")
+ lp2, type2, "work2 complaint1 source", "work2 complaint1 detail"
+ )
lp2_resolved_complaint = self._complaint(
lp2,
type2,
"work2 resolved complaint source",
"work2 resolved complaint detail",
- utc_now())
+ utc_now(),
+ )
work3 = self._work(
"fiction work without complaint",
language="eng",
fiction=True,
- with_open_access_download=True)
+ with_open_access_download=True,
+ )
lp3 = work3.license_pools[0]
lp3_resolved_complaint = self._complaint(
lp3,
type3,
"work3 resolved complaint source",
"work3 resolved complaint detail",
- utc_now())
+ utc_now(),
+ )
work4 = self._work(
"nonfiction work without complaint",
language="eng",
fiction=False,
- with_open_access_download=True)
+ with_open_access_download=True,
+ )
# excludes resolved complaints by default
results = LicensePool.with_complaint(library).all()
@@ -713,8 +768,7 @@ def test_with_complaint(self):
assert 1 == more_results[2][1]
# show only resolved complaints
- resolved_results = LicensePool.with_complaint(
- library, resolved=True).all()
+ resolved_results = LicensePool.with_complaint(library, resolved=True).all()
lp_ids = set([result[0].id for result in resolved_results])
counts = set([result[1] for result in resolved_results])
@@ -740,13 +794,19 @@ def test_set_presentation_edition(self):
# Here's an Overdrive audiobook which also has data from the metadata
# wrangler and from library staff.
- od, pool = self._edition(data_source_name=DataSource.OVERDRIVE, with_license_pool=True)
+ od, pool = self._edition(
+ data_source_name=DataSource.OVERDRIVE, with_license_pool=True
+ )
od.medium = Edition.AUDIO_MEDIUM
- admin = self._edition(data_source_name=DataSource.LIBRARY_STAFF, with_license_pool=False)
+ admin = self._edition(
+ data_source_name=DataSource.LIBRARY_STAFF, with_license_pool=False
+ )
admin.primary_identifier = pool.identifier
- mw = self._edition(data_source_name=DataSource.METADATA_WRANGLER, with_license_pool=False)
+ mw = self._edition(
+ data_source_name=DataSource.METADATA_WRANGLER, with_license_pool=False
+ )
mw.primary_identifier = pool.identifier
# The library staff has no opinion on the book's medium,
@@ -815,24 +875,42 @@ def test_circulation_changelog(self):
# Since all four circulation values changed, the message is as
# long as it could possibly get.
assert (
- 'CHANGED %s "%s" %s (%s/%s) %s: %s=>%s %s: %s=>%s %s: %s=>%s %s: %s=>%s' ==
- msg)
- assert (
- args ==
- (edition.medium, edition.title, edition.author,
- pool.identifier.type, pool.identifier.identifier,
- 'OWN', 1, 10, 'AVAIL', 2, 9, 'RSRV', 3, 8, 'HOLD', 4, 7))
+ 'CHANGED %s "%s" %s (%s/%s) %s: %s=>%s %s: %s=>%s %s: %s=>%s %s: %s=>%s'
+ == msg
+ )
+ assert args == (
+ edition.medium,
+ edition.title,
+ edition.author,
+ pool.identifier.type,
+ pool.identifier.identifier,
+ "OWN",
+ 1,
+ 10,
+ "AVAIL",
+ 2,
+ 9,
+ "RSRV",
+ 3,
+ 8,
+ "HOLD",
+ 4,
+ 7,
+ )
# If only one circulation value changes, the message is a lot shorter.
msg, args = pool.circulation_changelog(10, 9, 8, 15)
- assert (
- 'CHANGED %s "%s" %s (%s/%s) %s: %s=>%s' ==
- msg)
- assert (
- args ==
- (edition.medium, edition.title, edition.author,
- pool.identifier.type, pool.identifier.identifier,
- 'HOLD', 15, 7))
+ assert 'CHANGED %s "%s" %s (%s/%s) %s: %s=>%s' == msg
+ assert args == (
+ edition.medium,
+ edition.title,
+ edition.author,
+ pool.identifier.type,
+ pool.identifier.identifier,
+ "HOLD",
+ 15,
+ 7,
+ )
# This works even if, for whatever reason, the edition's
# bibliographic data is missing.
@@ -926,39 +1004,39 @@ def test_calculate_change_from_one_event(self):
# event that makes no difference. This lets us see what a
# 'status quo' response from the method would look like.
calc = pool._calculate_change_from_one_event
- assert (5,4,0,0) == calc(CE.DISTRIBUTOR_CHECKIN, 0)
+ assert (5, 4, 0, 0) == calc(CE.DISTRIBUTOR_CHECKIN, 0)
# If there ever appear to be more licenses available than
# owned, the number of owned licenses is left alone. It's
# possible that we have more licenses than we thought, but
# it's more likely that a license has expired or otherwise
# been removed.
- assert (5,5,0,0) == calc(CE.DISTRIBUTOR_CHECKIN, 3)
+ assert (5, 5, 0, 0) == calc(CE.DISTRIBUTOR_CHECKIN, 3)
# But we don't bump up the number of available licenses just
# because one becomes available.
- assert (5,5,0,0) == calc(CE.DISTRIBUTOR_CHECKIN, 1)
+ assert (5, 5, 0, 0) == calc(CE.DISTRIBUTOR_CHECKIN, 1)
# When you signal a hold on a book that's available, we assume
# that the book has stopped being available.
- assert (5,0,0,3) == calc(CE.DISTRIBUTOR_HOLD_PLACE, 3)
+ assert (5, 0, 0, 3) == calc(CE.DISTRIBUTOR_HOLD_PLACE, 3)
# If a license stops being owned, it implicitly stops being
# available. (But we don't know if the license that became
# unavailable is one of the ones currently checked out to
# someone, or one of the other ones.)
- assert (3,3,0,0) == calc(CE.DISTRIBUTOR_LICENSE_REMOVE, 2)
+ assert (3, 3, 0, 0) == calc(CE.DISTRIBUTOR_LICENSE_REMOVE, 2)
# If a license stops being available, it doesn't stop
# being owned.
- assert (5,3,0,0) == calc(CE.DISTRIBUTOR_CHECKOUT, 1)
+ assert (5, 3, 0, 0) == calc(CE.DISTRIBUTOR_CHECKOUT, 1)
# None of these numbers will go below zero.
- assert (0,0,0,0) == calc(CE.DISTRIBUTOR_LICENSE_REMOVE, 100)
+ assert (0, 0, 0, 0) == calc(CE.DISTRIBUTOR_LICENSE_REMOVE, 100)
# Newly added licenses start out available if there are no
# patrons in the hold queue.
- assert (6,5,0,0) == calc(CE.DISTRIBUTOR_LICENSE_ADD, 1)
+ assert (6, 5, 0, 0) == calc(CE.DISTRIBUTOR_LICENSE_ADD, 1)
# Now let's run some tests with a LicensePool that has a large holds
# queue.
@@ -966,26 +1044,26 @@ def test_calculate_change_from_one_event(self):
pool.licenses_available = 0
pool.licenses_reserved = 1
pool.patrons_in_hold_queue = 3
- assert (5,0,1,3) == calc(CE.DISTRIBUTOR_HOLD_PLACE, 0)
+ assert (5, 0, 1, 3) == calc(CE.DISTRIBUTOR_HOLD_PLACE, 0)
# When you signal a hold on a book that already has holds, it
# does nothing but increase the number of patrons in the hold
# queue.
- assert (5,0,1,6) == calc(CE.DISTRIBUTOR_HOLD_PLACE, 3)
+ assert (5, 0, 1, 6) == calc(CE.DISTRIBUTOR_HOLD_PLACE, 3)
# A checkin event has no effect...
- assert (5,0,1,3) == calc(CE.DISTRIBUTOR_CHECKIN, 1)
+ assert (5, 0, 1, 3) == calc(CE.DISTRIBUTOR_CHECKIN, 1)
# ...because it's presumed that it will be followed by an
# availability notification event, which takes a patron off
# the hold queue and adds them to the reserved list.
- assert (5,0,2,2) == calc(CE.DISTRIBUTOR_AVAILABILITY_NOTIFY, 1)
+ assert (5, 0, 2, 2) == calc(CE.DISTRIBUTOR_AVAILABILITY_NOTIFY, 1)
# The only exception is if the checkin event wipes out the
# entire holds queue, in which case the number of available
# licenses increases. (But nothing else changes -- we're
# still waiting for the availability notification events.)
- assert (5,3,1,3) == calc(CE.DISTRIBUTOR_CHECKIN, 6)
+ assert (5, 3, 1, 3) == calc(CE.DISTRIBUTOR_CHECKIN, 6)
# Again, note that even though six copies were checked in,
# we're not assuming we own more licenses than we
@@ -994,11 +1072,11 @@ def test_calculate_change_from_one_event(self):
# When there are no licenses available, a checkout event
# draws from the pool of licenses reserved instead.
- assert (5,0,0,3) == calc(CE.DISTRIBUTOR_CHECKOUT, 2)
+ assert (5, 0, 0, 3) == calc(CE.DISTRIBUTOR_CHECKOUT, 2)
# Newly added licenses do not start out available if there are
# patrons in the hold queue.
- assert (6,0,1,3) == calc(CE.DISTRIBUTOR_LICENSE_ADD, 1)
+ assert (6, 0, 1, 3) == calc(CE.DISTRIBUTOR_LICENSE_ADD, 1)
def test_loan_to_patron(self):
# Test our ability to loan LicensePools to Patrons.
@@ -1018,8 +1096,11 @@ def test_loan_to_patron(self):
fulfillment = pool.delivery_mechanisms[0]
external_identifier = self._str
loan, is_new = pool.loan_to(
- patron, start=yesterday, end=tomorrow,
- fulfillment=fulfillment, external_identifier=external_identifier
+ patron,
+ start=yesterday,
+ end=tomorrow,
+ fulfillment=fulfillment,
+ external_identifier=external_identifier,
)
assert True == is_new
@@ -1042,14 +1123,16 @@ def test_loan_to_patron(self):
# uncertainty.
patron.last_loan_activity_sync = now
loan2, is_new = pool.loan_to(
- patron, start=yesterday, end=tomorrow,
- fulfillment=fulfillment, external_identifier=external_identifier
+ patron,
+ start=yesterday,
+ end=tomorrow,
+ fulfillment=fulfillment,
+ external_identifier=external_identifier,
)
assert False == is_new
assert loan == loan2
assert now == patron.last_loan_activity_sync
-
def test_on_hold_to_patron(self):
# Test our ability to put a Patron in the holds queue for a LicensePool.
#
@@ -1068,8 +1151,11 @@ def test_on_hold_to_patron(self):
position = 99
external_identifier = self._str
hold, is_new = pool.on_hold_to(
- patron, start=yesterday, end=tomorrow,
- position=position, external_identifier=external_identifier
+ patron,
+ start=yesterday,
+ end=tomorrow,
+ position=position,
+ external_identifier=external_identifier,
)
assert True == is_new
@@ -1092,13 +1178,16 @@ def test_on_hold_to_patron(self):
# uncertainty.
patron.last_loan_activity_sync = now
hold2, is_new = pool.on_hold_to(
- patron, start=yesterday, end=tomorrow,
- position=position, external_identifier=external_identifier
+ patron,
+ start=yesterday,
+ end=tomorrow,
+ position=position,
+ external_identifier=external_identifier,
)
assert False == is_new
assert hold == hold2
assert now == patron.last_loan_activity_sync
-
+
class TestLicensePoolDeliveryMechanism(DatabaseTest):
def test_lpdm_change_may_change_open_access_status(self):
@@ -1113,8 +1202,7 @@ def test_lpdm_change_may_change_open_access_status(self):
content_type = MediaTypes.EPUB_MEDIA_TYPE
drm_scheme = DeliveryMechanism.NO_DRM
LicensePoolDeliveryMechanism.set(
- data_source, identifier, content_type, drm_scheme,
- RightsStatus.IN_COPYRIGHT
+ data_source, identifier, content_type, drm_scheme, RightsStatus.IN_COPYRIGHT
)
# Now there's a way to get the book, but it's not open access.
@@ -1122,12 +1210,15 @@ def test_lpdm_change_may_change_open_access_status(self):
# Now give it an open-access LPDM.
link, new = pool.identifier.add_link(
- Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url,
- data_source, content_type
+ Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url, data_source, content_type
)
oa_lpdm = LicensePoolDeliveryMechanism.set(
- data_source, identifier, content_type, drm_scheme,
- RightsStatus.GENERIC_OPEN_ACCESS, link.resource
+ data_source,
+ identifier,
+ content_type,
+ drm_scheme,
+ RightsStatus.GENERIC_OPEN_ACCESS,
+ link.resource,
)
# Now it's open access.
@@ -1179,8 +1270,11 @@ def test_set_rights_status(self):
# Now add a second delivery mechanism, so the pool has one
# open-access and one commercial delivery mechanism.
lpdm2 = pool.set_delivery_mechanism(
- MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM,
- RightsStatus.CC_BY, None)
+ MediaTypes.EPUB_MEDIA_TYPE,
+ DeliveryMechanism.NO_DRM,
+ RightsStatus.CC_BY,
+ None,
+ )
assert 2 == len(pool.delivery_mechanisms)
# Now the pool is open access again
@@ -1194,15 +1288,15 @@ def test_set_rights_status(self):
def test_uniqueness_constraint(self):
# with_open_access_download will create a LPDM
# for the open-access download.
- edition, pool = self._edition(with_license_pool=True,
- with_open_access_download=True)
+ edition, pool = self._edition(
+ with_license_pool=True, with_open_access_download=True
+ )
[lpdm] = pool.delivery_mechanisms
# We can create a second LPDM with the same data type and DRM status,
# so long as the resource is different.
link, new = pool.identifier.add_link(
- Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url,
- pool.data_source, "text/html"
+ Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url, pool.data_source, "text/html"
)
lpdm2 = pool.set_delivery_mechanism(
lpdm.delivery_mechanism.content_type,
@@ -1219,7 +1313,7 @@ def test_uniqueness_constraint(self):
lpdm.delivery_mechanism.content_type,
lpdm.delivery_mechanism.drm_scheme,
lpdm.rights_status.uri,
- None
+ None,
)
assert lpdm3.delivery_mechanism == lpdm.delivery_mechanism
assert None == lpdm3.resource
@@ -1227,12 +1321,14 @@ def test_uniqueness_constraint(self):
# But we can't create a second such LPDM -- it violates a
# constraint of a unique index.
pytest.raises(
- IntegrityError, create, self._db,
+ IntegrityError,
+ create,
+ self._db,
LicensePoolDeliveryMechanism,
delivery_mechanism=lpdm3.delivery_mechanism,
identifier=pool.identifier,
data_source=pool.data_source,
- resource=None
+ resource=None,
)
self._db.rollback()
@@ -1241,8 +1337,9 @@ def test_compatible_with(self):
mutually compatible and which are mutually exclusive.
"""
- edition, pool = self._edition(with_license_pool=True,
- with_open_access_download=True)
+ edition, pool = self._edition(
+ with_license_pool=True, with_open_access_download=True
+ )
[mech] = pool.delivery_mechanisms
# Test the simple cases.
@@ -1274,16 +1371,16 @@ def test_compatible_with(self):
# The underlying delivery mechanisms don't have to be exactly
# the same, but they must be compatible.
pdf_adobe, ignore = DeliveryMechanism.lookup(
- self._db, MediaTypes.PDF_MEDIA_TYPE,
- DeliveryMechanism.ADOBE_DRM
+ self._db, MediaTypes.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM
)
mech1.delivery_mechanism = pdf_adobe
self._db.commit()
assert False == mech1.compatible_with(mech2)
streaming, ignore = DeliveryMechanism.lookup(
- self._db, DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
- DeliveryMechanism.STREAMING_DRM
+ self._db,
+ DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
+ DeliveryMechanism.STREAMING_DRM,
)
mech1.delivery_mechanism = streaming
self._db.commit()
@@ -1292,8 +1389,9 @@ def test_compatible_with(self):
def test_compatible_with_calls_compatible_with_on_deliverymechanism(self):
# Create two LicensePoolDeliveryMechanisms with different
# media types.
- edition, pool = self._edition(with_license_pool=True,
- with_open_access_download=True)
+ edition, pool = self._edition(
+ with_license_pool=True, with_open_access_download=True
+ )
[mech1] = pool.delivery_mechanisms
mech2 = self._add_generic_delivery_mechanism(pool)
mech2.delivery_mechanism, ignore = DeliveryMechanism.lookup(
@@ -1311,10 +1409,12 @@ def test_compatible_with_calls_compatible_with_on_deliverymechanism(self):
# open-access?'
class Mock(object):
called_with = None
+
@classmethod
def compatible_with(cls, other, open_access):
cls.called_with = (other, open_access)
return True
+
mech1.delivery_mechanism.compatible_with = Mock.compatible_with
# Call compatible_with, and the mock method is called with the
@@ -1322,9 +1422,7 @@ def compatible_with(cls, other, open_access):
# LicensePoolDeliveryMechanisms is not open-access) the value
# False.
mech1.compatible_with(mech2)
- assert (
- (mech2.delivery_mechanism, False) ==
- Mock.called_with)
+ assert (mech2.delivery_mechanism, False) == Mock.called_with
# If both LicensePoolDeliveryMechanisms are open-access,
# True is passed in instead, so that
@@ -1332,14 +1430,9 @@ def compatible_with(cls, other, open_access):
# compatibility rules for open-access fulfillment.
mech2.set_rights_status(RightsStatus.GENERIC_OPEN_ACCESS)
mech1.compatible_with(mech2)
- assert (
- (mech2.delivery_mechanism, True) ==
- Mock.called_with)
+ assert (mech2.delivery_mechanism, True) == Mock.called_with
- @parameterized.expand([
- ('ascii_sy', 'a', 'a', 'a'),
- ('', 'ą', 'ą', 'ą')
- ])
+ @parameterized.expand([("ascii_sy", "a", "a", "a"), ("", "ą", "ą", "ą")])
def test_repr(self, _, data_source, identifier, delivery_mechanism):
"""Test that LicensePoolDeliveryMechanism.__repr__ correctly works for both ASCII and non-ASCII symbols.
@@ -1366,9 +1459,15 @@ def test_repr(self, _, data_source, identifier, delivery_mechanism):
delivery_mechanism_mock.__repr__ = MagicMock(return_value=delivery_mechanism)
license_delivery_mechanism_mock = LicensePoolDeliveryMechanism()
- license_delivery_mechanism_mock.data_source = PropertyMock(return_value=data_source_mock)
- license_delivery_mechanism_mock.identifier = PropertyMock(return_value=identifier_mock)
- license_delivery_mechanism_mock.delivery_mechanism = PropertyMock(return_value=delivery_mechanism_mock)
+ license_delivery_mechanism_mock.data_source = PropertyMock(
+ return_value=data_source_mock
+ )
+ license_delivery_mechanism_mock.identifier = PropertyMock(
+ return_value=identifier_mock
+ )
+ license_delivery_mechanism_mock.delivery_mechanism = PropertyMock(
+ return_value=delivery_mechanism_mock
+ )
# Act
# NOTE: we are not interested in the result returned by repr,
diff --git a/tests/models/test_listeners.py b/tests/models/test_listeners.py
index 7c45e0cca..64124b000 100644
--- a/tests/models/test_listeners.py
+++ b/tests/models/test_listeners.py
@@ -1,28 +1,28 @@
# encoding: utf-8
import functools
+
from parameterized import parameterized
-from ...testing import DatabaseTest
-from ... import lane
-from ... import model
+from ... import lane, model
from ...config import Configuration
from ...model import (
CachedFeed,
ConfigurationSetting,
- create,
- site_configuration_has_changed,
Timestamp,
WorkCoverageRecord,
+ create,
+ site_configuration_has_changed,
)
+from ...testing import DatabaseTest
from ...util.datetime_helpers import utc_now
class TestSiteConfigurationHasChanged(DatabaseTest):
-
class MockSiteConfigurationHasChanged(object):
"""Keep track of whether site_configuration_has_changed was
ever called.
"""
+
def __init__(self):
self.was_called = False
@@ -42,7 +42,9 @@ def setup_method(self):
super(TestSiteConfigurationHasChanged, self).setup_method()
# Mock model.site_configuration_has_changed
- self.old_site_configuration_has_changed = model.listeners.site_configuration_has_changed
+ self.old_site_configuration_has_changed = (
+ model.listeners.site_configuration_has_changed
+ )
self.mock = self.MockSiteConfigurationHasChanged()
for module in model.listeners, lane:
module.site_configuration_has_changed = self.mock.run
@@ -50,7 +52,9 @@ def setup_method(self):
def teardown_method(self):
super(TestSiteConfigurationHasChanged, self).teardown_method()
for module in model.listeners, lane:
- module.site_configuration_has_changed = self.old_site_configuration_has_changed
+ module.site_configuration_has_changed = (
+ self.old_site_configuration_has_changed
+ )
def test_site_configuration_has_changed(self):
"""Test the site_configuration_has_changed() function and its
@@ -63,9 +67,12 @@ def test_site_configuration_has_changed(self):
def ts():
return Timestamp.value(
- self._db, Configuration.SITE_CONFIGURATION_CHANGED,
- service_type=None, collection=None
+ self._db,
+ Configuration.SITE_CONFIGURATION_CHANGED,
+ service_type=None,
+ collection=None,
)
+
timestamp_value = ts()
assert timestamp_value == last_update
@@ -92,21 +99,22 @@ def ts():
# site_configuration_has_changed() -- they will know about the
# change but we won't be informed.
timestamp = Timestamp.stamp(
- self._db, Configuration.SITE_CONFIGURATION_CHANGED,
- service_type=None, collection=None
+ self._db,
+ Configuration.SITE_CONFIGURATION_CHANGED,
+ service_type=None,
+ collection=None,
)
# Calling Configuration.check_for_site_configuration_update
# with a timeout doesn't detect the change.
- assert (new_last_update_time ==
- Configuration.site_configuration_last_update(self._db, timeout=60))
+ assert new_last_update_time == Configuration.site_configuration_last_update(
+ self._db, timeout=60
+ )
# But the default behavior -- a timeout of zero -- forces
# the method to go to the database and find the correct
# answer.
- newer_update = Configuration.site_configuration_last_update(
- self._db
- )
+ newer_update = Configuration.site_configuration_last_update(self._db)
assert newer_update > last_update
# The Timestamp that tracks the last configuration update has
@@ -126,8 +134,7 @@ def ts():
# Verify that the Timestamp has not changed (how could it,
# with no database connection to modify the Timestamp?)
- assert (newer_update ==
- Configuration.site_configuration_last_update(self._db))
+ assert newer_update == Configuration.site_configuration_last_update(self._db)
# We don't test every event listener, but we do test one of each type.
def test_configuration_relevant_lifecycle_event_updates_configuration(self):
@@ -172,8 +179,9 @@ def test_configuration_relevant_collection_change_updates_configuration(self):
# Associating a CachedFeed with the library does _not_ call
# the method, because nothing changed on the Library object and
# we don't listen for 'append' events on Library.cachedfeeds.
- create(self._db, CachedFeed, type='page', pagination='',
- facets='', library=library)
+ create(
+ self._db, CachedFeed, type="page", pagination="", facets="", library=library
+ )
self._db.commit()
self.mock.assert_was_not_called()
@@ -188,10 +196,18 @@ def _set_property(object, value, property_name):
class TestListeners(DatabaseTest):
- @parameterized.expand([
- ('works_when_open_access_property_changes', functools.partial(_set_property, property_name='open_access')),
- ('works_when_self_hosted_property_changes', functools.partial(_set_property, property_name='self_hosted'))
- ])
+ @parameterized.expand(
+ [
+ (
+ "works_when_open_access_property_changes",
+ functools.partial(_set_property, property_name="open_access"),
+ ),
+ (
+ "works_when_self_hosted_property_changes",
+ functools.partial(_set_property, property_name="self_hosted"),
+ ),
+ ]
+ )
def test_licensepool_storage_status_change(self, name, status_property_setter):
# Arrange
work = self._work(with_license_pool=True)
@@ -211,5 +227,8 @@ def test_licensepool_storage_status_change(self, name, status_property_setter):
# Assert
assert 1 == len(work.coverage_records)
assert work.id == work.coverage_records[0].work_id
- assert WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION == work.coverage_records[0].operation
+ assert (
+ WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION
+ == work.coverage_records[0].operation
+ )
assert WorkCoverageRecord.REGISTERED == work.coverage_records[0].status
diff --git a/tests/models/test_measurement.py b/tests/models/test_measurement.py
index 937a503b5..4c9a58f0b 100644
--- a/tests/models/test_measurement.py
+++ b/tests/models/test_measurement.py
@@ -1,35 +1,130 @@
-from ...model import (
- DataSource,
- Measurement,
- get_one_or_create
-)
-from ...testing import (
- DatabaseTest,
-)
+from ...model import DataSource, Measurement, get_one_or_create
+from ...testing import DatabaseTest
from ...util.datetime_helpers import datetime_utc
-class TestMeasurement(DatabaseTest):
+class TestMeasurement(DatabaseTest):
def setup_method(self):
super(TestMeasurement, self).setup_method()
self.SOURCE_NAME = "Test Data Source"
# Create a test DataSource
obj, new = get_one_or_create(
- self._db, DataSource,
- name=self.SOURCE_NAME,
+ self._db,
+ DataSource,
+ name=self.SOURCE_NAME,
)
self.source = obj
Measurement.PERCENTILE_SCALES[Measurement.POPULARITY][self.SOURCE_NAME] = [
- 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 9, 9, 10, 10, 11, 12, 13, 14, 15, 15, 16, 18, 19, 20, 21, 22, 24, 25, 26, 28, 30, 31, 33, 35, 37, 39, 41, 43, 46, 48, 51, 53, 56, 59, 63, 66, 70, 74, 78, 82, 87, 92, 97, 102, 108, 115, 121, 128, 135, 142, 150, 159, 168, 179, 190, 202, 216, 230, 245, 260, 277, 297, 319, 346, 372, 402, 436, 478, 521, 575, 632, 702, 777, 861, 965, 1100, 1248, 1428, 1665, 2020, 2560, 3535, 5805]
+ 1,
+ 1,
+ 1,
+ 2,
+ 2,
+ 2,
+ 3,
+ 3,
+ 4,
+ 4,
+ 5,
+ 5,
+ 6,
+ 6,
+ 7,
+ 7,
+ 8,
+ 9,
+ 9,
+ 10,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 15,
+ 16,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 24,
+ 25,
+ 26,
+ 28,
+ 30,
+ 31,
+ 33,
+ 35,
+ 37,
+ 39,
+ 41,
+ 43,
+ 46,
+ 48,
+ 51,
+ 53,
+ 56,
+ 59,
+ 63,
+ 66,
+ 70,
+ 74,
+ 78,
+ 82,
+ 87,
+ 92,
+ 97,
+ 102,
+ 108,
+ 115,
+ 121,
+ 128,
+ 135,
+ 142,
+ 150,
+ 159,
+ 168,
+ 179,
+ 190,
+ 202,
+ 216,
+ 230,
+ 245,
+ 260,
+ 277,
+ 297,
+ 319,
+ 346,
+ 372,
+ 402,
+ 436,
+ 478,
+ 521,
+ 575,
+ 632,
+ 702,
+ 777,
+ 861,
+ 965,
+ 1100,
+ 1248,
+ 1428,
+ 1665,
+ 2020,
+ 2560,
+ 3535,
+ 5805,
+ ]
Measurement.RATING_SCALES[self.SOURCE_NAME] = [1, 10]
def _measurement(self, quantity, value, source, weight):
source = source or self.source
return Measurement(
- data_source=source, quantity_measured=quantity,
- value=value, weight=weight)
+ data_source=source, quantity_measured=quantity, value=value, weight=weight
+ )
def _popularity(self, value, source=None, weight=1):
return self._measurement(Measurement.POPULARITY, value, source, weight)
@@ -56,19 +151,16 @@ def test_newer_measurement_displaces_earlier_measurement(self):
assert True == m2.is_most_recent
assert True == m3.is_most_recent
-
def test_can_insert_measurement_after_the_fact(self):
old = datetime_utc(2011, 1, 1)
new = datetime_utc(2012, 1, 1)
wi = self._identifier()
- m1 = wi.add_measurement(self.source, Measurement.DOWNLOADS, 10,
- taken_at=new)
+ m1 = wi.add_measurement(self.source, Measurement.DOWNLOADS, 10, taken_at=new)
assert True == m1.is_most_recent
- m2 = wi.add_measurement(self.source, Measurement.DOWNLOADS, 5,
- taken_at=old)
+ m2 = wi.add_measurement(self.source, Measurement.DOWNLOADS, 5, taken_at=old)
assert True == m1.is_most_recent
def test_normalized_popularity(self):
@@ -119,17 +211,14 @@ def test_normalized_rating(self):
# Here's a slightly less good book.
p = self._rating(9)
- assert 8.0/9 == p.normalized_value
+ assert 8.0 / 9 == p.normalized_value
# Here's a very bad book
p = self._rating(1)
assert 0 == p.normalized_value
def test_neglected_source_cannot_be_normalized(self):
- obj, new = get_one_or_create(
- self._db, DataSource,
- name="Neglected source"
- )
+ obj, new = get_one_or_create(self._db, DataSource, name="Neglected source")
neglected_source = obj
p = self._popularity(100, neglected_source)
assert None == p.normalized_value
@@ -144,22 +233,20 @@ def test_overall_quality(self):
pop = popularity.normalized_value
rat = rating.normalized_value
assert 0.5 == pop
- assert 1.0/3 == rat
+ assert 1.0 / 3 == rat
l = [popularity, rating, irrelevant]
quality = Measurement.overall_quality(l)
- assert (0.7*rat)+(0.3*pop) == quality
+ assert (0.7 * rat) + (0.3 * pop) == quality
# Mess with the weights.
- assert (0.5*rat)+(0.5*pop) == Measurement.overall_quality(l, 0.5, 0.5)
+ assert (0.5 * rat) + (0.5 * pop) == Measurement.overall_quality(l, 0.5, 0.5)
# Adding a non-popularity measurement that is _equated_ to
# popularity via a percentile scale modifies the
# normalized value -- we don't care exactly how, only that
# it's taken into account.
oclc = DataSource.lookup(self._db, DataSource.OCLC)
- popularityish = self._measurement(
- Measurement.HOLDINGS, 400, oclc, 10
- )
+ popularityish = self._measurement(Measurement.HOLDINGS, 400, oclc, 10)
new_quality = Measurement.overall_quality(l + [popularityish])
assert quality != new_quality
@@ -184,7 +271,7 @@ def test_overall_quality_with_popularity_and_quality_but_not_rating(self):
# We would expect the final quality score to be 1/2 of the quality
# score we got from the metadata wrangler, and 1/2 of the normalized
# value of the 4-star rating.
- expect = (pop.normalized_value / 2) + (0.5/2)
+ expect = (pop.normalized_value / 2) + (0.5 / 2)
assert expect == Measurement.overall_quality([pop, qual], 0.5, 0.5)
def test_overall_quality_with_popularity_quality_and_rating(self):
@@ -195,17 +282,17 @@ def test_overall_quality_with_popularity_quality_and_rating(self):
# The popularity and rating are scaled appropriately and
# added together.
- expect_1 = (pop.normalized_value * 0.75) + (rat.normalized_value*0.25)
+ expect_1 = (pop.normalized_value * 0.75) + (rat.normalized_value * 0.25)
# Then the whole thing is divided in half and added to half of the
# quality score
- expect_total = (expect_1/2 + (quality_score/2))
+ expect_total = expect_1 / 2 + (quality_score / 2)
assert expect_total == Measurement.overall_quality([pop, rat, qual], 0.75, 0.25)
def test_overall_quality_takes_weights_into_account(self):
rating1 = self._rating(10, weight=10)
rating2 = self._rating(1, weight=1)
- assert 0.91 == round(Measurement.overall_quality([rating1, rating2]),2)
+ assert 0.91 == round(Measurement.overall_quality([rating1, rating2]), 2)
def test_overall_quality_is_zero_if_no_relevant_measurements(self):
irrelevant = self._measurement("Some other quantity", 42, self.source, 1)
@@ -217,17 +304,18 @@ def test_calculate_quality(self):
# This book used to be incredibly popular.
identifier = w.presentation_edition.primary_identifier
old_popularity = identifier.add_measurement(
- self.source, Measurement.POPULARITY, 6000)
+ self.source, Measurement.POPULARITY, 6000
+ )
# Now it's just so-so.
- popularity = identifier.add_measurement(
- self.source, Measurement.POPULARITY, 59)
+ popularity = identifier.add_measurement(self.source, Measurement.POPULARITY, 59)
# This measurement is irrelevant because "Test Data Source"
# doesn't have a mapping from number of editions to a
# percentile range.
irrelevant = identifier.add_measurement(
- self.source, Measurement.PUBLISHED_EDITIONS, 42)
+ self.source, Measurement.PUBLISHED_EDITIONS, 42
+ )
# If we calculate the quality based solely on the primary
# identifier, only the most recent popularity is considered,
diff --git a/tests/models/test_model.py b/tests/models/test_model.py
index f298fb433..23627d6a3 100644
--- a/tests/models/test_model.py
+++ b/tests/models/test_model.py
@@ -1,28 +1,28 @@
# encoding: utf-8
-import pytest
import datetime
+
+import pytest
from psycopg2.extras import NumericRange
from sqlalchemy import not_
from sqlalchemy.orm.exc import MultipleResultsFound
-from ...testing import DatabaseTest
from ... import classifier
-from ...external_search import mock_search_index
from ...config import Configuration
+from ...external_search import mock_search_index
from ...model import (
DataSource,
Edition,
Genre,
- get_one,
SessionManager,
Timestamp,
+ get_one,
numericrange_to_tuple,
tuple_to_numericrange,
)
+from ...testing import DatabaseTest
class TestDatabaseInterface(DatabaseTest):
-
def test_get_one(self):
# When a matching object isn't found, None is returned.
@@ -39,14 +39,13 @@ def test_get_one(self):
pytest.raises(MultipleResultsFound, get_one, self._db, Edition)
# Unless they're interchangeable.
- result = get_one(self._db, Edition, on_multiple='interchangeable')
+ result = get_one(self._db, Edition, on_multiple="interchangeable")
assert result in self._db.query(Edition)
# Or specific attributes are passed that limit the results to one.
result = get_one(
- self._db, Edition,
- title=other_edition.title,
- author=other_edition.author)
+ self._db, Edition, title=other_edition.title, author=other_edition.author
+ )
assert other_edition == result
# A particular constraint clause can also be passed in.
@@ -60,9 +59,12 @@ def test_initialize_data_does_not_reset_timestamp(self):
# initialized and the 'site configuration changed' Timestamp has
# been set. Calling initialize_data() again won't change the
# date on the timestamp.
- timestamp = get_one(self._db, Timestamp,
- collection=None,
- service=Configuration.SITE_CONFIGURATION_CHANGED)
+ timestamp = get_one(
+ self._db,
+ Timestamp,
+ collection=None,
+ service=Configuration.SITE_CONFIGURATION_CHANGED,
+ )
old_timestamp = timestamp.finish
SessionManager.initialize_data(self._db)
assert old_timestamp == timestamp.finish
@@ -77,7 +79,7 @@ def test_tuple_to_numericrange(self):
f = tuple_to_numericrange
assert None == f(None)
- one_to_ten = f((1,10))
+ one_to_ten = f((1, 10))
assert isinstance(one_to_ten, NumericRange)
assert 1 == one_to_ten.lower
assert 10 == one_to_ten.upper
@@ -89,7 +91,7 @@ def test_tuple_to_numericrange(self):
assert 10 == up_to_ten.upper
assert True == up_to_ten.upper_inc
- ten_and_up = f((10,None))
+ ten_and_up = f((10, None))
assert isinstance(ten_and_up, NumericRange)
assert 10 == ten_and_up.lower
assert None == ten_and_up.upper
@@ -97,7 +99,7 @@ def test_tuple_to_numericrange(self):
def test_numericrange_to_tuple(self):
m = numericrange_to_tuple
- two_to_six_inclusive = NumericRange(2,6, '[]')
- assert (2,6) == m(two_to_six_inclusive)
- two_to_six_exclusive = NumericRange(2,6, '()')
- assert (3,5) == m(two_to_six_exclusive)
+ two_to_six_inclusive = NumericRange(2, 6, "[]")
+ assert (2, 6) == m(two_to_six_inclusive)
+ two_to_six_exclusive = NumericRange(2, 6, "()")
+ assert (3, 5) == m(two_to_six_exclusive)
diff --git a/tests/models/test_patron.py b/tests/models/test_patron.py
index c918acc9b..d33e4f0bd 100644
--- a/tests/models/test_patron.py
+++ b/tests/models/test_patron.py
@@ -1,35 +1,26 @@
# encoding: utf-8
-import pytest
import datetime
-from mock import (
- call,
- MagicMock,
-)
-from ...testing import DatabaseTest
+import pytest
+from mock import MagicMock, call
+
from ...classifier import Classifier
-from ...model import (
- create,
- tuple_to_numericrange,
-)
+from ...model import create, tuple_to_numericrange
from ...model.credential import Credential
from ...model.datasource import DataSource
from ...model.library import Library
from ...model.licensing import PolicyException
-from ...model.patron import (
- Annotation,
- Hold,
- Loan,
- Patron,
- PatronProfileStorage,
-)
+from ...model.patron import Annotation, Hold, Loan, Patron, PatronProfileStorage
+from ...testing import DatabaseTest
from ...util.datetime_helpers import datetime_utc, utc_now
+
class TestAnnotation(DatabaseTest):
def test_set_inactive(self):
pool = self._licensepool(None)
annotation, ignore = create(
- self._db, Annotation,
+ self._db,
+ Annotation,
patron=self._patron(),
identifier=pool.identifier,
motivation=Annotation.IDLING,
@@ -49,7 +40,8 @@ def test_patron_annotations_are_descending(self):
pool2 = self._licensepool(None)
patron = self._patron()
annotation1, ignore = create(
- self._db, Annotation,
+ self._db,
+ Annotation,
patron=patron,
identifier=pool2.identifier,
motivation=Annotation.IDLING,
@@ -57,7 +49,8 @@ def test_patron_annotations_are_descending(self):
active=True,
)
annotation2, ignore = create(
- self._db, Annotation,
+ self._db,
+ Annotation,
patron=patron,
identifier=pool2.identifier,
motivation=Annotation.IDLING,
@@ -74,8 +67,8 @@ def test_patron_annotations_are_descending(self):
assert annotation2 == patron.annotations[0]
assert annotation1 == patron.annotations[1]
-class TestHold(DatabaseTest):
+class TestHold(DatabaseTest):
def test_on_hold_to(self):
now = utc_now()
later = now + datetime.timedelta(days=1)
@@ -172,15 +165,21 @@ def _mock__calculate_until(self, *args):
"""Track the arguments passed into _calculate_until."""
self.called_with = args
return "mock until"
+
old__calculate_until = hold._calculate_until
Hold._calculate_until = _mock__calculate_until
assert "mock until" == m(one_day, two_days)
- (calculate_from, position, licenses_available, default_loan_period,
- default_reservation_period) = hold.called_with
+ (
+ calculate_from,
+ position,
+ licenses_available,
+ default_loan_period,
+ default_reservation_period,
+ ) = hold.called_with
- assert (calculate_from-now).total_seconds() < 5
+ assert (calculate_from - now).total_seconds() < 5
assert hold.position == position
assert pool.licenses_available == licenses_available
assert one_day == default_loan_period
@@ -190,8 +189,13 @@ def _mock__calculate_until(self, *args):
# assume they're at the end.
hold.position = None
assert "mock until" == m(one_day, two_days)
- (calculate_from, position, licenses_available, default_loan_period,
- default_reservation_period) = hold.called_with
+ (
+ calculate_from,
+ position,
+ licenses_available,
+ default_loan_period,
+ default_reservation_period,
+ ) = hold.called_with
assert pool.patrons_in_hold_queue == position
Hold._calculate_until = old__calculate_until
@@ -210,45 +214,37 @@ def test_calculate_until(self):
# After 21 days, those copies are released and I am 8th in line.
# After 28 days, those copies are released and I am 4th in line.
# After 35 days, those copies are released and get my notification.
- a = Hold._calculate_until(
- start, 20, 4, default_loan, default_reservation)
- assert a == start + datetime.timedelta(days=(7*5))
+ a = Hold._calculate_until(start, 20, 4, default_loan, default_reservation)
+ assert a == start + datetime.timedelta(days=(7 * 5))
# If I am 21st in line, I need to wait six weeks.
- b = Hold._calculate_until(
- start, 21, 4, default_loan, default_reservation)
- assert b == start + datetime.timedelta(days=(7*6))
+ b = Hold._calculate_until(start, 21, 4, default_loan, default_reservation)
+ assert b == start + datetime.timedelta(days=(7 * 6))
# If I am 3rd in line, I only need to wait seven days--that's when
# I'll get the notification message.
- b = Hold._calculate_until(
- start, 3, 4, default_loan, default_reservation)
+ b = Hold._calculate_until(start, 3, 4, default_loan, default_reservation)
assert b == start + datetime.timedelta(days=7)
# A new person gets the book every week. Someone has the book now
# and there are 3 people ahead of me in the queue. I will get
# the book in 7 days + 3 weeks
- c = Hold._calculate_until(
- start, 3, 1, default_loan, default_reservation)
- assert c == start + datetime.timedelta(days=(7*4))
+ c = Hold._calculate_until(start, 3, 1, default_loan, default_reservation)
+ assert c == start + datetime.timedelta(days=(7 * 4))
# I'm first in line for 1 book. After 7 days, one copy is
# released and I'll get my notification.
- a = Hold._calculate_until(
- start, 1, 1, default_loan, default_reservation)
+ a = Hold._calculate_until(start, 1, 1, default_loan, default_reservation)
assert a == start + datetime.timedelta(days=7)
# The book is reserved to me. I need to hurry up and check it out.
- d = Hold._calculate_until(
- start, 0, 1, default_loan, default_reservation)
+ d = Hold._calculate_until(start, 0, 1, default_loan, default_reservation)
assert d == start + datetime.timedelta(days=1)
# If there are no licenses, I will never get the book.
- e = Hold._calculate_until(
- start, 10, 0, default_loan, default_reservation)
+ e = Hold._calculate_until(start, 10, 0, default_loan, default_reservation)
assert e == None
-
def test_vendor_hold_end_value_takes_precedence_over_calculated_value(self):
"""If the vendor has provided an estimated availability time,
that is used in preference to the availability time we
@@ -268,15 +264,19 @@ def test_vendor_hold_end_value_takes_precedence_over_calculated_value(self):
assert tomorrow == hold.until(default_loan, default_reservation)
calculated_value = hold._calculate_until(
- now, hold.position, pool.licenses_available,
- default_loan, default_reservation
+ now,
+ hold.position,
+ pool.licenses_available,
+ default_loan,
+ default_reservation,
)
# If the vendor value is not in the future, it's ignored
# and the calculated value is used instead.
def assert_calculated_value_used():
result = hold.until(default_loan, default_reservation)
- assert (result-calculated_value).seconds < 5
+ assert (result - calculated_value).seconds < 5
+
hold.end = now
assert_calculated_value_used()
@@ -285,8 +285,8 @@ def assert_calculated_value_used():
hold.end = None
assert_calculated_value_used()
-class TestLoans(DatabaseTest):
+class TestLoans(DatabaseTest):
def test_open_access_loan(self):
patron = self._patron()
work = self._work(with_license_pool=True)
@@ -378,8 +378,8 @@ def test_library(self):
loan.patron = patron
assert patron.library == loan.library
-class TestPatron(DatabaseTest):
+class TestPatron(DatabaseTest):
def test_repr(self):
patron = self._patron(external_identifier="a patron")
@@ -387,8 +387,9 @@ def test_repr(self):
patron.authorization_expires = datetime_utc(2018, 1, 2, 3, 4, 5)
patron.last_external_sync = None
assert (
- "" ==
- repr(patron))
+ ""
+ == repr(patron)
+ )
def test_identifier_to_remote_service(self):
@@ -419,17 +420,16 @@ def test_identifier_to_remote_service(self):
# patron identifier.
def fake_generator():
return "fake string"
+
bib = DataSource.BIBLIOTHECA
- assert ("fake string" ==
- patron.identifier_to_remote_service(bib, fake_generator))
+ assert "fake string" == patron.identifier_to_remote_service(bib, fake_generator)
# Once the identifier is created, specifying a different generator
# does nothing.
- assert ("fake string" ==
- patron.identifier_to_remote_service(bib))
- assert (
- axis_identifier ==
- patron.identifier_to_remote_service(axis, fake_generator))
+ assert "fake string" == patron.identifier_to_remote_service(bib)
+ assert axis_identifier == patron.identifier_to_remote_service(
+ axis, fake_generator
+ )
def test_set_synchronize_annotations(self):
# Two patrons.
@@ -451,7 +451,7 @@ def test_set_synchronize_annotations(self):
identifier=identifier,
motivation=Annotation.IDLING,
)
- annotation.content="The content for %s" % patron.id,
+ annotation.content = ("The content for %s" % patron.id,)
assert 1 == len(patron.annotations)
@@ -464,8 +464,11 @@ def test_set_synchronize_annotations(self):
# Patron #1 can no longer use Annotation.get_one_or_create.
pytest.raises(
- ValueError, Annotation.get_one_or_create,
- self._db, patron=p1, identifier=identifier,
+ ValueError,
+ Annotation.get_one_or_create,
+ self._db,
+ patron=p1,
+ identifier=identifier,
motivation=Annotation.IDLING,
)
@@ -474,7 +477,9 @@ def test_set_synchronize_annotations(self):
# But patron #2 can use Annotation.get_one_or_create.
i2, is_new = Annotation.get_one_or_create(
- self._db, patron=p2, identifier=self._identifier(),
+ self._db,
+ patron=p2,
+ identifier=self._identifier(),
motivation=Annotation.IDLING,
)
assert True == is_new
@@ -483,6 +488,7 @@ def test_set_synchronize_annotations(self):
# can't go back to not having made the decision.
def try_to_set_none(patron):
patron.synchronize_annotations = None
+
pytest.raises(ValueError, try_to_set_none, p2)
def test_cascade_delete(self):
@@ -527,7 +533,7 @@ def test_cascade_delete(self):
def test_loan_activity_max_age(self):
# Currently, patron.loan_activity_max_age is a constant
# and cannot be changed.
- assert 15*60 == self._patron().loan_activity_max_age
+ assert 15 * 60 == self._patron().loan_activity_max_age
def test_last_loan_activity_sync(self):
# Verify that last_loan_activity_sync is cleared out
@@ -535,8 +541,8 @@ def test_last_loan_activity_sync(self):
patron = self._patron()
now = utc_now()
max_age = patron.loan_activity_max_age
- recently = now - datetime.timedelta(seconds=max_age/2)
- long_ago = now - datetime.timedelta(seconds=max_age*2)
+ recently = now - datetime.timedelta(seconds=max_age / 2)
+ long_ago = now - datetime.timedelta(seconds=max_age * 2)
# So long as last_loan_activity_sync is relatively recent,
# it's treated as a normal piece of data.
@@ -595,14 +601,13 @@ def test_work_is_age_appropriate(self):
# The target audience and age of a patron's root lane controls
# whether a given book is 'age-appropriate' for them.
lane = self._lane()
- lane.audiences = [Classifier.AUDIENCE_CHILDREN,
- Classifier.AUDIENCE_YOUNG_ADULT]
- lane.target_age = (9,14)
+ lane.audiences = [Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_YOUNG_ADULT]
+ lane.target_age = (9, 14)
lane.root_for_patron_type = ["1"]
self._db.flush()
- def mock_age_appropriate(work_audience, work_target_age,
- reader_audience, reader_target_age
+ def mock_age_appropriate(
+ work_audience, work_target_age, reader_audience, reader_target_age
):
"""Returns True only if reader_audience is the preconfigured
expected value.
@@ -631,12 +636,22 @@ def mock_age_appropriate(work_audience, work_target_age,
# age_appropriate_match method was called on
# each audience associated with the patron's root lane.
- mock.assert_has_calls([
- call(work_audience, work_target_age,
- Classifier.AUDIENCE_CHILDREN, lane.target_age),
- call(work_audience, work_target_age,
- Classifier.AUDIENCE_YOUNG_ADULT, lane.target_age)
- ])
+ mock.assert_has_calls(
+ [
+ call(
+ work_audience,
+ work_target_age,
+ Classifier.AUDIENCE_CHILDREN,
+ lane.target_age,
+ ),
+ call(
+ work_audience,
+ work_target_age,
+ Classifier.AUDIENCE_YOUNG_ADULT,
+ lane.target_age,
+ ),
+ ]
+ )
# work_is_age_appropriate() will only return True if at least
# one of the age_appropriate_match() calls returns True.
@@ -686,7 +701,7 @@ def test_age_appropriate_match(self):
# a value that would allow for this (as can happen when
# the patron's root lane is set up to show both children's
# and YA titles).
- assert False == m(work_audience, object(), children, (14,18))
+ assert False == m(work_audience, object(), children, (14, 18))
# YA readers can see any children's title.
assert True == m(children, object(), ya, object())
@@ -694,9 +709,7 @@ def test_age_appropriate_match(self):
# A YA reader is treated as an adult (with no reading
# restrictions) if they have no associated age range, or their
# age range includes ADULT_AGE_CUTOFF.
- for reader_age in [
- None, 18, (14, 18), tuple_to_numericrange((14, 18))
- ]:
+ for reader_age in [None, 18, (14, 18), tuple_to_numericrange((14, 18))]:
assert True == m(adult, object(), ya, reader_age)
# Otherwise, YA readers cannot see books for adults.
@@ -711,26 +724,31 @@ def test_age_appropriate_match(self):
# we don't have the information necessary to say it's not
# fine).
work_target_age = None
- assert True == m(work_audience, work_target_age,
- reader_audience, object())
+ assert True == m(work_audience, work_target_age, reader_audience, object())
# Now give the work a specific target age range.
- for work_target_age in [(5, 7), tuple_to_numericrange((5,7))]:
+ for work_target_age in [(5, 7), tuple_to_numericrange((5, 7))]:
# The lower end of the age range is old enough.
- for age in range(5,9):
+ for age in range(5, 9):
for reader_age in (
- age, (age-1, age), tuple_to_numericrange((age-1, age))
+ age,
+ (age - 1, age),
+ tuple_to_numericrange((age - 1, age)),
):
- assert True == m(work_audience, work_target_age,
- reader_audience, reader_age)
+ assert True == m(
+ work_audience, work_target_age, reader_audience, reader_age
+ )
# Anything lower than that is not.
- for age in range(2,5):
+ for age in range(2, 5):
for reader_age in (
- age, (age-1, age), tuple_to_numericrange((age-1, age))
+ age,
+ (age - 1, age),
+ tuple_to_numericrange((age - 1, age)),
):
- assert False == m(work_audience, work_target_age,
- reader_audience, reader_age)
+ assert False == m(
+ work_audience, work_target_age, reader_audience, reader_age
+ )
# Similar rules apply for a YA reader who wants to read a YA
# book.
@@ -740,30 +758,34 @@ def test_age_appropriate_match(self):
# If there's no target age, it's fine (or at least we don't
# have the information necessary to say it's not fine).
work_target_age = None
- assert True == m(work_audience, work_target_age,
- reader_audience, object())
+ assert True == m(work_audience, work_target_age, reader_audience, object())
# Now give the work a specific target age range.
for work_target_age in ((14, 16), tuple_to_numericrange((14, 16))):
# The lower end of the age range is old enough
for age in range(14, 20):
for reader_age in (
- age, (age-1, age), tuple_to_numericrange((age-1, age))
+ age,
+ (age - 1, age),
+ tuple_to_numericrange((age - 1, age)),
):
- assert True == m(work_audience, work_target_age,
- reader_audience, reader_age)
+ assert True == m(
+ work_audience, work_target_age, reader_audience, reader_age
+ )
# Anything lower than that is not.
for age in range(7, 14):
for reader_age in (
- age, (age-1, age), tuple_to_numericrange((age-1, age))
+ age,
+ (age - 1, age),
+ tuple_to_numericrange((age - 1, age)),
):
- assert False == m(work_audience, work_target_age,
- reader_audience, reader_age)
+ assert False == m(
+ work_audience, work_target_age, reader_audience, reader_age
+ )
class TestPatronProfileStorage(DatabaseTest):
-
def setup_method(self):
super(TestPatronProfileStorage, self).setup_method()
self.patron = self._patron()
@@ -771,8 +793,10 @@ def setup_method(self):
def test_writable_setting_names(self):
"""Only one setting is currently writable."""
- assert (set([self.store.SYNCHRONIZE_ANNOTATIONS]) ==
- self.store.writable_setting_names)
+ assert (
+ set([self.store.SYNCHRONIZE_ANNOTATIONS])
+ == self.store.writable_setting_names
+ )
def test_profile_document(self):
# synchronize_annotations always shows up as settable, even if
@@ -780,25 +804,19 @@ def test_profile_document(self):
self.patron.authorization_identifier = "abcd"
assert None == self.patron.synchronize_annotations
rep = self.store.profile_document
- assert (
- {
- 'simplified:authorization_identifier': 'abcd',
- 'settings': {'simplified:synchronize_annotations': None}
- } ==
- rep)
+ assert {
+ "simplified:authorization_identifier": "abcd",
+ "settings": {"simplified:synchronize_annotations": None},
+ } == rep
self.patron.synchronize_annotations = True
- self.patron.authorization_expires = datetime_utc(
- 2016, 1, 1, 10, 20, 30
- )
+ self.patron.authorization_expires = datetime_utc(2016, 1, 1, 10, 20, 30)
rep = self.store.profile_document
- assert (
- {
- 'simplified:authorization_expires': '2016-01-01T10:20:30Z',
- 'simplified:authorization_identifier': 'abcd',
- 'settings': {'simplified:synchronize_annotations': True}
- } ==
- rep)
+ assert {
+ "simplified:authorization_expires": "2016-01-01T10:20:30Z",
+ "simplified:authorization_identifier": "abcd",
+ "settings": {"simplified:synchronize_annotations": True},
+ } == rep
def test_update(self):
# This is a no-op.
@@ -806,5 +824,5 @@ def test_update(self):
assert None == self.patron.synchronize_annotations
# This is not.
- self.store.update({self.store.SYNCHRONIZE_ANNOTATIONS : True}, {})
+ self.store.update({self.store.SYNCHRONIZE_ANNOTATIONS: True}, {})
assert True == self.patron.synchronize_annotations
diff --git a/tests/models/test_resource.py b/tests/models/test_resource.py
index 7c305864e..074e8f0bd 100644
--- a/tests/models/test_resource.py
+++ b/tests/models/test_resource.py
@@ -1,34 +1,35 @@
# encoding: utf-8
-import pytest
import os
-from ...testing import (
- DatabaseTest,
- DummyHTTPClient,
-)
+
+import pytest
+
from ...model import create
from ...model.datasource import DataSource
from ...model.edition import Edition
from ...model.identifier import Identifier
from ...model.licensing import RightsStatus
-from ...model.resource import (
- Hyperlink,
- Representation,
- Resource,
-)
-from ...testing import MockRequestsResponse
+from ...model.resource import Hyperlink, Representation, Resource
+from ...testing import DatabaseTest, DummyHTTPClient, MockRequestsResponse
-class TestHyperlink(DatabaseTest):
+class TestHyperlink(DatabaseTest):
def test_add_link(self):
edition, pool = self._edition(with_license_pool=True)
identifier = edition.primary_identifier
data_source = pool.data_source
original, ignore = create(self._db, Resource, url="http://bar.com")
hyperlink, is_new = pool.add_link(
- Hyperlink.DESCRIPTION, "http://foo.com/", data_source,
- "text/plain", "The content", None, RightsStatus.CC_BY,
- "The rights explanation", original,
- transformation_settings=dict(setting="a setting"))
+ Hyperlink.DESCRIPTION,
+ "http://foo.com/",
+ data_source,
+ "text/plain",
+ "The content",
+ None,
+ RightsStatus.CC_BY,
+ "The rights explanation",
+ original,
+ transformation_settings=dict(setting="a setting"),
+ )
assert True == is_new
rep = hyperlink.resource.representation
assert "text/plain" == rep.media_type
@@ -73,16 +74,12 @@ def m():
assert [] == m()
# Hyperlink rel is not mirrorable.
- wrong_type, ignore = i1.add_link(
- "not mirrorable", self._url, ds, "text/plain"
- )
+ wrong_type, ignore = i1.add_link("not mirrorable", self._url, ds, "text/plain")
assert [] == m()
# Hyperlink has no associated representation -- it needs to be
# mirrored, which will create one!
- hyperlink, ignore = i1.add_link(
- Hyperlink.IMAGE, self._url, ds, "image/png"
- )
+ hyperlink, ignore = i1.add_link(Hyperlink.IMAGE, self._url, ds, "image/png")
assert [hyperlink] == m()
# Representation is already mirrored, so does not show up
@@ -104,7 +101,6 @@ def m():
class TestResource(DatabaseTest):
-
def test_as_delivery_mechanism_for(self):
# Calling as_delivery_mechanism_for on a Resource that is used
@@ -124,16 +120,18 @@ def test_as_delivery_mechanism_for(self):
class TestRepresentation(DatabaseTest):
-
def test_normalized_content_path(self):
assert "baz" == Representation.normalize_content_path(
- "/foo/bar/baz", "/foo/bar")
+ "/foo/bar/baz", "/foo/bar"
+ )
assert "baz" == Representation.normalize_content_path(
- "/foo/bar/baz", "/foo/bar/")
+ "/foo/bar/baz", "/foo/bar/"
+ )
assert "/foo/bar/baz" == Representation.normalize_content_path(
- "/foo/bar/baz", "/blah/blah/")
+ "/foo/bar/baz", "/blah/blah/"
+ )
def test_best_media_type(self):
"""Test our ability to determine whether the Content-Type
@@ -152,25 +150,23 @@ def test_best_media_type(self):
# Except when the content-type header is so generic as to be uselses.
assert "text/plain" == m(
- None,
- {"content-type": "application/octet-stream;profile=foo"},
- "text/plain")
+ None, {"content-type": "application/octet-stream;profile=foo"}, "text/plain"
+ )
# If no default media type is specified, but one can be derived from
# the URL, that one is used as the default.
assert "image/jpeg" == m(
"http://images-galore/cover.jpeg",
{"content-type": "application/octet-stream;profile=foo"},
- None)
+ None,
+ )
# But a default media type doesn't override a specific
# Content-Type from the server, even if it superficially makes
# more sense.
assert "image/png" == m(
- "http://images-galore/cover.jpeg",
- {"content-type": "image/png"},
- None)
-
+ "http://images-galore/cover.jpeg", {"content-type": "image/png"}, None
+ )
def test_mirrorable_media_type(self):
representation, ignore = self._representation(self._url)
@@ -198,7 +194,9 @@ def test_guess_media_type(self):
assert Representation.JPEG_MEDIA_TYPE == m_file(jpg_file)
assert Representation.ZIP_MEDIA_TYPE == m_file(zip_file)
- for extension, media_type in list(Representation.MEDIA_TYPE_FOR_EXTENSION.items()):
+ for extension, media_type in list(
+ Representation.MEDIA_TYPE_FOR_EXTENSION.items()
+ ):
filename = "file" + extension
assert media_type == m_file(filename)
@@ -211,10 +209,16 @@ def test_guess_media_type(self):
zip_url = "https://some_url/path/file.zip"
assert Representation.ZIP_MEDIA_TYPE == m_file(zip_url)
# ... but will get these wrong.
- zip_url_with_query = "https://some_url/path/file.zip?Policy=xyz123&Key-Pair-Id=xxx"
- zip_url_misleading = "https://some_url/path/file.zip?Policy=xyz123&associated_cover=image.jpg"
+ zip_url_with_query = (
+ "https://some_url/path/file.zip?Policy=xyz123&Key-Pair-Id=xxx"
+ )
+ zip_url_misleading = (
+ "https://some_url/path/file.zip?Policy=xyz123&associated_cover=image.jpg"
+ )
assert None == m_file(zip_url_with_query) # We get None, but want Zip
- assert Representation.JPEG_MEDIA_TYPE == m_file(zip_url_misleading) # We get JPEG, but want Zip
+ assert Representation.JPEG_MEDIA_TYPE == m_file(
+ zip_url_misleading
+ ) # We get JPEG, but want Zip
# Taking URL structure into account should get them all right.
assert Representation.ZIP_MEDIA_TYPE == m_url(zip_url)
@@ -227,8 +231,6 @@ def test_guess_media_type(self):
assert Representation.ZIP_MEDIA_TYPE == m_url(zip_file_rel_path)
assert Representation.ZIP_MEDIA_TYPE == m_url(zip_file_abs_path)
-
-
def test_external_media_type_and_extension(self):
"""Test the various transformations that might happen to media type
and extension when we mirror a representation.
@@ -237,17 +239,15 @@ def test_external_media_type_and_extension(self):
# An unknown file at /foo
representation, ignore = self._representation(self._url, "text/unknown")
assert "text/unknown" == representation.external_media_type
- assert '' == representation.extension()
+ assert "" == representation.extension()
# A text file at /foo
representation, ignore = self._representation(self._url, "text/plain")
assert "text/plain" == representation.external_media_type
- assert '.txt' == representation.extension()
+ assert ".txt" == representation.extension()
# A JPEG at /foo.jpg
- representation, ignore = self._representation(
- self._url + ".jpg", "image/jpeg"
- )
+ representation, ignore = self._representation(self._url + ".jpg", "image/jpeg")
assert "image/jpeg" == representation.external_media_type
assert ".jpg" == representation.extension()
@@ -264,12 +264,14 @@ def test_external_media_type_and_extension(self):
# An EPUB at /foo.epub.images -- information present in the URL
# is preserved.
representation, ignore = self._representation(
- self._url + '.epub.images', Representation.EPUB_MEDIA_TYPE
+ self._url + ".epub.images", Representation.EPUB_MEDIA_TYPE
)
assert Representation.EPUB_MEDIA_TYPE == representation.external_media_type
assert ".epub.images" == representation.extension()
- representation, ignore = self._representation(self._url + ".svg", "image/svg+xml")
+ representation, ignore = self._representation(
+ self._url + ".svg", "image/svg+xml"
+ )
assert "image/svg+xml" == representation.external_media_type
assert ".svg" == representation.extension()
@@ -327,44 +329,49 @@ def test_presumed_media_type(self):
# In the absence of a content-type header, the presumed_media_type
# takes over.
- h.queue_response(200, None, content='content')
+ h.queue_response(200, None, content="content")
representation, cached = Representation.get(
- self._db, 'http://url', do_get=h.do_get, max_age=0,
- presumed_media_type="text/xml"
+ self._db,
+ "http://url",
+ do_get=h.do_get,
+ max_age=0,
+ presumed_media_type="text/xml",
)
- assert 'text/xml' == representation.media_type
+ assert "text/xml" == representation.media_type
# In the presence of a generic content-type header, the
# presumed_media_type takes over.
- h.queue_response(200, 'application/octet-stream',
- content='content')
+ h.queue_response(200, "application/octet-stream", content="content")
representation, cached = Representation.get(
- self._db, 'http://url', do_get=h.do_get, max_age=0,
- presumed_media_type="text/xml"
+ self._db,
+ "http://url",
+ do_get=h.do_get,
+ max_age=0,
+ presumed_media_type="text/xml",
)
- assert 'text/xml' == representation.media_type
+ assert "text/xml" == representation.media_type
# A non-generic content-type header takes precedence over
# presumed_media_type.
- h.queue_response(200, 'text/plain', content='content')
+ h.queue_response(200, "text/plain", content="content")
representation, cached = Representation.get(
- self._db, 'http://url', do_get=h.do_get, max_age=0,
- presumed_media_type="text/xml"
+ self._db,
+ "http://url",
+ do_get=h.do_get,
+ max_age=0,
+ presumed_media_type="text/xml",
)
- assert 'text/plain' == representation.media_type
-
+ assert "text/plain" == representation.media_type
def test_404_creates_cachable_representation(self):
h = DummyHTTPClient()
h.queue_response(404)
url = self._url
- representation, cached = Representation.get(
- self._db, url, do_get=h.do_get)
+ representation, cached = Representation.get(self._db, url, do_get=h.do_get)
assert False == cached
- representation2, cached = Representation.get(
- self._db, url, do_get=h.do_get)
+ representation2, cached = Representation.get(self._db, url, do_get=h.do_get)
assert True == cached
assert representation == representation2
@@ -373,12 +380,10 @@ def test_302_creates_cachable_representation(self):
h.queue_response(302)
url = self._url
- representation, cached = Representation.get(
- self._db, url, do_get=h.do_get)
+ representation, cached = Representation.get(self._db, url, do_get=h.do_get)
assert False == cached
- representation2, cached = Representation.get(
- self._db, url, do_get=h.do_get)
+ representation2, cached = Representation.get(self._db, url, do_get=h.do_get)
assert True == cached
assert representation == representation2
@@ -386,22 +391,20 @@ def test_500_creates_uncachable_representation(self):
h = DummyHTTPClient()
h.queue_response(500)
url = self._url
- representation, cached = Representation.get(
- self._db, url, do_get=h.do_get)
+ representation, cached = Representation.get(self._db, url, do_get=h.do_get)
assert False == cached
h.queue_response(500)
- representation, cached = Representation.get(
- self._db, url, do_get=h.do_get)
+ representation, cached = Representation.get(self._db, url, do_get=h.do_get)
assert False == cached
def test_response_reviewer_impacts_representation(self):
h = DummyHTTPClient()
- h.queue_response(200, media_type='text/html')
+ h.queue_response(200, media_type="text/html")
def reviewer(response):
status, headers, content = response
- if 'html' in headers['content-type']:
+ if "html" in headers["content-type"]:
raise Exception("No. Just no.")
representation, cached = Representation.get(
@@ -417,11 +420,11 @@ def oops(*args, **kwargs):
# By default exceptions raised during get() are
# recorded along with the (empty) Representation objects
representation, cached = Representation.get(
- self._db, self._url, do_get=oops,
- )
- assert representation.fetch_exception.strip().endswith(
- "Exception: oops!"
+ self._db,
+ self._url,
+ do_get=oops,
)
+ assert representation.fetch_exception.strip().endswith("Exception: oops!")
assert None == representation.content
assert None == representation.status_code
@@ -429,8 +432,11 @@ def oops(*args, **kwargs):
# being handled.
with pytest.raises(Exception) as excinfo:
Representation.get(
- self._db, self._url, do_get = oops,
- exception_handler = Representation.reraise_exception)
+ self._db,
+ self._url,
+ do_get=oops,
+ exception_handler=Representation.reraise_exception,
+ )
assert "oops!" in str(excinfo.value)
def test_url_extension(self):
@@ -458,8 +464,9 @@ def test_url_extension(self):
def test_clean_media_type(self):
m = Representation._clean_media_type
assert "image/jpeg" == m("image/jpeg")
- assert ("application/atom+xml" ==
- m("application/atom+xml;profile=opds-catalog;kind=acquisition"))
+ assert "application/atom+xml" == m(
+ "application/atom+xml;profile=opds-catalog;kind=acquisition"
+ )
def test_extension(self):
m = Representation._extension
@@ -500,17 +507,23 @@ def test_default_filename(self):
# This URL has no path component, so we can't even come up with a
# decent default filename. We have to go with 'resource'.
- representation, ignore = self._representation("http://example.com/", "text/unknown")
- assert 'resource' == representation.default_filename()
- assert 'resource.png' == representation.default_filename(destination_type="image/png")
+ representation, ignore = self._representation(
+ "http://example.com/", "text/unknown"
+ )
+ assert "resource" == representation.default_filename()
+ assert "resource.png" == representation.default_filename(
+ destination_type="image/png"
+ )
# But if we know what type of thing we're linking to, we can
# do a little better.
link = Hyperlink(rel=Hyperlink.IMAGE)
filename = representation.default_filename(link=link)
- assert 'cover' == filename
- filename = representation.default_filename(link=link, destination_type="image/png")
- assert 'cover.png' == filename
+ assert "cover" == filename
+ filename = representation.default_filename(
+ link=link, destination_type="image/png"
+ )
+ assert "cover.png" == filename
def test_cautious_http_get(self):
@@ -521,8 +534,11 @@ def test_cautious_http_get(self):
# with no HEAD request being made.
m = Representation.cautious_http_get
status, headers, content = m(
- "http://safe.org/", {}, do_not_access=['unsafe.org'],
- do_get=h.do_get, cautious_head_client=object()
+ "http://safe.org/",
+ {},
+ do_not_access=["unsafe.org"],
+ do_get=h.do_get,
+ cautious_head_client=object(),
)
assert 200 == status
assert b"yay" == content
@@ -530,44 +546,56 @@ def test_cautious_http_get(self):
# If the domain is obviously unsafe, no GET request or HEAD
# request is made.
status, headers, content = m(
- "http://unsafe.org/", {}, do_not_access=['unsafe.org'],
- do_get=object(), cautious_head_client=object()
+ "http://unsafe.org/",
+ {},
+ do_not_access=["unsafe.org"],
+ do_get=object(),
+ cautious_head_client=object(),
)
assert 417 == status
- assert ("Cautiously decided not to make a GET request to http://unsafe.org/" ==
- content)
+ assert (
+ "Cautiously decided not to make a GET request to http://unsafe.org/"
+ == content
+ )
# If the domain is potentially unsafe, a HEAD request is made,
# and the answer depends on its outcome.
# Here, the HEAD request redirects to a prohibited site.
def mock_redirect(*args, **kwargs):
- return MockRequestsResponse(
- 301, dict(location="http://unsafe.org/")
- )
+ return MockRequestsResponse(301, dict(location="http://unsafe.org/"))
+
status, headers, content = m(
- "http://caution.org/", {},
- do_not_access=['unsafe.org'],
- check_for_redirect=['caution.org'],
- do_get=object(), cautious_head_client=mock_redirect
+ "http://caution.org/",
+ {},
+ do_not_access=["unsafe.org"],
+ check_for_redirect=["caution.org"],
+ do_get=object(),
+ cautious_head_client=mock_redirect,
)
assert 417 == status
- assert ("application/vnd.librarysimplified-did-not-make-request" ==
- headers['content-type'])
- assert ("Cautiously decided not to make a GET request to http://caution.org/" ==
- content)
+ assert (
+ "application/vnd.librarysimplified-did-not-make-request"
+ == headers["content-type"]
+ )
+ assert (
+ "Cautiously decided not to make a GET request to http://caution.org/"
+ == content
+ )
# Here, the HEAD request redirects to an allowed site.
h.queue_response(200, content="good content")
+
def mock_redirect(*args, **kwargs):
- return MockRequestsResponse(
- 301, dict(location="http://safe.org/")
- )
+ return MockRequestsResponse(301, dict(location="http://safe.org/"))
+
status, headers, content = m(
- "http://caution.org/", {},
- do_not_access=['unsafe.org'],
- check_for_redirect=['caution.org'],
- do_get=h.do_get, cautious_head_client=mock_redirect
+ "http://caution.org/",
+ {},
+ do_not_access=["unsafe.org"],
+ check_for_redirect=["caution.org"],
+ do_get=h.do_get,
+ cautious_head_client=mock_redirect,
)
assert 200 == status
assert b"good content" == content
@@ -583,31 +611,38 @@ def test_get_would_be_useful(self):
fake_head = object()
# Most sites are safe with no HEAD request necessary.
- assert True == safe("http://www.safe-site.org/book.epub", {},
- head_client=fake_head)
+ assert True == safe(
+ "http://www.safe-site.org/book.epub", {}, head_client=fake_head
+ )
# gutenberg.org is problematic, no HEAD request necessary.
- assert False == safe("http://www.gutenberg.org/book.epub", {},
- head_client=fake_head)
+ assert False == safe(
+ "http://www.gutenberg.org/book.epub", {}, head_client=fake_head
+ )
# do_not_access controls which domains should always be
# considered unsafe.
- assert (
- False == safe(
- "http://www.safe-site.org/book.epub", {},
- do_not_access=['safe-site.org'], head_client=fake_head
- ))
- assert (
- True == safe(
- "http://www.gutenberg.org/book.epub", {},
- do_not_access=['safe-site.org'], head_client=fake_head
- ))
+ assert False == safe(
+ "http://www.safe-site.org/book.epub",
+ {},
+ do_not_access=["safe-site.org"],
+ head_client=fake_head,
+ )
+ assert True == safe(
+ "http://www.gutenberg.org/book.epub",
+ {},
+ do_not_access=["safe-site.org"],
+ head_client=fake_head,
+ )
# Domain match is based on a subdomain match, not a substring
# match.
- assert True == safe("http://www.not-unsafe-site.org/book.epub", {},
- do_not_access=['unsafe-site.org'],
- head_client=fake_head)
+ assert True == safe(
+ "http://www.not-unsafe-site.org/book.epub",
+ {},
+ do_not_access=["unsafe-site.org"],
+ head_client=fake_head,
+ )
# Some domains (unglue.it) are known to make surprise
# redirects to unsafe domains. For these, we must make a HEAD
@@ -615,33 +650,32 @@ def test_get_would_be_useful(self):
def bad_redirect(*args, **kwargs):
return MockRequestsResponse(
- 301, dict(
- location="http://www.gutenberg.org/a-book.html"
- )
+ 301, dict(location="http://www.gutenberg.org/a-book.html")
)
- assert False == safe("http://www.unglue.it/book", {},
- head_client=bad_redirect)
+
+ assert False == safe("http://www.unglue.it/book", {}, head_client=bad_redirect)
def good_redirect(*args, **kwargs):
return MockRequestsResponse(
- 301,
- dict(location="http://www.some-other-site.org/a-book.epub")
+ 301, dict(location="http://www.some-other-site.org/a-book.epub")
)
- assert (
- True ==
- safe("http://www.unglue.it/book", {}, head_client=good_redirect))
+
+ assert True == safe("http://www.unglue.it/book", {}, head_client=good_redirect)
def not_a_redirect(*args, **kwargs):
return MockRequestsResponse(200)
- assert True == safe("http://www.unglue.it/book", {},
- head_client=not_a_redirect)
+
+ assert True == safe("http://www.unglue.it/book", {}, head_client=not_a_redirect)
# The `check_for_redirect` argument controls which domains are
# checked using HEAD requests. Here, we customise it to check
# a site other than unglue.it.
- assert False == safe("http://www.questionable-site.org/book.epub", {},
- check_for_redirect=['questionable-site.org'],
- head_client=bad_redirect)
+ assert False == safe(
+ "http://www.questionable-site.org/book.epub",
+ {},
+ check_for_redirect=["questionable-site.org"],
+ head_client=bad_redirect,
+ )
def test_get_with_url_normalizer(self):
# Verify our ability to store a Resource under a URL other than
@@ -649,6 +683,7 @@ def test_get_with_url_normalizer(self):
class Normalizer(object):
called_with = None
+
def normalize(self, url):
# Strip off a session ID from an outgoing URL.
self.called_with = url
@@ -661,8 +696,7 @@ def normalize(self, url):
original_url = "http://url/?sid=12345"
representation, from_cache = Representation.get(
- self._db, original_url, do_get=h.do_get,
- url_normalizer=normalizer.normalize
+ self._db, original_url, do_get=h.do_get, url_normalizer=normalizer.normalize
)
# The original URL was used to make the actual request.
@@ -684,8 +718,7 @@ def normalize(self, url):
# Replace do_get with a dud object to prove that no second
# request goes out 'over the wire'.
representation2, from_cache = Representation.get(
- self._db, original_url, do_get=object(),
- url_normalizer=normalizer.normalize
+ self._db, original_url, do_get=object(), url_normalizer=normalizer.normalize
)
assert True == from_cache
assert representation2 == representation
@@ -711,8 +744,8 @@ def test_best_thumbnail(self):
t2.set_as_mirrored(self._url)
assert t2 == representation.best_thumbnail
-class TestCoverResource(DatabaseTest):
+class TestCoverResource(DatabaseTest):
def test_set_cover(self):
edition, pool = self._edition(with_license_pool=True)
original = self._url
@@ -720,8 +753,11 @@ def test_set_cover(self):
thumbnail_mirror = self._url
sample_cover_path = self.sample_cover_path("test-book-cover.png")
hyperlink, ignore = pool.add_link(
- Hyperlink.IMAGE, original, edition.data_source, "image/png",
- content=open(sample_cover_path, 'rb').read()
+ Hyperlink.IMAGE,
+ original,
+ edition.data_source,
+ "image/png",
+ content=open(sample_cover_path, "rb").read(),
)
full_rep = hyperlink.resource.representation
full_rep.set_as_mirrored(mirror)
@@ -744,8 +780,11 @@ def test_set_cover_for_very_small_image(self):
mirror = self._url
sample_cover_path = self.sample_cover_path("tiny-image-cover.png")
hyperlink, ignore = pool.add_link(
- Hyperlink.IMAGE, original, edition.data_source, "image/png",
- open(sample_cover_path, 'rb').read()
+ Hyperlink.IMAGE,
+ original,
+ edition.data_source,
+ "image/png",
+ open(sample_cover_path, "rb").read(),
)
full_rep = hyperlink.resource.representation
full_rep.set_as_mirrored(mirror)
@@ -760,15 +799,20 @@ def test_set_cover_for_smallish_image_uses_full_sized_image_as_thumbnail(self):
mirror = self._url
sample_cover_path = self.sample_cover_path("tiny-image-cover.png")
hyperlink, ignore = pool.add_link(
- Hyperlink.IMAGE, original, edition.data_source, "image/png",
- open(sample_cover_path, 'rb').read()
+ Hyperlink.IMAGE,
+ original,
+ edition.data_source,
+ "image/png",
+ open(sample_cover_path, "rb").read(),
)
full_rep = hyperlink.resource.representation
full_rep.set_as_mirrored(mirror)
# For purposes of this test, pretend that the full-sized image is
# larger than a thumbnail, but not terribly large.
- hyperlink.resource.representation.image_height = Edition.MAX_FALLBACK_THUMBNAIL_HEIGHT
+ hyperlink.resource.representation.image_height = (
+ Edition.MAX_FALLBACK_THUMBNAIL_HEIGHT
+ )
edition.set_cover(hyperlink.resource)
assert mirror == edition.cover_full_url
@@ -776,16 +820,19 @@ def test_set_cover_for_smallish_image_uses_full_sized_image_as_thumbnail(self):
# If the full-sized image had been slightly larger, we would have
# decided not to use a thumbnail at all.
- hyperlink.resource.representation.image_height = Edition.MAX_FALLBACK_THUMBNAIL_HEIGHT + 1
+ hyperlink.resource.representation.image_height = (
+ Edition.MAX_FALLBACK_THUMBNAIL_HEIGHT + 1
+ )
edition.cover_thumbnail_url = None
edition.set_cover(hyperlink.resource)
assert None == edition.cover_thumbnail_url
-
def test_attempt_to_scale_non_image_sets_scale_exception(self):
rep, ignore = self._representation(media_type="text/plain", content="foo")
scaled, ignore = rep.scale(300, 600, self._url, "image/png")
- expect = "ValueError: Cannot load non-image representation as image: type text/plain"
+ expect = (
+ "ValueError: Cannot load non-image representation as image: type text/plain"
+ )
assert scaled == rep
assert expect in rep.scale_exception
@@ -822,8 +869,7 @@ def test_success(self):
# With the force argument we can forcibly re-scale an image,
# changing its size.
assert [thumbnail] == cover.thumbnails
- thumbnail2, is_new = cover.scale(
- 400, 700, url, "image/png", force=True)
+ thumbnail2, is_new = cover.scale(400, 700, url, "image/png", force=True)
assert True == is_new
assert [thumbnail2] == cover.thumbnails
assert cover == thumbnail2.thumbnail_of
@@ -897,8 +943,8 @@ def test_best_covers_among(self):
# Here's an abysmally bad cover.
lousy_cover = self.sample_cover_representation("tiny-image-cover.png")
- lousy_cover.image_height=1
- lousy_cover.image_width=10000
+ lousy_cover.image_height = 1
+ lousy_cover.image_width = 10000
link2, ignore = pool.add_link(
Hyperlink.THUMBNAIL_IMAGE, self._url, pool.data_source
)
@@ -919,9 +965,9 @@ def test_best_covers_among(self):
# This cover is at least good enough to pass muster if there
# is no other option.
- assert (
- [resource_with_decent_cover] ==
- Resource.best_covers_among([resource_with_decent_cover]))
+ assert [resource_with_decent_cover] == Resource.best_covers_among(
+ [resource_with_decent_cover]
+ )
# Let's create another cover image with identical
# characteristics.
@@ -940,13 +986,13 @@ def test_best_covers_among(self):
# All else being equal, if one cover is an PNG and the other
# is a JPEG, we prefer the PNG.
- resource_with_decent_cover.representation.media_type = Representation.JPEG_MEDIA_TYPE
+ resource_with_decent_cover.representation.media_type = (
+ Representation.JPEG_MEDIA_TYPE
+ )
assert [resource_with_decent_cover_2] == Resource.best_covers_among(l)
# But if the metadata wrangler said to use the JPEG, we use the JPEG.
- metadata_wrangler = DataSource.lookup(
- self._db, DataSource.METADATA_WRANGLER
- )
+ metadata_wrangler = DataSource.lookup(self._db, DataSource.METADATA_WRANGLER)
resource_with_decent_cover.data_source = metadata_wrangler
# ...the decision becomes easy.
@@ -1010,7 +1056,7 @@ def test_rejection_and_approval(self):
assert -last_voted_quality == cover.voted_quality
assert True == (cover.quality < 0)
- assert last_votes_for_quality+1 == cover.votes_for_quality
+ assert last_votes_for_quality + 1 == cover.votes_for_quality
def test_quality_as_thumbnail_image(self):
@@ -1021,9 +1067,7 @@ def test_quality_as_thumbnail_image(self):
self._db, DataSource.GUTENBERG_COVER_GENERATOR
)
overdrive = DataSource.lookup(self._db, DataSource.OVERDRIVE)
- metadata_wrangler = DataSource.lookup(
- self._db, DataSource.METADATA_WRANGLER
- )
+ metadata_wrangler = DataSource.lookup(self._db, DataSource.METADATA_WRANGLER)
# Here's a book with a thumbnail image.
edition, pool = self._edition(with_license_pool=True)
@@ -1085,16 +1129,16 @@ def f(width, height):
# An image that is the perfect aspect ratio, but too large,
# has no penalty.
- assert 1 == f(ideal_width*2, ideal_height*2)
+ assert 1 == f(ideal_width * 2, ideal_height * 2)
# An image that is the perfect aspect ratio, but is too small,
# is penalised.
- assert 1/4.0 == f(ideal_width*0.5, ideal_height*0.5)
- assert 1/16.0 == f(ideal_width*0.25, ideal_height*0.25)
+ assert 1 / 4.0 == f(ideal_width * 0.5, ideal_height * 0.5)
+ assert 1 / 16.0 == f(ideal_width * 0.25, ideal_height * 0.25)
# An image that deviates from the perfect aspect ratio is
# penalized in proportion.
- assert 1/2.0 == f(ideal_width*2, ideal_height)
- assert 1/2.0 == f(ideal_width, ideal_height*2)
- assert 1/4.0 == f(ideal_width*4, ideal_height)
- assert 1/4.0 == f(ideal_width, ideal_height*4)
+ assert 1 / 2.0 == f(ideal_width * 2, ideal_height)
+ assert 1 / 2.0 == f(ideal_width, ideal_height * 2)
+ assert 1 / 4.0 == f(ideal_width * 4, ideal_height)
+ assert 1 / 4.0 == f(ideal_width, ideal_height * 4)
diff --git a/tests/models/test_work.py b/tests/models/test_work.py
index e342b7fb5..9d04ce08a 100644
--- a/tests/models/test_work.py
+++ b/tests/models/test_work.py
@@ -1,45 +1,27 @@
# encoding: utf-8
-import pytest
import datetime
-from mock import MagicMock
import os
-from psycopg2.extras import NumericRange
import random
+import pytest
+from mock import MagicMock
+from psycopg2.extras import NumericRange
+
+from ...classifier import Classifier, Fantasy, Romance, Science_Fiction
from ...external_search import MockExternalSearchIndex
-from ...testing import DatabaseTest
-from ...classifier import (
- Classifier,
- Fantasy,
- Romance,
- Science_Fiction,
-)
-from ...model import (
- get_one_or_create,
- tuple_to_numericrange,
-)
-from ...model.coverage import WorkCoverageRecord
-from ...model.classification import (
- Genre,
- Subject,
-)
+from ...model import get_one_or_create, tuple_to_numericrange
+from ...model.classification import Genre, Subject
from ...model.complaint import Complaint
from ...model.contributor import Contributor
+from ...model.coverage import WorkCoverageRecord
from ...model.datasource import DataSource
from ...model.edition import Edition
from ...model.identifier import Identifier
from ...model.licensing import LicensePool
-from ...model.resource import (
- Hyperlink,
- Representation,
- Resource,
-)
-from ...model.work import (
- Work,
- WorkGenre,
-)
-from ...util.datetime_helpers import from_timestamp
-from ...util.datetime_helpers import datetime_utc, utc_now
+from ...model.resource import Hyperlink, Representation, Resource
+from ...model.work import Work, WorkGenre
+from ...testing import DatabaseTest
+from ...util.datetime_helpers import datetime_utc, from_timestamp, utc_now
class TestWork(DatabaseTest):
@@ -48,28 +30,22 @@ def test_complaints(self):
[lp1] = work.license_pools
lp2 = self._licensepool(
- edition=work.presentation_edition,
- data_source_name=DataSource.OVERDRIVE
+ edition=work.presentation_edition, data_source_name=DataSource.OVERDRIVE
)
lp2.work = work
complaint_type = random.choice(list(Complaint.VALID_TYPES))
- complaint1, ignore = Complaint.register(
- lp1, complaint_type, "blah", "blah"
- )
- complaint2, ignore = Complaint.register(
- lp2, complaint_type, "blah", "blah"
- )
+ complaint1, ignore = Complaint.register(lp1, complaint_type, "blah", "blah")
+ complaint2, ignore = Complaint.register(lp2, complaint_type, "blah", "blah")
# Create a complaint with no association with the work.
_edition, lp3 = self._edition(with_license_pool=True)
- complaint3, ignore = Complaint.register(
- lp3, complaint_type, "blah", "blah"
- )
+ complaint3, ignore = Complaint.register(lp3, complaint_type, "blah", "blah")
# Only the first two complaints show up in work.complaints.
- assert (sorted([complaint1.id, complaint2.id]) ==
- sorted([x.id for x in work.complaints]))
+ assert sorted([complaint1.id, complaint2.id]) == sorted(
+ [x.id for x in work.complaints]
+ )
def test_all_identifier_ids(self):
work = self._work(with_license_pool=True)
@@ -86,9 +62,7 @@ def test_all_identifier_ids(self):
all_identifier_ids = work.all_identifier_ids()
assert 3 == len(all_identifier_ids)
- expect_all_ids = set(
- [lp.identifier.id, lp2.identifier.id, identifier.id]
- )
+ expect_all_ids = set([lp.identifier.id, lp2.identifier.id, identifier.id])
assert expect_all_ids == all_identifier_ids
@@ -96,7 +70,9 @@ def test_from_identifiers(self):
# Prep a work to be identified and a work to be ignored.
work = self._work(with_license_pool=True, with_open_access_download=True)
lp = work.license_pools[0]
- ignored_work = self._work(with_license_pool=True, with_open_access_download=True)
+ ignored_work = self._work(
+ with_license_pool=True, with_open_access_download=True
+ )
# No identifiers returns None.
result = Work.from_identifiers(self._db, [])
@@ -133,8 +109,12 @@ def test_from_identifiers(self):
assert [work] == result
# It accepts a base query.
- qu = self._db.query(Work).join(LicensePool).join(Identifier).\
- filter(LicensePool.suppressed)
+ qu = (
+ self._db.query(Work)
+ .join(LicensePool)
+ .join(Identifier)
+ .filter(LicensePool.suppressed)
+ )
identifiers = [lp.identifier]
result = Work.from_identifiers(self._db, identifiers, base_query=qu).all()
# Because the work's license_pool isn't suppressed, it isn't returned.
@@ -157,14 +137,24 @@ def test_calculate_presentation(self):
[bob], ignore = Contributor.lookup(self._db, "Bitshifter, Bob")
bob.family_name, bob.display_name = bob.default_names()
- edition1, pool1 = self._edition(gitenberg_source, Identifier.GUTENBERG_ID,
- with_license_pool=True, with_open_access_download=True, authors=[])
+ edition1, pool1 = self._edition(
+ gitenberg_source,
+ Identifier.GUTENBERG_ID,
+ with_license_pool=True,
+ with_open_access_download=True,
+ authors=[],
+ )
edition1.title = "The 1st Title"
edition1.subtitle = "The 1st Subtitle"
edition1.add_contributor(bob, Contributor.AUTHOR_ROLE)
- edition2, pool2 = self._edition(gitenberg_source, Identifier.GUTENBERG_ID,
- with_license_pool=True, with_open_access_download=True, authors=[])
+ edition2, pool2 = self._edition(
+ gitenberg_source,
+ Identifier.GUTENBERG_ID,
+ with_license_pool=True,
+ with_open_access_download=True,
+ authors=[],
+ )
edition2.title = "The 2nd Title"
edition2.subtitle = "The 2nd Subtitle"
edition2.add_contributor(bob, Contributor.AUTHOR_ROLE)
@@ -172,8 +162,13 @@ def test_calculate_presentation(self):
alice.family_name, alice.display_name = alice.default_names()
edition2.add_contributor(alice, Contributor.AUTHOR_ROLE)
- edition3, pool3 = self._edition(gutenberg_source, Identifier.GUTENBERG_ID,
- with_license_pool=True, with_open_access_download=True, authors=[])
+ edition3, pool3 = self._edition(
+ gutenberg_source,
+ Identifier.GUTENBERG_ID,
+ with_license_pool=True,
+ with_open_access_download=True,
+ authors=[],
+ )
edition3.title = "The 2nd Title"
edition3.subtitle = "The 2nd Subtitle"
edition3.add_contributor(bob, Contributor.AUTHOR_ROLE)
@@ -194,12 +189,12 @@ def test_calculate_presentation(self):
# This summary is associated with one of the work's
# LicensePools, but it comes from a less reliable source, so
# it won't be chosen.
- less_reliable_summary_source = DataSource.lookup(
- self._db, DataSource.OCLC
- )
+ less_reliable_summary_source = DataSource.lookup(self._db, DataSource.OCLC)
pool2.identifier.add_link(
- Hyperlink.DESCRIPTION, None, less_reliable_summary_source,
- content="less reliable summary"
+ Hyperlink.DESCRIPTION,
+ None,
+ less_reliable_summary_source,
+ content="less reliable summary",
)
# This summary looks really nice, and it's associated with the
@@ -211,8 +206,10 @@ def test_calculate_presentation(self):
pool3.data_source, related_identifier, strength=1
)
related_identifier.add_link(
- Hyperlink.DESCRIPTION, None, pool3.data_source,
- content="This is an indirect summary. It's much longer, and looks more 'real', so you'd think it would be prefered, but it won't be."
+ Hyperlink.DESCRIPTION,
+ None,
+ pool3.data_source,
+ content="This is an indirect summary. It's much longer, and looks more 'real', so you'd think it would be prefered, but it won't be.",
)
work = self._slow_work(presentation_edition=edition2)
@@ -236,9 +233,11 @@ def test_calculate_presentation(self):
# it adds choose-edition as a primary edition is set. The
# search index CoverageRecord is a marker for work that must
# be done in the future, and is not tested here.
- [choose_edition, generate_opds, update_search_index] = sorted(work.coverage_records, key=lambda x: x.operation)
- assert (generate_opds.operation == WorkCoverageRecord.GENERATE_OPDS_OPERATION)
- assert (choose_edition.operation == WorkCoverageRecord.CHOOSE_EDITION_OPERATION)
+ [choose_edition, generate_opds, update_search_index] = sorted(
+ work.coverage_records, key=lambda x: x.operation
+ )
+ assert generate_opds.operation == WorkCoverageRecord.GENERATE_OPDS_OPERATION
+ assert choose_edition.operation == WorkCoverageRecord.CHOOSE_EDITION_OPERATION
# pools aren't yet aware of each other
assert pool1.superceded == False
@@ -297,15 +296,17 @@ def test_calculate_presentation(self):
wcr = WorkCoverageRecord
success = wcr.SUCCESS
- expect = set([
- (wcr.CHOOSE_EDITION_OPERATION, success),
- (wcr.CLASSIFY_OPERATION, success),
- (wcr.SUMMARY_OPERATION, success),
- (wcr.QUALITY_OPERATION, success),
- (wcr.GENERATE_OPDS_OPERATION, success),
- (wcr.GENERATE_MARC_OPERATION, success),
- (wcr.UPDATE_SEARCH_INDEX_OPERATION, wcr.REGISTERED),
- ])
+ expect = set(
+ [
+ (wcr.CHOOSE_EDITION_OPERATION, success),
+ (wcr.CLASSIFY_OPERATION, success),
+ (wcr.SUMMARY_OPERATION, success),
+ (wcr.QUALITY_OPERATION, success),
+ (wcr.GENERATE_OPDS_OPERATION, success),
+ (wcr.GENERATE_MARC_OPERATION, success),
+ (wcr.UPDATE_SEARCH_INDEX_OPERATION, wcr.REGISTERED),
+ ]
+ )
assert expect == set([(x.operation, x.status) for x in records])
# Now mark the pool with the presentation edition as suppressed.
@@ -343,8 +344,11 @@ def test_calculate_presentation(self):
# except when it has no contributors, and they do.
pool2.suppressed = False
- staff_edition = self._edition(data_source_name=DataSource.LIBRARY_STAFF,
- with_license_pool=False, authors=[])
+ staff_edition = self._edition(
+ data_source_name=DataSource.LIBRARY_STAFF,
+ with_license_pool=False,
+ authors=[],
+ )
staff_edition.title = "The Staff Title"
staff_edition.primary_identifier = pool2.identifier
# set edition's authorship to "nope", and make sure the lower-priority
@@ -378,8 +382,7 @@ def test_calculate_presentation_with_no_presentation_edition(self):
# Work was done to choose the presentation edition, but since no
# presentation edition was found, no other work was done.
[choose_edition] = work.coverage_records
- assert (WorkCoverageRecord.CHOOSE_EDITION_OPERATION ==
- choose_edition.operation)
+ assert WorkCoverageRecord.CHOOSE_EDITION_OPERATION == choose_edition.operation
def test_calculate_presentation_sets_presentation_ready_based_on_content(self):
@@ -398,7 +401,9 @@ def test_calculate_presentation_sets_presentation_ready_based_on_content(self):
work.calculate_presentation()
assert True == work.presentation_ready
- def test_calculate_presentation_uses_default_audience_set_as_collection_setting(self):
+ def test_calculate_presentation_uses_default_audience_set_as_collection_setting(
+ self,
+ ):
default_audience = Classifier.AUDIENCE_ADULT
collection = self._default_collection
collection.default_audience = default_audience
@@ -407,7 +412,7 @@ def test_calculate_presentation_uses_default_audience_set_as_collection_setting(
Identifier.GUTENBERG_ID,
collection=collection,
with_license_pool=True,
- with_open_access_download=True
+ with_open_access_download=True,
)
work = self._slow_work(presentation_edition=edition)
work.last_update_time = None
@@ -443,20 +448,13 @@ def set_summary(self, summary):
i1 = self._identifier()
l1, ignore = i1.add_link(
- Hyperlink.DESCRIPTION, None, source1,
- content="ok summary"
+ Hyperlink.DESCRIPTION, None, source1, content="ok summary"
)
good_summary = "This summary is great! It's more than one sentence long and features some noun phrases."
- i1.add_link(
- Hyperlink.DESCRIPTION, None, source2,
- content=good_summary
- )
+ i1.add_link(Hyperlink.DESCRIPTION, None, source2, content=good_summary)
i2 = self._identifier()
- i2.add_link(
- Hyperlink.DESCRIPTION, None, source2,
- content="not too bad"
- )
+ i2.add_link(Hyperlink.DESCRIPTION, None, source2, content="not too bad")
# Now we can test out the rules for choosing summaries.
@@ -487,9 +485,7 @@ def set_summary(self, summary):
# LIBRARY_STAFF is always considered a good source of
# descriptions.
- l1.data_source = DataSource.lookup(
- self._db, DataSource.LIBRARY_STAFF
- )
+ l1.data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF)
m([i1.id, i2.id], [], [])
assert l1.resource.representation.content.decode("utf-8") == w.summary_text
@@ -500,9 +496,11 @@ def test_set_presentation_ready_based_on_content(self):
search = MockExternalSearchIndex()
# This is how the work will be represented in the dummy search
# index.
- index_key = (search.works_index,
- MockExternalSearchIndex.work_document_type,
- work.id)
+ index_key = (
+ search.works_index,
+ MockExternalSearchIndex.work_document_type,
+ work.id,
+ )
presentation = work.presentation_edition
work.set_presentation_ready_based_on_content(search_index_client=search)
@@ -517,10 +515,12 @@ def assert_record():
# Verify the search index WorkCoverageRecord for this work
# is in the REGISTERED state.
[record] = [
- x for x in work.coverage_records
- if x.operation==WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION
+ x
+ for x in work.coverage_records
+ if x.operation == WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION
]
assert WorkCoverageRecord.REGISTERED == record.status
+
assert_record()
# This work is presentation ready because it has a title.
@@ -555,7 +555,7 @@ def assert_record():
work.set_presentation_ready_based_on_content(search_index_client=search)
assert False == work.presentation_ready
- presentation.language = 'eng'
+ presentation.language = "eng"
work.set_presentation_ready_based_on_content(search_index_client=search)
assert True == work.presentation_ready
@@ -570,16 +570,16 @@ def test_assign_genres_from_weights(self):
work = self._work()
# This work was once classified under Fantasy and Romance.
- work.assign_genres_from_weights({Romance : 1000, Fantasy : 1000})
+ work.assign_genres_from_weights({Romance: 1000, Fantasy: 1000})
self._db.commit()
before = sorted((x.genre.name, x.affinity) for x in work.work_genres)
- assert [('Fantasy', 0.5), ('Romance', 0.5)] == before
+ assert [("Fantasy", 0.5), ("Romance", 0.5)] == before
# But now it's classified under Science Fiction and Romance.
- work.assign_genres_from_weights({Romance : 100, Science_Fiction : 300})
+ work.assign_genres_from_weights({Romance: 100, Science_Fiction: 300})
self._db.commit()
after = sorted((x.genre.name, x.affinity) for x in work.work_genres)
- assert [('Romance', 0.25), ('Science Fiction', 0.75)] == after
+ assert [("Romance", 0.25), ("Science Fiction", 0.75)] == after
def test_classifications_with_genre(self):
work = self._work(with_open_access_download=True)
@@ -593,14 +593,14 @@ def test_classifications_with_genre(self):
subject3.genre = None
source = DataSource.lookup(self._db, DataSource.AXIS_360)
classification1 = self._classification(
- identifier=identifier, subject=subject1,
- data_source=source, weight=1)
+ identifier=identifier, subject=subject1, data_source=source, weight=1
+ )
classification2 = self._classification(
- identifier=identifier, subject=subject2,
- data_source=source, weight=2)
+ identifier=identifier, subject=subject2, data_source=source, weight=2
+ )
classification3 = self._classification(
- identifier=identifier, subject=subject3,
- data_source=source, weight=2)
+ identifier=identifier, subject=subject3, data_source=source, weight=2
+ )
results = work.classifications_with_genre().all()
@@ -609,9 +609,7 @@ def test_classifications_with_genre(self):
def test_mark_licensepools_as_superceded(self):
# A commercial LP that somehow got superceded will be
# un-superceded.
- commercial = self._licensepool(
- None, data_source_name=DataSource.OVERDRIVE
- )
+ commercial = self._licensepool(None, data_source_name=DataSource.OVERDRIVE)
work, is_new = commercial.calculate_work()
commercial.superceded = True
work.mark_licensepools_as_superceded()
@@ -620,8 +618,10 @@ def test_mark_licensepools_as_superceded(self):
# An open-access LP that was superceded will be un-superceded if
# chosen.
gutenberg = self._licensepool(
- None, data_source_name=DataSource.GUTENBERG,
- open_access=True, with_open_access_download=True
+ None,
+ data_source_name=DataSource.GUTENBERG,
+ open_access=True,
+ with_open_access_download=True,
)
work, is_new = gutenberg.calculate_work()
gutenberg.superceded = True
@@ -632,8 +632,10 @@ def test_mark_licensepools_as_superceded(self):
# source will be un-superceded, and the one from the
# lower-quality data source will be superceded.
standard_ebooks = self._licensepool(
- None, data_source_name=DataSource.STANDARD_EBOOKS,
- open_access=True, with_open_access_download=True
+ None,
+ data_source_name=DataSource.STANDARD_EBOOKS,
+ open_access=True,
+ with_open_access_download=True,
)
work.license_pools.append(standard_ebooks)
gutenberg.superceded = False
@@ -643,16 +645,25 @@ def test_mark_licensepools_as_superceded(self):
assert False == standard_ebooks.superceded
# Of three open-access pools, 1 and only 1 will be chosen as non-superceded.
- gitenberg1 = self._licensepool(edition=None, open_access=True,
- data_source_name=DataSource.PROJECT_GITENBERG, with_open_access_download=True
+ gitenberg1 = self._licensepool(
+ edition=None,
+ open_access=True,
+ data_source_name=DataSource.PROJECT_GITENBERG,
+ with_open_access_download=True,
)
- gitenberg2 = self._licensepool(edition=None, open_access=True,
- data_source_name=DataSource.PROJECT_GITENBERG, with_open_access_download=True
+ gitenberg2 = self._licensepool(
+ edition=None,
+ open_access=True,
+ data_source_name=DataSource.PROJECT_GITENBERG,
+ with_open_access_download=True,
)
- gutenberg1 = self._licensepool(edition=None, open_access=True,
- data_source_name=DataSource.GUTENBERG, with_open_access_download=True
+ gutenberg1 = self._licensepool(
+ edition=None,
+ open_access=True,
+ data_source_name=DataSource.GUTENBERG,
+ with_open_access_download=True,
)
work_multipool = self._work(presentation_edition=None)
@@ -674,7 +685,7 @@ def test_mark_licensepools_as_superceded(self):
chosen_count = 0
for chosen_pool in gutenberg1, gitenberg1, gitenberg2:
if chosen_pool.superceded is False:
- chosen_count += 1;
+ chosen_count += 1
assert chosen_count == 1
# throw wrench in
@@ -696,11 +707,20 @@ def test_mark_licensepools_as_superceded(self):
assert False == only_pool.superceded
def test_work_remains_viable_on_pools_suppressed(self):
- """ If a work has all of its pools suppressed, the work's author, title,
+ """If a work has all of its pools suppressed, the work's author, title,
and subtitle still have the last best-known info in them.
"""
- (work, pool_std_ebooks, pool_git, pool_gut,
- edition_std_ebooks, edition_git, edition_gut, alice, bob) = self._sample_ecosystem()
+ (
+ work,
+ pool_std_ebooks,
+ pool_git,
+ pool_gut,
+ edition_std_ebooks,
+ edition_git,
+ edition_gut,
+ alice,
+ bob,
+ ) = self._sample_ecosystem()
# make sure the setup is what we expect
assert pool_std_ebooks.suppressed == False
@@ -750,12 +770,21 @@ def test_work_remains_viable_on_pools_suppressed(self):
assert "Adder, Alice" == work.sort_author
def test_work_updates_info_on_pool_suppressed(self):
- """ If the provider of the work's presentation edition gets suppressed,
+ """If the provider of the work's presentation edition gets suppressed,
the work will choose another child license pool's presentation edition as
its presentation edition.
"""
- (work, pool_std_ebooks, pool_git, pool_gut,
- edition_std_ebooks, edition_git, edition_gut, alice, bob) = self._sample_ecosystem()
+ (
+ work,
+ pool_std_ebooks,
+ pool_git,
+ pool_gut,
+ edition_std_ebooks,
+ edition_git,
+ edition_gut,
+ alice,
+ bob,
+ ) = self._sample_ecosystem()
# make sure the setup is what we expect
assert pool_std_ebooks.suppressed == False
@@ -808,16 +837,22 @@ def test_different_language_means_different_work(self):
same, so the books have the same permanent work ID, but since
they are in different languages they become separate works.
"""
- title = 'Siddhartha'
- author = ['Herman Hesse']
+ title = "Siddhartha"
+ author = ["Herman Hesse"]
edition1, lp1 = self._edition(
- title=title, authors=author, language='eng', with_license_pool=True,
- with_open_access_download=True
+ title=title,
+ authors=author,
+ language="eng",
+ with_license_pool=True,
+ with_open_access_download=True,
)
w1 = lp1.calculate_work()
edition2, lp2 = self._edition(
- title=title, authors=author, language='ger', with_license_pool=True,
- with_open_access_download=True
+ title=title,
+ authors=author,
+ language="ger",
+ with_license_pool=True,
+ with_open_access_download=True,
)
w2 = lp2.calculate_work()
for l in (lp1, lp2):
@@ -830,20 +865,22 @@ def test_reject_covers(self):
# Create a cover and thumbnail for the edition.
current_folder = os.path.split(__file__)[0]
base_path = os.path.dirname(current_folder)
- sample_cover_path = base_path + '/files/covers/test-book-cover.png'
- cover_href = 'http://cover.png'
+ sample_cover_path = base_path + "/files/covers/test-book-cover.png"
+ cover_href = "http://cover.png"
cover_link = lp.add_link(
- Hyperlink.IMAGE, cover_href, lp.data_source,
+ Hyperlink.IMAGE,
+ cover_href,
+ lp.data_source,
media_type=Representation.PNG_MEDIA_TYPE,
- content=open(sample_cover_path, 'rb').read()
+ content=open(sample_cover_path, "rb").read(),
)[0]
- thumbnail_href = 'http://thumbnail.png'
+ thumbnail_href = "http://thumbnail.png"
thumbnail_rep = self._representation(
url=thumbnail_href,
media_type=Representation.PNG_MEDIA_TYPE,
- content=open(sample_cover_path, 'rb').read(),
- mirrored=True
+ content=open(sample_cover_path, "rb").read(),
+ mirrored=True,
)[0]
cover_rep = cover_link.resource.representation
@@ -914,7 +951,7 @@ def reset_cover():
assert has_no_cover(other_work)
def test_missing_coverage_from(self):
- operation = 'the_operation'
+ operation = "the_operation"
# Here's a work with a coverage record.
work = self._work(with_license_pool=True)
@@ -932,10 +969,9 @@ def test_missing_coverage_from(self):
# certain time, it might need coverage again.
cutoff = record.timestamp + datetime.timedelta(seconds=1)
- assert (
- [work] == Work.missing_coverage_from(
- self._db, operation, count_as_missing_before=cutoff
- ).all())
+ assert [work] == Work.missing_coverage_from(
+ self._db, operation, count_as_missing_before=cutoff
+ ).all()
def test_top_genre(self):
work = self._work()
@@ -961,7 +997,9 @@ def test_top_genre(self):
def test_to_search_document(self):
# Set up an edition and work.
- edition, pool1 = self._edition(authors=[self._str, self._str], with_license_pool=True)
+ edition, pool1 = self._edition(
+ authors=[self._str, self._str], with_license_pool=True
+ )
work = self._work(presentation_edition=edition)
# Create a second Collection that has a different LicensePool
@@ -980,12 +1018,20 @@ def test_to_search_document(self):
collection3 = self._collection()
# These are the edition's authors.
- [contributor1] = [c.contributor for c in edition.contributions if c.role == Contributor.PRIMARY_AUTHOR_ROLE]
+ [contributor1] = [
+ c.contributor
+ for c in edition.contributions
+ if c.role == Contributor.PRIMARY_AUTHOR_ROLE
+ ]
contributor1.display_name = self._str
contributor1.family_name = self._str
contributor1.viaf = self._str
contributor1.lc = self._str
- [contributor2] = [c.contributor for c in edition.contributions if c.role == Contributor.AUTHOR_ROLE]
+ [contributor2] = [
+ c.contributor
+ for c in edition.contributions
+ if c.role == Contributor.AUTHOR_ROLE
+ ]
data_source = DataSource.lookup(self._db, DataSource.THREEM)
@@ -1006,13 +1052,19 @@ def test_to_search_document(self):
# Add some classifications.
# This classification has no subject name, so the search document will use the subject identifier.
- edition.primary_identifier.classify(data_source, Subject.BISAC, "FICTION/Science Fiction/Time Travel", None, 6)
+ edition.primary_identifier.classify(
+ data_source, Subject.BISAC, "FICTION/Science Fiction/Time Travel", None, 6
+ )
# This one has the same subject type and identifier, so their weights will be combined.
- identifier1.classify(data_source, Subject.BISAC, "FICTION/Science Fiction/Time Travel", None, 1)
+ identifier1.classify(
+ data_source, Subject.BISAC, "FICTION/Science Fiction/Time Travel", None, 1
+ )
# Here's another classification with a different subject type.
- edition.primary_identifier.classify(data_source, Subject.OVERDRIVE, "Romance", None, 2)
+ edition.primary_identifier.classify(
+ data_source, Subject.OVERDRIVE, "Romance", None, 2
+ )
# This classification has a subject name, so the search document will use that instead of the identifier.
identifier1.classify(data_source, Subject.FAST, self._str, "Sea Stories", 7)
@@ -1035,14 +1087,22 @@ def test_to_search_document(self):
appeared_1 = datetime_utc(2010, 1, 1)
appeared_2 = datetime_utc(2011, 1, 1)
l1, ignore = self._customlist(num_entries=0)
- l1.add_entry(work, featured=False, update_external_index=False,
- first_appearance=appeared_1)
+ l1.add_entry(
+ work,
+ featured=False,
+ update_external_index=False,
+ first_appearance=appeared_1,
+ )
l2, ignore = self._customlist(num_entries=0)
- l2.add_entry(work, featured=True, update_external_index=False,
- first_appearance=appeared_2)
+ l2.add_entry(
+ work,
+ featured=True,
+ update_external_index=False,
+ first_appearance=appeared_2,
+ )
# Add the other fields used in the search document.
- work.target_age = NumericRange(7, 8, '[]')
+ work.target_age = NumericRange(7, 8, "[]")
edition.subtitle = self._str
edition.series = self._str
edition.series_position = 99
@@ -1070,73 +1130,70 @@ def assert_time_match(python, postgres):
:param python: A datetime from the Python part of this test.
:param postgres: A float from the Postgres part.
"""
- expect = (
- python - from_timestamp(0)
- ).total_seconds()
+ expect = (python - from_timestamp(0)).total_seconds()
assert int(expect) == int(postgres)
search_doc = work.to_search_document()
- assert work.id == search_doc['_id']
- assert work.id == search_doc['work_id']
- assert work.title == search_doc['title']
- assert edition.subtitle == search_doc['subtitle']
- assert edition.series == search_doc['series']
- assert edition.series_position == search_doc['series_position']
- assert edition.language == search_doc['language']
- assert work.sort_title == search_doc['sort_title']
- assert work.author == search_doc['author']
- assert work.sort_author == search_doc['sort_author']
- assert edition.publisher == search_doc['publisher']
- assert edition.imprint == search_doc['imprint']
- assert edition.permanent_work_id == search_doc['permanent_work_id']
- assert "Nonfiction" == search_doc['fiction']
- assert "YoungAdult" == search_doc['audience']
- assert work.summary_text == search_doc['summary']
- assert work.quality == search_doc['quality']
- assert work.rating == search_doc['rating']
- assert work.popularity == search_doc['popularity']
- assert work.presentation_ready == search_doc['presentation_ready']
- assert_time_match(work.last_update_time, search_doc['last_update_time'])
- assert dict(lower=7, upper=8) == search_doc['target_age']
+ assert work.id == search_doc["_id"]
+ assert work.id == search_doc["work_id"]
+ assert work.title == search_doc["title"]
+ assert edition.subtitle == search_doc["subtitle"]
+ assert edition.series == search_doc["series"]
+ assert edition.series_position == search_doc["series_position"]
+ assert edition.language == search_doc["language"]
+ assert work.sort_title == search_doc["sort_title"]
+ assert work.author == search_doc["author"]
+ assert work.sort_author == search_doc["sort_author"]
+ assert edition.publisher == search_doc["publisher"]
+ assert edition.imprint == search_doc["imprint"]
+ assert edition.permanent_work_id == search_doc["permanent_work_id"]
+ assert "Nonfiction" == search_doc["fiction"]
+ assert "YoungAdult" == search_doc["audience"]
+ assert work.summary_text == search_doc["summary"]
+ assert work.quality == search_doc["quality"]
+ assert work.rating == search_doc["rating"]
+ assert work.popularity == search_doc["popularity"]
+ assert work.presentation_ready == search_doc["presentation_ready"]
+ assert_time_match(work.last_update_time, search_doc["last_update_time"])
+ assert dict(lower=7, upper=8) == search_doc["target_age"]
# Each LicensePool for the Work is listed in
# the 'licensepools' section.
- licensepools = search_doc['licensepools']
+ licensepools = search_doc["licensepools"]
assert 2 == len(licensepools)
- assert (set([x.id for x in work.license_pools]) ==
- set([x['licensepool_id'] for x in licensepools]))
+ assert set([x.id for x in work.license_pools]) == set(
+ [x["licensepool_id"] for x in licensepools]
+ )
# Each item in the 'licensepools' section has a variety of useful information
# about the corresponding LicensePool.
for pool in work.license_pools:
- [match] = [x for x in licensepools if x['licensepool_id'] == pool.id]
- assert pool.open_access == match['open_access']
- assert pool.collection_id == match['collection_id']
- assert pool.suppressed == match['suppressed']
- assert pool.data_source_id == match['data_source_id']
+ [match] = [x for x in licensepools if x["licensepool_id"] == pool.id]
+ assert pool.open_access == match["open_access"]
+ assert pool.collection_id == match["collection_id"]
+ assert pool.suppressed == match["suppressed"]
+ assert pool.data_source_id == match["data_source_id"]
- assert isinstance(match['available'], bool)
- assert (pool.licenses_available > 0) == match['available']
- assert isinstance(match['licensed'], bool)
- assert (pool.licenses_owned > 0) == match['licensed']
+ assert isinstance(match["available"], bool)
+ assert (pool.licenses_available > 0) == match["available"]
+ assert isinstance(match["licensed"], bool)
+ assert (pool.licenses_owned > 0) == match["licensed"]
# The work quality is stored in the main document, but
# it's also stored in the license pool subdocument so that
# we can apply a nested filter that includes quality +
# information from the subdocument.
- assert work.quality == match['quality']
+ assert work.quality == match["quality"]
- assert_time_match(
- pool.availability_time, match['availability_time']
- )
+ assert_time_match(pool.availability_time, match["availability_time"])
# The medium of the work's presentation edition is stored
# in the main document, but it's also stored in the
# license poolsubdocument, so that we can filter out
# license pools that represent audiobooks from unsupported
# sources.
- assert edition.medium == search_doc['medium']
- assert edition.medium == match['medium']
+ assert edition.medium == search_doc["medium"]
+ assert edition.medium == match["medium"]
# Each identifier that could, with high confidence, be
# associated with the work, is in the 'identifiers' section.
@@ -1148,71 +1205,88 @@ def assert_time_match(python, postgres):
# identifiers not tied to a LicensePool.
expect = [
dict(identifier=identifier1.identifier, type=identifier1.type),
- dict(identifier=pool1.identifier.identifier,
- type=pool1.identifier.type),
+ dict(identifier=pool1.identifier.identifier, type=pool1.identifier.type),
]
+
def s(x):
# Sort an identifier dictionary by its identifier value.
- return sorted(x, key = lambda b: b['identifier'])
- assert s(expect) == s(search_doc['identifiers'])
+ return sorted(x, key=lambda b: b["identifier"])
+
+ assert s(expect) == s(search_doc["identifiers"])
# Each custom list entry for the work is in the 'customlists'
# section.
not_featured, featured = sorted(
- search_doc['customlists'], key = lambda x: x['featured']
+ search_doc["customlists"], key=lambda x: x["featured"]
)
- assert_time_match(appeared_1, not_featured.pop('first_appearance'))
+ assert_time_match(appeared_1, not_featured.pop("first_appearance"))
assert dict(featured=False, list_id=l1.id) == not_featured
- assert_time_match(appeared_2, featured.pop('first_appearance'))
+ assert_time_match(appeared_2, featured.pop("first_appearance"))
assert dict(featured=True, list_id=l2.id) == featured
- contributors = search_doc['contributors']
+ contributors = search_doc["contributors"]
assert 2 == len(contributors)
- [contributor1_doc] = [c for c in contributors if c['sort_name'] == contributor1.sort_name]
- [contributor2_doc] = [c for c in contributors if c['sort_name'] == contributor2.sort_name]
+ [contributor1_doc] = [
+ c for c in contributors if c["sort_name"] == contributor1.sort_name
+ ]
+ [contributor2_doc] = [
+ c for c in contributors if c["sort_name"] == contributor2.sort_name
+ ]
- assert contributor1.display_name == contributor1_doc['display_name']
- assert None == contributor2_doc['display_name']
+ assert contributor1.display_name == contributor1_doc["display_name"]
+ assert None == contributor2_doc["display_name"]
- assert contributor1.family_name == contributor1_doc['family_name']
- assert None == contributor2_doc['family_name']
+ assert contributor1.family_name == contributor1_doc["family_name"]
+ assert None == contributor2_doc["family_name"]
- assert contributor1.viaf == contributor1_doc['viaf']
- assert None == contributor2_doc['viaf']
+ assert contributor1.viaf == contributor1_doc["viaf"]
+ assert None == contributor2_doc["viaf"]
- assert contributor1.lc == contributor1_doc['lc']
- assert None == contributor2_doc['lc']
+ assert contributor1.lc == contributor1_doc["lc"]
+ assert None == contributor2_doc["lc"]
- assert Contributor.PRIMARY_AUTHOR_ROLE == contributor1_doc['role']
- assert Contributor.AUTHOR_ROLE == contributor2_doc['role']
+ assert Contributor.PRIMARY_AUTHOR_ROLE == contributor1_doc["role"]
+ assert Contributor.AUTHOR_ROLE == contributor2_doc["role"]
- classifications = search_doc['classifications']
+ classifications = search_doc["classifications"]
assert 3 == len(classifications)
- [classification1_doc] = [c for c in classifications if c['scheme'] == Subject.uri_lookup[Subject.BISAC]]
- [classification2_doc] = [c for c in classifications if c['scheme'] == Subject.uri_lookup[Subject.OVERDRIVE]]
- [classification3_doc] = [c for c in classifications if c['scheme'] == Subject.uri_lookup[Subject.FAST]]
- assert "FICTION Science Fiction Time Travel" == classification1_doc['term']
- assert float(6 + 1)/(6 + 1 + 2 + 7) == classification1_doc['weight']
- assert "Romance" == classification2_doc['term']
- assert float(2)/(6 + 1 + 2 + 7) == classification2_doc['weight']
- assert "Sea Stories" == classification3_doc['term']
- assert float(7)/(6 + 1 + 2 + 7) == classification3_doc['weight']
-
- genres = search_doc['genres']
+ [classification1_doc] = [
+ c
+ for c in classifications
+ if c["scheme"] == Subject.uri_lookup[Subject.BISAC]
+ ]
+ [classification2_doc] = [
+ c
+ for c in classifications
+ if c["scheme"] == Subject.uri_lookup[Subject.OVERDRIVE]
+ ]
+ [classification3_doc] = [
+ c
+ for c in classifications
+ if c["scheme"] == Subject.uri_lookup[Subject.FAST]
+ ]
+ assert "FICTION Science Fiction Time Travel" == classification1_doc["term"]
+ assert float(6 + 1) / (6 + 1 + 2 + 7) == classification1_doc["weight"]
+ assert "Romance" == classification2_doc["term"]
+ assert float(2) / (6 + 1 + 2 + 7) == classification2_doc["weight"]
+ assert "Sea Stories" == classification3_doc["term"]
+ assert float(7) / (6 + 1 + 2 + 7) == classification3_doc["weight"]
+
+ genres = search_doc["genres"]
assert 2 == len(genres)
- [genre1_doc] = [g for g in genres if g['name'] == genre1.name]
- [genre2_doc] = [g for g in genres if g['name'] == genre2.name]
- assert Subject.SIMPLIFIED_GENRE == genre1_doc['scheme']
- assert genre1.id == genre1_doc['term']
- assert 1 == genre1_doc['weight']
- assert Subject.SIMPLIFIED_GENRE == genre2_doc['scheme']
- assert genre2.id == genre2_doc['term']
- assert 0 == genre2_doc['weight']
-
- target_age_doc = search_doc['target_age']
- assert work.target_age.lower == target_age_doc['lower']
- assert work.target_age.upper == target_age_doc['upper']
+ [genre1_doc] = [g for g in genres if g["name"] == genre1.name]
+ [genre2_doc] = [g for g in genres if g["name"] == genre2.name]
+ assert Subject.SIMPLIFIED_GENRE == genre1_doc["scheme"]
+ assert genre1.id == genre1_doc["term"]
+ assert 1 == genre1_doc["weight"]
+ assert Subject.SIMPLIFIED_GENRE == genre2_doc["scheme"]
+ assert genre2.id == genre2_doc["term"]
+ assert 0 == genre2_doc["weight"]
+
+ target_age_doc = search_doc["target_age"]
+ assert work.target_age.lower == target_age_doc["lower"]
+ assert work.target_age.upper == target_age_doc["upper"]
# If a book stops being available through a collection
# (because its LicensePool loses all its licenses or stops
@@ -1222,16 +1296,18 @@ def s(x):
pool.licenses_owned = 0
self._db.commit()
search_doc = work.to_search_document()
- assert ([collection2.id] ==
- [x['collection_id'] for x in search_doc['licensepools']])
+ assert [collection2.id] == [
+ x["collection_id"] for x in search_doc["licensepools"]
+ ]
# If the book becomes available again, the collection will
# start showing up again.
pool.open_access = True
self._db.commit()
search_doc = work.to_search_document()
- assert (set([collection1.id, collection2.id]) ==
- set([x['collection_id'] for x in search_doc['licensepools']]))
+ assert set([collection1.id, collection2.id]) == set(
+ [x["collection_id"] for x in search_doc["licensepools"]]
+ )
def test_age_appropriate_for_patron(self):
work = self._work()
@@ -1272,10 +1348,8 @@ def test_age_appropriate_for_patron_end_to_end(self):
# NOTE: setting target_age sets .audiences to appropriate values,
# so setting .audiences here is purely demonstrative.
- lane.audiences = [
- Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_YOUNG_ADULT
- ]
- lane.target_age = (9,14)
+ lane.audiences = [Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_YOUNG_ADULT]
+ lane.target_age = (9, 14)
# This work is a YA title within the age range.
work = self._work()
@@ -1289,7 +1363,7 @@ def test_age_appropriate_for_patron_end_to_end(self):
assert False == work.age_appropriate_for_patron(patron)
# Bump up the lane to match, and it's age-appropriate again.
- lane.target_age = (9,16)
+ lane.target_age = (9, 16)
assert True == work.age_appropriate_for_patron(patron)
# Change the audience to AUDIENCE_ADULT, and the work stops being
@@ -1299,7 +1373,9 @@ def test_age_appropriate_for_patron_end_to_end(self):
def test_unlimited_access_books_are_available_by_default(self):
# Set up an edition and work.
- edition, pool = self._edition(authors=[self._str, self._str], with_license_pool=True)
+ edition, pool = self._edition(
+ authors=[self._str, self._str], with_license_pool=True
+ )
work = self._work(presentation_edition=edition)
pool.open_access = False
@@ -1313,14 +1389,16 @@ def test_unlimited_access_books_are_available_by_default(self):
# Each LicensePool for the Work is listed in
# the 'licensepools' section.
- licensepools = search_doc['licensepools']
+ licensepools = search_doc["licensepools"]
assert 1 == len(licensepools)
- assert licensepools[0]['open_access'] == False
- assert licensepools[0]['available'] == True
+ assert licensepools[0]["open_access"] == False
+ assert licensepools[0]["available"] == True
def test_self_hosted_books_are_available_by_default(self):
# Set up an edition and work.
- edition, pool = self._edition(authors=[self._str, self._str], with_license_pool=True)
+ edition, pool = self._edition(
+ authors=[self._str, self._str], with_license_pool=True
+ )
work = self._work(presentation_edition=edition)
pool.licenses_owned = 0
@@ -1334,44 +1412,44 @@ def test_self_hosted_books_are_available_by_default(self):
# Each LicensePool for the Work is listed in
# the 'licensepools' section.
- licensepools = search_doc['licensepools']
+ licensepools = search_doc["licensepools"]
assert 1 == len(licensepools)
- assert licensepools[0]['open_access'] == False
- assert licensepools[0]['available'] == True
+ assert licensepools[0]["open_access"] == False
+ assert licensepools[0]["available"] == True
def test_target_age_string(self):
work = self._work()
- work.target_age = NumericRange(7, 8, '[]')
+ work.target_age = NumericRange(7, 8, "[]")
assert "7-8" == work.target_age_string
- work.target_age = NumericRange(0, 8, '[]')
+ work.target_age = NumericRange(0, 8, "[]")
assert "0-8" == work.target_age_string
- work.target_age = NumericRange(8, None, '[]')
+ work.target_age = NumericRange(8, None, "[]")
assert "8" == work.target_age_string
- work.target_age = NumericRange(None, 8, '[]')
+ work.target_age = NumericRange(None, 8, "[]")
assert "8" == work.target_age_string
- work.target_age = NumericRange(7, 8, '[)')
+ work.target_age = NumericRange(7, 8, "[)")
assert "7" == work.target_age_string
- work.target_age = NumericRange(0, 8, '[)')
+ work.target_age = NumericRange(0, 8, "[)")
assert "0-7" == work.target_age_string
- work.target_age = NumericRange(7, 8, '(]')
+ work.target_age = NumericRange(7, 8, "(]")
assert "8" == work.target_age_string
- work.target_age = NumericRange(0, 8, '(]')
+ work.target_age = NumericRange(0, 8, "(]")
assert "1-8" == work.target_age_string
- work.target_age = NumericRange(7, 9, '()')
+ work.target_age = NumericRange(7, 9, "()")
assert "8" == work.target_age_string
- work.target_age = NumericRange(0, 8, '()')
+ work.target_age = NumericRange(0, 8, "()")
assert "1-7" == work.target_age_string
- work.target_age = NumericRange(None, None, '()')
+ work.target_age = NumericRange(None, None, "()")
assert "" == work.target_age_string
work.target_age = None
@@ -1386,14 +1464,16 @@ def find_record(work):
WorkCoverageRecord.
"""
records = [
- x for x in work.coverage_records
+ x
+ for x in work.coverage_records
if x.operation.startswith(
- WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION
+ WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION
)
]
if records:
return records[0]
return None
+
registered = WorkCoverageRecord.REGISTERED
success = WorkCoverageRecord.SUCCESS
@@ -1441,7 +1521,7 @@ def find_record(work):
# its former Work needs to be reindexed.
record.status = success
self._db.delete(pool)
- work = self._db.query(Work).filter(Work.id==work.id).one()
+ work = self._db.query(Work).filter(Work.id == work.id).one()
record = find_record(work)
assert registered == record.status
@@ -1481,15 +1561,13 @@ def test_reset_coverage(self):
# for some specific operation.
def mock_reset_coverage(operation):
work.coverage_reset_for = operation
+
work._reset_coverage = mock_reset_coverage
for method, operation in (
- (work.needs_full_presentation_recalculation,
- WCR.CLASSIFY_OPERATION),
- (work.needs_new_presentation_edition,
- WCR.CHOOSE_EDITION_OPERATION),
- (work.external_index_needs_updating,
- WCR.UPDATE_SEARCH_INDEX_OPERATION)
+ (work.needs_full_presentation_recalculation, WCR.CLASSIFY_OPERATION),
+ (work.needs_new_presentation_edition, WCR.CHOOSE_EDITION_OPERATION),
+ (work.external_index_needs_updating, WCR.UPDATE_SEARCH_INDEX_OPERATION),
):
method()
assert operation == work.coverage_reset_for
@@ -1535,7 +1613,7 @@ def test_calculate_opds_entries(self):
work.calculate_opds_entries(verbose=False)
simple_entry = work.simple_opds_entry
- assert simple_entry.startswith(' len(simple_entry)
def test_calculate_marc_record(self):
@@ -1557,8 +1635,7 @@ def test_calculate_marc_record(self):
assert "online resource" in work.marc_record
def test_active_licensepool_ignores_superceded_licensepools(self):
- work = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work = self._work(with_license_pool=True, with_open_access_download=True)
[pool1] = work.license_pools
edition, pool2 = self._edition(with_license_pool=True)
work.license_pools.append(pool2)
@@ -1622,22 +1699,22 @@ def test_active_licensepool_ignores_superceded_licensepools(self):
def test_delete_work(self):
# Search mock
- class MockSearchIndex():
+ class MockSearchIndex:
removed = []
+
def remove_work(self, work):
self.removed.append(work)
- s = MockSearchIndex();
+ s = MockSearchIndex()
work = self._work(with_license_pool=True)
work.delete(search_index=s)
- assert [] == self._db.query(Work).filter(Work.id==work.id).all()
+ assert [] == self._db.query(Work).filter(Work.id == work.id).all()
assert 1 == len(s.removed)
assert s.removed == [work]
class TestWorkConsolidation(DatabaseTest):
-
def test_calculate_work_success(self):
e, p = self._edition(with_license_pool=True)
work, new = p.calculate_work()
@@ -1646,7 +1723,7 @@ def test_calculate_work_success(self):
def test_calculate_work_bails_out_if_no_title(self):
e, p = self._edition(with_license_pool=True)
- e.title=None
+ e.title = None
work, new = p.calculate_work()
assert None == work
assert False == new
@@ -1670,8 +1747,7 @@ def test_calculate_work_matches_based_on_permanent_work_id(self):
# since they have the same title/author.
edition1, ignore = self._edition(with_license_pool=True)
edition2, ignore = self._edition(
- title=edition1.title, authors=edition1.author,
- with_license_pool=True
+ title=edition1.title, authors=edition1.author, with_license_pool=True
)
# For purposes of this test, let's pretend all these books are
@@ -1694,18 +1770,27 @@ def test_calculate_work_matches_based_on_permanent_work_id(self):
expect = edition1.license_pools + edition2.license_pools
assert set(expect) == set(work1.license_pools)
-
def test_calculate_work_for_licensepool_creates_new_work(self):
- edition1, ignore = self._edition(data_source_name=DataSource.GUTENBERG, identifier_type=Identifier.GUTENBERG_ID,
- title=self._str, authors=[self._str], with_license_pool=True)
+ edition1, ignore = self._edition(
+ data_source_name=DataSource.GUTENBERG,
+ identifier_type=Identifier.GUTENBERG_ID,
+ title=self._str,
+ authors=[self._str],
+ with_license_pool=True,
+ )
# This edition is unique to the existing work.
preexisting_work = Work()
preexisting_work.set_presentation_edition(edition1)
# This edition is unique to the new LicensePool
- edition2, pool = self._edition(data_source_name=DataSource.GUTENBERG, identifier_type=Identifier.GUTENBERG_ID,
- title=self._str, authors=[self._str], with_license_pool=True)
+ edition2, pool = self._edition(
+ data_source_name=DataSource.GUTENBERG,
+ identifier_type=Identifier.GUTENBERG_ID,
+ title=self._str,
+ authors=[self._str],
+ with_license_pool=True,
+ )
# Call calculate_work(), and a new Work is created.
work, created = pool.calculate_work()
@@ -1713,13 +1798,19 @@ def test_calculate_work_for_licensepool_creates_new_work(self):
assert work != preexisting_work
def test_calculate_work_does_nothing_unless_edition_has_title(self):
- collection=self._collection()
+ collection = self._collection()
edition, ignore = Edition.for_foreign_id(
- self._db, DataSource.GUTENBERG, Identifier.GUTENBERG_ID, "1",
+ self._db,
+ DataSource.GUTENBERG,
+ Identifier.GUTENBERG_ID,
+ "1",
)
pool, ignore = LicensePool.for_foreign_id(
- self._db, DataSource.GUTENBERG, Identifier.GUTENBERG_ID, "1",
- collection=collection
+ self._db,
+ DataSource.GUTENBERG,
+ Identifier.GUTENBERG_ID,
+ "1",
+ collection=collection,
)
work, created = pool.calculate_work()
assert None == work
@@ -1735,7 +1826,9 @@ def test_calculate_work_does_nothing_unless_edition_has_title(self):
assert "foo" == work.title
assert "[Unknown]" == work.author
- def test_calculate_work_fails_when_presentation_edition_identifier_does_not_match_license_pool(self):
+ def test_calculate_work_fails_when_presentation_edition_identifier_does_not_match_license_pool(
+ self,
+ ):
# Here's a LicensePool with an Edition.
edition1, pool = self._edition(
@@ -1762,9 +1855,11 @@ def test_calculate_work_fails_when_presentation_edition_identifier_does_not_matc
# edition for a LicensePool with a totally different Identifier.
for edition in (edition2, edition3):
with pytest.raises(ValueError) as excinfo:
- pool.calculate_work(known_edition = edition)
- assert "Alleged presentation edition is not the presentation edition for the license pool for which work is being calculated!" \
+ pool.calculate_work(known_edition=edition)
+ assert (
+ "Alleged presentation edition is not the presentation edition for the license pool for which work is being calculated!"
in str(excinfo.value)
+ )
def test_open_access_pools_grouped_together(self):
@@ -1777,11 +1872,17 @@ def test_open_access_pools_grouped_together(self):
open1.open_access = True
open2.open_access = True
ed3, restricted3 = self._edition(
- title=title, authors=author, data_source_name=DataSource.OVERDRIVE,
- with_license_pool=True)
+ title=title,
+ authors=author,
+ data_source_name=DataSource.OVERDRIVE,
+ with_license_pool=True,
+ )
ed4, restricted4 = self._edition(
- title=title, authors=author, data_source_name=DataSource.OVERDRIVE,
- with_license_pool=True)
+ title=title,
+ authors=author,
+ data_source_name=DataSource.OVERDRIVE,
+ with_license_pool=True,
+ )
restricted3.open_access = False
restricted4.open_access = False
@@ -1824,7 +1925,7 @@ def test_all_licensepools_with_same_identifier_get_same_work(self):
with_license_pool=True,
identifier_type=identifier.type,
identifier_id=identifier.identifier,
- collection=collection2
+ collection=collection2,
)
assert pool1.identifier == pool2.identifier
@@ -1876,6 +1977,7 @@ def test_calculate_work_fixes_work_in_invalid_state(self):
# permanent work ID.
def mock_pwid(debug=False):
return "abcd"
+
for lp in [abcd_commercial, abcd_commercial_2, abcd_open_access]:
lp.presentation_edition.calculate_permanent_work_id = mock_pwid
@@ -1898,10 +2000,11 @@ def mock_pwid(debug=False):
# used for all open-access LicensePools for that book going
# forward.
- expect_open_access_work, open_access_work_is_new = (
- Work.open_access_for_permanent_work_id(
- self._db, "abcd", Edition.BOOK_MEDIUM, 'eng'
- )
+ (
+ expect_open_access_work,
+ open_access_work_is_new,
+ ) = Work.open_access_for_permanent_work_id(
+ self._db, "abcd", Edition.BOOK_MEDIUM, "eng"
)
assert expect_open_access_work == abcd_open_access.work
@@ -1945,7 +2048,7 @@ def test_calculate_work_fixes_incorrectly_grouped_books(self):
# open-access _audiobook_ of "abcd".
edition, audiobook = self._edition(with_license_pool=True)
audiobook.open_access = True
- audiobook.presentation_edition.medium=Edition.AUDIO_MEDIUM
+ audiobook.presentation_edition.medium = Edition.AUDIO_MEDIUM
audiobook.presentation_edition.permanent_work_id = "abcd"
work.license_pools.append(audiobook)
@@ -1953,12 +2056,13 @@ def test_calculate_work_fixes_incorrectly_grouped_books(self):
# in a different language.
edition, spanish = self._edition(with_license_pool=True)
spanish.open_access = True
- spanish.presentation_edition.language='spa'
+ spanish.presentation_edition.language = "spa"
spanish.presentation_edition.permanent_work_id = "abcd"
work.license_pools.append(spanish)
def mock_pwid(debug=False):
return "abcd"
+
for lp in [book, audiobook, spanish]:
lp.presentation_edition.calculate_permanent_work_id = mock_pwid
@@ -1979,33 +2083,32 @@ def mock_pwid(debug=False):
# The book has been given the Work that will be used for all
# book-type LicensePools for that title going forward.
- expect_book_work, book_work_is_new = (
- Work.open_access_for_permanent_work_id(
- self._db, "abcd", Edition.BOOK_MEDIUM, 'eng'
- )
+ expect_book_work, book_work_is_new = Work.open_access_for_permanent_work_id(
+ self._db, "abcd", Edition.BOOK_MEDIUM, "eng"
)
assert expect_book_work == book.work
# The audiobook has been given the Work that will be used for
# all audiobook-type LicensePools for that title going
# forward.
- expect_audiobook_work, audiobook_work_is_new = (
- Work.open_access_for_permanent_work_id(
- self._db, "abcd", Edition.AUDIO_MEDIUM, 'eng'
- )
+ (
+ expect_audiobook_work,
+ audiobook_work_is_new,
+ ) = Work.open_access_for_permanent_work_id(
+ self._db, "abcd", Edition.AUDIO_MEDIUM, "eng"
)
assert expect_audiobook_work == audiobook.work
# The Spanish book has been given the Work that will be used
# for all Spanish LicensePools for that title going forward.
- expect_spanish_work, spanish_work_is_new = (
- Work.open_access_for_permanent_work_id(
- self._db, "abcd", Edition.BOOK_MEDIUM, 'spa'
- )
+ (
+ expect_spanish_work,
+ spanish_work_is_new,
+ ) = Work.open_access_for_permanent_work_id(
+ self._db, "abcd", Edition.BOOK_MEDIUM, "spa"
)
assert expect_spanish_work == spanish.work
- assert 'spa' == expect_spanish_work.language
-
+ assert "spa" == expect_spanish_work.language
def test_calculate_work_detaches_licensepool_with_no_title(self):
# Here's a Work with an open-access edition of "abcd".
@@ -2034,8 +2137,8 @@ def test_calculate_work_detaches_licensepool_with_no_pwid(self):
# with no title or author, and thus no permanent work ID.
edition, no_title = self._edition(with_license_pool=True)
- no_title.presentation_edition.title=None
- no_title.presentation_edition.author=None
+ no_title.presentation_edition.title = None
+ no_title.presentation_edition.author = None
no_title.presentation_edition.permanent_work_id = None
work.license_pools.append(no_title)
@@ -2064,7 +2167,6 @@ def test_calculate_work_detaches_licensepool_with_no_pwid(self):
work_after, is_new = no_title.calculate_work()
assert [book] == work.license_pools
-
def test_pwids(self):
"""Test the property that finds all permanent work IDs
associated with a Work.
@@ -2073,55 +2175,49 @@ def test_pwids(self):
# with two different PWIDs are associated with the same work.
work = self._work(with_license_pool=True)
[lp1] = work.license_pools
- assert (set([lp1.presentation_edition.permanent_work_id]) ==
- work.pwids)
+ assert set([lp1.presentation_edition.permanent_work_id]) == work.pwids
edition, lp2 = self._edition(with_license_pool=True)
work.license_pools.append(lp2)
# Work.pwids finds both PWIDs.
- assert (set([lp1.presentation_edition.permanent_work_id,
- lp2.presentation_edition.permanent_work_id]) ==
- work.pwids)
+ assert (
+ set(
+ [
+ lp1.presentation_edition.permanent_work_id,
+ lp2.presentation_edition.permanent_work_id,
+ ]
+ )
+ == work.pwids
+ )
def test_open_access_for_permanent_work_id_no_licensepools(self):
# There are no LicensePools, which short-circuilts
# open_access_for_permanent_work_id.
- assert (
- (None, False) == Work.open_access_for_permanent_work_id(
- self._db, "No such permanent work ID", Edition.BOOK_MEDIUM,
- "eng"
- ))
+ assert (None, False) == Work.open_access_for_permanent_work_id(
+ self._db, "No such permanent work ID", Edition.BOOK_MEDIUM, "eng"
+ )
# Now it works.
w = self._work(
- language="eng", with_license_pool=True,
- with_open_access_download=True
+ language="eng", with_license_pool=True, with_open_access_download=True
)
w.presentation_edition.permanent_work_id = "permid"
- assert (
- (w, False) == Work.open_access_for_permanent_work_id(
- self._db, "permid", Edition.BOOK_MEDIUM,
- "eng"
- ))
+ assert (w, False) == Work.open_access_for_permanent_work_id(
+ self._db, "permid", Edition.BOOK_MEDIUM, "eng"
+ )
# But the language, medium, and permanent ID must all match.
- assert (
- (None, False) == Work.open_access_for_permanent_work_id(
- self._db, "permid", Edition.BOOK_MEDIUM,
- "spa"
- ))
+ assert (None, False) == Work.open_access_for_permanent_work_id(
+ self._db, "permid", Edition.BOOK_MEDIUM, "spa"
+ )
- assert (
- (None, False) == Work.open_access_for_permanent_work_id(
- self._db, "differentid", Edition.BOOK_MEDIUM,
- "eng"
- ))
+ assert (None, False) == Work.open_access_for_permanent_work_id(
+ self._db, "differentid", Edition.BOOK_MEDIUM, "eng"
+ )
- assert (
- (None, False) == Work.open_access_for_permanent_work_id(
- self._db, "differentid", Edition.AUDIO_MEDIUM,
- "eng"
- ))
+ assert (None, False) == Work.open_access_for_permanent_work_id(
+ self._db, "differentid", Edition.AUDIO_MEDIUM, "eng"
+ )
def test_open_access_for_permanent_work_id(self):
# Two different works full of open-access license pools.
@@ -2143,15 +2239,16 @@ def test_open_access_for_permanent_work_id(self):
# exact same book.
def mock_pwid(debug=False):
return "abcd"
+
for lp in [lp1, lp2, lp3]:
- lp.presentation_edition.permanent_work_id="abcd"
+ lp.presentation_edition.permanent_work_id = "abcd"
lp.presentation_edition.calculate_permanent_work_id = mock_pwid
# We've also got Work #3, which provides a commercial license
# for that book.
w3 = self._work(with_license_pool=True)
w3_pool = w3.license_pools[0]
- w3_pool.presentation_edition.permanent_work_id="abcd"
+ w3_pool.presentation_edition.permanent_work_id = "abcd"
w3_pool.open_access = False
# Work.open_access_for_permanent_work_id can resolve this problem.
@@ -2160,7 +2257,7 @@ def mock_pwid(debug=False):
)
# Work #3 still exists and its license pool was not affected.
- assert [w3] == self._db.query(Work).filter(Work.id==w3.id).all()
+ assert [w3] == self._db.query(Work).filter(Work.id == w3.id).all()
assert w3 == w3_pool.work
# But the other three license pools now have the same work.
@@ -2175,7 +2272,7 @@ def mock_pwid(debug=False):
assert False == is_new
# Work #1 no longer exists.
- assert [] == self._db.query(Work).filter(Work.id==w1.id).all()
+ assert [] == self._db.query(Work).filter(Work.id == w1.id).all()
# Calling Work.open_access_for_permanent_work_id again returns the same
# result.
@@ -2190,7 +2287,7 @@ def test_open_access_for_permanent_work_id_can_create_work(self):
# Here's a LicensePool with no corresponding Work.
edition, lp = self._edition(with_license_pool=True)
lp.open_access = True
- edition.permanent_work_id="abcd"
+ edition.permanent_work_id = "abcd"
# open_access_for_permanent_work_id creates the Work.
work, is_new = Work.open_access_for_permanent_work_id(
@@ -2204,17 +2301,23 @@ def test_potential_open_access_works_for_permanent_work_id(self):
# helper method.
# Here are two editions of the same book with the same PWID.
- title = 'Siddhartha'
- author = ['Herman Hesse']
+ title = "Siddhartha"
+ author = ["Herman Hesse"]
e1, lp1 = self._edition(
data_source_name=DataSource.STANDARD_EBOOKS,
- title=title, authors=author, language='eng', with_license_pool=True,
+ title=title,
+ authors=author,
+ language="eng",
+ with_license_pool=True,
)
e1.permanent_work_id = "pwid"
e2, lp2 = self._edition(
data_source_name=DataSource.GUTENBERG,
- title=title, authors=author, language='eng', with_license_pool=True,
+ title=title,
+ authors=author,
+ language="eng",
+ with_license_pool=True,
)
e2.permanent_work_id = "pwid"
@@ -2228,6 +2331,7 @@ def m():
return Work._potential_open_access_works_for_permanent_work_id(
self._db, "pwid", Edition.BOOK_MEDIUM, "eng"
)
+
pools, counts = m()
# Both LicensePools show up in the list of LicensePools that
@@ -2235,7 +2339,7 @@ def m():
# associated with the same Work.
poolset = set([lp1, lp2])
assert poolset == pools
- assert {w1 : 2} == counts
+ assert {w1: 2} == counts
# Since the work was just created, it has no presentation
# edition and thus no language. If the presentation edition
@@ -2243,7 +2347,7 @@ def m():
w1.presentation_edition = e1
pools, counts = m()
assert poolset == pools
- assert {w1 : 2} == counts
+ assert {w1: 2} == counts
# If the Work's presentation edition has information that
# _conflicts_ with the information passed in to
@@ -2251,14 +2355,14 @@ def m():
# does not show up in `counts`, indicating that a new Work
# should to be created to hold those books.
bad_pe = self._edition()
- bad_pe.permanent_work_id='pwid'
+ bad_pe.permanent_work_id = "pwid"
w1.presentation_edition = bad_pe
- bad_pe.language = 'fin'
+ bad_pe.language = "fin"
pools, counts = m()
assert poolset == pools
assert {} == counts
- bad_pe.language = 'eng'
+ bad_pe.language = "eng"
bad_pe.medium = Edition.AUDIO_MEDIUM
pools, counts = m()
@@ -2282,7 +2386,7 @@ def assert_lp1_missing():
# LicensePools for its Work.
pools, counts = m()
assert set([lp2]) == pools
- assert {w1 : 1} == counts
+ assert {w1: 1} == counts
# It has to be open-access.
lp1.open_access = False
@@ -2303,7 +2407,7 @@ def assert_lp1_missing():
# The language must also match.
e1.language = "another language"
assert_lp1_missing()
- e1.language = 'eng'
+ e1.language = "eng"
# Finally, let's see what happens when there are two Works where
# there should be one.
@@ -2326,10 +2430,9 @@ def assert_lp1_missing():
def test_make_exclusive_open_access_for_permanent_work_id(self):
# Here's a work containing an open-access LicensePool for
# literary work "abcd".
- work1 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work1 = self._work(with_license_pool=True, with_open_access_download=True)
[abcd_oa] = work1.license_pools
- abcd_oa.presentation_edition.permanent_work_id="abcd"
+ abcd_oa.presentation_edition.permanent_work_id = "abcd"
# Unfortunately, a commercial LicensePool for the literary
# work "abcd" has gotten associated with the same work.
@@ -2337,15 +2440,14 @@ def test_make_exclusive_open_access_for_permanent_work_id(self):
with_license_pool=True, with_open_access_download=True
)
abcd_commercial.open_access = False
- abcd_commercial.presentation_edition.permanent_work_id="abcd"
+ abcd_commercial.presentation_edition.permanent_work_id = "abcd"
abcd_commercial.work = work1
# Here's another Work containing an open-access LicensePool
# for literary work "efgh".
- work2 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work2 = self._work(with_license_pool=True, with_open_access_download=True)
[efgh_1] = work2.license_pools
- efgh_1.presentation_edition.permanent_work_id="efgh"
+ efgh_1.presentation_edition.permanent_work_id = "efgh"
# Unfortunately, there's another open-access LicensePool for
# "efgh", and it's incorrectly associated with the "abcd"
@@ -2358,7 +2460,9 @@ def test_make_exclusive_open_access_for_permanent_work_id(self):
# Let's fix these problems.
work1.make_exclusive_open_access_for_permanent_work_id(
- "abcd", Edition.BOOK_MEDIUM, "eng",
+ "abcd",
+ Edition.BOOK_MEDIUM,
+ "eng",
)
# The open-access "abcd" book is now the only LicensePool
@@ -2374,8 +2478,7 @@ def test_make_exclusive_open_access_for_permanent_work_id(self):
def test_make_exclusive_open_access_for_null_permanent_work_id(self):
# Here's a LicensePool that, due to a previous error, has
# a null PWID in its presentation edition.
- work = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work = self._work(with_license_pool=True, with_open_access_download=True)
[null1] = work.license_pools
null1.presentation_edition.title = None
null1.presentation_edition.sort_author = None
@@ -2406,40 +2509,41 @@ def test_make_exclusive_open_access_for_null_permanent_work_id(self):
def test_merge_into_success(self):
# Here's a work with an open-access LicensePool.
- work1 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work1 = self._work(with_license_pool=True, with_open_access_download=True)
[lp1] = work1.license_pools
- lp1.presentation_edition.permanent_work_id="abcd"
+ lp1.presentation_edition.permanent_work_id = "abcd"
# Let's give it a WorkGenre and a WorkCoverageRecord.
genre, ignore = Genre.lookup(self._db, "Fantasy")
- wg, wg_is_new = get_one_or_create(
- self._db, WorkGenre, work=work1, genre=genre
- )
+ wg, wg_is_new = get_one_or_create(self._db, WorkGenre, work=work1, genre=genre)
wcr, wcr_is_new = WorkCoverageRecord.add_for(work1, "test")
# Here's another work with an open-access LicensePool for the
# same book.
- work2 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work2 = self._work(with_license_pool=True, with_open_access_download=True)
[lp2] = work2.license_pools
- lp2.presentation_edition.permanent_work_id="abcd"
+ lp2.presentation_edition.permanent_work_id = "abcd"
# Let's merge the first work into the second.
work1.merge_into(work2)
# The first work has been deleted, as have its WorkGenre and
# WorkCoverageRecord.
- assert [] == self._db.query(Work).filter(Work.id==work1.id).all()
+ assert [] == self._db.query(Work).filter(Work.id == work1.id).all()
assert [] == self._db.query(WorkGenre).all()
- assert [] == self._db.query(WorkCoverageRecord).filter(
- WorkCoverageRecord.work_id==work1.id).all()
+ assert (
+ []
+ == self._db.query(WorkCoverageRecord)
+ .filter(WorkCoverageRecord.work_id == work1.id)
+ .all()
+ )
- def test_open_access_for_permanent_work_id_fixes_mismatched_works_incidentally(self):
+ def test_open_access_for_permanent_work_id_fixes_mismatched_works_incidentally(
+ self,
+ ):
# Here's a work with two open-access LicensePools for the book "abcd".
- work1 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work1 = self._work(with_license_pool=True, with_open_access_download=True)
[abcd_1] = work1.license_pools
edition, abcd_2 = self._edition(
with_license_pool=True, with_open_access_download=True
@@ -2456,8 +2560,7 @@ def test_open_access_for_permanent_work_id_fixes_mismatched_works_incidentally(s
# Here's another work with an open-access LicensePool for the
# book "abcd".
- work2 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work2 = self._work(with_license_pool=True, with_open_access_download=True)
[abcd_3] = work2.license_pools
# Unfortunately, this work also contains an open-access Licensepool
@@ -2480,13 +2583,13 @@ def mock_pwid_ijkl(debug=False):
for lp in abcd_1, abcd_2, abcd_3:
lp.presentation_edition.calculate_permanent_work_id = mock_pwid_abcd
- lp.presentation_edition.permanent_work_id = 'abcd'
+ lp.presentation_edition.permanent_work_id = "abcd"
efgh.presentation_edition.calculate_permanent_work_id = mock_pwid_efgh
- efgh.presentation_edition.permanent_work_id = 'efgh'
+ efgh.presentation_edition.permanent_work_id = "efgh"
ijkl.presentation_edition.calculate_permanent_work_id = mock_pwid_ijkl
- ijkl.presentation_edition.permanent_work_id = 'ijkl'
+ ijkl.presentation_edition.permanent_work_id = "ijkl"
# Calling Work.open_access_for_permanent_work_id()
# automatically kicks the 'efgh' and 'ijkl' LicensePools into
@@ -2526,12 +2629,10 @@ def mock_pwid_ijkl(debug=False):
def test_open_access_for_permanent_work_untangles_tangled_works(self):
# Here are three works for the books "abcd", "efgh", and "ijkl".
- abcd_work = self._work(with_license_pool=True,
- with_open_access_download=True)
+ abcd_work = self._work(with_license_pool=True, with_open_access_download=True)
[abcd_1] = abcd_work.license_pools
- efgh_work = self._work(with_license_pool=True,
- with_open_access_download=True)
+ efgh_work = self._work(with_license_pool=True, with_open_access_download=True)
[efgh_1] = efgh_work.license_pools
# Unfortunately, due to an earlier error, the 'abcd' work
@@ -2560,14 +2661,14 @@ def mock_pwid_abcd(debug=False):
for lp in abcd_1, abcd_2:
lp.presentation_edition.calculate_permanent_work_id = mock_pwid_abcd
- lp.presentation_edition.permanent_work_id = 'abcd'
+ lp.presentation_edition.permanent_work_id = "abcd"
def mock_pwid_efgh(debug=False):
return "efgh"
for lp in efgh_1, efgh_2:
lp.presentation_edition.calculate_permanent_work_id = mock_pwid_efgh
- lp.presentation_edition.permanent_work_id = 'efgh'
+ lp.presentation_edition.permanent_work_id = "efgh"
# Calling Work.open_access_for_permanent_work_id() creates a
# new work that contains both 'abcd' LicensePools.
@@ -2603,41 +2704,42 @@ def mock_pwid_efgh(debug=False):
def test_merge_into_raises_exception_if_grouping_rules_violated(self):
# Here's a work with an open-access LicensePool.
- work1 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work1 = self._work(with_license_pool=True, with_open_access_download=True)
[lp1] = work1.license_pools
- lp1.presentation_edition.permanent_work_id="abcd"
+ lp1.presentation_edition.permanent_work_id = "abcd"
# Here's another work with a commercial LicensePool for the
# same book.
- work2 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work2 = self._work(with_license_pool=True, with_open_access_download=True)
[lp2] = work2.license_pools
lp2.open_access = False
- lp2.presentation_edition.permanent_work_id="abcd"
+ lp2.presentation_edition.permanent_work_id = "abcd"
# The works cannot be merged.
with pytest.raises(ValueError) as excinfo:
work1.merge_into(work2)
- assert "Refusing to merge {} into {} because it would put an open-access LicensePool into the same work as a non-open-access LicensePool.".format(work1, work2) \
- in str(excinfo.value)
-
+ assert "Refusing to merge {} into {} because it would put an open-access LicensePool into the same work as a non-open-access LicensePool.".format(
+ work1, work2
+ ) in str(
+ excinfo.value
+ )
def test_merge_into_raises_exception_if_pwids_differ(self):
- work1 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work1 = self._work(with_license_pool=True, with_open_access_download=True)
[abcd_oa] = work1.license_pools
- abcd_oa.presentation_edition.permanent_work_id="abcd"
+ abcd_oa.presentation_edition.permanent_work_id = "abcd"
- work2 = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work2 = self._work(with_license_pool=True, with_open_access_download=True)
[efgh_oa] = work2.license_pools
- efgh_oa.presentation_edition.permanent_work_id="efgh"
+ efgh_oa.presentation_edition.permanent_work_id = "efgh"
with pytest.raises(ValueError) as excinfo:
work1.merge_into(work2)
- assert "Refusing to merge {} into {} because permanent work IDs don't match: abcd vs. efgh".format(work1, work2) \
- in str(excinfo.value)
+ assert "Refusing to merge {} into {} because permanent work IDs don't match: abcd vs. efgh".format(
+ work1, work2
+ ) in str(
+ excinfo.value
+ )
def test_licensepool_without_identifier_gets_no_work(self):
work = self._work(with_license_pool=True)
diff --git a/tests/python_expression_dsl/test_evaluator.py b/tests/python_expression_dsl/test_evaluator.py
index ad98ff005..42d15fa24 100644
--- a/tests/python_expression_dsl/test_evaluator.py
+++ b/tests/python_expression_dsl/test_evaluator.py
@@ -8,6 +8,7 @@
)
from ...python_expression_dsl.parser import DSLParseError, DSLParser
+
class Subject(object):
"""Dummy object designed for testing DSLEvaluator."""
@@ -44,10 +45,8 @@ class TestDSLEvaluator(object):
@parameterized.expand(
[
("incorrect_expression", "?", None, None, None, DSLParseError),
-
("numeric_literal", "9", 9),
("numeric_float_literal", "9.5", 9.5),
-
("unknown_identifier", "foo", None, None, None, DSLEvaluationError),
("known_identifier", "foo", 9, {"foo": 9}),
(
@@ -60,33 +59,31 @@ class TestDSLEvaluator(object):
),
("known_nested_identifier", "foo.bar", 9, {"foo": {"bar": 9}}),
("known_nested_identifier", "foo.bar.baz", 9, {"foo": {"bar": {"baz": 9}}}),
- ("known_nested_identifier", "foo.bar[0].baz", 9, {"foo": {"bar": [{"baz": 9}]}}),
+ (
+ "known_nested_identifier",
+ "foo.bar[0].baz",
+ 9,
+ {"foo": {"bar": [{"baz": 9}]}},
+ ),
(
"identifier_pointing_to_the_object",
"'eresources' in subject.attributes",
True,
{"subject": Subject(["eresources"])},
),
-
("simple_negation", "-9", -9),
("simple_parenthesized_expression_negation", "-(9)", -(9)),
("parenthesized_expression_negation", "-(9 + 3)", -(9 + 3)),
("slice_expression_negation", "-(arr[1])", -12, {"arr": [1, 12, 3]}),
-
("addition_with_two_operands", "9 + 3", 9 + 3),
("addition_with_three_operands", "9 + 3 + 3", 9 + 3 + 3),
("addition_with_four_operands", "9 + 3 + 3 + 3", 9 + 3 + 3 + 3),
-
("subtraction_with_two_operands", "9 - 3", 9 - 3),
-
("multiplication_with_two_operands", "9 * 3", 9 * 3),
-
("division_with_two_operands", "9 / 3", 9 / 3),
("division_with_two_operands_and_remainder", "9 / 4", 9.0 / 4.0),
-
("exponentiation_with_two_operands", "9 ** 3", 9 ** 3),
("exponentiation_with_three_operands", "2 ** 3 ** 3", 2 ** 3 ** 3),
-
(
"associative_law_for_addition",
"(a + b) + c == a + (b + c)",
@@ -99,47 +96,33 @@ class TestDSLEvaluator(object):
True,
{"a": 9, "b": 3, "c": 3},
),
-
- (
- "commutative_law_for_addition",
- "a + b == b + a",
- True,
- {"a": 9, "b": 3}
- ),
+ ("commutative_law_for_addition", "a + b == b + a", True, {"a": 9, "b": 3}),
(
"commutative_law_for_multiplication",
"a * b == b * a",
True,
{"a": 9, "b": 3},
),
-
(
"distributive_law",
"a * (b + c) == a * b + a * c",
True,
{"a": 9, "b": 3, "c": 3},
),
-
("less_comparison", "9 < 3", 9 < 3),
("less_or_equal_comparison", "3 <= 3", 3 <= 3),
("greater_comparison", "9 > 3", 9 > 3),
("greater_or_equal_comparison", "3 >= 2", 3 >= 2),
-
("in_operator", "3 in list", True, {"list": [1, 2, 3]}),
-
("inversion", "not 9 < 3", not 9 < 3),
("double_inversion", "not not 9 < 3", not not 9 < 3),
("triple_inversion", "not not not 9 < 3", not not not 9 < 3),
-
("conjunction", "9 == 9 and 3 == 3", 9 == 9 and 3 == 3),
("disjunction", "9 == 3 or 3 == 3", 9 == 3 or 3 == 3),
-
("simple_parenthesized_expression", "(9 + 3)", (9 + 3)),
("arithmetic_parenthesized_expression", "2 * (9 + 3) * 2", 2 * (9 + 3) * 2),
-
("slice_expression", "arr[1] == 12", True, {"arr": [1, 12, 3]}),
("complex_slice_expression", "arr[1] + arr[2]", 15, {"arr": [1, 12, 3]}),
-
("method_call", "string.upper()", "HELLO WORLD", {"string": "Hello World"}),
("builtin_function_call", "min(1, 2)", min(1, 2)),
(
diff --git a/tests/test_analytics.py b/tests/test_analytics.py
index d32d5a2ef..ddf3fec62 100644
--- a/tests/test_analytics.py
+++ b/tests/test_analytics.py
@@ -1,19 +1,11 @@
-from ..config import (
- Configuration,
- temp_config,
-)
+import json
+
from ..analytics import Analytics
-from ..mock_analytics_provider import MockAnalyticsProvider
+from ..config import Configuration, temp_config
from ..local_analytics_provider import LocalAnalyticsProvider
+from ..mock_analytics_provider import MockAnalyticsProvider
+from ..model import CirculationEvent, ExternalIntegration, Library, create, get_one
from ..testing import DatabaseTest
-from ..model import (
- CirculationEvent,
- ExternalIntegration,
- Library,
- create,
- get_one
-)
-import json
# We can't import mock_analytics_provider from within a test,
# and we can't tell Analytics to do so either. We need to tell
@@ -21,29 +13,32 @@
# class is in.
MOCK_PROTOCOL = "..mock_analytics_provider"
-class TestAnalytics(DatabaseTest):
+class TestAnalytics(DatabaseTest):
def test_initialize(self):
# supports multiple analytics providers, site-wide or with libraries
# Two site-wide integrations
mock_integration, ignore = create(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
goal=ExternalIntegration.ANALYTICS_GOAL,
- protocol=MOCK_PROTOCOL
+ protocol=MOCK_PROTOCOL,
)
mock_integration.url = self._str
local_integration, ignore = create(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
goal=ExternalIntegration.ANALYTICS_GOAL,
- protocol="..local_analytics_provider"
+ protocol="..local_analytics_provider",
)
# A broken integration
missing_integration, ignore = create(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
goal=ExternalIntegration.ANALYTICS_GOAL,
- protocol="missing_provider"
+ protocol="missing_provider",
)
# Two library-specific integrations
@@ -51,17 +46,19 @@ def test_initialize(self):
l2, ignore = create(self._db, Library, short_name="L2")
library_integration1, ignore = create(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
goal=ExternalIntegration.ANALYTICS_GOAL,
- protocol=MOCK_PROTOCOL
- )
+ protocol=MOCK_PROTOCOL,
+ )
library_integration1.libraries += [l1, l2]
library_integration2, ignore = create(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
goal=ExternalIntegration.ANALYTICS_GOAL,
- protocol=MOCK_PROTOCOL
- )
+ protocol=MOCK_PROTOCOL,
+ )
library_integration2.libraries += [l2]
analytics = Analytics(self._db)
@@ -123,14 +120,16 @@ def test_is_configured(self):
def test_collect_event(self):
sitewide_integration, ignore = create(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
goal=ExternalIntegration.ANALYTICS_GOAL,
- protocol=MOCK_PROTOCOL
+ protocol=MOCK_PROTOCOL,
)
library, ignore = create(self._db, Library, short_name="library")
library_integration, ignore = create(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
goal=ExternalIntegration.ANALYTICS_GOAL,
protocol=MOCK_PROTOCOL,
)
@@ -142,7 +141,9 @@ def test_collect_event(self):
sitewide_provider = analytics.sitewide_providers[0]
library_provider = analytics.library_providers[library.id][0]
- analytics.collect_event(self._default_library, lp, CirculationEvent.DISTRIBUTOR_CHECKIN, None)
+ analytics.collect_event(
+ self._default_library, lp, CirculationEvent.DISTRIBUTOR_CHECKIN, None
+ )
# The sitewide provider was called.
assert 1 == sitewide_provider.count
@@ -169,9 +170,10 @@ def test_collect_event(self):
def test_initialize(self):
local_analytics = get_one(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
protocol=LocalAnalyticsProvider.__module__,
- goal=ExternalIntegration.ANALYTICS_GOAL
+ goal=ExternalIntegration.ANALYTICS_GOAL,
)
# There shouldn't exist a local analytics service.
@@ -189,10 +191,11 @@ def test_initialize(self):
local_analytics = LocalAnalyticsProvider.initialize(self._db)
local_analytics_2 = get_one(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
protocol=LocalAnalyticsProvider.__module__,
- goal=ExternalIntegration.ANALYTICS_GOAL
+ goal=ExternalIntegration.ANALYTICS_GOAL,
)
assert local_analytics_2.id == local_analytics.id
- assert local_analytics_2.name == local_analytics.name
\ No newline at end of file
+ assert local_analytics_2.name == local_analytics.name
diff --git a/tests/test_app_server.py b/tests/test_app_server.py
index d6ffae78a..ec39658fd 100644
--- a/tests/test_app_server.py
+++ b/tests/test_app_server.py
@@ -1,101 +1,69 @@
import gzip
-from io import BytesIO
-import os
import json
+import os
+from io import BytesIO
import flask
from flask import Flask
-from flask_babel import (
- Babel,
- lazy_gettext as _
-)
-
-from ..testing import (
- DatabaseTest,
-)
-
-from ..opds import TestAnnotator
-
-from ..model import (
- Identifier,
- ConfigurationSetting,
-)
-
-from ..lane import (
- Facets,
- Pagination,
- SearchFacets,
- WorkList,
-)
+from flask_babel import Babel
+from flask_babel import lazy_gettext as _
from ..app_server import (
+ ComplaintController,
+ ErrorHandler,
HeartbeatController,
URNLookupController,
URNLookupHandler,
- ErrorHandler,
- ComplaintController,
compressible,
load_facets_from_request,
load_pagination_from_request,
)
-
from ..config import Configuration
-
+from ..entrypoint import AudiobooksEntryPoint, EbooksEntryPoint, EntryPoint
+from ..lane import Facets, Pagination, SearchFacets, WorkList
from ..log import LogConfiguration
-
-from ..entrypoint import (
- AudiobooksEntryPoint,
- EbooksEntryPoint,
- EntryPoint,
-)
-
-from ..problem_details import (
- INVALID_INPUT,
- INVALID_URN,
-)
-
-from ..util.opds_writer import (
- OPDSFeed,
- OPDSMessage,
-)
+from ..model import ConfigurationSetting, Identifier
+from ..opds import TestAnnotator
+from ..problem_details import INVALID_INPUT, INVALID_URN
+from ..testing import DatabaseTest
+from ..util.opds_writer import OPDSFeed, OPDSMessage
class TestHeartbeatController(object):
-
def test_heartbeat(self):
app = Flask(__name__)
controller = HeartbeatController()
- with app.test_request_context('/'):
+ with app.test_request_context("/"):
response = controller.heartbeat()
assert 200 == response.status_code
- assert controller.HEALTH_CHECK_TYPE == response.headers.get('Content-Type')
+ assert controller.HEALTH_CHECK_TYPE == response.headers.get("Content-Type")
data = json.loads(response.data.decode("utf8"))
- assert 'pass' == data['status']
+ assert "pass" == data["status"]
# Create a .version file.
root_dir = os.path.join(os.path.split(__file__)[0], "..", "..")
version_filename = os.path.join(root_dir, controller.VERSION_FILENAME)
- with open(version_filename, 'w') as f:
- f.write('ba.na.na-10-ssssssssss')
+ with open(version_filename, "w") as f:
+ f.write("ba.na.na-10-ssssssssss")
# Create a mock configuration object to test with.
class MockConfiguration(Configuration):
instance = dict()
- with app.test_request_context('/'):
+ with app.test_request_context("/"):
response = controller.heartbeat(conf_class=MockConfiguration)
if os.path.exists(version_filename):
os.remove(version_filename)
assert 200 == response.status_code
- content_type = response.headers.get('Content-Type')
+ content_type = response.headers.get("Content-Type")
assert controller.HEALTH_CHECK_TYPE == content_type
data = json.loads(response.data.decode("utf8"))
- assert 'pass' == data['status']
- assert 'ba.na.na' == data['version']
- assert 'ba.na.na-10-ssssssssss' == data['releaseID']
+ assert "pass" == data["status"]
+ assert "ba.na.na" == data["version"]
+ assert "ba.na.na-10-ssssssssss" == data["releaseID"]
class TestURNLookupHandler(DatabaseTest):
@@ -121,6 +89,7 @@ def test_process_urns_hook_method(self):
class Mock(URNLookupHandler):
def post_lookup_hook(self):
self.called = True
+
handler = Mock(self._db)
handler.process_urns([])
assert True == handler.called
@@ -133,13 +102,11 @@ def test_process_urns_invalid_urn(self):
def test_process_urns_unrecognized_identifier(self):
# Give the handler a URN that, although valid, doesn't
# correspond to any Identifier in the database.
- urn = Identifier.GUTENBERG_URN_SCHEME_PREFIX + 'Gutenberg%20ID/000'
+ urn = Identifier.GUTENBERG_URN_SCHEME_PREFIX + "Gutenberg%20ID/000"
self.handler.process_urns([urn])
# The result is a 404 message.
- self.assert_one_message(
- urn, 404, self.handler.UNRECOGNIZED_IDENTIFIER
- )
+ self.assert_one_message(urn, 404, self.handler.UNRECOGNIZED_IDENTIFIER)
def test_process_identifier_no_license_pool(self):
# Give the handler a URN that corresponds to an Identifier
@@ -156,9 +123,7 @@ def test_process_identifier_license_pool_but_no_work(self):
edition, pool = self._edition(with_license_pool=True)
identifier = edition.primary_identifier
self.handler.process_identifier(identifier, identifier.urn)
- self.assert_one_message(
- identifier.urn, 202, self.handler.WORK_NOT_CREATED
- )
+ self.assert_one_message(identifier.urn, 202, self.handler.WORK_NOT_CREATED)
def test_process_identifier_work_not_presentation_ready(self):
work = self._work(with_license_pool=True)
@@ -175,21 +140,24 @@ def test_process_identifier_work_is_presentation_ready(self):
identifier = work.license_pools[0].identifier
self.handler.process_identifier(identifier, identifier.urn)
assert [] == self.handler.precomposed_entries
- assert ([(work.presentation_edition.primary_identifier, work)] ==
- self.handler.works)
+ assert [
+ (work.presentation_edition.primary_identifier, work)
+ ] == self.handler.works
-class TestURNLookupController(DatabaseTest):
+class TestURNLookupController(DatabaseTest):
def setup_method(self):
super(TestURNLookupController, self).setup_method()
self.controller = URNLookupController(self._db)
# Set up a mock Flask app for testing the controller methods.
app = Flask(__name__)
- @app.route('/lookup')
+
+ @app.route("/lookup")
def lookup(self, urn):
pass
- @app.route('/work')
+
+ @app.route("/work")
def work(self, urn):
pass
@@ -207,8 +175,9 @@ def test_work_lookup(self):
# We got an OPDS feed that includes an entry for the work.
assert 200 == response.status_code
- assert (OPDSFeed.ACQUISITION_FEED_TYPE ==
- response.headers['Content-Type'])
+ assert (
+ OPDSFeed.ACQUISITION_FEED_TYPE == response.headers["Content-Type"]
+ )
response_data = response.data.decode("utf8")
assert identifier.urn in response_data
assert 1 == response_data.count(work.title)
@@ -219,6 +188,7 @@ def test_process_urns_problem_detail(self):
class Mock(URNLookupController):
def process_urns(self, urns, **kwargs):
return INVALID_INPUT
+
controller = Mock(self._db)
with self.app.test_request_context("/?urn=foobar"):
response = controller.work_lookup(annotator=object())
@@ -234,15 +204,13 @@ def test_permalink(self):
# We got an OPDS feed that includes an entry for the work.
assert 200 == response.status_code
- assert (OPDSFeed.ACQUISITION_FEED_TYPE ==
- response.headers['Content-Type'])
+ assert OPDSFeed.ACQUISITION_FEED_TYPE == response.headers["Content-Type"]
response_data = response.data.decode("utf8")
assert identifier.urn in response_data
assert work.title in response_data
class TestComplaintController(DatabaseTest):
-
def setup_method(self):
super(TestComplaintController, self).setup_method()
self.controller = ComplaintController()
@@ -253,25 +221,26 @@ def setup_method(self):
def test_no_license_pool(self):
with self.app.test_request_context("/"):
response = self.controller.register(None, "{}")
- assert response.status.startswith('400')
+ assert response.status.startswith("400")
body = json.loads(response.data.decode("utf8"))
- assert "No license pool specified" == body['title']
+ assert "No license pool specified" == body["title"]
def test_invalid_document(self):
with self.app.test_request_context("/"):
response = self.controller.register(self.pool, "not {a} valid document")
- assert response.status.startswith('400')
+ assert response.status.startswith("400")
body = json.loads(response.data.decode("utf8"))
- assert "Invalid problem detail document" == body['title']
+ assert "Invalid problem detail document" == body["title"]
def test_invalid_type(self):
data = json.dumps({"type": "http://not-a-recognized-type/"})
with self.app.test_request_context("/"):
response = self.controller.register(self.pool, data)
- assert response.status.startswith('400')
+ assert response.status.startswith("400")
body = json.loads(response.data.decode("utf8"))
- assert ("Unrecognized problem type: http://not-a-recognized-type/" ==
- body['title'])
+ assert (
+ "Unrecognized problem type: http://not-a-recognized-type/" == body["title"]
+ )
def test_success(self):
data = json.dumps(
@@ -283,14 +252,13 @@ def test_success(self):
)
with self.app.test_request_context("/"):
response = self.controller.register(self.pool, data)
- assert response.status.startswith('201')
+ assert response.status.startswith("201")
[complaint] = self.pool.complaints
assert "foo" == complaint.source
assert "bar" == complaint.detail
class TestLoadMethods(DatabaseTest):
-
def setup_method(self):
super(TestLoadMethods, self).setup_method()
self.app = Flask(__name__)
@@ -298,12 +266,11 @@ def setup_method(self):
def test_load_facets_from_request(self):
# The library has two EntryPoints enabled.
- self._default_library.setting(EntryPoint.ENABLED_SETTING).value = (
- json.dumps([EbooksEntryPoint.INTERNAL_NAME,
- AudiobooksEntryPoint.INTERNAL_NAME])
+ self._default_library.setting(EntryPoint.ENABLED_SETTING).value = json.dumps(
+ [EbooksEntryPoint.INTERNAL_NAME, AudiobooksEntryPoint.INTERNAL_NAME]
)
- with self.app.test_request_context('/?order=%s' % Facets.ORDER_TITLE):
+ with self.app.test_request_context("/?order=%s" % Facets.ORDER_TITLE):
flask.request.library = self._default_library
facets = load_facets_from_request()
assert Facets.ORDER_TITLE == facets.order
@@ -311,7 +278,7 @@ def test_load_facets_from_request(self):
# in case the load method received a custom config.
assert facets.facets_enabled_at_init != None
- with self.app.test_request_context('/?order=bad_facet'):
+ with self.app.test_request_context("/?order=bad_facet"):
flask.request.library = self._default_library
problemdetail = load_facets_from_request()
assert INVALID_INPUT.uri == problemdetail.uri
@@ -321,7 +288,7 @@ def test_load_facets_from_request(self):
# configured on the present library.
worklist = WorkList()
worklist.initialize(self._default_library)
- with self.app.test_request_context('/?entrypoint=Audio'):
+ with self.app.test_request_context("/?entrypoint=Audio"):
flask.request.library = self._default_library
facets = load_facets_from_request(worklist=worklist)
assert AudiobooksEntryPoint == facets.entrypoint
@@ -329,9 +296,9 @@ def test_load_facets_from_request(self):
# If the requested EntryPoint not configured, the default
# EntryPoint is used.
- with self.app.test_request_context('/?entrypoint=NoSuchEntryPoint'):
+ with self.app.test_request_context("/?entrypoint=NoSuchEntryPoint"):
flask.request.library = self._default_library
- default_entrypoint=object()
+ default_entrypoint = object()
facets = load_facets_from_request(
worklist=worklist, default_entrypoint=default_entrypoint
)
@@ -340,32 +307,31 @@ def test_load_facets_from_request(self):
# Load a SearchFacets object that pulls information from an
# HTTP header.
- with self.app.test_request_context(
- '/', headers = {'Accept-Language' : 'ja' }
- ):
+ with self.app.test_request_context("/", headers={"Accept-Language": "ja"}):
flask.request.library = self._default_library
facets = load_facets_from_request(base_class=SearchFacets)
- assert ['jpn'] == facets.languages
+ assert ["jpn"] == facets.languages
def test_load_facets_from_request_class_instantiation(self):
"""The caller of load_facets_from_request() can specify a class other
than Facets to call from_request() on.
"""
+
class MockFacets(object):
@classmethod
def from_request(*args, **kwargs):
facets = MockFacets()
facets.called_with = kwargs
return facets
- kwargs = dict(some_arg='some value')
- with self.app.test_request_context(''):
+
+ kwargs = dict(some_arg="some value")
+ with self.app.test_request_context(""):
flask.request.library = self._default_library
facets = load_facets_from_request(
- None, None, base_class=MockFacets,
- base_class_constructor_kwargs=kwargs
+ None, None, base_class=MockFacets, base_class_constructor_kwargs=kwargs
)
assert isinstance(facets, MockFacets)
- assert 'some value' == facets.called_with['some_arg']
+ assert "some value" == facets.called_with["some_arg"]
def test_load_pagination_from_request(self):
# Verify that load_pagination_from_request insantiates a
@@ -379,28 +345,27 @@ def from_request(cls, get_arg, default_size, **kwargs):
cls.called_with = (get_arg, default_size, kwargs)
return "I'm a pagination object!"
- with self.app.test_request_context('/'):
+ with self.app.test_request_context("/"):
# Call load_pagination_from_request and verify that
# Mock.from_request was called with the arguments we expect.
- extra_kwargs = dict(extra='kwarg')
+ extra_kwargs = dict(extra="kwarg")
pagination = load_pagination_from_request(
- base_class=Mock, base_class_constructor_kwargs=extra_kwargs,
- default_size=44
+ base_class=Mock,
+ base_class_constructor_kwargs=extra_kwargs,
+ default_size=44,
)
assert "I'm a pagination object!" == pagination
- assert ((flask.request.args.get, 44, extra_kwargs) ==
- Mock.called_with)
+ assert (flask.request.args.get, 44, extra_kwargs) == Mock.called_with
# If no default size is specified, we trust from_request to
# use the class default.
- with self.app.test_request_context('/'):
+ with self.app.test_request_context("/"):
pagination = load_pagination_from_request(base_class=Mock)
- assert ((flask.request.args.get, None, {}) ==
- Mock.called_with)
+ assert (flask.request.args.get, None, {}) == Mock.called_with
# Now try a real case using the default pagination class,
# Pagination
- with self.app.test_request_context('/?size=50&after=10'):
+ with self.app.test_request_context("/?size=50&after=10"):
pagination = load_pagination_from_request()
assert isinstance(pagination, Pagination)
assert 50 == pagination.size
@@ -418,12 +383,11 @@ class CanBeProblemDetailDocument(Exception):
def as_problem_detail_document(self, debug):
return INVALID_URN.detailed(
_("detail info"),
- debug_message="A debug_message which should only appear in debug mode."
+ debug_message="A debug_message which should only appear in debug mode.",
)
class TestErrorHandler(DatabaseTest):
-
def setup_method(self):
super(TestErrorHandler, self).setup_method()
@@ -433,6 +397,7 @@ class MockManager(object):
This gives ErrorHandler access to a database connection.
"""
+
_db = self._db
self.app = Flask(__name__)
@@ -453,7 +418,7 @@ def raise_exception(self, cls=Exception):
def test_unhandled_error(self):
handler = ErrorHandler(self.app)
- with self.app.test_request_context('/'):
+ with self.app.test_request_context("/"):
response = None
try:
self.raise_exception()
@@ -462,26 +427,24 @@ def test_unhandled_error(self):
assert 500 == response.status_code
assert "An internal error occured" == response.data.decode("utf8")
-
def test_unhandled_error_debug(self):
# Set the sitewide log level to DEBUG to get a stack trace
# instead of a generic error message.
handler = ErrorHandler(self.app)
self.activate_debug_mode()
- with self.app.test_request_context('/'):
+ with self.app.test_request_context("/"):
response = None
try:
self.raise_exception()
except Exception as exception:
response = handler.handle(exception)
assert 500 == response.status_code
- assert response.data.startswith(b'Traceback (most recent call last)')
-
+ assert response.data.startswith(b"Traceback (most recent call last)")
def test_handle_error_as_problem_detail_document(self):
handler = ErrorHandler(self.app)
- with self.app.test_request_context('/'):
+ with self.app.test_request_context("/"):
try:
self.raise_exception(CanBeProblemDetailDocument)
except Exception as exception:
@@ -489,18 +452,18 @@ def test_handle_error_as_problem_detail_document(self):
assert 400 == response.status_code
data = json.loads(response.data.decode("utf8"))
- assert INVALID_URN.title == data['title']
+ assert INVALID_URN.title == data["title"]
# Since we are not in debug mode, the debug_message is
# destroyed.
- assert 'debug_message' not in data
+ assert "debug_message" not in data
def test_handle_error_as_problem_detail_document_debug(self):
# When in debug mode, the debug_message is preserved and a
# stack trace is appended to it.
handler = ErrorHandler(self.app)
self.activate_debug_mode()
- with self.app.test_request_context('/'):
+ with self.app.test_request_context("/"):
try:
self.raise_exception(CanBeProblemDetailDocument)
except Exception as exception:
@@ -508,10 +471,10 @@ def test_handle_error_as_problem_detail_document_debug(self):
assert 400 == response.status_code
data = json.loads(response.data.decode("utf8"))
- assert INVALID_URN.title == data['title']
- assert data['debug_message'].startswith(
+ assert INVALID_URN.title == data["title"]
+ assert data["debug_message"].startswith(
"A debug_message which should only appear in debug mode.\n\n"
- 'Traceback (most recent call last)'
+ "Traceback (most recent call last)"
)
@@ -528,13 +491,13 @@ def test_compressible(self):
value = b"Compress me! (Or not.)"
buffer = BytesIO()
- gzipped = gzip.GzipFile(mode='wb', fileobj=buffer)
+ gzipped = gzip.GzipFile(mode="wb", fileobj=buffer)
gzipped.write(value)
gzipped.close()
compressed = buffer.getvalue()
# Spot-check the compressed value
- assert b'-(J-.V' in compressed
+ assert b"-(J-.V" in compressed
# This compressible controller function always returns the
# same value.
@@ -542,7 +505,7 @@ def test_compressible(self):
def function():
return value
- def ask_for_compression(compression, header='Accept-Encoding'):
+ def ask_for_compression(compression, header="Accept-Encoding"):
"""This context manager simulates the entire Flask
request-response cycle, including a call to
process_response(), which triggers the @after_this_request
@@ -562,22 +525,22 @@ def ask_for_compression(compression, header='Accept-Encoding'):
# representation is compressed.
response = ask_for_compression("gzip")
assert compressed == response.data
- assert "gzip" == response.headers['Content-Encoding']
+ assert "gzip" == response.headers["Content-Encoding"]
# If the client doesn't ask for compression, the value is
# passed through unchanged.
response = ask_for_compression(None)
assert value == response.data
- assert 'Content-Encoding' not in response.headers
+ assert "Content-Encoding" not in response.headers
# Similarly if the client asks for an unsupported compression
# mechanism.
- response = ask_for_compression('compress')
+ response = ask_for_compression("compress")
assert value == response.data
- assert 'Content-Encoding' not in response.headers
+ assert "Content-Encoding" not in response.headers
# Or if the client asks for a compression mechanism through
# Accept-Transfer-Encoding, which is currently unsupported.
response = ask_for_compression("gzip", "Accept-Transfer-Encoding")
assert value == response.data
- assert 'Content-Encoding' not in response.headers
+ assert "Content-Encoding" not in response.headers
diff --git a/tests/test_authentication_for_opds.py b/tests/test_authentication_for_opds.py
index 7de0dc6ed..1bdafcb50 100644
--- a/tests/test_authentication_for_opds.py
+++ b/tests/test_authentication_for_opds.py
@@ -1,22 +1,26 @@
import pytest
-from ..util.authentication_for_opds import (
- AuthenticationForOPDSDocument as Doc,
- OPDSAuthenticationFlow as Flow,
-)
+
+from ..util.authentication_for_opds import AuthenticationForOPDSDocument as Doc
+from ..util.authentication_for_opds import OPDSAuthenticationFlow as Flow
+
class MockFlow(Flow):
"""A mock OPDSAuthenticationFlow that sets `type` in to_dict()"""
+
def __init__(self, description):
- self.description=description
+ self.description = description
def _authentication_flow_document(self, argument):
- return { "description": self.description,
- "arg": argument,
- "type" : "http://mock1/"}
+ return {
+ "description": self.description,
+ "arg": argument,
+ "type": "http://mock1/",
+ }
class MockFlowWithURI(Flow):
"""A mock OPDSAuthenticationFlow that sets URI."""
+
FLOW_TYPE = "http://mock2/"
def _authentication_flow_document(self, argument):
@@ -28,22 +32,23 @@ class MockFlowWithoutType(Flow):
Calling authentication_flow_document() on this object will fail.
"""
+
def _authentication_flow_document(self, argument):
return {}
class TestOPDSAuthenticationFlow(object):
-
def test_flow_sets_type_at_runtime(self):
"""An OPDSAuthenticationFlow object can set `type` during
to_dict().
"""
flow = MockFlow("description")
doc = flow.authentication_flow_document("argument")
- assert (
- {'type': 'http://mock1/', 'description': 'description',
- 'arg': 'argument'} ==
- doc)
+ assert {
+ "type": "http://mock1/",
+ "description": "description",
+ "arg": "argument",
+ } == doc
def test_flow_gets_type_from_uri(self):
"""An OPDSAuthenticationFlow object can define the class variableURI
@@ -51,49 +56,41 @@ def test_flow_gets_type_from_uri(self):
"""
flow = MockFlowWithURI()
doc = flow.authentication_flow_document("argument")
- assert {'type': 'http://mock2/'} == doc
+ assert {"type": "http://mock2/"} == doc
def test_flow_must_define_type(self):
"""An OPDSAuthenticationFlow object must get a value for `type`
_somehow_, or authentication_flow_document() will fail.
"""
flow = MockFlowWithoutType()
- pytest.raises(
- ValueError, flow.authentication_flow_document, 'argument'
- )
+ pytest.raises(ValueError, flow.authentication_flow_document, "argument")
class TestAuthenticationForOPDSDocument(object):
-
def test_good_document(self):
- """Verify that to_dict() works when all the data is in place.
- """
+ """Verify that to_dict() works when all the data is in place."""
doc_obj = Doc(
id="id",
title="title",
authentication_flows=[MockFlow("hello")],
- links=[
- dict(rel="register", href="http://registration/")
- ]
+ links=[dict(rel="register", href="http://registration/")],
)
doc = doc_obj.to_dict("argument")
- assert (
- {'id': 'id',
- 'title': 'title',
- 'authentication': [
- {'arg': 'argument',
- 'description': 'hello',
- 'type': 'http://mock1/'}
- ],
- 'links': [{'href': 'http://registration/', 'rel': 'register'}],
- } ==
- doc)
+ assert {
+ "id": "id",
+ "title": "title",
+ "authentication": [
+ {"arg": "argument", "description": "hello", "type": "http://mock1/"}
+ ],
+ "links": [{"href": "http://registration/", "rel": "register"}],
+ } == doc
def test_bad_document(self):
"""Test that to_dict() raises ValueError when something is
wrong with the data.
"""
+
def cannot_make(document):
pytest.raises(ValueError, document.to_dict, object())
@@ -102,23 +99,35 @@ def cannot_make(document):
cannot_make(Doc(id="no title", title=None))
# authentication_flows and links must both be lists.
- cannot_make(Doc(id="id", title="title",
- authentication_flows="not a list"))
- cannot_make(Doc(id="id", title="title",
- authentication_flows=["a list"],
- links="not a list"))
+ cannot_make(Doc(id="id", title="title", authentication_flows="not a list"))
+ cannot_make(
+ Doc(
+ id="id",
+ title="title",
+ authentication_flows=["a list"],
+ links="not a list",
+ )
+ )
# A link must be a dict.
- cannot_make(Doc(id="id", title="title",
- authentication_flows=[],
- links=["not a dict"]))
+ cannot_make(
+ Doc(id="id", title="title", authentication_flows=[], links=["not a dict"])
+ )
# A link must have a rel and an href.
- cannot_make(Doc(id="id", title="title",
- authentication_flows=[],
- links=[{"rel": "no href"}]))
- cannot_make(Doc(id="id", title="title",
- authentication_flows=[],
- links=[{"href": "no rel"}]))
-
-
+ cannot_make(
+ Doc(
+ id="id",
+ title="title",
+ authentication_flows=[],
+ links=[{"rel": "no href"}],
+ )
+ )
+ cannot_make(
+ Doc(
+ id="id",
+ title="title",
+ authentication_flows=[],
+ links=[{"href": "no rel"}],
+ )
+ )
diff --git a/tests/test_cdn.py b/tests/test_cdn.py
index 968294f91..2ce5f3d44 100644
--- a/tests/test_cdn.py
+++ b/tests/test_cdn.py
@@ -1,13 +1,11 @@
# encoding: utf-8
-from ..testing import DatabaseTest
-
+from ..cdn import cdnify
from ..config import Configuration, temp_config
from ..model import ExternalIntegration
-from ..cdn import cdnify
+from ..testing import DatabaseTest
class TestCDN(DatabaseTest):
-
def unchanged(self, url, cdns):
self.ceq(url, url, cdns)
@@ -24,21 +22,21 @@ def test_no_cdns(self):
def test_non_matching_cdn(self):
url = "http://foo.com/bar"
- self.unchanged(url, {"bar.com" : "cdn.com"})
+ self.unchanged(url, {"bar.com": "cdn.com"})
def test_matching_cdn(self):
url = "http://foo.com/bar#baz"
- self.ceq("https://cdn.org/bar#baz", url,
- {"foo.com" : "https://cdn.org",
- "bar.com" : "http://cdn2.net/"}
+ self.ceq(
+ "https://cdn.org/bar#baz",
+ url,
+ {"foo.com": "https://cdn.org", "bar.com": "http://cdn2.net/"},
)
def test_relative_url(self):
# By default, relative URLs are untouched.
url = "/groups/"
- self.unchanged(url, {"bar.com" : "cdn.com"})
+ self.unchanged(url, {"bar.com": "cdn.com"})
# But if the CDN list has an entry for the empty string, that
# URL is used for relative URLs.
- self.ceq("https://cdn.org/groups/", url,
- {"" : "https://cdn.org/"})
+ self.ceq("https://cdn.org/groups/", url, {"": "https://cdn.org/"})
diff --git a/tests/test_circulation_data.py b/tests/test_circulation_data.py
index 5d18a9884..5a220fec7 100644
--- a/tests/test_circulation_data.py
+++ b/tests/test_circulation_data.py
@@ -1,7 +1,7 @@
-import pytest
-
-from copy import deepcopy
import datetime
+from copy import deepcopy
+
+import pytest
from ..metadata_layer import (
CirculationData,
@@ -14,7 +14,6 @@
ReplacementPolicy,
SubjectData,
)
-
from ..model import (
Collection,
DataSource,
@@ -27,15 +26,12 @@
Subject,
)
from ..model.configuration import ExternalIntegrationLink
-from ..testing import (
- DatabaseTest,
- DummyHTTPClient,
-)
from ..s3 import MockS3Uploader
+from ..testing import DatabaseTest, DummyHTTPClient
from ..util.datetime_helpers import utc_now
-class TestCirculationData(DatabaseTest):
+class TestCirculationData(DatabaseTest):
def test_circulationdata_may_require_collection(self):
"""Depending on the information provided in a CirculationData
object, it might or might not be possible to call apply()
@@ -44,13 +40,12 @@ def test_circulationdata_may_require_collection(self):
identifier = IdentifierData(Identifier.OVERDRIVE_ID, "1")
format = FormatData(
- Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM,
- rights_uri=RightsStatus.IN_COPYRIGHT
+ Representation.EPUB_MEDIA_TYPE,
+ DeliveryMechanism.NO_DRM,
+ rights_uri=RightsStatus.IN_COPYRIGHT,
)
circdata = CirculationData(
- DataSource.OVERDRIVE,
- primary_identifier=identifier,
- formats=[format]
+ DataSource.OVERDRIVE, primary_identifier=identifier, formats=[format]
)
circdata.apply(self._db, collection=None)
@@ -71,7 +66,10 @@ def test_circulationdata_may_require_collection(self):
circdata.licenses_owned = 0
with pytest.raises(ValueError) as excinfo:
circdata.apply(self._db, collection=None)
- assert 'Cannot store circulation information because no Collection was provided.' in str(excinfo.value)
+ assert (
+ "Cannot store circulation information because no Collection was provided."
+ in str(excinfo.value)
+ )
def test_circulationdata_can_be_deepcopied(self):
# Check that we didn't put something in the CirculationData that
@@ -101,18 +99,20 @@ def test_circulationdata_can_be_deepcopied(self):
# If deepcopy didn't throw an exception we're ok.
assert circulation_data_copy is not None
-
def test_links_filtered(self):
# Tests that passed-in links filter down to only the relevant ones.
link1 = LinkData(Hyperlink.OPEN_ACCESS_DOWNLOAD, "example.epub")
link2 = LinkData(rel=Hyperlink.IMAGE, href="http://example.com/")
link3 = LinkData(rel=Hyperlink.DESCRIPTION, content="foo")
link4 = LinkData(
- rel=Hyperlink.THUMBNAIL_IMAGE, href="http://thumbnail.com/",
+ rel=Hyperlink.THUMBNAIL_IMAGE,
+ href="http://thumbnail.com/",
media_type=Representation.JPEG_MEDIA_TYPE,
)
link5 = LinkData(
- rel=Hyperlink.IMAGE, href="http://example.com/", thumbnail=link4,
+ rel=Hyperlink.IMAGE,
+ href="http://example.com/",
+ thumbnail=link4,
media_type=Representation.JPEG_MEDIA_TYPE,
)
links = [link1, link2, link3, link4, link5]
@@ -124,11 +124,10 @@ def test_links_filtered(self):
links=links,
)
- filtered_links = sorted(circulation_data.links, key=lambda x:x.rel)
+ filtered_links = sorted(circulation_data.links, key=lambda x: x.rel)
assert [link1] == filtered_links
-
def test_explicit_formatdata(self):
# Creating an edition with an open-access download will
# automatically create a delivery mechanism.
@@ -147,8 +146,9 @@ def test_explicit_formatdata(self):
)
circulation_data.apply(self._db, pool.collection)
- [epub, pdf] = sorted(pool.delivery_mechanisms,
- key=lambda x: x.delivery_mechanism.content_type)
+ [epub, pdf] = sorted(
+ pool.delivery_mechanisms, key=lambda x: x.delivery_mechanism.content_type
+ )
assert epub.resource == pool.best_open_access_resource
assert Representation.PDF_MEDIA_TYPE == pdf.delivery_mechanism.content_type
@@ -157,8 +157,8 @@ def test_explicit_formatdata(self):
# If we tell Metadata to replace the list of formats, we only
# have the one format we manually created.
replace = ReplacementPolicy(
- formats=True,
- )
+ formats=True,
+ )
circulation_data.apply(self._db, pool.collection, replace=replace)
[pdf] = pool.delivery_mechanisms
assert Representation.PDF_MEDIA_TYPE == pdf.delivery_mechanism.content_type
@@ -171,8 +171,11 @@ def test_apply_removes_old_formats_based_on_replacement_policy(self):
self._db.delete(lpdm)
old_lpdm = pool.set_delivery_mechanism(
- Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM,
- RightsStatus.IN_COPYRIGHT, None)
+ Representation.PDF_MEDIA_TYPE,
+ DeliveryMechanism.ADOBE_DRM,
+ RightsStatus.IN_COPYRIGHT,
+ None,
+ )
# And it has been loaned.
patron = self._patron()
@@ -196,8 +199,11 @@ def test_apply_removes_old_formats_based_on_replacement_policy(self):
circulation_data.apply(self._db, pool.collection, replacement_policy)
assert 2 == len(pool.delivery_mechanisms)
- assert (set([Representation.PDF_MEDIA_TYPE, Representation.EPUB_MEDIA_TYPE]) ==
- set([lpdm.delivery_mechanism.content_type for lpdm in pool.delivery_mechanisms]))
+ assert set(
+ [Representation.PDF_MEDIA_TYPE, Representation.EPUB_MEDIA_TYPE]
+ ) == set(
+ [lpdm.delivery_mechanism.content_type for lpdm in pool.delivery_mechanisms]
+ )
assert old_lpdm == loan.fulfillment
# But if we make formats true in the policy, we'll delete the old format
@@ -206,7 +212,10 @@ def test_apply_removes_old_formats_based_on_replacement_policy(self):
circulation_data.apply(self._db, pool.collection, replacement_policy)
assert 1 == len(pool.delivery_mechanisms)
- assert Representation.EPUB_MEDIA_TYPE == pool.delivery_mechanisms[0].delivery_mechanism.content_type
+ assert (
+ Representation.EPUB_MEDIA_TYPE
+ == pool.delivery_mechanisms[0].delivery_mechanism.content_type
+ )
assert None == loan.fulfillment
def test_apply_adds_new_licenses(self):
@@ -214,7 +223,9 @@ def test_apply_adds_new_licenses(self):
# Start with one license for this pool.
old_license = self._license(
- pool, expires=None, remaining_checkouts=2,
+ pool,
+ expires=None,
+ remaining_checkouts=2,
concurrent_checkouts=3,
)
@@ -229,7 +240,9 @@ def test_apply_adds_new_licenses(self):
checkout_url="https://borrow2",
status_url="https://status2",
expires=(utc_now() + datetime.timedelta(days=7)),
- remaining_checkouts=None, concurrent_checkouts=1)
+ remaining_checkouts=None,
+ concurrent_checkouts=1,
+ )
circulation_data = CirculationData(
licenses=[license_data],
@@ -243,8 +256,9 @@ def test_apply_adds_new_licenses(self):
self._db.commit()
assert 2 == len(pool.licenses)
- assert (set([old_license.identifier, license_data.identifier]) ==
- set([license.identifier for license in pool.licenses]))
+ assert set([old_license.identifier, license_data.identifier]) == set(
+ [license.identifier for license in pool.licenses]
+ )
assert old_license == loan.license
def test_apply_creates_work_and_presentation_edition_if_needed(self):
@@ -291,9 +305,7 @@ def test_license_pool_sets_default_license_values(self):
formats=[drm_format],
)
collection = self._default_collection
- pool, is_new = circulation.license_pool(
- self._db, collection
- )
+ pool, is_new = circulation.license_pool(self._db, collection)
assert True == is_new
assert collection == pool.collection
@@ -317,11 +329,10 @@ def test_implicit_format_for_open_access_link(self):
assert Representation.EPUB_MEDIA_TYPE == epub.delivery_mechanism.content_type
assert DeliveryMechanism.ADOBE_DRM == epub.delivery_mechanism.drm_scheme
-
link = LinkData(
rel=Hyperlink.OPEN_ACCESS_DOWNLOAD,
media_type=Representation.PDF_MEDIA_TYPE,
- href=self._url
+ href=self._url,
)
circulation_data = CirculationData(
data_source=DataSource.GUTENBERG,
@@ -330,8 +341,8 @@ def test_implicit_format_for_open_access_link(self):
)
replace = ReplacementPolicy(
- formats=True,
- )
+ formats=True,
+ )
circulation_data.apply(self._db, pool.collection, replace)
# We destroyed the default delivery format and added a new,
@@ -343,12 +354,12 @@ def test_implicit_format_for_open_access_link(self):
circulation_data = CirculationData(
data_source=DataSource.GUTENBERG,
primary_identifier=edition.primary_identifier,
- links=[]
+ links=[],
)
replace = ReplacementPolicy(
- formats=True,
- links=True,
- )
+ formats=True,
+ links=True,
+ )
circulation_data.apply(self._db, pool.collection, replace)
# Now we have no formats at all.
@@ -362,13 +373,13 @@ def test_rights_status_default_rights_passed_in(self):
link = LinkData(
rel=Hyperlink.DRM_ENCRYPTED_DOWNLOAD,
media_type=Representation.EPUB_MEDIA_TYPE,
- href=self._url
+ href=self._url,
)
circulation_data = CirculationData(
data_source=DataSource.OA_CONTENT_SERVER,
primary_identifier=identifier,
- default_rights_uri = RightsStatus.CC_BY,
+ default_rights_uri=RightsStatus.CC_BY,
links=[link],
)
@@ -376,9 +387,7 @@ def test_rights_status_default_rights_passed_in(self):
formats=True,
)
- pool, ignore = circulation_data.license_pool(
- self._db, self._default_collection
- )
+ pool, ignore = circulation_data.license_pool(self._db, self._default_collection)
circulation_data.apply(self._db, pool.collection, replace)
assert True == pool.open_access
assert 1 == len(pool.delivery_mechanisms)
@@ -393,7 +402,7 @@ def test_rights_status_default_rights_from_data_source(self):
link = LinkData(
rel=Hyperlink.DRM_ENCRYPTED_DOWNLOAD,
media_type=Representation.EPUB_MEDIA_TYPE,
- href=self._url
+ href=self._url,
)
circulation_data = CirculationData(
@@ -407,9 +416,7 @@ def test_rights_status_default_rights_from_data_source(self):
)
# This pool starts off as not being open-access.
- pool, ignore = circulation_data.license_pool(
- self._db, self._default_collection
- )
+ pool, ignore = circulation_data.license_pool(self._db, self._default_collection)
assert False == pool.open_access
circulation_data.apply(self._db, pool.collection, replace)
@@ -419,7 +426,10 @@ def test_rights_status_default_rights_from_data_source(self):
assert True == pool.open_access
assert 1 == len(pool.delivery_mechanisms)
# The rights status is the default for the OA content server.
- assert RightsStatus.GENERIC_OPEN_ACCESS == pool.delivery_mechanisms[0].rights_status.uri
+ assert (
+ RightsStatus.GENERIC_OPEN_ACCESS
+ == pool.delivery_mechanisms[0].rights_status.uri
+ )
def test_rights_status_open_access_link_no_rights_uses_data_source_default(self):
identifier = IdentifierData(
@@ -432,7 +442,7 @@ def test_rights_status_open_access_link_no_rights_uses_data_source_default(self)
link = LinkData(
rel=Hyperlink.OPEN_ACCESS_DOWNLOAD,
media_type=Representation.EPUB_MEDIA_TYPE,
- href=self._url
+ href=self._url,
)
circulation_data = CirculationData(
data_source=DataSource.GUTENBERG,
@@ -443,9 +453,7 @@ def test_rights_status_open_access_link_no_rights_uses_data_source_default(self)
formats=True,
)
- pool, ignore = circulation_data.license_pool(
- self._db, self._default_collection
- )
+ pool, ignore = circulation_data.license_pool(self._db, self._default_collection)
pool.open_access = False
# Applying this CirculationData to a LicensePool makes it
@@ -456,7 +464,10 @@ def test_rights_status_open_access_link_no_rights_uses_data_source_default(self)
# The delivery mechanism's rights status is the default for
# the data source.
- assert RightsStatus.PUBLIC_DOMAIN_USA == pool.delivery_mechanisms[0].rights_status.uri
+ assert (
+ RightsStatus.PUBLIC_DOMAIN_USA
+ == pool.delivery_mechanisms[0].rights_status.uri
+ )
# Even if a commercial source like Overdrive should offer a
# link with rel="open access", unless we know it's an
@@ -469,7 +480,7 @@ def test_rights_status_open_access_link_no_rights_uses_data_source_default(self)
link = LinkData(
rel=Hyperlink.OPEN_ACCESS_DOWNLOAD,
media_type=Representation.EPUB_MEDIA_TYPE,
- href=self._url
+ href=self._url,
)
circulation_data = CirculationData(
@@ -478,13 +489,12 @@ def test_rights_status_open_access_link_no_rights_uses_data_source_default(self)
links=[link],
)
- pool, ignore = circulation_data.license_pool(
- self._db, self._default_collection
- )
+ pool, ignore = circulation_data.license_pool(self._db, self._default_collection)
pool.open_access = False
circulation_data.apply(self._db, pool.collection, replace_formats)
- assert (RightsStatus.IN_COPYRIGHT ==
- pool.delivery_mechanisms[0].rights_status.uri)
+ assert (
+ RightsStatus.IN_COPYRIGHT == pool.delivery_mechanisms[0].rights_status.uri
+ )
assert False == pool.open_access
@@ -509,9 +519,7 @@ def test_rights_status_open_access_link_with_rights(self):
formats=True,
)
- pool, ignore = circulation_data.license_pool(
- self._db, self._default_collection
- )
+ pool, ignore = circulation_data.license_pool(self._db, self._default_collection)
circulation_data.apply(self._db, pool.collection, replace)
assert True == pool.open_access
assert 1 == len(pool.delivery_mechanisms)
@@ -546,13 +554,13 @@ def test_rights_status_commercial_link_with_rights(self):
formats=True,
)
- pool, ignore = circulation_data.license_pool(
- self._db, self._default_collection
- )
+ pool, ignore = circulation_data.license_pool(self._db, self._default_collection)
circulation_data.apply(self._db, pool.collection, replace)
assert False == pool.open_access
assert 1 == len(pool.delivery_mechanisms)
- assert RightsStatus.IN_COPYRIGHT == pool.delivery_mechanisms[0].rights_status.uri
+ assert (
+ RightsStatus.IN_COPYRIGHT == pool.delivery_mechanisms[0].rights_status.uri
+ )
def test_format_change_may_change_open_access_status(self):
@@ -579,24 +587,21 @@ def test_format_change_may_change_open_access_status(self):
)
# Applying this information turns the pool into an open-access pool.
- circulation_data.apply(
- self._db, pool.collection, replace=replace_formats
- )
+ circulation_data.apply(self._db, pool.collection, replace=replace_formats)
assert True == pool.open_access
# Then we find out it was a mistake -- the book is in copyright.
format = FormatData(
- Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM,
- rights_uri=RightsStatus.IN_COPYRIGHT
+ Representation.EPUB_MEDIA_TYPE,
+ DeliveryMechanism.NO_DRM,
+ rights_uri=RightsStatus.IN_COPYRIGHT,
)
circulation_data = CirculationData(
data_source=pool.data_source,
primary_identifier=pool.identifier,
- formats=[format]
- )
- circulation_data.apply(
- self._db, pool.collection, replace=replace_formats
+ formats=[format],
)
+ circulation_data.apply(self._db, pool.collection, replace=replace_formats)
# The original LPDM has been removed and only the new one remains.
assert False == pool.open_access
@@ -604,43 +609,44 @@ def test_format_change_may_change_open_access_status(self):
class TestMetaToModelUtility(DatabaseTest):
-
def test_open_access_content_mirrored(self):
# Make sure that open access material links are translated to our S3 buckets, and that
# commercial material links are left as is.
# Note: Mirroring tests passing does not guarantee that all code now
# correctly calls on CirculationData, as well as Metadata. This is a risk.
- mirrors = dict(books_mirror=MockS3Uploader(),covers_mirror=None)
+ mirrors = dict(books_mirror=MockS3Uploader(), covers_mirror=None)
mirror_type = ExternalIntegrationLink.OPEN_ACCESS_BOOKS
# Here's a book.
edition, pool = self._edition(with_license_pool=True)
# Here's a link to the content of the book, which will be mirrored.
link_mirrored = LinkData(
- rel=Hyperlink.OPEN_ACCESS_DOWNLOAD, href="http://example.com/",
+ rel=Hyperlink.OPEN_ACCESS_DOWNLOAD,
+ href="http://example.com/",
media_type=Representation.EPUB_MEDIA_TYPE,
- content="i am a tiny book"
+ content="i am a tiny book",
)
# This link will not be mirrored.
link_unmirrored = LinkData(
- rel=Hyperlink.DRM_ENCRYPTED_DOWNLOAD, href="http://example.com/2",
+ rel=Hyperlink.DRM_ENCRYPTED_DOWNLOAD,
+ href="http://example.com/2",
media_type=Representation.EPUB_MEDIA_TYPE,
- content="i am a pricy book"
+ content="i am a pricy book",
)
# Apply the metadata.
policy = ReplacementPolicy(mirrors=mirrors)
- metadata = Metadata(data_source=edition.data_source,
- links=[link_mirrored, link_unmirrored],
- )
+ metadata = Metadata(
+ data_source=edition.data_source,
+ links=[link_mirrored, link_unmirrored],
+ )
metadata.apply(edition, pool.collection, replace=policy)
# make sure the refactor is done right, and metadata does not upload
assert 0 == len(mirrors[mirror_type].uploaded)
-
circulation_data = CirculationData(
data_source=edition.data_source,
primary_identifier=edition.primary_identifier,
@@ -655,30 +661,29 @@ def test_open_access_content_mirrored(self):
[book] = mirrors[mirror_type].uploaded
# It's remained an open-access link.
- assert (
- [Hyperlink.OPEN_ACCESS_DOWNLOAD] ==
- [x.rel for x in book.resource.links])
-
+ assert [Hyperlink.OPEN_ACCESS_DOWNLOAD] == [x.rel for x in book.resource.links]
# It's been 'mirrored' to the appropriate S3 bucket.
- assert book.mirror_url.startswith('https://test-content-bucket.s3.amazonaws.com/')
- expect = '/%s/%s.epub' % (
- edition.primary_identifier.identifier,
- edition.title
+ assert book.mirror_url.startswith(
+ "https://test-content-bucket.s3.amazonaws.com/"
)
+ expect = "/%s/%s.epub" % (edition.primary_identifier.identifier, edition.title)
assert book.mirror_url.endswith(expect)
# make sure the mirrored link is safely on edition
sorted_edition_links = sorted(pool.identifier.links, key=lambda x: x.rel)
- unmirrored_representation, mirrored_representation = [edlink.resource.representation for edlink in sorted_edition_links]
- assert mirrored_representation.mirror_url.startswith('https://test-content-bucket.s3.amazonaws.com/')
+ unmirrored_representation, mirrored_representation = [
+ edlink.resource.representation for edlink in sorted_edition_links
+ ]
+ assert mirrored_representation.mirror_url.startswith(
+ "https://test-content-bucket.s3.amazonaws.com/"
+ )
# make sure the unmirrored link is safely on edition
- assert 'http://example.com/2' == unmirrored_representation.url
+ assert "http://example.com/2" == unmirrored_representation.url
# make sure the unmirrored link has not been translated to an S3 URL
assert None == unmirrored_representation.mirror_url
-
def test_mirror_open_access_link_fetch_failure(self):
mirrors = dict(books_mirror=MockS3Uploader())
h = DummyHTTPClient()
@@ -699,8 +704,11 @@ def test_mirror_open_access_link_fetch_failure(self):
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content,
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
h.queue_response(403)
@@ -721,9 +729,8 @@ def test_mirror_open_access_link_fetch_failure(self):
assert True == pool.suppressed
assert representation.fetch_exception in pool.license_exception
-
def test_mirror_open_access_link_mirror_failure(self):
- mirrors = dict(books_mirror=MockS3Uploader(fail=True),covers_mirror=None)
+ mirrors = dict(books_mirror=MockS3Uploader(fail=True), covers_mirror=None)
h = DummyHTTPClient()
edition, pool = self._edition(with_license_pool=True)
@@ -743,8 +750,11 @@ def test_mirror_open_access_link_mirror_failure(self):
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
h.queue_response(200, media_type=Representation.EPUB_MEDIA_TYPE)
@@ -827,4 +837,3 @@ def test_availability_needs_update(self):
pool.last_checked = now
assert True == recent_data._availability_needs_update(pool)
assert False == old_data._availability_needs_update(pool)
-
diff --git a/tests/test_config.py b/tests/test_config.py
index 9a3b0b3cd..b6c272140 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,19 +1,18 @@
import os
+
from sqlalchemy.orm.session import Session
+from ..config import Configuration as BaseConfiguration
+from ..model import ConfigurationSetting, ExternalIntegration
from ..testing import DatabaseTest
-from ..config import Configuration as BaseConfiguration
-from ..model import (
- ConfigurationSetting,
- ExternalIntegration,
-)
# Create a configuration object that the tests can run against without
# impacting the real configuration object.
class MockConfiguration(BaseConfiguration):
instance = None
+
class TestConfiguration(DatabaseTest):
Conf = MockConfiguration
@@ -31,7 +30,7 @@ def teardown_method(self):
super(TestConfiguration, self).teardown_method()
def create_version_file(self, content):
- with open(self.VERSION_FILENAME, 'w') as f:
+ with open(self.VERSION_FILENAME, "w") as f:
f.write(content)
def test_app_version(self):
@@ -41,40 +40,37 @@ def test_app_version(self):
result = self.Conf.app_version()
assert self.Conf.APP_VERSION in self.Conf.instance
assert self.Conf.NO_APP_VERSION_FOUND == result
- assert (
- self.Conf.NO_APP_VERSION_FOUND ==
- self.Conf.get(self.Conf.APP_VERSION))
+ assert self.Conf.NO_APP_VERSION_FOUND == self.Conf.get(self.Conf.APP_VERSION)
# An empty .version file yields the same results.
self.Conf.instance = dict()
- self.create_version_file(' \n')
+ self.create_version_file(" \n")
result = self.Conf.app_version()
assert self.Conf.NO_APP_VERSION_FOUND == result
- assert (
- self.Conf.NO_APP_VERSION_FOUND ==
- self.Conf.get(self.Conf.APP_VERSION))
+ assert self.Conf.NO_APP_VERSION_FOUND == self.Conf.get(self.Conf.APP_VERSION)
# A .version file with content loads the content.
self.Conf.instance = dict()
- self.create_version_file('ba.na.na')
+ self.create_version_file("ba.na.na")
result = self.Conf.app_version()
- assert 'ba.na.na' == result
- assert 'ba.na.na' == self.Conf.get(self.Conf.APP_VERSION)
+ assert "ba.na.na" == result
+ assert "ba.na.na" == self.Conf.get(self.Conf.APP_VERSION)
def test_load_cdns(self):
- """Test our ability to load CDN configuration from the database.
- """
+ """Test our ability to load CDN configuration from the database."""
self._external_integration(
protocol=ExternalIntegration.CDN,
goal=ExternalIntegration.CDN_GOAL,
- settings = { self.Conf.CDN_MIRRORED_DOMAIN_KEY : "site.com",
- ExternalIntegration.URL : "http://cdn/" }
+ settings={
+ self.Conf.CDN_MIRRORED_DOMAIN_KEY: "site.com",
+ ExternalIntegration.URL: "http://cdn/",
+ },
)
self.Conf.load_cdns(self._db)
integrations = self.Conf.instance[self.Conf.INTEGRATIONS]
- assert {'site.com' : 'http://cdn/'} == integrations[ExternalIntegration.CDN]
+ assert {"site.com": "http://cdn/"} == integrations[ExternalIntegration.CDN]
assert True == self.Conf.instance[self.Conf.CDNS_LOADED_FROM_DATABASE]
def test_cdns_loaded_dynamically(self):
diff --git a/tests/test_coverage.py b/tests/test_coverage.py
index 93e20e31a..9201eb665 100644
--- a/tests/test_coverage.py
+++ b/tests/test_coverage.py
@@ -1,20 +1,30 @@
import datetime
+
import pytest
-from ..testing import (
- DatabaseTest
+
+from ..coverage import (
+ BaseCoverageProvider,
+ BibliographicCoverageProvider,
+ CatalogCoverageProvider,
+ CollectionCoverageProvider,
+ CoverageFailure,
+ CoverageProviderProgress,
+ IdentifierCoverageProvider,
+ MARCRecordWorkCoverageProvider,
+ OPDSEntryWorkCoverageProvider,
+ PresentationReadyWorkCoverageProvider,
+ WorkClassificationCoverageProvider,
+ WorkPresentationEditionCoverageProvider,
)
-from ..testing import (
- AlwaysSuccessfulBibliographicCoverageProvider,
- AlwaysSuccessfulCollectionCoverageProvider,
- AlwaysSuccessfulCoverageProvider,
- AlwaysSuccessfulWorkCoverageProvider,
- DummyHTTPClient,
- TaskIgnoringCoverageProvider,
- NeverSuccessfulBibliographicCoverageProvider,
- NeverSuccessfulWorkCoverageProvider,
- NeverSuccessfulCoverageProvider,
- TransientFailureCoverageProvider,
- TransientFailureWorkCoverageProvider,
+from ..metadata_layer import (
+ CirculationData,
+ ContributorData,
+ FormatData,
+ IdentifierData,
+ LinkData,
+ Metadata,
+ ReplacementPolicy,
+ SubjectData,
)
from ..model import (
Collection,
@@ -36,33 +46,24 @@
WorkCoverageRecord,
)
from ..model.configuration import ExternalIntegrationLink
-from ..metadata_layer import (
- Metadata,
- CirculationData,
- FormatData,
- IdentifierData,
- ContributorData,
- LinkData,
- ReplacementPolicy,
- SubjectData,
-)
from ..s3 import MockS3Uploader
-from ..coverage import (
- BaseCoverageProvider,
- BibliographicCoverageProvider,
- CatalogCoverageProvider,
- CollectionCoverageProvider,
- CoverageFailure,
- CoverageProviderProgress,
- IdentifierCoverageProvider,
- OPDSEntryWorkCoverageProvider,
- MARCRecordWorkCoverageProvider,
- PresentationReadyWorkCoverageProvider,
- WorkClassificationCoverageProvider,
- WorkPresentationEditionCoverageProvider,
+from ..testing import (
+ AlwaysSuccessfulBibliographicCoverageProvider,
+ AlwaysSuccessfulCollectionCoverageProvider,
+ AlwaysSuccessfulCoverageProvider,
+ AlwaysSuccessfulWorkCoverageProvider,
+ DatabaseTest,
+ DummyHTTPClient,
+ NeverSuccessfulBibliographicCoverageProvider,
+ NeverSuccessfulCoverageProvider,
+ NeverSuccessfulWorkCoverageProvider,
+ TaskIgnoringCoverageProvider,
+ TransientFailureCoverageProvider,
+ TransientFailureWorkCoverageProvider,
)
from ..util.datetime_helpers import datetime_utc, utc_now
+
class TestCoverageFailure(DatabaseTest):
"""Test the CoverageFailure class."""
@@ -91,9 +92,7 @@ def test_to_coverage_record(self):
def test_to_work_coverage_record(self):
work = self._work()
- transient_failure = CoverageFailure(
- work, "Bah!", transient=True
- )
+ transient_failure = CoverageFailure(work, "Bah!", transient=True)
rec = transient_failure.to_work_coverage_record("the_operation")
assert isinstance(rec, WorkCoverageRecord)
assert work == rec.work
@@ -101,18 +100,13 @@ def test_to_work_coverage_record(self):
assert CoverageRecord.TRANSIENT_FAILURE == rec.status
assert "Bah!" == rec.exception
- persistent_failure = CoverageFailure(
- work, "Bah forever!", transient=False
- )
- rec = persistent_failure.to_work_coverage_record(
- operation="the_operation"
- )
+ persistent_failure = CoverageFailure(work, "Bah forever!", transient=False)
+ rec = persistent_failure.to_work_coverage_record(operation="the_operation")
assert CoverageRecord.PERSISTENT_FAILURE == rec.status
assert "Bah forever!" == rec.exception
class TestCoverageProviderProgress(object):
-
def test_achievements(self):
progress = CoverageProviderProgress()
progress.successes = 1
@@ -132,39 +126,35 @@ class CoverageProviderTest(DatabaseTest):
def bibliographic_data(self):
return Metadata(
DataSource.OVERDRIVE,
- publisher='Perfection Learning',
- language='eng',
- title='A Girl Named Disaster',
+ publisher="Perfection Learning",
+ language="eng",
+ title="A Girl Named Disaster",
published=datetime_utc(1998, 3, 1, 0, 0),
primary_identifier=IdentifierData(
type=Identifier.OVERDRIVE_ID,
- identifier='ba9b3419-b0bd-4ca7-a24f-26c4246b6b44'
+ identifier="ba9b3419-b0bd-4ca7-a24f-26c4246b6b44",
),
- identifiers = [
+ identifiers=[
IdentifierData(
- type=Identifier.OVERDRIVE_ID,
- identifier='ba9b3419-b0bd-4ca7-a24f-26c4246b6b44'
- ),
- IdentifierData(type=Identifier.ISBN, identifier='9781402550805')
+ type=Identifier.OVERDRIVE_ID,
+ identifier="ba9b3419-b0bd-4ca7-a24f-26c4246b6b44",
+ ),
+ IdentifierData(type=Identifier.ISBN, identifier="9781402550805"),
],
- contributors = [
- ContributorData(sort_name="Nancy Farmer",
- roles=[Contributor.PRIMARY_AUTHOR_ROLE])
+ contributors=[
+ ContributorData(
+ sort_name="Nancy Farmer", roles=[Contributor.PRIMARY_AUTHOR_ROLE]
+ )
],
- subjects = [
- SubjectData(type=Subject.TOPIC,
- identifier='Action & Adventure'),
- SubjectData(type=Subject.FREEFORM_AUDIENCE,
- identifier='Young Adult'),
- SubjectData(type=Subject.PLACE, identifier='Africa')
+ subjects=[
+ SubjectData(type=Subject.TOPIC, identifier="Action & Adventure"),
+ SubjectData(type=Subject.FREEFORM_AUDIENCE, identifier="Young Adult"),
+ SubjectData(type=Subject.PLACE, identifier="Africa"),
],
)
-
-
class TestBaseCoverageProvider(CoverageProviderTest):
-
def test_instantiation(self):
"""Verify variable initialization."""
@@ -173,7 +163,7 @@ class ValidMock(BaseCoverageProvider):
OPERATION = "An Operation"
DEFAULT_BATCH_SIZE = 50
- now = cutoff_time=utc_now()
+ now = cutoff_time = utc_now()
provider = ValidMock(self._db, cutoff_time=now)
# Class variables defined in subclasses become appropriate
@@ -197,6 +187,7 @@ class NoServiceName(BaseCoverageProvider):
def test_run(self):
"""Verify that run() calls run_once_and_update_timestamp()."""
+
class MockProvider(BaseCoverageProvider):
SERVICE_NAME = "I do nothing"
was_run = False
@@ -231,12 +222,14 @@ class MockProvider(BaseCoverageProvider):
"""A BaseCoverageProvider that returns a strange
CoverageProviderProgress representing the work it did.
"""
+
SERVICE_NAME = "I do nothing"
was_run = False
custom_timestamp_data = CoverageProviderProgress(
start=start, finish=finish, counter=counter
)
+
def run_once_and_update_timestamp(self):
return self.custom_timestamp_data
@@ -257,6 +250,7 @@ def test_run_once_and_update_timestamp(self):
"""Test that run_once_and_update_timestamp calls run_once until all
the work is done, and then updates a Timestamp.
"""
+
class MockProvider(BaseCoverageProvider):
SERVICE_NAME = "I do nothing"
run_once_calls = []
@@ -300,8 +294,9 @@ def run_once(self, progress, count_as_covered=None):
# We start with no Timestamp.
service_name = "I do nothing"
service_type = Timestamp.COVERAGE_PROVIDER_TYPE
- timestamp = Timestamp.value(self._db, service_name, service_type,
- collection=None)
+ timestamp = Timestamp.value(
+ self._db, service_name, service_type, collection=None
+ )
assert None == timestamp
# Instantiate the Provider, and call
@@ -401,22 +396,18 @@ def test_run_once(self):
# We previously tried to cover one of them, but got a
# transient failure.
self._coverage_record(
- transient, data_source,
- status=CoverageRecord.TRANSIENT_FAILURE
+ transient, data_source, status=CoverageRecord.TRANSIENT_FAILURE
)
# Another of the four has a persistent failure.
self._coverage_record(
- persistent, data_source,
- status=CoverageRecord.PERSISTENT_FAILURE
+ persistent, data_source, status=CoverageRecord.PERSISTENT_FAILURE
)
# The third one has no coverage record at all.
# And the fourth one has been successfully covered.
- self._coverage_record(
- covered, data_source, status=CoverageRecord.SUCCESS
- )
+ self._coverage_record(covered, data_source, status=CoverageRecord.SUCCESS)
# Now let's run the coverage provider. Every Identifier
# that's covered will succeed, so the question is which ones
@@ -451,19 +442,17 @@ def test_run_once(self):
# Nothing happened to the identifier that had a persistent
# failure or the identifier that was successfully covered.
- assert ([CoverageRecord.PERSISTENT_FAILURE] ==
- [x.status for x in persistent.coverage_records])
- assert ([CoverageRecord.SUCCESS] ==
- [x.status for x in covered.coverage_records])
+ assert [CoverageRecord.PERSISTENT_FAILURE] == [
+ x.status for x in persistent.coverage_records
+ ]
+ assert [CoverageRecord.SUCCESS] == [x.status for x in covered.coverage_records]
assert persistent not in provider.attempts
assert covered not in provider.attempts
# We can change which identifiers get processed by changing
# what counts as 'coverage'.
- result = provider.run_once(
- progress, count_as_covered=[CoverageRecord.SUCCESS]
- )
+ result = provider.run_once(progress, count_as_covered=[CoverageRecord.SUCCESS])
assert progress == result
assert 0 == progress.offset
@@ -488,7 +477,6 @@ def test_run_once(self):
assert 4 == progress.offset
def test_run_once_records_successes_and_failures(self):
-
class Mock(AlwaysSuccessfulCoverageProvider):
def process_batch_and_handle_results(self, batch):
# Simulate 1 success, 2 transient failures,
@@ -515,8 +503,9 @@ def process_batch_and_handle_results(self, batch):
assert 3 == progress.persistent_failures
assert (
- "Items processed: 6. Successes: 1, transient failures: 2, persistent failures: 3" ==
- progress.achievements)
+ "Items processed: 6. Successes: 1, transient failures: 2, persistent failures: 3"
+ == progress.achievements
+ )
def test_process_batch_and_handle_results(self):
"""Test that process_batch_and_handle_results passes the identifiers
@@ -530,7 +519,7 @@ def test_process_batch_and_handle_results(self):
i2 = e2.primary_identifier
class MockProvider(AlwaysSuccessfulCoverageProvider):
- OPERATION = 'i succeed'
+ OPERATION = "i succeed"
def finalize_batch(self):
self.finalized = True
@@ -554,7 +543,7 @@ def finalize_batch(self):
assert set([i1, i2]) == set([x.identifier for x in successes])
# ...and with the coverage provider's operation.
- assert ['i succeed'] * 2 == [x.operation for x in successes]
+ assert ["i succeed"] * 2 == [x.operation for x in successes]
# Now try a different CoverageProvider which creates transient
# failures.
@@ -562,47 +551,53 @@ class MockProvider(TransientFailureCoverageProvider):
OPERATION = "i fail transiently"
transient_failure_provider = MockProvider(self._db)
- counts, failures = transient_failure_provider.process_batch_and_handle_results(batch)
+ counts, failures = transient_failure_provider.process_batch_and_handle_results(
+ batch
+ )
# Two transient failures.
assert (0, 2, 0) == counts
# New coverage records were added to track the transient
# failures.
- assert ([CoverageRecord.TRANSIENT_FAILURE] * 2 ==
- [x.status for x in failures])
+ assert [CoverageRecord.TRANSIENT_FAILURE] * 2 == [x.status for x in failures]
assert ["i fail transiently"] * 2 == [x.operation for x in failures]
# Another way of getting transient failures is to just ignore every
# item you're told to process.
class MockProvider(TaskIgnoringCoverageProvider):
OPERATION = "i ignore"
+
task_ignoring_provider = MockProvider(self._db)
counts, records = task_ignoring_provider.process_batch_and_handle_results(batch)
assert (0, 2, 0) == counts
- assert ([CoverageRecord.TRANSIENT_FAILURE] * 2 ==
- [x.status for x in records])
+ assert [CoverageRecord.TRANSIENT_FAILURE] * 2 == [x.status for x in records]
assert ["i ignore"] * 2 == [x.operation for x in records]
# If a transient failure becomes a success, the it won't have
# an exception anymore.
- assert ['Was ignored by CoverageProvider.'] * 2 == [x.exception for x in records]
+ assert ["Was ignored by CoverageProvider."] * 2 == [
+ x.exception for x in records
+ ]
records = success_provider.process_batch_and_handle_results(batch)[1]
assert [None, None] == [x.exception for x in records]
# Or you can go really bad and have persistent failures.
class MockProvider(NeverSuccessfulCoverageProvider):
OPERATION = "i will always fail"
+
persistent_failure_provider = MockProvider(self._db)
- counts, results = persistent_failure_provider.process_batch_and_handle_results(batch)
+ counts, results = persistent_failure_provider.process_batch_and_handle_results(
+ batch
+ )
# Two persistent failures.
assert (0, 0, 2) == counts
assert all([isinstance(x, CoverageRecord) for x in results])
- assert (["What did you expect?", "What did you expect?"] ==
- [x.exception for x in results])
- assert ([CoverageRecord.PERSISTENT_FAILURE] * 2 ==
- [x.status for x in results])
+ assert ["What did you expect?", "What did you expect?"] == [
+ x.exception for x in results
+ ]
+ assert [CoverageRecord.PERSISTENT_FAILURE] * 2 == [x.status for x in results]
assert ["i will always fail"] * 2 == [x.operation for x in results]
def test_process_batch(self):
@@ -649,18 +644,14 @@ def test_should_update(self):
ask if a CoverageRecord needs to be updated.
"""
cutoff = datetime_utc(2016, 1, 1)
- provider = AlwaysSuccessfulCoverageProvider(
- self._db, cutoff_time = cutoff
- )
+ provider = AlwaysSuccessfulCoverageProvider(self._db, cutoff_time=cutoff)
identifier = self._identifier()
# If coverage is missing, we should update.
assert True == provider.should_update(None)
# If coverage is outdated, we should update.
- record, ignore = CoverageRecord.add_for(
- identifier, provider.data_source
- )
+ record, ignore = CoverageRecord.add_for(identifier, provider.data_source)
record.timestamp = datetime_utc(2015, 1, 1)
assert True == provider.should_update(record)
@@ -674,7 +665,6 @@ def test_should_update(self):
class TestIdentifierCoverageProvider(CoverageProviderTest):
-
def setup_method(self):
super(TestIdentifierCoverageProvider, self).setup_method()
self.identifier = self._identifier()
@@ -691,28 +681,37 @@ class Base(IdentifierCoverageProvider):
class MockProvider(Base):
INPUT_IDENTIFIER_TYPES = None
+
provider = MockProvider(self._db)
assert None == provider.input_identifier_types
# It's okay to set a single value.
class MockProvider(Base):
INPUT_IDENTIFIER_TYPES = Identifier.ISBN
+
provider = MockProvider(self._db)
assert [Identifier.ISBN] == provider.input_identifier_types
# It's okay to set a list of values.
class MockProvider(Base):
INPUT_IDENTIFIER_TYPES = [Identifier.ISBN, Identifier.OVERDRIVE_ID]
+
provider = MockProvider(self._db)
- assert ([Identifier.ISBN, Identifier.OVERDRIVE_ID] ==
- provider.input_identifier_types)
+ assert [
+ Identifier.ISBN,
+ Identifier.OVERDRIVE_ID,
+ ] == provider.input_identifier_types
# It's not okay to do nothing.
class MockProvider(Base):
pass
+
with pytest.raises(ValueError) as excinfo:
MockProvider(self._db)
- assert "MockProvider must define INPUT_IDENTIFIER_TYPES, even if the value is None." in str(excinfo.value)
+ assert (
+ "MockProvider must define INPUT_IDENTIFIER_TYPES, even if the value is None."
+ in str(excinfo.value)
+ )
def test_can_cover(self):
"""Verify that can_cover gives the correct answer when
@@ -727,9 +726,7 @@ def test_can_cover(self):
assert True == m(identifier)
# This provider handles ISBNs.
- provider.input_identifier_types = [
- Identifier.OVERDRIVE_ID, Identifier.ISBN
- ]
+ provider.input_identifier_types = [Identifier.OVERDRIVE_ID, Identifier.ISBN]
assert True == m(identifier)
# This provider doesn't.
@@ -745,9 +742,7 @@ def test_replacement_policy(self):
assert False == provider.replacement_policy.formats
policy = ReplacementPolicy.from_license_source(self._db)
- provider = AlwaysSuccessfulCoverageProvider(
- self._db, replacement_policy=policy
- )
+ provider = AlwaysSuccessfulCoverageProvider(self._db, replacement_policy=policy)
assert policy == provider.replacement_policy
def test_register(self):
@@ -782,9 +777,7 @@ def test_bulk_register(self):
i1 = self._identifier()
covered = self._identifier()
- existing = self._coverage_record(
- covered, source, operation=provider.OPERATION
- )
+ existing = self._coverage_record(covered, source, operation=provider.OPERATION)
new_records, ignored_identifiers = provider.bulk_register([i1, covered])
@@ -820,8 +813,10 @@ def test_bulk_register_with_collection(self):
# If a DataSource or data source name is provided and
# autocreate is set True, the record is created with that source.
provider.bulk_register(
- [self.identifier], data_source=collection.name,
- collection=collection, autocreate=True
+ [self.identifier],
+ data_source=collection.name,
+ collection=collection,
+ autocreate=True,
)
[record] = self.identifier.coverage_records
@@ -844,7 +839,7 @@ def test_bulk_register_with_collection(self):
)
records = self.identifier.coverage_records
assert 2 == len(records)
- assert [r for r in records if r.collection==collection]
+ assert [r for r in records if r.collection == collection]
finally:
# Return the mock class to its original state for other tests.
provider.COVERAGE_COUNTS_FOR_EVERY_COLLECTION = True
@@ -853,9 +848,7 @@ def test_ensure_coverage(self):
"""Verify that ensure_coverage creates a CoverageRecord for an
Identifier, assuming that the CoverageProvider succeeds.
"""
- provider = AlwaysSuccessfulCollectionCoverageProvider(
- self._default_collection
- )
+ provider = AlwaysSuccessfulCollectionCoverageProvider(self._default_collection)
provider.OPERATION = self._str
record = provider.ensure_coverage(self.identifier)
assert isinstance(record, CoverageRecord)
@@ -877,11 +870,12 @@ def test_ensure_coverage(self):
# The coverage provider's timestamp was not updated, because
# we're using ensure_coverage on a single record.
- assert (None ==
- Timestamp.value(
- self._db, provider.service_name,
- Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
- ))
+ assert None == Timestamp.value(
+ self._db,
+ provider.service_name,
+ Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
+ )
# Now let's try a CollectionCoverageProvider that needs to
# grant coverage separately for every collection.
@@ -907,10 +901,12 @@ def test_ensure_coverage_respects_operation(self):
# Two providers with the same output source but different operations.
class Mock1(AlwaysSuccessfulCoverageProvider):
OPERATION = "foo"
+
provider1 = Mock1(self._db)
class Mock2(NeverSuccessfulCoverageProvider):
OPERATION = "bar"
+
provider2 = Mock2(self._db)
# Ensure coverage from both providers.
@@ -918,7 +914,7 @@ class Mock2(NeverSuccessfulCoverageProvider):
assert "foo" == coverage1.operation
old_timestamp = coverage1.timestamp
- coverage2 = provider2.ensure_coverage(self.identifier)
+ coverage2 = provider2.ensure_coverage(self.identifier)
assert "bar" == coverage2.operation
# There are now two CoverageRecords, one for each operation.
@@ -948,11 +944,12 @@ def test_ensure_coverage_persistent_coverage_failure(self):
# we're using ensure_coverage.
# The coverage provider's timestamp was not updated, because
# we're using ensure_coverage on a single record.
- assert (None ==
- Timestamp.value(
- self._db, provider.service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
- ))
+ assert None == Timestamp.value(
+ self._db,
+ provider.service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
+ )
def test_ensure_coverage_transient_coverage_failure(self):
@@ -963,11 +960,12 @@ def test_ensure_coverage_transient_coverage_failure(self):
assert "Oops!" == failure.exception
# Timestamp was not updated.
- assert (None ==
- Timestamp.value(
- self._db, provider.service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
- ))
+ assert None == Timestamp.value(
+ self._db,
+ provider.service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
+ )
def test_ensure_coverage_changes_status(self):
"""Verify that processing an item that has a preexisting
@@ -1039,7 +1037,7 @@ def test_set_metadata(self, bibliographic_data):
# It can't set circulation data, because it's not a
# CollectionCoverageProvider.
- assert not hasattr(provider, 'set_metadata_and_circulationdata')
+ assert not hasattr(provider, "set_metadata_and_circulationdata")
# But it can set metadata.
identifier = self._identifier(
@@ -1070,9 +1068,7 @@ def test_set_metadata(self, bibliographic_data):
assert "ValueError" in result.exception
def test_items_that_need_coverage_respects_registration_reqs(self):
- provider = AlwaysSuccessfulCoverageProvider(
- self._db, registered_only=True
- )
+ provider = AlwaysSuccessfulCoverageProvider(self._db, registered_only=True)
items = provider.items_that_need_coverage()
assert self.identifier not in items
@@ -1084,14 +1080,15 @@ def test_items_that_need_coverage_respects_registration_reqs(self):
# With a failing CoverageRecord, the item shows up.
[record] = self.identifier.coverage_records
record.status = CoverageRecord.TRANSIENT_FAILURE
- record.exception = 'Oh no!'
+ record.exception = "Oh no!"
assert self.identifier in items
def test_items_that_need_coverage_respects_operation(self):
# Here's a provider that carries out the 'foo' operation.
class Mock1(AlwaysSuccessfulCoverageProvider):
- OPERATION = 'foo'
+ OPERATION = "foo"
+
provider = Mock1(self._db)
# Here's a generic CoverageRecord for an identifier.
@@ -1118,7 +1115,7 @@ def test_run_on_specific_identifiers(self):
counts, records = provider.run_on_specific_identifiers(to_be_tested)
# Six identifiers were covered in two batches.
- assert (6,0,0) == counts
+ assert (6, 0, 0) == counts
assert 6 == len(records)
# Only the identifiers in to_be_tested were covered.
@@ -1138,15 +1135,15 @@ def test_run_on_specific_identifiers_respects_cutoff_time(self):
# ever run the coverage provider again we will get a
# persistent failure.
provider = NeverSuccessfulCoverageProvider(self._db)
- record, ignore = CoverageRecord.add_for(
- self.identifier, provider.data_source
- )
+ record, ignore = CoverageRecord.add_for(self.identifier, provider.data_source)
record.timestamp = last_run
# You might think this would result in a persistent failure...
- (success, transient_failure, persistent_failure), records = (
- provider.run_on_specific_identifiers([self.identifier])
- )
+ (
+ success,
+ transient_failure,
+ persistent_failure,
+ ), records = provider.run_on_specific_identifiers([self.identifier])
# ...but we get an automatic success. We didn't even try to
# run the coverage provider on self.identifier because the
@@ -1158,9 +1155,11 @@ def test_run_on_specific_identifiers_respects_cutoff_time(self):
# But if we move the cutoff time forward, the provider will run
# on self.identifier and fail.
provider.cutoff_time = datetime_utc(2016, 2, 1)
- (success, transient_failure, persistent_failure), records = (
- provider.run_on_specific_identifiers([self.identifier])
- )
+ (
+ success,
+ transient_failure,
+ persistent_failure,
+ ), records = provider.run_on_specific_identifiers([self.identifier])
assert 0 == success
assert 1 == persistent_failure
@@ -1178,11 +1177,12 @@ def test_run_never_successful(self):
# We start with no CoverageRecords and no Timestamp.
assert [] == self._db.query(CoverageRecord).all()
- assert (None ==
- Timestamp.value(
- self._db, provider.service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
- ))
+ assert None == Timestamp.value(
+ self._db,
+ provider.service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
+ )
provider.run()
@@ -1195,8 +1195,10 @@ def test_run_never_successful(self):
# But the coverage provider did run, and the timestamp is now set to
# a recent value.
value = Timestamp.value(
- self._db, provider.service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
+ self._db,
+ provider.service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
)
assert (utc_now() - value).total_seconds() < 1
@@ -1209,11 +1211,12 @@ def test_run_transient_failure(self):
# We start with no CoverageRecords and no Timestamp.
assert [] == self._db.query(CoverageRecord).all()
- assert (None ==
- Timestamp.value(
- self._db, provider.service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
- ))
+ assert None == Timestamp.value(
+ self._db,
+ provider.service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
+ )
now = utc_now()
provider.run()
@@ -1224,28 +1227,29 @@ def test_run_transient_failure(self):
# The timestamp was set.
timestamp = Timestamp.value(
- self._db, provider.service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
+ self._db,
+ provider.service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
)
- assert (timestamp-now).total_seconds() < 1
+ assert (timestamp - now).total_seconds() < 1
def test_add_coverage_record_for(self):
"""Calling CollectionCoverageProvider.add_coverage_record is the same
as calling CoverageRecord.add_for with the relevant
information.
"""
- provider = AlwaysSuccessfulCollectionCoverageProvider(
- self._default_collection
- )
+ provider = AlwaysSuccessfulCollectionCoverageProvider(self._default_collection)
identifier = self._identifier()
record = provider.add_coverage_record_for(identifier)
# This is the same as calling CoverageRecord.add_for with
# appropriate arguments.
record2, is_new = CoverageRecord.add_for(
- identifier, data_source=provider.data_source,
+ identifier,
+ data_source=provider.data_source,
operation=provider.operation,
- collection=provider.collection_or_not
+ collection=provider.collection_or_not,
)
assert False == is_new
assert record == record2
@@ -1262,25 +1266,21 @@ def test_add_coverage_record_for(self):
assert self._default_collection == record.collection
record2, is_new = CoverageRecord.add_for(
- identifier, data_source=provider.data_source,
+ identifier,
+ data_source=provider.data_source,
operation=provider.operation,
- collection=provider.collection_or_not
+ collection=provider.collection_or_not,
)
assert False == is_new
assert record == record2
-
def test_record_failure_as_coverage_record(self):
"""TODO: We need test coverage here."""
def test_failure(self):
- provider = AlwaysSuccessfulCollectionCoverageProvider(
- self._default_collection
- )
+ provider = AlwaysSuccessfulCollectionCoverageProvider(self._default_collection)
identifier = self._identifier()
- failure = provider.failure(
- identifier, error="an error", transient=False
- )
+ failure = provider.failure(identifier, error="an error", transient=False)
assert provider.data_source == failure.data_source
assert "an error" == failure.exception
assert False == failure.transient
@@ -1293,9 +1293,7 @@ def test_failure(self):
# will change that -- a failure will only count for the
# collection associated with the CoverageProvider.
provider.COVERAGE_COUNTS_FOR_EVERY_COLLECTION = False
- failure = provider.failure(
- identifier, error="an error", transient=False
- )
+ failure = provider.failure(identifier, error="an error", transient=False)
assert self._default_collection == failure.collection
def test_failure_for_ignored_item(self):
@@ -1312,7 +1310,6 @@ def test_failure_for_ignored_item(self):
class TestCollectionCoverageProvider(CoverageProviderTest):
-
@pytest.fixture
def circulation_data(self, bibliographic_data):
# This data is used to test the insertion of circulation data
@@ -1320,13 +1317,13 @@ def circulation_data(self, bibliographic_data):
return CirculationData(
DataSource.OVERDRIVE,
primary_identifier=bibliographic_data.primary_identifier,
- formats = [
+ formats=[
FormatData(
content_type=Representation.EPUB_MEDIA_TYPE,
drm_scheme=DeliveryMechanism.NO_DRM,
rights_uri=RightsStatus.IN_COPYRIGHT,
)
- ]
+ ],
)
def test_class_variables(self):
@@ -1340,24 +1337,30 @@ def test_class_variables(self):
def test_must_have_collection(self):
with pytest.raises(CollectionMissing) as excinfo:
AlwaysSuccessfulCollectionCoverageProvider(None)
- assert "AlwaysSuccessfulCollectionCoverageProvider must be instantiated with a Collection." in str(excinfo.value)
+ assert (
+ "AlwaysSuccessfulCollectionCoverageProvider must be instantiated with a Collection."
+ in str(excinfo.value)
+ )
def test_collection_protocol_must_match_class_protocol(self):
collection = self._collection(protocol=ExternalIntegration.OVERDRIVE)
with pytest.raises(ValueError) as excinfo:
AlwaysSuccessfulCollectionCoverageProvider(collection)
- assert "Collection protocol (Overdrive) does not match CoverageProvider protocol (OPDS Import)" in str(excinfo.value)
+ assert (
+ "Collection protocol (Overdrive) does not match CoverageProvider protocol (OPDS Import)"
+ in str(excinfo.value)
+ )
- def test_items_that_need_coverage_ignores_collection_when_collection_is_irrelevant(self):
+ def test_items_that_need_coverage_ignores_collection_when_collection_is_irrelevant(
+ self,
+ ):
# Two providers that do the same work, but one is associated
# with a collection and the other is not.
collection_provider = AlwaysSuccessfulCollectionCoverageProvider(
self._default_collection
)
- no_collection_provider = AlwaysSuccessfulCoverageProvider(
- self._db
- )
+ no_collection_provider = AlwaysSuccessfulCoverageProvider(self._db)
# This distinction is irrelevant because they both consider an
# Identifier covered when it has a CoverageRecord not
@@ -1365,8 +1368,7 @@ def test_items_that_need_coverage_ignores_collection_when_collection_is_irreleva
assert True == collection_provider.COVERAGE_COUNTS_FOR_EVERY_COLLECTION
assert True == no_collection_provider.COVERAGE_COUNTS_FOR_EVERY_COLLECTION
- assert (collection_provider.data_source ==
- no_collection_provider.data_source)
+ assert collection_provider.data_source == no_collection_provider.data_source
data_source = collection_provider.data_source
# Create a license pool belonging to the default collection.
@@ -1378,8 +1380,8 @@ def needs():
CoverageProviders.
"""
return tuple(
- p.items_that_need_coverage().all() for p in
- (collection_provider, no_collection_provider)
+ p.items_that_need_coverage().all()
+ for p in (collection_provider, no_collection_provider)
)
# We start out in the state where the identifier appears to need
@@ -1397,12 +1399,12 @@ def needs():
# Add coverage not associated with any collection, and both
# CoverageProviders consider it covered.
- self._coverage_record(
- identifier, data_source, collection=None
- )
+ self._coverage_record(identifier, data_source, collection=None)
assert ([], []) == needs()
- def test_items_that_need_coverage_respects_collection_when_collection_is_relevant(self):
+ def test_items_that_need_coverage_respects_collection_when_collection_is_relevant(
+ self,
+ ):
# Two providers that do the same work, but are associated
# with different collections.
@@ -1410,9 +1412,7 @@ def test_items_that_need_coverage_respects_collection_when_collection_is_relevan
self._default_collection
)
collection_2 = self._collection()
- collection_2_provider = AlwaysSuccessfulCollectionCoverageProvider(
- collection_2
- )
+ collection_2_provider = AlwaysSuccessfulCollectionCoverageProvider(collection_2)
# And one that does the same work but is not associated with
# any collection.
@@ -1425,8 +1425,7 @@ def test_items_that_need_coverage_respects_collection_when_collection_is_relevan
collection_2_provider.COVERAGE_COUNTS_FOR_EVERY_COLLECTION = False
no_collection_provider.COVERAGE_COUNTS_FOR_EVERY_COLLECTION = False
- assert (collection_1_provider.data_source ==
- collection_2_provider.data_source)
+ assert collection_1_provider.data_source == collection_2_provider.data_source
data_source = collection_1_provider.data_source
# Create a license pool belonging to the default collection so
@@ -1440,8 +1439,8 @@ def needs():
CoverageProviders.
"""
return tuple(
- p.items_that_need_coverage().all() for p in
- (collection_1_provider, no_collection_provider)
+ p.items_that_need_coverage().all()
+ for p in (collection_1_provider, no_collection_provider)
)
# We start out in the state where the identifier needs
@@ -1457,9 +1456,7 @@ def needs():
assert [] == collection_2_provider.items_that_need_coverage().all()
# Add coverage for an irrelevant collection, and nothing happens.
- self._coverage_record(
- identifier, data_source, collection=self._collection()
- )
+ self._coverage_record(identifier, data_source, collection=self._collection())
assert ([identifier], [identifier]) == needs()
# Add coverage for a relevant collection, and it's treated as
@@ -1479,9 +1476,7 @@ def test_replacement_policy(self):
"""Unless a different replacement policy is passed in, the
replacement policy is ReplacementPolicy.from_license_source().
"""
- provider = AlwaysSuccessfulCollectionCoverageProvider(
- self._default_collection
- )
+ provider = AlwaysSuccessfulCollectionCoverageProvider(self._default_collection)
assert True == provider.replacement_policy.identifiers
assert True == provider.replacement_policy.formats
@@ -1519,33 +1514,34 @@ def test_set_circulationdata_errors(self):
"""Verify that errors when setting circulation data
are turned into CoverageFailure objects.
"""
- provider = AlwaysSuccessfulCollectionCoverageProvider(
- self._default_collection
- )
+ provider = AlwaysSuccessfulCollectionCoverageProvider(self._default_collection)
identifier = self._identifier()
# No data.
failure = provider._set_circulationdata(identifier, None)
- assert ("Did not receive circulationdata from input source" ==
- failure.exception)
+ assert "Did not receive circulationdata from input source" == failure.exception
# No identifier in CirculationData.
empty = CirculationData(provider.data_source, primary_identifier=None)
failure = provider._set_circulationdata(identifier, empty)
- assert ("Identifier did not match CirculationData's primary identifier." ==
- failure.exception)
+ assert (
+ "Identifier did not match CirculationData's primary identifier."
+ == failure.exception
+ )
# Mismatched identifier in CirculationData.
- wrong = CirculationData(provider.data_source,
- primary_identifier=self._identifier())
+ wrong = CirculationData(
+ provider.data_source, primary_identifier=self._identifier()
+ )
failure = provider._set_circulationdata(identifier, empty)
- assert ("Identifier did not match CirculationData's primary identifier." ==
- failure.exception)
+ assert (
+ "Identifier did not match CirculationData's primary identifier."
+ == failure.exception
+ )
# Here, the data is okay, but the ReplacementPolicy is
# going to cause an error the first time we try to use it.
- correct = CirculationData(provider.data_source,
- identifier)
+ correct = CirculationData(provider.data_source, identifier)
provider.replacement_policy = object()
failure = provider._set_circulationdata(identifier, correct)
assert isinstance(failure, CoverageFailure)
@@ -1572,7 +1568,8 @@ def test_set_metadata_incorporates_replacement_policy(self):
# 'HTTP client'...
http = DummyHTTPClient()
http.queue_response(
- 200, content='I am an epub.',
+ 200,
+ content="I am an epub.",
media_type=Representation.EPUB_MEDIA_TYPE,
)
@@ -1588,7 +1585,7 @@ def __init__(self, *args, **kwargs):
def __getattr__(self, name):
self.tripped = True
- if name.startswith('equivalent_identifier_'):
+ if name.startswith("equivalent_identifier_"):
# These need to be numbers rather than booleans,
# but the exact number doesn't matter.
return 100
@@ -1598,7 +1595,7 @@ def __getattr__(self, name):
replacement_policy = ReplacementPolicy(
mirrors=mirrors,
http_get=http.do_get,
- presentation_calculation_policy=presentation_calculation_policy
+ presentation_calculation_policy=presentation_calculation_policy,
)
provider = AlwaysSuccessfulCollectionCoverageProvider(
@@ -1612,21 +1609,21 @@ def __getattr__(self, name):
# We get an error if the CirculationData's identifier is
# doesn't match what we pass in.
circulationdata = CirculationData(
- provider.data_source,
- primary_identifier=self._identifier(),
- links=[link]
+ provider.data_source, primary_identifier=self._identifier(), links=[link]
)
failure = provider.set_metadata_and_circulation_data(
identifier, metadata, circulationdata
)
- assert ("Identifier did not match CirculationData's primary identifier." ==
- failure.exception)
+ assert (
+ "Identifier did not match CirculationData's primary identifier."
+ == failure.exception
+ )
# Otherwise, the data is applied.
circulationdata = CirculationData(
provider.data_source,
primary_identifier=metadata.primary_identifier,
- links=[link]
+ links=[link],
)
provider.set_metadata_and_circulation_data(
@@ -1665,9 +1662,7 @@ def test_items_that_need_coverage(self):
# If we set the CoverageProvider's cutoff_time to the time of
# coverage, the Identifier is still treated as covered.
- provider = AlwaysSuccessfulCoverageProvider(
- self._db, cutoff_time=cutoff_time
- )
+ provider = AlwaysSuccessfulCoverageProvider(self._db, cutoff_time=cutoff_time)
assert [] == provider.items_that_need_coverage().all()
# But if we set the cutoff time to immediately after the time
@@ -1678,8 +1673,7 @@ def test_items_that_need_coverage(self):
)
# The identifier is treated as lacking coverage.
- assert ([identifier] ==
- provider.items_that_need_coverage().all())
+ assert [identifier] == provider.items_that_need_coverage().all()
def test_work(self):
"""Verify that a CollectionCoverageProvider can create a Work."""
@@ -1688,9 +1682,7 @@ def test_work(self):
# Here's a CollectionCoverageProvider that is associated
# with an OPDS import-style Collection.
- provider = AlwaysSuccessfulCollectionCoverageProvider(
- self._default_collection
- )
+ provider = AlwaysSuccessfulCollectionCoverageProvider(self._default_collection)
# This CoverageProvider cannot create a Work for the given
# Identifier, because that would require creating a
@@ -1731,8 +1723,9 @@ def test_work(self):
identifier2 = self._identifier()
identifier.licensed_through = []
collection2 = self._collection()
- edition2 = self._edition(identifier_type=identifier2.type,
- identifier_id=identifier2.identifier)
+ edition2 = self._edition(
+ identifier_type=identifier2.type, identifier_id=identifier2.identifier
+ )
pool2 = self._licensepool(edition=edition2, collection=collection2)
work2 = provider.work(identifier, pool2)
assert work2 != work
@@ -1756,10 +1749,10 @@ def test_work(self):
# If a work exists but is not presentation-ready,
# CollectionCoverageProvider.work() will call calculate_work()
# in an attempt to fix it.
- edition.title = 'Finally a title'
+ edition.title = "Finally a title"
work2 = provider.work(pool.identifier, pool)
assert work2 == work
- assert 'Finally a title' == work.title
+ assert "Finally a title" == work.title
assert True == work.presentation_ready
# Once the work is presentation_ready, calling
@@ -1767,11 +1760,14 @@ def test_work(self):
# calculate_work() -- it will just return the work.
def explode():
raise Exception("don't call me!")
+
pool.calculate_work = explode
work2 = provider.work(pool.identifier, pool)
assert work2 == work
- def test_set_metadata_and_circulationdata(self, bibliographic_data, circulation_data):
+ def test_set_metadata_and_circulationdata(
+ self, bibliographic_data, circulation_data
+ ):
"""Verify that a CollectionCoverageProvider can set both
metadata (on an Edition) and circulation data (on a LicensePool).
"""
@@ -1789,17 +1785,17 @@ class OverdriveProvider(AlwaysSuccessfulCollectionCoverageProvider):
DATA_SOURCE_NAME = DataSource.OVERDRIVE
PROTOCOL = ExternalIntegration.OVERDRIVE
IDENTIFIER_TYPES = Identifier.OVERDRIVE_ID
+
collection = self._collection(protocol=ExternalIntegration.OVERDRIVE)
provider = OverdriveProvider(collection)
# We get a CoverageFailure if we don't pass in any data at all.
- result = provider.set_metadata_and_circulation_data(
- identifier, None, None
- )
+ result = provider.set_metadata_and_circulation_data(identifier, None, None)
assert isinstance(result, CoverageFailure)
assert (
- "Received neither metadata nor circulation data from input source" ==
- result.exception)
+ "Received neither metadata nor circulation data from input source"
+ == result.exception
+ )
# We get a CoverageFailure if no work can be created. In this
# case, that happens because the metadata doesn't provide a
@@ -1858,9 +1854,7 @@ def test_autocreate_licensepool(self):
"""
identifier = self._identifier()
assert [] == identifier.licensed_through
- provider = AlwaysSuccessfulCollectionCoverageProvider(
- self._default_collection
- )
+ provider = AlwaysSuccessfulCollectionCoverageProvider(self._default_collection)
pool = provider.license_pool(identifier)
assert [pool] == identifier.licensed_through
assert pool.data_source == provider.data_source
@@ -1882,17 +1876,13 @@ def test_autocreate_licensepool(self):
# If a working pool already exists, it's returned and no new
# pool is created.
- same_pool = provider.license_pool(
- identifier, DataSource.INTERNAL_PROCESSING
- )
+ same_pool = provider.license_pool(identifier, DataSource.INTERNAL_PROCESSING)
assert same_pool == pool2
assert provider.data_source == same_pool.data_source
# A new pool is only created if no working pool can be found.
identifier2 = self._identifier()
- new_pool = provider.license_pool(
- identifier2, DataSource.INTERNAL_PROCESSING
- )
+ new_pool = provider.license_pool(identifier2, DataSource.INTERNAL_PROCESSING)
assert new_pool.data_source.name == DataSource.INTERNAL_PROCESSING
assert new_pool.identifier == identifier2
assert new_pool.collection == provider.collection
@@ -1902,9 +1892,7 @@ def test_set_presentation_ready(self):
as presentation-ready.
"""
identifier = self._identifier()
- provider = AlwaysSuccessfulCollectionCoverageProvider(
- self._default_collection
- )
+ provider = AlwaysSuccessfulCollectionCoverageProvider(self._default_collection)
# If there is no LicensePool for the Identifier,
# set_presentation_ready will not try to create one,
@@ -1918,14 +1906,13 @@ def test_set_presentation_ready(self):
# mark it presentation ready.
pool = provider.license_pool(identifier)
edition = provider.edition(identifier)
- edition.title = 'A title'
+ edition.title = "A title"
result = provider.set_presentation_ready(identifier)
assert result == identifier
assert True == pool.work.presentation_ready
class TestCatalogCoverageProvider(CoverageProviderTest):
-
def test_items_that_need_coverage(self):
c1 = self._collection()
@@ -1942,8 +1929,7 @@ def test_items_that_need_coverage(self):
# This Identifier is licensed through the Collection c1, but
# it's not in the catalog--catalogs are used for different
# things.
- edition, lp = self._edition(with_license_pool=True,
- collection=c1)
+ edition, lp = self._edition(with_license_pool=True, collection=c1)
# We have four identifiers, but only i1 shows up, because
# it's the only one in c1's catalog.
@@ -1961,9 +1947,7 @@ class TestBibliographicCoverageProvider(CoverageProviderTest):
def setup_method(self):
super(TestBibliographicCoverageProvider, self).setup_method()
- self.work = self._work(
- with_license_pool=True, with_open_access_download=True
- )
+ self.work = self._work(with_license_pool=True, with_open_access_download=True)
self.work.presentation_ready = False
[self.pool] = self.work.license_pools
self.identifier = self.pool.identifier
@@ -1971,9 +1955,7 @@ def setup_method(self):
def test_work_set_presentation_ready_on_success(self):
# When a Work is successfully run through a
# BibliographicCoverageProvider, it's set as presentation-ready.
- provider = AlwaysSuccessfulBibliographicCoverageProvider(
- self.pool.collection
- )
+ provider = AlwaysSuccessfulBibliographicCoverageProvider(self.pool.collection)
[result] = provider.process_batch([self.identifier])
assert result == self.identifier
assert True == self.work.presentation_ready
@@ -1986,19 +1968,15 @@ def test_work_set_presentation_ready_on_success(self):
assert True == self.work.presentation_ready
def test_failure_does_not_set_work_presentation_ready(self):
- """A Work is not set as presentation-ready except on success.
- """
+ """A Work is not set as presentation-ready except on success."""
- provider = NeverSuccessfulBibliographicCoverageProvider(
- self.pool.collection
- )
+ provider = NeverSuccessfulBibliographicCoverageProvider(self.pool.collection)
result = provider.ensure_coverage(self.identifier)
assert CoverageRecord.TRANSIENT_FAILURE == result.status
assert False == self.work.presentation_ready
class TestWorkCoverageProvider(DatabaseTest):
-
def setup_method(self):
super(TestWorkCoverageProvider, self).setup_method()
self.work = self._work()
@@ -2008,17 +1986,18 @@ class MockProvider(AlwaysSuccessfulWorkCoverageProvider):
OPERATION = "the_operation"
qu = self._db.query(WorkCoverageRecord).filter(
- WorkCoverageRecord.operation==MockProvider.OPERATION
+ WorkCoverageRecord.operation == MockProvider.OPERATION
)
provider = MockProvider(self._db)
# We start with no relevant WorkCoverageRecord and no Timestamp.
assert [] == qu.all()
- assert (None ==
- Timestamp.value(
- self._db, provider.service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
- ))
+ assert None == Timestamp.value(
+ self._db,
+ provider.service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
+ )
now = utc_now()
provider.run()
@@ -2030,45 +2009,52 @@ class MockProvider(AlwaysSuccessfulWorkCoverageProvider):
# The timestamp is now set.
timestamp = Timestamp.value(
- self._db, provider.service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
+ self._db,
+ provider.service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
)
- assert (timestamp-now).total_seconds() < 1
+ assert (timestamp - now).total_seconds() < 1
def test_transient_failure(self):
class MockProvider(TransientFailureWorkCoverageProvider):
OPERATION = "the_operation"
+
provider = MockProvider(self._db)
# We start with no relevant WorkCoverageRecords.
qu = self._db.query(WorkCoverageRecord).filter(
- WorkCoverageRecord.operation==provider.operation
+ WorkCoverageRecord.operation == provider.operation
)
assert [] == qu.all()
provider.run()
# We now have a CoverageRecord for the transient failure.
- [failure] = [x for x in self.work.coverage_records if
- x.operation==provider.operation]
+ [failure] = [
+ x for x in self.work.coverage_records if x.operation == provider.operation
+ ]
assert CoverageRecord.TRANSIENT_FAILURE == failure.status
# The timestamp is now set to a recent value.
service_name = "Never successful (transient, works) (the_operation)"
value = Timestamp.value(
- self._db, service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
+ self._db,
+ service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
)
- assert (utc_now()-value).total_seconds() < 2
+ assert (utc_now() - value).total_seconds() < 2
def test_persistent_failure(self):
class MockProvider(NeverSuccessfulWorkCoverageProvider):
OPERATION = "the_operation"
+
provider = MockProvider(self._db)
# We start with no relevant WorkCoverageRecords.
qu = self._db.query(WorkCoverageRecord).filter(
- WorkCoverageRecord.operation==provider.operation
+ WorkCoverageRecord.operation == provider.operation
)
assert [] == qu.all()
@@ -2082,10 +2068,12 @@ class MockProvider(NeverSuccessfulWorkCoverageProvider):
# The timestamp is now set to a recent value.
service_name = "Never successful (works) (the_operation)"
value = Timestamp.value(
- self._db, service_name,
- service_type=Timestamp.COVERAGE_PROVIDER_TYPE, collection=None
+ self._db,
+ service_name,
+ service_type=Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=None,
)
- assert (utc_now()-value).total_seconds() < 2
+ assert (utc_now() - value).total_seconds() < 2
def test_items_that_need_coverage(self):
# Here's a WorkCoverageProvider.
@@ -2114,8 +2102,7 @@ def test_items_that_need_coverage(self):
# WorkCoverageRecord was created, then that work starts
# showing up again as needing coverage.
provider.cutoff_time = record.timestamp + datetime.timedelta(seconds=1)
- assert (set([w2, w3]) ==
- set(provider.items_that_need_coverage([i2, i3]).all()))
+ assert set([w2, w3]) == set(provider.items_that_need_coverage([i2, i3]).all())
def test_failure_for_ignored_item(self):
class MockProvider(NeverSuccessfulWorkCoverageProvider):
@@ -2140,11 +2127,9 @@ def test_record_failure_as_coverage_record(self):
class TestPresentationReadyWorkCoverageProvider(DatabaseTest):
-
def test_items_that_need_coverage(self):
-
class Mock(PresentationReadyWorkCoverageProvider):
- SERVICE_NAME = 'mock'
+ SERVICE_NAME = "mock"
provider = Mock(self._db)
work = self._work()
@@ -2163,12 +2148,12 @@ class MockWork(object):
"""A Work-like object that keeps track of the policy that was used
to recalculate its presentation.
"""
+
def calculate_presentation(self, policy):
self.calculate_presentation_called_with = policy
class TestWorkPresentationEditionCoverageProvider(DatabaseTest):
-
def test_process_item(self):
work = MockWork()
provider = WorkPresentationEditionCoverageProvider(self._db)
@@ -2179,18 +2164,20 @@ def test_process_item(self):
# Verify that the policy is configured correctly. It does
# all the work that's not expensive.
assert all(
- [policy.choose_edition, policy.set_edition_metadata,
- policy.choose_cover, policy.regenerate_opds_entries,
- policy.update_search_index]
+ [
+ policy.choose_edition,
+ policy.set_edition_metadata,
+ policy.choose_cover,
+ policy.regenerate_opds_entries,
+ policy.update_search_index,
+ ]
)
assert not any(
- [policy.classify, policy.choose_summary,
- policy.calculate_quality]
+ [policy.classify, policy.choose_summary, policy.calculate_quality]
)
class TestWorkClassificationCoverageProvider(DatabaseTest):
-
def test_process_item(self):
work = MockWork()
provider = WorkClassificationCoverageProvider(self._db)
@@ -2200,46 +2187,51 @@ def test_process_item(self):
# work.
policy = work.calculate_presentation_called_with
assert all(
- [policy.choose_edition, policy.set_edition_metadata,
- policy.choose_cover, policy.regenerate_opds_entries,
- policy.update_search_index, policy.classify,
- policy.choose_summary, policy.calculate_quality]
+ [
+ policy.choose_edition,
+ policy.set_edition_metadata,
+ policy.choose_cover,
+ policy.regenerate_opds_entries,
+ policy.update_search_index,
+ policy.classify,
+ policy.choose_summary,
+ policy.calculate_quality,
+ ]
)
class TestOPDSEntryWorkCoverageProvider(DatabaseTest):
-
def test_run(self):
provider = OPDSEntryWorkCoverageProvider(self._db)
work = self._work()
- work.simple_opds_entry = 'old junk'
- work.verbose_opds_entry = 'old long junk'
+ work.simple_opds_entry = "old junk"
+ work.verbose_opds_entry = "old long junk"
# The work is not presentation-ready, so nothing happens.
provider.run()
- assert 'old junk' == work.simple_opds_entry
- assert 'old long junk' == work.verbose_opds_entry
+ assert "old junk" == work.simple_opds_entry
+ assert "old long junk" == work.verbose_opds_entry
# The work is presentation-ready, so its OPDS entries are
# regenerated.
work.presentation_ready = True
provider.run()
- assert work.simple_opds_entry.startswith(' \1
replace = replace.replace("$", "\\")
@@ -429,7 +414,9 @@ def populate_works(self):
_work = self.default_work
self.moby_dick = _work(
- title="Moby Dick", authors="Herman Melville", fiction=True,
+ title="Moby Dick",
+ authors="Herman Melville",
+ fiction=True,
)
self.moby_dick.presentation_edition.subtitle = "Or, the Whale"
self.moby_dick.presentation_edition.series = "Classics"
@@ -438,7 +425,9 @@ def populate_works(self):
self.moby_dick.last_update_time = datetime_utc(2019, 1, 1)
self.moby_duck = _work(title="Moby Duck", authors="Donovan Hohn", fiction=False)
- self.moby_duck.presentation_edition.subtitle = "The True Story of 28,800 Bath Toys Lost at Sea"
+ self.moby_duck.presentation_edition.subtitle = (
+ "The True Story of 28,800 Bath Toys Lost at Sea"
+ )
self.moby_duck.summary_text = "A compulsively readable narrative"
self.moby_duck.presentation_edition.publisher = "Penguin"
self.moby_duck.last_update_time = datetime_utc(2019, 1, 2)
@@ -470,24 +459,33 @@ def populate_works(self):
self.washington = _work(genre="Biography", title="George Washington")
- self.lincoln_vampire = _work(title="Abraham Lincoln: Vampire Hunter", genre="Fantasy")
+ self.lincoln_vampire = _work(
+ title="Abraham Lincoln: Vampire Hunter", genre="Fantasy"
+ )
- self.children_work = _work(title="Alice in Wonderland", audience=Classifier.AUDIENCE_CHILDREN)
+ self.children_work = _work(
+ title="Alice in Wonderland", audience=Classifier.AUDIENCE_CHILDREN
+ )
- self.all_ages_work = _work(title="The Annotated Alice", audience=Classifier.AUDIENCE_ALL_AGES)
+ self.all_ages_work = _work(
+ title="The Annotated Alice", audience=Classifier.AUDIENCE_ALL_AGES
+ )
- self.ya_work = _work(title="Go Ask Alice", audience=Classifier.AUDIENCE_YOUNG_ADULT)
+ self.ya_work = _work(
+ title="Go Ask Alice", audience=Classifier.AUDIENCE_YOUNG_ADULT
+ )
self.adult_work = _work(title="Still Alice", audience=Classifier.AUDIENCE_ADULT)
self.research_work = _work(
title="Curiouser and Curiouser: Surrealism and Repression in 'Alice in Wonderland'",
- audience=Classifier.AUDIENCE_RESEARCH
+ audience=Classifier.AUDIENCE_RESEARCH,
)
self.ya_romance = _work(
title="Gumby In Love",
- audience=Classifier.AUDIENCE_YOUNG_ADULT, genre="Romance"
+ audience=Classifier.AUDIENCE_YOUNG_ADULT,
+ genre="Romance",
)
self.ya_romance.presentation_edition.subtitle = (
"Modern Fairytale Series, Volume 7"
@@ -495,35 +493,43 @@ def populate_works(self):
self.ya_romance.presentation_edition.series = "Modern Fairytales"
self.no_age = _work()
- self.no_age.summary_text = "President Barack Obama's election in 2008 energized the United States"
+ self.no_age.summary_text = (
+ "President Barack Obama's election in 2008 energized the United States"
+ )
# Set the series to the empty string rather than None -- this isn't counted
# as the book belonging to a series.
self.no_age.presentation_edition.series = ""
self.age_4_5 = _work()
- self.age_4_5.target_age = NumericRange(4, 5, '[]')
- self.age_4_5.summary_text = "President Barack Obama's election in 2008 energized the United States"
+ self.age_4_5.target_age = NumericRange(4, 5, "[]")
+ self.age_4_5.summary_text = (
+ "President Barack Obama's election in 2008 energized the United States"
+ )
self.age_5_6 = _work(fiction=False)
- self.age_5_6.target_age = NumericRange(5, 6, '[]')
+ self.age_5_6.target_age = NumericRange(5, 6, "[]")
- self.obama = _work(
- title="Barack Obama", genre="Biography & Memoir"
+ self.obama = _work(title="Barack Obama", genre="Biography & Memoir")
+ self.obama.target_age = NumericRange(8, 8, "[]")
+ self.obama.summary_text = (
+ "President Barack Obama's election in 2008 energized the United States"
)
- self.obama.target_age = NumericRange(8, 8, '[]')
- self.obama.summary_text = "President Barack Obama's election in 2008 energized the United States"
self.dodger = _work()
- self.dodger.target_age = NumericRange(8, 8, '[]')
- self.dodger.summary_text = "Willie finds himself running for student council president"
+ self.dodger.target_age = NumericRange(8, 8, "[]")
+ self.dodger.summary_text = (
+ "Willie finds himself running for student council president"
+ )
self.age_9_10 = _work()
- self.age_9_10.target_age = NumericRange(9, 10, '[]')
- self.age_9_10.summary_text = "President Barack Obama's election in 2008 energized the United States"
+ self.age_9_10.target_age = NumericRange(9, 10, "[]")
+ self.age_9_10.summary_text = (
+ "President Barack Obama's election in 2008 energized the United States"
+ )
self.age_2_10 = _work()
- self.age_2_10.target_age = NumericRange(2, 10, '[]')
+ self.age_2_10.target_age = NumericRange(2, 10, "[]")
self.pride = _work(title="Pride and Prejudice (E)")
self.pride.presentation_edition.medium = Edition.BOOK_MEDIUM
@@ -532,8 +538,7 @@ def populate_works(self):
self.pride_audio.presentation_edition.medium = Edition.AUDIO_MEDIUM
self.sherlock = _work(
- title="The Adventures of Sherlock Holmes",
- with_open_access_download=True
+ title="The Adventures of Sherlock Holmes", with_open_access_download=True
)
self.sherlock.presentation_edition.language = "eng"
@@ -550,8 +555,7 @@ def populate_works(self):
# Create a second collection that only contains a few books.
self.tiny_collection = self._collection("A Tiny Collection")
self.tiny_book = self._work(
- title="A Tiny Book", with_license_pool=True,
- collection=self.tiny_collection
+ title="A Tiny Book", with_license_pool=True, collection=self.tiny_collection
)
self.tiny_book.license_pools[0].self_hosted = True
@@ -559,8 +563,7 @@ def populate_works(self):
# Holmes", but each collection licenses the book through a
# different mechanism.
self.sherlock_pool_2 = self._licensepool(
- edition=self.sherlock.presentation_edition,
- collection=self.tiny_collection
+ edition=self.sherlock.presentation_edition, collection=self.tiny_collection
)
sherlock_2, is_new = self.sherlock_pool_2.calculate_work()
@@ -595,9 +598,11 @@ def test_query_works(self):
# document query doesn't contain over-zealous joins. This test
# class is the main place where we make a large number of
# works and generate search documents for them.
- assert 1 == len(self.moby_dick.to_search_document()['licensepools'])
- assert ("Audio" ==
- self.pride_audio.to_search_document()['licensepools'][0]['medium'])
+ assert 1 == len(self.moby_dick.to_search_document()["licensepools"])
+ assert (
+ "Audio"
+ == self.pride_audio.to_search_document()["licensepools"][0]["medium"]
+ )
# Set up convenient aliases for methods we'll be calling a
# lot.
@@ -612,10 +617,7 @@ def test_query_works(self):
expect(self.moby_duck, "moby dick", None, second_item)
two_per_page = Pagination(size=2, offset=0)
- expect(
- [self.moby_dick, self.moby_duck],
- "moby dick", None, two_per_page
- )
+ expect([self.moby_dick, self.moby_duck], "moby dick", None, two_per_page)
# Now try some different search queries.
@@ -648,10 +650,7 @@ def test_query_works(self):
# A search for a partial title match + a partial author match
# considers only books that match both fields.
- expect(
- [self.moby_dick],
- "moby melville"
- )
+ expect([self.moby_dick], "moby melville")
# Match a quoted phrase
# 'Moby-Dick' is the first result because it's an exact title
@@ -706,8 +705,7 @@ def test_query_works(self):
# This finds books that belong to _some_ series.
some_series = Filter(series=True)
- expect([self.moby_dick, self.ya_romance], "", some_series,
- ordered=False)
+ expect([self.moby_dick, self.ya_romance], "", some_series, ordered=False)
# Find results based on genre.
@@ -719,12 +717,10 @@ def test_query_works(self):
# Find results based on audience.
expect(self.children_work, "children's")
- expect(
- [self.ya_work, self.ya_romance], "young adult", ordered=False
- )
+ expect([self.ya_work, self.ya_romance], "young adult", ordered=False)
# Find results based on grade level or target age.
- for q in ('grade 4', 'grade 4-6', 'age 9'):
+ for q in ("grade 4", "grade 4-6", "age 9"):
# ages 9-10 is a better result because a book targeted
# toward a narrow range is a better match than a book
# targeted toward a wide range.
@@ -782,10 +778,7 @@ def test_query_works(self):
expect(self.sherlock, "sherlock", english)
expect(self.sherlock_spanish, "sherlock", spanish)
- expect(
- [self.sherlock, self.sherlock_spanish], "sherlock", both,
- ordered=False
- )
+ expect([self.sherlock, self.sherlock_spanish], "sherlock", both, ordered=False)
# Filters on fiction status
fiction = Filter(fiction=True)
@@ -805,8 +798,7 @@ def test_query_works(self):
ya = Filter(audiences=Classifier.AUDIENCE_YOUNG_ADULT)
children = Filter(audiences=Classifier.AUDIENCE_CHILDREN)
ya_and_children = Filter(
- audiences=[Classifier.AUDIENCE_CHILDREN,
- Classifier.AUDIENCE_YOUNG_ADULT]
+ audiences=[Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_YOUNG_ADULT]
)
research = Filter(audiences=[Classifier.AUDIENCE_RESEARCH])
@@ -816,8 +808,9 @@ def expect_alice(expect_works, filter):
expect_alice([self.adult_work, self.all_ages_work], adult)
expect_alice([self.ya_work, self.all_ages_work], ya)
expect_alice([self.children_work, self.all_ages_work], children)
- expect_alice([self.children_work, self.ya_work, self.all_ages_work],
- ya_and_children)
+ expect_alice(
+ [self.children_work, self.ya_work, self.all_ages_work], ya_and_children
+ )
# The 'all ages' work appears except when the audience would make
# that inappropriate...
@@ -828,53 +821,55 @@ def expect_alice(expect_works, filter):
# to have the necessary reading fluency.
expect_alice(
[self.children_work],
- Filter(audiences=Classifier.AUDIENCE_CHILDREN, target_age=(2,3))
+ Filter(audiences=Classifier.AUDIENCE_CHILDREN, target_age=(2, 3)),
)
# If there is no filter, the research work is excluded by
# default, but everything else is included.
default_filter = Filter()
expect_alice(
- [self.children_work, self.ya_work, self.adult_work,
- self.all_ages_work],
- default_filter
+ [self.children_work, self.ya_work, self.adult_work, self.all_ages_work],
+ default_filter,
)
# Filters on age range
age_8 = Filter(target_age=8)
- age_5_8 = Filter(target_age=(5,8))
- age_5_10 = Filter(target_age=(5,10))
- age_8_10 = Filter(target_age=(8,10))
+ age_5_8 = Filter(target_age=(5, 8))
+ age_5_10 = Filter(target_age=(5, 10))
+ age_8_10 = Filter(target_age=(8, 10))
# As the age filter changes, different books appear and
# disappear. no_age is always present since it has no age
# restrictions.
expect(
- [self.no_age, self.obama, self.dodger],
- "president", age_8, ordered=False
+ [self.no_age, self.obama, self.dodger], "president", age_8, ordered=False
)
expect(
[self.no_age, self.age_4_5, self.obama, self.dodger],
- "president", age_5_8, ordered=False
+ "president",
+ age_5_8,
+ ordered=False,
)
expect(
- [self.no_age, self.age_4_5, self.obama, self.dodger,
- self.age_9_10],
- "president", age_5_10, ordered=False
+ [self.no_age, self.age_4_5, self.obama, self.dodger, self.age_9_10],
+ "president",
+ age_5_10,
+ ordered=False,
)
expect(
[self.no_age, self.obama, self.dodger, self.age_9_10],
- "president", age_8_10, ordered=False
+ "president",
+ age_8_10,
+ ordered=False,
)
# Filters on license source.
gutenberg = DataSource.lookup(self._db, DataSource.GUTENBERG)
gutenberg_only = Filter(license_datasource=gutenberg)
- expect([self.moby_dick, self.moby_duck], "moby", gutenberg_only,
- ordered=False)
+ expect([self.moby_dick, self.moby_duck], "moby", gutenberg_only, ordered=False)
overdrive = DataSource.lookup(self._db, DataSource.OVERDRIVE)
overdrive_only = Filter(license_datasource=overdrive)
@@ -899,16 +894,13 @@ def expect_alice(expect_works, filter):
expect(self.lincoln, "lincoln", biography_filter)
expect(self.lincoln_vampire, "lincoln", fantasy_filter)
- expect([self.lincoln, self.lincoln_vampire], "lincoln", both,
- ordered=False)
+ expect([self.lincoln, self.lincoln_vampire], "lincoln", both, ordered=False)
# Filters on list membership.
# This ignores 'Abraham Lincoln, Vampire Hunter' because that
# book isn't on the self.presidential list.
- on_presidential_list = Filter(
- customlist_restriction_sets=[[self.presidential]]
- )
+ on_presidential_list = Filter(customlist_restriction_sets=[[self.presidential]])
expect(self.lincoln, "lincoln", on_presidential_list)
# This filters everything, since the query is restricted to
@@ -929,15 +921,12 @@ def expect_alice(expect_works, filter):
# being searched, it only shows up in search results once.
f = Filter(
collections=[self._default_collection, self.tiny_collection],
- languages="eng"
+ languages="eng",
)
expect(self.sherlock, "sherlock holmes", f)
# Filter on identifier -- one or many.
- for results in [
- [self.lincoln],
- [self.sherlock, self.pride_audio]
- ]:
+ for results in [[self.lincoln], [self.sherlock, self.pride_audio]]:
identifiers = [w.license_pools[0].identifier for w in results]
f = Filter(identifiers=identifiers)
expect(results, None, f, ordered=False)
@@ -964,10 +953,7 @@ def expect_alice(expect_works, filter):
DataSource.lookup(self._db, DataSource.BIBLIOTHECA)
]
)
- expect(
- [self.pride, self.pride_audio], "pride and prejudice", f,
- ordered=False
- )
+ expect([self.pride, self.pride_audio], "pride and prejudice", f, ordered=False)
# "Moby Duck" is not currently available, so it won't show up in
# search results if allow_holds is False.
@@ -984,14 +970,10 @@ def pages(worklist):
pages.
"""
pagination = SortKeyPagination(size=2)
- facets = Facets(
- self._default_library, None, None, order=Facets.ORDER_TITLE
- )
+ facets = Facets(self._default_library, None, None, order=Facets.ORDER_TITLE)
pages = []
while pagination:
- pages.append(worklist.works(
- self._db, facets, pagination, self.search
- ))
+ pages.append(worklist.works(self._db, facets, pagination, self.search))
pagination = pagination.next_page
# The last page should always be empty -- that's how we
@@ -1003,16 +985,14 @@ def pages(worklist):
# Test a WorkList based on a custom list.
presidential = WorkList()
- presidential.initialize(
- self._default_library, customlists=[self.presidential]
- )
+ presidential.initialize(self._default_library, customlists=[self.presidential])
p1, p2 = pages(presidential)
assert [self.lincoln, self.obama] == p1
assert [self.washington] == p2
# Test a WorkList based on a language.
spanish = WorkList()
- spanish.initialize(self._default_library, languages=['spa'])
+ spanish.initialize(self._default_library, languages=["spa"])
assert [[self.sherlock_spanish]] == pages(spanish)
# Test a WorkList based on a genre.
@@ -1024,14 +1004,18 @@ def pages(worklist):
# quality.
f = SearchFacets
by_author = f(
- library=self._default_library, collection=f.COLLECTION_FULL,
- availability=f.AVAILABLE_ALL, order=f.ORDER_AUTHOR
+ library=self._default_library,
+ collection=f.COLLECTION_FULL,
+ availability=f.AVAILABLE_ALL,
+ order=f.ORDER_AUTHOR,
)
by_author = Filter(facets=by_author)
by_title = f(
- library=self._default_library, collection=f.COLLECTION_FULL,
- availability=f.AVAILABLE_ALL, order=f.ORDER_TITLE
+ library=self._default_library,
+ collection=f.COLLECTION_FULL,
+ availability=f.AVAILABLE_ALL,
+ order=f.ORDER_TITLE,
)
by_title = Filter(facets=by_title)
@@ -1056,10 +1040,9 @@ def pages(worklist):
# Lower it even more and we can start picking up search results
# that only match because of words in the description.
- by_title.min_score=10
- by_author.min_score=10
- results = [self.no_age, self.age_4_5, self.dodger,
- self.age_9_10, self.obama]
+ by_title.min_score = 10
+ by_author.min_score = 10
+ results = [self.no_age, self.age_4_5, self.dodger, self.age_9_10, self.obama]
expect(results, "president", by_title)
# Reverse the sort order to demonstrate that these works are being
@@ -1073,15 +1056,13 @@ def pages(worklist):
# Different query strings.
self._expect_results_multi(
[[self.moby_dick], [self.moby_duck]],
- [("moby dick", None, first_item),
- ("moby duck", None, first_item)]
+ [("moby dick", None, first_item), ("moby duck", None, first_item)],
)
# Same query string, different pagination settings.
self._expect_results_multi(
[[self.moby_dick], [self.moby_duck]],
- [("moby dick", None, first_item),
- ("moby dick", None, second_item)]
+ [("moby dick", None, first_item), ("moby dick", None, second_item)],
)
# Same query string, same pagination settings, different
@@ -1092,13 +1073,14 @@ def pages(worklist):
match_nothing = Filter(match_nothing=True)
self._expect_results_multi(
[[self.moby_duck], []],
- [("moby dick", Filter(fiction=False), first_item),
- (None, match_nothing, first_item)]
+ [
+ ("moby dick", Filter(fiction=False), first_item),
+ (None, match_nothing, first_item),
+ ],
)
class TestFacetFilters(EndToEndSearchTest):
-
def populate_works(self):
_work = self.default_work
@@ -1109,18 +1091,16 @@ def populate_works(self):
self.horse.quality = 0.2
# A high-quality open-access work.
- self.moby = _work(
- title="Moby Dick", with_open_access_download=True
- )
+ self.moby = _work(title="Moby Dick", with_open_access_download=True)
self.moby.quality = 0.8
# A currently available commercially-licensed work.
- self.duck = _work(title='Moby Duck')
+ self.duck = _work(title="Moby Duck")
self.duck.license_pools[0].licenses_available = 1
self.duck.quality = 0.5
# A currently unavailable commercially-licensed work.
- self.becoming = _work(title='Becoming')
+ self.becoming = _work(title="Becoming")
self.becoming.license_pools[0].licenses_available = 0
self.becoming.quality = 0.9
@@ -1135,35 +1115,44 @@ def test_facet_filtering(self):
def expect(availability, collection, works):
facets = Facets(
- self._default_library, availability, collection,
- order=Facets.ORDER_TITLE
- )
- self._expect_results(
- works, None, Filter(facets=facets), ordered=False
+ self._default_library,
+ availability,
+ collection,
+ order=Facets.ORDER_TITLE,
)
+ self._expect_results(works, None, Filter(facets=facets), ordered=False)
# Get all the books in alphabetical order by title.
- expect(Facets.COLLECTION_FULL, Facets.AVAILABLE_ALL,
- [self.becoming, self.horse, self.moby, self.duck])
+ expect(
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_ALL,
+ [self.becoming, self.horse, self.moby, self.duck],
+ )
# Show only works that can be borrowed right now.
- expect(Facets.COLLECTION_FULL, Facets.AVAILABLE_NOW,
- [self.horse, self.moby, self.duck])
+ expect(
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_NOW,
+ [self.horse, self.moby, self.duck],
+ )
# Show only works that can *not* be borrowed right now.
expect(Facets.COLLECTION_FULL, Facets.AVAILABLE_NOT_NOW, [self.becoming])
# Show only open-access works.
- expect(Facets.COLLECTION_FULL, Facets.AVAILABLE_OPEN_ACCESS,
- [self.horse, self.moby])
+ expect(
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_OPEN_ACCESS,
+ [self.horse, self.moby],
+ )
# Show only featured-quality works.
- expect(Facets.COLLECTION_FEATURED, Facets.AVAILABLE_ALL,
- [self.becoming, self.moby])
+ expect(
+ Facets.COLLECTION_FEATURED, Facets.AVAILABLE_ALL, [self.becoming, self.moby]
+ )
class TestSearchOrder(EndToEndSearchTest):
-
def populate_works(self):
_work = self.default_work
@@ -1202,7 +1191,9 @@ def populate_works(self):
# the Filter.
# c, a - when two sets of custom list restrictions [1], [3]
# are associated with the filter.
- self.moby_dick = _work(title="moby dick", authors="Herman Melville", fiction=True)
+ self.moby_dick = _work(
+ title="moby dick", authors="Herman Melville", fiction=True
+ )
self.moby_dick.presentation_edition.subtitle = "Or, the Whale"
self.moby_dick.presentation_edition.series = "Classics"
self.moby_dick.presentation_edition.series_position = 10
@@ -1211,7 +1202,9 @@ def populate_works(self):
self.moby_dick.random = 0.1
self.moby_duck = _work(title="Moby Duck", authors="donovan hohn", fiction=False)
- self.moby_duck.presentation_edition.subtitle = "The True Story of 28,800 Bath Toys Lost at Sea"
+ self.moby_duck.presentation_edition.subtitle = (
+ "The True Story of 28,800 Bath Toys Lost at Sea"
+ )
self.moby_duck.summary_text = "A compulsively readable narrative"
self.moby_duck.presentation_edition.series_position = 1
self.moby_duck.presentation_edition.publisher = "Penguin"
@@ -1246,21 +1239,21 @@ def populate_works(self):
# order.
self.collection2 = self._collection(name="Collection 2 - BAC")
self.a2 = self._licensepool(
- edition=self.a.presentation_edition, collection=self.collection2,
- with_open_access_download=True
-
+ edition=self.a.presentation_edition,
+ collection=self.collection2,
+ with_open_access_download=True,
)
self.a.license_pools.append(self.a2)
self.b2 = self._licensepool(
- edition=self.b.presentation_edition, collection=self.collection2,
- with_open_access_download=True
-
+ edition=self.b.presentation_edition,
+ collection=self.collection2,
+ with_open_access_download=True,
)
self.b.license_pools.append(self.b2)
self.c2 = self._licensepool(
- edition=self.c.presentation_edition, collection=self.collection2,
- with_open_access_download=True
-
+ edition=self.c.presentation_edition,
+ collection=self.collection2,
+ with_open_access_download=True,
)
self.c.license_pools.append(self.c2)
self.b2.availability_time = datetime_utc(2020, 1, 1)
@@ -1269,48 +1262,25 @@ def populate_works(self):
# Here are three custom lists which contain the same books but
# with different first appearances.
- self.list1, ignore = self._customlist(
- name="Custom list 1 - BCA", num_entries=0
- )
- self.list1.add_entry(
- self.b, first_appearance=datetime_utc(2030, 1, 1)
- )
- self.list1.add_entry(
- self.c, first_appearance=datetime_utc(2031, 1, 1)
- )
- self.list1.add_entry(
- self.a, first_appearance=datetime_utc(2032, 1, 1)
- )
+ self.list1, ignore = self._customlist(name="Custom list 1 - BCA", num_entries=0)
+ self.list1.add_entry(self.b, first_appearance=datetime_utc(2030, 1, 1))
+ self.list1.add_entry(self.c, first_appearance=datetime_utc(2031, 1, 1))
+ self.list1.add_entry(self.a, first_appearance=datetime_utc(2032, 1, 1))
- self.list2, ignore = self._customlist(
- name="Custom list 2 - CAB", num_entries=0
- )
- self.list2.add_entry(
- self.c, first_appearance=datetime_utc(2001, 1, 1)
- )
- self.list2.add_entry(
- self.a, first_appearance=datetime_utc(2014, 1, 1)
- )
- self.list2.add_entry(
- self.b, first_appearance=datetime_utc(2015, 1, 1)
- )
+ self.list2, ignore = self._customlist(name="Custom list 2 - CAB", num_entries=0)
+ self.list2.add_entry(self.c, first_appearance=datetime_utc(2001, 1, 1))
+ self.list2.add_entry(self.a, first_appearance=datetime_utc(2014, 1, 1))
+ self.list2.add_entry(self.b, first_appearance=datetime_utc(2015, 1, 1))
- self.list3, ignore = self._customlist(
- name="Custom list 3 -- CA", num_entries=0
- )
- self.list3.add_entry(
- self.a, first_appearance=datetime_utc(2032, 1, 1)
- )
- self.list3.add_entry(
- self.c, first_appearance=datetime_utc(1999, 1, 1)
- )
+ self.list3, ignore = self._customlist(name="Custom list 3 -- CA", num_entries=0)
+ self.list3.add_entry(self.a, first_appearance=datetime_utc(2032, 1, 1))
+ self.list3.add_entry(self.c, first_appearance=datetime_utc(1999, 1, 1))
# Create two custom lists which contain some of the same books,
# but with different first appearances.
self.by_publication_date, ignore = self._customlist(
- name="First appearance on list is publication date",
- num_entries=0
+ name="First appearance on list is publication date", num_entries=0
)
self.by_publication_date.add_entry(
self.moby_duck, first_appearance=datetime_utc(2011, 3, 1)
@@ -1320,8 +1290,7 @@ def populate_works(self):
)
self.staff_picks, ignore = self._customlist(
- name="First appearance is date book was made a staff pick",
- num_entries=0
+ name="First appearance is date book was made a staff pick", num_entries=0
)
self.staff_picks.add_entry(
self.moby_dick, first_appearance=datetime_utc(2015, 5, 2)
@@ -1342,18 +1311,13 @@ def populate_works(self):
self.e.license_pools[0].availability_time = datetime_utc(2011, 1, 1)
self.extra_list, ignore = self._customlist(num_entries=0)
- self.extra_list.add_entry(
- self.d, first_appearance=datetime_utc(2020, 1, 1)
- )
- self.extra_list.add_entry(
- self.e, first_appearance=datetime_utc(2021, 1, 1)
- )
+ self.extra_list.add_entry(self.d, first_appearance=datetime_utc(2020, 1, 1))
+ self.extra_list.add_entry(self.e, first_appearance=datetime_utc(2021, 1, 1))
self.e.last_update_time = datetime_utc(2090, 1, 1)
self.d.last_update_time = datetime_utc(2091, 1, 1)
def test_ordering(self):
-
def assert_order(sort_field, order, **filter_kwargs):
"""Verify that when the books created during test setup are ordered by
the given `sort_field`, they show up in the given `order`.
@@ -1370,8 +1334,11 @@ def assert_order(sort_field, order, **filter_kwargs):
"""
expect = self._expect_results
facets = Facets(
- self._default_library, Facets.COLLECTION_FULL,
- Facets.AVAILABLE_ALL, order=sort_field, order_ascending=True
+ self._default_library,
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_ALL,
+ order=sort_field,
+ order_ascending=True,
)
expect(order, None, Filter(facets=facets, **filter_kwargs))
@@ -1382,9 +1349,7 @@ def assert_order(sort_field, order, **filter_kwargs):
# proves that pagination works for this sort order for
# both Pagination and SortKeyPagination.
facets.order_ascending = True
- for pagination_class in (
- Pagination, SortKeyPagination
- ):
+ for pagination_class in (Pagination, SortKeyPagination):
pagination = pagination_class(size=1)
to_process = list(order) + [[]]
while to_process:
@@ -1399,9 +1364,7 @@ def assert_order(sort_field, order, **filter_kwargs):
# Now try the same tests but in reverse order.
facets.order_ascending = False
- for pagination_class in (
- Pagination, SortKeyPagination
- ):
+ for pagination_class in (Pagination, SortKeyPagination):
pagination = pagination_class(size=1)
to_process = list(reversed(order)) + [[]]
results = []
@@ -1417,15 +1380,17 @@ def assert_order(sort_field, order, **filter_kwargs):
# We can sort by title.
assert_order(
- Facets.ORDER_TITLE, [self.untitled, self.moby_dick, self.moby_duck],
- collections=[self._default_collection]
+ Facets.ORDER_TITLE,
+ [self.untitled, self.moby_dick, self.moby_duck],
+ collections=[self._default_collection],
)
# We can sort by author; 'Hohn' sorts before 'Melville' sorts
# before "[Unknown]"
assert_order(
- Facets.ORDER_AUTHOR, [self.moby_duck, self.moby_dick, self.untitled],
- collections=[self._default_collection]
+ Facets.ORDER_AUTHOR,
+ [self.moby_duck, self.moby_dick, self.untitled],
+ collections=[self._default_collection],
)
# We can sort by series position. Here, the books aren't in
@@ -1434,14 +1399,14 @@ def assert_order(sort_field, order, **filter_kwargs):
assert_order(
Facets.ORDER_SERIES_POSITION,
[self.moby_duck, self.untitled, self.moby_dick],
- collections=[self._default_collection]
+ collections=[self._default_collection],
)
# We can sort by internal work ID, which isn't very useful.
assert_order(
Facets.ORDER_WORK_ID,
[self.moby_dick, self.moby_duck, self.untitled],
- collections=[self._default_collection]
+ collections=[self._default_collection],
)
# We can sort by the time the Work's LicensePools were first
@@ -1452,12 +1417,14 @@ def assert_order(sort_field, order, **filter_kwargs):
# results.
assert_order(
Facets.ORDER_ADDED_TO_COLLECTION,
- [self.a, self.c, self.b], collections=[self.collection1]
+ [self.a, self.c, self.b],
+ collections=[self.collection1],
)
assert_order(
Facets.ORDER_ADDED_TO_COLLECTION,
- [self.b, self.a, self.c], collections=[self.collection2]
+ [self.b, self.a, self.c],
+ collections=[self.collection2],
)
# If a work shows up with multiple availability times through
@@ -1468,44 +1435,49 @@ def assert_order(sort_field, order, **filter_kwargs):
assert_order(
Facets.ORDER_ADDED_TO_COLLECTION,
[self.a, self.c, self.b],
- collections=[self.collection1, self.collection2]
+ collections=[self.collection1, self.collection2],
)
-
# Finally, here are the tests of ORDER_LAST_UPDATE, as described
# above in setup().
assert_order(Facets.ORDER_LAST_UPDATE, [self.a, self.b, self.c, self.e, self.d])
assert_order(
- Facets.ORDER_LAST_UPDATE, [self.a, self.c, self.b],
- collections=[self.collection1]
+ Facets.ORDER_LAST_UPDATE,
+ [self.a, self.c, self.b],
+ collections=[self.collection1],
)
assert_order(
- Facets.ORDER_LAST_UPDATE, [self.b, self.a, self.c],
- collections=[self.collection1, self.collection2]
+ Facets.ORDER_LAST_UPDATE,
+ [self.b, self.a, self.c],
+ collections=[self.collection1, self.collection2],
)
assert_order(
- Facets.ORDER_LAST_UPDATE, [self.b, self.c, self.a],
- customlist_restriction_sets=[[self.list1]]
+ Facets.ORDER_LAST_UPDATE,
+ [self.b, self.c, self.a],
+ customlist_restriction_sets=[[self.list1]],
)
assert_order(
- Facets.ORDER_LAST_UPDATE, [self.c, self.a, self.b],
+ Facets.ORDER_LAST_UPDATE,
+ [self.c, self.a, self.b],
collections=[self.collection1],
- customlist_restriction_sets=[[self.list2]]
+ customlist_restriction_sets=[[self.list2]],
)
assert_order(
- Facets.ORDER_LAST_UPDATE, [self.c, self.a],
- customlist_restriction_sets=[[self.list1], [self.list3]]
+ Facets.ORDER_LAST_UPDATE,
+ [self.c, self.a],
+ customlist_restriction_sets=[[self.list1], [self.list3]],
)
assert_order(
- Facets.ORDER_LAST_UPDATE, [self.e, self.d],
+ Facets.ORDER_LAST_UPDATE,
+ [self.e, self.d],
collections=[self.collection3],
- customlist_restriction_sets=[[self.extra_list]]
+ customlist_restriction_sets=[[self.extra_list]],
)
@@ -1519,39 +1491,41 @@ def populate_works(self):
# Create a number of Contributor objects--some fragmentary--
# representing the same person.
self.full = Contributor(
- display_name='Ann Leckie', sort_name='Leckie, Ann', viaf="73520345",
- lc="n2013008575"
+ display_name="Ann Leckie",
+ sort_name="Leckie, Ann",
+ viaf="73520345",
+ lc="n2013008575",
)
self.display_name = Contributor(
- sort_name=Edition.UNKNOWN_AUTHOR, display_name='ann leckie'
- )
- self.sort_name = Contributor(sort_name='LECKIE, ANN')
- self.viaf = Contributor(
- sort_name=Edition.UNKNOWN_AUTHOR, viaf="73520345"
- )
- self.lc = Contributor(
- sort_name=Edition.UNKNOWN_AUTHOR, lc="n2013008575"
+ sort_name=Edition.UNKNOWN_AUTHOR, display_name="ann leckie"
)
+ self.sort_name = Contributor(sort_name="LECKIE, ANN")
+ self.viaf = Contributor(sort_name=Edition.UNKNOWN_AUTHOR, viaf="73520345")
+ self.lc = Contributor(sort_name=Edition.UNKNOWN_AUTHOR, lc="n2013008575")
# Create a different Work for every Contributor object.
# Alternate among the various 'author match' roles.
self.works = []
roles = list(Filter.AUTHOR_MATCH_ROLES)
for i, (contributor, title, attribute) in enumerate(
- [(self.full, "Ancillary Justice", 'justice'),
- (self.display_name, "Ancillary Sword", 'sword'),
- (self.sort_name, "Ancillary Mercy", 'mercy'),
- (self.viaf, "Provenance", 'provenance'),
- (self.lc, "Raven Tower", 'raven'),
- ]):
+ [
+ (self.full, "Ancillary Justice", "justice"),
+ (self.display_name, "Ancillary Sword", "sword"),
+ (self.sort_name, "Ancillary Mercy", "mercy"),
+ (self.viaf, "Provenance", "provenance"),
+ (self.lc, "Raven Tower", "raven"),
+ ]
+ ):
self._db.add(contributor)
edition, ignore = self._edition(
title=title, authors=[], with_license_pool=True
)
contribution, was_new = get_one_or_create(
- self._db, Contribution, edition=edition,
+ self._db,
+ Contribution,
+ edition=edition,
contributor=contributor,
- role=roles[i % len(roles)]
+ role=roles[i % len(roles)],
)
work = self.default_work(
presentation_edition=edition,
@@ -1564,30 +1538,28 @@ def populate_works(self):
# always be filtered out.
edition, ignore = self._edition(
title="Science Fiction: The Best of the Year (2007 Edition)",
- authors=[], with_license_pool=True
+ authors=[],
+ with_license_pool=True,
)
contribution, is_new = get_one_or_create(
- self._db, Contribution, edition=edition, contributor=self.full,
- role=Contributor.CONTRIBUTOR_ROLE
- )
- self.literary_wonderlands = self.default_work(
- presentation_edition=edition
+ self._db,
+ Contribution,
+ edition=edition,
+ contributor=self.full,
+ role=Contributor.CONTRIBUTOR_ROLE,
)
+ self.literary_wonderlands = self.default_work(presentation_edition=edition)
# Another decoy. This work is by a different person and will
# always be filtered out.
- self.ubik = self.default_work(
- title="Ubik", authors=["Phillip K. Dick"]
- )
+ self.ubik = self.default_work(title="Ubik", authors=["Phillip K. Dick"])
def test_author_match(self):
# By providing a Contributor object with all the identifiers,
# we get every work with an author-type contribution from
# someone who can be identified with that Contributor.
- self._expect_results(
- self.works, None, Filter(author=self.full), ordered=False
- )
+ self._expect_results(self.works, None, Filter(author=self.full), ordered=False)
# If we provide a Contributor object with partial information,
# we can only get works that are identifiable with that
@@ -1604,9 +1576,7 @@ def test_author_match(self):
(Filter(author=self.viaf), self.provenance),
(Filter(author=self.lc), self.raven),
]:
- self._expect_results(
- [self.justice, extra], None, filter, ordered=False
- )
+ self._expect_results([self.justice, extra], None, filter, ordered=False)
# ContributorData also works here.
@@ -1616,8 +1586,10 @@ def test_author_match(self):
# that knows both.
author = ContributorData(sort_name="Leckie, Ann", viaf="73520345")
self._expect_results(
- [self.justice, self.mercy, self.provenance], None,
- Filter(author=author), ordered=False
+ [self.justice, self.mercy, self.provenance],
+ None,
+ Filter(author=author),
+ ordered=False,
)
# The filter can also accommodate very minor variants in names
@@ -1626,8 +1598,7 @@ def test_author_match(self):
for variant in ("ann leckie", "Àñn Léckiê"):
author = ContributorData(display_name=variant)
self._expect_results(
- [self.justice, self.sword], None,
- Filter(author=author), ordered=False
+ [self.justice, self.sword], None, Filter(author=author), ordered=False
)
# It cannot accommodate misspellings, no matter how minor.
@@ -1636,12 +1607,12 @@ def test_author_match(self):
# If the information in the ContributorData is inconsistent,
# the results may also be inconsistent.
- author = ContributorData(
- sort_name="Dick, Phillip K.", lc="n2013008575"
- )
+ author = ContributorData(sort_name="Dick, Phillip K.", lc="n2013008575")
self._expect_results(
[self.justice, self.raven, self.ubik],
- None, Filter(author=author), ordered=False
+ None,
+ Filter(author=author),
+ ordered=False,
)
@@ -1663,16 +1634,17 @@ def populate_works(self):
self.ya_romance = _work(
title="Gumby In Love",
authors="Pokey",
- audience=Classifier.AUDIENCE_YOUNG_ADULT, genre="Romance"
+ audience=Classifier.AUDIENCE_YOUNG_ADULT,
+ genre="Romance",
)
self.ya_romance.presentation_edition.subtitle = (
"Modern Fairytale Series, Book 3"
)
self.parent_book = _work(
- title="Our Son Aziz",
- authors=["Fatima Ansari", "Shoukath Ansari"],
- genre="Biography & Memoir",
+ title="Our Son Aziz",
+ authors=["Fatima Ansari", "Shoukath Ansari"],
+ genre="Biography & Memoir",
)
self.behind_the_scenes = _work(
@@ -1693,9 +1665,7 @@ def populate_works(self):
)
self.book_by_someone_else = _work(
- title="The Deadly Graves",
- authors="Peter Ansari",
- genre="Mystery"
+ title="The Deadly Graves", authors="Peter Ansari", genre="Mystery"
)
def test_exact_matches(self):
@@ -1706,10 +1676,10 @@ def test_exact_matches(self):
# split across genre and subtitle.
expect(
[
- self.modern_romance, # "modern romance" in title
- self.ya_romance # "modern" in subtitle, genre "romance"
+ self.modern_romance, # "modern romance" in title
+ self.ya_romance, # "modern" in subtitle, genre "romance"
],
- "modern romance"
+ "modern romance",
)
# A full author match takes precedence over a partial author
@@ -1717,10 +1687,10 @@ def test_exact_matches(self):
# all all because it can't match two words.
expect(
[
- self.modern_romance, # "Aziz Ansari" in author
- self.parent_book, # "Aziz" in title, "Ansari" in author
+ self.modern_romance, # "Aziz Ansari" in author
+ self.parent_book, # "Aziz" in title, "Ansari" in author
],
- "aziz ansari"
+ "aziz ansari",
)
# 'peter graves' is a string that has exact matches in both
@@ -1751,17 +1721,16 @@ def test_exact_matches(self):
# if there are more than two search terms, only two must match.
order = [
- self.behind_the_scenes, # all words match in title
- self.biography_of_peter_graves, # title + genre 'biography'
- self.book_by_peter_graves, # author (no 'biography')
+ self.behind_the_scenes, # all words match in title
+ self.biography_of_peter_graves, # title + genre 'biography'
+ self.book_by_peter_graves, # author (no 'biography')
]
expect(order, "peter graves biography")
class TestFeaturedFacets(EndToEndSearchTest):
- """Test how a FeaturedFacets object affects search ordering.
- """
+ """Test how a FeaturedFacets object affects search ordering."""
def populate_works(self):
_work = self.default_work
@@ -1804,17 +1773,20 @@ def test_scoring_functions(self):
source = filter.FEATURABLE_SCRIPT % dict(
cutoff=f.minimum_featured_quality ** 2, exponent=2
)
- assert source == featurable.script['source']
+ assert source == featurable.script["source"]
# It can be currently available.
- availability_filter = available_now['filter']
+ availability_filter = available_now["filter"]
assert (
- dict(nested=dict(
- path='licensepools',
- query=dict(term={'licensepools.available': True})
- )) ==
- availability_filter.to_dict())
- assert 5 == available_now['weight']
+ dict(
+ nested=dict(
+ path="licensepools",
+ query=dict(term={"licensepools.available": True}),
+ )
+ )
+ == availability_filter.to_dict()
+ )
+ assert 5 == available_now["weight"]
# It can get lucky.
assert isinstance(random, RandomScore)
@@ -1830,30 +1802,35 @@ def test_scoring_functions(self):
# If custom lists are in play, it can also be featured on one
# of its custom lists.
- filter.customlist_restriction_sets = [[1,2], [3]]
- [featurable_2, available_now_2,
- featured_on_list] = f.scoring_functions(filter)
+ filter.customlist_restriction_sets = [[1, 2], [3]]
+ [featurable_2, available_now_2, featured_on_list] = f.scoring_functions(filter)
assert featurable_2 == featurable
assert available_now_2 == available_now
# Any list will do -- the customlist restriction sets aren't
# relevant here.
- featured_filter = featured_on_list['filter']
- assert (dict(
- nested=dict(
- path='customlists',
- query=dict(bool=dict(
- must=[{'term': {'customlists.featured': True}},
- {'terms': {'customlists.list_id': [1, 2, 3]}}])))) ==
- featured_filter.to_dict())
- assert 11 == featured_on_list['weight']
+ featured_filter = featured_on_list["filter"]
+ assert (
+ dict(
+ nested=dict(
+ path="customlists",
+ query=dict(
+ bool=dict(
+ must=[
+ {"term": {"customlists.featured": True}},
+ {"terms": {"customlists.list_id": [1, 2, 3]}},
+ ]
+ )
+ ),
+ )
+ )
+ == featured_filter.to_dict()
+ )
+ assert 11 == featured_on_list["weight"]
def test_run(self):
-
def works(worklist, facets):
- return worklist.works(
- self._db, facets, None, self.search, debug=True
- )
+ return worklist.works(self._db, facets, None, self.search, debug=True)
def assert_featured(description, worklist, facets, expect):
# Generate a list of featured works for the given `worklist`
@@ -1869,15 +1846,11 @@ def assert_featured(description, worklist, facets, expect):
# not_featured_on_list, not_featured_on_list shows up first because
# it's available right now.
w = works(worklist, facets)
- assert w.index(self.not_featured_on_list) < w.index(
- self.hq_not_available
- )
+ assert w.index(self.not_featured_on_list) < w.index(self.hq_not_available)
# not_featured_on_list shows up before featured_on_list because
# it's higher-quality and list membership isn't relevant.
- assert w.index(self.not_featured_on_list) < w.index(
- self.featured_on_list
- )
+ assert w.index(self.not_featured_on_list) < w.index(self.featured_on_list)
# Create a WorkList that's restricted to best-sellers.
best_sellers = WorkList()
@@ -1887,7 +1860,9 @@ def assert_featured(description, worklist, facets, expect):
# The featured work appears above the non-featured work,
# even though it's lower quality and is not available.
assert_featured(
- "Works from WorkList based on CustomList", best_sellers, facets,
+ "Works from WorkList based on CustomList",
+ best_sellers,
+ facets,
[self.featured_on_list, self.not_featured_on_list],
)
@@ -1899,9 +1874,7 @@ def assert_featured(description, worklist, facets, expect):
# to 0, which makes all works 'featured' and stops quality
# from being considered altogether. Basically all that matters
# is availability.
- all_featured_facets = FeaturedFacets(
- 0, random_seed=Filter.DETERMINISTIC
- )
+ all_featured_facets = FeaturedFacets(0, random_seed=Filter.DETERMINISTIC)
# We don't know exactly what order the books will be in,
# because even without the random element Elasticsearch is
# slightly nondeterministic, but we do expect that all of the
@@ -1924,15 +1897,19 @@ def assert_featured(description, worklist, facets, expect):
random_facets = FeaturedFacets(1, random_seed=43)
assert_featured(
"Works permuted by a random seed",
- worklist, random_facets,
- [self.hq_available_2, self.hq_available,
- self.not_featured_on_list, self.hq_not_available,
- self.featured_on_list],
+ worklist,
+ random_facets,
+ [
+ self.hq_available_2,
+ self.hq_available,
+ self.not_featured_on_list,
+ self.hq_not_available,
+ self.featured_on_list,
+ ],
)
class TestSearchBase(object):
-
def test__boost(self):
# Verify that _boost() converts a regular query (or list of queries)
# into a boosted query.
@@ -1962,10 +1939,9 @@ def test__boost(self):
def test__nest(self):
# Test the _nest method, which turns a normal query into a
# nested query.
- query = Term(**{"nested_field" : "value"})
+ query = Term(**{"nested_field": "value"})
nested = SearchBase._nest("subdocument", query)
- assert (Nested(path='subdocument', query=query) ==
- nested)
+ assert Nested(path="subdocument", query=query) == nested
def test_nestable(self):
# Test the _nestable helper method, which turns a normal
@@ -1975,35 +1951,29 @@ def test_nestable(self):
# A query on a field that's not in a subdocument is
# unaffected.
field = "name.minimal"
- normal_query = Term(**{field : "name"})
+ normal_query = Term(**{field: "name"})
assert normal_query == m(field, normal_query)
# A query on a subdocument field becomes a nested query on
# that subdocument.
field = "contributors.sort_name.minimal"
- subdocument_query = Term(**{field : "name"})
+ subdocument_query = Term(**{field: "name"})
nested = m(field, subdocument_query)
- assert (
- Nested(path='contributors', query=subdocument_query) ==
- nested)
+ assert Nested(path="contributors", query=subdocument_query) == nested
def test__match_term(self):
# _match_term creates a Match Elasticsearch object which does a
# match against a specific field.
m = SearchBase._match_term
qu = m("author", "flannery o'connor")
- assert (
- Term(author="flannery o'connor") ==
- qu)
+ assert Term(author="flannery o'connor") == qu
# If the field name references a subdocument, the query is
# embedded in a Nested object that describes how to match it
# against that subdocument.
field = "genres.name"
qu = m(field, "Biography")
- assert (
- Nested(path='genres', query=Term(**{field: "Biography"})) ==
- qu)
+ assert Nested(path="genres", query=Term(**{field: "Biography"})) == qu
def test__match_range(self):
# Test the _match_range helper method.
@@ -2012,7 +1982,7 @@ def test__match_range(self):
# This only matches if field.name has a value >= 5.
r = SearchBase._match_range("field.name", "gte", 5)
- assert r == {'range': {'field.name': {'gte': 5}}}
+ assert r == {"range": {"field.name": {"gte": 5}}}
def test__combine_hypotheses(self):
# Verify that _combine_hypotheses creates a DisMax query object
@@ -2036,28 +2006,26 @@ def test_make_target_age_query(self):
#
# This gives us two similar queries: one to use as a filter
# and one to use as a boost query.
- as_filter, as_query = Query.make_target_age_query((5,10))
+ as_filter, as_query = Query.make_target_age_query((5, 10))
# Here's the filter part: a book's age range must be include the
# 5-10 range, or it gets filtered out.
filter_clauses = [
- Range(**{"target_age.upper":dict(gte=5)}),
- Range(**{"target_age.lower":dict(lte=10)}),
+ Range(**{"target_age.upper": dict(gte=5)}),
+ Range(**{"target_age.lower": dict(lte=10)}),
]
assert Bool(must=filter_clauses) == as_filter
# Here's the query part: a book gets boosted if its
# age range fits _entirely_ within the target age range.
query_clauses = [
- Range(**{"target_age.upper":dict(lte=10)}),
- Range(**{"target_age.lower":dict(gte=5)}),
+ Range(**{"target_age.upper": dict(lte=10)}),
+ Range(**{"target_age.lower": dict(gte=5)}),
]
- assert (Bool(boost=1.1, must=filter_clauses, should=query_clauses) ==
- as_query)
+ assert Bool(boost=1.1, must=filter_clauses, should=query_clauses) == as_query
class TestQuery(DatabaseTest):
-
def test_constructor(self):
# Verify that the Query constructor sets members with
# no processing.
@@ -2085,7 +2053,6 @@ def test_constructor(self):
assert None == query.contains_stopwords
assert 0 == query.fuzzy_coefficient
-
def test_build(self):
# Verify that the build() method combines the 'query' part of
# a Query and the 'filter' part to create a single
@@ -2101,9 +2068,14 @@ class MockSearch(object):
created by to get to a certain point by following the
.parent relation.
"""
+
def __init__(
- self, parent=None, query=None, nested_filter_calls=None,
- order=None, script_fields=None
+ self,
+ parent=None,
+ query=None,
+ nested_filter_calls=None,
+ order=None,
+ script_fields=None,
):
self.parent = parent
self._query = query
@@ -2118,8 +2090,7 @@ def filter(self, **kwargs):
"""
new_filters = self.nested_filter_calls + [kwargs]
return MockSearch(
- self, self._query, new_filters, self.order,
- self._script_fields
+ self, self._query, new_filters, self.order, self._script_fields
)
def query(self, query):
@@ -2129,22 +2100,27 @@ def query(self, query):
:return: A New MockSearch object.
"""
return MockSearch(
- self, query, self.nested_filter_calls, self.order,
- self._script_fields
+ self,
+ query,
+ self.nested_filter_calls,
+ self.order,
+ self._script_fields,
)
def sort(self, *order_fields):
"""Simulate the application of a sort order."""
return MockSearch(
- self, self._query, self.nested_filter_calls, order_fields,
- self._script_fields
+ self,
+ self._query,
+ self.nested_filter_calls,
+ order_fields,
+ self._script_fields,
)
def script_fields(self, **kwargs):
"""Simulate the addition of script fields."""
return MockSearch(
- self, self._query, self.nested_filter_calls, self.order,
- kwargs
+ self, self._query, self.nested_filter_calls, self.order, kwargs
)
class MockQuery(Query):
@@ -2165,13 +2141,13 @@ def modify_search_query(self, search):
# replace them with simpler versions.
class MockFilter(object):
- universal_base_term = Q('term', universal_base_called=True)
- universal_nested_term = Q('term', universal_nested_called=True)
+ universal_base_term = Q("term", universal_base_called=True)
+ universal_nested_term = Q("term", universal_nested_called=True)
universal_nested_filter = dict(nested_called=[universal_nested_term])
@classmethod
def universal_base_filter(cls):
- cls.universal_called=True
+ cls.universal_called = True
return cls.universal_base_term
@classmethod
@@ -2223,18 +2199,19 @@ def validate_universal_calls(cls):
# The pagination filter was the last one to be applied.
pagination = built.nested_filter_calls.pop()
- assert dict(name_or_query='pagination modified') == pagination
+ assert dict(name_or_query="pagination modified") == pagination
# The mocked universal nested filter was applied
# just before that.
universal_nested = built.nested_filter_calls.pop()
assert (
dict(
- name_or_query='nested',
- path='nested_called',
- query=Bool(filter=[MockFilter.universal_nested_term])
- ) ==
- universal_nested)
+ name_or_query="nested",
+ path="nested_called",
+ query=Bool(filter=[MockFilter.universal_nested_term]),
+ )
+ == universal_nested
+ )
# The result of Query.elasticsearch_query is used as the basis
# for the Search object.
@@ -2263,20 +2240,19 @@ def validate_universal_calls(cls):
# The filter we passed in was combined with the universal
# base filter into a boolean query, with its own 'must'.
main_filter.must = main_filter.must + [MockFilter.universal_base_term]
- assert (
- underlying_query.filter ==
- [main_filter])
+ assert underlying_query.filter == [main_filter]
# There are no nested filters, apart from the universal one.
assert {} == nested_filters
universal_nested = built.nested_filter_calls.pop()
assert (
dict(
- name_or_query='nested',
- path='nested_called',
- query=Bool(filter=[MockFilter.universal_nested_term])
- ) ==
- universal_nested)
+ name_or_query="nested",
+ path="nested_called",
+ query=Bool(filter=[MockFilter.universal_nested_term]),
+ )
+ == universal_nested
+ )
assert [] == built.nested_filter_calls
# At this point the universal filters are more trouble than they're
@@ -2285,10 +2261,7 @@ def validate_universal_calls(cls):
MockFilter.universal_nested_filter = None
# Now let's try a combination of regular filters and nested filters.
- filter = Filter(
- fiction=True,
- collections=[self._default_collection]
- )
+ filter = Filter(fiction=True, collections=[self._default_collection])
qu = MockQuery("query string", filter=filter)
built = qu.build(search)
underlying_query = built._query
@@ -2296,7 +2269,7 @@ def validate_universal_calls(cls):
# We get a main filter (for the fiction restriction) and one
# nested filter.
main_filter, nested_filters = filter.build()
- [nested_licensepool_filter] = nested_filters.pop('licensepools')
+ [nested_licensepool_filter] = nested_filters.pop("licensepools")
assert {} == nested_filters
# As before, the main filter has been applied to the underlying
@@ -2307,9 +2280,9 @@ def validate_universal_calls(cls):
# into Search.filter(). This applied an additional filter on the
# 'licensepools' subdocument.
[filter_call] = built.nested_filter_calls
- assert 'nested' == filter_call['name_or_query']
- assert 'licensepools' == filter_call['path']
- filter_as_query = filter_call['query']
+ assert "nested" == filter_call["name_or_query"]
+ assert "licensepools" == filter_call["path"]
+ filter_as_query = filter_call["query"]
assert Bool(filter=nested_licensepool_filter) == filter_as_query
# Now we're going to test how queries are built to accommodate
@@ -2335,65 +2308,73 @@ def from_facets(*args, **kwargs):
# A non-nested filter is applied on the 'quality' field.
[quality_filter] = built._query.filter
quality_range = Filter._match_range(
- 'quality', 'gte', self._default_library.minimum_featured_quality
+ "quality", "gte", self._default_library.minimum_featured_quality
)
- assert Q('bool', must=[quality_range], must_not=[RESEARCH]) == quality_filter
+ assert Q("bool", must=[quality_range], must_not=[RESEARCH]) == quality_filter
# When using the AVAILABLE_OPEN_ACCESS availability restriction...
- built = from_facets(Facets.COLLECTION_FULL,
- Facets.AVAILABLE_OPEN_ACCESS, None)
+ built = from_facets(Facets.COLLECTION_FULL, Facets.AVAILABLE_OPEN_ACCESS, None)
# An additional nested filter is applied.
[available_now] = built.nested_filter_calls
- assert 'nested' == available_now['name_or_query']
- assert 'licensepools' == available_now['path']
+ assert "nested" == available_now["name_or_query"]
+ assert "licensepools" == available_now["path"]
# It finds only license pools that are open access.
- nested_filter = available_now['query']
- open_access = dict(term={'licensepools.open_access': True})
- assert (
- nested_filter.to_dict() ==
- {'bool': {'filter': [open_access]}})
+ nested_filter = available_now["query"]
+ open_access = dict(term={"licensepools.open_access": True})
+ assert nested_filter.to_dict() == {"bool": {"filter": [open_access]}}
# When using the AVAILABLE_NOW restriction...
built = from_facets(Facets.COLLECTION_FULL, Facets.AVAILABLE_NOW, None)
# An additional nested filter is applied.
[available_now] = built.nested_filter_calls
- assert 'nested' == available_now['name_or_query']
- assert 'licensepools' == available_now['path']
+ assert "nested" == available_now["name_or_query"]
+ assert "licensepools" == available_now["path"]
# It finds only license pools that are open access *or* that have
# active licenses.
- nested_filter = available_now['query']
- available = {'term': {'licensepools.available': True}}
- assert (
- nested_filter.to_dict() ==
- {'bool': {'filter': [{'bool': {'should': [open_access, available],
- 'minimum_should_match': 1}}]}})
+ nested_filter = available_now["query"]
+ available = {"term": {"licensepools.available": True}}
+ assert nested_filter.to_dict() == {
+ "bool": {
+ "filter": [
+ {
+ "bool": {
+ "should": [open_access, available],
+ "minimum_should_match": 1,
+ }
+ }
+ ]
+ }
+ }
# When using the AVAILABLE_NOT_NOW restriction...
built = from_facets(Facets.COLLECTION_FULL, Facets.AVAILABLE_NOT_NOW, None)
# An additional nested filter is applied.
[not_available_now] = built.nested_filter_calls
- assert 'nested' == available_now['name_or_query']
- assert 'licensepools' == available_now['path']
+ assert "nested" == available_now["name_or_query"]
+ assert "licensepools" == available_now["path"]
# It finds only license pools that are licensed, but not
# currently available or open access.
- nested_filter = not_available_now['query']
- not_available = {'term': {'licensepools.available': False}}
- licensed = {'term': {'licensepools.licensed': True}}
- not_open_access = {'term': {'licensepools.open_access': False}}
- assert (
- nested_filter.to_dict() ==
- {'bool': {'filter': [{'bool': {'must': [not_open_access, licensed, not_available]}}]}})
+ nested_filter = not_available_now["query"]
+ not_available = {"term": {"licensepools.available": False}}
+ licensed = {"term": {"licensepools.licensed": True}}
+ not_open_access = {"term": {"licensepools.open_access": False}}
+ assert nested_filter.to_dict() == {
+ "bool": {
+ "filter": [
+ {"bool": {"must": [not_open_access, licensed, not_available]}}
+ ]
+ }
+ }
# If the Filter specifies script fields, those fields are
# added to the Query through a call to script_fields()
- script_fields = dict(field1="Definition1",
- field2="Definition2")
+ script_fields = dict(field1="Definition1", field2="Definition2")
filter = Filter(script_fields=script_fields)
qu = MockQuery("query string", filter=filter)
built = qu.build(search)
@@ -2413,7 +2394,7 @@ def from_facets(*args, **kwargs):
# But a number of other sort fields are also employed to act
# as tiebreakers.
- for tiebreaker_field in ('sort_author', 'sort_title', 'work_id'):
+ for tiebreaker_field in ("sort_author", "sort_title", "work_id"):
assert {tiebreaker_field: "asc"} == order.pop(0)
assert [] == order
@@ -2426,9 +2407,7 @@ def test_build_match_nothing(self):
# is set, it gets built into a simple filter that matches
# nothing, with no nested subfilters.
filter = Filter(
- fiction=True,
- collections=[self._default_collection],
- match_nothing = True
+ fiction=True, collections=[self._default_collection], match_nothing=True
)
main, nested = filter.build()
assert MatchNone() == main
@@ -2467,13 +2446,18 @@ def title_multi_match_for(self, other_field):
"hypothesis based on substring",
"another such hypothesis",
)
+
@property
def parsed_query_matches(self):
return self.SUBSTRING_HYPOTHESES, "only valid with this filter"
def _hypothesize(
- self, hypotheses, new_hypothesis, boost="default",
- filters=None, **kwargs
+ self,
+ hypotheses,
+ new_hypothesis,
+ boost="default",
+ filters=None,
+ **kwargs
):
self._boosts[new_hypothesis] = boost
if kwargs:
@@ -2504,33 +2488,28 @@ def _combine_hypotheses(self, hypotheses):
assert result == query._combine_hypotheses_called_with
# We ended up with a number of hypothesis:
- assert (result ==
- [
- # Several hypotheses checking whether the search query is an attempt to
- # match a single field -- the results of calling match_one_field()
- # many times.
- 'match title',
- 'match subtitle',
- 'match series',
- 'match publisher',
- 'match imprint',
-
- # The results of calling match_author_queries() once.
- 'author query 1',
- 'author query 2',
-
- # The results of calling match_topic_queries() once.
- 'topic query',
-
- # The results of calling multi_match() for three fields.
- 'multi match title+subtitle',
- 'multi match title+series',
- 'multi match title+author',
-
- # The 'query' part of the return value of
- # parsed_query_matches()
- Mock.SUBSTRING_HYPOTHESES
- ])
+ assert result == [
+ # Several hypotheses checking whether the search query is an attempt to
+ # match a single field -- the results of calling match_one_field()
+ # many times.
+ "match title",
+ "match subtitle",
+ "match series",
+ "match publisher",
+ "match imprint",
+ # The results of calling match_author_queries() once.
+ "author query 1",
+ "author query 2",
+ # The results of calling match_topic_queries() once.
+ "topic query",
+ # The results of calling multi_match() for three fields.
+ "multi match title+subtitle",
+ "multi match title+series",
+ "multi match title+author",
+ # The 'query' part of the return value of
+ # parsed_query_matches()
+ Mock.SUBSTRING_HYPOTHESES,
+ ]
# That's not the whole story, though. parsed_query_matches()
# said it was okay to test certain hypotheses, but only
@@ -2540,9 +2519,9 @@ def _combine_hypotheses(self, hypotheses):
# of _hypothesize added it to the 'filters' dict to indicate
# we know that those filters go with the substring
# hypotheses. That's the only time 'filters' was touched.
- assert (
- {Mock.SUBSTRING_HYPOTHESES: 'only valid with this filter'} ==
- query._filters)
+ assert {
+ Mock.SUBSTRING_HYPOTHESES: "only valid with this filter"
+ } == query._filters
# Each call to _hypothesize included a boost factor indicating
# how heavily to weight that hypothesis. Rather than do
@@ -2550,29 +2529,28 @@ def _combine_hypotheses(self, hypotheses):
# anyway -- we just stored it in _boosts.
boosts = sorted(list(query._boosts.items()), key=lambda x: str(x[0]))
boosts = sorted(boosts, key=lambda x: x[1])
- assert (boosts ==
- [
- ('match imprint', 1),
- ('match publisher', 1),
- ('match series', 1),
- ('match subtitle', 1),
- ('match title', 1),
- # The only non-mocked value here is this one. The
- # substring hypotheses have their own weights, which
- # we don't see in this test. This is saying that if a
- # book matches those sub-hypotheses and _also_ matches
- # the filter, then whatever weight it got from the
- # sub-hypotheses should be boosted slighty. This gives
- # works that match the filter an edge over works that
- # don't.
- (Mock.SUBSTRING_HYPOTHESES, 1.1),
- ('author query 1', 2),
- ('author query 2', 3),
- ('topic query', 4),
- ('multi match title+author', 5),
- ('multi match title+series', 5),
- ('multi match title+subtitle', 5),
- ])
+ assert boosts == [
+ ("match imprint", 1),
+ ("match publisher", 1),
+ ("match series", 1),
+ ("match subtitle", 1),
+ ("match title", 1),
+ # The only non-mocked value here is this one. The
+ # substring hypotheses have their own weights, which
+ # we don't see in this test. This is saying that if a
+ # book matches those sub-hypotheses and _also_ matches
+ # the filter, then whatever weight it got from the
+ # sub-hypotheses should be boosted slighty. This gives
+ # works that match the filter an edge over works that
+ # don't.
+ (Mock.SUBSTRING_HYPOTHESES, 1.1),
+ ("author query 1", 2),
+ ("author query 2", 3),
+ ("topic query", 4),
+ ("multi match title+author", 5),
+ ("multi match title+series", 5),
+ ("multi match title+subtitle", 5),
+ ]
def test_match_one_field_hypotheses(self):
# Test our ability to generate hypotheses that a search string
@@ -2583,8 +2561,8 @@ class Mock(Query):
stopword_field=3,
stemmable_field=4,
)
- STOPWORD_FIELDS = ['stopword_field']
- STEMMABLE_FIELDS = ['stemmable_field']
+ STOPWORD_FIELDS = ["stopword_field"]
+ STEMMABLE_FIELDS = ["stemmable_field"]
def __init__(self, *args, **kwargs):
super(Mock, self).__init__(*args, **kwargs)
@@ -2603,7 +2581,7 @@ def _fuzzy_matches(self, field_name, **kwargs):
m = query.match_one_field_hypotheses
# We'll get a Term query and a MatchPhrase query.
- term, phrase = list(m('regular_field'))
+ term, phrase = list(m("regular_field"))
# The Term hypothesis tries to find an exact match for 'book'
# in this field. It is boosted 1000x relative to the baseline
@@ -2612,6 +2590,7 @@ def validate_keyword(field, hypothesis, expect_weight):
hypothesis, weight = hypothesis
assert Term(**{"%s.keyword" % field: "book"}) == hypothesis
assert expect_weight == weight
+
validate_keyword("regular_field", term, 2000)
# The MatchPhrase hypothesis tries to find a partial phrase
@@ -2621,6 +2600,7 @@ def validate_minimal(field, hypothesis, expect_weight):
hypothesis, weight = hypothesis
assert MatchPhrase(**{"%s.minimal" % field: "book"}) == hypothesis
assert expect_weight == weight
+
validate_minimal("regular_field", phrase, 2)
# Now let's try the same query, but with fuzzy searching
@@ -2638,15 +2618,18 @@ def validate_minimal(field, hypothesis, expect_weight):
def validate_fuzzy(field, hypothesis, phrase_weight):
minimal_field = field + ".minimal"
hypothesis, weight = fuzzy
- assert 'fuzzy match for %s' % minimal_field == hypothesis
- assert phrase_weight*0.66 == weight
+ assert "fuzzy match for %s" % minimal_field == hypothesis
+ assert phrase_weight * 0.66 == weight
# Validate standard arguments passed into _fuzzy_matches.
# Since a fuzzy match is kind of loose, we don't allow a
# match on a single word of a multi-word query. At least
# two of the words have to be involved.
- assert (dict(minimum_should_match=2, query='book') ==
- query.fuzzy_calls[minimal_field])
+ assert (
+ dict(minimum_should_match=2, query="book")
+ == query.fuzzy_calls[minimal_field]
+ )
+
validate_fuzzy("regular_field", fuzzy, 2)
# Now try a field where stopwords might be relevant.
@@ -2668,8 +2651,7 @@ def validate_fuzzy(field, hypothesis, phrase_weight):
# stopword_field that leaves the stopwords in place. This
# hypothesis is boosted just above the baseline hypothesis.
hypothesis, weight = stopword
- assert (hypothesis ==
- MatchPhrase(**{"stopword_field.with_stopwords": "book"}))
+ assert hypothesis == MatchPhrase(**{"stopword_field.with_stopwords": "book"})
assert weight == 3 * Mock.SLIGHTLY_ABOVE_BASELINE
# Finally, let's try a stemmable field.
@@ -2683,13 +2665,9 @@ def validate_fuzzy(field, hypothesis, phrase_weight):
# minimum_should_match=2 here for the same reason we do it for
# the fuzzy search -- a normal Match query is kind of loose.
hypothesis, weight = stemmable
- assert (hypothesis ==
- Match(
- stemmable_field=dict(
- minimum_should_match=2,
- query="book"
- )
- ))
+ assert hypothesis == Match(
+ stemmable_field=dict(minimum_should_match=2, query="book")
+ )
assert weight == 4 * 0.75
def test_match_author_hypotheses(self):
@@ -2708,24 +2686,20 @@ def _author_field_must_match(self, base_field, query_string=None):
# display name, it's the author's sort name, or it matches the
# author's sort name when automatically converted to a sort
# name.
- assert (
- [
- 'display_name must match ursula le guin',
- 'sort_name must match le guin, ursula'
- ] ==
- hypotheses)
+ assert [
+ "display_name must match ursula le guin",
+ "sort_name must match le guin, ursula",
+ ] == hypotheses
# If the string passed in already looks like a sort name, we
# don't try to convert it -- but someone's name may contain a
# comma, so we do check both fields.
query = Mock("le guin, ursula")
hypotheses = list(query.match_author_hypotheses)
- assert (
- [
- 'display_name must match le guin, ursula',
- 'sort_name must match le guin, ursula',
- ] ==
- hypotheses)
+ assert [
+ "display_name must match le guin, ursula",
+ "sort_name must match le guin, ursula",
+ ] == hypotheses
def test__author_field_must_match(self):
class Mock(Query):
@@ -2743,20 +2717,20 @@ def _role_must_also_match(self, hypothesis):
# run the result through _role_must_also_match() to ensure we
# only get works where this author made a major contribution.
[(hypothesis, weight)] = list(m("display_name"))
- assert (
- ['maybe contributors.display_name matches ursula le guin',
- '(but the role must be appropriate)'] ==
- hypothesis)
+ assert [
+ "maybe contributors.display_name matches ursula le guin",
+ "(but the role must be appropriate)",
+ ] == hypothesis
assert 6 == weight
# We can pass in a different query string to override
# .query_string. This is how we test a match against our guess
# at an author's sort name.
[(hypothesis, weight)] = list(m("sort_name", "le guin, ursula"))
- assert (
- ['maybe contributors.sort_name matches le guin, ursula',
- '(but the role must be appropriate)'] ==
- hypothesis)
+ assert [
+ "maybe contributors.sort_name matches le guin, ursula",
+ "(but the role must be appropriate)",
+ ] == hypothesis
assert 6 == weight
def test__role_must_also_match(self):
@@ -2768,7 +2742,7 @@ def _nest(cls, subdocument, base):
# Verify that _role_must_also_match() puts an appropriate
# restriction on a match against a field in the 'contributors'
# sub-document.
- original_query = Term(**{'contributors.sort_name': 'ursula le guin'})
+ original_query = Term(**{"contributors.sort_name": "ursula le guin"})
modified = Mock._role_must_also_match(original_query)
# The resulting query was run through Mock._nest. In a real
@@ -2781,7 +2755,7 @@ def _nest(cls, subdocument, base):
# The original query was combined with an extra clause, which
# only matches people if their contribution to a book was of
# the type that library patrons are likely to search for.
- extra = Terms(**{"contributors.role": ['Primary Author', 'Author', 'Narrator']})
+ extra = Terms(**{"contributors.role": ["Primary Author", "Author", "Narrator"]})
assert Bool(must=[original_query, extra]) == modified_base
def test_match_topic_hypotheses(self):
@@ -2797,11 +2771,12 @@ def test_match_topic_hypotheses(self):
query="whales",
fields=["summary", "classifications.term"],
type="best_fields",
- ) ==
- hypothesis)
+ )
+ == hypothesis
+ )
# The weight of the hypothesis is the base weight associated
# with the 'summary' field.
- assert Query.WEIGHT_FOR_FIELD['summary'] == weight
+ assert Query.WEIGHT_FOR_FIELD["summary"] == weight
def test_title_multi_match_for(self):
# Test our ability to hypothesize that a query string might
@@ -2810,16 +2785,14 @@ def test_title_multi_match_for(self):
# If there's only one word in the query, then we don't bother
# making this hypothesis at all.
- assert (
- [] ==
- list(Query("grasslands").title_multi_match_for("other field")))
+ assert [] == list(Query("grasslands").title_multi_match_for("other field"))
query = Query("grass lands")
[(hypothesis, weight)] = list(query.title_multi_match_for("author"))
expect = MultiMatch(
query="grass lands",
- fields = ['title.minimal', 'author.minimal'],
+ fields=["title.minimal", "author.minimal"],
type="cross_fields",
operator="and",
minimum_should_match="100%",
@@ -2828,9 +2801,9 @@ def test_title_multi_match_for(self):
# The weight of this hypothesis is between the weight of a
# pure title match and the weight of a pure author match.
- title_weight = Query.WEIGHT_FOR_FIELD['title']
- author_weight = Query.WEIGHT_FOR_FIELD['author']
- assert weight == author_weight * (author_weight/title_weight)
+ title_weight = Query.WEIGHT_FOR_FIELD["title"]
+ author_weight = Query.WEIGHT_FOR_FIELD["author"]
+ assert weight == author_weight * (author_weight / title_weight)
def test_parsed_query_matches(self):
# Test our ability to take a query like "asteroids
@@ -2851,6 +2824,7 @@ def test_hypothesize(self):
# boosting it if necessary.
class Mock(Query):
boost_extras = []
+
@classmethod
def _boost(cls, boost, queries, filters=None, **kwargs):
if filters or kwargs:
@@ -2872,15 +2846,22 @@ def _boost(cls, boost, queries, filters=None, **kwargs):
assert [] == Mock.boost_extras
Mock._hypothesize(hypotheses, "another query object", 1)
- assert (["query object boosted by 10", "another query object boosted by 1"] ==
- hypotheses)
+ assert [
+ "query object boosted by 10",
+ "another query object boosted by 1",
+ ] == hypotheses
assert [] == Mock.boost_extras
# If a filter or any other arguments are passed in, those arguments
# are propagated to _boost().
hypotheses = []
- Mock._hypothesize(hypotheses, "query with filter", 2, filters="some filters",
- extra="extra kwarg")
+ Mock._hypothesize(
+ hypotheses,
+ "query with filter",
+ 2,
+ filters="some filters",
+ extra="extra kwarg",
+ )
assert ["query with filter boosted by 2"] == hypotheses
assert [("some filters", dict(extra="extra kwarg"))] == Mock.boost_extras
@@ -2899,13 +2880,18 @@ class MockQuery(Query):
"""Create 'query' objects that are easier to test than
the ones the Query class makes.
"""
+
@classmethod
def _match_term(cls, field, query):
return (field, query)
@classmethod
def make_target_age_query(cls, query, boost="default boost"):
- return ("target age (filter)", query), ("target age (query)", query, boost)
+ return ("target age (filter)", query), (
+ "target age (query)",
+ query,
+ boost,
+ )
@property
def elasticsearch_query(self):
@@ -2931,7 +2917,7 @@ def elasticsearch_query(self):
# parser.filters contains the filters that we think we were
# able to derive from the query string.
- assert [('genres.name', 'Science Fiction')] == parser.filters
+ assert [("genres.name", "Science Fiction")] == parser.filters
# parser.match_queries contains the result of putting the rest
# of the query string into a Query object (or, here, our
@@ -2963,41 +2949,33 @@ def assert_parses_as(query_string, filters, remainder, extra_queries=None):
assert_parses_as(
"science fiction about dogs",
("genres.name", "Science Fiction"),
- "about dogs"
+ "about dogs",
)
# Test audiences.
assert_parses_as(
- "children's picture books",
- ("audience", "children"),
- "picture books"
+ "children's picture books", ("audience", "children"), "picture books"
)
# (It's possible for the entire query string to be eaten up,
# such that there is no remainder match at all.)
assert_parses_as(
"young adult romance",
- [("genres.name", "Romance"),
- ("audience", "youngadult")],
- ''
+ [("genres.name", "Romance"), ("audience", "youngadult")],
+ "",
)
# Test fiction/nonfiction status.
- assert_parses_as(
- "fiction dinosaurs",
- ("fiction", "fiction"),
- "dinosaurs"
- )
+ assert_parses_as("fiction dinosaurs", ("fiction", "fiction"), "dinosaurs")
# (Genres are parsed before fiction/nonfiction; otherwise
# "science fiction" would be chomped by a search for "fiction"
# and "nonfiction" would not be picked up.)
assert_parses_as(
"science fiction or nonfiction dinosaurs",
- [("genres.name", "Science Fiction"),
- ("fiction", "nonfiction")],
- "or dinosaurs"
+ [("genres.name", "Science Fiction"), ("fiction", "nonfiction")],
+ "or dinosaurs",
)
# Test target age.
@@ -3010,17 +2988,16 @@ def assert_parses_as(query_string, filters, remainder, extra_queries=None):
# age range).
assert_parses_as(
"grade 5 science",
- [("genres.name", "Science"),
- ("target age (filter)", (10, 10))],
- '',
- ("target age (query)", (10, 10), 'default boost')
+ [("genres.name", "Science"), ("target age (filter)", (10, 10))],
+ "",
+ ("target age (query)", (10, 10), "default boost"),
)
assert_parses_as(
- 'divorce ages 10 and up',
+ "divorce ages 10 and up",
("target age (filter)", (10, 14)),
- 'divorce and up', # TODO: not ideal
- ("target age (query)", (10, 14), 'default boost'),
+ "divorce and up", # TODO: not ideal
+ ("target age (query)", (10, 14), "default boost"),
)
# Nothing can be parsed out from this query--it's an author's name
@@ -3062,19 +3039,20 @@ def test_add_target_age_filter(self):
# Here's the filter part: a book's age range must be include the
# 10-11 range, or it gets filtered out.
filter_clauses = [
- Range(**{"target_age.upper":dict(gte=10)}),
- Range(**{"target_age.lower":dict(lte=11)}),
+ Range(**{"target_age.upper": dict(gte=10)}),
+ Range(**{"target_age.lower": dict(lte=11)}),
]
assert [Bool(must=filter_clauses)] == parser.filters
# Here's the query part: a book gets boosted if its
# age range fits _entirely_ within the target age range.
query_clauses = [
- Range(**{"target_age.upper":dict(lte=11)}),
- Range(**{"target_age.lower":dict(gte=10)}),
+ Range(**{"target_age.upper": dict(lte=11)}),
+ Range(**{"target_age.lower": dict(gte=10)}),
]
- assert ([Bool(boost=1.1, must=filter_clauses, should=query_clauses)] ==
- parser.match_queries)
+ assert [
+ Bool(boost=1.1, must=filter_clauses, should=query_clauses)
+ ] == parser.match_queries
def test__without_match(self):
# Test our ability to remove matched text from a string.
@@ -3089,14 +3067,11 @@ def test__without_match(self):
class TestFilter(DatabaseTest):
-
def setup_method(self):
super(TestFilter, self).setup_method()
# Look up three Genre objects which can be used to make filters.
- self.literary_fiction, ignore = Genre.lookup(
- self._db, "Literary Fiction"
- )
+ self.literary_fiction, ignore = Genre.lookup(self._db, "Literary Fiction")
self.fantasy, ignore = Genre.lookup(self._db, "Fantasy")
self.horror, ignore = Genre.lookup(self._db, "Horror")
@@ -3121,9 +3096,13 @@ def test_constructor(self):
# the Filter object. If necessary, they'll be cleaned up
# later, during build().
filter = Filter(
- media=media, languages=languages,
- fiction=fiction, audiences=audiences, author=author,
- match_nothing=match_nothing, min_score=min_score
+ media=media,
+ languages=languages,
+ fiction=fiction,
+ audiences=audiences,
+ author=author,
+ match_nothing=match_nothing,
+ min_score=min_score,
)
assert media == filter.media
assert languages == filter.languages
@@ -3166,12 +3145,12 @@ def test_constructor(self):
assert None == empty_filter.target_age
one_year = Filter(target_age=8)
- assert (8,8) == one_year.target_age
+ assert (8, 8) == one_year.target_age
- year_range = Filter(target_age=(8,10))
- assert (8,10) == year_range.target_age
+ year_range = Filter(target_age=(8, 10))
+ assert (8, 10) == year_range.target_age
- year_range = Filter(target_age=NumericRange(3, 6, '()'))
+ year_range = Filter(target_age=NumericRange(3, 6, "()"))
assert (4, 5) == year_range.target_age
# Test genre_restriction_sets
@@ -3184,15 +3163,15 @@ def test_constructor(self):
# Restrict to books that are literary fiction AND (horror OR
# fantasy).
restricted = Filter(
- genre_restriction_sets = [
+ genre_restriction_sets=[
[self.horror, self.fantasy],
[self.literary_fiction],
]
)
- assert (
- [[self.horror.id, self.fantasy.id],
- [self.literary_fiction.id]] ==
- restricted.genre_restriction_sets)
+ assert [
+ [self.horror.id, self.fantasy.id],
+ [self.literary_fiction.id],
+ ] == restricted.genre_restriction_sets
# This is a restriction: 'only books that have no genre'
assert [[]] == Filter(genre_restriction_sets=[[]]).genre_restriction_sets
@@ -3201,26 +3180,28 @@ def test_constructor(self):
# In these three cases, there are no restrictions.
assert [] == empty_filter.customlist_restriction_sets
- assert [] == Filter(customlist_restriction_sets=None).customlist_restriction_sets
+ assert (
+ [] == Filter(customlist_restriction_sets=None).customlist_restriction_sets
+ )
assert [] == Filter(customlist_restriction_sets=[]).customlist_restriction_sets
# Restrict to books that are on *both* the best sellers list and the
# staff picks list.
restricted = Filter(
- customlist_restriction_sets = [
+ customlist_restriction_sets=[
[self.best_sellers],
[self.staff_picks],
]
)
- assert (
- [[self.best_sellers.id],
- [self.staff_picks.id]] ==
- restricted.customlist_restriction_sets)
+ assert [
+ [self.best_sellers.id],
+ [self.staff_picks.id],
+ ] == restricted.customlist_restriction_sets
# This is a restriction -- 'only books that are not on any lists'.
- assert (
- [[]] ==
- Filter(customlist_restriction_sets=[[]]).customlist_restriction_sets)
+ assert [[]] == Filter(
+ customlist_restriction_sets=[[]]
+ ).customlist_restriction_sets
# Test the license_datasource argument
overdrive = DataSource.lookup(self._db, DataSource.OVERDRIVE)
@@ -3268,32 +3249,28 @@ def test_from_worklist(self):
library = self._default_library
assert True == library.allow_holds
- parent = self._lane(
- display_name="Parent Lane", library=library
- )
+ parent = self._lane(display_name="Parent Lane", library=library)
parent.media = Edition.AUDIO_MEDIUM
parent.languages = ["eng", "fra"]
parent.fiction = True
parent.audiences = set([Classifier.AUDIENCE_CHILDREN])
- parent.target_age = NumericRange(10, 11, '[]')
+ parent.target_age = NumericRange(10, 11, "[]")
parent.genres = [self.horror, self.fantasy]
parent.customlists = [self.best_sellers]
- parent.license_datasource = DataSource.lookup(
- self._db, DataSource.GUTENBERG
- )
+ parent.license_datasource = DataSource.lookup(self._db, DataSource.GUTENBERG)
# This lane inherits most of its configuration from its parent.
- inherits = self._lane(
- display_name="Child who inherits", parent=parent
- )
+ inherits = self._lane(display_name="Child who inherits", parent=parent)
inherits.genres = [self.literary_fiction]
inherits.customlists = [self.staff_picks]
class Mock(object):
def modify_search_filter(self, filter):
self.called_with = filter
+
def scoring_functions(self, filter):
return []
+
facets = Mock()
filter = Filter.from_worklist(self._db, inherits, facets)
@@ -3303,8 +3280,7 @@ def scoring_functions(self, filter):
assert parent.fiction == filter.fiction
assert parent.audiences + [Classifier.AUDIENCE_ALL_AGES] == filter.audiences
assert [parent.license_datasource_id] == filter.license_datasources
- assert ((parent.target_age.lower, parent.target_age.upper) ==
- filter.target_age)
+ assert (parent.target_age.lower, parent.target_age.upper) == filter.target_age
assert True == filter.allow_holds
# Filter.from_worklist passed the mock Facets object in to
@@ -3314,11 +3290,14 @@ def scoring_functions(self, filter):
# For genre and custom list restrictions, the child values are
# appended to the parent's rather than replacing it.
- assert ([parent.genre_ids, inherits.genre_ids] ==
- [set(x) for x in filter.genre_restriction_sets])
+ assert [parent.genre_ids, inherits.genre_ids] == [
+ set(x) for x in filter.genre_restriction_sets
+ ]
- assert ([parent.customlist_ids, inherits.customlist_ids] ==
- filter.customlist_restriction_sets)
+ assert [
+ parent.customlist_ids,
+ inherits.customlist_ids,
+ ] == filter.customlist_restriction_sets
# If any other value is set on the child lane, the parent value
# is overridden.
@@ -3344,9 +3323,10 @@ def scoring_functions(self, filter):
# filter; rather it's in a subfilter that will be applied to the
# 'licensepools' subdocument, where the collection ID lives.
- [subfilter] = subfilters.pop('licensepools')
- assert ({'terms': {'licensepools.collection_id': [self._default_collection.id]}} ==
- subfilter.to_dict())
+ [subfilter] = subfilters.pop("licensepools")
+ assert {
+ "terms": {"licensepools.collection_id": [self._default_collection.id]}
+ } == subfilter.to_dict()
# No other subfilters were specified.
assert {} == subfilters
@@ -3405,10 +3385,10 @@ def assert_filter_builds_to(self, expect, filter, _chain_filters=None):
"""Helper method for the most common case, where a
Filter.build() returns a main filter and no nested filters.
"""
- final_query = {'bool': {'must_not': [RESEARCH.to_dict()]}}
+ final_query = {"bool": {"must_not": [RESEARCH.to_dict()]}}
if expect:
- final_query['bool']['must'] = expect
+ final_query["bool"]["must"] = expect
main, nested = filter.build(_chain_filters)
assert final_query == main.to_dict()
@@ -3431,30 +3411,33 @@ def test_audiences(self):
# "all ages" should always be an audience if the audience is
# young adult or adult.
filter = Filter(audiences=Classifier.AUDIENCE_YOUNG_ADULT)
- assert filter.audiences == [Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_ALL_AGES]
+ assert filter.audiences == [
+ Classifier.AUDIENCE_YOUNG_ADULT,
+ Classifier.AUDIENCE_ALL_AGES,
+ ]
filter = Filter(audiences=Classifier.AUDIENCE_ADULT)
- assert filter.audiences == [Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_ALL_AGES]
- filter = Filter(audiences=[Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_YOUNG_ADULT])
- assert (
- filter.audiences ==
- [Classifier.AUDIENCE_ADULT,
+ assert filter.audiences == [
+ Classifier.AUDIENCE_ADULT,
+ Classifier.AUDIENCE_ALL_AGES,
+ ]
+ filter = Filter(
+ audiences=[Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_YOUNG_ADULT]
+ )
+ assert filter.audiences == [
+ Classifier.AUDIENCE_ADULT,
Classifier.AUDIENCE_YOUNG_ADULT,
- Classifier.AUDIENCE_ALL_AGES])
+ Classifier.AUDIENCE_ALL_AGES,
+ ]
# If the audience is meant for adults, then "all ages" should not
# be included
- for audience in (
- Classifier.AUDIENCE_ADULTS_ONLY, Classifier.AUDIENCE_RESEARCH
- ):
+ for audience in (Classifier.AUDIENCE_ADULTS_ONLY, Classifier.AUDIENCE_RESEARCH):
filter = Filter(audiences=audience)
- assert(Classifier.AUDIENCE_ALL_AGES not in filter.audiences)
+ assert Classifier.AUDIENCE_ALL_AGES not in filter.audiences
# If the audience and target age is meant for children, then the
# audience should only be for children
- filter = Filter(
- audiences=Classifier.AUDIENCE_CHILDREN,
- target_age=5
- )
+ filter = Filter(audiences=Classifier.AUDIENCE_CHILDREN, target_age=5)
assert filter.audiences == [Classifier.AUDIENCE_CHILDREN]
# If the children's target age includes children older than
@@ -3463,9 +3446,10 @@ def test_audiences(self):
all_children = Filter(audiences=Classifier.AUDIENCE_CHILDREN)
nine_years = Filter(audiences=Classifier.AUDIENCE_CHILDREN, target_age=9)
for filter in (all_children, nine_years):
- assert (
- filter.audiences ==
- [Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_ALL_AGES])
+ assert filter.audiences == [
+ Classifier.AUDIENCE_CHILDREN,
+ Classifier.AUDIENCE_ALL_AGES,
+ ]
def test_build(self):
# Test the ability to turn a Filter into an ElasticSearch
@@ -3488,18 +3472,17 @@ def test_build(self):
# Add a medium clause to the filter.
filter.media = "a medium"
- medium_built = {'terms': {'medium': ['amedium']}}
+ medium_built = {"terms": {"medium": ["amedium"]}}
built_filters, subfilters = self.assert_filter_builds_to([medium_built], filter)
assert {} == subfilters
# Add a language clause to the filter.
filter.languages = ["lang1", "LANG2"]
- language_built = {'terms': {'language': ['lang1', 'lang2']}}
+ language_built = {"terms": {"language": ["lang1", "lang2"]}}
# Now both the medium clause and the language clause must match.
built_filters, subfilters = self.assert_filter_builds_to(
- [medium_built, language_built],
- filter
+ [medium_built, language_built], filter
)
assert {} == subfilters
@@ -3507,8 +3490,8 @@ def test_build(self):
filter.collection_ids = [self._default_collection]
filter.fiction = True
- filter._audiences = 'CHILDREN'
- filter.target_age = (2,3)
+ filter._audiences = "CHILDREN"
+ filter.target_age = (2, 3)
overdrive = DataSource.lookup(self._db, DataSource.OVERDRIVE)
filter.excluded_audiobook_data_sources = [overdrive.id]
filter.allow_holds = False
@@ -3527,13 +3510,12 @@ def test_build(self):
# We want books that are literary fiction, *and* either
# fantasy or horror.
filter.genre_restriction_sets = [
- [self.literary_fiction], [self.fantasy, self.horror]
+ [self.literary_fiction],
+ [self.fantasy, self.horror],
]
# We want books that are on _both_ of the custom lists.
- filter.customlist_restriction_sets = [
- [self.best_sellers], [self.staff_picks]
- ]
+ filter.customlist_restriction_sets = [[self.best_sellers], [self.staff_picks]]
# At this point every item on this Filter that can be set, has been
# set. When we run build, we'll end up with the output of our mocked
@@ -3546,43 +3528,50 @@ def test_build(self):
# restrictions is kept in the nested 'licensepools' document,
# so those restrictions must be described in terms of nested
# filters on that document.
- [licensepool_filter, datasource_filter, excluded_audiobooks_filter,
- no_holds_filter] = nested.pop('licensepools')
+ [
+ licensepool_filter,
+ datasource_filter,
+ excluded_audiobooks_filter,
+ no_holds_filter,
+ ] = nested.pop("licensepools")
# The 'current collection' filter.
- assert (
- {'terms': {'licensepools.collection_id': [self._default_collection.id]}} ==
- licensepool_filter.to_dict())
+ assert {
+ "terms": {"licensepools.collection_id": [self._default_collection.id]}
+ } == licensepool_filter.to_dict()
# The 'only certain data sources' filter.
- assert ({'terms': {'licensepools.data_source_id': [overdrive.id]}} ==
- datasource_filter.to_dict())
+ assert {
+ "terms": {"licensepools.data_source_id": [overdrive.id]}
+ } == datasource_filter.to_dict()
# The 'excluded audiobooks' filter.
- audio = Q('term', **{'licensepools.medium': Edition.AUDIO_MEDIUM})
+ audio = Q("term", **{"licensepools.medium": Edition.AUDIO_MEDIUM})
excluded_audio_source = Q(
- 'terms', **{'licensepools.data_source_id' : [overdrive.id]}
+ "terms", **{"licensepools.data_source_id": [overdrive.id]}
)
excluded_audio = Bool(must=[audio, excluded_audio_source])
not_excluded_audio = Bool(must_not=excluded_audio)
assert not_excluded_audio == excluded_audiobooks_filter
# The 'no holds' filter.
- open_access = Q('term', **{'licensepools.open_access' : True})
- licenses_available = Q('term', **{'licensepools.available' : True})
+ open_access = Q("term", **{"licensepools.open_access": True})
+ licenses_available = Q("term", **{"licensepools.available": True})
currently_available = Bool(should=[licenses_available, open_access])
assert currently_available == no_holds_filter
# The best-seller list and staff picks restrictions are also
# expressed as nested filters.
- [best_sellers_filter, staff_picks_filter] = nested.pop('customlists')
- assert ({'terms': {'customlists.list_id': [self.best_sellers.id]}} ==
- best_sellers_filter.to_dict())
- assert ({'terms': {'customlists.list_id': [self.staff_picks.id]}} ==
- staff_picks_filter.to_dict())
+ [best_sellers_filter, staff_picks_filter] = nested.pop("customlists")
+ assert {
+ "terms": {"customlists.list_id": [self.best_sellers.id]}
+ } == best_sellers_filter.to_dict()
+ assert {
+ "terms": {"customlists.list_id": [self.staff_picks.id]}
+ } == staff_picks_filter.to_dict()
# The author restriction is also expressed as a nested filter.
- [contributor_filter] = nested.pop('contributors')
+ [contributor_filter] = nested.pop("contributors")
# It's value is the value of .author_filter, which is tested
# separately in test_author_filter.
@@ -3590,39 +3579,41 @@ def test_build(self):
assert filter.author_filter == contributor_filter
# The genre restrictions are also expressed as nested filters.
- literary_fiction_filter, fantasy_or_horror_filter = nested.pop(
- 'genres'
- )
+ literary_fiction_filter, fantasy_or_horror_filter = nested.pop("genres")
# There are two different restrictions on genre, because
# genre_restriction_sets was set to two lists of genres.
- assert ({'terms': {'genres.term': [self.literary_fiction.id]}} ==
- literary_fiction_filter.to_dict())
- assert ({'terms': {'genres.term': [self.fantasy.id, self.horror.id]}} ==
- fantasy_or_horror_filter.to_dict())
+ assert {
+ "terms": {"genres.term": [self.literary_fiction.id]}
+ } == literary_fiction_filter.to_dict()
+ assert {
+ "terms": {"genres.term": [self.fantasy.id, self.horror.id]}
+ } == fantasy_or_horror_filter.to_dict()
# There's a restriction on the identifier.
- [identifier_restriction] = nested.pop('identifiers')
+ [identifier_restriction] = nested.pop("identifiers")
# The restriction includes subclases, each of which matches
# the identifier and type of one of the Identifier objects.
subclauses = [
- Bool(must=[Term(identifiers__identifier=x.identifier),
- Term(identifiers__type=x.type)])
+ Bool(
+ must=[
+ Term(identifiers__identifier=x.identifier),
+ Term(identifiers__type=x.type),
+ ]
+ )
for x in [i1, i2]
]
# Any identifier will work, but at least one must match.
- assert (Bool(minimum_should_match=1, should=subclauses) ==
- identifier_restriction)
+ assert Bool(minimum_should_match=1, should=subclauses) == identifier_restriction
# There are no other nested filters.
assert {} == nested
# Every other restriction imposed on the Filter object becomes an
# Elasticsearch filter object in this list.
- (medium, language, fiction, audience, target_age,
- updated_after) = built
+ (medium, language, fiction, audience, target_age, updated_after) = built
# Test them one at a time.
#
@@ -3640,8 +3631,8 @@ def test_build(self):
assert medium_built == medium.to_dict()
assert language_built == language.to_dict()
- assert {'term': {'fiction': 'fiction'}} == fiction.to_dict()
- assert {'terms': {'audience': ['children']}} == audience.to_dict()
+ assert {"term": {"fiction": "fiction"}} == fiction.to_dict()
+ assert {"terms": {"audience": ["children"]}} == audience.to_dict()
# The contents of target_age_filter are tested below -- this
# just tests that the target_age_filter is included.
@@ -3650,19 +3641,17 @@ def test_build(self):
# There's a restriction on the last updated time for bibliographic
# metadata. The datetime is converted to a number of seconds since
# the epoch, since that's how we index times.
- expect = (
- last_update_time - from_timestamp(0)
- ).total_seconds()
- assert (
- {'bool': {'must': [
- {'range': {'last_update_time': {'gte': expect}}}
- ]}} ==
- updated_after.to_dict())
+ expect = (last_update_time - from_timestamp(0)).total_seconds()
+ assert {
+ "bool": {"must": [{"range": {"last_update_time": {"gte": expect}}}]}
+ } == updated_after.to_dict()
# We tried fiction; now try nonfiction.
filter = Filter()
filter.fiction = False
- built_filters, subfilters = self.assert_filter_builds_to([{'term': {'fiction': 'nonfiction'}}], filter)
+ built_filters, subfilters = self.assert_filter_builds_to(
+ [{"term": {"fiction": "nonfiction"}}], filter
+ )
assert {} == subfilters
def test_build_series(self):
@@ -3673,9 +3662,9 @@ def test_build_series(self):
# A match against a keyword field only matches on an exact
# string match.
- assert (
- built.to_dict()['bool']['must'] ==
- [{'term': {'series.keyword': 'Talking Hedgehog Mysteries'}}])
+ assert built.to_dict()["bool"]["must"] == [
+ {"term": {"series.keyword": "Talking Hedgehog Mysteries"}}
+ ]
# Find books that are in _some_ series--which one doesn't
# matter.
@@ -3684,12 +3673,10 @@ def test_build_series(self):
assert {} == nested
# The book must have an indexed series.
- assert (
- built.to_dict()['bool']['must'] ==
- [{'exists': {'field': 'series'}}])
+ assert built.to_dict()["bool"]["must"] == [{"exists": {"field": "series"}}]
# But the 'series' that got indexed must not be the empty string.
- assert {'term': {'series.keyword': ''}} in built.to_dict()['bool']['must_not']
+ assert {"term": {"series.keyword": ""}} in built.to_dict()["bool"]["must_not"]
def test_sort_order(self):
# Test the Filter.sort_order property.
@@ -3712,34 +3699,33 @@ def validate_sort_order(filter, main_field):
# it's removed from the list -- there's no need to sort on
# that field a second time.
default_sort_fields = [
- {x: "asc"} for x in ['sort_author', 'sort_title', 'work_id']
+ {x: "asc"}
+ for x in ["sort_author", "sort_title", "work_id"]
if x != main_field
]
assert default_sort_fields == filter.sort_order[1:]
return filter.sort_order[0]
# A simple field, either ascending or descending.
- f.order='field'
+ f.order = "field"
assert False == f.order_ascending
- first_field = validate_sort_order(f, 'field')
- assert dict(field='desc') == first_field
+ first_field = validate_sort_order(f, "field")
+ assert dict(field="desc") == first_field
f.order_ascending = True
- first_field = validate_sort_order(f, 'field')
- assert dict(field='asc') == first_field
+ first_field = validate_sort_order(f, "field")
+ assert dict(field="asc") == first_field
# When multiple fields are given, they are put at the
# beginning and any remaining tiebreaker fields are added.
- f.order=['series_position', 'work_id', 'some_other_field']
- assert (
- [
- dict(series_position='asc'),
- dict(work_id='asc'),
- dict(some_other_field='asc'),
- dict(sort_author='asc'),
- dict(sort_title='asc'),
- ] ==
- f.sort_order)
+ f.order = ["series_position", "work_id", "some_other_field"]
+ assert [
+ dict(series_position="asc"),
+ dict(work_id="asc"),
+ dict(some_other_field="asc"),
+ dict(sort_author="asc"),
+ dict(sort_title="asc"),
+ ] == f.sort_order
# You can't sort by some random subdocument field, because there's
# not enough information to know how to aggregate multiple values.
@@ -3747,7 +3733,7 @@ def validate_sort_order(filter, main_field):
# You _can_ sort by license pool availability time and first
# appearance on custom list -- those are tested below -- but it's
# complicated.
- f.order = 'subdocument.field'
+ f.order = "subdocument.field"
with pytest.raises(ValueError) as excinfo:
f.sort_order()
assert "I don't know how to sort by subdocument.field" in str(excinfo.value)
@@ -3759,25 +3745,21 @@ def validate_sort_order(filter, main_field):
series_position = used_orders[Facets.ORDER_SERIES_POSITION]
last_update = used_orders[Facets.ORDER_LAST_UPDATE]
for sort_field in list(used_orders.values()):
- if sort_field in (added_to_collection, series_position,
- last_update):
+ if sort_field in (added_to_collection, series_position, last_update):
# These are complicated cases, tested below.
continue
f.order = sort_field
first_field = validate_sort_order(f, sort_field)
- assert {sort_field: 'asc'} == first_field
+ assert {sort_field: "asc"} == first_field
# A slightly more complicated case is when a feed is ordered by
# series position -- there the second field is title rather than
# author.
f.order = series_position
- assert (
- [
- {x:'asc'} for x in [
- 'series_position', 'sort_title', 'sort_author', 'work_id'
- ]
- ] ==
- f.sort_order)
+ assert [
+ {x: "asc"}
+ for x in ["series_position", "sort_title", "sort_author", "work_id"]
+ ] == f.sort_order
# A more complicated case is when a feed is ordered by date
# added to the collection. This requires an aggregate function
@@ -3789,13 +3771,13 @@ def validate_sort_order(filter, main_field):
# function. If a book is available through multiple
# collections, we sort by the _earliest_ availability time.
simple_nested_configuration = {
- 'licensepools.availability_time': {'mode': 'min', 'order': 'asc'}
+ "licensepools.availability_time": {"mode": "min", "order": "asc"}
}
assert simple_nested_configuration == first_field
# Setting a collection ID restriction will add a nested filter.
f.collection_ids = [self._default_collection]
- first_field = validate_sort_order(f, 'licensepools.availability_time')
+ first_field = validate_sort_order(f, "licensepools.availability_time")
# The nested filter ensures that when sorting the results, we
# only consider availability times from license pools that
@@ -3807,16 +3789,13 @@ def validate_sort_order(filter, main_field):
#
# This just makes sure that the books show up in the right _order_
# for any given set of collections.
- nested_filter = first_field['licensepools.availability_time'].pop('nested')
- assert (
- {'path': 'licensepools',
- 'filter': {
- 'terms': {
- 'licensepools.collection_id': [self._default_collection.id]
- }
- }
- } ==
- nested_filter)
+ nested_filter = first_field["licensepools.availability_time"].pop("nested")
+ assert {
+ "path": "licensepools",
+ "filter": {
+ "terms": {"licensepools.collection_id": [self._default_collection.id]}
+ },
+ } == nested_filter
# Apart from the nested filter, this is the same ordering
# configuration as before.
@@ -3827,41 +3806,40 @@ def validate_sort_order(filter, main_field):
f.order = last_update
f.collection_ids = []
first_field = validate_sort_order(f, last_update)
- assert dict(last_update_time='asc') == first_field
+ assert dict(last_update_time="asc") == first_field
# Or it can be *incredibly complicated*, if there _are_
# collections or lists associated with the filter. Which,
# unfortunately, is almost all the time.
f.collection_ids = [self._default_collection.id]
- f.customlist_restriction_sets = [[1], [1,2]]
+ f.customlist_restriction_sets = [[1], [1, 2]]
first_field = validate_sort_order(f, last_update)
# Here, the ordering is done by a script that runs on the
# ElasticSearch server.
- sort = first_field.pop('_script')
+ sort = first_field.pop("_script")
assert {} == first_field
# The script returns a numeric value and we want to sort those
# values in ascending order.
- assert 'asc' == sort.pop('order')
- assert 'number' == sort.pop('type')
+ assert "asc" == sort.pop("order")
+ assert "number" == sort.pop("type")
- script = sort.pop('script')
+ script = sort.pop("script")
assert {} == sort
# The script is the 'simplified.work_last_update' stored script.
- assert (CurrentMapping.script_name("work_last_update") ==
- script.pop('stored'))
+ assert CurrentMapping.script_name("work_last_update") == script.pop("stored")
# Two parameters are passed into the script -- the IDs of the
# collections and the lists relevant to the query. This is so
# the query knows which updates should actually be considered
# for purposes of this query.
- params = script.pop('params')
+ params = script.pop("params")
assert {} == script
- assert [self._default_collection.id] == params.pop('collection_ids')
- assert [1,2] == params.pop('list_ids')
+ assert [self._default_collection.id] == params.pop("collection_ids")
+ assert [1, 2] == params.pop("list_ids")
assert {} == params
def test_author_filter(self):
@@ -3878,45 +3856,39 @@ def check_filter(contributor, *shoulds):
# We only count contributions that were in one of the
# matching roles.
- role_match = Terms(
- **{"contributors.role": Filter.AUTHOR_MATCH_ROLES}
- )
+ role_match = Terms(**{"contributors.role": Filter.AUTHOR_MATCH_ROLES})
# Among the other restrictions on fields in the
# 'contributors' subdocument (sort name, VIAF, etc.), at
# least one must also be met.
author_match = [Term(**should) for should in shoulds]
expect = Bool(
- must=[
- role_match,
- Bool(minimum_should_match=1, should=author_match)
- ]
+ must=[role_match, Bool(minimum_should_match=1, should=author_match)]
)
assert expect == actual
# You can apply the filter on any one of these four fields,
# using a Contributor or a ContributorData
- for contributor_field in ('sort_name', 'display_name', 'viaf', 'lc'):
+ for contributor_field in ("sort_name", "display_name", "viaf", "lc"):
for cls in Contributor, ContributorData:
- contributor = cls(**{contributor_field:"value"})
+ contributor = cls(**{contributor_field: "value"})
index_field = contributor_field
- if contributor_field in ('sort_name', 'display_name'):
+ if contributor_field in ("sort_name", "display_name"):
# Sort name and display name are indexed both as
# searchable text fields and filterable keywords.
# We're filtering, so we want to use the keyword
# version.
- index_field += '.keyword'
- check_filter(
- contributor,
- {"contributors.%s" % index_field: "value"}
- )
+ index_field += ".keyword"
+ check_filter(contributor, {"contributors.%s" % index_field: "value"})
# You can also apply the filter using a combination of these
# fields. At least one of the provided fields must match.
for cls in Contributor, ContributorData:
contributor = cls(
- display_name='Ann Leckie', sort_name='Leckie, Ann',
- viaf="73520345", lc="n2013008575"
+ display_name="Ann Leckie",
+ sort_name="Leckie, Ann",
+ viaf="73520345",
+ lc="n2013008575",
)
check_filter(
contributor,
@@ -3932,7 +3904,7 @@ def check_filter(contributor, *shoulds):
unknown_viaf = ContributorData(
sort_name=Edition.UNKNOWN_AUTHOR,
display_name=Edition.UNKNOWN_AUTHOR,
- viaf="123"
+ viaf="123",
)
check_filter(unknown_viaf, {"contributors.viaf": "123"})
@@ -3956,7 +3928,7 @@ def test_target_age_filter(self):
# a number of inputs.
# First, let's create a filter that matches "ages 2 to 5".
- two_to_five = Filter(target_age=(2,5))
+ two_to_five = Filter(target_age=(2, 5))
filter = two_to_five.target_age_filter
# The result is the combination of two filters -- both must
@@ -3977,35 +3949,30 @@ def dichotomy(filter):
assert "bool" == filter.name
assert 1 == filter.minimum_should_match
return filter.should
- more_than_two, no_upper_limit = dichotomy(upper_match)
+ more_than_two, no_upper_limit = dichotomy(upper_match)
# Either the upper age limit must be greater than two...
- assert (
- {'range': {'target_age.upper': {'gte': 2}}} ==
- more_than_two.to_dict())
+ assert {"range": {"target_age.upper": {"gte": 2}}} == more_than_two.to_dict()
# ...or the upper age limit must be missing entirely.
def assert_matches_nonexistent_field(f, field):
"""Verify that a filter only matches when there is
no value for the given field.
"""
- assert (
- f.to_dict() ==
- {'bool': {'must_not': [{'exists': {'field': field}}]}})
- assert_matches_nonexistent_field(no_upper_limit, 'target_age.upper')
+ assert f.to_dict() == {"bool": {"must_not": [{"exists": {"field": field}}]}}
+
+ assert_matches_nonexistent_field(no_upper_limit, "target_age.upper")
# We must also establish that five-year-olds are not too young
# for the book. Again, there are two ways of doing this.
less_than_five, no_lower_limit = dichotomy(lower_match)
# Either the lower age limit must be less than five...
- assert (
- {'range': {'target_age.lower': {'lte': 5}}} ==
- less_than_five.to_dict())
+ assert {"range": {"target_age.lower": {"lte": 5}}} == less_than_five.to_dict()
# ...or the lower age limit must be missing entirely.
- assert_matches_nonexistent_field(no_lower_limit, 'target_age.lower')
+ assert_matches_nonexistent_field(no_lower_limit, "target_age.lower")
# Now let's try a filter that matches "ten and under"
ten_and_under = Filter(target_age=(None, 10))
@@ -4017,9 +3984,8 @@ def assert_matches_nonexistent_field(f, field):
# Either the lower part of the age range must be <= ten, or
# there must be no lower age limit. If neither of these are
# true, then ten-year-olds are too young for the book.
- assert ({'range': {'target_age.lower': {'lte': 10}}} ==
- less_than_ten.to_dict())
- assert_matches_nonexistent_field(no_lower_limit, 'target_age.lower')
+ assert {"range": {"target_age.lower": {"lte": 10}}} == less_than_ten.to_dict()
+ assert_matches_nonexistent_field(no_lower_limit, "target_age.lower")
# Next, let's try a filter that matches "twelve and up".
twelve_and_up = Filter(target_age=(12, None))
@@ -4031,9 +3997,10 @@ def assert_matches_nonexistent_field(f, field):
# Either the upper part of the age range must be >= twelve, or
# there must be no upper age limit. If neither of these are true,
# then twelve-year-olds are too old for the book.
- assert ({'range': {'target_age.upper': {'gte': 12}}} ==
- more_than_twelve.to_dict())
- assert_matches_nonexistent_field(no_upper_limit, 'target_age.upper')
+ assert {
+ "range": {"target_age.upper": {"gte": 12}}
+ } == more_than_twelve.to_dict()
+ assert_matches_nonexistent_field(no_upper_limit, "target_age.upper")
# Finally, test filters that put no restriction on target age.
no_target_age = Filter()
@@ -4065,7 +4032,7 @@ def test__filter_ids(self):
m = Filter._filter_ids
assert None == m(None)
assert [] == m([])
- assert [1,2,3] == m([1,2,3])
+ assert [1, 2, 3] == m([1, 2, 3])
library = self._default_library
assert [library.id] == m([library])
@@ -4087,8 +4054,8 @@ def test__scrub_identifiers(self):
def test__chain_filters(self):
# Test the _chain_filters method, which combines
# two Elasticsearch filter objects.
- f1 = Q('term', key="value")
- f2 = Q('term', key2="value2")
+ f1 = Q("term", key="value")
+ f2 = Q("term", key2="value2")
m = Filter._chain_filters
@@ -4116,15 +4083,14 @@ def test_universal_nested_filters(self):
# Currently all nested filters operate on the 'licensepools'
# subdocument.
- [not_suppressed, currently_owned] = nested.pop('licensepools')
+ [not_suppressed, currently_owned] = nested.pop("licensepools")
assert {} == nested
# Let's look at those filters.
# The first one is simple -- the license pool must not be
# suppressed.
- assert (Term(**{"licensepools.suppressed": False}) ==
- not_suppressed)
+ assert Term(**{"licensepools.suppressed": False}) == not_suppressed
# The second one is a little more complex
owned = Term(**{"licensepools.licensed": True})
@@ -4178,9 +4144,7 @@ def test_from_request(self):
pagination_key = json.dumps(["field 1", 2])
- pagination = SortKeyPagination.from_request(
- dict(key=pagination_key).get
- )
+ pagination = SortKeyPagination.from_request(dict(key=pagination_key).get)
assert isinstance(pagination, SortKeyPagination)
assert SortKeyPagination.DEFAULT_SIZE == pagination.size
assert pagination_key == pagination.pagination_key
@@ -4208,15 +4172,13 @@ def test_items(self):
assert [("size", 20)] == list(pagination.items())
key = ["the last", "item"]
pagination.last_item_on_previous_page = key
- assert (
- [("key", json.dumps(key)), ("size", 20)] ==
- list(pagination.items()))
+ assert [("key", json.dumps(key)), ("size", 20)] == list(pagination.items())
def test_pagination_key(self):
# SortKeyPagination has no pagination key until it knows
# about the last item on the previous page.
pagination = SortKeyPagination()
- assert None == pagination. pagination_key
+ assert None == pagination.pagination_key
key = ["the last", "item"]
pagination.last_item_on_previous_page = key
@@ -4245,12 +4207,15 @@ def test_unimplemented_features(self):
with pytest.raises(NotImplementedError) as excinfo:
pagination.modify_database_query(object())
- assert "SortKeyPagination does not work with database queries." in str(excinfo.value)
+ assert "SortKeyPagination does not work with database queries." in str(
+ excinfo.value
+ )
def test_modify_search_query(self):
class MockSearch(object):
update_from_dict_called_with = "not called"
getitem_called_with = "not called"
+
def update_from_dict(self, dict):
self.update_from_dict_called_with = dict
return self
@@ -4316,9 +4281,7 @@ def __init__(self, sort_key):
self.meta = MockMeta(sort_key)
# Make a page of results, each with a unique sort key.
- hits = [
- MockItem(['sort', 'key', num]) for num in range(5)
- ]
+ hits = [MockItem(["sort", "key", num]) for num in range(5)]
last_hit = hits[-1]
# Tell the page about the results.
@@ -4372,7 +4335,6 @@ def test_next_page(self):
class TestBulkUpdate(DatabaseTest):
-
def test_works_not_presentation_ready_kept_in_index(self):
w1 = self._work()
w1.set_presentation_ready()
@@ -4396,23 +4358,30 @@ def test_works_not_presentation_ready_kept_in_index(self):
# index.
w2.presentation_ready = False
successes, failures = index.bulk_update([w1, w2, w3])
- assert set([w1.id, w2.id, w3.id]) == set([x[-1] for x in list(index.docs.keys())])
+ assert set([w1.id, w2.id, w3.id]) == set(
+ [x[-1] for x in list(index.docs.keys())]
+ )
assert set([w1, w2, w3]) == set(successes)
assert [] == failures
-class TestSearchErrors(ExternalSearchTest):
+class TestSearchErrors(ExternalSearchTest):
def test_search_connection_timeout(self):
attempts = []
def bulk_with_timeout(docs, raise_on_error=False, raise_on_exception=False):
attempts.append(docs)
+
def error(doc):
- return dict(index=dict(status='TIMEOUT',
- exception='ConnectionTimeout',
- error='Connection Timeout!',
- _id=doc['_id'],
- data=doc))
+ return dict(
+ index=dict(
+ status="TIMEOUT",
+ exception="ConnectionTimeout",
+ error="Connection Timeout!",
+ _id=doc["_id"],
+ data=doc,
+ )
+ )
errors = list(map(error, docs))
return 0, errors
@@ -4428,8 +4397,7 @@ def error(doc):
assert "Connection Timeout!" == failures[0][1]
# When all the documents fail, it tries again once with the same arguments.
- assert ([work.id, work.id] ==
- [docs[0]['_id'] for docs in attempts])
+ assert [work.id, work.id] == [docs[0]["_id"] for docs in attempts]
def test_search_single_document_error(self):
successful_work = self._work()
@@ -4438,9 +4406,13 @@ def test_search_single_document_error(self):
failing_work.set_presentation_ready()
def bulk_with_error(docs, raise_on_error=False, raise_on_exception=False):
- failures = [dict(data=dict(_id=failing_work.id),
- error="There was an error!",
- exception="Exception")]
+ failures = [
+ dict(
+ data=dict(_id=failing_work.id),
+ error="There was an error!",
+ exception="Exception",
+ )
+ ]
success_count = 1
return success_count, failures
@@ -4474,22 +4446,16 @@ def test_constructor(self):
class TestSearchIndexCoverageProvider(DatabaseTest):
-
def test_operation(self):
index = MockExternalSearchIndex()
- provider = SearchIndexCoverageProvider(
- self._db, search_index_client=index
- )
- assert (WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION ==
- provider.operation)
+ provider = SearchIndexCoverageProvider(self._db, search_index_client=index)
+ assert WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION == provider.operation
def test_success(self):
work = self._work()
work.set_presentation_ready()
index = MockExternalSearchIndex()
- provider = SearchIndexCoverageProvider(
- self._db, search_index_client=index
- )
+ provider = SearchIndexCoverageProvider(self._db, search_index_client=index)
results = provider.process_batch([work])
# We got one success and no failures.
@@ -4501,24 +4467,25 @@ def test_success(self):
def test_failure(self):
class DoomedExternalSearchIndex(MockExternalSearchIndex):
"""All documents sent to this index will fail."""
+
def bulk(self, docs, **kwargs):
return 0, [
- dict(data=dict(_id=failing_work['_id']),
- error="There was an error!",
- exception="Exception")
+ dict(
+ data=dict(_id=failing_work["_id"]),
+ error="There was an error!",
+ exception="Exception",
+ )
for failing_work in docs
]
work = self._work()
work.set_presentation_ready()
index = DoomedExternalSearchIndex()
- provider = SearchIndexCoverageProvider(
- self._db, search_index_client=index
- )
+ provider = SearchIndexCoverageProvider(self._db, search_index_client=index)
results = provider.process_batch([work])
# We have one transient failure.
[record] = results
assert work == record.obj
assert True == record.transient
- assert 'There was an error!' == record.exception
+ assert "There was an error!" == record.exception
diff --git a/tests/test_facets.py b/tests/test_facets.py
index 062e47f63..76fa812c7 100644
--- a/tests/test_facets.py
+++ b/tests/test_facets.py
@@ -1,13 +1,9 @@
+from ..facets import FacetConfig
+from ..facets import FacetConstants as Facets
from ..testing import DatabaseTest
-from ..facets import (
- FacetConstants as Facets,
- FacetConfig,
-)
-
class TestFacetConfig(DatabaseTest):
-
def test_from_library(self):
library = self._default_library
order_by = Facets.ORDER_FACET_GROUP_NAME
@@ -18,10 +14,8 @@ def test_from_library(self):
config = FacetConfig.from_library(library)
assert Facets.ORDER_RANDOM not in config.enabled_facets(order_by)
for group in list(Facets.DEFAULT_FACET.keys()):
- assert (config.enabled_facets(group) ==
- library.enabled_facets(group))
- assert (config.default_facet(group) ==
- library.default_facet(group))
+ assert config.enabled_facets(group) == library.enabled_facets(group)
+ assert config.default_facet(group) == library.default_facet(group)
# If you then modify the FacetConfig, it deviates from what
# the Library would do.
diff --git a/tests/test_lane.py b/tests/test_lane.py
index 66c68e86e..374e68c8c 100644
--- a/tests/test_lane.py
+++ b/tests/test_lane.py
@@ -1,30 +1,21 @@
import datetime
import json
import logging
-import pytest
-from mock import (
- call,
- MagicMock,
-)
import random
-from sqlalchemy.sql.elements import Case
-from sqlalchemy import (
- and_,
- func,
- text,
-)
+
+import pytest
from elasticsearch.exceptions import ElasticsearchException
+from mock import MagicMock, call
+from sqlalchemy import and_, func, text
+from sqlalchemy.sql.elements import Case
-from ..testing import (
- DatabaseTest,
-)
from ..classifier import Classifier
from ..config import Configuration
from ..entrypoint import (
AudiobooksEntryPoint,
EbooksEntryPoint,
- EverythingEntryPoint,
EntryPoint,
+ EverythingEntryPoint,
)
from ..external_search import (
Filter,
@@ -40,16 +31,13 @@
Facets,
FacetsWithEntryPoint,
FeaturedFacets,
+ Lane,
Pagination,
SearchFacets,
TopLevelWorkList,
WorkList,
- Lane,
)
from ..model import (
- dump_query,
- get_one_or_create,
- tuple_to_numericrange,
CachedFeed,
CustomListEntry,
DataSource,
@@ -61,18 +49,22 @@
SessionManager,
Work,
WorkGenre,
+ dump_query,
+ get_one_or_create,
+ tuple_to_numericrange,
)
from ..problem_details import INVALID_INPUT
-from ..testing import EndToEndSearchTest, LogCaptureHandler
-from ..util.opds_writer import OPDSFeed
+from ..testing import DatabaseTest, EndToEndSearchTest, LogCaptureHandler
from ..util.datetime_helpers import utc_now
+from ..util.opds_writer import OPDSFeed
-class TestFacetsWithEntryPoint(DatabaseTest):
+class TestFacetsWithEntryPoint(DatabaseTest):
class MockFacetConfig(object):
"""Pass this in when you call FacetsWithEntryPoint.from_request
but you don't care which EntryPoints are configured.
"""
+
entrypoints = []
def test_items(self):
@@ -89,7 +81,6 @@ def test_items(self):
]
assert expect_items == list(f.items())
-
def test_modify_database_query(self):
class MockEntryPoint(object):
def modify_database_query(self, _db, qu):
@@ -108,8 +99,7 @@ def test_navigate(self):
old_entrypoint = object()
kwargs = dict(extra_key="extra_value")
facets = FacetsWithEntryPoint(
- old_entrypoint, entrypoint_is_default=True,
- max_cache_age=123, **kwargs
+ old_entrypoint, entrypoint_is_default=True, max_cache_age=123, **kwargs
)
new_entrypoint = object()
new_facets = facets.navigate(new_entrypoint)
@@ -134,21 +124,32 @@ def test_navigate(self):
def test_from_request(self):
# from_request just calls the _from_request class method
expect = object()
+
class Mock(FacetsWithEntryPoint):
@classmethod
def _from_request(cls, *args, **kwargs):
cls.called_with = (args, kwargs)
return expect
+
result = Mock.from_request(
- "library", "facet config", "get_argument",
- "get_header", "worklist", "default entrypoint",
- extra="extra argument"
+ "library",
+ "facet config",
+ "get_argument",
+ "get_header",
+ "worklist",
+ "default entrypoint",
+ extra="extra argument",
)
# The arguments given to from_request were propagated to _from_request.
args, kwargs = Mock.called_with
- assert ("facet config", "get_argument",
- "get_header", "worklist", "default entrypoint") == args
+ assert (
+ "facet config",
+ "get_argument",
+ "get_header",
+ "worklist",
+ "default entrypoint",
+ ) == args
assert dict(extra="extra argument") == kwargs
# The return value of _from_request was propagated through
@@ -171,7 +172,11 @@ def selectable_entrypoints(cls, facet_config):
@classmethod
def load_entrypoint(cls, entrypoint_name, entrypoints, default=None):
- cls.load_entrypoint_called_with = (entrypoint_name, entrypoints, default)
+ cls.load_entrypoint_called_with = (
+ entrypoint_name,
+ entrypoints,
+ default,
+ )
return cls.expect_load_entrypoint
@classmethod
@@ -202,8 +207,12 @@ def get_header(name):
def m():
return MockFacetsWithEntryPoint._from_request(
- config, get_argument, get_header, mock_worklist,
- default_entrypoint=default_entrypoint, extra="extra kwarg"
+ config,
+ get_argument,
+ get_header,
+ mock_worklist,
+ default_entrypoint=default_entrypoint,
+ extra="extra kwarg",
)
# First, test failure. If load_entrypoint() returns a
@@ -216,7 +225,10 @@ def m():
# returns a ProblemDetail.
expect_entrypoint = object()
expect_is_default = object()
- MockFacetsWithEntryPoint.expect_load_entrypoint = (expect_entrypoint, expect_is_default)
+ MockFacetsWithEntryPoint.expect_load_entrypoint = (
+ expect_entrypoint,
+ expect_is_default,
+ )
MockFacetsWithEntryPoint.expect_max_cache_age = INVALID_INPUT
assert INVALID_INPUT == m()
@@ -233,12 +245,17 @@ def m():
assert expect_entrypoint == facets.entrypoint
assert expect_is_default == facets.entrypoint_is_default
assert (
- ("entrypoint name from request", ["Selectable entrypoints"], default_entrypoint) ==
- MockFacetsWithEntryPoint.load_entrypoint_called_with)
+ "entrypoint name from request",
+ ["Selectable entrypoints"],
+ default_entrypoint,
+ ) == MockFacetsWithEntryPoint.load_entrypoint_called_with
assert 345 == facets.max_cache_age
assert dict(extra="extra kwarg") == facets.constructor_kwargs
assert MockFacetsWithEntryPoint.selectable_entrypoints_called_with == config
- assert MockFacetsWithEntryPoint.load_max_cache_age_called_with == "max cache age from request"
+ assert (
+ MockFacetsWithEntryPoint.load_max_cache_age_called_with
+ == "max cache age from request"
+ )
def test_load_entrypoint(self):
audio = AudiobooksEntryPoint
@@ -302,6 +319,7 @@ def test_selectable_entrypoints(self):
"""The default implementation of selectable_entrypoints just returns
the worklist's entrypoints.
"""
+
class MockWorkList(object):
def __init__(self, entrypoints):
self.entrypoints = entrypoints
@@ -330,7 +348,6 @@ def test_modify_search_filter(self):
class TestFacets(DatabaseTest):
-
def _configure_facets(self, library, enabled, default):
"""Set facet configuration for the given Library."""
for key, values in list(enabled.items()):
@@ -344,7 +361,9 @@ def test_max_cache_age(self):
# should be.
facets = Facets(
self._default_library,
- Facets.COLLECTION_FULL, Facets.AVAILABLE_ALL, Facets.ORDER_TITLE
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_ALL,
+ Facets.ORDER_TITLE,
)
assert None == facets.max_cache_age
@@ -352,7 +371,9 @@ def test_facet_groups(self):
facets = Facets(
self._default_library,
- Facets.COLLECTION_FULL, Facets.AVAILABLE_ALL, Facets.ORDER_TITLE
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_ALL,
+ Facets.ORDER_TITLE,
)
all_groups = list(facets.facet_groups)
@@ -363,36 +384,33 @@ def test_facet_groups(self):
# available=all, collection=full, and order=title are the selected
# facets.
selected = sorted([x[:2] for x in all_groups if x[-1] == True])
- assert (
- [('available', 'all'), ('collection', 'full'), ('order', 'title')] ==
- selected)
+ assert [
+ ("available", "all"),
+ ("collection", "full"),
+ ("order", "title"),
+ ] == selected
test_enabled_facets = {
- Facets.ORDER_FACET_GROUP_NAME : [
- Facets.ORDER_WORK_ID, Facets.ORDER_TITLE
- ],
- Facets.COLLECTION_FACET_GROUP_NAME : [Facets.COLLECTION_FEATURED],
- Facets.AVAILABILITY_FACET_GROUP_NAME : [Facets.AVAILABLE_ALL],
+ Facets.ORDER_FACET_GROUP_NAME: [Facets.ORDER_WORK_ID, Facets.ORDER_TITLE],
+ Facets.COLLECTION_FACET_GROUP_NAME: [Facets.COLLECTION_FEATURED],
+ Facets.AVAILABILITY_FACET_GROUP_NAME: [Facets.AVAILABLE_ALL],
}
test_default_facets = {
- Facets.ORDER_FACET_GROUP_NAME : Facets.ORDER_TITLE,
- Facets.COLLECTION_FACET_GROUP_NAME : Facets.COLLECTION_FEATURED,
- Facets.AVAILABILITY_FACET_GROUP_NAME : Facets.AVAILABLE_ALL,
+ Facets.ORDER_FACET_GROUP_NAME: Facets.ORDER_TITLE,
+ Facets.COLLECTION_FACET_GROUP_NAME: Facets.COLLECTION_FEATURED,
+ Facets.AVAILABILITY_FACET_GROUP_NAME: Facets.AVAILABLE_ALL,
}
library = self._default_library
- self._configure_facets(
- library, test_enabled_facets, test_default_facets
- )
+ self._configure_facets(library, test_enabled_facets, test_default_facets)
- facets = Facets(self._default_library,
- None, None, Facets.ORDER_TITLE)
+ facets = Facets(self._default_library, None, None, Facets.ORDER_TITLE)
all_groups = list(facets.facet_groups)
# We have disabled almost all the facets, so the list of
# facet transitions includes only two items.
#
# 'Sort by title' was selected, and it shows up as the selected
# item in this facet group.
- expect = [['order', 'title', True], ['order', 'work_id', False]]
+ expect = [["order", "title", True], ["order", "work_id", False]]
assert expect == sorted([list(x[:2]) + [x[-1]] for x in all_groups])
def test_default(self):
@@ -402,11 +420,13 @@ class Mock(Facets):
def __init__(self, library, **kwargs):
self.library = library
self.kwargs = kwargs
+
facets = Mock.default(self._default_library)
assert self._default_library == facets.library
- assert (dict(collection=None, availability=None, order=None,
- entrypoint=None) ==
- facets.kwargs)
+ assert (
+ dict(collection=None, availability=None, order=None, entrypoint=None)
+ == facets.kwargs
+ )
def test_default_facet_is_always_available(self):
# By definition, the default facet must be enabled. So if the
@@ -445,6 +465,7 @@ class MockFacets(Facets):
def default_facet(cls, config, facet_group_name):
cls.called_with = (config, facet_group_name)
return "facet2"
+
available = MockFacets.available_facets(config, "some facet group")
assert ["facet1", "facet2"] == available
@@ -453,19 +474,20 @@ def test_default_availability(self):
# Normally, the availability will be the library's default availability
# facet.
test_enabled_facets = {
- Facets.ORDER_FACET_GROUP_NAME : [Facets.ORDER_WORK_ID],
- Facets.COLLECTION_FACET_GROUP_NAME : [Facets.COLLECTION_FULL],
- Facets.AVAILABILITY_FACET_GROUP_NAME : [Facets.AVAILABLE_ALL, Facets.AVAILABLE_NOW],
+ Facets.ORDER_FACET_GROUP_NAME: [Facets.ORDER_WORK_ID],
+ Facets.COLLECTION_FACET_GROUP_NAME: [Facets.COLLECTION_FULL],
+ Facets.AVAILABILITY_FACET_GROUP_NAME: [
+ Facets.AVAILABLE_ALL,
+ Facets.AVAILABLE_NOW,
+ ],
}
test_default_facets = {
- Facets.ORDER_FACET_GROUP_NAME : Facets.ORDER_TITLE,
- Facets.COLLECTION_FACET_GROUP_NAME : Facets.COLLECTION_FULL,
- Facets.AVAILABILITY_FACET_GROUP_NAME : Facets.AVAILABLE_ALL,
+ Facets.ORDER_FACET_GROUP_NAME: Facets.ORDER_TITLE,
+ Facets.COLLECTION_FACET_GROUP_NAME: Facets.COLLECTION_FULL,
+ Facets.AVAILABILITY_FACET_GROUP_NAME: Facets.AVAILABLE_ALL,
}
library = self._default_library
- self._configure_facets(
- library, test_enabled_facets, test_default_facets
- )
+ self._configure_facets(library, test_enabled_facets, test_default_facets)
facets = Facets(library, None, None, None)
assert Facets.AVAILABLE_ALL == facets.availability
@@ -477,20 +499,21 @@ def test_default_availability(self):
# Unless 'now' is not one of the enabled facets - then we keep
# using the library's default.
- test_enabled_facets[Facets.AVAILABILITY_FACET_GROUP_NAME] = [Facets.AVAILABLE_ALL]
- self._configure_facets(
- library, test_enabled_facets, test_default_facets
- )
+ test_enabled_facets[Facets.AVAILABILITY_FACET_GROUP_NAME] = [
+ Facets.AVAILABLE_ALL
+ ]
+ self._configure_facets(library, test_enabled_facets, test_default_facets)
facets = Facets(library, None, None, None)
assert Facets.AVAILABLE_ALL == facets.availability
def test_facets_can_be_enabled_at_initialization(self):
enabled_facets = {
- Facets.ORDER_FACET_GROUP_NAME : [
- Facets.ORDER_TITLE, Facets.ORDER_AUTHOR,
+ Facets.ORDER_FACET_GROUP_NAME: [
+ Facets.ORDER_TITLE,
+ Facets.ORDER_AUTHOR,
],
- Facets.COLLECTION_FACET_GROUP_NAME : [Facets.COLLECTION_FULL],
- Facets.AVAILABILITY_FACET_GROUP_NAME : [Facets.AVAILABLE_OPEN_ACCESS]
+ Facets.COLLECTION_FACET_GROUP_NAME: [Facets.COLLECTION_FULL],
+ Facets.AVAILABILITY_FACET_GROUP_NAME: [Facets.AVAILABLE_OPEN_ACCESS],
}
library = self._default_library
self._configure_facets(library, enabled_facets, {})
@@ -499,29 +522,34 @@ def test_facets_can_be_enabled_at_initialization(self):
# no matter the Configuration.
facets = Facets(
self._default_library,
- Facets.COLLECTION_FULL, Facets.AVAILABLE_OPEN_ACCESS,
- Facets.ORDER_TITLE, enabled_facets=enabled_facets
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_OPEN_ACCESS,
+ Facets.ORDER_TITLE,
+ enabled_facets=enabled_facets,
)
all_groups = list(facets.facet_groups)
- expect = [['order', 'author', False], ['order', 'title', True]]
+ expect = [["order", "author", False], ["order", "title", True]]
assert expect == sorted([list(x[:2]) + [x[-1]] for x in all_groups])
def test_facets_dont_need_a_library(self):
enabled_facets = {
- Facets.ORDER_FACET_GROUP_NAME : [
- Facets.ORDER_TITLE, Facets.ORDER_AUTHOR,
+ Facets.ORDER_FACET_GROUP_NAME: [
+ Facets.ORDER_TITLE,
+ Facets.ORDER_AUTHOR,
],
- Facets.COLLECTION_FACET_GROUP_NAME : [Facets.COLLECTION_FULL],
- Facets.AVAILABILITY_FACET_GROUP_NAME : [Facets.AVAILABLE_OPEN_ACCESS]
+ Facets.COLLECTION_FACET_GROUP_NAME: [Facets.COLLECTION_FULL],
+ Facets.AVAILABILITY_FACET_GROUP_NAME: [Facets.AVAILABLE_OPEN_ACCESS],
}
facets = Facets(
None,
- Facets.COLLECTION_FULL, Facets.AVAILABLE_OPEN_ACCESS,
- Facets.ORDER_TITLE, enabled_facets=enabled_facets
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_OPEN_ACCESS,
+ Facets.ORDER_TITLE,
+ enabled_facets=enabled_facets,
)
all_groups = list(facets.facet_groups)
- expect = [['order', 'author', False], ['order', 'title', True]]
+ expect = [["order", "author", False], ["order", "title", True]]
assert expect == sorted([list(x[:2]) + [x[-1]] for x in all_groups])
def test_items(self):
@@ -530,15 +558,17 @@ def test_items(self):
"""
facets = Facets(
self._default_library,
- Facets.COLLECTION_FULL, Facets.AVAILABLE_ALL, Facets.ORDER_TITLE,
- entrypoint=AudiobooksEntryPoint
+ Facets.COLLECTION_FULL,
+ Facets.AVAILABLE_ALL,
+ Facets.ORDER_TITLE,
+ entrypoint=AudiobooksEntryPoint,
)
- assert ([
- ('available', Facets.AVAILABLE_ALL),
- ('collection', Facets.COLLECTION_FULL),
- ('entrypoint', AudiobooksEntryPoint.INTERNAL_NAME),
- ('order', Facets.ORDER_TITLE)] ==
- sorted(facets.items()))
+ assert [
+ ("available", Facets.AVAILABLE_ALL),
+ ("collection", Facets.COLLECTION_FULL),
+ ("entrypoint", AudiobooksEntryPoint.INTERNAL_NAME),
+ ("order", Facets.ORDER_TITLE),
+ ] == sorted(facets.items())
def test_default_order_ascending(self):
@@ -548,20 +578,21 @@ def test_default_order_ascending(self):
self._default_library,
collection=Facets.COLLECTION_FULL,
availability=Facets.AVAILABLE_ALL,
- order=order
+ order=order,
)
assert True == f.order_ascending
# But the time-based facets are ordered descending by default
# (newest->oldest)
- assert (set([Facets.ORDER_ADDED_TO_COLLECTION, Facets.ORDER_LAST_UPDATE]) ==
- set(Facets.ORDER_DESCENDING_BY_DEFAULT))
+ assert set([Facets.ORDER_ADDED_TO_COLLECTION, Facets.ORDER_LAST_UPDATE]) == set(
+ Facets.ORDER_DESCENDING_BY_DEFAULT
+ )
for order in Facets.ORDER_DESCENDING_BY_DEFAULT:
f = Facets(
self._default_library,
collection=Facets.COLLECTION_FULL,
availability=Facets.AVAILABLE_ALL,
- order=order
+ order=order,
)
assert False == f.order_ascending
@@ -572,8 +603,13 @@ def test_navigate(self):
F = Facets
ebooks = EbooksEntryPoint
- f = Facets(self._default_library, F.COLLECTION_FULL, F.AVAILABLE_ALL,
- F.ORDER_TITLE, entrypoint=ebooks)
+ f = Facets(
+ self._default_library,
+ F.COLLECTION_FULL,
+ F.AVAILABLE_ALL,
+ F.ORDER_TITLE,
+ entrypoint=ebooks,
+ )
different_collection = f.navigate(collection=F.COLLECTION_FEATURED)
assert F.COLLECTION_FEATURED == different_collection.collection
@@ -615,9 +651,7 @@ def test_from_request(self):
# Valid object using the default settings.
default_order = config.default_facet(Facets.ORDER_FACET_GROUP_NAME)
- default_collection = config.default_facet(
- Facets.COLLECTION_FACET_GROUP_NAME
- )
+ default_collection = config.default_facet(Facets.COLLECTION_FACET_GROUP_NAME)
default_availability = config.default_facet(
Facets.AVAILABILITY_FACET_GROUP_NAME
)
@@ -651,22 +685,28 @@ def test_from_request(self):
args = dict(order="no such order")
invalid_order = m(library, library, args.get, headers.get, None)
assert INVALID_INPUT.uri == invalid_order.uri
- assert ("I don't know how to order a feed by 'no such order'" ==
- invalid_order.detail)
+ assert (
+ "I don't know how to order a feed by 'no such order'"
+ == invalid_order.detail
+ )
# Invalid availability
args = dict(available="no such availability")
invalid_availability = m(library, library, args.get, headers.get, None)
assert INVALID_INPUT.uri == invalid_availability.uri
- assert ("I don't understand the availability term 'no such availability'" ==
- invalid_availability.detail)
+ assert (
+ "I don't understand the availability term 'no such availability'"
+ == invalid_availability.detail
+ )
# Invalid collection
args = dict(collection="no such collection")
invalid_collection = m(library, library, args.get, headers.get, None)
assert INVALID_INPUT.uri == invalid_collection.uri
- assert ("I don't understand what 'no such collection' refers to." ==
- invalid_collection.detail)
+ assert (
+ "I don't understand what 'no such collection' refers to."
+ == invalid_collection.detail
+ )
def test_from_request_gets_available_facets_through_hook_methods(self):
# Available and default facets are determined by calling the
@@ -678,9 +718,11 @@ class Mock(Facets):
# For whatever reason, this faceting object allows only a
# single setting for each facet group.
- mock_enabled = dict(order=[Facets.ORDER_TITLE],
- available=[Facets.AVAILABLE_OPEN_ACCESS],
- collection=[Facets.COLLECTION_FULL])
+ mock_enabled = dict(
+ order=[Facets.ORDER_TITLE],
+ available=[Facets.AVAILABLE_OPEN_ACCESS],
+ collection=[Facets.COLLECTION_FULL],
+ )
@classmethod
def available_facets(cls, config, facet_group_name):
@@ -724,8 +766,7 @@ def test_modify_search_filter(self):
# Test superclass behavior -- filter is modified by entrypoint.
facets = Facets(
- self._default_library, None, None, None,
- entrypoint=AudiobooksEntryPoint
+ self._default_library, None, None, None, entrypoint=AudiobooksEntryPoint
)
filter = Filter()
facets.modify_search_filter(filter)
@@ -733,14 +774,19 @@ def test_modify_search_filter(self):
# Now test the subclass behavior.
facets = Facets(
- self._default_library, "some collection", "some availability",
- order=Facets.ORDER_ADDED_TO_COLLECTION, order_ascending="yep"
+ self._default_library,
+ "some collection",
+ "some availability",
+ order=Facets.ORDER_ADDED_TO_COLLECTION,
+ order_ascending="yep",
)
facets.modify_search_filter(filter)
# The library's minimum featured quality is passed in.
- assert (self._default_library.minimum_featured_quality ==
- filter.minimum_featured_quality)
+ assert (
+ self._default_library.minimum_featured_quality
+ == filter.minimum_featured_quality
+ )
# Availability and collection are propagated with no
# validation.
@@ -765,15 +811,13 @@ def test_modify_search_filter(self):
def test_modify_database_query(self):
# Make sure that modify_database_query handles the various
# reasons why a book might or might not be 'available'.
- open_access = self._work(with_open_access_download=True,
- title="open access")
+ open_access = self._work(with_open_access_download=True, title="open access")
open_access.quality = 1
self_hosted = self._work(
with_license_pool=True, self_hosted=True, title="self hosted"
)
unlimited_access = self._work(
- with_license_pool=True, unlimited_access=True,
- title="unlimited access"
+ with_license_pool=True, unlimited_access=True, title="unlimited access"
)
available = self._work(with_license_pool=True, title="available")
@@ -781,9 +825,7 @@ def test_modify_database_query(self):
pool.licenses_owned = 1
pool.licenses_available = 1
- not_available = self._work(
- with_license_pool=True, title="not available"
- )
+ not_available = self._work(with_license_pool=True, title="not available")
[pool] = not_available.license_pools
pool.licenses_owned = 1
pool.licenses_available = 0
@@ -792,40 +834,51 @@ def test_modify_database_query(self):
[pool] = not_licensed.license_pools
pool.licenses_owned = 0
pool.licenses_available = 0
- qu = self._db.query(Work).join(Work.license_pools).join(
- LicensePool.presentation_edition
+ qu = (
+ self._db.query(Work)
+ .join(Work.license_pools)
+ .join(LicensePool.presentation_edition)
)
for availability, expect in [
- (Facets.AVAILABLE_NOW,
- [open_access, available, self_hosted, unlimited_access]),
- (Facets.AVAILABLE_ALL,
- [open_access, available, not_available, self_hosted, unlimited_access]),
+ (
+ Facets.AVAILABLE_NOW,
+ [open_access, available, self_hosted, unlimited_access],
+ ),
+ (
+ Facets.AVAILABLE_ALL,
+ [open_access, available, not_available, self_hosted, unlimited_access],
+ ),
(Facets.AVAILABLE_NOT_NOW, [not_available]),
]:
facets = Facets(self._default_library, None, availability, None)
modified = facets.modify_database_query(self._db, qu)
- assert ((availability, sorted([x.title for x in modified])) ==
- (availability, sorted([x.title for x in expect])))
+ assert (availability, sorted([x.title for x in modified])) == (
+ availability,
+ sorted([x.title for x in expect]),
+ )
# Setting the 'featured' collection includes only known
# high-quality works.
for collection, expect in [
- (Facets.COLLECTION_FULL,
- [open_access, available, self_hosted, unlimited_access]),
- (Facets.COLLECTION_FEATURED,
- [open_access]),
+ (
+ Facets.COLLECTION_FULL,
+ [open_access, available, self_hosted, unlimited_access],
+ ),
+ (Facets.COLLECTION_FEATURED, [open_access]),
]:
- facets = Facets(self._default_library, collection,
- Facets.AVAILABLE_NOW, None)
+ facets = Facets(
+ self._default_library, collection, Facets.AVAILABLE_NOW, None
+ )
modified = facets.modify_database_query(self._db, qu)
- assert ((collection, sorted([x.title for x in modified])) ==
- (collection, sorted([x.title for x in expect])))
+ assert (collection, sorted([x.title for x in modified])) == (
+ collection,
+ sorted([x.title for x in expect]),
+ )
class TestDefaultSortOrderFacets(DatabaseTest):
-
def setup_method(self):
super(TestDefaultSortOrderFacets, self).setup_method()
self.config = self._default_library
@@ -833,12 +886,16 @@ def setup_method(self):
def _check_other_groups_not_changed(self, cls):
# Verify that nothing has changed for the collection or
# availability facet groups.
- for group_name in (Facets.COLLECTION_FACET_GROUP_NAME,
- Facets.AVAILABILITY_FACET_GROUP_NAME):
- assert (Facets.available_facets(self.config, group_name) ==
- cls.available_facets(self.config, group_name))
- assert (Facets.default_facet(self.config, group_name) ==
- cls.default_facet(self.config, group_name))
+ for group_name in (
+ Facets.COLLECTION_FACET_GROUP_NAME,
+ Facets.AVAILABILITY_FACET_GROUP_NAME,
+ ):
+ assert Facets.available_facets(
+ self.config, group_name
+ ) == cls.available_facets(self.config, group_name)
+ assert Facets.default_facet(self.config, group_name) == cls.default_facet(
+ self.config, group_name
+ )
def test_sort_order_rearrangement(self):
# Test the case where a DefaultSortOrderFacets does nothing but
@@ -853,11 +910,10 @@ class TitleFirst(DefaultSortOrderFacets):
# But the default sort order for TitleFirst is ORDER_TITLE.
order = Facets.ORDER_FACET_GROUP_NAME
- assert (TitleFirst.DEFAULT_SORT_ORDER ==
- TitleFirst.default_facet(self.config, order))
- assert Facets.default_facet(
+ assert TitleFirst.DEFAULT_SORT_ORDER == TitleFirst.default_facet(
self.config, order
- ) != TitleFirst.DEFAULT_SORT_ORDER
+ )
+ assert Facets.default_facet(self.config, order) != TitleFirst.DEFAULT_SORT_ORDER
# TitleFirst has the same sort orders as Facets, but ORDER_TITLE
# comes first in the list.
@@ -879,11 +935,12 @@ class SeriesFirst(DefaultSortOrderFacets):
# But its default sort order is ORDER_SERIES.
order = Facets.ORDER_FACET_GROUP_NAME
- assert (SeriesFirst.DEFAULT_SORT_ORDER ==
- SeriesFirst.default_facet(self.config, order))
- assert Facets.default_facet(
+ assert SeriesFirst.DEFAULT_SORT_ORDER == SeriesFirst.default_facet(
self.config, order
- ) != SeriesFirst.DEFAULT_SORT_ORDER
+ )
+ assert (
+ Facets.default_facet(self.config, order) != SeriesFirst.DEFAULT_SORT_ORDER
+ )
# Its list of sort orders is the same as Facets, except Series
# has been added to the front of the list.
@@ -893,7 +950,6 @@ class SeriesFirst(DefaultSortOrderFacets):
class TestDatabaseBackedFacets(DatabaseTest):
-
def test_available_facets(self):
# The only available sort orders are the ones that map
# directly onto a database field.
@@ -913,17 +969,16 @@ def test_available_facets(self):
)
assert len(f2_orders) < len(f1_orders)
for order in f2_orders:
- assert (
- order in f1_orders and order in f2.ORDER_FACET_TO_DATABASE_FIELD
- )
+ assert order in f1_orders and order in f2.ORDER_FACET_TO_DATABASE_FIELD
# The rules for collection and availability are the same.
for group in (
FacetConstants.COLLECTION_FACET_GROUP_NAME,
FacetConstants.AVAILABILITY_FACET_GROUP_NAME,
):
- assert (f1.available_facets(self._default_library, group) ==
- f2.available_facets(self._default_library, group))
+ assert f1.available_facets(
+ self._default_library, group
+ ) == f2.available_facets(self._default_library, group)
def test_default_facets(self):
# If the configured default sort order is not available,
@@ -936,8 +991,9 @@ def test_default_facets(self):
FacetConstants.COLLECTION_FACET_GROUP_NAME,
FacetConstants.AVAILABILITY_FACET_GROUP_NAME,
):
- assert (f1.default_facet(self._default_library, group) ==
- f2.default_facet(self._default_library, group))
+ assert f1.default_facet(self._default_library, group) == f2.default_facet(
+ self._default_library, group
+ )
# In this bizarre library, the default sort order is 'time
# added to collection' -- an order not supported by
@@ -945,8 +1001,10 @@ def test_default_facets(self):
class Mock(object):
enabled = [
FacetConstants.ORDER_ADDED_TO_COLLECTION,
- FacetConstants.ORDER_TITLE, FacetConstants.ORDER_AUTHOR
+ FacetConstants.ORDER_TITLE,
+ FacetConstants.ORDER_AUTHOR,
]
+
def enabled_facets(self, group_name):
return self.enabled
@@ -956,23 +1014,23 @@ def default_facet(self, group_name):
# A Facets object uses the 'time added to collection' order by
# default.
config = Mock()
- assert (f1.ORDER_ADDED_TO_COLLECTION ==
- f1.default_facet(config, f1.ORDER_FACET_GROUP_NAME))
+ assert f1.ORDER_ADDED_TO_COLLECTION == f1.default_facet(
+ config, f1.ORDER_FACET_GROUP_NAME
+ )
# A DatabaseBacked Facets can't do that. It finds the first
# enabled sort order that it can support, and uses it instead.
- assert (f2.ORDER_TITLE ==
- f2.default_facet(config, f2.ORDER_FACET_GROUP_NAME))
+ assert f2.ORDER_TITLE == f2.default_facet(config, f2.ORDER_FACET_GROUP_NAME)
# If no enabled sort orders are supported, it just sorts
# by Work ID, so that there is always _some_ sort order.
config.enabled = [FacetConstants.ORDER_ADDED_TO_COLLECTION]
- assert (f2.ORDER_WORK_ID ==
- f2.default_facet(config, f2.ORDER_FACET_GROUP_NAME))
+ assert f2.ORDER_WORK_ID == f2.default_facet(config, f2.ORDER_FACET_GROUP_NAME)
def test_order_by(self):
E = Edition
W = Work
+
def order(facet, ascending=None):
f = DatabaseBackedFacets(
self._default_library,
@@ -984,9 +1042,9 @@ def order(facet, ascending=None):
return f.order_by()[0]
def compare(a, b):
- assert(len(a) == len(b))
+ assert len(a) == len(b)
for i in range(0, len(a)):
- assert(a[i].compare(b[i]))
+ assert a[i].compare(b[i])
expect = [E.sort_author.asc(), E.sort_title.asc(), W.id.asc()]
actual = order(Facets.ORDER_AUTHOR, True)
@@ -1000,8 +1058,12 @@ def compare(a, b):
actual = order(Facets.ORDER_TITLE, True)
compare(expect, actual)
- expect = [W.last_update_time.asc(), E.sort_author.asc(),
- E.sort_title.asc(), W.id.asc()]
+ expect = [
+ W.last_update_time.asc(),
+ E.sort_author.asc(),
+ E.sort_title.asc(),
+ W.id.asc(),
+ ]
actual = order(Facets.ORDER_LAST_UPDATE, True)
compare(expect, actual)
@@ -1010,7 +1072,6 @@ def compare(a, b):
actual = order(Facets.ORDER_ADDED_TO_COLLECTION, True)
compare(expect, actual)
-
def test_modify_database_query(self):
# Set up works that are matched by different types of collections.
@@ -1024,8 +1085,8 @@ def test_modify_database_query(self):
# A high-quality licensed work which is not currently available.
(licensed_e1, licensed_p1) = self._edition(
- data_source_name=DataSource.OVERDRIVE,
- with_license_pool=True)
+ data_source_name=DataSource.OVERDRIVE, with_license_pool=True
+ )
licensed_high = self._work(presentation_edition=licensed_e1)
licensed_high.license_pools.append(licensed_p1)
licensed_high.quality = 0.8
@@ -1035,8 +1096,8 @@ def test_modify_database_query(self):
# A low-quality licensed work which is currently available.
(licensed_e2, licensed_p2) = self._edition(
- data_source_name=DataSource.OVERDRIVE,
- with_license_pool=True)
+ data_source_name=DataSource.OVERDRIVE, with_license_pool=True
+ )
licensed_p2.open_access = False
licensed_low = self._work(presentation_edition=licensed_e2)
licensed_low.license_pools.append(licensed_p2)
@@ -1045,13 +1106,17 @@ def test_modify_database_query(self):
licensed_p2.licenses_available = 1
# A high-quality work with unlimited access.
- unlimited_access_high = self._work(with_license_pool=True, unlimited_access=True)
+ unlimited_access_high = self._work(
+ with_license_pool=True, unlimited_access=True
+ )
unlimited_access_high.quality = 0.8
qu = DatabaseBackedWorkList.base_query(self._db)
- def facetify(collection=Facets.COLLECTION_FULL,
- available=Facets.AVAILABLE_ALL,
- order=Facets.ORDER_TITLE
+
+ def facetify(
+ collection=Facets.COLLECTION_FULL,
+ available=Facets.AVAILABLE_ALL,
+ order=Facets.ORDER_TITLE,
):
f = DatabaseBackedFacets(
self._default_library, collection, available, order
@@ -1097,25 +1162,32 @@ def facetify(collection=Facets.COLLECTION_FULL,
# Try some different orderings to verify that order_by()
# is called and used properly.
title_order = facetify(order=Facets.ORDER_TITLE)
- assert ([open_access_high.id, open_access_low.id, licensed_high.id,
- licensed_low.id, unlimited_access_high.id] ==
- [x.id for x in title_order])
- assert (
- ['sort_title', 'sort_author', 'id'] ==
- [x.name for x in title_order._distinct])
+ assert [
+ open_access_high.id,
+ open_access_low.id,
+ licensed_high.id,
+ licensed_low.id,
+ unlimited_access_high.id,
+ ] == [x.id for x in title_order]
+ assert ["sort_title", "sort_author", "id"] == [
+ x.name for x in title_order._distinct
+ ]
# This sort order is not supported, so the default is used.
unsupported_order = facetify(order=Facets.ORDER_ADDED_TO_COLLECTION)
- assert ([unlimited_access_high.id, licensed_low.id, licensed_high.id, open_access_low.id,
- open_access_high.id] ==
- [x.id for x in unsupported_order])
- assert (
- ['sort_author', 'sort_title', 'id'] ==
- [x.name for x in unsupported_order._distinct])
+ assert [
+ unlimited_access_high.id,
+ licensed_low.id,
+ licensed_high.id,
+ open_access_low.id,
+ open_access_high.id,
+ ] == [x.id for x in unsupported_order]
+ assert ["sort_author", "sort_title", "id"] == [
+ x.name for x in unsupported_order._distinct
+ ]
class TestFeaturedFacets(DatabaseTest):
-
def test_constructor(self):
# Verify that constructor arguments are stored.
entrypoint = object()
@@ -1149,8 +1221,10 @@ def test_default(self):
# Or with nothing -- in which case the default value is used.
facets = FeaturedFacets.default(None)
- assert (Configuration.DEFAULT_MINIMUM_FEATURED_QUALITY ==
- facets.minimum_featured_quality)
+ assert (
+ Configuration.DEFAULT_MINIMUM_FEATURED_QUALITY
+ == facets.minimum_featured_quality
+ )
def test_navigate(self):
# Test the ability of navigate() to move between slight
@@ -1168,7 +1242,6 @@ def test_navigate(self):
class TestSearchFacets(DatabaseTest):
-
def test_constructor(self):
# The SearchFacets constructor allows you to specify
# a medium and language (or a list of them) as well
@@ -1190,8 +1263,7 @@ def test_constructor(self):
# If you pass in a single value for medium or language
# they are turned into a list.
with_single_value = m(
- entrypoint=mock_entrypoint, media=Edition.BOOK_MEDIUM,
- languages="eng"
+ entrypoint=mock_entrypoint, media=Edition.BOOK_MEDIUM, languages="eng"
)
assert mock_entrypoint == with_single_value.entrypoint
assert [Edition.BOOK_MEDIUM] == with_single_value.media
@@ -1200,9 +1272,7 @@ def test_constructor(self):
# If you pass in a list of values, it's left alone.
media = [Edition.BOOK_MEDIUM, Edition.AUDIO_MEDIUM]
languages = ["eng", "spa"]
- with_multiple_values = m(
- media=media, languages=languages
- )
+ with_multiple_values = m(media=media, languages=languages)
assert media == with_multiple_values.media
assert languages == with_multiple_values.languages
@@ -1229,9 +1299,12 @@ def test_from_request(self):
# These variables mock the query string arguments and
# HTTP headers of an HTTP request.
- arguments = dict(entrypoint=EbooksEntryPoint.INTERNAL_NAME,
- media=Edition.AUDIO_MEDIUM, min_score="123")
- headers = {"Accept-Language" : "da, en-gb;q=0.8"}
+ arguments = dict(
+ entrypoint=EbooksEntryPoint.INTERNAL_NAME,
+ media=Edition.AUDIO_MEDIUM,
+ min_score="123",
+ )
+ headers = {"Accept-Language": "da, en-gb;q=0.8"}
get_argument = arguments.get
get_header = headers.get
@@ -1244,8 +1317,12 @@ def test_from_request(self):
def from_request(**extra):
return SearchFacets.from_request(
- self._default_library, self._default_library, get_argument,
- get_header, unused, **extra
+ self._default_library,
+ self._default_library,
+ get_argument,
+ get_header,
+ unused,
+ **extra
)
facets = from_request(extra="value")
@@ -1268,12 +1345,12 @@ def from_request(**extra):
# The SearchFacets implementation turned the 'Accept-Language'
# header into a set of language codes.
- assert ['dan', 'eng'] == facets.languages
+ assert ["dan", "eng"] == facets.languages
# Try again with bogus media, languages, and minimum score.
- arguments['media'] = 'Unknown Media'
- arguments['min_score'] = 'not a number'
- headers['Accept-Language'] = "xx, ql"
+ arguments["media"] = "Unknown Media"
+ arguments["min_score"] = "not a number"
+ headers["Accept-Language"] = "xx, ql"
# None of the bogus information was used.
facets = from_request()
@@ -1283,15 +1360,15 @@ def from_request(**extra):
# Reading the language query with acceptable Accept-Language header
# but not passing that value through.
- arguments['language'] = 'all'
- headers['Accept-Language'] = "da, en-gb;q=0.8"
+ arguments["language"] = "all"
+ headers["Accept-Language"] = "da, en-gb;q=0.8"
facets = from_request()
assert None == facets.languages
# Try again with no information.
- del arguments['media']
- del headers['Accept-Language']
+ del arguments["media"]
+ del headers["Accept-Language"]
facets = from_request()
assert None == facets.media
@@ -1301,9 +1378,14 @@ def test_from_request_from_admin_search(self):
# If the SearchFacets object is being created by a search run from the admin interface,
# there might be order and language arguments which should be used to filter search results.
- arguments = dict(order="author", language="fre", entrypoint=EbooksEntryPoint.INTERNAL_NAME,
- media=Edition.AUDIO_MEDIUM, min_score="123")
- headers = {"Accept-Language" : "da, en-gb;q=0.8"}
+ arguments = dict(
+ order="author",
+ language="fre",
+ entrypoint=EbooksEntryPoint.INTERNAL_NAME,
+ media=Edition.AUDIO_MEDIUM,
+ min_score="123",
+ )
+ headers = {"Accept-Language": "da, en-gb;q=0.8"}
get_argument = arguments.get
get_header = headers.get
@@ -1316,20 +1398,24 @@ def test_from_request_from_admin_search(self):
def from_request(**extra):
return SearchFacets.from_request(
- self._default_library, self._default_library, get_argument,
- get_header, unused, **extra
+ self._default_library,
+ self._default_library,
+ get_argument,
+ get_header,
+ unused,
+ **extra
)
facets = from_request(extra="value")
# The SearchFacets implementation uses the order and language values submitted by the admin.
assert "author" == facets.order
- assert ['fre'] == facets.languages
-
+ assert ["fre"] == facets.languages
def test_selectable_entrypoints(self):
"""If the WorkList has more than one facet, an 'everything' facet
is added for search purposes.
"""
+
class MockWorkList(object):
def __init__(self):
self.entrypoints = None
@@ -1359,8 +1445,9 @@ def __init__(self):
def test_items(self):
facets = SearchFacets(
entrypoint=EverythingEntryPoint,
- media=Edition.BOOK_MEDIUM, languages=['eng'],
- min_score=123
+ media=Edition.BOOK_MEDIUM,
+ languages=["eng"],
+ min_score=123,
)
# When we call items(), e.g. to create a query string that
@@ -1370,28 +1457,23 @@ def test_items(self):
# language is not propagated, because it's set through
# the Accept-Language header rather than through a query
# string.
- assert (
- [('entrypoint', EverythingEntryPoint.INTERNAL_NAME),
- (Facets.ORDER_FACET_GROUP_NAME, SearchFacets.ORDER_BY_RELEVANCE),
- (Facets.AVAILABILITY_FACET_GROUP_NAME, Facets.AVAILABLE_ALL),
- (Facets.COLLECTION_FACET_GROUP_NAME, Facets.COLLECTION_FULL),
- ('media', Edition.BOOK_MEDIUM),
- ('min_score', '123'),
- ] ==
- list(facets.items()))
+ assert [
+ ("entrypoint", EverythingEntryPoint.INTERNAL_NAME),
+ (Facets.ORDER_FACET_GROUP_NAME, SearchFacets.ORDER_BY_RELEVANCE),
+ (Facets.AVAILABILITY_FACET_GROUP_NAME, Facets.AVAILABLE_ALL),
+ (Facets.COLLECTION_FACET_GROUP_NAME, Facets.COLLECTION_FULL),
+ ("media", Edition.BOOK_MEDIUM),
+ ("min_score", "123"),
+ ] == list(facets.items())
def test_navigation(self):
"""Navigating from one SearchFacets to another gives a new
SearchFacets object. A number of fields can be changed,
including min_score, which is SearchFacets-specific.
"""
- facets = SearchFacets(
- entrypoint=object(), order="field1", min_score=100
- )
+ facets = SearchFacets(entrypoint=object(), order="field1", min_score=100)
new_ep = object()
- new_facets = facets.navigate(
- entrypoint=new_ep, order="field2", min_score=120
- )
+ new_facets = facets.navigate(entrypoint=new_ep, order="field2", min_score=120)
assert isinstance(new_facets, SearchFacets)
assert new_ep == new_facets.entrypoint
assert "field2" == new_facets.order
@@ -1439,7 +1521,6 @@ def test_modify_search_filter(self):
facets.modify_search_filter(filter)
assert ["eng"] == filter.languages
-
# If no languages are specified in the SearchFacets, the value
# set by the filter is used by itself.
facets = SearchFacets(languages=None)
@@ -1492,7 +1573,6 @@ def test_modify_search_filter_accepts_relevance_order(self):
class TestPagination(DatabaseTest):
-
def test_from_request(self):
# No arguments -> Class defaults.
@@ -1607,7 +1687,7 @@ def test_page_loaded(self):
pagination = Pagination()
assert None == pagination.this_page_size
assert False == pagination.page_has_loaded
- pagination.page_loaded([1,2,3])
+ pagination.page_loaded([1, 2, 3])
assert 3 == pagination.this_page_size
assert True == pagination.page_has_loaded
@@ -1615,17 +1695,19 @@ def test_modify_search_query(self):
# The default implementation of modify_search_query is to slice
# a set of search results like a list.
pagination = Pagination(offset=2, size=3)
- o = [1,2,3,4,5,6]
- assert o[2:2+3] == pagination.modify_search_query(o)
+ o = [1, 2, 3, 4, 5, 6]
+ assert o[2 : 2 + 3] == pagination.modify_search_query(o)
class MockWork(object):
"""Acts enough like a Work to trick code that doesn't need to make
database requests.
"""
+
def __init__(self, id):
self.id = id
+
class MockWorks(WorkList):
"""A WorkList that mocks works_from_database()."""
@@ -1657,7 +1739,6 @@ def random_sample(self, query, target_size):
class TestWorkList(DatabaseTest):
-
def test_initialize(self):
wl = WorkList()
child = WorkList()
@@ -1667,28 +1748,32 @@ def test_initialize(self):
# Create a WorkList that's associated with a Library, two genres,
# and a child WorkList.
- wl.initialize(self._default_library, children=[child],
- genres=[sf, romance], entrypoints=[1,2,3])
+ wl.initialize(
+ self._default_library,
+ children=[child],
+ genres=[sf, romance],
+ entrypoints=[1, 2, 3],
+ )
# Access the Library.
assert self._default_library == wl.get_library(self._db)
# The Collections associated with the WorkList are those associated
# with the Library.
- assert (set(wl.collection_ids) ==
- set([x.id for x in self._default_library.collections]))
+ assert set(wl.collection_ids) == set(
+ [x.id for x in self._default_library.collections]
+ )
# The Genres associated with the WorkList are the ones passed
# in on the constructor.
- assert (set(wl.genre_ids) ==
- set([x.id for x in [sf, romance]]))
+ assert set(wl.genre_ids) == set([x.id for x in [sf, romance]])
# The WorkList's child is the WorkList passed in to the constructor.
assert [child] == wl.visible_children
# The Worklist's .entrypoints is whatever was passed in
# to the constructor.
- assert [1,2,3] == wl.entrypoints
+ assert [1, 2, 3] == wl.entrypoints
def test_initialize_without_library(self):
# It's possible to initialize a WorkList with no Library.
@@ -1715,16 +1800,16 @@ def test_initialize_with_customlists(self):
# Make a WorkList based on specific CustomLists.
worklist = WorkList()
- worklist.initialize(self._default_library,
- customlists=[customlist1, customlist3])
+ worklist.initialize(
+ self._default_library, customlists=[customlist1, customlist3]
+ )
assert [customlist1.id, customlist3.id] == worklist.customlist_ids
assert None == worklist.list_datasource_id
# Make a WorkList based on a DataSource, as a shorthand for
# 'all the CustomLists from that DataSource'.
worklist = WorkList()
- worklist.initialize(self._default_library,
- list_datasource=gutenberg)
+ worklist.initialize(self._default_library, list_datasource=gutenberg)
assert [customlist1.id, customlist2.id] == worklist.customlist_ids
assert gutenberg.id == worklist.list_datasource_id
@@ -1742,8 +1827,7 @@ def test_initialize_without_library(self):
# The Genres associated with the WorkList are the ones passed
# in on the constructor.
- assert (set(wl.genre_ids) ==
- set([x.id for x in [sf, romance]]))
+ assert set(wl.genre_ids) == set([x.id for x in [sf, romance]])
def test_initialize_uses_append_child_hook_method(self):
# When a WorkList is initialized with children, the children
@@ -1751,6 +1835,7 @@ def test_initialize_uses_append_child_hook_method(self):
# method, not simply set to WorkList.children.
class Mock(WorkList):
append_child_calls = []
+
def append_child(self, child):
self.append_child_calls.append(child)
return super(Mock, self).append_child(child)
@@ -1781,9 +1866,7 @@ def test_top_level_for_library(self):
lane1.sublanes.append(sublane)
# This lane belongs to a different library.
- other_library = self._library(
- name="Other Library", short_name="Other"
- )
+ other_library = self._library(name="Other Library", short_name="Other")
other_library_lane = self._lane(
display_name="Other Library Lane", library=other_library
)
@@ -1807,22 +1890,20 @@ def test_top_level_for_library(self):
assert [] == wl.children
assert Edition.FULFILLABLE_MEDIA == wl.media
-
def test_audience_key(self):
wl = WorkList()
wl.initialize(library=self._default_library)
# No audience.
- assert '' == wl.audience_key
+ assert "" == wl.audience_key
# All audiences.
wl.audiences = Classifier.AUDIENCES
- assert '' == wl.audience_key
+ assert "" == wl.audience_key
# Specific audiences.
- wl.audiences = [Classifier.AUDIENCE_CHILDREN,
- Classifier.AUDIENCE_YOUNG_ADULT]
- assert 'Children,Young+Adult' == wl.audience_key
+ wl.audiences = [Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_YOUNG_ADULT]
+ assert "Children,Young+Adult" == wl.audience_key
def test_parent(self):
# A WorkList has no parent.
@@ -1851,9 +1932,7 @@ def test_visible_children(self):
invisible.visible = False
child_wl = WorkList()
child_wl.initialize(self._default_library)
- wl.initialize(
- self._default_library, children=[visible, invisible, child_wl]
- )
+ wl.initialize(self._default_library, children=[visible, invisible, child_wl])
assert set([child_wl, visible]) == set(wl.visible_children)
def test_visible_children_sorted(self):
@@ -1861,16 +1940,14 @@ def test_visible_children_sorted(self):
wl = WorkList()
lane_child = self._lane()
- lane_child.display_name='ZZ'
+ lane_child.display_name = "ZZ"
lane_child.priority = 0
wl_child = WorkList()
wl_child.priority = 1
- wl_child.display_name='AA'
+ wl_child.display_name = "AA"
- wl.initialize(
- self._default_library, children=[lane_child, wl_child]
- )
+ wl.initialize(self._default_library, children=[lane_child, wl_child])
# lane_child has a higher priority so it shows up first even
# though its display name starts with a Z.
@@ -1964,7 +2041,7 @@ def test_accessible_to(self):
# Give it some audience restrictions.
wl.audiences = [Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_CHILDREN]
- wl.target_age = tuple_to_numericrange((4,5))
+ wl.target_age = tuple_to_numericrange((4, 5))
# Now it depends on the return value of Patron.work_is_age_appropriate.
# Mock that method.
@@ -1991,10 +2068,12 @@ def test_accessible_to(self):
# restriction in our WorkList. Only if _every_ call returns
# True is the WorkList considered age-appropriate for the
# patron.
- patron.work_is_age_appropriate.assert_has_calls([
- call(wl.audiences[0], wl.target_age),
- call(wl.audiences[1], wl.target_age),
- ])
+ patron.work_is_age_appropriate.assert_has_calls(
+ [
+ call(wl.audiences[0], wl.target_age),
+ call(wl.audiences[1], wl.target_age),
+ ]
+ )
def test_uses_customlists(self):
"""A WorkList is said to use CustomLists if either ._customlist_ids
@@ -2018,7 +2097,6 @@ def test_max_cache_age(self):
wl = WorkList()
assert OPDSFeed.DEFAULT_MAX_AGE == wl.max_cache_age(object())
-
def test_filter(self):
# Verify that filter() calls modify_search_filter_hook()
# and can handle either a new Filter being returned or a Filter
@@ -2097,8 +2175,13 @@ def overview_facets(self, _db, facets):
return "A new faceting object"
def _groups_for_lanes(
- self, _db, relevant_children, relevant_lanes, pagination,
- facets, **kwargs
+ self,
+ _db,
+ relevant_children,
+ relevant_lanes,
+ pagination,
+ facets,
+ **kwargs
):
self._groups_for_lanes_called_with = (pagination, facets)
return []
@@ -2125,8 +2208,7 @@ def _groups_for_lanes(
# Now try the case where we want to use a pagination object to
# restrict the number of results per lane.
pagination = object()
- [x for x in mock.groups(self._db, pagination=pagination,
- facets=facets)]
+ [x for x in mock.groups(self._db, pagination=pagination, facets=facets)]
# The pagination object is propagated to _groups_for_lanes.
assert (pagination, facets) == mock._groups_for_lanes_called_with
mock._groups_for_lanes_called_with = None
@@ -2151,7 +2233,9 @@ def test_works(self):
class MockSearchClient(object):
"""Respond to search requests with some fake work IDs."""
+
fake_work_ids = [1, 10, 100, 1000]
+
def query_works(self, **kwargs):
self.called_with = kwargs
return self.fake_work_ids
@@ -2159,7 +2243,9 @@ def query_works(self, **kwargs):
class MockWorkList(WorkList):
"""Mock the process of turning work IDs into WorkSearchResult
objects."""
+
fake_work_list = "a list of works"
+
def works_for_hits(self, _db, work_ids, facets=None):
self.called_with = (_db, work_ids)
return self.fake_work_list
@@ -2167,18 +2253,14 @@ def works_for_hits(self, _db, work_ids, facets=None):
# Here's a WorkList.
wl = MockWorkList()
wl.initialize(self._default_library, languages=["eng"])
- facets = Facets(
- self._default_library, None, None, order=Facets.ORDER_TITLE
- )
+ facets = Facets(self._default_library, None, None, order=Facets.ORDER_TITLE)
mock_pagination = object()
mock_debug = object()
search_client = MockSearchClient()
# Ask the WorkList for a page of works, using the search index
# to drive the query instead of the database.
- result = wl.works(
- self._db, facets, mock_pagination, search_client, mock_debug
- )
+ result = wl.works(self._db, facets, mock_pagination, search_client, mock_debug)
# MockSearchClient.query_works was used to grab a list of work
# IDs.
@@ -2187,22 +2269,19 @@ def works_for_hits(self, _db, work_ids, facets=None):
# Our facets and the requirements of the WorkList were used to
# make a Filter object, which was passed as the 'filter'
# keyword argument.
- filter = query_works_kwargs.pop('filter')
- assert (Filter.from_worklist(self._db, wl, facets).build() ==
- filter.build())
+ filter = query_works_kwargs.pop("filter")
+ assert Filter.from_worklist(self._db, wl, facets).build() == filter.build()
# The other arguments to query_works are either constants or
# our mock objects.
- assert (dict(query_string=None,
- pagination=mock_pagination,
- debug=mock_debug) ==
- query_works_kwargs)
+ assert (
+ dict(query_string=None, pagination=mock_pagination, debug=mock_debug)
+ == query_works_kwargs
+ )
# The fake work IDs returned from query_works() were passed into
# works_for_hits().
- assert (
- (self._db, search_client.fake_work_ids) ==
- wl.called_with)
+ assert (self._db, search_client.fake_work_ids) == wl.called_with
# And the fake return value of works_for_hits() was used as
# the return value of works(), the method we're testing.
@@ -2215,14 +2294,13 @@ class Mock(WorkList):
def works_for_resultsets(self, _db, resultsets, facets=None):
self.called_with = (_db, resultsets)
return [["some", "results"]]
+
wl = Mock()
results = wl.works_for_hits(self._db, ["hit1", "hit2"])
# The list of hits was itself wrapped in a list, and passed
# into works_for_resultsets().
- assert (
- (self._db, [["hit1", "hit2"]]) ==
- wl.called_with)
+ assert (self._db, [["hit1", "hit2"]]) == wl.called_with
# The return value -- a list of lists of results, which
# contained a single item -- was unrolled and used as the
@@ -2246,15 +2324,15 @@ def test_works_for_resultsets(self):
class MockHit(object):
def __init__(self, work_id, has_last_update=False):
if isinstance(work_id, Work):
- self.work_id=work_id.id
+ self.work_id = work_id.id
else:
- self.work_id=work_id
+ self.work_id = work_id
self.has_last_update = has_last_update
def __contains__(self, k):
# Pretend to have the 'last_update' script field,
# if necessary.
- return (k == 'last_update' and self.has_last_update)
+ return k == "last_update" and self.has_last_update
hit1 = MockHit(w1)
hit2 = MockHit(w2)
@@ -2263,8 +2341,7 @@ def __contains__(self, k):
# Works is returned.
assert [[w2]] == m(self._db, [[hit2]])
assert [[w2], [w1]] == m(self._db, [[hit2], [hit1]])
- assert ([[w1, w1], [w2, w2], []] ==
- m(self._db, [[hit1, hit1], [hit2, hit2], []]))
+ assert [[w1, w1], [w2, w2], []] == m(self._db, [[hit1, hit1], [hit2, hit2], []])
# Works are returned in the order we ask for.
for ordering in ([hit1, hit2], [hit2, hit1]):
@@ -2311,16 +2388,12 @@ def works_for_hits(self, _db, work_ids):
return "A bunch of Works"
wl = MockWorkList()
- wl.initialize(
- self._default_library, audiences=[Classifier.AUDIENCE_CHILDREN]
- )
+ wl.initialize(self._default_library, audiences=[Classifier.AUDIENCE_CHILDREN])
query = "a query"
class MockSearchClient(object):
def query_works(self, query, filter, pagination, debug):
- self.query_works_called_with = (
- query, filter, pagination, debug
- )
+ self.query_works_called_with = (query, filter, pagination, debug)
return "A bunch of work IDs"
# Search with the default arguments.
@@ -2329,9 +2402,7 @@ def query_works(self, query, filter, pagination, debug):
# The results of query_works were passed into
# MockWorkList.works_for_hits.
- assert (
- (self._db, "A bunch of work IDs") ==
- wl.works_for_hits_called_with)
+ assert (self._db, "A bunch of work IDs") == wl.works_for_hits_called_with
# The return value of MockWorkList.works_for_hits is
# used as the return value of query_works().
@@ -2350,8 +2421,10 @@ def query_works(self, query, filter, pagination, debug):
# A Filter object was created to match only works that belong
# in the MockWorkList.
- assert ([Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_ALL_AGES] ==
- filter.audiences)
+ assert [
+ Classifier.AUDIENCE_CHILDREN,
+ Classifier.AUDIENCE_ALL_AGES,
+ ] == filter.audiences
# A default Pagination object was created.
assert 0 == pagination.offset
@@ -2361,8 +2434,7 @@ def query_works(self, query, filter, pagination, debug):
# objects.
facets = SearchFacets(languages=["chi"])
pagination = object()
- results = wl.search(self._db, query, client, pagination, facets,
- debug=True)
+ results = wl.search(self._db, query, client, pagination, facets, debug=True)
qu, filter, pag, debug = client.query_works_called_with
assert query == qu
@@ -2371,8 +2443,10 @@ def query_works(self, query, filter, pagination, debug):
# The Filter incorporates restrictions imposed by both the
# MockWorkList and the Facets.
- assert ([Classifier.AUDIENCE_CHILDREN, Classifier.AUDIENCE_ALL_AGES] ==
- filter.audiences)
+ assert [
+ Classifier.AUDIENCE_CHILDREN,
+ Classifier.AUDIENCE_ALL_AGES,
+ ] == filter.audiences
assert ["chi"] == filter.languages
def test_search_failures(self):
@@ -2388,6 +2462,7 @@ def test_search_failures(self):
class NoResults(object):
def query_works(self, *args, **kwargs):
return None
+
assert [] == wl.search(self._db, query, NoResults())
# If there's an ElasticSearch exception during the query,
@@ -2395,11 +2470,11 @@ def query_works(self, *args, **kwargs):
class RaisesException(object):
def query_works(self, *args, **kwargs):
raise ElasticsearchException("oh no")
+
assert [] == wl.search(self._db, query, RaisesException())
class TestDatabaseBackedWorkList(DatabaseTest):
-
def test_works_from_database(self):
# Verify that the works_from_database() method calls the
# methods we expect, in the right order.
@@ -2423,13 +2498,14 @@ def distinct(self, fields):
def __repr__(self):
return "" % (
- len(self.clauses), self.clauses[-1]
+ len(self.clauses),
+ self.clauses[-1],
)
class MockWorkList(DatabaseBackedWorkList):
def __init__(self, _db):
super(MockWorkList, self).__init__()
- self._db = _db # We'll be using this in assertions.
+ self._db = _db # We'll be using this in assertions.
self.stages = []
def _stage(self, method_name, _db, qu, qu_is_previous_stage=True):
@@ -2456,12 +2532,12 @@ def base_query(self, _db):
# This kicks off the process -- most future calls will
# use _stage().
assert _db == self._db
- query = MockQuery(['base_query'])
+ query = MockQuery(["base_query"])
self.stages.append(query)
return query
def only_show_ready_deliverable_works(self, _db, qu):
- return self._stage('only_show_ready_deliverable_works', _db, qu)
+ return self._stage("only_show_ready_deliverable_works", _db, qu)
def bibliographic_filter_clauses(self, _db, qu):
# This method is a little different, so we can't use
@@ -2474,7 +2550,7 @@ def bibliographic_filter_clauses(self, _db, qu):
return qu, []
def modify_database_query_hook(self, _db, qu):
- return self._stage('modify_database_query_hook', _db, qu)
+ return self._stage("modify_database_query_hook", _db, qu)
def active_bibliographic_filter_clauses(self, _db, qu):
# This alternate implementation of
@@ -2485,10 +2561,7 @@ def active_bibliographic_filter_clauses(self, _db, qu):
["new query made inside active_bibliographic_filter_clauses"]
)
self.stages.append(new_query)
- return (
- new_query,
- [text("clause 1"), text("clause 2")]
- )
+ return (new_query, [text("clause 1"), text("clause 2")])
# The simplest case: no facets or pagination,
# and bibliographic_filter_clauses does nothing.
@@ -2502,17 +2575,20 @@ def active_bibliographic_filter_clauses(self, _db, qu):
# MockQuery is constructed by chaining method calls. Now we
# just need to verify that all the methods were called and in
# the order we expect.
- assert (['base_query', 'only_show_ready_deliverable_works',
- 'modify_database_query_hook'] ==
- result.clauses)
+ assert [
+ "base_query",
+ "only_show_ready_deliverable_works",
+ "modify_database_query_hook",
+ ] == result.clauses
# bibliographic_filter_clauses used a different mechanism, but
# since it stored the MockQuery it was called with, we can see
# when it was called -- just after
# only_show_ready_deliverable_works.
- assert (
- ['base_query', 'only_show_ready_deliverable_works'] ==
- wl.bibliographic_filter_clauses_called_with.clauses)
+ assert [
+ "base_query",
+ "only_show_ready_deliverable_works",
+ ] == wl.bibliographic_filter_clauses_called_with.clauses
wl.bibliographic_filter_clauses_called_with = None
# Since nobody made the query distinct, it was set distinct on
@@ -2537,9 +2613,7 @@ def modify_database_query(self, _db, qu):
# Normally, _stage() will assert that `qu` is the
# return value from the previous call, but this time
# we want to assert the opposite.
- result = self.wl._stage(
- "facets", _db, qu, qu_is_previous_stage=False
- )
+ result = self.wl._stage("facets", _db, qu, qu_is_previous_stage=False)
distinct = result.distinct("some other field")
self.wl.stages.append(distinct)
@@ -2557,27 +2631,32 @@ def modify_database_query(self, _db, qu):
)
# Here are the methods called before bibliographic_filter_clauses.
- assert (['base_query', 'only_show_ready_deliverable_works'] ==
- wl.pre_bibliographic_filter.clauses)
+ assert [
+ "base_query",
+ "only_show_ready_deliverable_works",
+ ] == wl.pre_bibliographic_filter.clauses
# bibliographic_filter_clauses created a brand new object,
# which ended up as our result after some more methods were
# called on it.
- assert ('new query made inside active_bibliographic_filter_clauses' ==
- result.clauses.pop(0))
+ assert (
+ "new query made inside active_bibliographic_filter_clauses"
+ == result.clauses.pop(0)
+ )
# bibliographic_filter_clauses() returned two clauses which were
# combined with and_().
bibliographic_filter_clauses = result.clauses.pop(0)
- assert (str(and_(text('clause 1'), text('clause 2'))) ==
- str(bibliographic_filter_clauses))
+ assert str(and_(text("clause 1"), text("clause 2"))) == str(
+ bibliographic_filter_clauses
+ )
# The rest of the calls are easy to trac.
- assert (['facets',
- 'modify_database_query_hook',
- 'pagination',
- ] ==
- result.clauses)
+ assert [
+ "facets",
+ "modify_database_query_hook",
+ "pagination",
+ ] == result.clauses
# The query was made distinct on some other field, so the
# default behavior (making it distinct on Work.id) wasn't
@@ -2595,10 +2674,10 @@ def test_works_from_database_end_to_end(self):
# Create two books.
oliver_twist = self._work(
- title='Oliver Twist', with_license_pool=True, language="eng"
+ title="Oliver Twist", with_license_pool=True, language="eng"
)
barnaby_rudge = self._work(
- title='Barnaby Rudge', with_license_pool=True, language="spa"
+ title="Barnaby Rudge", with_license_pool=True, language="spa"
)
# A standard DatabaseBackedWorkList will find both books.
@@ -2608,7 +2687,7 @@ def test_works_from_database_end_to_end(self):
# A work list with a language restriction will only find books
# in that language.
- wl.initialize(self._default_library, languages=['eng'])
+ wl.initialize(self._default_library, languages=["eng"])
assert [oliver_twist] == [x for x in wl.works_from_database(self._db)]
# A DatabaseBackedWorkList will only find books licensed
@@ -2638,13 +2717,17 @@ def test_works_from_database_end_to_end(self):
self._default_library,
collection=Facets.COLLECTION_FULL,
availability=Facets.AVAILABLE_ALL,
- order=Facets.ORDER_TITLE
+ order=Facets.ORDER_TITLE,
)
pagination = Pagination(offset=1, size=1)
- assert [oliver_twist] == wl.works_from_database(self._db, facets, pagination).all()
+ assert [oliver_twist] == wl.works_from_database(
+ self._db, facets, pagination
+ ).all()
facets.order_ascending = False
- assert [barnaby_rudge] == wl.works_from_database(self._db, facets, pagination).all()
+ assert [barnaby_rudge] == wl.works_from_database(
+ self._db, facets, pagination
+ ).all()
# Ensure that availability facets are handled properly
# We still have two works:
@@ -2690,17 +2773,16 @@ def _modify_loading(cls, qu):
@classmethod
def _defer_unused_fields(cls, qu):
- return qu + ['_defer_unused_fields']
+ return qu + ["_defer_unused_fields"]
result = Mock.base_query(self._db)
[base_query, m, d] = result
- expect = self._db.query(Work).join(
- Work.license_pools
- ).join(
- Work.presentation_edition
- ).filter(
- LicensePool.superceded==False
+ expect = (
+ self._db.query(Work)
+ .join(Work.license_pools)
+ .join(Work.presentation_edition)
+ .filter(LicensePool.superceded == False)
)
assert str(expect) == str(base_query)
assert "_modify_loading" == m
@@ -2715,25 +2797,26 @@ class MockWorkList(DatabaseBackedWorkList):
The hook methods themselves are tested separately.
"""
+
def __init__(self, parent):
super(MockWorkList, self).__init__()
self._parent = parent
self._inherit_parent_restrictions = False
def audience_filter_clauses(self, _db, qu):
- called['audience_filter_clauses'] = (_db, qu)
+ called["audience_filter_clauses"] = (_db, qu)
return []
def customlist_filter_clauses(self, qu):
- called['customlist_filter_clauses'] = qu
+ called["customlist_filter_clauses"] = qu
return qu, []
def age_range_filter_clauses(self):
- called['age_range_filter_clauses'] = True
+ called["age_range_filter_clauses"] = True
return []
def genre_filter_clause(self, qu):
- called['genre_filter_clause'] = qu
+ called["genre_filter_clause"] = qu
return qu, None
@property
@@ -2746,6 +2829,7 @@ def inherit_parent_restrictions(self):
class MockParent(object):
bibliographic_filter_clauses_called_with = None
+
def bibliographic_filter_clauses(self, _db, qu):
self.bibliographic_filter_clauses_called_with = (_db, qu)
return qu, []
@@ -2760,26 +2844,24 @@ def bibliographic_filter_clauses(self, _db, qu):
# If no languages or genre IDs are specified, and the hook
# methods do nothing, then bibliographic_filter_clauses() has
# no effect.
- final_qu, clauses = wl.bibliographic_filter_clauses(
- self._db, original_qu
- )
+ final_qu, clauses = wl.bibliographic_filter_clauses(self._db, original_qu)
assert original_qu == final_qu
assert [] == clauses
# But at least the apply_audience_filter was called with the correct
# arguments.
- _db, qu = called['audience_filter_clauses']
+ _db, qu = called["audience_filter_clauses"]
assert self._db == _db
assert original_qu == qu
# age_range_filter_clauses was also called.
- assert True == called['age_range_filter_clauses']
+ assert True == called["age_range_filter_clauses"]
# customlist_filter_clauses and genre_filter_clause were not
# called because the WorkList doesn't do anything relating to
# custom lists.
- assert 'customlist_filter_clauses' not in called
- assert 'genre_filter_clause' not in called
+ assert "customlist_filter_clauses" not in called
+ assert "genre_filter_clause" not in called
# The parent's bibliographic_filter_clauses() implementation
# was not called, because wl.inherit_parent_restrictions is
@@ -2789,18 +2871,17 @@ def bibliographic_filter_clauses(self, _db, qu):
# Set things up so that those other methods will be called.
empty_list, ignore = self._customlist(num_entries=0)
sf, ignore = Genre.lookup(self._db, "Science Fiction")
- wl.initialize(self._default_library, customlists=[empty_list],
- genres=[sf])
+ wl.initialize(self._default_library, customlists=[empty_list], genres=[sf])
wl._inherit_parent_restrictions = True
- final_qu, clauses = wl.bibliographic_filter_clauses(
- self._db, original_qu
- )
+ final_qu, clauses = wl.bibliographic_filter_clauses(self._db, original_qu)
- assert ((self._db, original_qu) ==
- parent.bibliographic_filter_clauses_called_with)
- assert original_qu == called['genre_filter_clause']
- assert original_qu == called['customlist_filter_clauses']
+ assert (
+ self._db,
+ original_qu,
+ ) == parent.bibliographic_filter_clauses_called_with
+ assert original_qu == called["genre_filter_clause"]
+ assert original_qu == called["customlist_filter_clauses"]
# But none of those methods changed anything, because their
# implementations didn't return anything.
@@ -2810,14 +2891,14 @@ def bibliographic_filter_clauses(self, _db, qu):
# bibliographic_filter_clauses.
overdrive = DataSource.lookup(self._db, DataSource.OVERDRIVE)
wl.initialize(
- self._default_library, languages=['eng'],
+ self._default_library,
+ languages=["eng"],
media=[Edition.BOOK_MEDIUM],
- fiction=True, license_datasource=overdrive
+ fiction=True,
+ license_datasource=overdrive,
)
- final_qu, clauses = wl.bibliographic_filter_clauses(
- self._db, original_qu
- )
+ final_qu, clauses = wl.bibliographic_filter_clauses(self._db, original_qu)
assert original_qu == final_qu
language, medium, fiction, datasource = clauses
@@ -2825,8 +2906,8 @@ def bibliographic_filter_clauses(self, _db, qu):
# that the constraints are similar.
assert str(language) == str(Edition.language.in_(wl.languages))
assert str(medium) == str(Edition.medium.in_(wl.media))
- assert str(fiction) == str(Work.fiction==True)
- assert str(datasource) == str(LicensePool.data_source_id==overdrive.id)
+ assert str(fiction) == str(Work.fiction == True)
+ assert str(datasource) == str(LicensePool.data_source_id == overdrive.id)
def test_bibliographic_filter_clauses_end_to_end(self):
# Verify that bibliographic_filter_clauses generates
@@ -2838,30 +2919,31 @@ def test_bibliographic_filter_clauses_end_to_end(self):
# DatabaseBackedWorkLists.
sf, ignore = Genre.lookup(self._db, "Science Fiction")
english_sf = self._work(
- title="English SF", language="eng", with_license_pool=True,
- audience=Classifier.AUDIENCE_YOUNG_ADULT
+ title="English SF",
+ language="eng",
+ with_license_pool=True,
+ audience=Classifier.AUDIENCE_YOUNG_ADULT,
)
italian_sf = self._work(
- title="Italian SF", language="ita", with_license_pool=True,
- audience=Classifier.AUDIENCE_YOUNG_ADULT
+ title="Italian SF",
+ language="ita",
+ with_license_pool=True,
+ audience=Classifier.AUDIENCE_YOUNG_ADULT,
)
- english_sf.target_age = tuple_to_numericrange((12,14))
+ english_sf.target_age = tuple_to_numericrange((12, 14))
gutenberg = english_sf.license_pools[0].data_source
english_sf.presentation_edition.medium = Edition.BOOK_MEDIUM
english_sf.genres.append(sf)
italian_sf.genres.append(sf)
- def worklist_has_books(expect_books, worklist=None,
- **initialize_kwargs):
+ def worklist_has_books(expect_books, worklist=None, **initialize_kwargs):
"""Apply bibliographic filters to a query and verify
that it finds only the given books.
"""
if worklist is None:
worklist = DatabaseBackedWorkList()
worklist.initialize(self._default_library, **initialize_kwargs)
- qu, clauses = worklist.bibliographic_filter_clauses(
- self._db, original_qu
- )
+ qu, clauses = worklist.bibliographic_filter_clauses(self._db, original_qu)
qu = qu.filter(and_(*clauses))
expect_titles = sorted([x.sort_title for x in expect_books])
actual_titles = sorted([x.sort_title for x in qu])
@@ -2877,7 +2959,7 @@ def worklist_has_books(expect_books, worklist=None,
fiction=True,
license_datasource=gutenberg,
audiences=[Classifier.AUDIENCE_YOUNG_ADULT],
- target_age=tuple_to_numericrange((13,13))
+ target_age=tuple_to_numericrange((13, 13)),
)
# This might be because there _are_ no restrictions.
@@ -2890,13 +2972,11 @@ def worklist_has_books(expect_books, worklist=None,
romance, ignore = Genre.lookup(self._db, "Romance")
worklist_has_books([], languages=["eng"], genres=[romance])
worklist_has_books(
- [],
- languages=["eng"], genres=[sf], media=[Edition.AUDIO_MEDIUM]
+ [], languages=["eng"], genres=[sf], media=[Edition.AUDIO_MEDIUM]
)
worklist_has_books([], fiction=False)
worklist_has_books(
- [],
- license_datasource=DataSource.lookup(self._db, DataSource.OVERDRIVE)
+ [], license_datasource=DataSource.lookup(self._db, DataSource.OVERDRIVE)
)
# If the WorkList has custom list IDs, then works will only show up if
@@ -2923,9 +3003,7 @@ def worklist_has_books(expect_books, worklist=None,
english_lane.customlists.append(english_list)
# This child of that lane has books from the list of SF books.
- sf_lane = self._lane(
- parent=english_lane, inherit_parent_restrictions=False
- )
+ sf_lane = self._lane(parent=english_lane, inherit_parent_restrictions=False)
sf_lane.customlists.append(sf_list)
# When the child lane does not inherit its parent restrictions,
@@ -2955,8 +3033,7 @@ def worklist_has_books(expect_books, worklist=None,
# Here's a child of that lane, which contains science fiction.
sf_shorts = self._lane(
- genres=[sf], parent=short_stories_lane,
- inherit_parent_restrictions=False
+ genres=[sf], parent=short_stories_lane, inherit_parent_restrictions=False
)
self._db.flush()
@@ -2995,24 +3072,23 @@ def worklist_has_books(expect, **wl_args):
audience=Classifier.AUDIENCE_YOUNG_ADULT,
with_license_pool=True,
)
- fourteen_or_fifteen.target_age = tuple_to_numericrange((14,15))
+ fourteen_or_fifteen.target_age = tuple_to_numericrange((14, 15))
# This DatabaseBackedWorkList contains the YA book because its
# age range overlaps the age range of the book.
- worklist_has_books(
- [fourteen_or_fifteen], target_age=(12, 14)
- )
+ worklist_has_books([fourteen_or_fifteen], target_age=(12, 14))
worklist_has_books(
[adult, fourteen_or_fifteen],
- audiences=[Classifier.AUDIENCE_ADULT], target_age=(12, 14)
+ audiences=[Classifier.AUDIENCE_ADULT],
+ target_age=(12, 14),
)
# This lane contains no books because it skews too old for the YA
# book, but books for adults are not allowed.
older_ya = self._lane()
- older_ya.target_age = (16,17)
- worklist_has_books([], target_age=(16,17))
+ older_ya.target_age = (16, 17)
+ worklist_has_books([], target_age=(16, 17))
# Expand it to include books for adults, and the adult book
# shows up despite having no target age at all.
@@ -3025,14 +3101,16 @@ def test_audience_filter_clauses(self):
# Create a children's book and a book for adults.
adult = self._work(
title="Diseases of the Horse",
- with_license_pool=True, with_open_access_download=True,
- audience=Classifier.AUDIENCE_ADULT
+ with_license_pool=True,
+ with_open_access_download=True,
+ audience=Classifier.AUDIENCE_ADULT,
)
children = self._work(
title="Wholesome Nursery Rhymes For All Children",
- with_license_pool=True, with_open_access_download=True,
- audience=Classifier.AUDIENCE_CHILDREN
+ with_license_pool=True,
+ with_open_access_download=True,
+ audience=Classifier.AUDIENCE_CHILDREN,
)
def for_audiences(*audiences):
@@ -3075,9 +3153,7 @@ def test_customlist_filter_clauses(self):
# This DatabaseBackedWorkList gets every work on a specific list.
works_on_list = DatabaseBackedWorkList()
- works_on_list.initialize(
- self._default_library, customlists=[gutenberg_list]
- )
+ works_on_list.initialize(self._default_library, customlists=[gutenberg_list])
# This lane gets every work on every list associated with Project
# Gutenberg.
@@ -3138,7 +3214,7 @@ def results(wl=works_on_gutenberg_lists, must_be_featured=False):
# method.
gutenberg_list_2_wl = DatabaseBackedWorkList()
gutenberg_list_2_wl.initialize(
- self._default_library, customlists = [gutenberg_list_2]
+ self._default_library, customlists=[gutenberg_list_2]
)
# These two lines won't work, because these are
@@ -3199,8 +3275,7 @@ def test_works_from_database_with_superceded_pool(self):
class TestHierarchyWorkList(DatabaseTest):
- """Test HierarchyWorkList in terms of its two subclasses, Lane and TopLevelWorkList.
- """
+ """Test HierarchyWorkList in terms of its two subclasses, Lane and TopLevelWorkList."""
def test_accessible_to(self):
# In addition to the general tests imposed by WorkList, a Lane
@@ -3247,7 +3322,6 @@ def test_accessible_to(self):
class TestLane(DatabaseTest):
-
def test_get_library(self):
lane = self._lane()
assert self._default_library == lane.get_library(self._db)
@@ -3285,7 +3359,6 @@ def test_set_audiences(self):
assert [Classifier.AUDIENCE_ADULT] == lane.audiences
def test_update_size(self):
-
class Mock(object):
# Mock the ExternalSearchIndex.count_works() method to
# return specific values without consulting an actual
@@ -3301,6 +3374,7 @@ def count_works(self, filter):
else:
medium = None
return values_by_medium[medium]
+
search_engine = Mock()
# Enable the 'ebooks' and 'audiobooks' entry points.
@@ -3319,10 +3393,11 @@ def count_works(self, filter):
# The lane size is also calculated individually for every
# enabled entry point. EverythingEntryPoint is used for the
# total size of the lane.
- assert ({AudiobooksEntryPoint.URI: 3,
- EbooksEntryPoint.URI: 99,
- EverythingEntryPoint.URI: 102} ==
- fiction.size_by_entrypoint)
+ assert {
+ AudiobooksEntryPoint.URI: 3,
+ EbooksEntryPoint.URI: 99,
+ EverythingEntryPoint.URI: 102,
+ } == fiction.size_by_entrypoint
assert 102 == fiction.size
def test_visibility(self):
@@ -3359,15 +3434,21 @@ def test_parentage(self):
# this.
assert [] == list(lane.parentage)
assert [lane] == list(child_lane.parentage)
- assert ("%s / %s" % (lane.library.short_name, lane.display_name) ==
- lane.full_identifier)
+ assert (
+ "%s / %s" % (lane.library.short_name, lane.display_name)
+ == lane.full_identifier
+ )
assert (
- "%s / %s / %s / %s" % (
- lane.library.short_name, lane.display_name,
- child_lane.display_name, grandchild_lane.display_name
- ) ==
- grandchild_lane.full_identifier)
+ "%s / %s / %s / %s"
+ % (
+ lane.library.short_name,
+ lane.display_name,
+ child_lane.display_name,
+ grandchild_lane.display_name,
+ )
+ == grandchild_lane.full_identifier
+ )
assert [lane, child_lane, grandchild_lane] == grandchild_lane.hierarchy
@@ -3427,8 +3508,7 @@ def test_affected_by_customlist(self):
# Two lists.
l1, ignore = self._customlist(
- data_source_name=DataSource.GUTENBERG,
- num_entries=0
+ data_source_name=DataSource.GUTENBERG, num_entries=0
)
l2, ignore = self._customlist(
data_source_name=DataSource.OVERDRIVE, num_entries=0
@@ -3512,22 +3592,23 @@ def test_inherited_values(self):
# anything in particular. This lane contains books that
# are on the staff picks list.
staff_picks_lane.inherit_parent_restrictions = False
- assert [[staff_picks]] == staff_picks_lane.inherited_values('customlists')
+ assert [[staff_picks]] == staff_picks_lane.inherited_values("customlists")
# If inherit_parent_restrictions is True, then the lane
# has *two* sets of restrictions: a book must be on both
# the staff picks list *and* the best sellers list.
staff_picks_lane.inherit_parent_restrictions = True
- x = staff_picks_lane.inherited_values('customlists')
- assert (sorted([[staff_picks], [best_sellers]]) ==
- sorted(staff_picks_lane.inherited_values('customlists')))
+ x = staff_picks_lane.inherited_values("customlists")
+ assert sorted([[staff_picks], [best_sellers]]) == sorted(
+ staff_picks_lane.inherited_values("customlists")
+ )
def test_setting_target_age_locks_audiences(self):
lane = self._lane()
lane.target_age = (16, 18)
- assert (
- sorted([Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_ADULT]) ==
- sorted(lane.audiences))
+ assert sorted(
+ [Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_ADULT]
+ ) == sorted(lane.audiences)
lane.target_age = (0, 2)
assert [Classifier.AUDIENCE_CHILDREN] == lane.audiences
lane.target_age = 14
@@ -3536,11 +3617,15 @@ def test_setting_target_age_locks_audiences(self):
# It's not possible to modify .audiences to a value that's
# incompatible with .target_age.
lane.audiences = lane.audiences
+
def doomed():
lane.audiences = [Classifier.AUDIENCE_CHILDREN]
+
with pytest.raises(ValueError) as excinfo:
doomed()
- assert "Cannot modify Lane.audiences when Lane.target_age is set" in str(excinfo.value)
+ assert "Cannot modify Lane.audiences when Lane.target_age is set" in str(
+ excinfo.value
+ )
# Setting target_age to None leaves preexisting .audiences in place.
lane.target_age = None
@@ -3550,10 +3635,9 @@ def doomed():
lane.audiences = [Classifier.AUDIENCE_CHILDREN]
def test_target_age_treats_all_adults_equally(self):
- """We don't distinguish between different age groups for adults.
- """
+ """We don't distinguish between different age groups for adults."""
lane = self._lane()
- lane.target_age = (35,40)
+ lane.target_age = (35, 40)
assert tuple_to_numericrange((18, 18)) == lane.target_age
def test_uses_customlists(self):
@@ -3590,9 +3674,12 @@ def test_genre_ids(self):
# At this point the lane picks up Fantasy and all of its
# subgenres.
expect = [
- Genre.lookup(self._db, genre)[0].id for genre in [
- "Fantasy", "Epic Fantasy","Historical Fantasy",
- "Urban Fantasy"
+ Genre.lookup(self._db, genre)[0].id
+ for genre in [
+ "Fantasy",
+ "Epic Fantasy",
+ "Historical Fantasy",
+ "Urban Fantasy",
]
]
assert set(expect) == fantasy.genre_ids
@@ -3639,12 +3726,8 @@ def test_customlist_ids(self):
# When you add a CustomList to a Lane, you are saying that works
# from that CustomList can appear in the Lane.
- nyt1, ignore = self._customlist(
- num_entries=0, data_source_name=DataSource.NYT
- )
- nyt2, ignore = self._customlist(
- num_entries=0, data_source_name=DataSource.NYT
- )
+ nyt1, ignore = self._customlist(num_entries=0, data_source_name=DataSource.NYT)
+ nyt2, ignore = self._customlist(num_entries=0, data_source_name=DataSource.NYT)
no_lists = self._lane()
assert None == no_lists.customlist_ids
@@ -3657,17 +3740,13 @@ def test_customlist_ids(self):
# works appear in the Lane if they are on _any_ CustomList from
# that data source.
has_list_source = self._lane()
- has_list_source.list_datasource = DataSource.lookup(
- self._db, DataSource.NYT
- )
+ has_list_source.list_datasource = DataSource.lookup(self._db, DataSource.NYT)
assert set([nyt1.id, nyt2.id]) == set(has_list_source.customlist_ids)
# If there are no CustomLists from that data source, an empty
# list is returned.
has_no_lists = self._lane()
- has_no_lists.list_datasource = DataSource.lookup(
- self._db, DataSource.OVERDRIVE
- )
+ has_no_lists.list_datasource = DataSource.lookup(self._db, DataSource.OVERDRIVE)
assert [] == has_no_lists.customlist_ids
def test_search_target(self):
@@ -3742,16 +3821,27 @@ def test_search_target(self):
target = lane.search_target
assert "English Adult and Young Adult" == target.display_name
assert ["eng"] == target.languages
- assert [Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_YOUNG_ADULT] == target.audiences
+ assert [
+ Classifier.AUDIENCE_ADULT,
+ Classifier.AUDIENCE_YOUNG_ADULT,
+ ] == target.audiences
assert [Edition.BOOK_MEDIUM] == target.media
# If there are too many audiences, they're left
# out of the display name.
- lane.audiences = [Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_CHILDREN]
+ lane.audiences = [
+ Classifier.AUDIENCE_ADULT,
+ Classifier.AUDIENCE_YOUNG_ADULT,
+ Classifier.AUDIENCE_CHILDREN,
+ ]
target = lane.search_target
assert "English" == target.display_name
assert ["eng"] == target.languages
- assert [Classifier.AUDIENCE_ADULT, Classifier.AUDIENCE_YOUNG_ADULT, Classifier.AUDIENCE_CHILDREN] == target.audiences
+ assert [
+ Classifier.AUDIENCE_ADULT,
+ Classifier.AUDIENCE_YOUNG_ADULT,
+ Classifier.AUDIENCE_CHILDREN,
+ ] == target.audiences
assert [Edition.BOOK_MEDIUM] == target.media
def test_search(self):
@@ -3794,9 +3884,11 @@ def test_search_propagates_facets(self):
"""Lane.search propagates facets when calling search() on
its search target.
"""
+
class Mock(object):
def search(self, *args, **kwargs):
- self.called_with = kwargs['facets']
+ self.called_with = kwargs["facets"]
+
mock = Mock()
lane = self._lane()
@@ -3825,21 +3917,21 @@ def test_explain(self):
child = self._lane(parent=parent, display_name="Child")
child.priority = 2
data = parent.explain()
- assert (['ID: %s' % parent.id,
- 'Library: %s' % self._default_library.short_name,
- 'Priority: 1',
- 'Display name: Parent',
- ] ==
- data)
+ assert [
+ "ID: %s" % parent.id,
+ "Library: %s" % self._default_library.short_name,
+ "Priority: 1",
+ "Display name: Parent",
+ ] == data
data = child.explain()
- assert (['ID: %s' % child.id,
- 'Library: %s' % self._default_library.short_name,
- 'Parent ID: %s (Parent)' % parent.id,
- 'Priority: 2',
- 'Display name: Child',
- ] ==
- data)
+ assert [
+ "ID: %s" % child.id,
+ "Library: %s" % self._default_library.short_name,
+ "Parent ID: %s (Parent)" % parent.id,
+ "Priority: 2",
+ "Display name: Child",
+ ] == data
def test_groups_propagates_facets(self):
# Lane.groups propagates a received Facets object into
@@ -3847,6 +3939,7 @@ def test_groups_propagates_facets(self):
def mock(self, _db, relevant_lanes, queryable_lanes, facets, *args, **kwargs):
self.called_with = facets
return []
+
old_value = Lane._groups_for_lanes
Lane._groups_for_lanes = mock
lane = self._lane()
@@ -3873,15 +3966,15 @@ def _w(**kwargs):
library.setting(library.FEATURED_LANE_SIZE).value = "2"
# Create eight works.
- self.hq_litfic = _w(title="HQ LitFic", fiction=True, genre='Literary Fiction')
+ self.hq_litfic = _w(title="HQ LitFic", fiction=True, genre="Literary Fiction")
self.hq_litfic.quality = 0.8
- self.lq_litfic = _w(title="LQ LitFic", fiction=True, genre='Literary Fiction')
+ self.lq_litfic = _w(title="LQ LitFic", fiction=True, genre="Literary Fiction")
self.lq_litfic.quality = 0
self.hq_sf = _w(title="HQ SF", genre="Science Fiction", fiction=True)
# Add a lot of irrelevant genres to one of the works. This
# won't affect the results.
- for genre in ['Westerns', 'Horror', 'Erotica']:
+ for genre in ["Westerns", "Horror", "Erotica"]:
genre_obj, is_new = Genre.lookup(self._db, genre)
get_one_or_create(self._db, WorkGenre, work=self.hq_sf, genre=genre_obj)
@@ -3896,7 +3989,9 @@ def _w(**kwargs):
self.mq_ro.quality = 0.6
# This work is in a different language -- necessary to run the
# LQRomanceEntryPoint test below.
- self.lq_ro = _w(title="LQ Romance", genre="Romance", fiction=True, language='lan')
+ self.lq_ro = _w(
+ title="LQ Romance", genre="Romance", fiction=True, language="lan"
+ )
self.lq_ro.quality = 0.1
self.nonfiction = _w(title="Nonfiction", fiction=False)
@@ -3917,15 +4012,11 @@ def test_groups(self):
fiction.fiction = True
# "Best Sellers", which will contain one book.
- best_sellers = self._lane(
- "Best Sellers", parent=fiction
- )
+ best_sellers = self._lane("Best Sellers", parent=fiction)
best_sellers.customlists.append(self.best_seller_list)
# "Staff Picks", which will contain the same book.
- staff_picks = self._lane(
- "Staff Picks", parent=fiction
- )
+ staff_picks = self._lane("Staff Picks", parent=fiction)
staff_picks.customlists.append(self.staff_picks_list)
# "Science Fiction", which will contain two books (but
@@ -3935,15 +4026,12 @@ def test_groups(self):
)
# "Romance", which will contain two books.
- romance_lane = self._lane(
- "Romance", parent=fiction, genres=["Romance"]
- )
+ romance_lane = self._lane("Romance", parent=fiction, genres=["Romance"])
# "Discredited Nonfiction", which contains a book that would
# not normally appear in 'Fiction'.
discredited_nonfiction = self._lane(
- "Discredited Nonfiction", fiction=False,
- parent=fiction
+ "Discredited Nonfiction", fiction=False, parent=fiction
)
discredited_nonfiction.inherit_parent_restrictions = False
@@ -3955,10 +4043,16 @@ def test_groups(self):
self._default_library,
collection=Facets.COLLECTION_FULL,
availability=Facets.AVAILABLE_ALL,
- order=Facets.ORDER_TITLE
+ order=Facets.ORDER_TITLE,
)
- for lane in [fiction, best_sellers, staff_picks, sf_lane, romance_lane,
- discredited_nonfiction]:
+ for lane in [
+ fiction,
+ best_sellers,
+ staff_picks,
+ sf_lane,
+ romance_lane,
+ discredited_nonfiction,
+ ]:
t1 = [x.id for x in lane.works(self._db, facets)]
t2 = [x.id for x in lane.works_from_database(self._db, facets)]
assert t1 == t2
@@ -3968,36 +4062,31 @@ def assert_contents(g, expect):
(Work, lane) 2-tuples.
"""
results = list(g)
- expect = [
- (x[0].sort_title, x[1].display_name) for x in expect
- ]
- actual = [
- (x[0].sort_title, x[1].display_name) for x in results
- ]
+ expect = [(x[0].sort_title, x[1].display_name) for x in expect]
+ actual = [(x[0].sort_title, x[1].display_name) for x in results]
for i, expect_item in enumerate(expect):
if i >= len(actual):
actual_item = None
else:
actual_item = actual[i]
- assert (expect_item == actual_item), \
- "Mismatch in position %d: Expected %r, got %r.\nOverall, expected:\n%r\nGot:\n%r:" % \
- (i, expect_item, actual_item, expect, actual)
- assert len(expect) == len(actual), \
- "Expect matches actual, but actual has extra members.\nOverall, expected:\n%r\nGot:\n%r:" % \
- (expect, actual)
+ assert expect_item == actual_item, (
+ "Mismatch in position %d: Expected %r, got %r.\nOverall, expected:\n%r\nGot:\n%r:"
+ % (i, expect_item, actual_item, expect, actual)
+ )
+ assert len(expect) == len(actual), (
+ "Expect matches actual, but actual has extra members.\nOverall, expected:\n%r\nGot:\n%r:"
+ % (expect, actual)
+ )
def make_groups(lane, facets=None, **kwargs):
# Run the `WorkList.groups` method in a way that's
# instrumented for this unit test.
# Most of the time, we want a simple deterministic query.
- facets = facets or FeaturedFacets(
- 1, random_seed=Filter.DETERMINISTIC
- )
+ facets = facets or FeaturedFacets(1, random_seed=Filter.DETERMINISTIC)
return lane.groups(
- self._db, facets=facets, search_engine=self.search, debug=True,
- **kwargs
+ self._db, facets=facets, search_engine=self.search, debug=True, **kwargs
)
assert_contents(
@@ -4008,13 +4097,11 @@ def make_groups(lane, facets=None, **kwargs):
# FEATURED_LANE_SIZE, but nothing else belongs in the
# lane.
(self.mq_sf, best_sellers),
-
# In fact, both lanes feature the same title -- this
# generally won't happen but it can happen when
# multiple lanes are based on lists that feature the
# same title.
(self.mq_sf, staff_picks),
-
# The genre-based lanes contain FEATURED_LANE_SIZE
# (two) titles each. The 'Science Fiction' lane
# features a low-quality work because the
@@ -4023,12 +4110,10 @@ def make_groups(lane, facets=None, **kwargs):
(self.lq_sf, sf_lane),
(self.hq_ro, romance_lane),
(self.mq_ro, romance_lane),
-
# The 'Discredited Nonfiction' lane contains a single
# book. There just weren't enough matching books to fill
# out the lane to FEATURED_LANE_SIZE.
(self.nonfiction, discredited_nonfiction),
-
# The 'Fiction' lane contains a title that fits in the
# fiction lane but was not classified under any other
# lane. It also contains a title that was previously
@@ -4042,7 +4127,7 @@ def make_groups(lane, facets=None, **kwargs):
# low-quality one to show up.
(self.hq_litfic, fiction),
(self.hq_sf, fiction),
- ]
+ ],
)
# If we ask only about 'Fiction', not including its sublanes,
@@ -4050,7 +4135,7 @@ def make_groups(lane, facets=None, **kwargs):
# 'fiction'.
assert_contents(
make_groups(fiction, include_sublanes=False),
- [(self.hq_litfic, fiction), (self.hq_sf, fiction)]
+ [(self.hq_litfic, fiction), (self.hq_sf, fiction)],
)
# If we exclude 'Fiction' from its own grouped feed, we get
@@ -4067,7 +4152,7 @@ def make_groups(lane, facets=None, **kwargs):
(self.hq_ro, romance_lane),
(self.mq_ro, romance_lane),
(self.nonfiction, discredited_nonfiction),
- ]
+ ],
)
fiction.include_self_in_grouped_feed = True
@@ -4078,7 +4163,7 @@ def make_groups(lane, facets=None, **kwargs):
discredited_nonfiction.groups(
self._db, include_sublanes=include_sublanes
),
- [(self.nonfiction, discredited_nonfiction)]
+ [(self.nonfiction, discredited_nonfiction)],
)
# If we make the lanes thirstier for content, we see slightly
@@ -4091,7 +4176,6 @@ def make_groups(lane, facets=None, **kwargs):
# The list-based lanes are the same as before.
(self.mq_sf, best_sellers),
(self.mq_sf, staff_picks),
-
# After using every single science fiction work that
# wasn't previously used, we reuse self.mq_sf to pad the
# "Science Fiction" lane up to three items. It's
@@ -4101,17 +4185,14 @@ def make_groups(lane, facets=None, **kwargs):
(self.hq_sf, sf_lane),
(self.lq_sf, sf_lane),
(self.mq_sf, sf_lane),
-
# The 'Romance' lane now contains all three Romance
# titles, with the higher-quality titles first.
(self.hq_ro, romance_lane),
(self.mq_ro, romance_lane),
(self.lq_ro, romance_lane),
-
# The 'Discredited Nonfiction' lane is the same as
# before.
(self.nonfiction, discredited_nonfiction),
-
# After using every single fiction work that wasn't
# previously used, we reuse high-quality works to pad
# the "Fiction" lane to three items. The
@@ -4122,7 +4203,7 @@ def make_groups(lane, facets=None, **kwargs):
(self.hq_litfic, fiction),
(self.hq_sf, fiction),
(self.hq_ro, fiction),
- ]
+ ],
)
# Let's see how entry points affect the feeds.
@@ -4138,12 +4219,13 @@ def make_groups(lane, facets=None, **kwargs):
# that only finds one book.
class LQRomanceEntryPoint(EntryPoint):
URI = ""
+
@classmethod
def modify_search_filter(cls, filter):
- filter.languages = ['lan']
+ filter.languages = ["lan"]
+
facets = FeaturedFacets(
- 1, entrypoint=LQRomanceEntryPoint,
- random_seed=Filter.DETERMINISTIC
+ 1, entrypoint=LQRomanceEntryPoint, random_seed=Filter.DETERMINISTIC
)
assert_contents(
make_groups(fiction, facets=facets),
@@ -4152,7 +4234,7 @@ def modify_search_filter(cls, filter):
# that can show it.
(self.lq_ro, romance_lane),
(self.lq_ro, fiction),
- ]
+ ],
)
# Now, instead of relying on the 'Fiction' lane, make a
@@ -4171,9 +4253,7 @@ def groups(slf, _db, include_sublanes, pagination=None, facets=None):
mock = MockWorkList()
wl = WorkList()
- wl.initialize(
- self._default_library, children=[best_sellers, staff_picks, mock]
- )
+ wl.initialize(self._default_library, children=[best_sellers, staff_picks, mock])
# We get results from the two lanes and from the MockWorkList.
# Since the MockWorkList wasn't a lane, its results were obtained
@@ -4184,12 +4264,11 @@ def groups(slf, _db, include_sublanes, pagination=None, facets=None):
(self.mq_sf, best_sellers),
(self.mq_sf, staff_picks),
(self.lq_litfic, mock),
- ]
+ ],
)
class TestWorkListGroups(DatabaseTest):
-
def setup_method(self):
super(TestWorkListGroups, self).setup_method()
@@ -4203,11 +4282,14 @@ def test_groups_for_lanes_adapts_facets(self):
# FeaturedFacets objects to its own needs.
class MockParent(WorkList):
-
def _featured_works_with_lanes(
- self, _db, lanes, pagination, facets, *args, **kwargs
+ self, _db, lanes, pagination, facets, *args, **kwargs
):
- self._featured_works_with_lanes_called_with = (lanes, pagination, facets)
+ self._featured_works_with_lanes_called_with = (
+ lanes,
+ pagination,
+ facets,
+ )
return super(MockParent, self)._featured_works_with_lanes(
_db, lanes, pagination, facets, *args, **kwargs
)
@@ -4233,8 +4315,7 @@ def works(self, _db, pagination, facets, *args, **kwargs):
for wl in children:
wl.initialize(library=self._default_library)
- parent.initialize(library=self._default_library,
- children=[child1, child2])
+ parent.initialize(library=self._default_library, children=[child1, child2])
# We're going to make a grouped feed in which both children
# are relevant, but neither one is queryable.
@@ -4271,8 +4352,7 @@ def works(self, _db, pagination, facets, *args, **kwargs):
# a new Pagination object is created based on the featured lane
# size for the library.
groups = list(
- parent._groups_for_lanes(self._db, relevant, queryable, None,
- facets)
+ parent._groups_for_lanes(self._db, relevant, queryable, None, facets)
)
(ignore1, pagination, ignore2) = parent._featured_works_with_lanes_called_with
@@ -4281,8 +4361,7 @@ def works(self, _db, pagination, facets, *args, **kwargs):
# For each sublane, we ask for 10% more items than we need to
# reduce the chance that we'll need to put the same item in
# multiple lanes.
- assert (int(self._default_library.featured_lane_size * 1.10) ==
- pagination.size)
+ assert int(self._default_library.featured_lane_size * 1.10) == pagination.size
def test_featured_works_with_lanes(self):
# _featured_works_with_lanes builds a list of queries and
@@ -4295,6 +4374,7 @@ class MockWorkList(WorkList):
searched, and works_for_resultsets() for the parent that's
doing the searching.
"""
+
def __init__(self, *args, **kwargs):
# Track all the times overview_facets is called (it
# should be called twice), plus works_for_resultsets
@@ -4321,6 +4401,7 @@ def works_for_resultsets(self, _db, resultsets, facets=None):
class MockSearchEngine(object):
"""Mock a multi-query call to an Elasticsearch server."""
+
def __init__(self):
self.called_with = None
@@ -4336,8 +4417,9 @@ def query_works_multi(self, queries):
child1 = MockWorkList()
child2 = MockWorkList()
parent.initialize(
- library=self._default_library, children=[child1, child2],
- display_name = "Parent lane -- call my _featured_works_with_lanes()!"
+ library=self._default_library,
+ children=[child1, child2],
+ display_name="Parent lane -- call my _featured_works_with_lanes()!",
)
child1.initialize(library=self._default_library, display_name="Child 1")
child2.initialize(library=self._default_library, display_name="Child 2")
@@ -4405,22 +4487,20 @@ def query_works_multi(self, queries):
#
# This was passed into parent.works_for_resultsets():
call = parent.works_for_resultsets_calls.pop()
- assert call == (self._db, [['some'], ['search'], ['results']])
+ assert call == (self._db, [["some"], ["search"], ["results"]])
assert [] == parent.works_for_resultsets_calls
# The return value of works_for_resultsets -- another list of
# lists -- was then turned into a sequence of ('work', Lane)
# 2-tuples.
- assert (
- [
- ("Here is", child1),
- ("one lane", child1),
- ("of works", child1),
- ("Here is", child2),
- ("one lane", child2),
- ("of works", child2),
- ] ==
- results)
+ assert [
+ ("Here is", child1),
+ ("one lane", child1),
+ ("of works", child1),
+ ("Here is", child2),
+ ("one lane", child2),
+ ("of works", child2),
+ ] == results
# And that's how we got a sequence of 2-tuples mapping out a
# grouped OPDS feed.
@@ -4432,8 +4512,10 @@ def test__size_for_facets(self):
ebooks, audio, everything, nothing = [
FeaturedFacets(minimum_featured_quality=0.5, entrypoint=x)
for x in (
- EbooksEntryPoint, AudiobooksEntryPoint, EverythingEntryPoint,
- None
+ EbooksEntryPoint,
+ AudiobooksEntryPoint,
+ EverythingEntryPoint,
+ None,
)
]
@@ -4445,9 +4527,9 @@ def test__size_for_facets(self):
# Once Lane.size_by_entrypoint is set, it's used when possible.
lane.size_by_entrypoint = {
- EverythingEntryPoint.URI : 99,
- EbooksEntryPoint.URI : 1,
- AudiobooksEntryPoint.URI : 2
+ EverythingEntryPoint.URI: 99,
+ EbooksEntryPoint.URI: 1,
+ AudiobooksEntryPoint.URI: 2,
}
assert 99 == m(None)
assert 99 == m(nothing)
diff --git a/tests/test_local_analytics_provider.py b/tests/test_local_analytics_provider.py
index a4970bb47..aa1d7451e 100644
--- a/tests/test_local_analytics_provider.py
+++ b/tests/test_local_analytics_provider.py
@@ -1,39 +1,43 @@
import pytest
-from ..testing import DatabaseTest
from ..local_analytics_provider import LocalAnalyticsProvider
-from ..model import (
- CirculationEvent,
- ExternalIntegration,
- create,
-)
+from ..model import CirculationEvent, ExternalIntegration, create
+from ..testing import DatabaseTest
from ..util.datetime_helpers import to_utc, utc_now
-class TestLocalAnalyticsProvider(DatabaseTest):
+class TestLocalAnalyticsProvider(DatabaseTest):
def setup_method(self):
super(TestLocalAnalyticsProvider, self).setup_method()
self.integration, ignore = create(
- self._db, ExternalIntegration,
+ self._db,
+ ExternalIntegration,
goal=ExternalIntegration.ANALYTICS_GOAL,
- protocol="core.local_analytics_provider")
- self.la = LocalAnalyticsProvider(
- self.integration, self._default_library
+ protocol="core.local_analytics_provider",
)
+ self.la = LocalAnalyticsProvider(self.integration, self._default_library)
def test_collect_event(self):
library2 = self._library()
work = self._work(
- title="title", authors="author", fiction=True,
- audience="audience", language="lang",
- with_license_pool=True
+ title="title",
+ authors="author",
+ fiction=True,
+ audience="audience",
+ language="lang",
+ with_license_pool=True,
)
[lp] = work.license_pools
now = utc_now()
self.la.collect_event(
- self._default_library, lp, CirculationEvent.DISTRIBUTOR_CHECKIN, now,
- old_value=None, new_value=None)
+ self._default_library,
+ lp,
+ CirculationEvent.DISTRIBUTOR_CHECKIN,
+ now,
+ old_value=None,
+ new_value=None,
+ )
qu = self._db.query(CirculationEvent).filter(
CirculationEvent.type == CirculationEvent.DISTRIBUTOR_CHECKIN
@@ -50,8 +54,13 @@ def test_collect_event(self):
# for a different library.
now = utc_now()
self.la.collect_event(
- library2, lp, CirculationEvent.DISTRIBUTOR_CHECKIN, now,
- old_value=None, new_value=None)
+ library2,
+ lp,
+ CirculationEvent.DISTRIBUTOR_CHECKIN,
+ now,
+ old_value=None,
+ new_value=None,
+ )
assert 1 == qu.count()
# It's possible to instantiate the LocalAnalyticsProvider
@@ -61,9 +70,13 @@ def test_collect_event(self):
# In that case, it will process events for any library.
for library in [self._default_library, library2]:
now = utc_now()
- la.collect_event(library, lp,
- CirculationEvent.DISTRIBUTOR_CHECKIN, now,
- old_value=None, new_value=None
+ la.collect_event(
+ library,
+ lp,
+ CirculationEvent.DISTRIBUTOR_CHECKIN,
+ now,
+ old_value=None,
+ new_value=None,
)
assert 3 == qu.count()
@@ -88,8 +101,7 @@ def test_neighborhood_is_location(self):
# The default LocalAnalytics object doesn't have a location
# gathering policy, and the default is to ignore location.
event, is_new = self.la.collect_event(
- self._default_library, None, "event", utc_now(),
- neighborhood="Gormenghast"
+ self._default_library, None, "event", utc_now(), neighborhood="Gormenghast"
)
assert True == is_new
assert None == event.location
@@ -98,14 +110,13 @@ def test_neighborhood_is_location(self):
# neighborhood as the event location.
p = LocalAnalyticsProvider
- self.integration.setting(p.LOCATION_SOURCE).value = (
- p.LOCATION_SOURCE_NEIGHBORHOOD
- )
+ self.integration.setting(
+ p.LOCATION_SOURCE
+ ).value = p.LOCATION_SOURCE_NEIGHBORHOOD
la = p(self.integration, self._default_library)
event, is_new = la.collect_event(
- self._default_library, None, "event", utc_now(),
- neighborhood="Gormenghast"
+ self._default_library, None, "event", utc_now(), neighborhood="Gormenghast"
)
assert True == is_new
assert "Gormenghast" == event.location
@@ -113,7 +124,10 @@ def test_neighborhood_is_location(self):
# If no neighborhood is available, the event ends up with no location
# anyway.
event2, is_new = la.collect_event(
- self._default_library, None, "event", utc_now(),
+ self._default_library,
+ None,
+ "event",
+ utc_now(),
)
assert event2 != event
assert True == is_new
diff --git a/tests/test_log.py b/tests/test_log.py
index 40eddc10a..aa2d9740e 100644
--- a/tests/test_log.py
+++ b/tests/test_log.py
@@ -5,27 +5,24 @@
import pytest
-from ..testing import DatabaseTest
+from ..config import Configuration
from ..log import (
- StringFormatter,
- JSONFormatter,
- LogglyHandler,
+ CannotLoadConfiguration,
CloudWatchLogHandler,
- LogConfiguration,
- SysLogger,
- Loggly,
CloudwatchLogs,
+ JSONFormatter,
+ LogConfiguration,
Logger,
- CannotLoadConfiguration,
-)
-from ..model import (
- ExternalIntegration,
- ConfigurationSetting
+ Loggly,
+ LogglyHandler,
+ StringFormatter,
+ SysLogger,
)
-from ..config import Configuration
+from ..model import ConfigurationSetting, ExternalIntegration
+from ..testing import DatabaseTest
-class TestJSONFormatter(object):
+class TestJSONFormatter(object):
def test_format(self):
formatter = JSONFormatter("some app")
assert "some app" == formatter.app_name
@@ -38,16 +35,22 @@ def test_format(self):
exc_info = sys.exc_info()
record = logging.LogRecord(
- "some logger", logging.DEBUG, "pathname",
- 104, "A message", {}, exc_info, None
+ "some logger",
+ logging.DEBUG,
+ "pathname",
+ 104,
+ "A message",
+ {},
+ exc_info,
+ None,
)
data = json.loads(formatter.format(record))
- assert "some logger" == data['name']
- assert "some app" == data['app']
- assert "DEBUG" == data['level']
- assert "A message" == data['message']
- assert "pathname" == data['filename']
- assert 'ValueError: fake exception' in data['traceback']
+ assert "some logger" == data["name"]
+ assert "some app" == data["app"]
+ assert "DEBUG" == data["level"]
+ assert "A message" == data["message"]
+ assert "pathname" == data["filename"]
+ assert "ValueError: fake exception" in data["traceback"]
def test_format_with_different_types_of_strings(self):
# As long as all data is either Unicode or UTF-8, any combination
@@ -69,16 +72,14 @@ def test_format_with_different_types_of_strings(self):
(byte_message, unicode_snowman),
):
record = logging.LogRecord(
- "some logger", logging.DEBUG, "pathname",
- 104, msg, (args,), None, None
+ "some logger", logging.DEBUG, "pathname", 104, msg, (args,), None, None
)
data = json.loads(formatter.format(record))
# The resulting data is always a Unicode string.
- assert "An important snowman: ☃" == data['message']
+ assert "An important snowman: ☃" == data["message"]
class TestLogConfiguration(DatabaseTest):
-
def test_configuration(self):
"""Loggly.NAME must equal ExternalIntegration.LOGGLY.
Enforcing this with code would create an import loop,
@@ -89,8 +90,7 @@ def test_configuration(self):
def loggly_integration(self):
"""Create an ExternalIntegration for a Loggly account."""
integration = self._external_integration(
- protocol=ExternalIntegration.LOGGLY,
- goal=ExternalIntegration.LOGGING_GOAL
+ protocol=ExternalIntegration.LOGGLY, goal=ExternalIntegration.LOGGING_GOAL
)
integration.url = "http://example.com/%s/"
integration.password = "a_token"
@@ -100,7 +100,7 @@ def cloudwatch_integration(self):
"""Create an ExternalIntegration for a Cloudwatch account."""
integration = self._external_integration(
protocol=ExternalIntegration.CLOUDWATCH,
- goal=ExternalIntegration.LOGGING_GOAL
+ goal=ExternalIntegration.LOGGING_GOAL,
)
integration.set_setting(CloudwatchLogs.CREATE_GROUP, "FALSE")
@@ -136,11 +136,13 @@ def test_from_configuration(self):
self.cloudwatch_integration()
internal = self._external_integration(
protocol=ExternalIntegration.INTERNAL_LOGGING,
- goal=ExternalIntegration.LOGGING_GOAL
+ goal=ExternalIntegration.LOGGING_GOAL,
)
ConfigurationSetting.sitewide(self._db, config.LOG_LEVEL).value = config.ERROR
internal.setting(SysLogger.LOG_FORMAT).value = SysLogger.TEXT_LOG_FORMAT
- ConfigurationSetting.sitewide(self._db, config.DATABASE_LOG_LEVEL).value = config.DEBUG
+ ConfigurationSetting.sitewide(
+ self._db, config.DATABASE_LOG_LEVEL
+ ).value = config.DEBUG
ConfigurationSetting.sitewide(self._db, config.LOG_APP_NAME).value = "test app"
template = "%(filename)s:%(message)s"
internal.setting(SysLogger.LOG_MESSAGE_TEMPLATE).value = template
@@ -153,13 +155,14 @@ def test_from_configuration(self):
assert "http://example.com/a_token/" == loggly_handler.url
assert "test app" == loggly_handler.formatter.app_name
- [cloudwatch_handler] = [x for x in handlers if isinstance(x, CloudWatchLogHandler)]
+ [cloudwatch_handler] = [
+ x for x in handlers if isinstance(x, CloudWatchLogHandler)
+ ]
assert "simplified" == cloudwatch_handler.stream_name
assert "simplified" == cloudwatch_handler.log_group
assert 60 == cloudwatch_handler.send_interval
- [stream_handler] = [x for x in handlers
- if isinstance(x, logging.StreamHandler)]
+ [stream_handler] = [x for x in handlers if isinstance(x, logging.StreamHandler)]
assert isinstance(stream_handler.formatter, StringFormatter)
assert template == stream_handler.formatter._fmt
@@ -178,25 +181,27 @@ def test_syslog_defaults(self):
# Normally log messages are emitted in JSON format.
assert (
- (SysLogger.JSON_LOG_FORMAT, SysLogger.DEFAULT_MESSAGE_TEMPLATE) ==
- cls._defaults(testing=False))
+ SysLogger.JSON_LOG_FORMAT,
+ SysLogger.DEFAULT_MESSAGE_TEMPLATE,
+ ) == cls._defaults(testing=False)
# When we're running unit tests, log messages are emitted in text format.
assert (
- (SysLogger.TEXT_LOG_FORMAT, SysLogger.DEFAULT_MESSAGE_TEMPLATE) ==
- cls._defaults(testing=True))
+ SysLogger.TEXT_LOG_FORMAT,
+ SysLogger.DEFAULT_MESSAGE_TEMPLATE,
+ ) == cls._defaults(testing=True)
def test_set_formatter(self):
# Create a generic handler.
handler = logging.StreamHandler()
# Configure it for text output.
- template = '%(filename)s:%(message)s'
+ template = "%(filename)s:%(message)s"
SysLogger.set_formatter(
handler,
log_format=SysLogger.TEXT_LOG_FORMAT,
message_template=template,
- app_name="some app"
+ app_name="some app",
)
formatter = handler.formatter
assert isinstance(formatter, StringFormatter)
@@ -214,7 +219,7 @@ def test_set_formatter(self):
# In this case the template is irrelevant. The JSONFormatter
# uses the default format template, but it doesn't matter,
# because JSONFormatter overrides the format() method.
- assert '%(message)s' == formatter._fmt
+ assert "%(message)s" == formatter._fmt
# Configure a handler for output to Loggly. In this case
# the format and template are irrelevant.
@@ -236,8 +241,7 @@ def test_loggly_handler(self):
# be used.
integration.url = None
handler = Loggly.loggly_handler(integration)
- assert (Loggly.DEFAULT_LOGGLY_URL % dict(token="a_token") ==
- handler.url)
+ assert Loggly.DEFAULT_LOGGLY_URL % dict(token="a_token") == handler.url
def test_cloudwatch_handler(self):
"""Turn an appropriate ExternalIntegration into a CloudWatchLogHandler."""
@@ -246,7 +250,7 @@ def test_cloudwatch_handler(self):
integration.set_setting(CloudwatchLogs.GROUP, "test_group")
integration.set_setting(CloudwatchLogs.STREAM, "test_stream")
integration.set_setting(CloudwatchLogs.INTERVAL, 120)
- integration.set_setting(CloudwatchLogs.REGION, 'us-east-2')
+ integration.set_setting(CloudwatchLogs.REGION, "us-east-2")
handler = CloudwatchLogs.get_handler(integration, testing=True)
assert isinstance(handler, CloudWatchLogHandler)
assert "test_stream" == handler.stream_name
@@ -254,9 +258,13 @@ def test_cloudwatch_handler(self):
assert 120 == handler.send_interval
integration.setting(CloudwatchLogs.INTERVAL).value = -10
- pytest.raises(CannotLoadConfiguration, CloudwatchLogs.get_handler, integration, True)
+ pytest.raises(
+ CannotLoadConfiguration, CloudwatchLogs.get_handler, integration, True
+ )
integration.setting(CloudwatchLogs.INTERVAL).value = "a string"
- pytest.raises(CannotLoadConfiguration, CloudwatchLogs.get_handler, integration, True)
+ pytest.raises(
+ CannotLoadConfiguration, CloudwatchLogs.get_handler, integration, True
+ )
def test_interpolate_loggly_url(self):
m = Loggly._interpolate_loggly_url
@@ -268,8 +276,7 @@ def test_interpolate_loggly_url(self):
# If the URL contains no string interpolation, we assume the token's
# already in there.
- assert ("http://foo/othertoken/bar/" ==
- m("http://foo/othertoken/bar/", "token"))
+ assert "http://foo/othertoken/bar/" == m("http://foo/othertoken/bar/", "token")
# Anything that doesn't fall under one of these cases will raise an
# exception.
@@ -281,8 +288,14 @@ def test_cloudwatch_initialization_exception(self):
integration = self.cloudwatch_integration()
integration.set_setting(CloudwatchLogs.CREATE_GROUP, "TRUE")
- internal_log_level, database_log_level, [handler], [error] = LogConfiguration.from_configuration(
- self._db, testing=False
- )
+ (
+ internal_log_level,
+ database_log_level,
+ [handler],
+ [error],
+ ) = LogConfiguration.from_configuration(self._db, testing=False)
assert isinstance(handler, logging.StreamHandler)
- assert 'Error creating logger AWS Cloudwatch Logs Unable to locate credentials' == error
+ assert (
+ "Error creating logger AWS Cloudwatch Logs Unable to locate credentials"
+ == error
+ )
diff --git a/tests/test_marc.py b/tests/test_marc.py
index 022e7f6b2..acde5dba7 100644
--- a/tests/test_marc.py
+++ b/tests/test_marc.py
@@ -1,11 +1,15 @@
import datetime
-import pytest
-from pymarc import Record, MARCReader
from io import StringIO
from urllib.parse import quote
+
+import pytest
+from pymarc import MARCReader, Record
from sqlalchemy.orm.session import Session
-from ..testing import DatabaseTest
+from ..config import CannotLoadConfiguration
+from ..external_search import Filter, MockExternalSearchIndex
+from ..lane import WorkList
+from ..marc import Annotator, MARCExporter, MARCExporterFacets
from ..model import (
CachedMARCFile,
Contributor,
@@ -21,32 +25,21 @@
Work,
get_one,
)
-from ..config import CannotLoadConfiguration
-from ..external_search import (
- MockExternalSearchIndex,
- Filter,
-)
-from ..marc import (
- Annotator,
- MARCExporter,
- MARCExporterFacets,
-)
-from ..s3 import (
- MockS3Uploader,
- S3Uploader,
-)
-from ..lane import WorkList
+from ..s3 import MockS3Uploader, S3Uploader
+from ..testing import DatabaseTest
from ..util.datetime_helpers import datetime_utc, utc_now
-class TestAnnotator(DatabaseTest):
+class TestAnnotator(DatabaseTest):
def test_annotate_work_record(self):
# Verify that annotate_work_record adds the distributor and formats.
class MockAnnotator(Annotator):
add_distributor_called_with = None
add_formats_called_with = None
+
def add_distributor(self, record, pool):
self.add_distributor_called_with = [record, pool]
+
def add_formats(self, record, pool):
self.add_formats_called_with = [record, pool]
@@ -96,8 +89,8 @@ def test_add_control_fields(self):
self._check_control_field(record, "006", "m d ")
self._check_control_field(record, "007", "cr cn ---anuuu")
self._check_control_field(
- record, "008",
- now.strftime("%y%m%d") + "s0956 xxu eng ")
+ record, "008", now.strftime("%y%m%d") + "s0956 xxu eng "
+ )
# This French edition has two formats and was published in 2018.
edition2, pool2 = self._edition(with_license_pool=True)
@@ -105,8 +98,12 @@ def test_add_control_fields(self):
edition2.issued = datetime_utc(2018, 2, 3)
edition2.language = "fre"
LicensePoolDeliveryMechanism.set(
- pool2.data_source, identifier2, Representation.PDF_MEDIA_TYPE,
- DeliveryMechanism.ADOBE_DRM, RightsStatus.IN_COPYRIGHT)
+ pool2.data_source,
+ identifier2,
+ Representation.PDF_MEDIA_TYPE,
+ DeliveryMechanism.ADOBE_DRM,
+ RightsStatus.IN_COPYRIGHT,
+ )
record = Record()
Annotator.add_control_fields(record, identifier2, pool2, edition2)
@@ -115,8 +112,8 @@ def test_add_control_fields(self):
self._check_control_field(record, "006", "m d ")
self._check_control_field(record, "007", "cr cn ---mnuuu")
self._check_control_field(
- record, "008",
- now.strftime("%y%m%d") + "s2018 xxu fre ")
+ record, "008", now.strftime("%y%m%d") + "s2018 xxu fre "
+ )
def test_add_marc_organization_code(self):
record = Record()
@@ -154,11 +151,15 @@ def test_add_title(self):
Annotator.add_title(record, edition)
[field] = record.get_fields("245")
self._check_field(
- record, "245", {
+ record,
+ "245",
+ {
"a": edition.title,
"b": edition.subtitle,
"c": edition.author,
- }, ["0", "4"])
+ },
+ ["0", "4"],
+ )
# If there's no subtitle or no author, those subfields are left out.
edition.subtitle = None
@@ -168,9 +169,13 @@ def test_add_title(self):
Annotator.add_title(record, edition)
[field] = record.get_fields("245")
self._check_field(
- record, "245", {
+ record,
+ "245",
+ {
"a": edition.title,
- }, ["0", "4"])
+ },
+ ["0", "4"],
+ )
assert [] == field.get_subfields("b")
assert [] == field.get_subfields("c")
@@ -198,7 +203,9 @@ def test_add_contributors(self):
fields = record.get_fields("700")
for field in fields:
assert ["1", " "] == field.indicators
- [author_field, author2_field, translator_field] = sorted(fields, key=lambda x: x.get_subfields("a")[0])
+ [author_field, author2_field, translator_field] = sorted(
+ fields, key=lambda x: x.get_subfields("a")[0]
+ )
assert author == author_field.get_subfields("a")[0]
assert Contributor.PRIMARY_AUTHOR_ROLE == author_field.get_subfields("e")[0]
assert author2 == author2_field.get_subfields("a")[0]
@@ -214,11 +221,15 @@ def test_add_publisher(self):
record = Record()
Annotator.add_publisher(record, edition)
self._check_field(
- record, "264", {
+ record,
+ "264",
+ {
"a": "[Place of publication not identified]",
"b": edition.publisher,
"c": "1894",
- }, [" ", "1"])
+ },
+ [" ", "1"],
+ )
# If there's no publisher, the field is left out.
record = Record()
@@ -241,55 +252,95 @@ def test_add_physical_description(self):
record = Record()
Annotator.add_physical_description(record, book)
self._check_field(record, "300", {"a": "1 online resource"})
- self._check_field(record, "336", {
- "a": "text",
- "b": "txt",
- "2": "rdacontent",
- })
- self._check_field(record, "337", {
- "a": "computer",
- "b": "c",
- "2": "rdamedia",
- })
- self._check_field(record, "338", {
- "a": "online resource",
- "b": "cr",
- "2": "rdacarrier",
- })
- self._check_field(record, "347", {
- "a": "text file",
- "2": "rda",
- })
- self._check_field(record, "380", {
- "a": "eBook",
- "2": "tlcgt",
- })
+ self._check_field(
+ record,
+ "336",
+ {
+ "a": "text",
+ "b": "txt",
+ "2": "rdacontent",
+ },
+ )
+ self._check_field(
+ record,
+ "337",
+ {
+ "a": "computer",
+ "b": "c",
+ "2": "rdamedia",
+ },
+ )
+ self._check_field(
+ record,
+ "338",
+ {
+ "a": "online resource",
+ "b": "cr",
+ "2": "rdacarrier",
+ },
+ )
+ self._check_field(
+ record,
+ "347",
+ {
+ "a": "text file",
+ "2": "rda",
+ },
+ )
+ self._check_field(
+ record,
+ "380",
+ {
+ "a": "eBook",
+ "2": "tlcgt",
+ },
+ )
record = Record()
Annotator.add_physical_description(record, audio)
- self._check_field(record, "300", {
- "a": "1 sound file",
- "b": "digital",
- })
- self._check_field(record, "336", {
- "a": "spoken word",
- "b": "spw",
- "2": "rdacontent",
- })
- self._check_field(record, "337", {
- "a": "computer",
- "b": "c",
- "2": "rdamedia",
- })
- self._check_field(record, "338", {
- "a": "online resource",
- "b": "cr",
- "2": "rdacarrier",
- })
- self._check_field(record, "347", {
- "a": "audio file",
- "2": "rda",
- })
+ self._check_field(
+ record,
+ "300",
+ {
+ "a": "1 sound file",
+ "b": "digital",
+ },
+ )
+ self._check_field(
+ record,
+ "336",
+ {
+ "a": "spoken word",
+ "b": "spw",
+ "2": "rdacontent",
+ },
+ )
+ self._check_field(
+ record,
+ "337",
+ {
+ "a": "computer",
+ "b": "c",
+ "2": "rdamedia",
+ },
+ )
+ self._check_field(
+ record,
+ "338",
+ {
+ "a": "online resource",
+ "b": "cr",
+ "2": "rdacarrier",
+ },
+ )
+ self._check_field(
+ record,
+ "347",
+ {
+ "a": "audio file",
+ "2": "rda",
+ },
+ )
assert [] == record.get_fields("380")
def test_add_audience(self):
@@ -297,10 +348,14 @@ def test_add_audience(self):
work = self._work(audience=audience)
record = Record()
Annotator.add_audience(record, work)
- self._check_field(record, "385", {
- "a": term,
- "2": "tlctarget",
- })
+ self._check_field(
+ record,
+ "385",
+ {
+ "a": term,
+ "2": "tlctarget",
+ },
+ )
def test_add_series(self):
edition = self._edition()
@@ -308,19 +363,29 @@ def test_add_series(self):
edition.series_position = 5
record = Record()
Annotator.add_series(record, edition)
- self._check_field(record, "490", {
- "a": edition.series,
- "v": str(edition.series_position),
- }, ["0", " "])
+ self._check_field(
+ record,
+ "490",
+ {
+ "a": edition.series,
+ "v": str(edition.series_position),
+ },
+ ["0", " "],
+ )
# If there's no series position, the same field is used without
# the v subfield.
edition.series_position = None
record = Record()
Annotator.add_series(record, edition)
- self._check_field(record, "490", {
- "a": edition.series,
- }, ["0", " "])
+ self._check_field(
+ record,
+ "490",
+ {
+ "a": edition.series,
+ },
+ ["0", " "],
+ )
[field] = record.get_fields("490")
assert [] == field.get_subfields("v")
@@ -338,11 +403,16 @@ def test_add_system_details(self):
def test_add_formats(self):
edition, pool = self._edition(with_license_pool=True)
epub_no_drm, ignore = DeliveryMechanism.lookup(
- self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM)
+ self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM
+ )
pool.delivery_mechanisms[0].delivery_mechanism = epub_no_drm
LicensePoolDeliveryMechanism.set(
- pool.data_source, pool.identifier, Representation.PDF_MEDIA_TYPE,
- DeliveryMechanism.ADOBE_DRM, RightsStatus.IN_COPYRIGHT)
+ pool.data_source,
+ pool.identifier,
+ Representation.PDF_MEDIA_TYPE,
+ DeliveryMechanism.ADOBE_DRM,
+ RightsStatus.IN_COPYRIGHT,
+ )
record = Record()
Annotator.add_formats(record, pool)
@@ -371,7 +441,9 @@ def test_add_simplified_genres(self):
record = Record()
Annotator.add_simplified_genres(record, work)
fields = record.get_fields("650")
- [fantasy_field, romance_field] = sorted(fields, key=lambda x: x.get_subfields("a")[0])
+ [fantasy_field, romance_field] = sorted(
+ fields, key=lambda x: x.get_subfields("a")[0]
+ )
assert ["0", "7"] == fantasy_field.indicators
assert "Fantasy" == fantasy_field.get_subfields("a")[0]
assert "Library Simplified" == fantasy_field.get_subfields("2")[0]
@@ -384,16 +456,19 @@ def test_add_ebooks_subject(self):
Annotator.add_ebooks_subject(record)
self._check_field(record, "655", {"a": "Electronic books."}, [" ", "0"])
-class TestMARCExporter(DatabaseTest):
+class TestMARCExporter(DatabaseTest):
def _integration(self):
return self._external_integration(
ExternalIntegration.MARC_EXPORT,
ExternalIntegration.CATALOG_GOAL,
- libraries=[self._default_library])
+ libraries=[self._default_library],
+ )
def test_from_config(self):
- pytest.raises(CannotLoadConfiguration, MARCExporter.from_config, self._default_library)
+ pytest.raises(
+ CannotLoadConfiguration, MARCExporter.from_config, self._default_library
+ )
integration = self._integration()
exporter = MARCExporter.from_config(self._default_library)
@@ -404,8 +479,12 @@ def test_from_config(self):
pytest.raises(CannotLoadConfiguration, MARCExporter.from_config, other_library)
def test_create_record(self):
- work = self._work(with_license_pool=True, title="old title",
- authors=["old author"], data_source_name=DataSource.OVERDRIVE)
+ work = self._work(
+ with_license_pool=True,
+ title="old title",
+ authors=["old author"],
+ data_source_name=DataSource.OVERDRIVE,
+ )
annotator = Annotator()
# The record isn't cached yet, so a new record is created and cached.
@@ -455,9 +534,13 @@ def test_create_record(self):
# If we pass in an integration, it's passed along to the annotator.
integration = self._integration()
+
class MockAnnotator(Annotator):
integration = None
- def annotate_work_record(self, work, pool, edition, identifier, record, integration):
+
+ def annotate_work_record(
+ self, work, pool, edition, identifier, record, integration
+ ):
self.integration = integration
annotator = MockAnnotator()
@@ -474,9 +557,9 @@ def test_create_record_roundtrip(self):
# Creates a new record and saves it to the database
work = self._work(
- title="Little Mimi\u2019s First Counting Lesson",
- authors=["Lagerlo\xf6f, Selma Ottiliana Lovisa,"],
- with_license_pool=True
+ title="Little Mimi\u2019s First Counting Lesson",
+ authors=["Lagerlo\xf6f, Selma Ottiliana Lovisa,"],
+ with_license_pool=True,
)
record = MARCExporter.create_record(work, annotator)
loaded_record = MARCExporter.create_record(work, annotator)
@@ -506,13 +589,23 @@ def test_records(self):
# If there is a storage integration, the output file is mirrored.
mirror_integration = self._external_integration(
- ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL,
- username="username", password="password",
+ ExternalIntegration.S3,
+ ExternalIntegration.STORAGE_GOAL,
+ username="username",
+ password="password",
)
mirror = MockS3Uploader()
- exporter.records(lane, annotator, mirror_integration, mirror=mirror, query_batch_size=1, upload_batch_size=1, search_engine=search_engine)
+ exporter.records(
+ lane,
+ annotator,
+ mirror_integration,
+ mirror=mirror,
+ query_batch_size=1,
+ upload_batch_size=1,
+ search_engine=search_engine,
+ )
# The file was mirrored and a CachedMARCFile was created to track the mirrored file.
assert 1 == len(mirror.uploaded)
@@ -521,11 +614,15 @@ def test_records(self):
assert lane == cache.lane
assert mirror.uploaded[0] == cache.representation
assert None == cache.representation.content
- assert ("https://test-marc-bucket.s3.amazonaws.com/%s/%s/%s.mrc" % (
- self._default_library.short_name,
- quote(str(cache.representation.fetched_at)),
- quote(lane.display_name)) ==
- mirror.uploaded[0].mirror_url)
+ assert (
+ "https://test-marc-bucket.s3.amazonaws.com/%s/%s/%s.mrc"
+ % (
+ self._default_library.short_name,
+ quote(str(cache.representation.fetched_at)),
+ quote(lane.display_name),
+ )
+ == mirror.uploaded[0].mirror_url
+ )
assert None == cache.start_time
assert cache.end_time > now
@@ -550,7 +647,15 @@ def test_records(self):
worklist.initialize(self._default_library, display_name="All Books")
mirror = MockS3Uploader()
- exporter.records(worklist, annotator, mirror_integration, mirror=mirror, query_batch_size=1, upload_batch_size=1, search_engine=search_engine)
+ exporter.records(
+ worklist,
+ annotator,
+ mirror_integration,
+ mirror=mirror,
+ query_batch_size=1,
+ upload_batch_size=1,
+ search_engine=search_engine,
+ )
assert 1 == len(mirror.uploaded)
[cache] = self._db.query(CachedMARCFile).all()
@@ -558,11 +663,15 @@ def test_records(self):
assert None == cache.lane
assert mirror.uploaded[0] == cache.representation
assert None == cache.representation.content
- assert ("https://test-marc-bucket.s3.amazonaws.com/%s/%s/%s.mrc" % (
- self._default_library.short_name,
- quote(str(cache.representation.fetched_at)),
- quote(worklist.display_name)) ==
- mirror.uploaded[0].mirror_url)
+ assert (
+ "https://test-marc-bucket.s3.amazonaws.com/%s/%s/%s.mrc"
+ % (
+ self._default_library.short_name,
+ quote(str(cache.representation.fetched_at)),
+ quote(worklist.display_name),
+ )
+ == mirror.uploaded[0].mirror_url
+ )
assert None == cache.start_time
assert cache.end_time > now
@@ -583,9 +692,14 @@ def test_records(self):
mirror = MockS3Uploader()
exporter.records(
- lane, annotator, mirror_integration, start_time=start_time,
- mirror=mirror, query_batch_size=2,
- upload_batch_size=2, search_engine=search_engine
+ lane,
+ annotator,
+ mirror_integration,
+ start_time=start_time,
+ mirror=mirror,
+ query_batch_size=2,
+ upload_batch_size=2,
+ search_engine=search_engine,
)
[cache] = self._db.query(CachedMARCFile).all()
@@ -593,11 +707,16 @@ def test_records(self):
assert lane == cache.lane
assert mirror.uploaded[0] == cache.representation
assert None == cache.representation.content
- assert ("https://test-marc-bucket.s3.amazonaws.com/%s/%s-%s/%s.mrc" % (
- self._default_library.short_name, quote(str(start_time)),
- quote(str(cache.representation.fetched_at)),
- quote(lane.display_name)) ==
- mirror.uploaded[0].mirror_url)
+ assert (
+ "https://test-marc-bucket.s3.amazonaws.com/%s/%s-%s/%s.mrc"
+ % (
+ self._default_library.short_name,
+ quote(str(start_time)),
+ quote(str(cache.representation.fetched_at)),
+ quote(lane.display_name),
+ )
+ == mirror.uploaded[0].mirror_url
+ )
assert start_time == cache.start_time
assert cache.end_time > now
self._db.delete(cache)
@@ -608,8 +727,13 @@ def test_records(self):
empty_search_engine = MockExternalSearchIndex()
mirror = MockS3Uploader()
- exporter.records(lane, annotator, mirror_integration,
- mirror=mirror, search_engine=empty_search_engine)
+ exporter.records(
+ lane,
+ annotator,
+ mirror_integration,
+ mirror=mirror,
+ search_engine=empty_search_engine,
+ )
assert [] == mirror.content[0]
[cache] = self._db.query(CachedMARCFile).all()
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 692eba551..f20c8e727 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -2,21 +2,16 @@
import datetime
import os
from copy import deepcopy
+
import pytest
from parameterized import parameterized
-from ..testing import (
- DatabaseTest,
- DummyHTTPClient,
- DummyMetadataClient,
-)
from ..analytics import Analytics
-from ..classifier import Classifier
-from ..classifier import NO_VALUE, NO_NUMBER
+from ..classifier import NO_NUMBER, NO_VALUE, Classifier
from ..metadata_layer import (
- CSVMetadataImporter,
CirculationData,
ContributorData,
+ CSVMetadataImporter,
IdentifierData,
LinkData,
MARCExtractor,
@@ -31,9 +26,9 @@
CoverageRecord,
DataSource,
Edition,
+ Hyperlink,
Identifier,
Measurement,
- Hyperlink,
Representation,
RightsStatus,
Subject,
@@ -43,17 +38,12 @@
)
from ..model.configuration import ExternalIntegrationLink
from ..s3 import MockS3Uploader
+from ..testing import DatabaseTest, DummyHTTPClient, DummyMetadataClient
+from ..util.datetime_helpers import datetime_utc, strptime_utc, to_utc, utc_now
from ..util.http import RemoteIntegrationException
-from ..util.datetime_helpers import (
- datetime_utc,
- strptime_utc,
- to_utc,
- utc_now,
-)
class TestIdentifierData(object):
-
def test_constructor(self):
data = IdentifierData(Identifier.ISBN, "foo", 0.5)
assert Identifier.ISBN == data.type
@@ -64,8 +54,7 @@ def test_constructor(self):
class TestMetadataImporter(DatabaseTest):
def test_parse(self):
base_path = os.path.split(__file__)[0]
- path = os.path.join(
- base_path, "files/csv/staff_picks_small.csv")
+ path = os.path.join(base_path, "files/csv/staff_picks_small.csv")
reader = csv.DictReader(open(path))
importer = CSVMetadataImporter(
DataSource.LIBRARY_STAFF,
@@ -82,7 +71,7 @@ def test_parse(self):
# The first book has an Overdrive ID
[overdrive] = m1.identifiers
assert Identifier.OVERDRIVE_ID == overdrive.type
- assert '504BA8F6-FF4E-4B57-896E-F1A50CFFCA0C' == overdrive.identifier
+ assert "504BA8F6-FF4E-4B57-896E-F1A50CFFCA0C" == overdrive.identifier
assert 0.75 == overdrive.weight
# The second book has no ID at all.
@@ -92,24 +81,24 @@ def test_parse(self):
overdrive, threem = sorted(m3.identifiers, key=lambda x: x.identifier)
assert Identifier.OVERDRIVE_ID == overdrive.type
- assert 'eae60d41-e0b8-4f9d-90b5-cbc43d433c2f' == overdrive.identifier
+ assert "eae60d41-e0b8-4f9d-90b5-cbc43d433c2f" == overdrive.identifier
assert 0.75 == overdrive.weight
assert Identifier.THREEM_ID == threem.type
- assert 'eswhyz9' == threem.identifier
+ assert "eswhyz9" == threem.identifier
assert 0.75 == threem.weight
# Now let's check out subjects.
- assert (
- [
- ('schema:typicalAgeRange', 'Adult', 100),
- ('tag', 'Character Driven', 100),
- ('tag', 'Historical', 100),
- ('tag', 'Nail-Biters', 100),
- ('tag', 'Setting Driven', 100)
- ] ==
- [(x.type, x.identifier, x.weight)
- for x in sorted(m2.subjects, key=lambda x: x.identifier)])
+ assert [
+ ("schema:typicalAgeRange", "Adult", 100),
+ ("tag", "Character Driven", 100),
+ ("tag", "Historical", 100),
+ ("tag", "Nail-Biters", 100),
+ ("tag", "Setting Driven", 100),
+ ] == [
+ (x.type, x.identifier, x.weight)
+ for x in sorted(m2.subjects, key=lambda x: x.identifier)
+ ]
def test_classifications_from_another_source_not_updated(self):
@@ -130,19 +119,18 @@ def test_classifications_from_another_source_not_updated(self):
# The old classification from source #2 has been destroyed.
# The old classification from source #1 is still there.
- assert (
- ['i will conquer', 'i will persist'] ==
- sorted([x.subject.identifier for x in identifier.classifications]))
+ assert ["i will conquer", "i will persist"] == sorted(
+ [x.subject.identifier for x in identifier.classifications]
+ )
def test_links(self):
edition = self._edition()
l1 = LinkData(rel=Hyperlink.IMAGE, href="http://example.com/")
l2 = LinkData(rel=Hyperlink.DESCRIPTION, content="foo")
- metadata = Metadata(links=[l1, l2],
- data_source=edition.data_source)
+ metadata = Metadata(links=[l1, l2], data_source=edition.data_source)
metadata.apply(edition, None)
[image, description] = sorted(
- edition.primary_identifier.links, key=lambda x:x.rel
+ edition.primary_identifier.links, key=lambda x: x.rel
)
assert Hyperlink.IMAGE == image.rel
assert "http://example.com/" == image.resource.url
@@ -153,22 +141,24 @@ def test_links(self):
def test_image_with_original_and_rights(self):
edition = self._edition()
data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF)
- original = LinkData(rel=Hyperlink.IMAGE,
- href="http://example.com/",
- media_type=Representation.PNG_MEDIA_TYPE,
- rights_uri=RightsStatus.PUBLIC_DOMAIN_USA,
- rights_explanation="This image is from 1922",
- )
- image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x01\x03\x00\x00\x00%\xdbV\xca\x00\x00\x00\x06PLTE\xffM\x00\x01\x01\x01\x8e\x1e\xe5\x1b\x00\x00\x00\x01tRNS\xcc\xd24V\xfd\x00\x00\x00\nIDATx\x9cc`\x00\x00\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82'
- derivative = LinkData(rel=Hyperlink.IMAGE,
- href="generic uri",
- content=image_data,
- media_type=Representation.PNG_MEDIA_TYPE,
- rights_uri=RightsStatus.PUBLIC_DOMAIN_USA,
- rights_explanation="This image is from 1922",
- original=original,
- transformation_settings=dict(position='top')
- )
+ original = LinkData(
+ rel=Hyperlink.IMAGE,
+ href="http://example.com/",
+ media_type=Representation.PNG_MEDIA_TYPE,
+ rights_uri=RightsStatus.PUBLIC_DOMAIN_USA,
+ rights_explanation="This image is from 1922",
+ )
+ image_data = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x01\x03\x00\x00\x00%\xdbV\xca\x00\x00\x00\x06PLTE\xffM\x00\x01\x01\x01\x8e\x1e\xe5\x1b\x00\x00\x00\x01tRNS\xcc\xd24V\xfd\x00\x00\x00\nIDATx\x9cc`\x00\x00\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82"
+ derivative = LinkData(
+ rel=Hyperlink.IMAGE,
+ href="generic uri",
+ content=image_data,
+ media_type=Representation.PNG_MEDIA_TYPE,
+ rights_uri=RightsStatus.PUBLIC_DOMAIN_USA,
+ rights_explanation="This image is from 1922",
+ original=original,
+ transformation_settings=dict(position="top"),
+ )
metadata = Metadata(links=[derivative], data_source=data_source)
metadata.apply(edition, None)
@@ -184,48 +174,54 @@ def test_image_with_original_and_rights(self):
assert image.resource == transformation.derivative
assert "http://example.com/" == transformation.original.url
- assert RightsStatus.PUBLIC_DOMAIN_USA == transformation.original.rights_status.uri
+ assert (
+ RightsStatus.PUBLIC_DOMAIN_USA == transformation.original.rights_status.uri
+ )
assert "This image is from 1922" == transformation.original.rights_explanation
assert "top" == transformation.settings.get("position")
def test_image_and_thumbnail(self):
edition = self._edition()
l2 = LinkData(
- rel=Hyperlink.THUMBNAIL_IMAGE, href="http://thumbnail.com/",
+ rel=Hyperlink.THUMBNAIL_IMAGE,
+ href="http://thumbnail.com/",
media_type=Representation.JPEG_MEDIA_TYPE,
)
l1 = LinkData(
- rel=Hyperlink.IMAGE, href="http://example.com/", thumbnail=l2,
+ rel=Hyperlink.IMAGE,
+ href="http://example.com/",
+ thumbnail=l2,
media_type=Representation.JPEG_MEDIA_TYPE,
)
# Even though we're only passing in the primary image link...
- metadata = Metadata(links=[l1],
- data_source=edition.data_source)
+ metadata = Metadata(links=[l1], data_source=edition.data_source)
metadata.apply(edition, None)
# ...a Hyperlink is also created for the thumbnail.
[image, thumbnail] = sorted(
- edition.primary_identifier.links, key=lambda x:x.rel
+ edition.primary_identifier.links, key=lambda x: x.rel
)
assert Hyperlink.IMAGE == image.rel
- assert ([thumbnail.resource.representation] ==
- image.resource.representation.thumbnails)
+ assert [
+ thumbnail.resource.representation
+ ] == image.resource.representation.thumbnails
def test_thumbnail_isnt_a_thumbnail(self):
edition = self._edition()
not_a_thumbnail = LinkData(
- rel=Hyperlink.DESCRIPTION, content="A great book",
+ rel=Hyperlink.DESCRIPTION,
+ content="A great book",
media_type=Representation.TEXT_PLAIN,
)
image = LinkData(
- rel=Hyperlink.IMAGE, href="http://example.com/",
+ rel=Hyperlink.IMAGE,
+ href="http://example.com/",
thumbnail=not_a_thumbnail,
media_type=Representation.JPEG_MEDIA_TYPE,
)
- metadata = Metadata(links=[image],
- data_source=edition.data_source)
+ metadata = Metadata(links=[image], data_source=edition.data_source)
metadata.apply(edition, None)
# Only one Hyperlink was created for the image, because
@@ -236,11 +232,12 @@ def test_thumbnail_isnt_a_thumbnail(self):
# If we pass in the 'thumbnail' separately, a Hyperlink is
# created for it, but it's still not a thumbnail of anything.
- metadata = Metadata(links=[image, not_a_thumbnail],
- data_source=edition.data_source)
+ metadata = Metadata(
+ links=[image, not_a_thumbnail], data_source=edition.data_source
+ )
metadata.apply(edition, None)
[image, description] = sorted(
- edition.primary_identifier.links, key=lambda x:x.rel
+ edition.primary_identifier.links, key=lambda x: x.rel
)
assert Hyperlink.DESCRIPTION == description.rel
assert b"A great book" == description.resource.representation.content
@@ -251,16 +248,18 @@ def test_image_and_thumbnail_are_the_same(self):
edition = self._edition()
url = "http://tinyimage.com/image.jpg"
l2 = LinkData(
- rel=Hyperlink.THUMBNAIL_IMAGE, href=url,
+ rel=Hyperlink.THUMBNAIL_IMAGE,
+ href=url,
)
l1 = LinkData(
- rel=Hyperlink.IMAGE, href=url, thumbnail=l2,
+ rel=Hyperlink.IMAGE,
+ href=url,
+ thumbnail=l2,
)
- metadata = Metadata(links=[l1, l2],
- data_source=edition.data_source)
+ metadata = Metadata(links=[l1, l2], data_source=edition.data_source)
metadata.apply(edition, None)
[image, thumbnail] = sorted(
- edition.primary_identifier.links, key=lambda x:x.rel
+ edition.primary_identifier.links, key=lambda x: x.rel
)
# The image and its thumbnail point to the same resource.
@@ -270,8 +269,9 @@ def test_image_and_thumbnail_are_the_same(self):
assert Hyperlink.THUMBNAIL_IMAGE == thumbnail.rel
# The thumbnail is marked as a thumbnail of the main image.
- assert ([thumbnail.resource.representation] ==
- image.resource.representation.thumbnails)
+ assert [
+ thumbnail.resource.representation
+ ] == image.resource.representation.thumbnails
assert url == edition.cover_full_url
assert url == edition.cover_thumbnail_url
@@ -281,21 +281,23 @@ def test_image_becomes_representation_but_thumbnail_does_not(self):
# The thumbnail link has no media type, and none can be
# derived from the URL.
l2 = LinkData(
- rel=Hyperlink.THUMBNAIL_IMAGE, href="http://tinyimage.com/",
+ rel=Hyperlink.THUMBNAIL_IMAGE,
+ href="http://tinyimage.com/",
)
# The full-sized image link does not have this problem.
l1 = LinkData(
- rel=Hyperlink.IMAGE, href="http://largeimage.com/", thumbnail=l2,
+ rel=Hyperlink.IMAGE,
+ href="http://largeimage.com/",
+ thumbnail=l2,
media_type=Representation.JPEG_MEDIA_TYPE,
)
- metadata = Metadata(links=[l1],
- data_source=edition.data_source)
+ metadata = Metadata(links=[l1], data_source=edition.data_source)
metadata.apply(edition, None)
# Both LinkData objects have been imported as Hyperlinks.
[image, thumbnail] = sorted(
- edition.primary_identifier.links, key=lambda x:x.rel
+ edition.primary_identifier.links, key=lambda x: x.rel
)
# However, since no Representation was created for the thumbnail,
@@ -326,19 +328,23 @@ def test_image_scale_and_mirror(self):
# However, updated tests passing does not guarantee that all code now
# correctly calls on CirculationData, too. This is a risk.
- mirrors = dict(covers_mirror=MockS3Uploader(),books_mirror=None)
+ mirrors = dict(covers_mirror=MockS3Uploader(), books_mirror=None)
edition, pool = self._edition(with_license_pool=True)
content = open(self.sample_cover_path("test-book-cover.png"), "rb").read()
l1 = LinkData(
- rel=Hyperlink.IMAGE, href="http://example.com/",
+ rel=Hyperlink.IMAGE,
+ href="http://example.com/",
media_type=Representation.JPEG_MEDIA_TYPE,
- content=content
+ content=content,
)
- thumbnail_content = open(self.sample_cover_path("tiny-image-cover.png"), "rb").read()
+ thumbnail_content = open(
+ self.sample_cover_path("tiny-image-cover.png"), "rb"
+ ).read()
l2 = LinkData(
- rel=Hyperlink.THUMBNAIL_IMAGE, href="http://example.com/thumb.jpg",
+ rel=Hyperlink.THUMBNAIL_IMAGE,
+ href="http://example.com/thumb.jpg",
media_type=Representation.JPEG_MEDIA_TYPE,
- content=content
+ content=content,
)
# When we call metadata.apply, all image links will be scaled and
@@ -372,23 +378,30 @@ def test_image_scale_and_mirror(self):
assert thumbnail.content != l2.content
# Both images have been 'mirrored' to Amazon S3.
- assert image.mirror_url.startswith('https://test-cover-bucket.s3.amazonaws.com/')
- assert image.mirror_url.endswith('cover.jpg')
+ assert image.mirror_url.startswith(
+ "https://test-cover-bucket.s3.amazonaws.com/"
+ )
+ assert image.mirror_url.endswith("cover.jpg")
# The thumbnail image has been converted to PNG.
- assert thumbnail.mirror_url.startswith('https://test-cover-bucket.s3.amazonaws.com/scaled/300/')
- assert thumbnail.mirror_url.endswith('cover.png')
+ assert thumbnail.mirror_url.startswith(
+ "https://test-cover-bucket.s3.amazonaws.com/scaled/300/"
+ )
+ assert thumbnail.mirror_url.endswith("cover.png")
def test_mirror_thumbnail_only(self):
# Make sure a thumbnail image is mirrored when there's no cover image.
mirrors = dict(covers_mirror=MockS3Uploader())
mirror_type = ExternalIntegrationLink.COVERS
edition, pool = self._edition(with_license_pool=True)
- thumbnail_content = open(self.sample_cover_path("tiny-image-cover.png"), "rb").read()
+ thumbnail_content = open(
+ self.sample_cover_path("tiny-image-cover.png"), "rb"
+ ).read()
l = LinkData(
- rel=Hyperlink.THUMBNAIL_IMAGE, href="http://example.com/thumb.png",
+ rel=Hyperlink.THUMBNAIL_IMAGE,
+ href="http://example.com/thumb.png",
media_type=Representation.PNG_MEDIA_TYPE,
- content=thumbnail_content
+ content=thumbnail_content,
)
policy = ReplacementPolicy(mirrors=mirrors)
@@ -399,8 +412,10 @@ def test_mirror_thumbnail_only(self):
[thumbnail] = mirrors[mirror_type].uploaded
# The image has been 'mirrored' to Amazon S3.
- assert thumbnail.mirror_url.startswith('https://test-cover-bucket.s3.amazonaws.com/')
- assert thumbnail.mirror_url.endswith('thumb.png')
+ assert thumbnail.mirror_url.startswith(
+ "https://test-cover-bucket.s3.amazonaws.com/"
+ )
+ assert thumbnail.mirror_url.endswith("thumb.png")
def test_mirror_open_access_link_fetch_failure(self):
edition, pool = self._edition(with_license_pool=True)
@@ -420,8 +435,11 @@ def test_mirror_open_access_link_fetch_failure(self):
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
h.queue_response(403)
@@ -446,7 +464,7 @@ def test_mirror_open_access_link_fetch_failure(self):
assert None == pool.license_exception
def test_mirror_404_error(self):
- mirrors = dict(covers_mirror=MockS3Uploader(),books_mirror=None)
+ mirrors = dict(covers_mirror=MockS3Uploader(), books_mirror=None)
mirror_type = ExternalIntegrationLink.COVERS
h = DummyHTTPClient()
h.queue_response(404)
@@ -463,8 +481,11 @@ def test_mirror_404_error(self):
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
m = Metadata(data_source=data_source)
@@ -492,12 +513,15 @@ def test_mirror_open_access_link_mirror_failure(self):
rel=Hyperlink.IMAGE,
media_type=Representation.JPEG_MEDIA_TYPE,
href="http://example.com/",
- content=content
+ content=content,
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
h.queue_response(200, media_type=Representation.JPEG_MEDIA_TYPE)
@@ -548,14 +572,18 @@ def test_mirror_link_bad_media_type(self):
rel=Hyperlink.IMAGE,
media_type=Representation.JPEG_MEDIA_TYPE,
href="http://example.com/",
- content=content
+ content=content,
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
)
# The remote server told us a generic media type.
- h.queue_response(200, media_type=Representation.OCTET_STREAM_MEDIA_TYPE, content=content)
+ h.queue_response(
+ 200, media_type=Representation.OCTET_STREAM_MEDIA_TYPE, content=content
+ )
m.mirror_link(edition, data_source, link, link_obj, policy)
representation = link_obj.resource.representation
@@ -569,18 +597,22 @@ def test_mirror_link_bad_media_type(self):
assert Representation.JPEG_MEDIA_TYPE == representation.media_type
assert link.href == representation.url
assert "Gutenberg" in representation.mirror_url
- assert representation.mirror_url.endswith("%s/cover.jpg" % edition.primary_identifier.identifier)
+ assert representation.mirror_url.endswith(
+ "%s/cover.jpg" % edition.primary_identifier.identifier
+ )
# We don't know the media type for this link, but it has a file extension.
link = LinkData(
- rel=Hyperlink.IMAGE,
- href="http://example.com/image.png",
- content=content
+ rel=Hyperlink.IMAGE, href="http://example.com/image.png", content=content
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ )
+ h.queue_response(
+ 200, media_type=Representation.OCTET_STREAM_MEDIA_TYPE, content=content
)
- h.queue_response(200, media_type=Representation.OCTET_STREAM_MEDIA_TYPE, content=content)
m.mirror_link(edition, data_source, link, link_obj, policy)
representation = link_obj.resource.representation
@@ -593,18 +625,22 @@ def test_mirror_link_bad_media_type(self):
assert Representation.PNG_MEDIA_TYPE == representation.media_type
assert link.href == representation.url
assert "Gutenberg" in representation.mirror_url
- assert representation.mirror_url.endswith("%s/image.png" % edition.primary_identifier.identifier)
+ assert representation.mirror_url.endswith(
+ "%s/image.png" % edition.primary_identifier.identifier
+ )
# We don't know the media type of this link, and there's no extension.
link = LinkData(
- rel=Hyperlink.IMAGE,
- href="http://example.com/unknown",
- content=content
+ rel=Hyperlink.IMAGE, href="http://example.com/unknown", content=content
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ )
+ h.queue_response(
+ 200, media_type=Representation.OCTET_STREAM_MEDIA_TYPE, content=content
)
- h.queue_response(200, media_type=Representation.OCTET_STREAM_MEDIA_TYPE, content=content)
m.mirror_link(edition, data_source, link, link_obj, policy)
representation = link_obj.resource.representation
@@ -634,13 +670,16 @@ def test_non_open_access_book_not_mirrored(self):
media_type=Representation.EPUB_MEDIA_TYPE,
href="http://example.com/",
content=content,
- rights_uri=RightsStatus.IN_COPYRIGHT
+ rights_uri=RightsStatus.IN_COPYRIGHT,
)
identifier = self._identifier()
link_obj, is_new = identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content,
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
# The Hyperlink object makes it look like an open-access book,
@@ -662,11 +701,15 @@ def test_mirror_with_content_modifier(self):
mirrors = dict(books_mirror=MockS3Uploader())
mirror_type = ExternalIntegrationLink.OPEN_ACCESS_BOOKS
+
def dummy_content_modifier(representation):
representation.content = "Replaced Content"
+
h = DummyHTTPClient()
- policy = ReplacementPolicy(mirrors=mirrors, content_modifier=dummy_content_modifier, http_get=h.do_get)
+ policy = ReplacementPolicy(
+ mirrors=mirrors, content_modifier=dummy_content_modifier, http_get=h.do_get
+ )
link = LinkData(
rel=Hyperlink.OPEN_ACCESS_DOWNLOAD,
@@ -676,8 +719,11 @@ def dummy_content_modifier(representation):
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
h.queue_response(200, media_type=Representation.EPUB_MEDIA_TYPE)
@@ -692,7 +738,9 @@ def dummy_content_modifier(representation):
# The mirror url is set.
assert "Gutenberg" in representation.mirror_url
- assert representation.mirror_url.endswith("%s/%s.epub" % (edition.primary_identifier.identifier, edition.title))
+ assert representation.mirror_url.endswith(
+ "%s/%s.epub" % (edition.primary_identifier.identifier, edition.title)
+ )
# Content isn't there since it was mirrored.
assert None == representation.content
@@ -712,9 +760,12 @@ def test_mirror_protected_access_book(self):
def dummy_content_modifier(representation):
representation.content = "Replaced Content"
+
h = DummyHTTPClient()
- policy = ReplacementPolicy(mirrors=mirrors, content_modifier=dummy_content_modifier, http_get=h.do_get)
+ policy = ReplacementPolicy(
+ mirrors=mirrors, content_modifier=dummy_content_modifier, http_get=h.do_get
+ )
link = LinkData(
rel=Hyperlink.GENERIC_OPDS_ACQUISITION,
@@ -724,8 +775,11 @@ def dummy_content_modifier(representation):
)
link_obj, ignore = edition.primary_identifier.add_link(
- rel=link.rel, href=link.href, data_source=data_source,
- media_type=link.media_type, content=link.content
+ rel=link.rel,
+ href=link.href,
+ data_source=data_source,
+ media_type=link.media_type,
+ content=link.content,
)
h.queue_response(200, media_type=Representation.EPUB_MEDIA_TYPE)
@@ -740,7 +794,9 @@ def dummy_content_modifier(representation):
# The mirror url is set.
assert "Gutenberg" in representation.mirror_url
- assert representation.mirror_url.endswith("%s/%s.epub" % (edition.primary_identifier.identifier, edition.title))
+ assert representation.mirror_url.endswith(
+ "%s/%s.epub" % (edition.primary_identifier.identifier, edition.title)
+ )
# Content isn't there since it was mirrored.
assert None == representation.content
@@ -751,10 +807,10 @@ def dummy_content_modifier(representation):
def test_measurements(self):
edition = self._edition()
- measurement = MeasurementData(quantity_measured=Measurement.POPULARITY,
- value=100)
- metadata = Metadata(measurements=[measurement],
- data_source=edition.data_source)
+ measurement = MeasurementData(
+ quantity_measured=Measurement.POPULARITY, value=100
+ )
+ metadata = Metadata(measurements=[measurement], data_source=edition.data_source)
metadata.apply(edition, None)
[m] = edition.primary_identifier.measurements
assert Measurement.POPULARITY == m.quantity_measured
@@ -770,8 +826,11 @@ def test_coverage_record(self):
last_update = datetime_utc(2015, 1, 1)
- m = Metadata(data_source=data_source,
- title="New title", data_source_last_updated=last_update)
+ m = Metadata(
+ data_source=data_source,
+ title="New title",
+ data_source_last_updated=last_update,
+ )
m.apply(edition, None)
coverage = CoverageRecord.lookup(edition, data_source)
@@ -779,9 +838,10 @@ def test_coverage_record(self):
assert "New title" == edition.title
older_last_update = datetime_utc(2014, 1, 1)
- m = Metadata(data_source=data_source,
- title="Another new title",
- data_source_last_updated=older_last_update
+ m = Metadata(
+ data_source=data_source,
+ title="Another new title",
+ data_source_last_updated=older_last_update,
)
m.apply(edition, None)
assert "New title" == edition.title
@@ -795,7 +855,6 @@ def test_coverage_record(self):
assert older_last_update == coverage.timestamp
-
class TestContributorData(DatabaseTest):
def test_from_contribution(self):
# Makes sure ContributorData.from_contribution copies all the fields over.
@@ -839,8 +898,10 @@ def m(*args, **kwargs):
# We know a lot about this person.
pkd, ignore = self._contributor(
- sort_name="Dick, Phillip K.", display_name="Phillip K. Dick",
- viaf="27063583", lc="n79018147"
+ sort_name="Dick, Phillip K.",
+ display_name="Phillip K. Dick",
+ viaf="27063583",
+ lc="n79018147",
)
def _match(expect, actual):
@@ -871,16 +932,17 @@ def _match(expect, actual):
# what we know from the database.
_match(
pkd,
- m(display_name="Phillip K. Dick", sort_name="Marenghi, Garth",
- viaf="1234", lc="abcd"
- )
+ m(
+ display_name="Phillip K. Dick",
+ sort_name="Marenghi, Garth",
+ viaf="1234",
+ lc="abcd",
+ ),
)
# If we're able to identify a Contributor, but we don't know some
# of the information, those fields are left blank.
- expect = ContributorData(
- display_name="Ann Leckie", sort_name="Leckie, Ann"
- )
+ expect = ContributorData(display_name="Ann Leckie", sort_name="Leckie, Ann")
_match(expect, m(display_name="Ann Leckie"))
# Now let's test cases where the database lookup finds
@@ -899,35 +961,34 @@ def _match(expect, actual):
# ContributorData that doesn't correspond to any one
# Contributor.
with_viaf, ignore = self._contributor(
- display_name="Ann Leckie", viaf="73520345",
+ display_name="Ann Leckie",
+ viaf="73520345",
)
# _contributor() set sort_name to a random value; remove it.
with_viaf.sort_name = None
expect = ContributorData(
- display_name="Ann Leckie", sort_name="Leckie, Ann",
- viaf="73520345"
- )
- _match(
- expect, m(display_name="Ann Leckie")
+ display_name="Ann Leckie", sort_name="Leckie, Ann", viaf="73520345"
)
+ _match(expect, m(display_name="Ann Leckie"))
# Again, this works even if some of the incoming arguments
# turn out not to be supported by the database data.
_match(
- expect, m(display_name="Ann Leckie", sort_name="Ann Leckie",
- viaf="abcd")
+ expect, m(display_name="Ann Leckie", sort_name="Ann Leckie", viaf="abcd")
)
# If there's a duplicate that provides conflicting information,
# the corresponding field is left blank -- we don't know which
# value is correct.
with_incorrect_viaf, ignore = self._contributor(
- display_name="Ann Leckie", viaf="abcd",
+ display_name="Ann Leckie",
+ viaf="abcd",
)
- with_incorrect_viaf.sort_name=None
+ with_incorrect_viaf.sort_name = None
expect = ContributorData(
- display_name="Ann Leckie", sort_name="Leckie, Ann",
+ display_name="Ann Leckie",
+ sort_name="Leckie, Ann",
)
_match(expect, m(display_name="Ann Leckie"))
@@ -940,22 +1001,23 @@ def _match(expect, actual):
def test_apply(self):
# Makes sure ContributorData.apply copies all the fields over when there's changes to be made.
-
- contributor_old, made_new = self._contributor(sort_name="Doe, John", viaf="viaf12345")
+ contributor_old, made_new = self._contributor(
+ sort_name="Doe, John", viaf="viaf12345"
+ )
kwargs = dict()
- kwargs[Contributor.BIRTH_DATE] = '2001-01-01'
+ kwargs[Contributor.BIRTH_DATE] = "2001-01-01"
contributor_data = ContributorData(
- sort_name = "Doerr, John",
- lc = "1234567",
- viaf = "ABC123",
- aliases = ["Primo"],
- display_name = "Test Author For The Win",
- family_name = "TestAuttie",
- wikipedia_name = "TestWikiAuth",
- biography = "He was born on Main Street.",
- extra = kwargs,
+ sort_name="Doerr, John",
+ lc="1234567",
+ viaf="ABC123",
+ aliases=["Primo"],
+ display_name="Test Author For The Win",
+ family_name="TestAuttie",
+ wikipedia_name="TestWikiAuth",
+ biography="He was born on Main Street.",
+ extra=kwargs,
)
contributor_new, changed = contributor_data.apply(contributor_old)
@@ -971,7 +1033,7 @@ def test_apply(self):
assert contributor_new.biography == "He was born on Main Street."
assert contributor_new.extra[Contributor.BIRTH_DATE] == "2001-01-01"
- #assert_equal(contributor_new.contributions, "Audio")
+ # assert_equal(contributor_new.contributions, "Audio")
contributor_new, changed = contributor_data.apply(contributor_new)
assert changed == False
@@ -979,16 +1041,30 @@ def test_apply(self):
def test_display_name_to_sort_name_from_existing_contributor(self):
# If there's an existing contributor with a matching display name,
# we'll use their sort name.
- existing_contributor, ignore = self._contributor(sort_name="Sort, Name", display_name="John Doe")
- assert "Sort, Name" == ContributorData.display_name_to_sort_name_from_existing_contributor(self._db, "John Doe")
+ existing_contributor, ignore = self._contributor(
+ sort_name="Sort, Name", display_name="John Doe"
+ )
+ assert (
+ "Sort, Name"
+ == ContributorData.display_name_to_sort_name_from_existing_contributor(
+ self._db, "John Doe"
+ )
+ )
# Otherwise, we don't know.
- assert None == ContributorData.display_name_to_sort_name_from_existing_contributor(self._db, "Jane Doe")
+ assert (
+ None
+ == ContributorData.display_name_to_sort_name_from_existing_contributor(
+ self._db, "Jane Doe"
+ )
+ )
def test_find_sort_name(self):
metadata_client = DummyMetadataClient()
metadata_client.lookups["Metadata Client Author"] = "Author, M. C."
- existing_contributor, ignore = self._contributor(sort_name="Author, E.", display_name="Existing Author")
+ existing_contributor, ignore = self._contributor(
+ sort_name="Author, E.", display_name="Existing Author"
+ )
contributor_data = ContributorData()
# If there's already a sort name, keep it.
@@ -1028,7 +1104,6 @@ def test_find_sort_name(self):
assert "Author, New" == contributor_data.sort_name
def test_find_sort_name_survives_metadata_client_exception(self):
-
class Mock(ContributorData):
# Simulate an integration error from the metadata wrangler side.
def display_name_to_sort_name_through_canonicalizer(
@@ -1048,8 +1123,7 @@ def display_name_to_sort_name_through_canonicalizer(
# display_name_to_sort_name_through_canonicalizer was called
# with the arguments we expect.
- assert ((self._db, identifiers, metadata_client) ==
- contributor_data.called_with)
+ assert (self._db, identifiers, metadata_client) == contributor_data.called_with
# Although that method raised an exception, we were able to
# keep going and use the default display name -> sort name
@@ -1058,15 +1132,27 @@ def display_name_to_sort_name_through_canonicalizer(
class TestLinkData(DatabaseTest):
- @parameterized.expand([
- ('image', Hyperlink.IMAGE, ExternalIntegrationLink.COVERS),
- ('thumbnail', Hyperlink.THUMBNAIL_IMAGE, ExternalIntegrationLink.COVERS),
- ('open_access_book', Hyperlink.OPEN_ACCESS_DOWNLOAD, ExternalIntegrationLink.OPEN_ACCESS_BOOKS),
- ('protected_access_book', Hyperlink.GENERIC_OPDS_ACQUISITION, ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS)
- ])
- def test_mirror_type_returns_correct_mirror_type_for(self, name, rel, expected_mirror_type):
+ @parameterized.expand(
+ [
+ ("image", Hyperlink.IMAGE, ExternalIntegrationLink.COVERS),
+ ("thumbnail", Hyperlink.THUMBNAIL_IMAGE, ExternalIntegrationLink.COVERS),
+ (
+ "open_access_book",
+ Hyperlink.OPEN_ACCESS_DOWNLOAD,
+ ExternalIntegrationLink.OPEN_ACCESS_BOOKS,
+ ),
+ (
+ "protected_access_book",
+ Hyperlink.GENERIC_OPDS_ACQUISITION,
+ ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS,
+ ),
+ ]
+ )
+ def test_mirror_type_returns_correct_mirror_type_for(
+ self, name, rel, expected_mirror_type
+ ):
# Arrange
- link_data = LinkData(rel, href='dummy')
+ link_data = LinkData(rel, href="dummy")
# Act
result = link_data.mirror_type()
@@ -1088,8 +1174,9 @@ def test_guess_media_type(self):
# An explicitly known media type takes precedence over
# something we guess from the file extension.
- png = LinkData(rel, href="http://foo/bar.jpeg",
- media_type=Representation.PNG_MEDIA_TYPE)
+ png = LinkData(
+ rel, href="http://foo/bar.jpeg", media_type=Representation.PNG_MEDIA_TYPE
+ )
assert Representation.PNG_MEDIA_TYPE == png.guessed_media_type
description = LinkData(Hyperlink.DESCRIPTION, content="Some content")
@@ -1109,7 +1196,9 @@ def test_from_edition(self):
edition, pool = self._edition(with_license_pool=True)
edition.series = "Harry Otter and the Mollusk of Infamy"
edition.series_position = "14"
- edition.primary_identifier.add_link(Hyperlink.IMAGE, "image", edition.data_source)
+ edition.primary_identifier.add_link(
+ Hyperlink.IMAGE, "image", edition.data_source
+ )
metadata = Metadata.from_edition(edition)
# make sure the metadata and the originating edition match
@@ -1122,7 +1211,10 @@ def test_from_edition(self):
assert e_contribution.role == m_contributor_data.roles[0]
assert edition.data_source == metadata.data_source(self._db)
- assert edition.primary_identifier.identifier == metadata.primary_identifier.identifier
+ assert (
+ edition.primary_identifier.identifier
+ == metadata.primary_identifier.identifier
+ )
e_link = edition.primary_identifier.links[0]
m_link = metadata.links[0]
@@ -1175,7 +1267,7 @@ def test_apply(self):
publisher="Scholastic Inc",
imprint="Follywood",
published=datetime.date(1987, 5, 4),
- issued=datetime.date(1989, 4, 5)
+ issued=datetime.date(1989, 4, 5),
)
edition_new, changed = metadata.apply(edition_old, pool.collection)
@@ -1248,6 +1340,7 @@ def assert_registered(full):
else:
assert WCR.REGISTERED == x.status
x.status = WCR.SUCCESS
+
assert_registered(full=False)
# We then learn about a subject under which the work
@@ -1262,9 +1355,7 @@ def assert_registered(full):
# We then find a new description for the work.
metadata.subjects = None
- metadata.links = [
- LinkData(rel=Hyperlink.DESCRIPTION, content="a description")
- ]
+ metadata.links = [LinkData(rel=Hyperlink.DESCRIPTION, content="a description")]
metadata.apply(edition, None)
# We need to do a full recalculation again.
@@ -1272,15 +1363,12 @@ def assert_registered(full):
# We then find a new cover image for the work.
metadata.subjects = None
- metadata.links = [
- LinkData(rel=Hyperlink.IMAGE, href="http://image/")
- ]
+ metadata.links = [LinkData(rel=Hyperlink.IMAGE, href="http://image/")]
metadata.apply(edition, None)
# We need to choose a new presentation edition.
assert_registered(full=False)
-
def test_apply_identifier_equivalency(self):
# Set up an Edition.
@@ -1300,7 +1388,7 @@ def test_apply_identifier_equivalency(self):
metadata = Metadata(
data_source=DataSource.OVERDRIVE,
primary_identifier=primary,
- identifiers=[other_data]
+ identifiers=[other_data],
)
# Metadata.identifiers has two elements -- the primary and the
@@ -1314,7 +1402,7 @@ def test_apply_identifier_equivalency(self):
metadata2 = Metadata(
data_source=DataSource.OVERDRIVE,
primary_identifier=primary,
- identifiers=[primary_as_data, other_data]
+ identifiers=[primary_as_data, other_data],
)
assert 3 == len(metadata2.identifiers)
assert primary_as_data in metadata2.identifiers
@@ -1340,7 +1428,7 @@ def test_apply_no_value(self):
data_source=DataSource.PRESENTATION_EDITION,
subtitle=NO_VALUE,
series=NO_VALUE,
- series_position=NO_NUMBER
+ series_position=NO_NUMBER,
)
edition_new, changed = metadata.apply(edition_old, pool.collection)
@@ -1361,68 +1449,56 @@ def test_apply_no_value(self):
def test_apply_creates_coverage_records(self):
edition, pool = self._edition(with_license_pool=True)
- metadata = Metadata(
- data_source=DataSource.OVERDRIVE,
- title=self._str
- )
+ metadata = Metadata(data_source=DataSource.OVERDRIVE, title=self._str)
edition, changed = metadata.apply(edition, pool.collection)
# One success was recorded.
- records = self._db.query(
- CoverageRecord
- ).filter(
- CoverageRecord.identifier_id==edition.primary_identifier.id
- ).filter(
- CoverageRecord.operation==None
+ records = (
+ self._db.query(CoverageRecord)
+ .filter(CoverageRecord.identifier_id == edition.primary_identifier.id)
+ .filter(CoverageRecord.operation == None)
)
assert 1 == records.count()
assert CoverageRecord.SUCCESS == records.all()[0].status
# No metadata upload failure was recorded, because this metadata
# came from Overdrive.
- records = self._db.query(
- CoverageRecord
- ).filter(
- CoverageRecord.identifier_id==edition.primary_identifier.id
- ).filter(
- CoverageRecord.operation==CoverageRecord.METADATA_UPLOAD_OPERATION
+ records = (
+ self._db.query(CoverageRecord)
+ .filter(CoverageRecord.identifier_id == edition.primary_identifier.id)
+ .filter(
+ CoverageRecord.operation == CoverageRecord.METADATA_UPLOAD_OPERATION
+ )
)
assert 0 == records.count()
# Apply metadata from a different source.
- metadata = Metadata(
- data_source=DataSource.GUTENBERG,
- title=self._str
- )
+ metadata = Metadata(data_source=DataSource.GUTENBERG, title=self._str)
edition, changed = metadata.apply(edition, pool.collection)
# Another success record was created.
- records = self._db.query(
- CoverageRecord
- ).filter(
- CoverageRecord.identifier_id==edition.primary_identifier.id
- ).filter(
- CoverageRecord.operation==None
+ records = (
+ self._db.query(CoverageRecord)
+ .filter(CoverageRecord.identifier_id == edition.primary_identifier.id)
+ .filter(CoverageRecord.operation == None)
)
assert 2 == records.count()
for record in records.all():
assert CoverageRecord.SUCCESS == record.status
# But now there's also a metadata upload failure.
- records = self._db.query(
- CoverageRecord
- ).filter(
- CoverageRecord.identifier_id==edition.primary_identifier.id
- ).filter(
- CoverageRecord.operation==CoverageRecord.METADATA_UPLOAD_OPERATION
+ records = (
+ self._db.query(CoverageRecord)
+ .filter(CoverageRecord.identifier_id == edition.primary_identifier.id)
+ .filter(
+ CoverageRecord.operation == CoverageRecord.METADATA_UPLOAD_OPERATION
+ )
)
assert 1 == records.count()
assert CoverageRecord.TRANSIENT_FAILURE == records.all()[0].status
-
-
def test_update_contributions(self):
edition = self._edition()
@@ -1437,7 +1513,7 @@ def test_update_contributions(self):
wikipedia_name="Robert_Jordan",
viaf="79096089",
lc="123",
- roles=[Contributor.PRIMARY_AUTHOR_ROLE]
+ roles=[Contributor.PRIMARY_AUTHOR_ROLE],
)
metadata = Metadata(DataSource.OVERDRIVE, contributors=[contributor])
@@ -1478,7 +1554,6 @@ def test_filter_recommendations(self):
# The genuwine article.
assert known_identifier == result
-
def test_metadata_can_be_deepcopied(self):
# Check that we didn't put something in the metadata that
# will prevent it from being copied. (e.g., self.log)
@@ -1488,12 +1563,14 @@ def test_metadata_can_be_deepcopied(self):
identifier = IdentifierData(Identifier.GUTENBERG_ID, "1")
link = LinkData(Hyperlink.OPEN_ACCESS_DOWNLOAD, "example.epub")
measurement = MeasurementData(Measurement.RATING, 5)
- circulation = CirculationData(data_source=DataSource.GUTENBERG,
+ circulation = CirculationData(
+ data_source=DataSource.GUTENBERG,
primary_identifier=identifier,
licenses_owned=0,
licenses_available=0,
licenses_reserved=0,
- patrons_in_hold_queue=0)
+ patrons_in_hold_queue=0,
+ )
primary_as_data = IdentifierData(
type=identifier.type, identifier=identifier.identifier
)
@@ -1507,7 +1584,6 @@ def test_metadata_can_be_deepcopied(self):
links=[link],
measurements=[measurement],
circulation=circulation,
-
title="Hello Title",
subtitle="Subtle Hello",
sort_title="Sorting Howdy",
@@ -1528,18 +1604,20 @@ def test_metadata_can_be_deepcopied(self):
# If deepcopy didn't throw an exception we're ok.
assert m_copy is not None
-
def test_links_filtered(self):
# test that filter links to only metadata-relevant ones
link1 = LinkData(Hyperlink.OPEN_ACCESS_DOWNLOAD, "example.epub")
link2 = LinkData(rel=Hyperlink.IMAGE, href="http://example.com/")
link3 = LinkData(rel=Hyperlink.DESCRIPTION, content="foo")
link4 = LinkData(
- rel=Hyperlink.THUMBNAIL_IMAGE, href="http://thumbnail.com/",
+ rel=Hyperlink.THUMBNAIL_IMAGE,
+ href="http://thumbnail.com/",
media_type=Representation.JPEG_MEDIA_TYPE,
)
link5 = LinkData(
- rel=Hyperlink.IMAGE, href="http://example.com/", thumbnail=link4,
+ rel=Hyperlink.IMAGE,
+ href="http://example.com/",
+ thumbnail=link4,
media_type=Representation.JPEG_MEDIA_TYPE,
)
links = [link1, link2, link3, link4, link5]
@@ -1551,13 +1629,12 @@ def test_links_filtered(self):
links=links,
)
- filtered_links = sorted(metadata.links, key=lambda x:x.rel)
+ filtered_links = sorted(metadata.links, key=lambda x: x.rel)
assert [link2, link5, link4, link3] == filtered_links
class TestCirculationData(DatabaseTest):
-
def test_apply_propagates_analytics(self):
# Verify that an Analytics object is always passed into
# license_pool() and update_availability(), even if none is
@@ -1575,12 +1652,15 @@ class MockLicensePool(object):
delivery_mechanisms = []
licenses = []
work = None
+
def calculate_work(self):
return None, False
+
def update_availability(self, **kwargs):
self.update_availability_called_with = kwargs
pool = MockLicensePool()
+
class MockCirculationData(CirculationData):
# A CirculationData-like object that always says
# update_availability ought to be called on a
@@ -1605,7 +1685,7 @@ def _availability_needs_update(self, *args):
# Then, the same Analytics object was passed into the
# update_availability() method of the MockLicensePool returned
# by license_pool()
- analytics2 = pool.update_availability_called_with['analytics']
+ analytics2 = pool.update_availability_called_with["analytics"]
assert analytics1 == analytics2
# Now try with a ReplacementPolicy that mentions a specific
@@ -1617,23 +1697,30 @@ def _availability_needs_update(self, *args):
# That object was used instead of a generic Analytics object in
# both cases.
assert analytics == data.license_pool_called_with[-1]
- assert analytics == pool.update_availability_called_with['analytics']
+ assert analytics == pool.update_availability_called_with["analytics"]
class TestTimestampData(DatabaseTest):
-
def test_constructor(self):
# By default, all fields are set to None
d = TimestampData()
- for i in (d.service, d.service_type, d.collection_id,
- d.start, d.finish, d.achievements, d.counter,
- d.exception):
+ for i in (
+ d.service,
+ d.service_type,
+ d.collection_id,
+ d.start,
+ d.finish,
+ d.achievements,
+ d.counter,
+ d.exception,
+ ):
assert i == None
# Some, but not all, of the fields can be set to real values.
- d = TimestampData(start="a", finish="b", achievements="c",
- counter="d", exception="e")
+ d = TimestampData(
+ start="a", finish="b", achievements="c", counter="d", exception="e"
+ )
assert "a" == d.start
assert "b" == d.finish
assert "c" == d.achievements
@@ -1701,9 +1788,13 @@ def test_finalize_full(self):
# You can call finalize() with a complete set of arguments.
d = TimestampData()
d.finalize(
- "service", "service_type", self._default_collection,
- start="start", finish="finish", counter="counter",
- exception="exception"
+ "service",
+ "service_type",
+ self._default_collection,
+ start="start",
+ finish="finish",
+ counter="counter",
+ exception="exception",
)
assert "start" == d.start
assert "finish" == d.finish
@@ -1715,9 +1806,13 @@ def test_finalize_full(self):
# the optional fields will be left alone.
new_collection = self._collection()
d.finalize(
- "service2", "service_type2", new_collection,
- start="start2", finish="finish2", counter="counter2",
- exception="exception2"
+ "service2",
+ "service_type2",
+ new_collection,
+ start="start2",
+ finish="finish2",
+ counter="counter2",
+ exception="exception2",
)
# These have changed.
assert "service2" == d.service
@@ -1741,7 +1836,9 @@ def test_apply(self):
d = TimestampData()
with pytest.raises(ValueError) as excinfo:
d.apply(self._db)
- assert "Not enough information to write TimestampData to the database." in str(excinfo.value)
+ assert "Not enough information to write TimestampData to the database." in str(
+ excinfo.value
+ )
# Set the basic timestamp information. Optional fields will stay
# at None.
@@ -1753,7 +1850,7 @@ def test_apply(self):
timestamp = Timestamp.lookup(
self._db, "service", Timestamp.SCRIPT_TYPE, collection
)
- assert (now-timestamp.start).total_seconds() < 2
+ assert (now - timestamp.start).total_seconds() < 2
assert timestamp.start == timestamp.finish
# Now set the optional fields as well.
@@ -1785,9 +1882,8 @@ def test_apply(self):
class TestAssociateWithIdentifiersBasedOnPermanentWorkID(DatabaseTest):
-
def test_success(self):
- pwid = 'pwid1'
+ pwid = "pwid1"
# Here's a print book.
book = self._edition()
@@ -1797,7 +1893,7 @@ def test_success(self):
# Here's an audio book with the same PWID.
audio = self._edition()
audio.medium = Edition.AUDIO_MEDIUM
- audio.permanent_work_id=pwid
+ audio.permanent_work_id = pwid
# Here's an Metadata object for a second print book with the
# same PWID.
@@ -1807,14 +1903,13 @@ def test_success(self):
)
metadata = Metadata(
DataSource.GUTENBERG,
- primary_identifier=identifierdata, medium=Edition.BOOK_MEDIUM
+ primary_identifier=identifierdata,
+ medium=Edition.BOOK_MEDIUM,
)
- metadata.permanent_work_id=pwid
+ metadata.permanent_work_id = pwid
# Call the method we're testing.
- metadata.associate_with_identifiers_based_on_permanent_work_id(
- self._db
- )
+ metadata.associate_with_identifiers_based_on_permanent_work_id(self._db)
# The identifier of the second print book has been associated
# with the identifier of the first print book, but not
@@ -1824,7 +1919,6 @@ def test_success(self):
class TestMARCExtractor(DatabaseTest):
-
def setup_method(self):
super(TestMARCExtractor, self).setup_method()
base_path = os.path.split(__file__)[0]
@@ -1862,10 +1956,13 @@ def test_parser(self):
assert "Canon" in subjects[0].identifier
assert Edition.BOOK_MEDIUM == record.medium
assert 2015 == record.issued.year
- assert 'eng' == record.language
+ assert "eng" == record.language
assert 1 == len(record.links)
- assert "Utterson and Enfield are worried about their friend" in record.links[0].content
+ assert (
+ "Utterson and Enfield are worried about their friend"
+ in record.links[0].content
+ )
def test_name_cleanup(self):
"""Test basic name cleanup techniques."""
diff --git a/tests/test_mirror_uploader.py b/tests/test_mirror_uploader.py
index a3897d1b7..0f79f7e4a 100644
--- a/tests/test_mirror_uploader.py
+++ b/tests/test_mirror_uploader.py
@@ -1,19 +1,32 @@
import pytest
from parameterized import parameterized
-from ..testing import DatabaseTest
from ..config import CannotLoadConfiguration
from ..mirror import MirrorUploader
from ..model import ExternalIntegration
from ..model.configuration import ExternalIntegrationLink
-from ..s3 import S3Uploader, MinIOUploader, MinIOUploaderConfiguration, S3UploaderConfiguration
+from ..s3 import (
+ MinIOUploader,
+ MinIOUploaderConfiguration,
+ S3Uploader,
+ S3UploaderConfiguration,
+)
+from ..testing import DatabaseTest
from ..util.datetime_helpers import utc_now
+
class DummySuccessUploader(MirrorUploader):
def __init__(self, integration=None):
pass
- def book_url(self, identifier, extension='.epub', open_access=True, data_source=None, title=None):
+ def book_url(
+ self,
+ identifier,
+ extension=".epub",
+ open_access=True,
+ data_source=None,
+ title=None,
+ ):
pass
def cover_image_url(self, data_source, identifier, filename=None, scaled_size=None):
@@ -33,7 +46,14 @@ class DummyFailureUploader(MirrorUploader):
def __init__(self, integration=None):
pass
- def book_url(self, identifier, extension='.epub', open_access=True, data_source=None, title=None):
+ def book_url(
+ self,
+ identifier,
+ extension=".epub",
+ open_access=True,
+ data_source=None,
+ title=None,
+ ):
pass
def cover_image_url(self, data_source, identifier, filename=None, scaled_size=None):
@@ -63,22 +83,26 @@ def _integration(self):
integration.name = storage_name
return integration
- @parameterized.expand([
- ('s3_uploader', ExternalIntegration.S3, S3Uploader),
- (
- 'minio_uploader',
+ @parameterized.expand(
+ [
+ ("s3_uploader", ExternalIntegration.S3, S3Uploader),
+ (
+ "minio_uploader",
ExternalIntegration.MINIO,
MinIOUploader,
- {MinIOUploaderConfiguration.ENDPOINT_URL: 'http://localhost'}
- )
- ])
+ {MinIOUploaderConfiguration.ENDPOINT_URL: "http://localhost"},
+ ),
+ ]
+ )
def test_mirror(self, name, protocol, uploader_class, settings=None):
storage_name = "some storage"
# If there's no integration with goal=STORAGE or name=storage_name,
# MirrorUploader.mirror raises an exception.
with pytest.raises(CannotLoadConfiguration) as excinfo:
MirrorUploader.mirror(self._db, storage_name)
- assert "No storage integration with name 'some storage' is configured" in str(excinfo.value)
+ assert "No storage integration with name 'some storage' is configured" in str(
+ excinfo.value
+ )
# If there's only one, mirror() uses it to initialize a
# MirrorUploader.
@@ -99,7 +123,9 @@ def test_integration_by_name(self):
# No name was passed so nothing is found
with pytest.raises(CannotLoadConfiguration) as excinfo:
MirrorUploader.integration_by_name(self._db)
- assert "No storage integration with name 'None' is configured" in str(excinfo.value)
+ assert "No storage integration with name 'None' is configured" in str(
+ excinfo.value
+ )
# Correct name was passed
integration = MirrorUploader.integration_by_name(self._db, integration.name)
@@ -109,36 +135,41 @@ def test_for_collection(self):
# This collection has no mirror_integration, so
# there is no MirrorUploader for it.
collection = self._collection()
- assert None == MirrorUploader.for_collection(collection, ExternalIntegrationLink.COVERS)
+ assert None == MirrorUploader.for_collection(
+ collection, ExternalIntegrationLink.COVERS
+ )
# This collection has a properly configured mirror_integration,
# so it can have an MirrorUploader.
integration = self._external_integration(
- ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL,
- username="username", password="password",
- settings={S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "some-covers"}
+ ExternalIntegration.S3,
+ ExternalIntegration.STORAGE_GOAL,
+ username="username",
+ password="password",
+ settings={S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "some-covers"},
)
integration_link = self._external_integration_link(
integration=collection._external_integration,
other_integration=integration,
- purpose=ExternalIntegrationLink.COVERS
+ purpose=ExternalIntegrationLink.COVERS,
)
- uploader = MirrorUploader.for_collection(collection, ExternalIntegrationLink.COVERS)
+ uploader = MirrorUploader.for_collection(
+ collection, ExternalIntegrationLink.COVERS
+ )
assert isinstance(uploader, MirrorUploader)
- @parameterized.expand([
- (
- 's3_uploader',
- ExternalIntegration.S3, S3Uploader
- ),
- (
- 'minio_uploader',
+ @parameterized.expand(
+ [
+ ("s3_uploader", ExternalIntegration.S3, S3Uploader),
+ (
+ "minio_uploader",
ExternalIntegration.MINIO,
MinIOUploader,
- {MinIOUploaderConfiguration.ENDPOINT_URL: 'http://localhost'}
- )
- ])
+ {MinIOUploaderConfiguration.ENDPOINT_URL: "http://localhost"},
+ ),
+ ]
+ )
def test_constructor(self, name, protocol, uploader_class, settings=None):
# You can't create a MirrorUploader with an integration
# that's not designed for storage.
@@ -179,13 +210,13 @@ def test_mirror_batch(self):
def test_success_and_then_failure(self):
r, ignore = self._representation()
now = utc_now()
- DummySuccessUploader().mirror_one(r, '')
+ DummySuccessUploader().mirror_one(r, "")
assert r.mirrored_at > now
assert None == r.mirror_exception
# Even if the original upload succeeds, a subsequent upload
# may fail in a way that leaves the image in an inconsistent
# state.
- DummyFailureUploader().mirror_one(r, '')
+ DummyFailureUploader().mirror_one(r, "")
assert None == r.mirrored_at
assert "I always fail." == r.mirror_exception
diff --git a/tests/test_monitor.py b/tests/test_monitor.py
index 051990bde..af256d651 100644
--- a/tests/test_monitor.py
+++ b/tests/test_monitor.py
@@ -1,7 +1,7 @@
import datetime
+
import pytest
-from ..testing import DatabaseTest
from ..config import Configuration
from ..metadata_layer import TimestampData
from ..model import (
@@ -55,10 +55,12 @@
)
from ..testing import (
AlwaysSuccessfulCoverageProvider,
+ DatabaseTest,
NeverSuccessfulCoverageProvider,
)
from ..util.datetime_helpers import datetime_utc, utc_now
+
class MockMonitor(Monitor):
SERVICE_NAME = "Dummy monitor for test"
@@ -77,9 +79,7 @@ def cleanup(self):
class TestMonitor(DatabaseTest):
-
def test_must_define_service_name(self):
-
class NoServiceName(MockMonitor):
SERVICE_NAME = None
@@ -117,6 +117,7 @@ def test_monitor_lifecycle(self):
# There is no timestamp for this monitor.
def get_timestamp():
return get_one(self._db, Timestamp, service=monitor.service_name)
+
assert None == get_timestamp()
# Run the monitor.
@@ -154,6 +155,7 @@ class NeverRunMonitor(MockMonitor):
class RunLongAgoMonitor(MockMonitor):
SERVICE_NAME = "Run long ago"
DEFAULT_START_TIME = MockMonitor.ONE_YEAR_AGO
+
# The Timestamp object is created, and its .timestamp is long ago.
m = RunLongAgoMonitor(self._db, self._default_collection)
timestamp = m.timestamp()
@@ -175,6 +177,7 @@ def test_run_once_returning_timestampdata(self):
class Mock(MockMonitor):
def run_once(self, progress):
return TimestampData(start=start, finish=finish, counter=-100)
+
monitor = Mock(self._db, self._default_collection)
monitor.run()
@@ -207,8 +210,10 @@ def assert_run_sets_exception(monitor, check_for):
# Try a monitor that raises an unhandled exception.
class DoomedMonitor(MockMonitor):
SERVICE_NAME = "Doomed"
+
def run_once(self, *args, **kwargs):
raise Exception("I'm doomed")
+
m = DoomedMonitor(self._db, self._default_collection)
assert_run_sets_exception(m, "Exception: I'm doomed")
@@ -216,8 +221,10 @@ def run_once(self, *args, **kwargs):
# returns.
class AlsoDoomed(MockMonitor):
SERVICE_NAME = "Doomed, but in a different way."
+
def run_once(self, progress):
return TimestampData(exception="I'm also doomed")
+
m = AlsoDoomed(self._db, self._default_collection)
assert_run_sets_exception(m, "I'm also doomed")
@@ -269,6 +276,7 @@ def test_protocol_enforcement(self):
"""A CollectionMonitor can require that it be instantiated
with a Collection that implements a certain protocol.
"""
+
class NoProtocolMonitor(CollectionMonitor):
SERVICE_NAME = "Test Monitor 1"
PROTOCOL = None
@@ -291,12 +299,16 @@ class OverdriveMonitor(CollectionMonitor):
OverdriveMonitor(self._db, c1)
with pytest.raises(ValueError) as excinfo:
OverdriveMonitor(self._db, c2)
- assert "Collection protocol (Bibliotheca) does not match Monitor protocol (Overdrive)" in str(excinfo.value)
+ assert (
+ "Collection protocol (Bibliotheca) does not match Monitor protocol (Overdrive)"
+ in str(excinfo.value)
+ )
with pytest.raises(CollectionMissing):
OverdriveMonitor(self._db, None)
def test_all(self):
"""Test that we can create a list of Monitors using all()."""
+
class OPDSCollectionMonitor(CollectionMonitor):
SERVICE_NAME = "Test Monitor"
PROTOCOL = ExternalIntegration.OPDS_IMPORT
@@ -311,23 +323,22 @@ class OPDSCollectionMonitor(CollectionMonitor):
# o1 just had its Monitor run.
Timestamp.stamp(
- self._db, OPDSCollectionMonitor.SERVICE_NAME,
- Timestamp.MONITOR_TYPE, o1
+ self._db, OPDSCollectionMonitor.SERVICE_NAME, Timestamp.MONITOR_TYPE, o1
)
# o2 and b1 have never had their Monitor run, but o2 has had some other Monitor run.
- Timestamp.stamp(
- self._db, "A Different Service", Timestamp.MONITOR_TYPE,
- o2
- )
+ Timestamp.stamp(self._db, "A Different Service", Timestamp.MONITOR_TYPE, o2)
# o3 had its Monitor run an hour ago.
now = utc_now()
an_hour_ago = now - datetime.timedelta(seconds=3600)
Timestamp.stamp(
- self._db, OPDSCollectionMonitor.SERVICE_NAME,
- Timestamp.MONITOR_TYPE, o3, start=an_hour_ago,
- finish=an_hour_ago
+ self._db,
+ OPDSCollectionMonitor.SERVICE_NAME,
+ Timestamp.MONITOR_TYPE,
+ o3,
+ start=an_hour_ago,
+ finish=an_hour_ago,
)
monitors = list(OPDSCollectionMonitor.all(self._db))
@@ -341,7 +352,9 @@ class OPDSCollectionMonitor(CollectionMonitor):
# If `collections` are specified, monitors should be yielded in the same order.
opds_collections = [o3, o1, o2]
- monitors = list(OPDSCollectionMonitor.all(self._db, collections=opds_collections))
+ monitors = list(
+ OPDSCollectionMonitor.all(self._db, collections=opds_collections)
+ )
monitor_collections = [m.collection for m in monitors]
# We should get a monitor for each collection.
assert set(opds_collections) == set(monitor_collections)
@@ -350,7 +363,9 @@ class OPDSCollectionMonitor(CollectionMonitor):
# If `collections` are specified, monitors should be yielded in the same order.
opds_collections = [o3, o1]
- monitors = list(OPDSCollectionMonitor.all(self._db, collections=opds_collections))
+ monitors = list(
+ OPDSCollectionMonitor.all(self._db, collections=opds_collections)
+ )
monitor_collections = [m.collection for m in monitors]
# We should get a monitor for each collection.
assert set(opds_collections) == set(monitor_collections)
@@ -360,16 +375,19 @@ class OPDSCollectionMonitor(CollectionMonitor):
# If collections are specified, they must match the monitor's protocol.
with pytest.raises(ValueError) as excinfo:
monitors = list(OPDSCollectionMonitor.all(self._db, collections=[b1]))
- assert 'Collection protocol (Bibliotheca) does not match Monitor protocol (OPDS Import)' in str(excinfo.value)
- assert 'Only the following collections are available: ' in str(excinfo.value)
+ assert (
+ "Collection protocol (Bibliotheca) does not match Monitor protocol (OPDS Import)"
+ in str(excinfo.value)
+ )
+ assert "Only the following collections are available: " in str(excinfo.value)
class TestTimelineMonitor(DatabaseTest):
-
def test_run_once(self):
class Mock(TimelineMonitor):
SERVICE_NAME = "Just a timeline"
catchups = []
+
def catch_up_from(self, start, cutoff, progress):
self.catchups.append((start, cutoff, progress))
@@ -394,9 +412,11 @@ def test_subclass_cannot_modify_dates(self):
If you want that, you shouldn't subclass TimelineMonitor.
"""
+
class Mock(TimelineMonitor):
DEFAULT_START_TIME = Monitor.NEVER
SERVICE_NAME = "I aim to misbehave"
+
def catch_up_from(self, start, cutoff, progress):
progress.start = 1
progress.finish = 2
@@ -421,9 +441,11 @@ def test_timestamp_not_updated_on_exception(self):
"""If the subclass sets .exception on the TimestampData
passed into it, the dates aren't modified.
"""
+
class Mock(TimelineMonitor):
DEFAULT_START_TIME = datetime_utc(2011, 1, 1)
SERVICE_NAME = "doomed"
+
def catch_up_from(self, start, cutoff, progress):
self.started_at = start
progress.exception = "oops"
@@ -466,6 +488,7 @@ def test_slice_timespan(self):
class MockSweepMonitor(SweepMonitor):
"""A SweepMonitor that does nothing."""
+
MODEL_CLASS = Identifier
SERVICE_NAME = "Sweep Monitor"
DEFAULT_BATCH_SIZE = 2
@@ -491,7 +514,6 @@ def cleanup(self):
class TestSweepMonitor(DatabaseTest):
-
def setup_method(self):
super(TestSweepMonitor, self).setup_method()
self.monitor = MockSweepMonitor(self._db)
@@ -499,6 +521,7 @@ def setup_method(self):
def test_model_class_is_required(self):
class NoModelClass(SweepMonitor):
MODEL_CLASS = None
+
with pytest.raises(ValueError) as excinfo:
NoModelClass(self._db)
assert "NoModelClass must define MODEL_CLASS" in str(excinfo.value)
@@ -551,9 +574,10 @@ def test_run_starts_at_previous_counter(self):
# The monitor was just run, but it was not able to proceed past
# i1.
timestamp = Timestamp.stamp(
- self._db, self.monitor.service_name,
+ self._db,
+ self.monitor.service_name,
Timestamp.MONITOR_TYPE,
- self.monitor.collection
+ self.monitor.collection,
)
timestamp.counter = i1.id
@@ -619,7 +643,6 @@ def process_item(self, item):
class TestIdentifierSweepMonitor(DatabaseTest):
-
def test_scope_to_collection(self):
# Two Collections, each with a LicensePool.
c1 = self._collection()
@@ -644,9 +667,7 @@ class Mock(IdentifierSweepMonitor):
class TestSubjectSweepMonitor(DatabaseTest):
-
def test_item_query(self):
-
class Mock(SubjectSweepMonitor):
SERVICE_NAME = "Mock"
@@ -677,7 +698,6 @@ class Mock(SubjectSweepMonitor):
class TestCustomListEntrySweepMonitor(DatabaseTest):
-
def test_item_query(self):
class Mock(CustomListEntrySweepMonitor):
SERVICE_NAME = "Mock"
@@ -712,7 +732,6 @@ class Mock(CustomListEntrySweepMonitor):
class TestEditionSweepMonitor(DatabaseTest):
-
def test_item_query(self):
class Mock(EditionSweepMonitor):
SERVICE_NAME = "Mock"
@@ -783,6 +802,7 @@ class Mock(WorkSweepMonitor):
# works that are not presentation ready.
class Mock(PresentationReadyWorkSweepMonitor):
SERVICE_NAME = "Mock"
+
assert [w1, w4] == Mock(self._db).item_query().all()
assert [w1] == Mock(self._db, collection=c1).item_query().all()
assert [] == Mock(self._db, collection=c2).item_query().all()
@@ -791,18 +811,19 @@ class Mock(PresentationReadyWorkSweepMonitor):
# includes works that are not presentation ready.
class Mock(NotPresentationReadyWorkSweepMonitor):
SERVICE_NAME = "Mock"
+
assert [w2, w3] == Mock(self._db).item_query().all()
assert [] == Mock(self._db, collection=c1).item_query().all()
assert [w2] == Mock(self._db, collection=c2).item_query().all()
-
class TestOPDSEntryCacheMonitor(DatabaseTest):
-
def test_process_item(self):
"""This Monitor calculates OPDS entries for works."""
+
class Mock(OPDSEntryCacheMonitor):
SERVICE_NAME = "Mock"
+
monitor = Mock(self._db)
work = self._work()
assert None == work.simple_opds_entry
@@ -814,11 +835,12 @@ class Mock(OPDSEntryCacheMonitor):
class TestPermanentWorkIDRefresh(DatabaseTest):
-
def test_process_item(self):
"""This Monitor calculates an Editions' permanent work ID."""
+
class Mock(PermanentWorkIDRefreshMonitor):
SERVICE_NAME = "Mock"
+
edition = self._edition()
assert None == edition.permanent_work_id
Mock(self._db).process_item(edition)
@@ -826,7 +848,6 @@ class Mock(PermanentWorkIDRefreshMonitor):
class TestMakePresentationReadyMonitor(DatabaseTest):
-
def setup_method(self):
super(TestMakePresentationReadyMonitor, self).setup_method()
@@ -845,8 +866,7 @@ class MockProvider2(NeverSuccessfulCoverageProvider):
self.success = MockProvider1(self._db)
self.failure = MockProvider2(self._db)
- self.work = self._work(
- DataSource.GUTENBERG, with_license_pool=True)
+ self.work = self._work(DataSource.GUTENBERG, with_license_pool=True)
# Don't fake that the work is presentation ready, as we usually do,
# because presentation readiness is what we're trying to test.
self.work.presentation_ready = False
@@ -862,19 +882,16 @@ def test_process_item_sets_presentation_ready_on_success(self):
assert True == self.work.presentation_ready
def test_process_item_sets_exception_on_failure(self):
- monitor = MakePresentationReadyMonitor(
- self._db, [self.success, self.failure]
- )
+ monitor = MakePresentationReadyMonitor(self._db, [self.success, self.failure])
monitor.process_item(self.work)
assert (
- "Provider(s) failed: %s" % self.failure.SERVICE_NAME ==
- self.work.presentation_ready_exception)
+ "Provider(s) failed: %s" % self.failure.SERVICE_NAME
+ == self.work.presentation_ready_exception
+ )
assert False == self.work.presentation_ready
def test_prepare_raises_exception_with_failing_providers(self):
- monitor = MakePresentationReadyMonitor(
- self._db, [self.success, self.failure]
- )
+ monitor = MakePresentationReadyMonitor(self._db, [self.success, self.failure])
with pytest.raises(CoverageProvidersFailed) as excinfo:
monitor.prepare(self.work)
assert self.failure.service_name in str(excinfo.value)
@@ -888,8 +905,9 @@ def test_prepare_does_not_call_irrelevant_provider(self):
assert [] == result
# The 'success' monitor ran.
- assert ([self.work.presentation_edition.primary_identifier] ==
- self.success.attempts)
+ assert [
+ self.work.presentation_edition.primary_identifier
+ ] == self.success.attempts
# The 'failure' monitor did not. (If it had, it would have
# failed.)
@@ -901,7 +919,6 @@ def test_prepare_does_not_call_irrelevant_provider(self):
class TestCustomListEntryWorkUpdateMonitor(DatabaseTest):
-
def test_set_item(self):
# Create a CustomListEntry.
@@ -920,11 +937,10 @@ def test_set_item(self):
class MockReaperMonitor(ReaperMonitor):
MODEL_CLASS = Timestamp
- TIMESTAMP_FIELD = 'timestamp'
+ TIMESTAMP_FIELD = "timestamp"
class TestReaperMonitor(DatabaseTest):
-
def test_cutoff(self):
"""Test that cutoff behaves correctly when given different values for
ReaperMonitor.MAX_AGE.
@@ -934,9 +950,7 @@ def test_cutoff(self):
# A number here means a number of days.
for value in [1, 1.5, -1]:
m.MAX_AGE = value
- expect = utc_now() - datetime.timedelta(
- days=value
- )
+ expect = utc_now() - datetime.timedelta(days=value)
self.time_eq(m.cutoff, expect)
# But you can pass in a timedelta instead.
@@ -948,7 +962,9 @@ def test_specific_reapers(self):
assert 30 == CachedFeedReaper.MAX_AGE
assert Credential.expires == CredentialReaper(self._db).timestamp_field
assert 1 == CredentialReaper.MAX_AGE
- assert Patron.authorization_expires == PatronRecordReaper(self._db).timestamp_field
+ assert (
+ Patron.authorization_expires == PatronRecordReaper(self._db).timestamp_field
+ )
assert 60 == PatronRecordReaper.MAX_AGE
def test_where_clause(self):
@@ -960,16 +976,12 @@ def test_run_once(self):
expired1 = self._credential()
expired2 = self._credential()
now = utc_now()
- expiration_date = now - datetime.timedelta(
- days=CredentialReaper.MAX_AGE + 1
- )
+ expiration_date = now - datetime.timedelta(days=CredentialReaper.MAX_AGE + 1)
for e in [expired1, expired2]:
e.expires = expiration_date
active = self._credential()
- active.expires = now - datetime.timedelta(
- days=CredentialReaper.MAX_AGE - 1
- )
+ active.expires = now - datetime.timedelta(days=CredentialReaper.MAX_AGE - 1)
eternal = self._credential()
@@ -997,9 +1009,7 @@ def test_reap_patrons(self):
days=PatronRecordReaper.MAX_AGE + 1
)
active = self._patron()
- active.expires = now - datetime.timedelta(
- days=PatronRecordReaper.MAX_AGE - 1
- )
+ active.expires = now - datetime.timedelta(days=PatronRecordReaper.MAX_AGE - 1)
result = m.run_once()
assert "Items deleted: 1" == result.achievements
remaining = self._db.query(Patron).all()
@@ -1009,10 +1019,9 @@ def test_reap_patrons(self):
class TestWorkReaper(DatabaseTest):
-
def test_end_to_end(self):
# Search mock
- class MockSearchIndex():
+ class MockSearchIndex:
removed = []
def remove_work(self, work):
@@ -1059,15 +1068,13 @@ def remove_work(self, work):
# Each work has a CachedFeed.
for work in works:
feed = CachedFeed(
- work=work, type='page', content="content",
- pagination="", facets=""
+ work=work, type="page", content="content", pagination="", facets=""
)
self._db.add(feed)
# Also create a CachedFeed that has no associated Work.
workless_feed = CachedFeed(
- work=None, type='page', content="content",
- pagination="", facets=""
+ work=None, type="page", content="content", pagination="", facets=""
)
self._db.add(workless_feed)
@@ -1099,7 +1106,7 @@ def remove_work(self, work):
assert [has_license_pool] == genre.works
surviving_records = self._db.query(WorkCoverageRecord)
assert surviving_records.count() > 0
- assert all(x.work==has_license_pool for x in surviving_records)
+ assert all(x.work == has_license_pool for x in surviving_records)
# The CustomListEntries still exist, but two of them have lost
# their work.
@@ -1116,7 +1123,6 @@ def remove_work(self, work):
class TestCollectionReaper(DatabaseTest):
-
def test_query(self):
# This reaper is looking for collections that are marked for
# deletion.
@@ -1134,6 +1140,7 @@ def test_reaper_delete_calls_collection_delete(self):
class MockCollection(object):
def delete(self):
self.was_called = True
+
collection = MockCollection()
reaper = CollectionReaper(self._db)
reaper.delete(collection)
@@ -1154,12 +1161,11 @@ def test_run_once(self):
class TestMeasurementReaper(DatabaseTest):
-
def test_query(self):
# This reaper is looking for measurements that are not current.
measurement, created = get_one_or_create(
- self._db, Measurement,
- is_most_recent=True)
+ self._db, Measurement, is_most_recent=True
+ )
reaper = MeasurementReaper(self._db)
assert [] == reaper.query().all()
measurement.is_most_recent = False
@@ -1168,15 +1174,19 @@ def test_query(self):
def test_run_once(self):
# End-to-end test
measurement1, created = get_one_or_create(
- self._db, Measurement,
+ self._db,
+ Measurement,
quantity_measured="answer",
value=12,
- is_most_recent=True)
+ is_most_recent=True,
+ )
measurement2, created = get_one_or_create(
- self._db, Measurement,
+ self._db,
+ Measurement,
quantity_measured="answer",
value=42,
- is_most_recent=False)
+ is_most_recent=False,
+ )
reaper = MeasurementReaper(self._db)
result = reaper.run_once()
assert [measurement1] == self._db.query(Measurement).all()
@@ -1184,18 +1194,24 @@ def test_run_once(self):
def test_disable(self):
# This reaper can be disabled with a configuration setting
- enabled = ConfigurationSetting.sitewide(self._db, Configuration.MEASUREMENT_REAPER)
+ enabled = ConfigurationSetting.sitewide(
+ self._db, Configuration.MEASUREMENT_REAPER
+ )
enabled.value = False
measurement1, created = get_one_or_create(
- self._db, Measurement,
+ self._db,
+ Measurement,
quantity_measured="answer",
value=12,
- is_most_recent=True)
+ is_most_recent=True,
+ )
measurement2, created = get_one_or_create(
- self._db, Measurement,
+ self._db,
+ Measurement,
quantity_measured="answer",
value=42,
- is_most_recent=False)
+ is_most_recent=False,
+ )
reaper = MeasurementReaper(self._db)
reaper.run()
assert [measurement1, measurement2] == self._db.query(Measurement).all()
@@ -1205,7 +1221,6 @@ def test_disable(self):
class TestScrubberMonitor(DatabaseTest):
-
def test_run_once(self):
# ScrubberMonitor is basically an abstract class, with
# subclasses doing nothing but define missing constants. This
@@ -1218,22 +1233,14 @@ def test_run_once(self):
# CirculationEvents are only scrubbed if they have a location
# *and* are older than MAX_AGE.
now = utc_now()
- not_long_ago = (
- m.cutoff + datetime.timedelta(days=1)
- )
- long_ago = (
- m.cutoff - datetime.timedelta(days=1)
- )
+ not_long_ago = m.cutoff + datetime.timedelta(days=1)
+ long_ago = m.cutoff - datetime.timedelta(days=1)
- new, ignore = create(
- self._db, CirculationEvent, start=now, location="loc"
- )
+ new, ignore = create(self._db, CirculationEvent, start=now, location="loc")
recent, ignore = create(
self._db, CirculationEvent, start=not_long_ago, location="loc"
)
- old, ignore = create(
- self._db, CirculationEvent, start=long_ago, location="loc"
- )
+ old, ignore = create(self._db, CirculationEvent, start=long_ago, location="loc")
already_scrubbed, ignore = create(
self._db, CirculationEvent, start=long_ago, location=None
)
diff --git a/tests/test_opds.py b/tests/test_opds.py
index d29d5d5d6..d880cc4b1 100644
--- a/tests/test_opds.py
+++ b/tests/test_opds.py
@@ -9,9 +9,6 @@
from lxml import etree
from psycopg2.extras import NumericRange
-from ..testing import (
- DatabaseTest,
-)
from ..classifier import (
Classifier,
Contemporary_Romance,
@@ -19,10 +16,7 @@
Fantasy,
History,
)
-from ..config import (
- Configuration,
- temp_config,
-)
+from ..config import Configuration, temp_config
from ..entrypoint import (
AudiobooksEntryPoint,
EbooksEntryPoint,
@@ -31,13 +25,7 @@
)
from ..external_search import MockExternalSearchIndex
from ..facets import FacetConstants
-from ..lane import (
- Facets,
- FeaturedFacets,
- Pagination,
- SearchFacets,
- WorkList,
-)
+from ..lane import Facets, FeaturedFacets, Pagination, SearchFacets, WorkList
from ..lcp.credential import LCPCredentialFactory
from ..model import (
CachedFeed,
@@ -52,8 +40,8 @@
Representation,
Subject,
Work,
- get_one,
create,
+ get_one,
)
from ..opds import (
AcquisitionFeed,
@@ -61,45 +49,31 @@
LookupAcquisitionFeed,
NavigationFacets,
NavigationFeed,
- VerboseAnnotator,
TestAnnotator,
TestAnnotatorWithGroup,
- TestUnfulfillableAnnotator
+ TestUnfulfillableAnnotator,
+ VerboseAnnotator,
)
from ..opds_import import OPDSXMLParser
-from ..util.flask_util import (
- OPDSEntryResponse,
- OPDSFeedResponse,
- Response,
-)
-from ..util.opds_writer import (
- AtomFeed,
- OPDSFeed,
- OPDSMessage,
-)
+from ..testing import DatabaseTest
from ..util.datetime_helpers import datetime_utc, utc_now
+from ..util.flask_util import OPDSEntryResponse, OPDSFeedResponse, Response
+from ..util.opds_writer import AtomFeed, OPDSFeed, OPDSMessage
class TestBaseAnnotator(DatabaseTest):
-
def test_authors(self):
# Create an Edition with an author and a narrator.
edition = self._edition(authors=[])
- edition.add_contributor(
- "Steven King", Contributor.PRIMARY_AUTHOR_ROLE
- )
- edition.add_contributor(
- "Jonathan Frakes", Contributor.NARRATOR_ROLE
- )
+ edition.add_contributor("Steven King", Contributor.PRIMARY_AUTHOR_ROLE)
+ edition.add_contributor("Jonathan Frakes", Contributor.NARRATOR_ROLE)
author, contributor = sorted(
- Annotator.authors(None, edition),
- key=lambda x: x.tag
+ Annotator.authors(None, edition), key=lambda x: x.tag
)
-
# The tag indicates a role of 'author', so there's no
# need for an explicitly specified role property.
- assert 'author' == author.tag
+ assert "author" == author.tag
[name] = author
assert "name" == name.tag
assert "King, Steven" == name.text
@@ -107,17 +81,18 @@ def test_authors(self):
# The tag includes an explicitly specified role
# property to explain the nature of the contribution.
- assert 'contributor' == contributor.tag
+ assert "contributor" == contributor.tag
[name] = contributor
assert "name" == name.tag
assert "Frakes, Jonathan" == name.text
- role_attrib = '{%s}role' % AtomFeed.OPF_NS
- assert (Contributor.MARC_ROLE_CODES[Contributor.NARRATOR_ROLE] ==
- contributor.attrib[role_attrib])
+ role_attrib = "{%s}role" % AtomFeed.OPF_NS
+ assert (
+ Contributor.MARC_ROLE_CODES[Contributor.NARRATOR_ROLE]
+ == contributor.attrib[role_attrib]
+ )
def test_annotate_work_entry_adds_tags(self):
- work = self._work(with_license_pool=True,
- with_open_access_download=True)
+ work = self._work(with_license_pool=True, with_open_access_download=True)
work.last_update_time = datetime_utc(2018, 2, 5, 7, 39, 49, 580651)
[pool] = work.license_pools
pool.availability_time = datetime_utc(2015, 1, 1)
@@ -132,32 +107,37 @@ def test_annotate_work_entry_adds_tags(self):
[id, distributor, published, updated] = entry
id_tag = etree.tounicode(id)
- assert 'id' in id_tag
+ assert "id" in id_tag
assert pool.identifier.urn in id_tag
assert 'ProviderName="Gutenberg"' in etree.tounicode(distributor)
published_tag = etree.tounicode(published)
- assert 'published' in published_tag
- assert '2015-01-01' in published_tag
+ assert "published" in published_tag
+ assert "2015-01-01" in published_tag
updated_tag = etree.tounicode(updated)
- assert 'updated' in updated_tag
- assert '2018-02-05' in updated_tag
+ assert "updated" in updated_tag
+ assert "2018-02-05" in updated_tag
entry = []
# We can pass in a specific update time to override the one
# found in work.last_update_time.
annotator.annotate_work_entry(
- work, pool, None, None, None, entry,
- updated=datetime_utc(2017, 1, 2, 3, 39, 49, 580651)
+ work,
+ pool,
+ None,
+ None,
+ None,
+ entry,
+ updated=datetime_utc(2017, 1, 2, 3, 39, 49, 580651),
)
[id, distributor, published, updated] = entry
- assert 'updated' in etree.tounicode(updated)
- assert '2017-01-02' in etree.tounicode(updated)
+ assert "updated" in etree.tounicode(updated)
+ assert "2017-01-02" in etree.tounicode(updated)
-class TestAnnotators(DatabaseTest):
+class TestAnnotators(DatabaseTest):
def test_all_subjects(self):
self.work = self._work(genre="Fiction", with_open_access_download=True)
edition = self.work.presentation_edition
@@ -170,7 +150,13 @@ def test_all_subjects(self):
(source1, Subject.LCSH, "lcsh1", "name2", 1),
(source2, Subject.LCSH, "lcsh1", "name2", 1),
(source1, Subject.LCSH, "lcsh2", "name3", 3),
- (source1, Subject.DDC, "300", "Social sciences, sociology & anthropology", 1),
+ (
+ source1,
+ Subject.DDC,
+ "300",
+ "Social sciences, sociology & anthropology",
+ 1,
+ ),
]
for source, subject_type, subject, name, weight in subjects:
@@ -183,6 +169,7 @@ def mock_all_identifier_ids(policy=None):
# Do the actual work so that categories() gets the
# correct information.
return self.work.original_all_identifier_ids(policy)
+
self.work.original_all_identifier_ids = self.work.all_identifier_ids
self.work.all_identifier_ids = mock_all_identifier_ids
category_tags = VerboseAnnotator.categories(self.work)
@@ -195,23 +182,30 @@ def mock_all_identifier_ids(policy=None):
assert 100 == self.work.called_with_policy.equivalent_identifier_cutoff
ddc_uri = Subject.uri_lookup[Subject.DDC]
- rating_value = '{http://schema.org/}ratingValue'
- assert ([{'term': '300',
- rating_value: 1,
- 'label': 'Social sciences, sociology & anthropology'}] ==
- category_tags[ddc_uri])
+ rating_value = "{http://schema.org/}ratingValue"
+ assert [
+ {
+ "term": "300",
+ rating_value: 1,
+ "label": "Social sciences, sociology & anthropology",
+ }
+ ] == category_tags[ddc_uri]
fast_uri = Subject.uri_lookup[Subject.FAST]
- assert ([{'term': 'fast1', 'label': 'name1', rating_value: 1}] ==
- category_tags[fast_uri])
+ assert [{"term": "fast1", "label": "name1", rating_value: 1}] == category_tags[
+ fast_uri
+ ]
lcsh_uri = Subject.uri_lookup[Subject.LCSH]
- assert ([{'term': 'lcsh1', 'label': 'name2', rating_value: 2},
- {'term': 'lcsh2', 'label': 'name3', rating_value: 3}] ==
- sorted(category_tags[lcsh_uri], key=lambda x: x[rating_value]))
+ assert [
+ {"term": "lcsh1", "label": "name2", rating_value: 2},
+ {"term": "lcsh2", "label": "name3", rating_value: 3},
+ ] == sorted(category_tags[lcsh_uri], key=lambda x: x[rating_value])
genre_uri = Subject.uri_lookup[Subject.SIMPLIFIED_GENRE]
- assert [dict(label='Fiction', term=Subject.SIMPLIFIED_GENRE+"Fiction")] == category_tags[genre_uri]
+ assert [
+ dict(label="Fiction", term=Subject.SIMPLIFIED_GENRE + "Fiction")
+ ] == category_tags[genre_uri]
def test_appeals(self):
work = self._work(with_open_access_download=True)
@@ -227,10 +221,10 @@ def test_appeals(self):
(Work.APPEALS_URI + Work.LANGUAGE_APPEAL, Work.LANGUAGE_APPEAL, 0.1),
(Work.APPEALS_URI + Work.CHARACTER_APPEAL, Work.CHARACTER_APPEAL, 0.2),
(Work.APPEALS_URI + Work.STORY_APPEAL, Work.STORY_APPEAL, 0.3),
- (Work.APPEALS_URI + Work.SETTING_APPEAL, Work.SETTING_APPEAL, 0.4)
+ (Work.APPEALS_URI + Work.SETTING_APPEAL, Work.SETTING_APPEAL, 0.4),
]
actual = [
- (x['term'], x['label'], x['{http://schema.org/}ratingValue'])
+ (x["term"], x["label"], x["{http://schema.org/}ratingValue"])
for x in appeal_tags
]
assert set(expect) == set(actual)
@@ -248,7 +242,9 @@ def test_detailed_author(self):
tag_string = etree.tounicode(author_tag)
assert "Givenname Familyname" in tag_string
assert "Familyname, Givenname" in tag_string
- assert "Givenname Familyname (Author)" in tag_string
+ assert (
+ "Givenname Familyname (Author)" in tag_string
+ )
assert "http://viaf.org/viaf/100" in tag_string
assert "http://id.loc.gov/authorities/names/n100"
@@ -278,9 +274,7 @@ def test_all_annotators_mention_every_relevant_author(self):
illustrator, ignore = self._contributor()
barrel_washer, ignore = self._contributor()
- edition.add_contributor(
- primary_author, Contributor.PRIMARY_AUTHOR_ROLE
- )
+ edition.add_contributor(primary_author, Contributor.PRIMARY_AUTHOR_ROLE)
edition.add_contributor(author, Contributor.AUTHOR_ROLE)
# This contributor is relevant because we have a MARC Role Code
@@ -291,41 +285,40 @@ def test_all_annotators_mention_every_relevant_author(self):
# Role Code for the role.
edition.add_contributor(barrel_washer, "Barrel Washer")
- role_attrib = '{%s}role' % AtomFeed.OPF_NS
- illustrator_code = Contributor.MARC_ROLE_CODES[
- Contributor.ILLUSTRATOR_ROLE
- ]
+ role_attrib = "{%s}role" % AtomFeed.OPF_NS
+ illustrator_code = Contributor.MARC_ROLE_CODES[Contributor.ILLUSTRATOR_ROLE]
for annotator in Annotator, VerboseAnnotator:
tags = Annotator.authors(work, edition)
# We made two tags and one
# tag, for the illustrator.
- assert (['author', 'author', 'contributor'] ==
- [x.tag for x in tags])
- assert ([None, None, illustrator_code] ==
- [x.attrib.get(role_attrib) for x in tags])
+ assert ["author", "author", "contributor"] == [x.tag for x in tags]
+ assert [None, None, illustrator_code] == [
+ x.attrib.get(role_attrib) for x in tags
+ ]
def test_ratings(self):
- work = self._work(
- with_license_pool=True, with_open_access_download=True)
- work.quality = 1.0/3
+ work = self._work(with_license_pool=True, with_open_access_download=True)
+ work.quality = 1.0 / 3
work.popularity = 0.25
work.rating = 0.6
work.calculate_opds_entries(verbose=True)
- feed = AcquisitionFeed(
- self._db, self._str, self._url, [work], VerboseAnnotator
- )
+ feed = AcquisitionFeed(self._db, self._str, self._url, [work], VerboseAnnotator)
url = self._url
tag = feed.create_entry(work, None)
- nsmap = dict(schema='http://schema.org/')
- ratings = [(rating.get('{http://schema.org/}ratingValue'),
- rating.get('{http://schema.org/}additionalType'))
- for rating in tag.xpath("schema:Rating", namespaces=nsmap)]
+ nsmap = dict(schema="http://schema.org/")
+ ratings = [
+ (
+ rating.get("{http://schema.org/}ratingValue"),
+ rating.get("{http://schema.org/}additionalType"),
+ )
+ for rating in tag.xpath("schema:Rating", namespaces=nsmap)
+ ]
expected = [
- ('0.3333', Measurement.QUALITY),
- ('0.2500', Measurement.POPULARITY),
- ('0.6000', None)
+ ("0.3333", Measurement.QUALITY),
+ ("0.2500", Measurement.POPULARITY),
+ ("0.6000", None),
]
assert set(expected) == set(ratings)
@@ -334,27 +327,27 @@ def test_subtitle(self):
work.presentation_edition.subtitle = "Return of the Jedi"
work.calculate_opds_entries()
- raw_feed = str(AcquisitionFeed(
- self._db, self._str, self._url, [work], Annotator
- ))
+ raw_feed = str(
+ AcquisitionFeed(self._db, self._str, self._url, [work], Annotator)
+ )
assert "schema:alternativeHeadline" in raw_feed
assert work.presentation_edition.subtitle in raw_feed
feed = feedparser.parse(str(raw_feed))
- alternative_headline = feed['entries'][0]['schema_alternativeheadline']
+ alternative_headline = feed["entries"][0]["schema_alternativeheadline"]
assert work.presentation_edition.subtitle == alternative_headline
# If there's no subtitle, the subtitle tag isn't included.
work.presentation_edition.subtitle = None
work.calculate_opds_entries()
- raw_feed = str(AcquisitionFeed(
- self._db, self._str, self._url, [work], Annotator
- ))
+ raw_feed = str(
+ AcquisitionFeed(self._db, self._str, self._url, [work], Annotator)
+ )
assert "schema:alternativeHeadline" not in raw_feed
assert "Return of the Jedi" not in raw_feed
- [entry] = feedparser.parse(str(raw_feed))['entries']
- assert 'schema_alternativeheadline' not in list(entry.items())
+ [entry] = feedparser.parse(str(raw_feed))["entries"]
+ assert "schema_alternativeheadline" not in list(entry.items())
def test_series(self):
work = self._work(with_license_pool=True, with_open_access_download=True)
@@ -362,55 +355,63 @@ def test_series(self):
work.presentation_edition.series_position = 4
work.calculate_opds_entries()
- raw_feed = str(AcquisitionFeed(
- self._db, self._str, self._url, [work], Annotator
- ))
+ raw_feed = str(
+ AcquisitionFeed(self._db, self._str, self._url, [work], Annotator)
+ )
assert "schema:Series" in raw_feed
assert work.presentation_edition.series in raw_feed
feed = feedparser.parse(str(raw_feed))
- schema_entry = feed['entries'][0]['schema_series']
- assert work.presentation_edition.series == schema_entry['name']
- assert str(work.presentation_edition.series_position) == schema_entry['schema:position']
+ schema_entry = feed["entries"][0]["schema_series"]
+ assert work.presentation_edition.series == schema_entry["name"]
+ assert (
+ str(work.presentation_edition.series_position)
+ == schema_entry["schema:position"]
+ )
# The series position can be 0, for a prequel for example.
work.presentation_edition.series_position = 0
work.calculate_opds_entries()
- raw_feed = str(AcquisitionFeed(
- self._db, self._str, self._url, [work], Annotator
- ))
+ raw_feed = str(
+ AcquisitionFeed(self._db, self._str, self._url, [work], Annotator)
+ )
assert "schema:Series" in raw_feed
assert work.presentation_edition.series in raw_feed
feed = feedparser.parse(str(raw_feed))
- schema_entry = feed['entries'][0]['schema_series']
- assert work.presentation_edition.series == schema_entry['name']
- assert str(work.presentation_edition.series_position) == schema_entry['schema:position']
+ schema_entry = feed["entries"][0]["schema_series"]
+ assert work.presentation_edition.series == schema_entry["name"]
+ assert (
+ str(work.presentation_edition.series_position)
+ == schema_entry["schema:position"]
+ )
# If there's no series title, the series tag isn't included.
work.presentation_edition.series = None
work.calculate_opds_entries()
- raw_feed = str(AcquisitionFeed(
- self._db, self._str, self._url, [work], Annotator
- ))
+ raw_feed = str(
+ AcquisitionFeed(self._db, self._str, self._url, [work], Annotator)
+ )
assert "schema:Series" not in raw_feed
assert "Lifetime of Despair" not in raw_feed
- [entry] = feedparser.parse(str(raw_feed))['entries']
- assert 'schema_series' not in list(entry.items())
+ [entry] = feedparser.parse(str(raw_feed))["entries"]
+ assert "schema_series" not in list(entry.items())
class TestOPDS(DatabaseTest):
-
def links(self, entry, rel=None):
- if 'feed' in entry:
- entry = entry['feed']
- links = sorted(entry['links'], key=lambda x: (x['rel'], x.get('title')))
+ if "feed" in entry:
+ entry = entry["feed"]
+ links = sorted(entry["links"], key=lambda x: (x["rel"], x.get("title")))
r = []
for l in links:
- if (not rel or l['rel'] == rel or
- (isinstance(rel, list) and l['rel'] in rel)):
+ if (
+ not rel
+ or l["rel"] == rel
+ or (isinstance(rel, list) and l["rel"] in rel)
+ ):
r.append(l)
return r
@@ -421,34 +422,29 @@ def setup_method(self):
self.fiction.fiction = True
self.fiction.audiences = [Classifier.AUDIENCE_ADULT]
- self.fantasy = self._lane(
- "Fantasy", parent=self.fiction, genres="Fantasy"
- )
- self.history = self._lane(
- "History", genres="History"
- )
+ self.fantasy = self._lane("Fantasy", parent=self.fiction, genres="Fantasy")
+ self.history = self._lane("History", genres="History")
self.ya = self._lane("Young Adult")
self.ya.history = None
self.ya.audiences = [Classifier.AUDIENCE_YOUNG_ADULT]
self.romance = self._lane("Romance", genres="Romance")
self.romance.fiction = True
self.contemporary_romance = self._lane(
- "Contemporary Romance", parent=self.romance,
- genres="Contemporary Romance"
+ "Contemporary Romance", parent=self.romance, genres="Contemporary Romance"
)
self.conf = WorkList()
self.conf.initialize(
self._default_library,
- children=[self.fiction, self.fantasy, self.history, self.ya,
- self.romance]
+ children=[self.fiction, self.fantasy, self.history, self.ya, self.romance],
)
def _assert_xml_equal(self, a, b):
# Compare xml is the same, we use etree to canonicalize the xml
# then compare the canonical versions
- assert etree.tostring(a, method="c14n2") == \
- etree.tostring(etree.fromstring(b), method="c14n2")
+ assert etree.tostring(a, method="c14n2") == etree.tostring(
+ etree.fromstring(b), method="c14n2"
+ )
def test_acquisition_link(self):
m = AcquisitionFeed.acquisition_link
@@ -461,22 +457,30 @@ def test_acquisition_link(self):
a,
'' % href
+ 'type="application/pdf"/>' % href,
)
# A direct acquisition link.
b = m(rel, href, ["application/epub"])
self._assert_xml_equal(
b,
- '' % href,
+ ''
+ % href,
)
# A direct acquisition link to a document with embedded access restriction rules.
- c = m(rel, href, ['application/audiobook+json;profile=http://www.feedbooks.com/audiobooks/access-restriction'])
+ c = m(
+ rel,
+ href,
+ [
+ "application/audiobook+json;profile=http://www.feedbooks.com/audiobooks/access-restriction"
+ ],
+ )
self._assert_xml_equal(
c,
'' % href
+ 'type="application/audiobook+json;profile=http://www.feedbooks.com/audiobooks/access-restriction"/>'
+ % href,
)
def test_group_uri(self):
@@ -484,32 +488,30 @@ def test_group_uri(self):
[lp] = work.license_pools
annotator = TestAnnotatorWithGroup()
- feed = AcquisitionFeed(self._db, "test", "http://the-url.com/",
- [work], annotator)
+ feed = AcquisitionFeed(
+ self._db, "test", "http://the-url.com/", [work], annotator
+ )
u = str(feed)
parsed = feedparser.parse(u)
- [group_link] = parsed.entries[0]['links']
- expect_uri, expect_title = annotator.group_uri(
- work, lp, lp.identifier)
- assert OPDSFeed.GROUP_REL == group_link['rel']
- assert expect_uri == group_link['href']
- assert expect_title == group_link['title']
+ [group_link] = parsed.entries[0]["links"]
+ expect_uri, expect_title = annotator.group_uri(work, lp, lp.identifier)
+ assert OPDSFeed.GROUP_REL == group_link["rel"]
+ assert expect_uri == group_link["href"]
+ assert expect_title == group_link["title"]
def test_acquisition_feed(self):
work = self._work(with_open_access_download=True, authors="Alice")
- feed = AcquisitionFeed(self._db, "test", "http://the-url.com/",
- [work])
+ feed = AcquisitionFeed(self._db, "test", "http://the-url.com/", [work])
u = str(feed)
assert '' in u
parsed = feedparser.parse(u)
- [with_author] = parsed['entries']
- assert "Alice" == with_author['authors'][0]['name']
+ [with_author] = parsed["entries"]
+ assert "Alice" == with_author["authors"][0]["name"]
def test_acquisition_feed_includes_license_source(self):
work = self._work(with_open_access_download=True)
- feed = AcquisitionFeed(self._db, "test", "http://the-url.com/",
- [work])
+ feed = AcquisitionFeed(self._db, "test", "http://the-url.com/", [work])
gutenberg = DataSource.lookup(self._db, DataSource.GUTENBERG)
# The tag containing the license
@@ -526,26 +528,22 @@ def test_acquisition_feed_includes_license_source(self):
# included.
internal = DataSource.lookup(self._db, DataSource.INTERNAL_PROCESSING)
work.license_pools[0].data_source = internal
- feed = AcquisitionFeed(self._db, "test", "http://the-url.com/",
- [work])
- assert '" in u
def test_acquisition_feed_includes_permanent_work_id(self):
work = self._work(with_open_access_download=True)
- feed = AcquisitionFeed(self._db, "test", "http://the-url.com/",
- [work])
+ feed = AcquisitionFeed(self._db, "test", "http://the-url.com/", [work])
u = str(feed)
parsed = feedparser.parse(u)
- entry = parsed['entries'][0]
- assert (work.presentation_edition.permanent_work_id ==
- entry['simplified_pwid'])
+ entry = parsed["entries"][0]
+ assert work.presentation_edition.permanent_work_id == entry["simplified_pwid"]
def test_lcp_acquisition_link_contains_hashed_passphrase(self):
# Arrange
@@ -553,8 +551,9 @@ def test_lcp_acquisition_link_contains_hashed_passphrase(self):
data_source = DataSource.lookup(self._db, DataSource.LCP, autocreate=True)
data_source_name = data_source.name
license_pool = self._licensepool(
- edition=None, data_source_name=data_source_name, collection=lcp_collection)
- hashed_passphrase = '12345'
+ edition=None, data_source_name=data_source_name, collection=lcp_collection
+ )
+ hashed_passphrase = "12345"
patron = self._patron()
lcp_credential_factory = LCPCredentialFactory()
loan, _ = license_pool.loan_to(patron)
@@ -566,10 +565,13 @@ def test_lcp_acquisition_link_contains_hashed_passphrase(self):
'type="application/vnd.readium.lcp.license.v1.0+json">'
'{1}'
''
- '').format(href, hashed_passphrase)
+ ""
+ ).format(href, hashed_passphrase)
# Act
- lcp_credential_factory.set_hashed_passphrase(self._db, patron, hashed_passphrase)
+ lcp_credential_factory.set_hashed_passphrase(
+ self._db, patron, hashed_passphrase
+ )
acquisition_link = AcquisitionFeed.acquisition_link(rel, href, types, loan)
# Assert
@@ -582,31 +584,26 @@ def test_lane_feed_contains_facet_links(self):
facets = Facets.default(self._default_library)
cached_feed = AcquisitionFeed.page(
- self._db, "title", "http://the-url.com/",
- lane, TestAnnotator, facets=facets
+ self._db, "title", "http://the-url.com/", lane, TestAnnotator, facets=facets
)
u = str(cached_feed)
parsed = feedparser.parse(u)
- by_title = parsed['feed']
+ by_title = parsed["feed"]
- [self_link] = self.links(by_title, 'self')
- assert "http://the-url.com/" == self_link['href']
+ [self_link] = self.links(by_title, "self")
+ assert "http://the-url.com/" == self_link["href"]
facet_links = self.links(by_title, AcquisitionFeed.FACET_REL)
library = self._default_library
- order_facets = library.enabled_facets(
- Facets.ORDER_FACET_GROUP_NAME
- )
+ order_facets = library.enabled_facets(Facets.ORDER_FACET_GROUP_NAME)
availability_facets = library.enabled_facets(
Facets.AVAILABILITY_FACET_GROUP_NAME
)
- collection_facets = library.enabled_facets(
- Facets.COLLECTION_FACET_GROUP_NAME
- )
+ collection_facets = library.enabled_facets(Facets.COLLECTION_FACET_GROUP_NAME)
def link_for_facets(facets):
- return [x for x in facet_links if facets.query_string in x['href']]
+ return [x for x in facet_links if facets.query_string in x["href"]]
facets = Facets(library, None, None, None)
for i1, i2, new_facets, selected in facets.facet_groups:
@@ -623,10 +620,10 @@ def link_for_facets(facets):
# As we'll see below, the feed parser parses facetGroup as
# facetgroup and activeFacet as activefacet. As we see here,
# that's not a problem with the generator code.
- assert 'opds:facetgroup' not in u
- assert 'opds:facetGroup' in u
- assert 'opds:activefacet' not in u
- assert 'opds:activeFacet' in u
+ assert "opds:facetgroup" not in u
+ assert "opds:facetGroup" in u
+ assert "opds:activefacet" not in u
+ assert "opds:activeFacet" in u
def test_acquisition_feed_includes_available_and_issued_tag(self):
today = datetime.date.today()
@@ -668,20 +665,19 @@ def test_acquisition_feed_includes_available_and_issued_tag(self):
self._db.commit()
works = self._db.query(Work)
- with_times = AcquisitionFeed(
- self._db, "test", "url", works, TestAnnotator)
+ with_times = AcquisitionFeed(self._db, "test", "url", works, TestAnnotator)
u = str(with_times)
- assert 'dcterms:issued' in u
+ assert "dcterms:issued" in u
with_times = etree.parse(StringIO(u))
- entries = OPDSXMLParser._xpath(with_times, '/atom:feed/atom:entry')
+ entries = OPDSXMLParser._xpath(with_times, "/atom:feed/atom:entry")
parsed = []
for entry in entries:
- title = OPDSXMLParser._xpath1(entry, 'atom:title').text
- issued = OPDSXMLParser._xpath1(entry, 'dcterms:issued')
+ title = OPDSXMLParser._xpath1(entry, "atom:title").text
+ issued = OPDSXMLParser._xpath1(entry, "dcterms:issued")
if issued != None:
issued = issued.text
- published = OPDSXMLParser._xpath1(entry, 'atom:published')
+ published = OPDSXMLParser._xpath1(entry, "atom:published")
if published != None:
published = published.text
parsed.append(
@@ -691,20 +687,18 @@ def test_acquisition_feed_includes_available_and_issued_tag(self):
published=published,
)
)
- e1, e2, e3, e4 = sorted(
- parsed, key = lambda x: x['title']
- )
- assert today_s == e1['issued']
- assert the_distant_past_s == e1['published']
+ e1, e2, e3, e4 = sorted(parsed, key=lambda x: x["title"])
+ assert today_s == e1["issued"]
+ assert the_distant_past_s == e1["published"]
- assert the_past_s == e2['issued']
- assert the_distant_past_s == e2['published']
+ assert the_past_s == e2["issued"]
+ assert the_distant_past_s == e2["published"]
- assert None == e3['issued']
- assert None == e3['published']
+ assert None == e3["issued"]
+ assert None == e3["published"]
- assert None == e4['issued']
- assert None == e4['published']
+ assert None == e4["issued"]
+ assert None == e4["published"]
def test_acquisition_feed_includes_publisher_and_imprint_tag(self):
work = self._work(with_open_access_download=True)
@@ -718,20 +712,19 @@ def test_acquisition_feed_includes_publisher_and_imprint_tag(self):
w.calculate_opds_entries(verbose=False)
works = self._db.query(Work)
- with_publisher = AcquisitionFeed(
- self._db, "test", "url", works, TestAnnotator)
+ with_publisher = AcquisitionFeed(self._db, "test", "url", works, TestAnnotator)
with_publisher = feedparser.parse(str(with_publisher))
- entries = sorted(with_publisher['entries'], key = lambda x: x['title'])
- assert 'The Publisher' == entries[0]['dcterms_publisher']
- assert 'The Imprint' == entries[0]['bib_publisherimprint']
- assert 'publisher' not in entries[1]
+ entries = sorted(with_publisher["entries"], key=lambda x: x["title"])
+ assert "The Publisher" == entries[0]["dcterms_publisher"]
+ assert "The Imprint" == entries[0]["bib_publisherimprint"]
+ assert "publisher" not in entries[1]
def test_acquisition_feed_includes_audience_as_category(self):
work = self._work(with_open_access_download=True)
work.audience = "Young Adult"
work2 = self._work(with_open_access_download=True)
work2.audience = "Children"
- work2.target_age = NumericRange(7, 9, '[]')
+ work2.target_age = NumericRange(7, 9, "[]")
work3 = self._work(with_open_access_download=True)
work3.audience = None
work4 = self._work(with_open_access_download=True)
@@ -747,35 +740,37 @@ def test_acquisition_feed_includes_audience_as_category(self):
with_audience = AcquisitionFeed(self._db, "test", "url", works)
u = str(with_audience)
with_audience = feedparser.parse(u)
- ya, children, no_audience, adult = sorted(with_audience['entries'], key = lambda x: int(x['title']))
+ ya, children, no_audience, adult = sorted(
+ with_audience["entries"], key=lambda x: int(x["title"])
+ )
scheme = "http://schema.org/audience"
- assert (
- [('Young Adult', 'Young Adult')] ==
- [(x['term'], x['label']) for x in ya['tags']
- if x['scheme'] == scheme])
+ assert [("Young Adult", "Young Adult")] == [
+ (x["term"], x["label"]) for x in ya["tags"] if x["scheme"] == scheme
+ ]
- assert (
- [('Children', 'Children')] ==
- [(x['term'], x['label']) for x in children['tags']
- if x['scheme'] == scheme])
+ assert [("Children", "Children")] == [
+ (x["term"], x["label"]) for x in children["tags"] if x["scheme"] == scheme
+ ]
age_scheme = Subject.uri_lookup[Subject.AGE_RANGE]
- assert (
- [('7-9', '7-9')] ==
- [(x['term'], x['label']) for x in children['tags']
- if x['scheme'] == age_scheme])
+ assert [("7-9", "7-9")] == [
+ (x["term"], x["label"])
+ for x in children["tags"]
+ if x["scheme"] == age_scheme
+ ]
- assert ([] ==
- [(x['term'], x['label']) for x in no_audience['tags']
- if x['scheme'] == scheme])
+ assert [] == [
+ (x["term"], x["label"])
+ for x in no_audience["tags"]
+ if x["scheme"] == scheme
+ ]
# Even though the 'Adult' book has a target age, the target
# age is not shown, because target age is only a relevant
# concept for children's and YA books.
- assert (
- [] ==
- [(x['term'], x['label']) for x in adult['tags']
- if x['scheme'] == age_scheme])
+ assert [] == [
+ (x["term"], x["label"]) for x in adult["tags"] if x["scheme"] == age_scheme
+ ]
def test_acquisition_feed_includes_category_tags_for_appeals(self):
work = self._work(with_open_access_download=True)
@@ -793,20 +788,23 @@ def test_acquisition_feed_includes_category_tags_for_appeals(self):
works = self._db.query(Work)
feed = AcquisitionFeed(self._db, "test", "url", works)
feed = feedparser.parse(str(feed))
- entries = sorted(feed['entries'], key = lambda x: int(x['title']))
-
- tags = entries[0]['tags']
- matches = [(x['term'], x['label']) for x in tags if x['scheme'] == Work.APPEALS_URI]
- assert ([
- (Work.APPEALS_URI + 'Character', 'Character'),
- (Work.APPEALS_URI + 'Language', 'Language'),
- (Work.APPEALS_URI + 'Setting', 'Setting'),
- (Work.APPEALS_URI + 'Story', 'Story'),
- ] ==
- sorted(matches))
-
- tags = entries[1]['tags']
- matches = [(x['term'], x['label']) for x in tags if x['scheme'] == Work.APPEALS_URI]
+ entries = sorted(feed["entries"], key=lambda x: int(x["title"]))
+
+ tags = entries[0]["tags"]
+ matches = [
+ (x["term"], x["label"]) for x in tags if x["scheme"] == Work.APPEALS_URI
+ ]
+ assert [
+ (Work.APPEALS_URI + "Character", "Character"),
+ (Work.APPEALS_URI + "Language", "Language"),
+ (Work.APPEALS_URI + "Setting", "Setting"),
+ (Work.APPEALS_URI + "Story", "Story"),
+ ] == sorted(matches)
+
+ tags = entries[1]["tags"]
+ matches = [
+ (x["term"], x["label"]) for x in tags if x["scheme"] == Work.APPEALS_URI
+ ]
assert [] == matches
def test_acquisition_feed_includes_category_tags_for_fiction_status(self):
@@ -823,17 +821,16 @@ def test_acquisition_feed_includes_category_tags_for_fiction_status(self):
works = self._db.query(Work)
feed = AcquisitionFeed(self._db, "test", "url", works)
feed = feedparser.parse(str(feed))
- entries = sorted(feed['entries'], key = lambda x: int(x['title']))
+ entries = sorted(feed["entries"], key=lambda x: int(x["title"]))
scheme = "http://librarysimplified.org/terms/fiction/"
- assert ([(scheme+'Nonfiction', 'Nonfiction')] ==
- [(x['term'], x['label']) for x in entries[0]['tags']
- if x['scheme'] == scheme])
- assert ([(scheme+'Fiction', 'Fiction')] ==
- [(x['term'], x['label']) for x in entries[1]['tags']
- if x['scheme'] == scheme])
-
+ assert [(scheme + "Nonfiction", "Nonfiction")] == [
+ (x["term"], x["label"]) for x in entries[0]["tags"] if x["scheme"] == scheme
+ ]
+ assert [(scheme + "Fiction", "Fiction")] == [
+ (x["term"], x["label"]) for x in entries[1]["tags"] if x["scheme"] == scheme
+ ]
def test_acquisition_feed_includes_category_tags_for_genres(self):
work = self._work(with_open_access_download=True)
@@ -847,16 +844,19 @@ def test_acquisition_feed_includes_category_tags_for_genres(self):
works = self._db.query(Work)
feed = AcquisitionFeed(self._db, "test", "url", works)
feed = feedparser.parse(str(feed))
- entries = sorted(feed['entries'], key = lambda x: int(x['title']))
+ entries = sorted(feed["entries"], key=lambda x: int(x["title"]))
scheme = Subject.SIMPLIFIED_GENRE
- assert (
- [(scheme+'Romance', 'Romance'),
- (scheme+'Science%20Fiction', 'Science Fiction')] ==
- sorted(
- [(x['term'], x['label']) for x in entries[0]['tags']
- if x['scheme'] == scheme]
- ))
+ assert [
+ (scheme + "Romance", "Romance"),
+ (scheme + "Science%20Fiction", "Science Fiction"),
+ ] == sorted(
+ [
+ (x["term"], x["label"])
+ for x in entries[0]["tags"]
+ if x["scheme"] == scheme
+ ]
+ )
def test_acquisition_feed_omits_works_with_no_active_license_pool(self):
work = self._work(title="open access", with_open_access_download=True)
@@ -876,13 +876,15 @@ def test_acquisition_feed_omits_works_with_no_active_license_pool(self):
by_title = feedparser.parse(by_title_raw)
# We have two entries...
- assert 2 == len(by_title['entries'])
+ assert 2 == len(by_title["entries"])
assert ["not open access", "open access"] == sorted(
- [x['title'] for x in by_title['entries']])
+ [x["title"] for x in by_title["entries"]]
+ )
# ...and two messages.
- assert (2 ==
- by_title_raw.count("I've heard about this work but have no active licenses for it."))
+ assert 2 == by_title_raw.count(
+ "I've heard about this work but have no active licenses for it."
+ )
def test_acquisition_feed_includes_image_links(self):
work = self._work(genre=Fantasy, with_open_access_download=True)
@@ -891,9 +893,10 @@ def test_acquisition_feed_includes_image_links(self):
work.calculate_opds_entries(verbose=False)
feed = feedparser.parse(str(work.simple_opds_entry))
- links = sorted([x['href'] for x in feed['entries'][0]['links'] if
- 'image' in x['rel']])
- assert ['http://full/a', 'http://thumbnail/b'] == links
+ links = sorted(
+ [x["href"] for x in feed["entries"][0]["links"] if "image" in x["rel"]]
+ )
+ assert ["http://full/a", "http://thumbnail/b"] == links
def test_acquisition_feed_image_links_respect_cdn(self):
work = self._work(genre=Fantasy, with_open_access_download=True)
@@ -903,16 +906,17 @@ def test_acquisition_feed_image_links_respect_cdn(self):
# Create some CDNS.
with temp_config() as config:
config[Configuration.INTEGRATIONS][ExternalIntegration.CDN] = {
- 'thumbnail.com' : 'http://foo/',
- 'full.com' : 'http://bar/'
+ "thumbnail.com": "http://foo/",
+ "full.com": "http://bar/",
}
config[Configuration.CDNS_LOADED_FROM_DATABASE] = True
work.calculate_opds_entries(verbose=False)
feed = feedparser.parse(work.simple_opds_entry)
- links = sorted([x['href'] for x in feed['entries'][0]['links'] if
- 'image' in x['rel']])
- assert ['http://bar/a', 'http://foo/b'] == links
+ links = sorted(
+ [x["href"] for x in feed["entries"][0]["links"] if "image" in x["rel"]]
+ )
+ assert ["http://bar/a", "http://foo/b"] == links
def test_messages(self):
"""Test the ability to include OPDSMessage objects for a given URN in
@@ -922,8 +926,9 @@ def test_messages(self):
OPDSMessage("urn:foo", 400, _("msg1")),
OPDSMessage("urn:bar", 500, _("msg2")),
]
- feed = AcquisitionFeed(self._db, "test", "http://the-url.com/",
- [], precomposed_entries=messages)
+ feed = AcquisitionFeed(
+ self._db, "test", "http://the-url.com/", [], precomposed_entries=messages
+ )
feed = str(feed)
for m in messages:
assert m.urn in feed
@@ -935,11 +940,16 @@ def test_precomposed_entries(self):
# in a feed.
entry = AcquisitionFeed.E.entry()
- entry.text='foo'
- feed = AcquisitionFeed(self._db, "test", "http://the-url.com/",
- works=[], precomposed_entries=[entry])
+ entry.text = "foo"
+ feed = AcquisitionFeed(
+ self._db,
+ "test",
+ "http://the-url.com/",
+ works=[],
+ precomposed_entries=[entry],
+ )
feed = str(feed)
- assert 'foo' in feed
+ assert "foo" in feed
def test_page_feed(self):
# Test the ability to create a paginated feed of works for a given
@@ -956,34 +966,43 @@ def test_page_feed(self):
def make_page(pagination):
return AcquisitionFeed.page(
- self._db, "test", self._url, lane, TestAnnotator,
- pagination=pagination, search_engine=search_engine
+ self._db,
+ "test",
+ self._url,
+ lane,
+ TestAnnotator,
+ pagination=pagination,
+ search_engine=search_engine,
)
+
cached_works = str(make_page(pagination))
parsed = feedparser.parse(cached_works)
- assert work1.title == parsed['entries'][0]['title']
+ assert work1.title == parsed["entries"][0]["title"]
# Make sure the links are in place.
- [up_link] = self.links(parsed, 'up')
- assert TestAnnotator.groups_url(lane.parent) == up_link['href']
- assert lane.parent.display_name == up_link['title']
+ [up_link] = self.links(parsed, "up")
+ assert TestAnnotator.groups_url(lane.parent) == up_link["href"]
+ assert lane.parent.display_name == up_link["title"]
- [start] = self.links(parsed, 'start')
- assert TestAnnotator.groups_url(None) == start['href']
- assert TestAnnotator.top_level_title() == start['title']
+ [start] = self.links(parsed, "start")
+ assert TestAnnotator.groups_url(None) == start["href"]
+ assert TestAnnotator.top_level_title() == start["title"]
- [next_link] = self.links(parsed, 'next')
- assert TestAnnotator.feed_url(lane, facets, pagination.next_page) == next_link['href']
+ [next_link] = self.links(parsed, "next")
+ assert (
+ TestAnnotator.feed_url(lane, facets, pagination.next_page)
+ == next_link["href"]
+ )
# This was the first page, so no previous link.
- assert [] == self.links(parsed, 'previous')
+ assert [] == self.links(parsed, "previous")
# Now get the second page and make sure it has a 'previous' link.
cached_works = str(make_page(pagination.next_page))
parsed = feedparser.parse(cached_works)
- [previous] = self.links(parsed, 'previous')
- assert TestAnnotator.feed_url(lane, facets, pagination) == previous['href']
- assert work2.title == parsed['entries'][0]['title']
+ [previous] = self.links(parsed, "previous")
+ assert TestAnnotator.feed_url(lane, facets, pagination) == previous["href"]
+ assert work2.title == parsed["entries"][0]["title"]
# The feed has breadcrumb links
parentage = list(lane.parentage)
@@ -997,8 +1016,8 @@ def make_page(pagination):
assert TestAnnotator.top_level_title() == links[0].get("title")
assert TestAnnotator.default_lane_url() == links[0].get("href")
for i, lane in enumerate(parentage):
- assert lane.display_name == links[i+1].get("title")
- assert TestAnnotator.lane_url(lane) == links[i+1].get("href")
+ assert lane.display_name == links[i + 1].get("title")
+ assert TestAnnotator.lane_url(lane) == links[i + 1].get("href")
def test_page_feed_for_worklist(self):
# Test the ability to create a paginated feed of works for a
@@ -1015,33 +1034,42 @@ def test_page_feed_for_worklist(self):
def make_page(pagination):
return AcquisitionFeed.page(
- self._db, "test", self._url, lane, TestAnnotator,
- pagination=pagination, search_engine=search_engine
+ self._db,
+ "test",
+ self._url,
+ lane,
+ TestAnnotator,
+ pagination=pagination,
+ search_engine=search_engine,
)
+
cached_works = make_page(pagination)
parsed = feedparser.parse(str(cached_works))
- assert work1.title == parsed['entries'][0]['title']
+ assert work1.title == parsed["entries"][0]["title"]
# Make sure the links are in place.
# This is the top-level, so no up link.
- assert [] == self.links(parsed, 'up')
+ assert [] == self.links(parsed, "up")
- [start] = self.links(parsed, 'start')
- assert TestAnnotator.groups_url(None) == start['href']
- assert TestAnnotator.top_level_title() == start['title']
+ [start] = self.links(parsed, "start")
+ assert TestAnnotator.groups_url(None) == start["href"]
+ assert TestAnnotator.top_level_title() == start["title"]
- [next_link] = self.links(parsed, 'next')
- assert TestAnnotator.feed_url(lane, facets, pagination.next_page) == next_link['href']
+ [next_link] = self.links(parsed, "next")
+ assert (
+ TestAnnotator.feed_url(lane, facets, pagination.next_page)
+ == next_link["href"]
+ )
# This was the first page, so no previous link.
- assert [] == self.links(parsed, 'previous')
+ assert [] == self.links(parsed, "previous")
# Now get the second page and make sure it has a 'previous' link.
cached_works = str(make_page(pagination.next_page))
parsed = feedparser.parse(cached_works)
- [previous] = self.links(parsed, 'previous')
- assert TestAnnotator.feed_url(lane, facets, pagination) == previous['href']
- assert work2.title == parsed['entries'][0]['title']
+ [previous] = self.links(parsed, "previous")
+ assert TestAnnotator.feed_url(lane, facets, pagination) == previous["href"]
+ assert work2.title == parsed["entries"][0]["title"]
# The feed has no parents, so no breadcrumbs.
root = ET.fromstring(cached_works)
@@ -1049,23 +1077,34 @@ def make_page(pagination):
assert None == breadcrumbs
def test_from_query(self):
- """Test creating a feed for a custom list from a query.
- """
+ """Test creating a feed for a custom list from a query."""
display_name = "custom_list"
staff_data_source = DataSource.lookup(self._db, DataSource.LIBRARY_STAFF)
- list, ignore = create(self._db, CustomList, name=self._str, library=self._default_library, data_source=staff_data_source)
+ list, ignore = create(
+ self._db,
+ CustomList,
+ name=self._str,
+ library=self._default_library,
+ data_source=staff_data_source,
+ )
work = self._work(with_license_pool=True)
work2 = self._work(with_license_pool=True)
list.add_entry(work)
list.add_entry(work2)
# get all the entries from a custom list
- query = self._db.query(Work).join(Work.custom_list_entries).filter(CustomListEntry.list_id==list.id)
+ query = (
+ self._db.query(Work)
+ .join(Work.custom_list_entries)
+ .filter(CustomListEntry.list_id == list.id)
+ )
pagination = Pagination(size=1)
worklist = WorkList()
- worklist.initialize(self._default_library, customlists=[list], display_name=display_name)
+ worklist.initialize(
+ self._default_library, customlists=[list], display_name=display_name
+ )
def url_for_custom_list(library, list):
def url_fn(after):
@@ -1073,34 +1112,46 @@ def url_fn(after):
if after:
base += "?after=%s&size=1" % after
return base
+
return url_fn
url_fn = url_for_custom_list(self._default_library, list)
+
def from_query(pagination):
return AcquisitionFeed.from_query(
- query, self._db, list.name, "url",
- pagination, url_fn, TestAnnotator,
+ query,
+ self._db,
+ list.name,
+ "url",
+ pagination,
+ url_fn,
+ TestAnnotator,
)
works = from_query(pagination)
parsed = feedparser.parse(str(works))
- assert 1 == len(parsed['entries'])
- assert list.name == parsed['feed'].title
+ assert 1 == len(parsed["entries"])
+ assert list.name == parsed["feed"].title
- [next_link] = self.links(parsed, 'next')
- assert TestAnnotator.feed_url(worklist, pagination=pagination.next_page) == next_link['href']
+ [next_link] = self.links(parsed, "next")
+ assert (
+ TestAnnotator.feed_url(worklist, pagination=pagination.next_page)
+ == next_link["href"]
+ )
# This was the first page, so no previous link.
- assert [] == self.links(parsed, 'previous')
+ assert [] == self.links(parsed, "previous")
# Now get the second page and make sure it has a 'previous' link.
works = from_query(pagination.next_page)
parsed = feedparser.parse(str(works))
- [previous_link] = self.links(parsed, 'previous')
- assert TestAnnotator.feed_url(worklist, pagination=pagination.previous_page) == previous_link['href']
- assert 1 == len(parsed['entries'])
- assert [] == self.links(parsed, 'next')
-
+ [previous_link] = self.links(parsed, "previous")
+ assert (
+ TestAnnotator.feed_url(worklist, pagination=pagination.previous_page)
+ == previous_link["href"]
+ )
+ assert 1 == len(parsed["entries"])
+ assert [] == self.links(parsed, "next")
def test_groups_feed(self):
# Test the ability to create a grouped feed of recommended works for
@@ -1130,9 +1181,15 @@ def test_groups_feed(self):
annotator = TestAnnotatorWithGroup()
private = object()
cached_groups = AcquisitionFeed.groups(
- self._db, "test", self._url, self.fantasy, annotator,
- max_age=0, search_engine=search_engine,
- search_debug=True, private=private
+ self._db,
+ "test",
+ self._url,
+ self.fantasy,
+ annotator,
+ max_age=0,
+ search_engine=search_engine,
+ search_debug=True,
+ private=private,
)
# The result is an OPDSFeedResponse object. The 'private'
@@ -1144,32 +1201,32 @@ def test_groups_feed(self):
parsed = feedparser.parse(cached_groups.data)
# There are three entries in three lanes.
- e1, e2, e3 = parsed['entries']
+ e1, e2, e3 = parsed["entries"]
# Each entry has one and only one link.
- [l1], [l2], [l3] = [x['links'] for x in parsed['entries']]
+ [l1], [l2], [l3] = [x["links"] for x in parsed["entries"]]
# Those links are 'collection' links that classify the
# works under their subgenres.
- assert all([l['rel'] == 'collection' for l in (l1, l2)])
+ assert all([l["rel"] == "collection" for l in (l1, l2)])
- assert l1['href'] == 'http://group/Epic Fantasy'
- assert l1['title'] == 'Group Title for Epic Fantasy!'
- assert l2['href'] == 'http://group/Urban Fantasy'
- assert l2['title'] == 'Group Title for Urban Fantasy!'
- assert l3['href'] == 'http://group/Fantasy'
- assert l3['title'] == 'Group Title for Fantasy!'
+ assert l1["href"] == "http://group/Epic Fantasy"
+ assert l1["title"] == "Group Title for Epic Fantasy!"
+ assert l2["href"] == "http://group/Urban Fantasy"
+ assert l2["title"] == "Group Title for Urban Fantasy!"
+ assert l3["href"] == "http://group/Fantasy"
+ assert l3["title"] == "Group Title for Fantasy!"
# The feed itself has an 'up' link which points to the
# groups for Fiction, and a 'start' link which points to
# the top-level groups feed.
- [up_link] = self.links(parsed['feed'], 'up')
- assert "http://groups/%s" % self.fiction.id == up_link['href']
- assert "Fiction" == up_link['title']
+ [up_link] = self.links(parsed["feed"], "up")
+ assert "http://groups/%s" % self.fiction.id == up_link["href"]
+ assert "Fiction" == up_link["title"]
- [start_link] = self.links(parsed['feed'], 'start')
- assert "http://groups/" == start_link['href']
- assert annotator.top_level_title() == start_link['title']
+ [start_link] = self.links(parsed["feed"], "start")
+ assert "http://groups/" == start_link["href"]
+ assert annotator.top_level_title() == start_link["title"]
# The feed has breadcrumb links
ancestors = list(self.fantasy.parentage)
@@ -1180,27 +1237,34 @@ def test_groups_feed(self):
assert annotator.top_level_title() == links[0].get("title")
assert annotator.default_lane_url() == links[0].get("href")
for i, lane in enumerate(reversed(ancestors)):
- assert lane.display_name == links[i+1].get("title")
- assert annotator.lane_url(lane) == links[i+1].get("href")
+ assert lane.display_name == links[i + 1].get("title")
+ assert annotator.lane_url(lane) == links[i + 1].get("href")
def test_empty_groups_feed(self):
# Test the case where a grouped feed turns up nothing.
# A Lane, and a Work not in the Lane.
- test_lane = self._lane("Test Lane", genres=['Mystery'])
+ test_lane = self._lane("Test Lane", genres=["Mystery"])
work1 = self._work(genre=History, with_open_access_download=True)
# Mock search index and Annotator.
search_engine = MockExternalSearchIndex()
+
class Mock(TestAnnotator):
def annotate_feed(self, feed, worklist):
self.called = True
+
annotator = Mock()
# Build a grouped feed for the lane.
feed = AcquisitionFeed.groups(
- self._db, "test", self._url, test_lane, annotator,
- max_age=0, search_engine=search_engine
+ self._db,
+ "test",
+ self._url,
+ test_lane,
+ annotator,
+ max_age=0,
+ search_engine=search_engine,
)
# A grouped feed was cached for the lane, but there were no
@@ -1210,7 +1274,7 @@ def annotate_feed(self, feed, worklist):
# So the feed contains no entries.
parsed = feedparser.parse(str(feed))
- assert [] == parsed['entries']
+ assert [] == parsed["entries"]
# but our mock Annotator got a chance to modify the feed in place.
assert True == annotator.called
@@ -1231,13 +1295,18 @@ def test_search_feed(self):
def make_page(pagination):
return AcquisitionFeed.search(
- self._db, "test", self._url, fantasy_lane, search_client,
+ self._db,
+ "test",
+ self._url,
+ fantasy_lane,
+ search_client,
"fantasy",
pagination=pagination,
facets=facets,
annotator=TestAnnotator,
- private=private
+ private=private,
)
+
response = make_page(pagination)
assert isinstance(response, OPDSFeedResponse)
assert OPDSFeed.DEFAULT_MAX_AGE == response.max_age
@@ -1245,44 +1314,44 @@ def make_page(pagination):
assert private == response.private
parsed = feedparser.parse(response.data)
- assert work1.title == parsed['entries'][0]['title']
+ assert work1.title == parsed["entries"][0]["title"]
# Make sure the links are in place.
- [start] = self.links(parsed, 'start')
- assert TestAnnotator.groups_url(None) == start['href']
- assert TestAnnotator.top_level_title() == start['title']
+ [start] = self.links(parsed, "start")
+ assert TestAnnotator.groups_url(None) == start["href"]
+ assert TestAnnotator.top_level_title() == start["title"]
- [next_link] = self.links(parsed, 'next')
+ [next_link] = self.links(parsed, "next")
expect = TestAnnotator.search_url(
fantasy_lane, "test", pagination.next_page, facets=facets
)
- assert expect == next_link['href']
+ assert expect == next_link["href"]
# This is tested elsewhere, but let's make sure
# SearchFacets-specific fields like order and min_score are
# propagated to the next-page URL.
- assert all(x in expect for x in ('order=author', 'min_score=10'))
+ assert all(x in expect for x in ("order=author", "min_score=10"))
# This was the first page, so no previous link.
- assert [] == self.links(parsed, 'previous')
+ assert [] == self.links(parsed, "previous")
# Make sure there's an "up" link to the lane that was searched
- [up_link] = self.links(parsed, 'up')
+ [up_link] = self.links(parsed, "up")
uplink_url = TestAnnotator.lane_url(fantasy_lane)
- assert uplink_url == up_link['href']
- assert fantasy_lane.display_name == up_link['title']
+ assert uplink_url == up_link["href"]
+ assert fantasy_lane.display_name == up_link["title"]
# Now get the second page and make sure it has a 'previous' link.
feed = str(make_page(pagination.next_page))
parsed = feedparser.parse(feed)
- [previous] = self.links(parsed, 'previous')
+ [previous] = self.links(parsed, "previous")
expect = TestAnnotator.search_url(
fantasy_lane, "test", pagination, facets=facets
)
- assert expect == previous['href']
- assert all(x in expect for x in ('order=author', 'min_score=10'))
+ assert expect == previous["href"]
+ assert all(x in expect for x in ("order=author", "min_score=10"))
- assert work2.title == parsed['entries'][0]['title']
+ assert work2.title == parsed["entries"][0]["title"]
# The feed has no breadcrumb links, since we're not
# searching the lane -- just using some aspects of the lane
@@ -1293,16 +1362,25 @@ def make_page(pagination):
assert None == breadcrumbs
def test_cache(self):
- work1 = self._work(title="The Original Title",
- genre=Epic_Fantasy, with_open_access_download=True)
+ work1 = self._work(
+ title="The Original Title",
+ genre=Epic_Fantasy,
+ with_open_access_download=True,
+ )
fantasy_lane = self.fantasy
search_engine = MockExternalSearchIndex()
search_engine.bulk_update([work1])
+
def make_page():
return AcquisitionFeed.page(
- self._db, "test", self._url, fantasy_lane, TestAnnotator,
- pagination=Pagination.default(), search_engine=search_engine
+ self._db,
+ "test",
+ self._url,
+ fantasy_lane,
+ TestAnnotator,
+ pagination=Pagination.default(),
+ search_engine=search_engine,
)
response1 = make_page()
@@ -1312,7 +1390,8 @@ def make_page():
work2 = self._work(
title="A Brand New Title",
- genre=Epic_Fantasy, with_open_access_download=True
+ genre=Epic_Fantasy,
+ with_open_access_download=True,
)
search_engine.bulk_update([work2])
@@ -1331,7 +1410,6 @@ def make_page():
class TestAcquisitionFeed(DatabaseTest):
-
def test_page(self):
# Verify that AcquisitionFeed.page() returns an appropriate OPDSFeedResponse
@@ -1339,8 +1417,13 @@ def test_page(self):
wl.initialize(self._default_library)
private = object()
response = AcquisitionFeed.page(
- self._db, "feed title", "url", wl, TestAnnotator,
- max_age=10, private=private
+ self._db,
+ "feed title",
+ "url",
+ wl,
+ TestAnnotator,
+ max_age=10,
+ private=private,
)
# The result is an OPDSFeedResponse. The 'private' argument,
@@ -1349,7 +1432,7 @@ def test_page(self):
assert 10 == response.max_age
assert private == response.private
- assert 'feed title' in str(response)
+ assert "feed title" in str(response)
def test_as_response(self):
# Verify the ability to convert an AcquisitionFeed object to an
@@ -1364,7 +1447,7 @@ def test_as_response(self):
# We get an OPDSFeedResponse containing the feed in its
# entity-body.
assert isinstance(response, OPDSFeedResponse)
- assert 'feed title' in str(response)
+ assert "feed title" in str(response)
# The caching expectations are respected.
assert 101 == response.max_age
@@ -1386,7 +1469,7 @@ def test_as_error_response(self):
# The content of the feed is unchanged.
assert 200 == response.status_code
- assert 'feed title' in str(response)
+ assert "feed title" in str(response)
# But the max_age and private settings have been overridden.
assert 0 == response.max_age
@@ -1399,6 +1482,7 @@ def test_add_entrypoint_links(self):
m = AcquisitionFeed.add_entrypoint_links
old_entrypoint_link = AcquisitionFeed._entrypoint_link
+
class Mock(object):
attrs = dict(href="the response")
@@ -1419,19 +1503,28 @@ def __call__(self, *args):
entrypoints = [AudiobooksEntryPoint, EbooksEntryPoint]
url_generator = object()
AcquisitionFeed.add_entrypoint_links(
- feed, url_generator, entrypoints, EbooksEntryPoint,
- "Some entry points"
+ feed, url_generator, entrypoints, EbooksEntryPoint, "Some entry points"
)
# Two different calls were made to the mock method.
c1, c2 = mock.calls
# The first entry point is not selected.
- assert (c1 ==
- (url_generator, AudiobooksEntryPoint, EbooksEntryPoint, True, "Some entry points"))
+ assert c1 == (
+ url_generator,
+ AudiobooksEntryPoint,
+ EbooksEntryPoint,
+ True,
+ "Some entry points",
+ )
# The second one is selected.
- assert (c2 ==
- (url_generator, EbooksEntryPoint, EbooksEntryPoint, False, "Some entry points"))
+ assert c2 == (
+ url_generator,
+ EbooksEntryPoint,
+ EbooksEntryPoint,
+ False,
+ "Some entry points",
+ )
# Two identical tags were added to the tag, one
# for each call to the mock method.
@@ -1448,8 +1541,7 @@ def __call__(self, *args):
mock.calls = []
entrypoints = [EbooksEntryPoint]
AcquisitionFeed.add_entrypoint_links(
- feed, url_generator, entrypoints, EbooksEntryPoint,
- "Some entry points"
+ feed, url_generator, entrypoints, EbooksEntryPoint, "Some entry points"
)
assert [] == mock.calls
@@ -1458,6 +1550,7 @@ def test_entrypoint_link(self):
attributes for tags.
"""
m = AcquisitionFeed._entrypoint_link
+
def g(entrypoint):
"""A mock URL generator."""
return "%s" % (entrypoint.INTERNAL_NAME)
@@ -1470,19 +1563,21 @@ def g(entrypoint):
# The link is identified as belonging to an entry point-type
# facet group.
- assert l['rel'] == AcquisitionFeed.FACET_REL
- assert (l['{http://librarysimplified.org/terms/}facetGroupType'] ==
- FacetConstants.ENTRY_POINT_REL)
- assert 'Grupe' == l['{http://opds-spec.org/2010/catalog}facetGroup']
+ assert l["rel"] == AcquisitionFeed.FACET_REL
+ assert (
+ l["{http://librarysimplified.org/terms/}facetGroupType"]
+ == FacetConstants.ENTRY_POINT_REL
+ )
+ assert "Grupe" == l["{http://opds-spec.org/2010/catalog}facetGroup"]
# This facet is the active one in the group.
- assert 'true' == l['{http://opds-spec.org/2010/catalog}activeFacet']
+ assert "true" == l["{http://opds-spec.org/2010/catalog}activeFacet"]
# The URL generator was invoked to create the href.
- assert l['href'] == g(AudiobooksEntryPoint)
+ assert l["href"] == g(AudiobooksEntryPoint)
# The facet title identifies it as a way to look at audiobooks.
- assert EntryPoint.DISPLAY_TITLES[AudiobooksEntryPoint] == l['title']
+ assert EntryPoint.DISPLAY_TITLES[AudiobooksEntryPoint] == l["title"]
# Now try some variants.
@@ -1490,23 +1585,21 @@ def g(entrypoint):
l = m(g, AudiobooksEntryPoint, AudiobooksEntryPoint, True, "Grupe")
# This may affect the URL generated for the facet link.
- assert l['href'] == g(AudiobooksEntryPoint)
+ assert l["href"] == g(AudiobooksEntryPoint)
# Here, the entry point for which we're generating the link is
# not the selected one -- EbooksEntryPoint is.
l = m(g, AudiobooksEntryPoint, EbooksEntryPoint, True, "Grupe")
# This means the 'activeFacet' attribute is not present.
- assert '{http://opds-spec.org/2010/catalog}activeFacet' not in l
+ assert "{http://opds-spec.org/2010/catalog}activeFacet" not in l
def test_license_tags_no_loan_or_hold(self):
edition, pool = self._edition(with_license_pool=True)
- availability, holds, copies = AcquisitionFeed.license_tags(
- pool, None, None
- )
- assert dict(status='available') == availability.attrib
- assert dict(total='0') == holds.attrib
- assert dict(total='1', available='1') == copies.attrib
+ availability, holds, copies = AcquisitionFeed.license_tags(pool, None, None)
+ assert dict(status="available") == availability.attrib
+ assert dict(total="0") == holds.attrib
+ assert dict(total="1", available="1") == copies.attrib
def test_license_tags_hold_position(self):
# When a book is placed on hold, it typically takes a while
@@ -1522,41 +1615,33 @@ def test_license_tags_hold_position(self):
pool.patrons_in_hold_queue = 3
hold, is_new = pool.on_hold_to(patron, position=1)
- availability, holds, copies = AcquisitionFeed.license_tags(
- pool, None, hold
- )
- assert '1' == holds.attrib['position']
- assert '3' == holds.attrib['total']
+ availability, holds, copies = AcquisitionFeed.license_tags(pool, None, hold)
+ assert "1" == holds.attrib["position"]
+ assert "3" == holds.attrib["total"]
# If the patron's hold position is missing, we assume they
# are last in the list.
hold.position = None
- availability, holds, copies = AcquisitionFeed.license_tags(
- pool, None, hold
- )
- assert '3' == holds.attrib['position']
- assert '3' == holds.attrib['total']
+ availability, holds, copies = AcquisitionFeed.license_tags(pool, None, hold)
+ assert "3" == holds.attrib["position"]
+ assert "3" == holds.attrib["total"]
# If the patron's current hold position is greater than the
# total recorded number of holds+reserves, their position will
# be used as the value of opds:total.
hold.position = 5
- availability, holds, copies = AcquisitionFeed.license_tags(
- pool, None, hold
- )
- assert '5' == holds.attrib['position']
- assert '5' == holds.attrib['total']
+ availability, holds, copies = AcquisitionFeed.license_tags(pool, None, hold)
+ assert "5" == holds.attrib["position"]
+ assert "5" == holds.attrib["total"]
# A patron earlier in the holds queue may see a different
# total number of holds, but that's fine -- it doesn't matter
# very much to that person the precise number of people behind
# them in the queue.
hold.position = 4
- availability, holds, copies = AcquisitionFeed.license_tags(
- pool, None, hold
- )
- assert '4' == holds.attrib['position']
- assert '4' == holds.attrib['total']
+ availability, holds, copies = AcquisitionFeed.license_tags(pool, None, hold)
+ assert "4" == holds.attrib["position"]
+ assert "4" == holds.attrib["total"]
# If the patron's hold position is zero (because the book is
# reserved to them), we do not represent them as having a hold
@@ -1565,11 +1650,9 @@ def test_license_tags_hold_position(self):
# is out of date.
hold.position = 0
pool.patrons_in_hold_queue = 0
- availability, holds, copies = AcquisitionFeed.license_tags(
- pool, None, hold
- )
- assert 'position' not in holds.attrib
- assert '1' == holds.attrib['total']
+ availability, holds, copies = AcquisitionFeed.license_tags(pool, None, hold)
+ assert "position" not in holds.attrib
+ assert "1" == holds.attrib["total"]
def test_license_tags_show_unlimited_access_books(self):
# Arrange
@@ -1579,19 +1662,17 @@ def test_license_tags_show_unlimited_access_books(self):
pool.unlimited_access = True
# Act
- tags = AcquisitionFeed.license_tags(
- pool, None, None
- )
+ tags = AcquisitionFeed.license_tags(pool, None, None)
# Assert
assert 1 == len(tags)
[tag] = tags
- assert ('status' in tag.attrib) == True
- assert 'available' == tag.attrib['status']
- assert ('holds' in tag.attrib) == False
- assert ('copies' in tag.attrib) == False
+ assert ("status" in tag.attrib) == True
+ assert "available" == tag.attrib["status"]
+ assert ("holds" in tag.attrib) == False
+ assert ("copies" in tag.attrib) == False
def test_license_tags_show_self_hosted_books(self):
# Arrange
@@ -1602,14 +1683,12 @@ def test_license_tags_show_self_hosted_books(self):
pool.licenses_owned = 0
# Act
- tags = AcquisitionFeed.license_tags(
- pool, None, None
- )
+ tags = AcquisitionFeed.license_tags(pool, None, None)
# Assert
assert 1 == len(tags)
- assert 'status' in tags[0].attrib
- assert 'available' == tags[0].attrib['status']
+ assert "status" in tags[0].attrib
+ assert "available" == tags[0].attrib["status"]
def test_single_entry(self):
@@ -1646,10 +1725,8 @@ def test_single_entry(self):
# If the edition was issued before 1980, no datetime formatting error
# is raised.
work.simple_opds_entry = work.verbose_opds_entry = None
- five_hundred_years = datetime.timedelta(days=(500*365))
- work.presentation_edition.issued = (
- utc_now() - five_hundred_years
- )
+ five_hundred_years = datetime.timedelta(days=(500 * 365))
+ work.presentation_edition.issued = utc_now() - five_hundred_years
entry = AcquisitionFeed.single_entry(self._db, work, TestAnnotator)
@@ -1679,8 +1756,8 @@ def create_entry(*args, **kwargs):
# We got an OPDS entry containing the message.
assert isinstance(response, OPDSEntryResponse)
assert 200 == response.status_code
- assert '500' in str(response)
- assert 'oops' in str(response)
+ assert "500" in str(response)
+ assert "oops" in str(response)
# Our caching preferences were overridden.
assert True == response.private
@@ -1700,18 +1777,18 @@ def test_entry_cache_adds_missing_drm_namespace(self):
class AddDRMTagAnnotator(TestAnnotator):
@classmethod
def annotate_work_entry(
- cls, work, license_pool, edition, identifier, feed,
- entry):
+ cls, work, license_pool, edition, identifier, feed, entry
+ ):
drm_link = OPDSFeed.makeelement("{%s}licensor" % OPDSFeed.DRM_NS)
entry.extend([drm_link])
# The entry is retrieved from cache and the appropriate
# namespace inserted.
- entry = AcquisitionFeed.single_entry(
- self._db, work, AddDRMTagAnnotator
+ entry = AcquisitionFeed.single_entry(self._db, work, AddDRMTagAnnotator)
+ assert (
+ 'bar'
+ == str(entry)
)
- assert ('bar' ==
- str(entry))
def test_error_when_work_has_no_identifier(self):
# We cannot create an OPDS entry for a Work that cannot be associated
@@ -1719,16 +1796,12 @@ def test_error_when_work_has_no_identifier(self):
work = self._work(title="Hello, World!", with_license_pool=True)
work.license_pools[0].identifier = None
work.presentation_edition.primary_identifier = None
- entry = AcquisitionFeed.single_entry(
- self._db, work, TestAnnotator
- )
+ entry = AcquisitionFeed.single_entry(self._db, work, TestAnnotator)
assert entry == None
def test_error_when_work_has_no_licensepool(self):
work = self._work()
- feed = AcquisitionFeed(
- self._db, self._str, self._url, [], annotator=Annotator
- )
+ feed = AcquisitionFeed(self._db, self._str, self._url, [], annotator=Annotator)
entry = feed.create_entry(work)
expect = AcquisitionFeed.error_message(
work.presentation_edition.primary_identifier,
@@ -1744,20 +1817,16 @@ def test_error_when_work_has_no_presentation_edition(self):
work = self._work(title="Hello, World!", with_license_pool=True)
work.license_pools[0].presentation_edition = None
work.presentation_edition = None
- feed = AcquisitionFeed(
- self._db, self._str, self._url, [], annotator=Annotator
- )
+ feed = AcquisitionFeed(self._db, self._str, self._url, [], annotator=Annotator)
entry = feed.create_entry(work)
assert None == entry
def test_cache_usage(self):
work = self._work(with_open_access_download=True)
- feed = AcquisitionFeed(
- self._db, self._str, self._url, [], annotator=Annotator
- )
+ feed = AcquisitionFeed(self._db, self._str, self._url, [], annotator=Annotator)
# Set the Work's cached OPDS entry to something that's clearly wrong.
- tiny_entry = 'cached entry'
+ tiny_entry = "cached entry"
work.simple_opds_entry = tiny_entry
# If we pass in use_cache=True, the cached value is used as a basis
@@ -1771,8 +1840,7 @@ def test_cache_usage(self):
xml = etree.fromstring(work.simple_opds_entry)
annotator = Annotator()
annotator.annotate_work_entry(
- work, pool, pool.presentation_edition, pool.identifier, feed,
- xml
+ work, pool, pool.presentation_edition, pool.identifier, feed, xml
)
assert etree.tounicode(xml) == etree.tounicode(entry)
@@ -1793,8 +1861,7 @@ def test_cache_usage(self):
# through `Annotator.annotate_work_entry`.
full_entry = etree.fromstring(work.simple_opds_entry)
annotator.annotate_work_entry(
- work, pool, pool.presentation_edition, pool.identifier, feed,
- full_entry
+ work, pool, pool.presentation_edition, pool.identifier, feed, full_entry
)
assert entry_string == etree.tounicode(full_entry)
@@ -1804,9 +1871,8 @@ def test_exception_during_entry_creation_is_not_reraised(self):
class DoomedFeed(AcquisitionFeed):
def _create_entry(self, *args, **kwargs):
raise Exception("I'm doomed!")
- feed = DoomedFeed(
- self._db, self._str, self._url, [], annotator=Annotator
- )
+
+ feed = DoomedFeed(self._db, self._str, self._url, [], annotator=Annotator)
work = self._work(with_open_access_download=True)
# But calling create_entry() doesn't raise an exception, it
@@ -1818,12 +1884,15 @@ def test_unfilfullable_work(self):
work = self._work(with_open_access_download=True)
[pool] = work.license_pools
response = AcquisitionFeed.single_entry(
- self._db, work, TestUnfulfillableAnnotator,
+ self._db,
+ work,
+ TestUnfulfillableAnnotator,
)
assert isinstance(response, Response)
expect = AcquisitionFeed.error_message(
- pool.identifier, 403,
- "I know about this work but can offer no way of fulfilling it."
+ pool.identifier,
+ 403,
+ "I know about this work but can offer no way of fulfilling it.",
)
# The status code equivalent inside the OPDS message has not affected
# the status code of the Response itself.
@@ -1834,39 +1903,46 @@ def test_format_types(self):
m = AcquisitionFeed.format_types
epub_no_drm, ignore = DeliveryMechanism.lookup(
- self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM)
+ self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM
+ )
assert [Representation.EPUB_MEDIA_TYPE] == m(epub_no_drm)
epub_adobe_drm, ignore = DeliveryMechanism.lookup(
- self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM)
- assert ([DeliveryMechanism.ADOBE_DRM, Representation.EPUB_MEDIA_TYPE] ==
- m(epub_adobe_drm))
+ self._db, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM
+ )
+ assert [DeliveryMechanism.ADOBE_DRM, Representation.EPUB_MEDIA_TYPE] == m(
+ epub_adobe_drm
+ )
overdrive_streaming_text, ignore = DeliveryMechanism.lookup(
- self._db, DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
- DeliveryMechanism.OVERDRIVE_DRM
+ self._db,
+ DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
+ DeliveryMechanism.OVERDRIVE_DRM,
)
- assert (
- [OPDSFeed.ENTRY_TYPE,
- Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE] ==
- m(overdrive_streaming_text))
+ assert [
+ OPDSFeed.ENTRY_TYPE,
+ Representation.TEXT_HTML_MEDIA_TYPE + DeliveryMechanism.STREAMING_PROFILE,
+ ] == m(overdrive_streaming_text)
audiobook_drm, ignore = DeliveryMechanism.lookup(
- self._db, Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE,
- DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_DRM
+ self._db,
+ Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE,
+ DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_DRM,
)
- assert (
- [Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE + DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_PROFILE] ==
- m(audiobook_drm))
+ assert [
+ Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE
+ + DeliveryMechanism.FEEDBOOKS_AUDIOBOOK_PROFILE
+ ] == m(audiobook_drm)
# Test a case where there is a DRM scheme but no underlying
# content type.
findaway_manifest, ignore = DeliveryMechanism.lookup(
self._db, DeliveryMechanism.FINDAWAY_DRM, None
)
- assert ([DeliveryMechanism.FINDAWAY_DRM] ==
- AcquisitionFeed.format_types(findaway_manifest))
+ assert [DeliveryMechanism.FINDAWAY_DRM] == AcquisitionFeed.format_types(
+ findaway_manifest
+ )
def test_add_breadcrumbs(self):
_db = self._db
@@ -1886,15 +1962,12 @@ def __init__(self):
lane = self._lane(display_name="lane")
sublane = self._lane(parent=lane, display_name="sublane")
subsublane = self._lane(parent=sublane, display_name="subsublane")
- subsubsublane = self._lane(parent=subsublane,
- display_name="subsubsublane")
+ subsubsublane = self._lane(parent=subsublane, display_name="subsubsublane")
top_level = object()
ep = AudiobooksEntryPoint
- def assert_breadcrumbs(
- expect_breadcrumbs_for, lane, **add_breadcrumbs_kwargs
- ):
+ def assert_breadcrumbs(expect_breadcrumbs_for, lane, **add_breadcrumbs_kwargs):
# Create breadcrumbs leading up to `lane` and verify that
# there is a breadcrumb for everything in
# `expect_breadcrumbs_for` -- Lanes, EntryPoints, and the
@@ -1906,8 +1979,8 @@ def assert_breadcrumbs(
feed = MockFeed()
annotator = TestAnnotator()
- entrypoint = add_breadcrumbs_kwargs.get('entrypoint', None)
- include_lane = add_breadcrumbs_kwargs.get('include_lane', False)
+ entrypoint = add_breadcrumbs_kwargs.get("entrypoint", None)
+ include_lane = add_breadcrumbs_kwargs.get("include_lane", False)
feed.add_breadcrumbs(lane, **add_breadcrumbs_kwargs)
@@ -1934,7 +2007,7 @@ def title(x):
return x.display_name
expect_titles = [title(x) for x in expect_breadcrumbs_for]
- actual_titles = [x.attrib.get('title') for x in crumbs]
+ actual_titles = [x.attrib.get("title") for x in crumbs]
assert expect_titles == actual_titles
# Now, compare the URLs of the breadcrumbs. This is
@@ -1961,9 +2034,7 @@ def title(x):
# The URL for this breadcrumb is the URL for the
# previous breadcrumb with the addition of the
# entrypoint selection query.
- expect_url = (
- previous_breadcrumb_url + entrypoint_query
- )
+ expect_url = previous_breadcrumb_url + entrypoint_query
else:
# Breadcrumb for a lane.
@@ -1978,8 +2049,7 @@ def title(x):
expect_url = lane_url
logging.debug(
- "%s: expect=%s actual=%s", expect_titles[i],
- expect_url, actual_url
+ "%s: expect=%s actual=%s", expect_titles[i], expect_url, actual_url
)
assert expect_url == actual_url
@@ -2008,24 +2078,19 @@ def title(x):
# A lane with an entrypoint selected
assert_breadcrumbs([top_level, ep], lane, entrypoint=ep)
assert_breadcrumbs(
- [top_level, ep, lane],
- lane, entrypoint=ep, include_lane=True
+ [top_level, ep, lane], lane, entrypoint=ep, include_lane=True
)
# One lane level down.
assert_breadcrumbs([top_level, lane], sublane)
assert_breadcrumbs([top_level, ep, lane], sublane, entrypoint=ep)
assert_breadcrumbs(
- [top_level, ep, lane, sublane],
- sublane, entrypoint=ep, include_lane=True
+ [top_level, ep, lane, sublane], sublane, entrypoint=ep, include_lane=True
)
# Two lane levels down.
assert_breadcrumbs([top_level, lane, sublane], subsublane)
- assert_breadcrumbs(
- [top_level, ep, lane, sublane],
- subsublane, entrypoint=ep
- )
+ assert_breadcrumbs([top_level, ep, lane, sublane], subsublane, entrypoint=ep)
# Three lane levels down.
assert_breadcrumbs(
@@ -2034,8 +2099,7 @@ def title(x):
)
assert_breadcrumbs(
- [top_level, ep, lane, sublane, subsublane],
- subsubsublane, entrypoint=ep
+ [top_level, ep, lane, sublane, subsublane], subsubsublane, entrypoint=ep
)
# Make the sublane a root lane for a certain patron type, and
@@ -2044,35 +2108,30 @@ def title(x):
sublane.root_for_patron_type = ["ya"]
assert_breadcrumbs([], sublane)
- assert_breadcrumbs(
- [sublane, subsublane],
- subsubsublane
- )
+ assert_breadcrumbs([sublane, subsublane], subsubsublane)
assert_breadcrumbs(
- [sublane, subsublane, subsubsublane],
- subsubsublane, include_lane=True
+ [sublane, subsublane, subsubsublane], subsubsublane, include_lane=True
)
# However, if an entrypoint is selected we will see a
# breadcrumb for it between the patron root lane and its
# child.
- assert_breadcrumbs(
- [sublane, ep, subsublane],
- subsubsublane, entrypoint=ep
- )
+ assert_breadcrumbs([sublane, ep, subsublane], subsubsublane, entrypoint=ep)
assert_breadcrumbs(
[sublane, ep, subsublane, subsubsublane],
- subsubsublane, entrypoint=ep, include_lane=True
+ subsubsublane,
+ entrypoint=ep,
+ include_lane=True,
)
def test_add_breadcrumb_links(self):
-
class MockFeed(AcquisitionFeed):
add_link_calls = []
add_breadcrumbs_call = None
current_entrypoint = None
+
def add_link_to_feed(self, **kwargs):
self.add_link_calls.append(kwargs)
@@ -2093,11 +2152,11 @@ def show_current_entrypoint(self, entrypoint):
# add_link_to_feed was called twice, to create the 'start' and
# 'up' links.
start, up = feed.add_link_calls
- assert 'start' == start['rel']
- assert annotator.top_level_title() == start['title']
+ assert "start" == start["rel"]
+ assert annotator.top_level_title() == start["title"]
- assert 'up' == up['rel']
- assert lane.display_name == up['title']
+ assert "up" == up["rel"]
+ assert lane.display_name == up["title"]
# The Lane and EntryPoint were passed into add_breadcrumbs.
assert (sublane, ep) == feed.add_breadcrumbs_call
@@ -2171,6 +2230,7 @@ def facet_groups(self):
class MockFeed(AcquisitionFeed):
links = []
+
@classmethod
def facet_link(cls, url, facet_title, group_title, selected):
# Return the passed-in objects as is.
@@ -2185,23 +2245,23 @@ def facet_link(cls, url, facet_title, group_title, selected):
#
# The other three 4-tuples were ignored since we don't know
# how to generate human-readable titles for them.
- [[url, facet, group, selected]] = MockFeed.facet_links(
- annotator, facets
- )
- assert 'url: try the featured collection instead' == url
+ [[url, facet, group, selected]] = MockFeed.facet_links(annotator, facets)
+ assert "url: try the featured collection instead" == url
assert Facets.FACET_DISPLAY_TITLES[Facets.COLLECTION_FULL] == facet
- assert (Facets.GROUP_DISPLAY_TITLES[Facets.COLLECTION_FACET_GROUP_NAME] ==
- group)
+ assert Facets.GROUP_DISPLAY_TITLES[Facets.COLLECTION_FACET_GROUP_NAME] == group
assert True == selected
class TestLookupAcquisitionFeed(DatabaseTest):
-
def feed(self, annotator=VerboseAnnotator, **kwargs):
"""Helper method to create a LookupAcquisitionFeed."""
return LookupAcquisitionFeed(
- self._db, "Feed Title", "http://whatever.io", [],
- annotator=annotator, **kwargs
+ self._db,
+ "Feed Title",
+ "http://whatever.io",
+ [],
+ annotator=annotator,
+ **kwargs
)
def entry(self, identifier, work, annotator=VerboseAnnotator, **kwargs):
@@ -2257,20 +2317,16 @@ def test_error_on_mismatched_identifier(self):
# Identifier and a totally random Work.
expect_error = 'I tried to generate an OPDS entry for the identifier "%s" using a Work not associated with that identifier.'
feed, entry = self.entry(identifier, work)
- assert (
- entry ==OPDSMessage(
- identifier.urn, 500, expect_error % identifier.urn
- ))
+ assert entry == OPDSMessage(identifier.urn, 500, expect_error % identifier.urn)
# Even if the Identifier does have a Work, if the Works don't
# match, we get the same error.
edition, lp = self._edition(with_license_pool=True)
work2 = lp.calculate_work()
feed, entry = self.entry(lp.identifier, work)
- assert (entry ==
- OPDSMessage(
- lp.identifier.urn, 500, expect_error % lp.identifier.urn
- ))
+ assert entry == OPDSMessage(
+ lp.identifier.urn, 500, expect_error % lp.identifier.urn
+ )
def test_error_when_work_has_no_licensepool(self):
"""Under most circumstances, a Work must have at least one
@@ -2290,11 +2346,11 @@ def test_error_when_work_has_no_licensepool(self):
def test_unfilfullable_work(self):
work = self._work(with_open_access_download=True)
[pool] = work.license_pools
- feed, entry = self.entry(pool.identifier, work,
- TestUnfulfillableAnnotator)
+ feed, entry = self.entry(pool.identifier, work, TestUnfulfillableAnnotator)
expect = AcquisitionFeed.error_message(
- pool.identifier, 403,
- "I know about this work but can offer no way of fulfilling it."
+ pool.identifier,
+ 403,
+ "I know about this work but can offer no way of fulfilling it.",
)
assert expect == entry
@@ -2303,6 +2359,7 @@ def test_create_entry_uses_cache_for_all_licensepools_for_work(self):
that Work, even LicensePools associated with different
identifiers.
"""
+
class InstrumentableActiveLicensePool(VerboseAnnotator):
"""A mock class that lets us control the output of
active_license_pool.
@@ -2313,6 +2370,7 @@ class InstrumentableActiveLicensePool(VerboseAnnotator):
@classmethod
def active_licensepool_for(cls, work):
return cls.ACTIVE
+
feed = self.feed(annotator=InstrumentableActiveLicensePool())
# Here are two completely different LicensePools for the same work.
@@ -2344,10 +2402,7 @@ def active_licensepool_for(cls, work):
work.license_pools = [pool1]
result = m((identifier2, work))
assert isinstance(result, OPDSMessage)
- assert (
- 'using a Work not associated with that identifier.'
- in result.message
- )
+ assert "using a Work not associated with that identifier." in result.message
class TestEntrypointLinkInsertion(DatabaseTest):
@@ -2363,13 +2418,12 @@ def setup_method(self):
class Mock(object):
def add_entrypoint_links(self, *args):
self.called_with = args
+
self.mock = Mock()
# A WorkList with no EntryPoints -- should not call the mock method.
self.no_eps = WorkList()
- self.no_eps.initialize(
- library=self._default_library, display_name="no_eps"
- )
+ self.no_eps.initialize(library=self._default_library, display_name="no_eps")
# A WorkList with two EntryPoints -- may call the mock method
# depending on circumstances.
@@ -2378,14 +2432,19 @@ def add_entrypoint_links(self, *args):
# The WorkList must have at least one child, or we won't generate
# a real groups feed for it.
self.lane = self._lane()
- self.wl.initialize(library=self._default_library, display_name="wl",
- entrypoints=self.entrypoints, children=[self.lane])
+ self.wl.initialize(
+ library=self._default_library,
+ display_name="wl",
+ entrypoints=self.entrypoints,
+ children=[self.lane],
+ )
def works(_db, **kwargs):
"""Mock WorkList.works so we don't need any actual works
to run the test.
"""
return []
+
self.no_eps.works = works
self.wl.works = works
@@ -2407,8 +2466,13 @@ def run(wl=None, facets=None):
"""
self.mock.called_with = None
AcquisitionFeed.groups(
- self._db, "title", "url", wl, self.annotator,
- max_age=0, facets=facets,
+ self._db,
+ "title",
+ "url",
+ wl,
+ self.annotator,
+ max_age=0,
+ facets=facets,
)
return self.mock.called_with
@@ -2420,7 +2484,7 @@ def run(wl=None, facets=None):
# to be called.
facets = FeaturedFacets(
minimum_featured_quality=self._default_library.minimum_featured_quality,
- entrypoint=EbooksEntryPoint
+ entrypoint=EbooksEntryPoint,
)
feed, make_link, entrypoints, selected = run(self.wl, facets)
@@ -2445,9 +2509,15 @@ def run(wl=None, facets=None, pagination=None):
self.mock.called_with = None
private = object()
AcquisitionFeed.page(
- self._db, "title", "url", wl, self.annotator,
- max_age=0, facets=facets,
- pagination=pagination, private=private
+ self._db,
+ "title",
+ "url",
+ wl,
+ self.annotator,
+ max_age=0,
+ facets=facets,
+ pagination=pagination,
+ private=private,
)
return self.mock.called_with
@@ -2470,15 +2540,15 @@ def run(wl=None, facets=None, pagination=None):
# The make_link function that was passed in calls
# TestAnnotator.feed_url() when passed an EntryPoint. The
# Facets object's other facet groups are propagated in this URL.
- first_page_url = "http://wl/?available=all&collection=full&entrypoint=Book&order=author"
+ first_page_url = (
+ "http://wl/?available=all&collection=full&entrypoint=Book&order=author"
+ )
assert first_page_url == make_link(EbooksEntryPoint)
# Pagination information is not propagated through entry point links
# -- you always start at the beginning of the list.
pagination = Pagination(offset=100)
- feed, make_link, entrypoints, selected = run(
- self.wl, facets, pagination
- )
+ feed, make_link, entrypoints, selected = run(self.wl, facets, pagination)
assert first_page_url == make_link(EbooksEntryPoint)
def test_search(self):
@@ -2491,15 +2561,22 @@ def run(wl=None, facets=None, pagination=None):
"""
self.mock.called_with = None
AcquisitionFeed.search(
- self._db, "title", "url", wl, None, None,
- annotator=self.annotator, facets=facets,
- pagination=pagination
+ self._db,
+ "title",
+ "url",
+ wl,
+ None,
+ None,
+ annotator=self.annotator,
+ facets=facets,
+ pagination=pagination,
)
return self.mock.called_with
# Mock search() so it never tries to return anything.
def mock_search(self, *args, **kwargs):
return []
+
self.no_eps.search = mock_search
self.wl.search = mock_search
@@ -2516,9 +2593,11 @@ def mock_search(self, *args, **kwargs):
# Since the SearchFacets has more than one entry point,
# the EverythingEntryPoint is prepended to the list of possible
# entry points.
- assert (
- [EverythingEntryPoint, AudiobooksEntryPoint, EbooksEntryPoint] ==
- entrypoints)
+ assert [
+ EverythingEntryPoint,
+ AudiobooksEntryPoint,
+ EbooksEntryPoint,
+ ] == entrypoints
# add_entrypoint_links was passed the three possible entry points
# and the selected entry point.
@@ -2526,20 +2605,19 @@ def mock_search(self, *args, **kwargs):
# The make_link function that was passed in calls
# TestAnnotator.search_url() when passed an EntryPoint.
- first_page_url = 'http://wl/?available=all&collection=full&entrypoint=Book&order=relevance'
+ first_page_url = (
+ "http://wl/?available=all&collection=full&entrypoint=Book&order=relevance"
+ )
assert first_page_url == make_link(EbooksEntryPoint)
# Pagination information is not propagated through entry point links
# -- you always start at the beginning of the list.
pagination = Pagination(offset=100)
- feed, make_link, entrypoints, selected = run(
- self.wl, facets, pagination
- )
+ feed, make_link, entrypoints, selected = run(self.wl, facets, pagination)
assert first_page_url == make_link(EbooksEntryPoint)
class TestNavigationFacets(object):
-
def test_feed_type(self):
# If a navigation feed is built via CachedFeed.fetch, it will be
# filed as a navigation feed.
@@ -2547,16 +2625,14 @@ def test_feed_type(self):
class TestNavigationFeed(DatabaseTest):
-
def setup_method(self):
super(TestNavigationFeed, self).setup_method()
self.fiction = self._lane("Fiction")
- self.fantasy = self._lane(
- "Fantasy", parent=self.fiction)
- self.romance = self._lane(
- "Romance", parent=self.fiction)
+ self.fantasy = self._lane("Fantasy", parent=self.fiction)
+ self.romance = self._lane("Romance", parent=self.fiction)
self.contemporary_romance = self._lane(
- "Contemporary Romance", parent=self.romance)
+ "Contemporary Romance", parent=self.romance
+ )
def test_add_entry(self):
feed = NavigationFeed("title", "http://navigation")
@@ -2572,8 +2648,13 @@ def test_add_entry(self):
def test_navigation_with_sublanes(self):
private = object()
response = NavigationFeed.navigation(
- self._db, "Navigation", "http://navigation",
- self.fiction, TestAnnotator, max_age=42, private=private
+ self._db,
+ "Navigation",
+ "http://navigation",
+ self.fiction,
+ TestAnnotator,
+ max_age=42,
+ private=private,
)
# We got an OPDSFeedResponse back. The values we passed in for
@@ -2616,8 +2697,8 @@ def test_navigation_with_sublanes(self):
def test_navigation_without_sublanes(self):
feed = NavigationFeed.navigation(
- self._db, "Navigation", "http://navigation",
- self.fantasy, TestAnnotator)
+ self._db, "Navigation", "http://navigation", self.fantasy, TestAnnotator
+ )
parsed = feedparser.parse(str(feed))
assert "Navigation" == parsed["feed"]["title"]
[self_link] = parsed["feed"]["links"]
diff --git a/tests/test_opds_import.py b/tests/test_opds_import.py
index 37927e42c..7991d9085 100644
--- a/tests/test_opds_import.py
+++ b/tests/test_opds_import.py
@@ -1,47 +1,28 @@
-import os
import datetime
+import os
+import pkgutil
import random
-from urllib.parse import quote
from io import StringIO
+from urllib.parse import quote
+
import feedparser
import pytest
-
from lxml import etree
-import pkgutil
from psycopg2.extras import NumericRange
-from ..testing import (
- DatabaseTest,
-)
-
-from ..config import (
- CannotLoadConfiguration,
- IntegrationException,
-)
-from ..opds_import import (
- AccessNotAuthenticated,
- MetadataWranglerOPDSLookup,
- OPDSImporter,
- OPDSImportMonitor,
- OPDSXMLParser,
- SimplifiedOPDSLookup,
-)
-from ..metadata_layer import (
- LinkData,
- CirculationData,
- Metadata,
- TimestampData,
-)
+from ..config import CannotLoadConfiguration, IntegrationException
+from ..coverage import CoverageFailure
+from ..metadata_layer import CirculationData, LinkData, Metadata, TimestampData
from ..model import (
Collection,
Contributor,
CoverageRecord,
DataSource,
DeliveryMechanism,
+ Edition,
ExternalIntegration,
Hyperlink,
Identifier,
- Edition,
Measurement,
MediaTypes,
Representation,
@@ -51,40 +32,48 @@
WorkCoverageRecord,
)
from ..model.configuration import ExternalIntegrationLink
-from ..coverage import CoverageFailure
-from ..s3 import (
- S3Uploader,
- MockS3Uploader,
- S3UploaderConfiguration)
+from ..opds_import import (
+ AccessNotAuthenticated,
+ MetadataWranglerOPDSLookup,
+ OPDSImporter,
+ OPDSImportMonitor,
+ OPDSXMLParser,
+ SimplifiedOPDSLookup,
+)
+from ..s3 import MockS3Uploader, S3Uploader, S3UploaderConfiguration
from ..selftest import SelfTestResult
from ..testing import (
+ DatabaseTest,
DummyHTTPClient,
MockRequestsRequest,
MockRequestsResponse,
)
-from ..util.http import BadResponseException
-from ..util.opds_writer import (
- AtomFeed,
- OPDSFeed,
- OPDSMessage,
-)
from ..util.datetime_helpers import datetime_utc, utc_now
+from ..util.http import BadResponseException
+from ..util.opds_writer import AtomFeed, OPDSFeed, OPDSMessage
+
class DoomedOPDSImporter(OPDSImporter):
def import_edition_from_metadata(self, metadata, *args):
if metadata.title == "Johnny Crow's Party":
# This import succeeds.
- return super(DoomedOPDSImporter, self).import_edition_from_metadata(metadata, *args)
+ return super(DoomedOPDSImporter, self).import_edition_from_metadata(
+ metadata, *args
+ )
else:
# Any other import fails.
raise Exception("Utter failure!")
+
class DoomedWorkOPDSImporter(OPDSImporter):
"""An OPDS Importer that imports editions but can't create works."""
+
def update_work_for_edition(self, edition, *args, **kwargs):
if edition.title == "Johnny Crow's Party":
# This import succeeds.
- return super(DoomedWorkOPDSImporter, self).update_work_for_edition(edition, *args, **kwargs)
+ return super(DoomedWorkOPDSImporter, self).update_work_for_edition(
+ edition, *args, **kwargs
+ )
else:
# Any other import fails.
raise Exception("Utter work failure!")
@@ -100,16 +89,16 @@ def sample_opds(self, filename, file_type="r"):
class TestMetadataWranglerOPDSLookup(OPDSTest):
-
def setup_method(self):
super(TestMetadataWranglerOPDSLookup, self).setup_method()
self.integration = self._external_integration(
ExternalIntegration.METADATA_WRANGLER,
goal=ExternalIntegration.METADATA_GOAL,
- password='secret', url="http://metadata.in"
+ password="secret",
+ url="http://metadata.in",
)
self.collection = self._collection(
- protocol=ExternalIntegration.OVERDRIVE, external_account_id='library'
+ protocol=ExternalIntegration.OVERDRIVE, external_account_id="library"
)
def test_authenticates_wrangler_requests(self):
@@ -129,68 +118,69 @@ def test_authenticates_wrangler_requests(self):
def test_add_args(self):
lookup = MetadataWranglerOPDSLookup.from_config(self._db)
- args = 'greeting=hello'
+ args = "greeting=hello"
# If the base url doesn't have any arguments, args are created.
base_url = self._url
- assert base_url + '?' + args == lookup.add_args(base_url, args)
+ assert base_url + "?" + args == lookup.add_args(base_url, args)
# If the base url has an argument already, additional args are appended.
- base_url = self._url + '?data_source=banana'
- assert base_url + '&' + args == lookup.add_args(base_url, args)
+ base_url = self._url + "?data_source=banana"
+ assert base_url + "&" + args == lookup.add_args(base_url, args)
def test_get_collection_url(self):
lookup = MetadataWranglerOPDSLookup.from_config(self._db)
# If the lookup client doesn't have a Collection, an error is
# raised.
- pytest.raises(
- ValueError, lookup.get_collection_url, 'banana'
- )
+ pytest.raises(ValueError, lookup.get_collection_url, "banana")
# If the lookup client isn't authenticated, an error is raised.
lookup.collection = self.collection
lookup.shared_secret = None
- pytest.raises(
- AccessNotAuthenticated, lookup.get_collection_url, 'banana'
- )
+ pytest.raises(AccessNotAuthenticated, lookup.get_collection_url, "banana")
# With both authentication and a specific Collection,
# a URL is returned.
- lookup.shared_secret = 'secret'
- expected = '%s%s/banana' % (lookup.base_url, self.collection.metadata_identifier)
- assert expected == lookup.get_collection_url('banana')
+ lookup.shared_secret = "secret"
+ expected = "%s%s/banana" % (
+ lookup.base_url,
+ self.collection.metadata_identifier,
+ )
+ assert expected == lookup.get_collection_url("banana")
# With an OPDS_IMPORT collection, a data source is included
opds = self._collection(
protocol=ExternalIntegration.OPDS_IMPORT,
external_account_id=self._url,
- data_source_name=DataSource.OA_CONTENT_SERVER
+ data_source_name=DataSource.OA_CONTENT_SERVER,
)
lookup.collection = opds
- data_source_args = '?data_source=%s' % quote(opds.data_source.name)
- assert lookup.get_collection_url('banana').endswith(data_source_args)
+ data_source_args = "?data_source=%s" % quote(opds.data_source.name)
+ assert lookup.get_collection_url("banana").endswith(data_source_args)
def test_lookup_endpoint(self):
# A Collection-specific endpoint is returned if authentication
# and a Collection is available.
- lookup = MetadataWranglerOPDSLookup.from_config(self._db, collection=self.collection)
+ lookup = MetadataWranglerOPDSLookup.from_config(
+ self._db, collection=self.collection
+ )
- expected = self.collection.metadata_identifier + '/lookup'
+ expected = self.collection.metadata_identifier + "/lookup"
assert expected == lookup.lookup_endpoint
# Without a collection, an unspecific endpoint is returned.
lookup.collection = None
- assert 'lookup' == lookup.lookup_endpoint
+ assert "lookup" == lookup.lookup_endpoint
# Without authentication, an unspecific endpoint is returned.
lookup.shared_secret = None
lookup.collection = self.collection
- assert 'lookup' == lookup.lookup_endpoint
+ assert "lookup" == lookup.lookup_endpoint
# With authentication and a collection, a specific endpoint is returned.
- lookup.shared_secret = 'secret'
- expected = '%s/lookup' % self.collection.metadata_identifier
+ lookup.shared_secret = "secret"
+ expected = "%s/lookup" % self.collection.metadata_identifier
assert expected == lookup.lookup_endpoint
# Tests of the self-test framework.
@@ -256,7 +246,8 @@ class Mock(MetadataWranglerOPDSLookup):
def _feed_self_test(self, title, method, *args):
self.feed_self_tests.append((title, method, args))
return "A feed self-test for %s: %s" % (
- self.collection.unique_account_id, title
+ self.collection.unique_account_id,
+ title,
)
# If there is no associated collection, _run_collection_self_tests()
@@ -266,9 +257,7 @@ def _feed_self_test(self, title, method, *args):
# Same if there is an associated collection but it has no
# metadata identifier.
- with_collection = Mock(
- "http://url/", collection=self._default_collection
- )
+ with_collection = Mock("http://url/", collection=self._default_collection)
assert [] == list(with_collection._run_collection_self_tests())
# If there is a metadata identifier, our mocked
@@ -276,12 +265,11 @@ def _feed_self_test(self, title, method, *args):
self._default_collection.external_account_id = "unique-id"
[r1, r2] = with_collection._run_collection_self_tests()
+ assert "A feed self-test for unique-id: Metadata updates in last 24 hours" == r1
assert (
- 'A feed self-test for unique-id: Metadata updates in last 24 hours' ==
- r1)
- assert (
- "A feed self-test for unique-id: Titles where we could (but haven't) provide information to the metadata wrangler" ==
- r2)
+ "A feed self-test for unique-id: Titles where we could (but haven't) provide information to the metadata wrangler"
+ == r2
+ )
# Let's make sure _feed_self_test() was called with the right
# arguments.
@@ -290,7 +278,7 @@ def _feed_self_test(self, title, method, *args):
# The first self-test wants to count updates for the last 24
# hours.
title1, method1, args1 = call1
- assert 'Metadata updates in last 24 hours' == title1
+ assert "Metadata updates in last 24 hours" == title1
assert with_collection.updates == method1
[timestamp] = args1
one_day_ago = utc_now() - datetime.timedelta(hours=24)
@@ -300,8 +288,9 @@ def _feed_self_test(self, title, method, *args):
# wrangler needs done but hasn't been done yet.
title2, method2, args2 = call2
assert (
- "Titles where we could (but haven't) provide information to the metadata wrangler" ==
- title2)
+ "Titles where we could (but haven't) provide information to the metadata wrangler"
+ == title2
+ )
assert with_collection.metadata_needed == method2
assert () == args2
@@ -313,6 +302,7 @@ def test__feed_self_test(self):
class Mock(MetadataWranglerOPDSLookup):
requests = []
annotated_responses = []
+
@classmethod
def _annotate_feed_response(cls, result, response):
cls.annotated_responses.append((result, response))
@@ -346,11 +336,9 @@ def make_some_request(self, *args, **kwargs):
# and a keyword argument indicating that 5xx responses should
# be processed normally and not used as a reason to raise an
# exception.
- assert (
- [((1, 2),
- {'allowed_response_codes': ['1xx', '2xx', '3xx', '4xx', '5xx']})
- ] ==
- lookup.requests)
+ assert [
+ ((1, 2), {"allowed_response_codes": ["1xx", "2xx", "3xx", "4xx", "5xx"]})
+ ] == lookup.requests
# That method returned "A fake response", which was passed
# into _annotate_feed_response, along with the
@@ -362,6 +350,7 @@ def make_some_request(self, *args, **kwargs):
def test__annotate_feed_response(self):
# Test the _annotate_feed_response class helper method.
m = MetadataWranglerOPDSLookup._annotate_feed_response
+
def mock_response(url, authorization, response_code, content):
request = MockRequestsRequest(
url, headers=dict(Authorization=authorization)
@@ -372,36 +361,32 @@ def mock_response(url, authorization, response_code, content):
return response
# First, test success.
- url = "http://metadata-wrangler/",
+ url = ("http://metadata-wrangler/",)
auth = "auth"
test_result = SelfTestResult("success")
response = mock_response(
- url, auth, 200,
- self.sample_opds("metadata_wrangler_overdrive.opds")
+ url, auth, 200, self.sample_opds("metadata_wrangler_overdrive.opds")
)
results = m(test_result, response)
assert [
- 'Request URL: %s' % url,
- 'Request authorization: %s' % auth,
- 'Status code: 200',
- 'Total identifiers registered with this collection: 201',
- 'Entries on this page: 1',
- ' The Green Mouse'
+ "Request URL: %s" % url,
+ "Request authorization: %s" % auth,
+ "Status code: 200",
+ "Total identifiers registered with this collection: 201",
+ "Entries on this page: 1",
+ " The Green Mouse",
] == test_result.result
assert True == test_result.success
# Next, test failure.
- response = mock_response(
- url, auth, 401,
- "An error message."
- )
+ response = mock_response(url, auth, 401, "An error message.")
test_result = SelfTestResult("failure")
assert False == test_result.success
m(test_result, response)
assert [
- 'Request URL: %s' % url,
- 'Request authorization: %s' % auth,
- 'Status code: 401',
+ "Request URL: %s" % url,
+ "Request authorization: %s" % auth,
+ "Status code: 401",
] == test_result.result
def test_external_integration(self):
@@ -409,29 +394,30 @@ def test_external_integration(self):
assert result.protocol == ExternalIntegration.METADATA_WRANGLER
assert result.goal == ExternalIntegration.METADATA_GOAL
-class OPDSImporterTest(OPDSTest):
+class OPDSImporterTest(OPDSTest):
def setup_method(self):
super(OPDSImporterTest, self).setup_method()
self.content_server_feed = self.sample_opds("content_server.opds")
self.content_server_mini_feed = self.sample_opds("content_server_mini.opds")
self.audiobooks_opds = self.sample_opds("audiobooks.opds")
- self.feed_with_id_and_dcterms_identifier = self.sample_opds("feed_with_id_and_dcterms_identifier.opds", "rb")
- self._default_collection.external_integration.setting('data_source').value = (
- DataSource.OA_CONTENT_SERVER
+ self.feed_with_id_and_dcterms_identifier = self.sample_opds(
+ "feed_with_id_and_dcterms_identifier.opds", "rb"
)
+ self._default_collection.external_integration.setting(
+ "data_source"
+ ).value = DataSource.OA_CONTENT_SERVER
# Set an ExternalIntegration for the metadata_client used
# in the OPDSImporter.
self.service = self._external_integration(
ExternalIntegration.METADATA_WRANGLER,
goal=ExternalIntegration.METADATA_GOAL,
- url="http://localhost"
+ url="http://localhost",
)
class TestOPDSImporter(OPDSImporterTest):
-
def test_constructor(self):
# The default way of making HTTP requests is with
# Representation.cautious_http_get.
@@ -445,9 +431,7 @@ def test_constructor(self):
def test_data_source_autocreated(self):
name = "New data source " + self._str
- importer = OPDSImporter(
- self._db, collection=None, data_source_name=name
- )
+ importer = OPDSImporter(self._db, collection=None, data_source_name=name)
source1 = importer.data_source
assert name == source1.name
@@ -455,9 +439,7 @@ def test_extract_next_links(self):
importer = OPDSImporter(
self._db, collection=None, data_source_name=DataSource.NYT
)
- next_links = importer.extract_next_links(
- self.content_server_mini_feed
- )
+ next_links = importer.extract_next_links(self.content_server_mini_feed)
assert 1 == len(next_links)
assert "http://localhost:5000/?after=327&size=100" == next_links[0]
@@ -504,14 +486,12 @@ def test_extract_metadata(self):
importer = OPDSImporter(
self._db, collection=None, data_source_name=data_source_name
)
- metadata, failures = importer.extract_feed_data(
- self.content_server_mini_feed
- )
+ metadata, failures = importer.extract_feed_data(self.content_server_mini_feed)
- m1 = metadata['http://www.gutenberg.org/ebooks/10441']
- m2 = metadata['http://www.gutenberg.org/ebooks/10557']
- c1 = metadata['http://www.gutenberg.org/ebooks/10441']
- c2 = metadata['http://www.gutenberg.org/ebooks/10557']
+ m1 = metadata["http://www.gutenberg.org/ebooks/10441"]
+ m2 = metadata["http://www.gutenberg.org/ebooks/10557"]
+ c1 = metadata["http://www.gutenberg.org/ebooks/10441"]
+ c2 = metadata["http://www.gutenberg.org/ebooks/10557"]
assert "The Green Mouse" == m1.title
assert "A Tale of Mousy Terror" == m1.subtitle
@@ -522,14 +502,21 @@ def test_extract_metadata(self):
assert data_source_name == c2._data_source
[failure] = list(failures.values())
- assert "202: I'm working to locate a source for this identifier." == failure.exception
+ assert (
+ "202: I'm working to locate a source for this identifier."
+ == failure.exception
+ )
def test_use_dcterm_identifier_as_id_with_id_and_dcterms_identifier(self):
data_source_name = "Data source name " + self._str
collection_to_test = self._default_collection
- collection_to_test.primary_identifier_source = ExternalIntegration.DCTERMS_IDENTIFIER
+ collection_to_test.primary_identifier_source = (
+ ExternalIntegration.DCTERMS_IDENTIFIER
+ )
importer = OPDSImporter(
- self._db, collection=collection_to_test, data_source_name=data_source_name,
+ self._db,
+ collection=collection_to_test,
+ data_source_name=data_source_name,
)
metadata, failures = importer.extract_feed_data(
@@ -537,29 +524,29 @@ def test_use_dcterm_identifier_as_id_with_id_and_dcterms_identifier(self):
)
# First book doesn't have , so must be used as identifier
- book_1 = metadata.get('https://root.uri/1')
+ book_1 = metadata.get("https://root.uri/1")
assert book_1 != None
# Second book have and , so must be used as id
- book_2 = metadata.get('urn:isbn:9781468316438')
+ book_2 = metadata.get("urn:isbn:9781468316438")
assert book_2 != None
# Verify if id was add in the end of identifier
book_2_identifiers = book_2.identifiers
found = False
for entry in book_2.identifiers:
- if entry.identifier == 'https://root.uri/2':
+ if entry.identifier == "https://root.uri/2":
found = True
break
assert found == True
# Third book has more than one dcterms:identifers, all of then must be present as metadata identifier
- book_3 = metadata.get('urn:isbn:9781683351993')
+ book_3 = metadata.get("urn:isbn:9781683351993")
assert book_2 != None
# Verify if id was add in the end of identifier
book_3_identifiers = book_3.identifiers
expected_identifier = [
- '9781683351993',
- 'https://root.uri/3',
- '9781683351504',
- '9780312939458',
+ "9781683351993",
+ "https://root.uri/3",
+ "9781683351504",
+ "9780312939458",
]
result_identifier = [entry.identifier for entry in book_3.identifiers]
assert set(expected_identifier) == set(result_identifier)
@@ -569,18 +556,20 @@ def test_use_id_with_existing_dcterms_identifier(self):
collection_to_test = self._default_collection
collection_to_test.primary_identifier_source = None
importer = OPDSImporter(
- self._db, collection=collection_to_test, data_source_name=data_source_name,
+ self._db,
+ collection=collection_to_test,
+ data_source_name=data_source_name,
)
metadata, failures = importer.extract_feed_data(
self.feed_with_id_and_dcterms_identifier
)
- book_1 = metadata.get('https://root.uri/1')
+ book_1 = metadata.get("https://root.uri/1")
assert book_1 != None
- book_2 = metadata.get('https://root.uri/2')
+ book_2 = metadata.get("https://root.uri/2")
assert book_2 != None
- book_3 = metadata.get('https://root.uri/3')
+ book_3 = metadata.get("https://root.uri/3")
assert book_3 != None
def test_extract_link(self):
@@ -601,18 +590,19 @@ def test_extract_link(self):
def test_get_medium_from_links(self):
audio_links = [
- LinkData(href="url", rel="http://opds-spec.org/acquisition/",
- media_type="application/audiobook+json;param=value"
+ LinkData(
+ href="url",
+ rel="http://opds-spec.org/acquisition/",
+ media_type="application/audiobook+json;param=value",
),
LinkData(href="url", rel="http://opds-spec.org/image"),
]
book_links = [
LinkData(href="url", rel="http://opds-spec.org/image"),
LinkData(
- href="url", rel="http://opds-spec.org/acquisition/",
- media_type=random.choice(
- MediaTypes.BOOK_MEDIA_TYPES
- ) + ";param=value"
+ href="url",
+ rel="http://opds-spec.org/acquisition/",
+ media_type=random.choice(MediaTypes.BOOK_MEDIA_TYPES) + ";param=value",
),
]
@@ -627,17 +617,13 @@ def test_extract_link_rights_uri(self):
entry_rights = RightsStatus.PUBLIC_DOMAIN_USA
link_tag = AtomFeed.E.link(href="http://foo", rel="bar")
- link = OPDSImporter.extract_link(
- link_tag, entry_rights_uri=entry_rights
- )
+ link = OPDSImporter.extract_link(link_tag, entry_rights_uri=entry_rights)
assert RightsStatus.PUBLIC_DOMAIN_USA == link.rights_uri
# But a dcterms:rights tag beneath the link can override this.
rights_attr = "{%s}rights" % AtomFeed.DCTERMS_NS
link_tag.attrib[rights_attr] = RightsStatus.IN_COPYRIGHT
- link = OPDSImporter.extract_link(
- link_tag, entry_rights_uri=entry_rights
- )
+ link = OPDSImporter.extract_link(link_tag, entry_rights_uri=entry_rights)
assert RightsStatus.IN_COPYRIGHT == link.rights_uri
def test_extract_data_from_feedparser(self):
@@ -648,14 +634,14 @@ def test_extract_data_from_feedparser(self):
)
# The tag became a Metadata object.
- metadata = values['urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441']
- assert "The Green Mouse" == metadata['title']
- assert "A Tale of Mousy Terror" == metadata['subtitle']
- assert 'en' == metadata['language']
- assert 'Project Gutenberg' == metadata['publisher']
+ metadata = values["urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441"]
+ assert "The Green Mouse" == metadata["title"]
+ assert "A Tale of Mousy Terror" == metadata["subtitle"]
+ assert "en" == metadata["language"]
+ assert "Project Gutenberg" == metadata["publisher"]
- circulation = metadata['circulation']
- assert DataSource.GUTENBERG == circulation['data_source']
+ circulation = metadata["circulation"]
+ assert DataSource.GUTENBERG == circulation["data_source"]
# The tag did not become a
# CoverageFailure -- that's handled by
@@ -665,12 +651,15 @@ def test_extract_data_from_feedparser(self):
def test_extract_data_from_feedparser_handles_exception(self):
class DoomedFeedparserOPDSImporter(OPDSImporter):
"""An importer that can't extract metadata from feedparser."""
+
@classmethod
def _data_detail_for_feedparser_entry(cls, entry, data_source):
raise Exception("Utter failure!")
data_source = DataSource.lookup(self._db, DataSource.OA_CONTENT_SERVER)
- importer = DoomedFeedparserOPDSImporter(self._db, None, data_source_name=data_source.name)
+ importer = DoomedFeedparserOPDSImporter(
+ self._db, None, data_source_name=data_source.name
+ )
values, failures = importer.extract_data_from_feedparser(
self.content_server_mini_feed, data_source
)
@@ -684,13 +673,13 @@ def _data_detail_for_feedparser_entry(cls, entry, data_source):
assert 2 == len(failures)
# The first error message became a CoverageFailure.
- failure = failures['urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441']
+ failure = failures["urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441"]
assert isinstance(failure, CoverageFailure)
assert True == failure.transient
assert "Utter failure!" in failure.exception
# The second error message became a CoverageFailure.
- failure = failures['urn:librarysimplified.org/terms/id/Gutenberg%20ID/10557']
+ failure = failures["urn:librarysimplified.org/terms/id/Gutenberg%20ID/10557"]
assert isinstance(failure, CoverageFailure)
assert True == failure.transient
assert "Utter failure!" in failure.exception
@@ -711,50 +700,58 @@ def test_extract_metadata_from_elementtree(self):
# We're going to do spot checks on a book and a periodical.
# First, the book.
- book_id = 'urn:librarysimplified.org/terms/id/Gutenberg%20ID/1022'
+ book_id = "urn:librarysimplified.org/terms/id/Gutenberg%20ID/1022"
book = data[book_id]
- assert Edition.BOOK_MEDIUM == book['medium']
+ assert Edition.BOOK_MEDIUM == book["medium"]
- [contributor] = book['contributors']
+ [contributor] = book["contributors"]
assert "Thoreau, Henry David" == contributor.sort_name
assert [Contributor.AUTHOR_ROLE] == contributor.roles
- subjects = book['subjects']
- assert ['LCSH', 'LCSH', 'LCSH', 'LCC'] == [x.type for x in subjects]
- assert (
- ['Essays', 'Nature', 'Walking', 'PS'] ==
- [x.identifier for x in subjects])
- assert (
- [None, None, None, 'American Literature'] ==
- [x.name for x in book['subjects']])
- assert (
- [1, 1, 1, 10] ==
- [x.weight for x in book['subjects']])
+ subjects = book["subjects"]
+ assert ["LCSH", "LCSH", "LCSH", "LCC"] == [x.type for x in subjects]
+ assert ["Essays", "Nature", "Walking", "PS"] == [x.identifier for x in subjects]
+ assert [None, None, None, "American Literature"] == [
+ x.name for x in book["subjects"]
+ ]
+ assert [1, 1, 1, 10] == [x.weight for x in book["subjects"]]
- assert [] == book['measurements']
+ assert [] == book["measurements"]
assert datetime_utc(1862, 6, 1) == book["published"]
- [link] = book['links']
+ [link] = book["links"]
assert Hyperlink.OPEN_ACCESS_DOWNLOAD == link.rel
assert "http://www.gutenberg.org/ebooks/1022.epub.noimages" == link.href
assert Representation.EPUB_MEDIA_TYPE == link.media_type
# And now, the periodical.
- periodical_id = 'urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441'
+ periodical_id = "urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441"
periodical = data[periodical_id]
- assert Edition.PERIODICAL_MEDIUM == periodical['medium']
+ assert Edition.PERIODICAL_MEDIUM == periodical["medium"]
- subjects = periodical['subjects']
- assert (
- ['LCSH', 'LCSH', 'LCSH', 'LCSH', 'LCC', 'schema:audience', 'schema:typicalAgeRange'] ==
- [x.type for x in subjects])
- assert (
- ['Courtship -- Fiction', 'New York (N.Y.) -- Fiction', 'Fantasy fiction', 'Magic -- Fiction', 'PZ', 'Children', '7'] ==
- [x.identifier for x in subjects])
+ subjects = periodical["subjects"]
+ assert [
+ "LCSH",
+ "LCSH",
+ "LCSH",
+ "LCSH",
+ "LCC",
+ "schema:audience",
+ "schema:typicalAgeRange",
+ ] == [x.type for x in subjects]
+ assert [
+ "Courtship -- Fiction",
+ "New York (N.Y.) -- Fiction",
+ "Fantasy fiction",
+ "Magic -- Fiction",
+ "PZ",
+ "Children",
+ "7",
+ ] == [x.identifier for x in subjects]
assert [1, 1, 1, 1, 1, 1, 1] == [x.weight for x in subjects]
- r1, r2, r3 = periodical['measurements']
+ r1, r2, r3 = periodical["measurements"]
assert Measurement.QUALITY == r1.quantity_measured
assert 0.3333 == r1.value
@@ -768,8 +765,8 @@ def test_extract_metadata_from_elementtree(self):
assert 0.25 == r3.value
assert 1 == r3.weight
- assert 'Animal Colors' == periodical['series']
- assert '1' == periodical['series_position']
+ assert "Animal Colors" == periodical["series"]
+ assert "1" == periodical["series_position"]
assert datetime_utc(1910, 1, 1) == periodical["published"]
@@ -786,7 +783,7 @@ def test_extract_metadata_from_elementtree_treats_message_as_failure(self):
# The CoverageFailure contains the information that was in a
# tag in unrecognized_identifier.opds.
- key = 'http://www.gutenberg.org/ebooks/100'
+ key = "http://www.gutenberg.org/ebooks/100"
assert [key] == list(failures.keys())
failure = failures[key]
assert "404: I've never heard of this work." == failure.exception
@@ -797,7 +794,7 @@ def test_extract_messages(self):
feed = self.sample_opds("unrecognized_identifier.opds")
root = etree.parse(StringIO(feed))
[message] = OPDSImporter.extract_messages(parser, root)
- assert 'urn:librarysimplified.org/terms/id/Gutenberg ID/100' == message.urn
+ assert "urn:librarysimplified.org/terms/id/Gutenberg ID/100" == message.urn
assert 404 == message.status_code
assert "I've never heard of this work." == message.message
@@ -810,13 +807,13 @@ def test_extract_medium(self):
def medium(additional_type, format, default="Default"):
# Make an tag with the given tags.
# Parse it and call extract_medium on it.
- entry= ''
+ entry += ">"
if format:
- entry += '%s' % format
- entry += ''
+ entry += "%s" % format
+ entry += ""
tag = etree.parse(StringIO(entry))
return m(tag.getroot(), default=default)
@@ -826,12 +823,10 @@ def medium(additional_type, format, default="Default"):
# schema:additionalType is checked first. If present, any
# potentially contradictory information in dcterms:format is
# ignored.
- assert (
- Edition.AUDIO_MEDIUM ==
- medium("http://bib.schema.org/Audiobook", ebook_type))
- assert (
- Edition.BOOK_MEDIUM ==
- medium("http://schema.org/EBook", audio_type))
+ assert Edition.AUDIO_MEDIUM == medium(
+ "http://bib.schema.org/Audiobook", ebook_type
+ )
+ assert Edition.BOOK_MEDIUM == medium("http://schema.org/EBook", audio_type)
# When schema:additionalType is missing or not useful, the
# value of dcterms:format is mapped to a medium using
@@ -847,11 +842,12 @@ def medium(additional_type, format, default="Default"):
def test_handle_failure(self):
axis_id = self._identifier(identifier_type=Identifier.AXIS_360_ID)
axis_isbn = self._identifier(Identifier.ISBN, "9781453219539")
- identifier_mapping = {axis_isbn : axis_id}
+ identifier_mapping = {axis_isbn: axis_id}
importer = OPDSImporter(
- self._db, collection=None,
+ self._db,
+ collection=None,
data_source_name=DataSource.OA_CONTENT_SERVER,
- identifier_mapping = identifier_mapping
+ identifier_mapping=identifier_mapping,
)
# The simplest case -- an identifier associated with a
@@ -861,9 +857,7 @@ def test_handle_failure(self):
urn = "urn:isbn:9781449358068"
expect_identifier, ignore = Identifier.parse_urn(self._db, urn)
- identifier, output_failure = importer.handle_failure(
- urn, input_failure
- )
+ identifier, output_failure = importer.handle_failure(urn, input_failure)
assert expect_identifier == identifier
assert input_failure == output_failure
@@ -896,17 +890,15 @@ def test_handle_failure(self):
assert axis_id == identifier
assert axis_id == not_a_failure
-
def test_coveragefailure_from_message(self):
"""Test all the different ways a tag might
become a CoverageFailure.
"""
data_source = DataSource.lookup(self._db, DataSource.OA_CONTENT_SERVER)
+
def f(*args):
message = OPDSMessage(*args)
- return OPDSImporter.coveragefailure_from_message(
- data_source, message
- )
+ return OPDSImporter.coveragefailure_from_message(data_source, message)
# If the URN is invalid we can't create a CoverageFailure.
invalid_urn = f("urnblah", "500", "description")
@@ -938,6 +930,7 @@ def test_coveragefailure_from_message_with_success_status_codes(self):
"""When an OPDSImporter defines SUCCESS_STATUS_CODES, messages with
those status codes are always treated as successes.
"""
+
class Mock(OPDSImporter):
SUCCESS_STATUS_CODES = [200, 999]
@@ -966,12 +959,16 @@ def f(*args):
assert isinstance(failure, CoverageFailure)
assert "500: hooray???" == failure.exception
- def test_extract_metadata_from_elementtree_handles_messages_that_become_identifiers(self):
+ def test_extract_metadata_from_elementtree_handles_messages_that_become_identifiers(
+ self,
+ ):
not_a_failure = self._identifier()
+
class MockOPDSImporter(OPDSImporter):
@classmethod
def coveragefailures_from_messages(
- cls, data_source, message, success_on_200=False):
+ cls, data_source, message, success_on_200=False
+ ):
"""No matter what input we get, we act as though there were
a single simplified:message tag in the OPDS feed, which we
decided to treat as success rather than failure.
@@ -985,17 +982,20 @@ def coveragefailures_from_messages(
)
assert {not_a_failure.urn: not_a_failure} == failures
-
def test_extract_metadata_from_elementtree_handles_exception(self):
class DoomedElementtreeOPDSImporter(OPDSImporter):
"""An importer that can't extract metadata from elementttree."""
+
@classmethod
def _detail_for_elementtree_entry(cls, *args, **kwargs):
raise Exception("Utter failure!")
data_source = DataSource.lookup(self._db, DataSource.OA_CONTENT_SERVER)
- values, failures = DoomedElementtreeOPDSImporter.extract_metadata_from_elementtree(
+ (
+ values,
+ failures,
+ ) = DoomedElementtreeOPDSImporter.extract_metadata_from_elementtree(
self.content_server_mini_feed, data_source
)
@@ -1009,20 +1009,20 @@ def _detail_for_elementtree_entry(cls, *args, **kwargs):
# The entry with the 202 message became an appropriate
# CoverageFailure because its data was not extracted through
# extract_metadata_from_elementtree.
- failure = failures['http://www.gutenberg.org/ebooks/1984']
+ failure = failures["http://www.gutenberg.org/ebooks/1984"]
assert isinstance(failure, CoverageFailure)
assert True == failure.transient
- assert failure.exception.startswith('202')
- assert 'Utter failure!' not in failure.exception
+ assert failure.exception.startswith("202")
+ assert "Utter failure!" not in failure.exception
# The other entries became generic CoverageFailures due to the failure
# of extract_metadata_from_elementtree.
- failure = failures['urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441']
+ failure = failures["urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441"]
assert isinstance(failure, CoverageFailure)
assert True == failure.transient
assert "Utter failure!" in failure.exception
- failure = failures['urn:librarysimplified.org/terms/id/Gutenberg%20ID/10557']
+ failure = failures["urn:librarysimplified.org/terms/id/Gutenberg%20ID/10557"]
assert isinstance(failure, CoverageFailure)
assert True == failure.transient
assert "Utter failure!" in failure.exception
@@ -1033,13 +1033,12 @@ def test_import_exception_if_unable_to_parse_feed(self):
pytest.raises(etree.XMLSyntaxError, importer.import_from_feed, feed)
-
def test_import(self):
feed = self.content_server_mini_feed
- imported_editions, pools, works, failures = (
- OPDSImporter(self._db, collection=None).import_from_feed(feed)
- )
+ imported_editions, pools, works, failures = OPDSImporter(
+ self._db, collection=None
+ ).import_from_feed(feed)
[crow, mouse] = sorted(imported_editions, key=lambda x: x.title)
@@ -1069,7 +1068,7 @@ def test_import(self):
# A Representation was imported for the image with a media type
# inferred from its URL.
image_rep = image.resource.representation
- assert image_rep.url.endswith('_9.png')
+ assert image_rep.url.endswith("_9.png")
assert Representation.PNG_MEDIA_TYPE == image_rep.media_type
# The thumbnail was imported similarly, and its representation
@@ -1087,16 +1086,15 @@ def test_import(self):
# distinctive extension, and we have not actually retrieved
# the URLs yet, we were not able to determine their media type,
# so they have no associated Representation.
- assert broken_image.resource.url.endswith('/broken-cover-image')
- assert working_image.resource.url.endswith('/working-cover-image')
+ assert broken_image.resource.url.endswith("/broken-cover-image")
+ assert working_image.resource.url.endswith("/working-cover-image")
assert None == broken_image.resource.representation
assert None == working_image.resource.representation
# Three measurements have been added to the 'mouse' edition.
popularity, quality, rating = sorted(
- [x for x in mouse.primary_identifier.measurements
- if x.is_most_recent],
- key=lambda x: x.quantity_measured
+ [x for x in mouse.primary_identifier.measurements if x.is_most_recent],
+ key=lambda x: x.quantity_measured,
)
assert DataSource.METADATA_WRANGLER == popularity.data_source.name
@@ -1112,8 +1110,8 @@ def test_import(self):
assert 0.6 == rating.value
seven, children, courtship, fantasy, pz, magic, new_york = sorted(
- mouse.primary_identifier.classifications,
- key=lambda x: x.subject.name)
+ mouse.primary_identifier.classifications, key=lambda x: x.subject.name
+ )
pz_s = pz.subject
assert "Juvenile Fiction" == pz_s.name
@@ -1123,28 +1121,27 @@ def test_import(self):
assert "New York (N.Y.) -- Fiction" == new_york_s.name
assert "sh2008108377" == new_york_s.identifier
- assert '7' == seven.subject.identifier
+ assert "7" == seven.subject.identifier
assert 100 == seven.weight
assert Subject.AGE_RANGE == seven.subject.type
from ..classifier import Classifier
+
classifier = Classifier.classifiers.get(seven.subject.type, None)
classifier.classify(seven.subject)
# If we import the same file again, we get the same list of Editions.
- imported_editions_2, pools_2, works_2, failures_2 = (
- OPDSImporter(self._db, collection=None).import_from_feed(feed)
- )
+ imported_editions_2, pools_2, works_2, failures_2 = OPDSImporter(
+ self._db, collection=None
+ ).import_from_feed(feed)
assert imported_editions_2 == imported_editions
# importing with a collection and a lendable data source makes
# license pools and works.
- imported_editions, pools, works, failures = (
- OPDSImporter(
- self._db,
- collection=self._default_collection,
- data_source_name=DataSource.OA_CONTENT_SERVER
- ).import_from_feed(feed)
- )
+ imported_editions, pools, works, failures = OPDSImporter(
+ self._db,
+ collection=self._default_collection,
+ data_source_name=DataSource.OA_CONTENT_SERVER,
+ ).import_from_feed(feed)
[crow_pool, mouse_pool] = sorted(
pools, key=lambda x: x.presentation_edition.title
@@ -1163,14 +1160,13 @@ def test_import(self):
work.calculate_presentation()
assert 0.4142 == round(work.quality, 4)
assert Classifier.AUDIENCE_CHILDREN == work.audience
- assert NumericRange(7,7, '[]') == work.target_age
+ assert NumericRange(7, 7, "[]") == work.target_age
# Bonus: make sure that delivery mechanisms are set appropriately.
[mech] = mouse_pool.delivery_mechanisms
assert Representation.EPUB_MEDIA_TYPE == mech.delivery_mechanism.content_type
assert DeliveryMechanism.NO_DRM == mech.delivery_mechanism.drm_scheme
- assert ('http://www.gutenberg.org/ebooks/10441.epub.images' ==
- mech.resource.url)
+ assert "http://www.gutenberg.org/ebooks/10441.epub.images" == mech.resource.url
def test_import_with_lendability(self):
"""Test that OPDS import creates Edition, LicensePool, and Work
@@ -1185,12 +1181,14 @@ def test_import_with_lendability(self):
# This import will create Editions, but not LicensePools or
# Works, because there is no Collection.
importer_mw = OPDSImporter(
- self._db, collection=None,
- data_source_name=DataSource.METADATA_WRANGLER
- )
- imported_editions_mw, pools_mw, works_mw, failures_mw = (
- importer_mw.import_from_feed(feed)
+ self._db, collection=None, data_source_name=DataSource.METADATA_WRANGLER
)
+ (
+ imported_editions_mw,
+ pools_mw,
+ works_mw,
+ failures_mw,
+ ) = importer_mw.import_from_feed(feed)
# Both editions were imported, because they were new.
assert 2 == len(imported_editions_mw)
@@ -1205,10 +1203,11 @@ def test_import_with_lendability(self):
# Try again, with a Collection to contain the LicensePools.
importer_g = OPDSImporter(
- self._db, collection=self._default_collection,
+ self._db,
+ collection=self._default_collection,
)
- imported_editions_g, pools_g, works_g, failures_g = (
- importer_g.import_from_feed(feed)
+ imported_editions_g, pools_g, works_g, failures_g = importer_g.import_from_feed(
+ feed
)
# now pools and works are in, too
@@ -1217,8 +1216,9 @@ def test_import_with_lendability(self):
assert 2 == len(works_g)
# The pools have presentation editions.
- assert (set(["The Green Mouse", "Johnny Crow's Party"]) ==
- set([x.presentation_edition.title for x in pools_g]))
+ assert set(["The Green Mouse", "Johnny Crow's Party"]) == set(
+ [x.presentation_edition.title for x in pools_g]
+ )
# The information used to create the first LicensePool said
# that the licensing authority is Project Gutenberg, so that's used
@@ -1226,8 +1226,9 @@ def test_import_with_lendability(self):
# to create the second LicensePool didn't include a data source,
# so the source of the OPDS feed (the open-access content server)
# was used.
- assert set([DataSource.GUTENBERG, DataSource.OA_CONTENT_SERVER]) == \
- set([pool.data_source.name for pool in pools_g])
+ assert set([DataSource.GUTENBERG, DataSource.OA_CONTENT_SERVER]) == set(
+ [pool.data_source.name for pool in pools_g]
+ )
def test_import_with_unrecognized_distributor_creates_distributor(self):
"""We get a book from a previously unknown data source, with a license
@@ -1235,16 +1236,14 @@ def test_import_with_unrecognized_distributor_creates_distributor(self):
book is imported and both DataSources are created.
"""
feed = self.sample_opds("unrecognized_distributor.opds")
- self._default_collection.external_integration.setting('data_source').value = (
- "some new source"
- )
+ self._default_collection.external_integration.setting(
+ "data_source"
+ ).value = "some new source"
importer = OPDSImporter(
self._db,
collection=self._default_collection,
)
- imported_editions, pools, works, failures = (
- importer.import_from_feed(feed)
- )
+ imported_editions, pools, works, failures = importer.import_from_feed(feed)
assert {} == failures
# We imported an Edition because there was metadata.
@@ -1266,8 +1265,7 @@ def test_import_updates_metadata(self):
feed = self.sample_opds("metadata_wrangler_overdrive.opds")
edition, is_new = self._edition(
- DataSource.OVERDRIVE, Identifier.OVERDRIVE_ID,
- with_license_pool=True
+ DataSource.OVERDRIVE, Identifier.OVERDRIVE_ID, with_license_pool=True
)
[old_license_pool] = edition.license_pools
old_license_pool.calculate_work()
@@ -1275,15 +1273,13 @@ def test_import_updates_metadata(self):
feed = feed.replace("{OVERDRIVE ID}", edition.primary_identifier.identifier)
- self._default_collection.external_integration.setting('data_source').value = (
- DataSource.OVERDRIVE
- )
- imported_editions, imported_pools, imported_works, failures = (
- OPDSImporter(
- self._db,
- collection=self._default_collection,
- ).import_from_feed(feed)
- )
+ self._default_collection.external_integration.setting(
+ "data_source"
+ ).value = DataSource.OVERDRIVE
+ imported_editions, imported_pools, imported_works, failures = OPDSImporter(
+ self._db,
+ collection=self._default_collection,
+ ).import_from_feed(feed)
# The edition we created has had its metadata updated.
[new_edition] = imported_editions
@@ -1305,9 +1301,12 @@ def test_import_from_license_source(self):
collection=self._default_collection,
)
- imported_editions, imported_pools, imported_works, failures = (
- importer.import_from_feed(feed)
- )
+ (
+ imported_editions,
+ imported_pools,
+ imported_works,
+ failures,
+ ) = importer.import_from_feed(feed)
# Two works have been created, because the content server
# actually tells you how to get copies of these books.
@@ -1325,14 +1324,18 @@ def test_import_from_license_source(self):
# But the license pool's presentation edition has a data
# source associated with the Library Simplified open-access
# content server, since that's where the metadata comes from.
- assert (DataSource.OA_CONTENT_SERVER ==
- mouse_pool.presentation_edition.data_source.name)
+ assert (
+ DataSource.OA_CONTENT_SERVER
+ == mouse_pool.presentation_edition.data_source.name
+ )
# Since the 'mouse' book came with an open-access link, the license
# pool delivery mechanism has been marked as open access.
assert True == mouse_pool.open_access
- assert (RightsStatus.GENERIC_OPEN_ACCESS ==
- mouse_pool.delivery_mechanisms[0].rights_status.uri)
+ assert (
+ RightsStatus.GENERIC_OPEN_ACCESS
+ == mouse_pool.delivery_mechanisms[0].rights_status.uri
+ )
# The 'mouse' work was marked presentation-ready immediately.
assert True == mouse_pool.work.presentation_ready
@@ -1346,35 +1349,30 @@ def test_import_from_license_source(self):
def test_import_from_feed_treats_message_as_failure(self):
feed = self.sample_opds("unrecognized_identifier.opds")
- imported_editions, imported_pools, imported_works, failures = (
- OPDSImporter(
- self._db, collection=self._default_collection
- ).import_from_feed(feed)
- )
+ imported_editions, imported_pools, imported_works, failures = OPDSImporter(
+ self._db, collection=self._default_collection
+ ).import_from_feed(feed)
[failure] = list(failures.values())
assert isinstance(failure, CoverageFailure)
assert True == failure.transient
assert "404: I've never heard of this work." == failure.exception
-
def test_import_edition_failure_becomes_coverage_failure(self):
# Make sure that an exception during import generates a
# meaningful error message.
feed = self.content_server_mini_feed
- imported_editions, pools, works, failures = (
- DoomedOPDSImporter(
- self._db,
- collection=self._default_collection,
- ).import_from_feed(feed)
- )
+ imported_editions, pools, works, failures = DoomedOPDSImporter(
+ self._db,
+ collection=self._default_collection,
+ ).import_from_feed(feed)
# Only one book was imported, the other failed.
assert 1 == len(imported_editions)
# The other failed to import, and became a CoverageFailure
- failure = failures['http://www.gutenberg.org/ebooks/10441']
+ failure = failures["http://www.gutenberg.org/ebooks/10441"]
assert isinstance(failure, CoverageFailure)
assert False == failure.transient
assert "Utter failure!" in failure.exception
@@ -1384,23 +1382,18 @@ def test_import_work_failure_becomes_coverage_failure(self):
# imported edition generates a meaningful error message.
feed = self.content_server_mini_feed
- self._default_collection.external_integration.setting('data_source').value = (
- DataSource.OA_CONTENT_SERVER
- )
- importer = DoomedWorkOPDSImporter(
- self._db,
- collection=self._default_collection
- )
+ self._default_collection.external_integration.setting(
+ "data_source"
+ ).value = DataSource.OA_CONTENT_SERVER
+ importer = DoomedWorkOPDSImporter(self._db, collection=self._default_collection)
- imported_editions, pools, works, failures = (
- importer.import_from_feed(feed)
- )
+ imported_editions, pools, works, failures = importer.import_from_feed(feed)
# One work was created, the other failed.
assert 1 == len(works)
# There's an error message for the work that failed.
- failure = failures['http://www.gutenberg.org/ebooks/10441']
+ failure = failures["http://www.gutenberg.org/ebooks/10441"]
assert isinstance(failure, CoverageFailure)
assert False == failure.transient
assert "Utter work failure!" in failure.exception
@@ -1412,51 +1405,59 @@ def test_consolidate_links(self):
links = [None, None]
assert [] == OPDSImporter.consolidate_links(links)
- links = [LinkData(href=self._url, rel=rel, media_type="image/jpeg")
- for rel in [Hyperlink.OPEN_ACCESS_DOWNLOAD,
- Hyperlink.IMAGE,
- Hyperlink.THUMBNAIL_IMAGE,
- Hyperlink.OPEN_ACCESS_DOWNLOAD]
+ links = [
+ LinkData(href=self._url, rel=rel, media_type="image/jpeg")
+ for rel in [
+ Hyperlink.OPEN_ACCESS_DOWNLOAD,
+ Hyperlink.IMAGE,
+ Hyperlink.THUMBNAIL_IMAGE,
+ Hyperlink.OPEN_ACCESS_DOWNLOAD,
+ ]
]
old_link = links[2]
links = OPDSImporter.consolidate_links(links)
- assert [Hyperlink.OPEN_ACCESS_DOWNLOAD,
- Hyperlink.IMAGE,
- Hyperlink.OPEN_ACCESS_DOWNLOAD] == [x.rel for x in links]
+ assert [
+ Hyperlink.OPEN_ACCESS_DOWNLOAD,
+ Hyperlink.IMAGE,
+ Hyperlink.OPEN_ACCESS_DOWNLOAD,
+ ] == [x.rel for x in links]
link = links[1]
assert old_link == link.thumbnail
- links = [LinkData(href=self._url, rel=rel, media_type="image/jpeg")
- for rel in [Hyperlink.THUMBNAIL_IMAGE,
- Hyperlink.IMAGE,
- Hyperlink.THUMBNAIL_IMAGE,
- Hyperlink.IMAGE]
+ links = [
+ LinkData(href=self._url, rel=rel, media_type="image/jpeg")
+ for rel in [
+ Hyperlink.THUMBNAIL_IMAGE,
+ Hyperlink.IMAGE,
+ Hyperlink.THUMBNAIL_IMAGE,
+ Hyperlink.IMAGE,
+ ]
]
t1, i1, t2, i2 = links
links = OPDSImporter.consolidate_links(links)
- assert [Hyperlink.IMAGE,
- Hyperlink.IMAGE] == [x.rel for x in links]
+ assert [Hyperlink.IMAGE, Hyperlink.IMAGE] == [x.rel for x in links]
assert t1 == i1.thumbnail
assert t2 == i2.thumbnail
- links = [LinkData(href=self._url, rel=rel, media_type="image/jpeg")
- for rel in [Hyperlink.THUMBNAIL_IMAGE,
- Hyperlink.IMAGE,
- Hyperlink.IMAGE]
+ links = [
+ LinkData(href=self._url, rel=rel, media_type="image/jpeg")
+ for rel in [Hyperlink.THUMBNAIL_IMAGE, Hyperlink.IMAGE, Hyperlink.IMAGE]
]
t1, i1, i2 = links
links = OPDSImporter.consolidate_links(links)
- assert [Hyperlink.IMAGE,
- Hyperlink.IMAGE] == [x.rel for x in links]
+ assert [Hyperlink.IMAGE, Hyperlink.IMAGE] == [x.rel for x in links]
assert t1 == i1.thumbnail
assert None == i2.thumbnail
def test_import_book_that_offers_no_license(self):
feed = self.sample_opds("book_without_license.opds")
importer = OPDSImporter(self._db, self._default_collection)
- imported_editions, imported_pools, imported_works, failures = (
- importer.import_from_feed(feed)
- )
+ (
+ imported_editions,
+ imported_pools,
+ imported_works,
+ failures,
+ ) = importer.import_from_feed(feed)
# We got an Edition for this book, but no LicensePool and no Work.
[edition] = imported_editions
@@ -1473,17 +1474,12 @@ def test_build_identifier_mapping(self):
collection = self._collection(protocol=ExternalIntegration.AXIS_360)
lp = self._licensepool(
- None, collection=collection,
- data_source_name=DataSource.AXIS_360
+ None, collection=collection, data_source_name=DataSource.AXIS_360
)
# Create a couple of ISBN equivalencies.
- isbn1 = self._identifier(
- identifier_type=Identifier.ISBN, foreign_id=self._isbn
- )
- isbn2 = self._identifier(
- identifier_type=Identifier.ISBN, foreign_id=self._isbn
- )
+ isbn1 = self._identifier(identifier_type=Identifier.ISBN, foreign_id=self._isbn)
+ isbn2 = self._identifier(identifier_type=Identifier.ISBN, foreign_id=self._isbn)
source = DataSource.lookup(self._db, DataSource.AXIS_360)
[lp.identifier.equivalent_to(source, isbn, 1) for isbn in [isbn1, isbn2]]
@@ -1493,12 +1489,12 @@ def test_build_identifier_mapping(self):
# We can build one.
importer.build_identifier_mapping([isbn1.urn])
- expected = { isbn1 : lp.identifier }
+ expected = {isbn1: lp.identifier}
assert expected == importer.identifier_mapping
# If we already have one, it's overwritten.
importer.build_identifier_mapping([isbn2.urn])
- overwrite = { isbn2 : lp.identifier }
+ overwrite = {isbn2: lp.identifier}
assert importer.identifier_mapping == overwrite
# If the importer doesn't have a collection, we can't build
@@ -1531,6 +1527,7 @@ def test_update_work_for_edition_having_no_work(self):
# immediately call LicensePool.calculate_work().
def explode():
raise Exception("boom!")
+
lp.calculate_work = explode
importer.update_work_for_edition(edition)
@@ -1547,14 +1544,13 @@ def test_update_work_for_edition_having_incomplete_work(self):
i = edition.primary_identifier
new_edition = self._edition(
data_source_name=DataSource.METADATA_WRANGLER,
- identifier_type=i.type, identifier_id=i.identifier,
- title="A working title"
+ identifier_type=i.type,
+ identifier_id=i.identifier,
+ title="A working title",
)
importer = OPDSImporter(self._db, None)
- returned_pool, returned_work = importer.update_work_for_edition(
- edition
- )
+ returned_pool, returned_work = importer.update_work_for_edition(edition)
assert returned_pool == pool
assert returned_work == work
@@ -1576,14 +1572,13 @@ def test_update_work_for_edition_having_presentation_ready_work(self):
i = edition.primary_identifier
new_edition = self._edition(
data_source_name=DataSource.LIBRARY_STAFF,
- identifier_type=i.type, identifier_id=i.identifier,
- title="A new title"
+ identifier_type=i.type,
+ identifier_id=i.identifier,
+ title="A new title",
)
importer = OPDSImporter(self._db, None)
- returned_pool, returned_work = importer.update_work_for_edition(
- new_edition
- )
+ returned_pool, returned_work = importer.update_work_for_edition(new_edition)
# The existing LicensePool and Work were returned.
assert returned_pool == pool
@@ -1612,11 +1607,11 @@ def test_update_work_for_edition_having_multiple_license_pools(self):
assert lp2.work == work
def test_assert_importable_content(self):
-
class Mock(OPDSImporter):
"""An importer that may or may not be able to find
real open-access content.
"""
+
# Set this variable to control whether any open-access links
# are "found" in the OPDS feed.
open_access_links = None
@@ -1660,21 +1655,27 @@ class NoLinks(Mock):
assert [] == importer._is_open_access_link_called_with
oa = Hyperlink.OPEN_ACCESS_DOWNLOAD
+
class BadLinks(Mock):
"""Simulate an OPDS feed that contains open-access links that
don't actually work, because _is_open_access always returns False
"""
+
open_access_links = [
LinkData(href="url1", rel=oa, media_type="text/html"),
LinkData(href="url2", rel=oa, media_type="application/json"),
- LinkData(href="I won't be tested", rel=oa,
- media_type="application/json")
+ LinkData(
+ href="I won't be tested", rel=oa, media_type="application/json"
+ ),
]
importer = BadLinks(self._db, None, do_get)
with pytest.raises(IntegrationException) as excinfo:
importer.assert_importable_content("feed", "url", max_get_attempts=2)
- assert "Was unable to GET supposedly open-access content such as url2 (tried 2 times)" in str(excinfo.value)
+ assert (
+ "Was unable to GET supposedly open-access content such as url2 (tried 2 times)"
+ in str(excinfo.value)
+ )
# We called _is_open_access_link on the first and second links
# found in the 'metadata', but failed both times.
@@ -1689,21 +1690,22 @@ class GoodLink(Mock):
"""Simulate an OPDS feed that contains two bad open-access links
and one good one.
"""
+
_is_open_access_link_called_with = []
open_access_links = [
LinkData(href="bad", rel=oa, media_type="text/html"),
LinkData(href="good", rel=oa, media_type="application/json"),
LinkData(href="also bad", rel=oa, media_type="text/html"),
]
+
def _is_open_access_link(self, url, type):
self._is_open_access_link_called_with.append((url, type))
- if url == 'bad':
+ if url == "bad":
return False
return "this is a book"
+
importer = GoodLink(self._db, None, do_get)
- result = importer.assert_importable_content(
- "feed", "url", max_get_attempts=5
- )
+ result = importer.assert_importable_content("feed", "url", max_get_attempts=5)
assert "this is a book" == result
# The first link didn't work, but the second one did,
@@ -1726,32 +1728,27 @@ def test__open_access_links(self):
# This CirculationData has no open-access links, so it will be
# ignored.
circulation = CirculationData(DataSource.GUTENBERG, self._identifier())
- no_open_access_links = Metadata(
- DataSource.GUTENBERG, circulation=circulation
- )
+ no_open_access_links = Metadata(DataSource.GUTENBERG, circulation=circulation)
# This has three links, but only the open-access links
# will be returned.
circulation = CirculationData(DataSource.GUTENBERG, self._identifier())
oa = Hyperlink.OPEN_ACCESS_DOWNLOAD
for rel in [oa, Hyperlink.IMAGE, oa]:
- circulation.links.append(
- LinkData(href=self._url, rel=rel)
- )
- two_open_access_links = Metadata(
- DataSource.GUTENBERG, circulation=circulation
- )
+ circulation.links.append(LinkData(href=self._url, rel=rel))
+ two_open_access_links = Metadata(DataSource.GUTENBERG, circulation=circulation)
- oa_only = [x for x in circulation.links if x.rel==oa]
- assert oa_only == list(m([no_circulation, two_open_access_links,
- no_open_access_links]))
+ oa_only = [x for x in circulation.links if x.rel == oa]
+ assert oa_only == list(
+ m([no_circulation, two_open_access_links, no_open_access_links])
+ )
def test__is_open_access_link(self):
http = DummyHTTPClient()
# We only check that the response entity-body isn't tiny. 11
# kilobytes of data is enough.
- enough_content = "a" * (1024*11)
+ enough_content = "a" * (1024 * 11)
# Set up an HTTP response that looks enough like a book
# to convince _is_open_access_link.
@@ -1760,8 +1757,9 @@ def test__is_open_access_link(self):
url = self._url
type = "text/html"
- assert ("Found a book-like thing at %s" % url ==
- monitor._is_open_access_link(url, type))
+ assert "Found a book-like thing at %s" % url == monitor._is_open_access_link(
+ url, type
+ )
# We made a GET request to the appropriate URL.
assert url == http.requests.pop()
@@ -1779,16 +1777,19 @@ def test__is_open_access_link(self):
def test_import_open_access_audiobook(self):
feed = self.audiobooks_opds
- download_manifest_url = 'https://api.archivelab.org/books/kniga_zitij_svjatyh_na_mesjac_avgust_eu_0811_librivox/opds_audio_manifest'
+ download_manifest_url = "https://api.archivelab.org/books/kniga_zitij_svjatyh_na_mesjac_avgust_eu_0811_librivox/opds_audio_manifest"
importer = OPDSImporter(
self._db,
collection=self._default_collection,
)
- imported_editions, imported_pools, imported_works, failures = (
- importer.import_from_feed(feed)
- )
+ (
+ imported_editions,
+ imported_pools,
+ imported_works,
+ failures,
+ ) = importer.import_from_feed(feed)
assert 1 == len(imported_editions)
@@ -1800,7 +1801,10 @@ def test_import_open_access_audiobook(self):
assert download_manifest_url == august_pool._open_access_download_url
[lpdm] = august_pool.delivery_mechanisms
- assert Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE == lpdm.delivery_mechanism.content_type
+ assert (
+ Representation.AUDIOBOOK_MANIFEST_MEDIA_TYPE
+ == lpdm.delivery_mechanism.content_type
+ )
assert DeliveryMechanism.NO_DRM == lpdm.delivery_mechanism.drm_scheme
@@ -1816,31 +1820,31 @@ def test_combine(self):
d1 = dict(
a_list=[1],
a_scalar="old value",
- a_dict=dict(key1=None, key2=[2], key3="value3")
+ a_dict=dict(key1=None, key2=[2], key3="value3"),
)
d2 = dict(
a_list=[2],
a_scalar="new value",
- a_dict=dict(key1="finally a value", key4="value4", key2=[200])
+ a_dict=dict(key1="finally a value", key4="value4", key2=[200]),
)
combined = OPDSImporter.combine(d1, d2)
# Dictionaries get combined recursively.
- d = combined['a_dict']
+ d = combined["a_dict"]
# Normal scalar values can be overridden once set.
- assert "new value" == combined['a_scalar']
+ assert "new value" == combined["a_scalar"]
# Missing values are filled in.
- assert 'finally a value' == d["key1"]
- assert 'value3' == d['key3']
- assert 'value4' == d['key4']
+ assert "finally a value" == d["key1"]
+ assert "value3" == d["key3"]
+ assert "value4" == d["key4"]
# Lists get extended.
- assert [1, 2] == combined['a_list']
- assert [2, 200] == d['key2']
+ assert [1, 2] == combined["a_list"]
+ assert [2, 200] == d["key2"]
def test_combine_null_cases(self):
"""Test combine()'s ability to handle empty and null dictionaries."""
@@ -1859,7 +1863,7 @@ def test_combine_missing_value_is_replaced(self):
expect = dict(a=None, b=None)
assert expect == c(a_is_missing, a_is_present)
- a_is_present['a'] = True
+ a_is_present["a"] = True
expect = dict(a=True, b=None)
assert expect == c(a_is_missing, a_is_present)
@@ -1875,7 +1879,7 @@ def test_combine_present_value_replaced(self):
a_is_old = dict(a="old value")
a_is_new = dict(a="new value")
- assert "new value" == c(a_is_old, a_is_new)['a']
+ assert "new value" == c(a_is_old, a_is_new)["a"]
def test_combine_present_value_not_replaced_with_none(self):
@@ -1903,11 +1907,12 @@ def test_combine_present_value_extends_dictionary(self):
"""
a_is_true = dict(a=dict(b=[True]))
a_is_false = dict(a=dict(b=[False]))
- assert (dict(a=dict(b=[True, False])) ==
- OPDSImporter.combine(a_is_true, a_is_false))
+ assert dict(a=dict(b=[True, False])) == OPDSImporter.combine(
+ a_is_true, a_is_false
+ )
-class TestMirroring(OPDSImporterTest):
+class TestMirroring(OPDSImporterTest):
@pytest.fixture()
def http(self):
class DummyHashedHttpClient(object):
@@ -1915,7 +1920,14 @@ def __init__(self):
self.responses = {}
self.requests = []
- def queue_response(self, url, response_code, media_type='text_html', other_headers=None, content=''):
+ def queue_response(
+ self,
+ url,
+ response_code,
+ media_type="text_html",
+ other_headers=None,
+ content="",
+ ):
headers = {}
if media_type:
headers["content-type"] = media_type
@@ -1927,6 +1939,7 @@ def queue_response(self, url, response_code, media_type='text_html', other_heade
def do_get(self, url, *args, **kwargs):
self.requests.append(url)
return self.responses.pop(url)
+
return DummyHashedHttpClient()
@pytest.fixture()
@@ -1948,45 +1961,45 @@ def png(self):
@pytest.fixture()
def epub10441(self):
return {
- 'url': 'http://www.gutenberg.org/ebooks/10441.epub.images',
- 'response_code': 200,
- 'content': b'I am 10441.epub.images',
- 'media_type': Representation.EPUB_MEDIA_TYPE
+ "url": "http://www.gutenberg.org/ebooks/10441.epub.images",
+ "response_code": 200,
+ "content": b"I am 10441.epub.images",
+ "media_type": Representation.EPUB_MEDIA_TYPE,
}
@pytest.fixture()
def epub10441_cover(self, svg):
return {
- 'url': 'https://s3.amazonaws.com/book-covers.nypl.org/Gutenberg-Illustrated/10441/cover_10441_9.png',
- 'response_code': 200,
- 'content': svg,
- 'media_type': Representation.SVG_MEDIA_TYPE
+ "url": "https://s3.amazonaws.com/book-covers.nypl.org/Gutenberg-Illustrated/10441/cover_10441_9.png",
+ "response_code": 200,
+ "content": svg,
+ "media_type": Representation.SVG_MEDIA_TYPE,
}
@pytest.fixture()
def epub10557(self):
return {
- 'url': 'http://www.gutenberg.org/ebooks/10557.epub.images',
- 'response_code': 200,
- 'content': b'I am 10557.epub.images',
- 'media_type': Representation.EPUB_MEDIA_TYPE
+ "url": "http://www.gutenberg.org/ebooks/10557.epub.images",
+ "response_code": 200,
+ "content": b"I am 10557.epub.images",
+ "media_type": Representation.EPUB_MEDIA_TYPE,
}
@pytest.fixture()
def epub10557_cover_broken(self):
return {
- 'url': 'http://root/broken-cover-image',
- 'response_code': 404,
- 'media_type': "text/plain"
+ "url": "http://root/broken-cover-image",
+ "response_code": 404,
+ "media_type": "text/plain",
}
@pytest.fixture()
def epub10557_cover_working(self, png):
return {
- 'url': 'http://root/working-cover-image',
- 'response_code': 200,
- 'content': png,
- 'media_type': Representation.PNG_MEDIA_TYPE
+ "url": "http://root/working-cover-image",
+ "response_code": 200,
+ "content": png,
+ "media_type": Representation.PNG_MEDIA_TYPE,
}
def test_importer_gets_appropriate_mirror_for_collection(self):
@@ -2001,16 +2014,18 @@ def test_importer_gets_appropriate_mirror_for_collection(self):
# First set up a storage integration.
integration = self._external_integration(
- ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL,
- username="username", password="password",
- settings = {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY : "some-covers"}
+ ExternalIntegration.S3,
+ ExternalIntegration.STORAGE_GOAL,
+ username="username",
+ password="password",
+ settings={S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "some-covers"},
)
# Associate the collection's integration with the storage integration
# for the purpose of 'covers'.
integration_link = self._external_integration_link(
integration=collection._external_integration,
other_integration=integration,
- purpose=ExternalIntegrationLink.COVERS
+ purpose=ExternalIntegrationLink.COVERS,
)
# Now an OPDSImporter created for this collection has an
@@ -2021,35 +2036,50 @@ def test_importer_gets_appropriate_mirror_for_collection(self):
assert isinstance(mirrors[ExternalIntegrationLink.COVERS], S3Uploader)
assert "some-covers" == mirrors[ExternalIntegrationLink.COVERS].get_bucket(
- S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY)
+ S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY
+ )
assert mirrors[ExternalIntegrationLink.OPEN_ACCESS_BOOKS] == None
-
# An OPDSImporter can have two types of mirrors.
integration = self._external_integration(
- ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL,
- username="username", password="password",
- settings={S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY : "some-books"}
+ ExternalIntegration.S3,
+ ExternalIntegration.STORAGE_GOAL,
+ username="username",
+ password="password",
+ settings={S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "some-books"},
)
# Associate the collection's integration with the storage integration
# for the purpose of 'covers'.
integration_link = self._external_integration_link(
integration=collection._external_integration,
other_integration=integration,
- purpose=ExternalIntegrationLink.OPEN_ACCESS_BOOKS
+ purpose=ExternalIntegrationLink.OPEN_ACCESS_BOOKS,
)
importer = OPDSImporter(self._db, collection=collection)
mirrors = importer.mirrors
- assert isinstance(mirrors[ExternalIntegrationLink.OPEN_ACCESS_BOOKS], S3Uploader)
- assert "some-books" == mirrors[ExternalIntegrationLink.OPEN_ACCESS_BOOKS].get_bucket(
- S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY)
+ assert isinstance(
+ mirrors[ExternalIntegrationLink.OPEN_ACCESS_BOOKS], S3Uploader
+ )
+ assert "some-books" == mirrors[
+ ExternalIntegrationLink.OPEN_ACCESS_BOOKS
+ ].get_bucket(S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY)
assert "some-covers" == mirrors[ExternalIntegrationLink.COVERS].get_bucket(
- S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY)
-
- def test_resources_are_mirrored_on_import(self, http, png, svg, epub10441, epub10557, epub10441_cover,
- epub10557_cover_broken, epub10557_cover_working):
+ S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY
+ )
+
+ def test_resources_are_mirrored_on_import(
+ self,
+ http,
+ png,
+ svg,
+ epub10441,
+ epub10557,
+ epub10441_cover,
+ epub10557_cover_broken,
+ epub10557_cover_working,
+ ):
http.queue_response(**epub10441)
http.queue_response(**epub10441_cover)
http.queue_response(**epub10557)
@@ -2063,71 +2093,89 @@ def test_resources_are_mirrored_on_import(self, http, png, svg, epub10441, epub1
mirrors = dict(books_mirror=s3_for_books, covers_mirror=s3_for_covers)
importer = OPDSImporter(
- self._db, collection=self._default_collection,
- mirrors=mirrors, http_get=http.do_get
+ self._db,
+ collection=self._default_collection,
+ mirrors=mirrors,
+ http_get=http.do_get,
)
- imported_editions, pools, works, failures = (
- importer.import_from_feed(self.content_server_mini_feed,
- feed_url='http://root')
+ imported_editions, pools, works, failures = importer.import_from_feed(
+ self.content_server_mini_feed, feed_url="http://root"
)
assert 2 == len(pools)
# Both items were requested
- assert epub10441['url'] in http.requests
- assert epub10557['url'] in http.requests
+ assert epub10441["url"] in http.requests
+ assert epub10557["url"] in http.requests
# The import process requested each remote resource in the feed. The thumbnail
# image was not requested, since we never trust foreign thumbnails. The order they
# are requested in is not deterministic, but after requesting the epub the images
# should be requested.
- index = http.requests.index(epub10441['url'])
- assert http.requests[index+1] == epub10441_cover['url']
-
- index = http.requests.index(epub10557['url'])
- assert http.requests[index:index+3] == [
- epub10557['url'],
- epub10557_cover_broken['url'],
- epub10557_cover_working['url']
+ index = http.requests.index(epub10441["url"])
+ assert http.requests[index + 1] == epub10441_cover["url"]
+
+ index = http.requests.index(epub10557["url"])
+ assert http.requests[index : index + 3] == [
+ epub10557["url"],
+ epub10557_cover_broken["url"],
+ epub10557_cover_working["url"],
]
- e_10441 = next(e for e in imported_editions if e.primary_identifier.identifier == '10441')
- e_10557 = next(e for e in imported_editions if e.primary_identifier.identifier == '10557')
-
- [e_10441_oa_link, e_10441_image_link, e_10441_thumbnail_link,
- e_10441_description_link] = sorted(
- e_10441.primary_identifier.links, key=lambda x: x.rel
+ e_10441 = next(
+ e for e in imported_editions if e.primary_identifier.identifier == "10441"
)
- [e_10557_broken_image_link, e_10557_working_image_link, e_10557_oa_link] = sorted(
- e_10557.primary_identifier.links, key=lambda x: x.resource.url
+ e_10557 = next(
+ e for e in imported_editions if e.primary_identifier.identifier == "10557"
)
+ [
+ e_10441_oa_link,
+ e_10441_image_link,
+ e_10441_thumbnail_link,
+ e_10441_description_link,
+ ] = sorted(e_10441.primary_identifier.links, key=lambda x: x.rel)
+ [
+ e_10557_broken_image_link,
+ e_10557_working_image_link,
+ e_10557_oa_link,
+ ] = sorted(e_10557.primary_identifier.links, key=lambda x: x.resource.url)
+
# The thumbnail image is associated with the Identifier, but
# it's not used because it's associated with a representation
# (cover_10441_9.png with media type "image/png") that no
# longer has a resource associated with it.
assert Hyperlink.THUMBNAIL_IMAGE == e_10441_thumbnail_link.rel
- hypothetical_full_representation = e_10441_thumbnail_link.resource.representation.thumbnail_of
+ hypothetical_full_representation = (
+ e_10441_thumbnail_link.resource.representation.thumbnail_of
+ )
assert None == hypothetical_full_representation.resource
- assert (Representation.PNG_MEDIA_TYPE ==
- hypothetical_full_representation.media_type)
+ assert (
+ Representation.PNG_MEDIA_TYPE == hypothetical_full_representation.media_type
+ )
# That's because when we actually got cover_10441_9.png,
# it turned out to be an SVG file, not a PNG, so we created a new
# Representation. TODO: Obviously we could do better here.
- assert (Representation.SVG_MEDIA_TYPE ==
- e_10441_image_link.resource.representation.media_type)
+ assert (
+ Representation.SVG_MEDIA_TYPE
+ == e_10441_image_link.resource.representation.media_type
+ )
# The two open-access links were mirrored to S3, as were the
# original SVG image, the working PNG image, and its thumbnail, which we generated. The
# The broken PNG image was not mirrored because our attempt to download
# it resulted in a 404 error.
- imported_book_representations = {e_10441_oa_link.resource.representation,
- e_10557_oa_link.resource.representation}
- imported_cover_representations = {e_10441_image_link.resource.representation,
- e_10557_working_image_link.resource.representation,
- e_10557_working_image_link.resource.representation.thumbnails[0]}
+ imported_book_representations = {
+ e_10441_oa_link.resource.representation,
+ e_10557_oa_link.resource.representation,
+ }
+ imported_cover_representations = {
+ e_10441_image_link.resource.representation,
+ e_10557_working_image_link.resource.representation,
+ e_10557_working_image_link.resource.representation.thumbnails[0],
+ }
assert imported_book_representations == set(s3_for_books.uploaded)
assert imported_cover_representations == set(s3_for_covers.uploaded)
@@ -2135,8 +2183,8 @@ def test_resources_are_mirrored_on_import(self, http, png, svg, epub10441, epub1
assert 2 == len(s3_for_books.uploaded)
assert 3 == len(s3_for_covers.uploaded)
- assert epub10441['content'] in s3_for_books.content
- assert epub10557['content'] in s3_for_books.content
+ assert epub10441["content"] in s3_for_books.content
+ assert epub10557["content"] in s3_for_books.content
svg_bytes = svg.encode("utf8")
covers_content = s3_for_covers.content[:]
@@ -2146,7 +2194,7 @@ def test_resources_are_mirrored_on_import(self, http, png, svg, epub10441, epub1
covers_content.remove(png)
# We don't know what the thumbnail is, but we know it's smaller than the original cover image.
- assert(len(png) > len(covers_content[0]))
+ assert len(png) > len(covers_content[0])
# Each resource was 'mirrored' to an Amazon S3 bucket.
#
@@ -2160,39 +2208,35 @@ def test_resources_are_mirrored_on_import(self, http, png, svg, epub10441, epub1
# The "crow" book was mirrored to a bucket corresponding to
# the open-access content source, the default data source used
# when no distributor was specified for a book.
- book1_url = 'https://test-content-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/10441/The%20Green%20Mouse.epub.images'
- book1_svg_cover = 'https://test-cover-bucket.s3.amazonaws.com/Library%20Simplified%20Open%20Access%20Content%20Server/Gutenberg%20ID/10441/cover_10441_9.svg'
- book2_url = 'https://test-content-bucket.s3.amazonaws.com/Library%20Simplified%20Open%20Access%20Content%20Server/Gutenberg%20ID/10557/Johnny%20Crow%27s%20Party.epub.images'
- book2_png_cover = 'https://test-cover-bucket.s3.amazonaws.com/Library%20Simplified%20Open%20Access%20Content%20Server/Gutenberg%20ID/10557/working-cover-image.png'
- book2_png_thumbnail = 'https://test-cover-bucket.s3.amazonaws.com/scaled/300/Library%20Simplified%20Open%20Access%20Content%20Server/Gutenberg%20ID/10557/working-cover-image.png'
+ book1_url = "https://test-content-bucket.s3.amazonaws.com/Gutenberg/Gutenberg%20ID/10441/The%20Green%20Mouse.epub.images"
+ book1_svg_cover = "https://test-cover-bucket.s3.amazonaws.com/Library%20Simplified%20Open%20Access%20Content%20Server/Gutenberg%20ID/10441/cover_10441_9.svg"
+ book2_url = "https://test-content-bucket.s3.amazonaws.com/Library%20Simplified%20Open%20Access%20Content%20Server/Gutenberg%20ID/10557/Johnny%20Crow%27s%20Party.epub.images"
+ book2_png_cover = "https://test-cover-bucket.s3.amazonaws.com/Library%20Simplified%20Open%20Access%20Content%20Server/Gutenberg%20ID/10557/working-cover-image.png"
+ book2_png_thumbnail = "https://test-cover-bucket.s3.amazonaws.com/scaled/300/Library%20Simplified%20Open%20Access%20Content%20Server/Gutenberg%20ID/10557/working-cover-image.png"
uploaded_urls = {x.mirror_url for x in s3_for_covers.uploaded}
uploaded_book_urls = {x.mirror_url for x in s3_for_books.uploaded}
assert {book1_svg_cover, book2_png_cover, book2_png_thumbnail} == uploaded_urls
assert {book1_url, book2_url} == uploaded_book_urls
-
# If we fetch the feed again, and the entries have been updated since the
# cutoff, but the content of the open access links hasn't changed, we won't mirror
# them again.
cutoff = datetime_utc(2013, 1, 2, 16, 56, 40)
http.queue_response(
- epub10441['url'],
- 304, media_type=Representation.EPUB_MEDIA_TYPE
+ epub10441["url"], 304, media_type=Representation.EPUB_MEDIA_TYPE
)
http.queue_response(
- epub10441_cover['url'],
- 304, media_type=Representation.SVG_MEDIA_TYPE
+ epub10441_cover["url"], 304, media_type=Representation.SVG_MEDIA_TYPE
)
http.queue_response(
- epub10557['url'],
- 304, media_type=Representation.EPUB_MEDIA_TYPE
+ epub10557["url"], 304, media_type=Representation.EPUB_MEDIA_TYPE
)
- imported_editions, pools, works, failures = (
- importer.import_from_feed(self.content_server_mini_feed)
+ imported_editions, pools, works, failures = importer.import_from_feed(
+ self.content_server_mini_feed
)
assert {e_10441, e_10557} == set(imported_editions)
@@ -2202,32 +2246,36 @@ def test_resources_are_mirrored_on_import(self, http, png, svg, epub10441, epub1
# If the content has changed, it will be mirrored again.
epub10441_updated = epub10441.copy()
- epub10441_updated['content'] = b"I am a new version of 10441.epub.images"
+ epub10441_updated["content"] = b"I am a new version of 10441.epub.images"
http.queue_response(**epub10441_updated)
http.queue_response(**epub10441_cover)
epub10557_updated = epub10557.copy()
- epub10557_updated['content'] = b"I am a new version of 10557.epub.images"
+ epub10557_updated["content"] = b"I am a new version of 10557.epub.images"
http.queue_response(**epub10557_updated)
- imported_editions, pools, works, failures = (
- importer.import_from_feed(self.content_server_mini_feed)
+ imported_editions, pools, works, failures = importer.import_from_feed(
+ self.content_server_mini_feed
)
assert {e_10441, e_10557} == set(imported_editions)
assert 4 == len(s3_for_books.uploaded)
- assert epub10441_updated['content'] in s3_for_books.content[-2:]
+ assert epub10441_updated["content"] in s3_for_books.content[-2:]
assert svg_bytes == s3_for_covers.content.pop()
- assert epub10557_updated['content'] in s3_for_books.content[-2:]
-
-
- def test_content_resources_not_mirrored_on_import_if_no_collection(self, http, svg, epub10557_cover_broken,
- epub10557_cover_working, epub10441_cover):
+ assert epub10557_updated["content"] in s3_for_books.content[-2:]
+
+ def test_content_resources_not_mirrored_on_import_if_no_collection(
+ self,
+ http,
+ svg,
+ epub10557_cover_broken,
+ epub10557_cover_working,
+ epub10441_cover,
+ ):
# If you don't provide a Collection to the OPDSImporter, no
# LicensePools are created for the book and content resources
# (like EPUB editions of the book) are not mirrored. Only
# metadata resources (like the book cover) are mirrored.
-
# The request to http://root/broken-cover-image
# will result in a 404 error, and the image will not be mirrored.
http.queue_response(**epub10557_cover_broken)
@@ -2238,13 +2286,11 @@ def test_content_resources_not_mirrored_on_import_if_no_collection(self, http, s
mirrors = dict(covers_mirror=s3)
importer = OPDSImporter(
- self._db, collection=None,
- mirrors=mirrors, http_get=http.do_get
+ self._db, collection=None, mirrors=mirrors, http_get=http.do_get
)
- imported_editions, pools, works, failures = (
- importer.import_from_feed(self.content_server_mini_feed,
- feed_url='http://root')
+ imported_editions, pools, works, failures = importer.import_from_feed(
+ self.content_server_mini_feed, feed_url="http://root"
)
# No LicensePools were created, since no Collection was
@@ -2258,37 +2304,53 @@ def test_content_resources_not_mirrored_on_import_if_no_collection(self, http, s
# were going to make our own thumbnail anyway.
assert len(http.requests) == 3
assert set(http.requests) == {
- epub10441_cover['url'],
- epub10557_cover_broken['url'],
- epub10557_cover_working['url']
+ epub10441_cover["url"],
+ epub10557_cover_broken["url"],
+ epub10557_cover_working["url"],
}
class TestOPDSImportMonitor(OPDSImporterTest):
-
def test_constructor(self):
with pytest.raises(ValueError) as excinfo:
OPDSImportMonitor(self._db, None, OPDSImporter)
- assert "OPDSImportMonitor can only be run in the context of a Collection." in str(excinfo.value)
+ assert (
+ "OPDSImportMonitor can only be run in the context of a Collection."
+ in str(excinfo.value)
+ )
- self._default_collection.external_integration.protocol = ExternalIntegration.OVERDRIVE
+ self._default_collection.external_integration.protocol = (
+ ExternalIntegration.OVERDRIVE
+ )
with pytest.raises(ValueError) as excinfo:
OPDSImportMonitor(self._db, self._default_collection, OPDSImporter)
- assert "Collection Default Collection is configured for protocol Overdrive, not OPDS Import." in str(excinfo.value)
+ assert (
+ "Collection Default Collection is configured for protocol Overdrive, not OPDS Import."
+ in str(excinfo.value)
+ )
- self._default_collection.external_integration.protocol = ExternalIntegration.OPDS_IMPORT
- self._default_collection.external_integration.setting('data_source').value = None
+ self._default_collection.external_integration.protocol = (
+ ExternalIntegration.OPDS_IMPORT
+ )
+ self._default_collection.external_integration.setting(
+ "data_source"
+ ).value = None
with pytest.raises(ValueError) as excinfo:
OPDSImportMonitor(self._db, self._default_collection, OPDSImporter)
- assert "Collection Default Collection has no associated data source." in str(excinfo.value)
+ assert "Collection Default Collection has no associated data source." in str(
+ excinfo.value
+ )
def test_external_integration(self):
monitor = OPDSImportMonitor(
- self._db, self._default_collection,
+ self._db,
+ self._default_collection,
import_class=OPDSImporter,
)
- assert (self._default_collection.external_integration ==
- monitor.external_integration(self._db))
+ assert (
+ self._default_collection.external_integration
+ == monitor.external_integration(self._db)
+ )
def test__run_self_tests(self):
"""Verify the self-tests of an OPDS collection."""
@@ -2308,8 +2370,7 @@ def follow_one_link(self, url):
feed_url = self._url
self._default_collection.external_account_id = feed_url
- monitor = Mock(self._db, self._default_collection,
- import_class=MockImporter)
+ monitor = Mock(self._db, self._default_collection, import_class=MockImporter)
[first_page, found_content] = monitor._run_self_tests(self._db)
expect = "Retrieve the first page of the OPDS feed (%s)" % feed_url
assert expect == first_page.name
@@ -2323,8 +2384,10 @@ def follow_one_link(self, url):
# Then, assert_importable_content was called on the importer.
assert "Checking for importable content" == found_content.name
assert True == found_content.success
- assert (("some content", feed_url) ==
- monitor.importer.assert_importable_content_called_with)
+ assert (
+ "some content",
+ feed_url,
+ ) == monitor.importer.assert_importable_content_called_with
assert "looks good" == found_content.result
def test_hook_methods(self):
@@ -2332,14 +2395,17 @@ def test_hook_methods(self):
come from the collection configuration.
"""
monitor = OPDSImportMonitor(
- self._db, self._default_collection,
+ self._db,
+ self._default_collection,
import_class=OPDSImporter,
)
- assert (self._default_collection.external_account_id ==
- monitor.opds_url(self._default_collection))
+ assert self._default_collection.external_account_id == monitor.opds_url(
+ self._default_collection
+ )
- assert (self._default_collection.data_source ==
- monitor.data_source(self._default_collection))
+ assert self._default_collection.data_source == monitor.data_source(
+ self._default_collection
+ )
def test_feed_contains_new_data(self):
feed = self.content_server_mini_feed
@@ -2349,7 +2415,8 @@ def _get(self, url, headers):
return 200, {"content-type": AtomFeed.ATOM_TYPE}, feed
monitor = OPDSImportMonitor(
- self._db, self._default_collection,
+ self._db,
+ self._default_collection,
import_class=OPDSImporter,
)
timestamp = monitor.timestamp()
@@ -2426,8 +2493,7 @@ def http_with_feed(self, feed, content_type=OPDSFeed.ACQUISITION_FEED_TYPE):
def test_follow_one_link(self):
monitor = OPDSImportMonitor(
- self._db, collection=self._default_collection,
- import_class=OPDSImporter
+ self._db, collection=self._default_collection, import_class=OPDSImporter
)
feed = self.content_server_mini_feed
@@ -2435,6 +2501,7 @@ def test_follow_one_link(self):
# If there's new data, follow_one_link extracts the next links.
def follow():
return monitor.follow_one_link("http://url", do_get=http.do_get)
+
http.queue_response(200, OPDSFeed.ACQUISITION_FEED_TYPE, content=feed)
next_links, content = follow()
assert 1 == len(next_links)
@@ -2455,7 +2522,6 @@ def follow():
)
record.timestamp = datetime_utc(2016, 1, 1, 1, 1, 1)
-
# If there's no new data, follow_one_link returns no next
# links and no content.
#
@@ -2488,8 +2554,9 @@ def test_import_one_feed(self):
# Check coverage records are created.
monitor = OPDSImportMonitor(
- self._db, collection=self._default_collection,
- import_class=DoomedOPDSImporter
+ self._db,
+ collection=self._default_collection,
+ import_class=DoomedOPDSImporter,
)
self._default_collection.external_account_id = "http://root-url/index.xml"
data_source = DataSource.lookup(self._db, DataSource.OA_CONTENT_SERVER)
@@ -2510,8 +2577,9 @@ def test_import_one_feed(self):
# That edition has a CoverageRecord.
record = CoverageRecord.lookup(
- editions[0].primary_identifier, data_source,
- operation=CoverageRecord.IMPORT_OPERATION
+ editions[0].primary_identifier,
+ data_source,
+ operation=CoverageRecord.IMPORT_OPERATION,
)
assert CoverageRecord.SUCCESS == record.status
assert None == record.exception
@@ -2519,28 +2587,36 @@ def test_import_one_feed(self):
# The edition's primary identifier has some cover links whose
# relative URL have been resolved relative to the Collection's
# external_account_id.
- covers = set([x.resource.url for x in editions[0].primary_identifier.links
- if x.rel==Hyperlink.IMAGE])
- assert covers == set(["http://root-url/broken-cover-image",
- "http://root-url/working-cover-image"]
- )
+ covers = set(
+ [
+ x.resource.url
+ for x in editions[0].primary_identifier.links
+ if x.rel == Hyperlink.IMAGE
+ ]
+ )
+ assert covers == set(
+ [
+ "http://root-url/broken-cover-image",
+ "http://root-url/working-cover-image",
+ ]
+ )
# The 202 status message in the feed caused a transient failure.
# The exception caused a persistent failure.
coverage_records = self._db.query(CoverageRecord).filter(
- CoverageRecord.operation==CoverageRecord.IMPORT_OPERATION,
- CoverageRecord.status != CoverageRecord.SUCCESS
+ CoverageRecord.operation == CoverageRecord.IMPORT_OPERATION,
+ CoverageRecord.status != CoverageRecord.SUCCESS,
)
- assert (
- sorted([CoverageRecord.TRANSIENT_FAILURE,
- CoverageRecord.PERSISTENT_FAILURE]) ==
- sorted([x.status for x in coverage_records]))
+ assert sorted(
+ [CoverageRecord.TRANSIENT_FAILURE, CoverageRecord.PERSISTENT_FAILURE]
+ ) == sorted([x.status for x in coverage_records])
- identifier, ignore = Identifier.parse_urn(self._db, "urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441")
+ identifier, ignore = Identifier.parse_urn(
+ self._db, "urn:librarysimplified.org/terms/id/Gutenberg%20ID/10441"
+ )
failure = CoverageRecord.lookup(
- identifier, data_source,
- operation=CoverageRecord.IMPORT_OPERATION
+ identifier, data_source, operation=CoverageRecord.IMPORT_OPERATION
)
assert "Utter failure!" in failure.exception
@@ -2564,11 +2640,10 @@ def follow_one_link(self, link, cutoff_date=None, do_get=None):
def import_one_feed(self, feed):
# Simulate two successes and one failure on every page.
self.imports.append(feed)
- return [object(), object()], { "identifier": "Failure" }
+ return [object(), object()], {"identifier": "Failure"}
monitor = MockOPDSImportMonitor(
- self._db, collection=self._default_collection,
- import_class=OPDSImporter
+ self._db, collection=self._default_collection, import_class=OPDSImporter
)
monitor.queue_response([[], "last page"])
@@ -2591,25 +2666,24 @@ def import_one_feed(self, feed):
def test_update_headers(self):
# Test the _update_headers helper method.
monitor = OPDSImportMonitor(
- self._db, collection=self._default_collection,
- import_class=OPDSImporter
+ self._db, collection=self._default_collection, import_class=OPDSImporter
)
# _update_headers return a new dictionary. An Accept header will be setted
# using the value of custom_accept_header. If the value is not set a
# default value will be used.
- headers = {'Some other': 'header'}
+ headers = {"Some other": "header"}
new_headers = monitor._update_headers(headers)
- assert ['Some other'] == list(headers.keys())
- assert ['Accept', 'Some other'] == sorted(list(new_headers.keys()))
+ assert ["Some other"] == list(headers.keys())
+ assert ["Accept", "Some other"] == sorted(list(new_headers.keys()))
# If a custom_accept_header exist, will be used instead a default value
new_headers = monitor._update_headers(headers)
- old_value = new_headers['Accept']
+ old_value = new_headers["Accept"]
target_value = old_value + "more characters"
monitor.custom_accept_header = target_value
new_headers = monitor._update_headers(headers)
- assert new_headers['Accept'] == target_value
+ assert new_headers["Accept"] == target_value
assert old_value != target_value
# If the monitor has a username and password, an Authorization
@@ -2618,7 +2692,7 @@ def test_update_headers(self):
monitor.password = "a password"
headers = {}
new_headers = monitor._update_headers(headers)
- assert new_headers['Authorization'].startswith('Basic')
+ assert new_headers["Authorization"].startswith("Basic")
# However, if the Authorization and/or Accept headers have been
# filled in by some other piece of code, _update_headers does
@@ -2627,4 +2701,3 @@ def test_update_headers(self):
headers = dict(expect)
new_headers = monitor._update_headers(headers)
assert headers == expect
-
diff --git a/tests/test_opensearch.py b/tests/test_opensearch.py
index 25a74988b..fe080b090 100644
--- a/tests/test_opensearch.py
+++ b/tests/test_opensearch.py
@@ -1,11 +1,11 @@
-from ..model import Genre
from ..classifier import Classifier
-from ..opensearch import OpenSearchDocument
from ..lane import Lane
+from ..model import Genre
+from ..opensearch import OpenSearchDocument
from ..testing import DatabaseTest
-class TestOpenSearchDocument(DatabaseTest):
+class TestOpenSearchDocument(DatabaseTest):
def test_search_info(self):
# Searching this lane will use the language
# and audience restrictions from the lane.
@@ -16,14 +16,14 @@ def test_search_info(self):
lane.fiction = True
info = OpenSearchDocument.search_info(lane)
- assert "Search" == info['name']
- assert "Search English/Deutsch Young Adult" == info['description']
- assert "english/deutsch-young-adult" == info['tags']
+ assert "Search" == info["name"]
+ assert "Search English/Deutsch Young Adult" == info["description"]
+ assert "english/deutsch-young-adult" == info["tags"]
# This lane is the root for a patron type, so searching
# it will use all the lane's restrictions.
root_lane = self._lane()
- root_lane.root_for_patron_type = ['A']
+ root_lane.root_for_patron_type = ["A"]
root_lane.display_name = "Science Fiction & Fantasy"
sf, ignore = Genre.lookup(self._db, "Science Fiction")
fantasy, ignore = Genre.lookup(self._db, "Fantasy")
@@ -31,9 +31,9 @@ def test_search_info(self):
root_lane.add_genre(fantasy)
info = OpenSearchDocument.search_info(root_lane)
- assert "Search" == info['name']
- assert "Search Science Fiction & Fantasy" == info['description']
- assert "science-fiction-&-fantasy" == info['tags']
+ assert "Search" == info["name"]
+ assert "Search Science Fiction & Fantasy" == info["description"]
+ assert "science-fiction-&-fantasy" == info["tags"]
def test_escape_entities(self):
"""Verify that escape_entities properly escapes ampersands."""
@@ -48,9 +48,9 @@ def test_url_template(self):
assert "http://url/?key=val&q={searchTerms}" == m("http://url/?key=val")
def test_for_lane(self):
-
class Mock(OpenSearchDocument):
"""Mock methods called by for_lane."""
+
@classmethod
def search_info(cls, lane):
return dict(
@@ -69,6 +69,6 @@ def url_template(cls, base_url):
# It's just the result of calling search_info() and url_template(),
# and using the resulting dict as arguments into TEMPLATE.
expect = Mock.search_info(object())
- expect['url_template'] = Mock.url_template(object())
+ expect["url_template"] = Mock.url_template(object())
expect = Mock.escape_entities(expect)
assert Mock.TEMPLATE % expect == doc
diff --git a/tests/test_overdrive.py b/tests/test_overdrive.py
index 070c1fc2f..8132749bc 100644
--- a/tests/test_overdrive.py
+++ b/tests/test_overdrive.py
@@ -1,53 +1,40 @@
# encoding: utf-8
-import pytest
-import os
import json
+import os
import pkgutil
-from ..overdrive import (
- OverdriveAPI,
- MockOverdriveAPI,
- OverdriveAdvantageAccount,
- OverdriveRepresentationExtractor,
- OverdriveBibliographicCoverageProvider,
-)
-
-from ..coverage import (
- CoverageFailure,
-)
+import pytest
from ..config import CannotLoadConfiguration
-
+from ..coverage import CoverageFailure
from ..metadata_layer import LinkData
-
from ..model import (
Collection,
Contributor,
DeliveryMechanism,
Edition,
ExternalIntegration,
+ Hyperlink,
Identifier,
- Representation,
- Subject,
Measurement,
MediaTypes,
- Hyperlink,
+ Representation,
+ Subject,
)
-from ..scripts import RunCollectionCoverageProviderScript
-
-from ..testing import MockRequestsResponse
-
-from ..util.http import (
- BadResponseException,
- HTTP,
+from ..overdrive import (
+ MockOverdriveAPI,
+ OverdriveAdvantageAccount,
+ OverdriveAPI,
+ OverdriveBibliographicCoverageProvider,
+ OverdriveRepresentationExtractor,
)
+from ..scripts import RunCollectionCoverageProviderScript
+from ..testing import DatabaseTest, MockRequestsResponse
+from ..util.http import HTTP, BadResponseException
from ..util.string_helpers import base64
-from ..testing import DatabaseTest
-
class OverdriveTest(DatabaseTest):
-
def setup_method(self):
super(OverdriveTest, self).setup_method()
self.collection = MockOverdriveAPI.mock_collection(self._db)
@@ -59,6 +46,7 @@ def sample_json(self, filename):
data = open(path).read()
return data, json.loads(data)
+
class OverdriveTestWithAPI(OverdriveTest):
"""Automatically create a MockOverdriveAPI class during setup.
@@ -68,13 +56,13 @@ class OverdriveTestWithAPI(OverdriveTest):
MockOverdriveAPI request created in a test behaves differently
from the first one.
"""
+
def setup_method(self):
super(OverdriveTestWithAPI, self).setup_method()
self.api = MockOverdriveAPI(self._db, self.collection)
class TestOverdriveAPI(OverdriveTestWithAPI):
-
def test_constructor_makes_no_requests(self):
# Invoking the OverdriveAPI constructor does not, by itself,
# make any HTTP requests.
@@ -82,16 +70,19 @@ def test_constructor_makes_no_requests(self):
class NoRequests(OverdriveAPI):
MSG = "This is a unit test, you can't make HTTP requests!"
+
def no_requests(self, *args, **kwargs):
raise Exception(self.MSG)
+
_do_get = no_requests
_do_post = no_requests
_make_request = no_requests
+
api = NoRequests(self._db, collection)
# Attempting to access .token or .collection_token _will_
# try to make an HTTP request.
- for field in 'token', 'collection_token':
+ for field in "token", "collection_token":
with pytest.raises(Exception) as excinfo:
getattr(api, field)
assert api.MSG in str(excinfo.value)
@@ -110,8 +101,9 @@ def test_ils_name(self):
def test_make_link_safe(self):
# Unsafe characters are escaped.
- assert ("http://foo.com?q=%2B%3A%7B%7D" ==
- OverdriveAPI.make_link_safe("http://foo.com?q=+:{}"))
+ assert "http://foo.com?q=%2B%3A%7B%7D" == OverdriveAPI.make_link_safe(
+ "http://foo.com?q=+:{}"
+ )
# Links to version 1 of the availability API are converted
# to links to version 2.
@@ -142,6 +134,7 @@ def api_with_setting(x):
integration = self.collection.external_integration
integration.setting(c.SERVER_NICKNAME).value = x
return c(self._db, self.collection)
+
testing = api_with_setting(c.TESTING_SERVERS)
assert testing.hosts == c.HOSTS[c.TESTING_SERVERS]
@@ -153,32 +146,39 @@ def api_with_setting(x):
def test_endpoint(self):
# The .endpoint() method performs string interpolation, including
# the names of servers.
- template = "%(host)s %(patron_host)s %(oauth_host)s %(oauth_patron_host)s %(extra)s"
+ template = (
+ "%(host)s %(patron_host)s %(oauth_host)s %(oauth_patron_host)s %(extra)s"
+ )
result = self.api.endpoint(template, extra="val")
# The host names and the 'extra' argument have been used to
# fill in the string interpolations.
expect_args = dict(self.api.hosts)
- expect_args['extra'] = 'val'
+ expect_args["extra"] = "val"
assert result == template % expect_args
# The string has been completely interpolated.
- assert '%' not in result
+ assert "%" not in result
# Once interpolation has happened, doing it again has no effect.
assert result == self.api.endpoint(result, extra="something else")
# This is important because an interpolated URL may superficially
# appear to contain extra formatting characters.
- assert (result + "%3A" ==
- self.api.endpoint(result +"%3A", extra="something else"))
+ assert result + "%3A" == self.api.endpoint(
+ result + "%3A", extra="something else"
+ )
def test_token_authorization_header(self):
# Verify that the Authorization header needed to get an access
# token for a given collection is encoded properly.
assert self.api.token_authorization_header == "Basic YTpi"
- assert self.api.token_authorization_header == "Basic " + base64.standard_b64encode(
- b"%s:%s" % (self.api.client_key, self.api.client_secret)
+ assert (
+ self.api.token_authorization_header
+ == "Basic "
+ + base64.standard_b64encode(
+ b"%s:%s" % (self.api.client_key, self.api.client_secret)
+ )
)
def test_token_post_success(self):
@@ -204,8 +204,13 @@ def test_error_getting_library(self):
class MisconfiguredOverdriveAPI(MockOverdriveAPI):
"""This Overdrive client has valid credentials but the library
can't be found -- probably because the library ID is wrong."""
+
def get_library(self):
- return {'errorCode': 'Some error', 'message': 'Some message.', 'token': 'abc-def-ghi'}
+ return {
+ "errorCode": "Some error",
+ "message": "Some message.",
+ "token": "abc-def-ghi",
+ }
# Just instantiating the API doesn't cause this error.
api = MisconfiguredOverdriveAPI(self._db, self.collection)
@@ -213,7 +218,10 @@ def get_library(self):
# But trying to access the collection token will cause it.
with pytest.raises(CannotLoadConfiguration) as excinfo:
api.collection_token()
- assert "Overdrive credentials are valid but could not fetch library: Some message." in str(excinfo.value)
+ assert (
+ "Overdrive credentials are valid but could not fetch library: Some message."
+ in str(excinfo.value)
+ )
def test_401_on_get_refreshes_bearer_token(self):
# We have a token.
@@ -241,8 +249,7 @@ def test_401_on_get_refreshes_bearer_token(self):
assert "new bearer token" == self.api.token
def test_credential_refresh_success(self):
- """Verify the process of refreshing the Overdrive bearer token.
- """
+ """Verify the process of refreshing the Overdrive bearer token."""
# Perform the initial credential check.
self.api.check_creds()
credential = self.api.credential_object(lambda x: x)
@@ -281,7 +288,9 @@ def test_401_after_token_refresh_raises_error(self):
with pytest.raises(BadResponseException) as excinfo:
self.api.get_library()
assert "Bad response from" in str(excinfo.value)
- assert "Something's wrong with the Overdrive OAuth Bearer Token!" in str(excinfo.value)
+ assert "Something's wrong with the Overdrive OAuth Bearer Token!" in str(
+ excinfo.value
+ )
def test_401_during_refresh_raises_error(self):
"""If we fail to refresh the OAuth bearer token, an exception is
@@ -299,19 +308,22 @@ def test_advantage_differences(self):
# Here's a regular Overdrive collection.
main = self._collection(
- protocol=ExternalIntegration.OVERDRIVE, external_account_id="1",
+ protocol=ExternalIntegration.OVERDRIVE,
+ external_account_id="1",
)
main.external_integration.username = "user"
main.external_integration.password = "password"
- main.external_integration.setting('website_id').value = '100'
- main.external_integration.setting('ils_name').value = 'default'
+ main.external_integration.setting("website_id").value = "100"
+ main.external_integration.setting("ils_name").value = "default"
# Here's an Overdrive API client for that collection.
overdrive_main = MockOverdriveAPI(self._db, main)
# Note the "library" endpoint.
- assert ("https://api.overdrive.com/v1/libraries/1" ==
- overdrive_main._library_endpoint)
+ assert (
+ "https://api.overdrive.com/v1/libraries/1"
+ == overdrive_main._library_endpoint
+ )
# The advantage_library_id of a non-Advantage Overdrive account
# is always -1.
@@ -321,7 +333,8 @@ def test_advantage_differences(self):
# Here's an Overdrive Advantage collection associated with the
# main Overdrive collection.
child = self._collection(
- protocol=ExternalIntegration.OVERDRIVE, external_account_id="2",
+ protocol=ExternalIntegration.OVERDRIVE,
+ external_account_id="2",
)
child.parent = main
overdrive_child = MockOverdriveAPI(self._db, child)
@@ -330,8 +343,9 @@ def test_advantage_differences(self):
# collection is beneath the the parent collection's "library"
# endpoint.
assert (
- 'https://api.overdrive.com/v1/libraries/1/advantageAccounts/2' ==
- overdrive_child._library_endpoint)
+ "https://api.overdrive.com/v1/libraries/1/advantageAccounts/2"
+ == overdrive_child._library_endpoint
+ )
# The advantage_library_id of an Advantage collection is the
# numeric value of its external_account_id.
@@ -352,9 +366,11 @@ def availability_link_list(self, content):
return ["an availability queue"]
original_data = {"key": "value"}
- for content in (original_data,
- json.dumps(original_data),
- json.dumps(original_data).encode("utf8")):
+ for content in (
+ original_data,
+ json.dumps(original_data),
+ json.dumps(original_data).encode("utf8"),
+ ):
extractor = MockExtractor()
self.api.queue_response(200, content=content)
result = self.api._get_book_list_page(
@@ -384,36 +400,33 @@ def availability_link_list(self, content):
class TestOverdriveRepresentationExtractor(OverdriveTestWithAPI):
-
def test_availability_info(self):
data, raw = self.sample_json("overdrive_book_list.json")
- availability = OverdriveRepresentationExtractor.availability_link_list(
- raw)
+ availability = OverdriveRepresentationExtractor.availability_link_list(raw)
# Every item in the list has a few important values.
for item in availability:
- for key in 'availability_link', 'author_name', 'id', 'title', 'date_added':
+ for key in "availability_link", "author_name", "id", "title", "date_added":
assert key in item
# Also run a spot check on the actual values.
spot = availability[0]
- assert '210bdcad-29b7-445f-8d05-cdbb40abc03a' == spot['id']
- assert 'King and Maxwell' == spot['title']
- assert 'David Baldacci' == spot['author_name']
- assert '2013-11-12T14:13:00-05:00' == spot['date_added']
+ assert "210bdcad-29b7-445f-8d05-cdbb40abc03a" == spot["id"]
+ assert "King and Maxwell" == spot["title"]
+ assert "David Baldacci" == spot["author_name"]
+ assert "2013-11-12T14:13:00-05:00" == spot["date_added"]
def test_availability_info_missing_data(self):
# overdrive_book_list_missing_data.json has two products. One
# only has a title, the other only has an ID.
data, raw = self.sample_json("overdrive_book_list_missing_data.json")
- [item] = OverdriveRepresentationExtractor.availability_link_list(
- raw)
+ [item] = OverdriveRepresentationExtractor.availability_link_list(raw)
# We got a data structure -- full of missing data -- for the
# item that has an ID.
- assert 'i only have an id' == item['id']
- assert None == item['title']
- assert None == item['author_name']
- assert None == item['date_added']
+ assert "i only have an id" == item["id"]
+ assert None == item["title"]
+ assert None == item["author_name"]
+ assert None == item["date_added"]
# We did not get a data structure for the item that only has a
# title, because an ID is required -- otherwise we don't know
@@ -421,10 +434,11 @@ def test_availability_info_missing_data(self):
def test_link(self):
data, raw = self.sample_json("overdrive_book_list.json")
- expect = OverdriveAPI.make_link_safe("http://api.overdrive.com/v1/collections/collection-id/products?limit=300&offset=0&lastupdatetime=2014-04-28%2009:25:09&sort=popularity:desc&formats=ebook-epub-open,ebook-epub-adobe,ebook-pdf-adobe,ebook-pdf-open")
+ expect = OverdriveAPI.make_link_safe(
+ "http://api.overdrive.com/v1/collections/collection-id/products?limit=300&offset=0&lastupdatetime=2014-04-28%2009:25:09&sort=popularity:desc&formats=ebook-epub-open,ebook-epub-adobe,ebook-pdf-adobe,ebook-pdf-open"
+ )
assert expect == OverdriveRepresentationExtractor.link(raw, "first")
-
def test_book_info_to_circulation(self):
# Tests that can convert an overdrive json block into a CirculationData object.
@@ -442,8 +456,10 @@ def test_book_info_to_circulation(self):
# Related IDs.
identifier = circulationdata.primary_identifier(self._db)
- assert ((Identifier.OVERDRIVE_ID, '2a005d55-a417-4053-b90d-7a38ca6d2065') ==
- (identifier.type, identifier.identifier))
+ assert (Identifier.OVERDRIVE_ID, "2a005d55-a417-4053-b90d-7a38ca6d2065") == (
+ identifier.type,
+ identifier.identifier,
+ )
def test_book_info_to_circulation_advantage(self):
# Overdrive Advantage accounts derive different information
@@ -482,13 +498,13 @@ class MockAPI(object):
# Pretend to be an API for an Overdrive Advantage collection with
# library ID 62.
advantage_library_id = 62
+
extractor = OverdriveRepresentationExtractor(MockAPI())
advantage_data = extractor.book_info_to_circulation(info)
assert None == advantage_data.licenses_owned
assert None == advantage_data.licenses_available
assert 0 == consortial_data.patrons_in_hold_queue
-
def test_not_found_error_to_circulationdata(self):
raw, info = self.sample_json("overdrive_availability_not_found.json")
@@ -503,7 +519,7 @@ def test_not_found_error_to_circulationdata(self):
# circulation code does), we do know, and we can create a
# CirculationData.
identifier = self._identifier(identifier_type=Identifier.OVERDRIVE_ID)
- info['id'] = identifier.identifier
+ info["id"] = identifier.identifier
data = m(info)
assert identifier == data.primary_identifier(self._db)
assert 0 == data.licenses_owned
@@ -517,8 +533,14 @@ def test_book_info_with_metadata(self):
metadata = OverdriveRepresentationExtractor.book_info_to_metadata(info)
assert "Agile Documentation" == metadata.title
- assert "Agile Documentation A Pattern Guide to Producing Lightweight Documents for Software Projects" == metadata.sort_title
- assert "A Pattern Guide to Producing Lightweight Documents for Software Projects" == metadata.subtitle
+ assert (
+ "Agile Documentation A Pattern Guide to Producing Lightweight Documents for Software Projects"
+ == metadata.sort_title
+ )
+ assert (
+ "A Pattern Guide to Producing Lightweight Documents for Software Projects"
+ == metadata.subtitle
+ )
assert Edition.BOOK_MEDIUM == metadata.medium
assert "Wiley Software Patterns" == metadata.series
assert "eng" == metadata.language
@@ -535,15 +557,17 @@ def test_book_info_with_metadata(self):
subjects = sorted(metadata.subjects, key=lambda x: x.identifier)
- assert ([("Computer Technology", Subject.OVERDRIVE, 100),
- ("Nonfiction", Subject.OVERDRIVE, 100),
- ('Object Technologies - Miscellaneous', 'tag', 1),
- ] ==
- [(x.identifier, x.type, x.weight) for x in subjects])
+ assert [
+ ("Computer Technology", Subject.OVERDRIVE, 100),
+ ("Nonfiction", Subject.OVERDRIVE, 100),
+ ("Object Technologies - Miscellaneous", "tag", 1),
+ ] == [(x.identifier, x.type, x.weight) for x in subjects]
# Related IDs.
- assert ((Identifier.OVERDRIVE_ID, '3896665d-9d81-4cac-bd43-ffc5066de1f5') ==
- (metadata.primary_identifier.type, metadata.primary_identifier.identifier))
+ assert (Identifier.OVERDRIVE_ID, "3896665d-9d81-4cac-bd43-ffc5066de1f5") == (
+ metadata.primary_identifier.type,
+ metadata.primary_identifier.identifier,
+ )
ids = [(x.type, x.identifier) for x in metadata.identifiers]
@@ -552,16 +576,16 @@ def test_book_info_with_metadata(self):
# text, one which is mis-typed and has a bad check digit, and one
# which has an invalid character; the bad identifiers do not show
# up here.
- assert (
- [
- (Identifier.ASIN, "B000VI88N2"),
- (Identifier.ISBN, "9780470856246"),
- (Identifier.OVERDRIVE_ID, '3896665d-9d81-4cac-bd43-ffc5066de1f5'),
- ] ==
- sorted(ids))
+ assert [
+ (Identifier.ASIN, "B000VI88N2"),
+ (Identifier.ISBN, "9780470856246"),
+ (Identifier.OVERDRIVE_ID, "3896665d-9d81-4cac-bd43-ffc5066de1f5"),
+ ] == sorted(ids)
# Available formats.
- [kindle, pdf] = sorted(metadata.circulation.formats, key=lambda x: x.content_type)
+ [kindle, pdf] = sorted(
+ metadata.circulation.formats, key=lambda x: x.content_type
+ )
assert DeliveryMechanism.KINDLE_CONTENT_TYPE == kindle.content_type
assert DeliveryMechanism.KINDLE_DRM == kindle.drm_scheme
@@ -569,9 +593,7 @@ def test_book_info_with_metadata(self):
assert DeliveryMechanism.ADOBE_DRM == pdf.drm_scheme
# Links to various resources.
- shortd, image, longd = sorted(
- metadata.links, key=lambda x:x.rel
- )
+ shortd, image, longd = sorted(metadata.links, key=lambda x: x.rel)
assert Hyperlink.DESCRIPTION == longd.rel
assert longd.content.startswith("
Software documentation")
@@ -581,36 +603,50 @@ def test_book_info_with_metadata(self):
assert len(shortd.content) < len(longd.content)
assert Hyperlink.IMAGE == image.rel
- assert 'http://images.contentreserve.com/ImageType-100/0128-1/%7B3896665D-9D81-4CAC-BD43-FFC5066DE1F5%7DImg100.jpg' == image.href
+ assert (
+ "http://images.contentreserve.com/ImageType-100/0128-1/%7B3896665D-9D81-4CAC-BD43-FFC5066DE1F5%7DImg100.jpg"
+ == image.href
+ )
thumbnail = image.thumbnail
assert Hyperlink.THUMBNAIL_IMAGE == thumbnail.rel
- assert 'http://images.contentreserve.com/ImageType-200/0128-1/%7B3896665D-9D81-4CAC-BD43-FFC5066DE1F5%7DImg200.jpg' == thumbnail.href
+ assert (
+ "http://images.contentreserve.com/ImageType-200/0128-1/%7B3896665D-9D81-4CAC-BD43-FFC5066DE1F5%7DImg200.jpg"
+ == thumbnail.href
+ )
# Measurements associated with the book.
measurements = metadata.measurements
- popularity = [x for x in measurements
- if x.quantity_measured==Measurement.POPULARITY][0]
+ popularity = [
+ x for x in measurements if x.quantity_measured == Measurement.POPULARITY
+ ][0]
assert 2 == popularity.value
- rating = [x for x in measurements
- if x.quantity_measured==Measurement.RATING][0]
+ rating = [x for x in measurements if x.quantity_measured == Measurement.RATING][
+ 0
+ ]
assert 1 == rating.value
# Request only the bibliographic information.
- metadata = OverdriveRepresentationExtractor.book_info_to_metadata(info, include_bibliographic=True, include_formats=False)
+ metadata = OverdriveRepresentationExtractor.book_info_to_metadata(
+ info, include_bibliographic=True, include_formats=False
+ )
assert "Agile Documentation" == metadata.title
assert None == metadata.circulation
# Request only the format information.
- metadata = OverdriveRepresentationExtractor.book_info_to_metadata(info, include_bibliographic=False, include_formats=True)
+ metadata = OverdriveRepresentationExtractor.book_info_to_metadata(
+ info, include_bibliographic=False, include_formats=True
+ )
assert None == metadata.title
- [kindle, pdf] = sorted(metadata.circulation.formats, key=lambda x: x.content_type)
+ [kindle, pdf] = sorted(
+ metadata.circulation.formats, key=lambda x: x.content_type
+ )
assert DeliveryMechanism.KINDLE_CONTENT_TYPE == kindle.content_type
assert DeliveryMechanism.KINDLE_DRM == kindle.drm_scheme
@@ -627,10 +663,10 @@ def test_audiobook_info(self):
streaming, manifest, legacy = sorted(
metadata.circulation.formats, key=lambda x: x.content_type
)
- assert (DeliveryMechanism.STREAMING_AUDIO_CONTENT_TYPE ==
- streaming.content_type)
- assert (MediaTypes.OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE ==
- manifest.content_type)
+ assert DeliveryMechanism.STREAMING_AUDIO_CONTENT_TYPE == streaming.content_type
+ assert (
+ MediaTypes.OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE == manifest.content_type
+ )
assert "application/x-od-media" == legacy.content_type
def test_book_info_with_sample(self):
@@ -639,47 +675,48 @@ def test_book_info_with_sample(self):
raw, info = self.sample_json("has_sample.json")
metadata = OverdriveRepresentationExtractor.book_info_to_metadata(info)
samples = [x for x in metadata.links if x.rel == Hyperlink.SAMPLE]
- epub_sample, manifest_sample = sorted(
- samples, key=lambda x: x.media_type
- )
+ epub_sample, manifest_sample = sorted(samples, key=lambda x: x.media_type)
# Here's the direct download.
- assert ("http://excerpts.contentreserve.com/FormatType-410/1071-1/9BD/24F/82/BridesofConvenienceBundle9781426803697.epub" ==
- epub_sample.href)
+ assert (
+ "http://excerpts.contentreserve.com/FormatType-410/1071-1/9BD/24F/82/BridesofConvenienceBundle9781426803697.epub"
+ == epub_sample.href
+ )
assert MediaTypes.EPUB_MEDIA_TYPE == epub_sample.media_type
# Here's the manifest.
- assert ("https://samples.overdrive.com/?crid=9BD24F82-35C0-4E0A-B5E7-BCFED07835CF&.epub-sample.overdrive.com" ==
- manifest_sample.href)
- assert (MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE ==
- manifest_sample.media_type)
+ assert (
+ "https://samples.overdrive.com/?crid=9BD24F82-35C0-4E0A-B5E7-BCFED07835CF&.epub-sample.overdrive.com"
+ == manifest_sample.href
+ )
+ assert (
+ MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE == manifest_sample.media_type
+ )
def test_book_info_with_grade_levels(self):
raw, info = self.sample_json("has_grade_levels.json")
metadata = OverdriveRepresentationExtractor.book_info_to_metadata(info)
grade_levels = sorted(
- [x.identifier for x in metadata.subjects
- if x.type==Subject.GRADE_LEVEL]
+ [x.identifier for x in metadata.subjects if x.type == Subject.GRADE_LEVEL]
)
- assert (['Grade 4', 'Grade 5', 'Grade 6', 'Grade 7', 'Grade 8'] ==
- grade_levels)
+ assert ["Grade 4", "Grade 5", "Grade 6", "Grade 7", "Grade 8"] == grade_levels
def test_book_info_with_awards(self):
raw, info = self.sample_json("has_awards.json")
metadata = OverdriveRepresentationExtractor.book_info_to_metadata(info)
- [awards] = [x for x in metadata.measurements
- if Measurement.AWARDS == x.quantity_measured
+ [awards] = [
+ x
+ for x in metadata.measurements
+ if Measurement.AWARDS == x.quantity_measured
]
assert 1 == awards.value
assert 1 == awards.weight
def test_image_link_to_linkdata(self):
def m(link):
- return OverdriveRepresentationExtractor.image_link_to_linkdata(
- link, "rel"
- )
+ return OverdriveRepresentationExtractor.image_link_to_linkdata(link, "rel")
# Test missing data.
assert None == m(None)
@@ -702,53 +739,61 @@ def m(link):
assert "http://api.overdrive.com/v1/foo%3Abar" == data.href
# Stand-in cover images are detected and filtered out.
- data = m(dict(href="https://img1.od-cdn.com/ImageType-100/0293-1/{00000000-0000-0000-0000-000000000002}Img100.jpg"))
+ data = m(
+ dict(
+ href="https://img1.od-cdn.com/ImageType-100/0293-1/{00000000-0000-0000-0000-000000000002}Img100.jpg"
+ )
+ )
assert None == data
def test_internal_formats(self):
# Overdrive's internal format names may correspond to one or more
# delivery mechanisms.
def assert_formats(overdrive_name, *expect):
- actual = OverdriveRepresentationExtractor.internal_formats(
- overdrive_name
- )
+ actual = OverdriveRepresentationExtractor.internal_formats(overdrive_name)
assert list(expect) == list(actual)
# Most formats correspond to one delivery mechanism.
assert_formats(
- 'ebook-pdf-adobe',
- (MediaTypes.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM)
+ "ebook-pdf-adobe", (MediaTypes.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM)
)
assert_formats(
- 'ebook-epub-open',
- (MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM)
+ "ebook-epub-open", (MediaTypes.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM)
)
# ebook-overdrive and audiobook-overdrive each correspond to
# two delivery mechanisms.
assert_formats(
- 'ebook-overdrive',
- (MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE,
- DeliveryMechanism.LIBBY_DRM),
- (DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
- DeliveryMechanism.STREAMING_DRM),
+ "ebook-overdrive",
+ (
+ MediaTypes.OVERDRIVE_EBOOK_MANIFEST_MEDIA_TYPE,
+ DeliveryMechanism.LIBBY_DRM,
+ ),
+ (
+ DeliveryMechanism.STREAMING_TEXT_CONTENT_TYPE,
+ DeliveryMechanism.STREAMING_DRM,
+ ),
)
assert_formats(
- 'audiobook-overdrive',
- (MediaTypes.OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE,
- DeliveryMechanism.LIBBY_DRM),
- (DeliveryMechanism.STREAMING_AUDIO_CONTENT_TYPE,
- DeliveryMechanism.STREAMING_DRM),
+ "audiobook-overdrive",
+ (
+ MediaTypes.OVERDRIVE_AUDIOBOOK_MANIFEST_MEDIA_TYPE,
+ DeliveryMechanism.LIBBY_DRM,
+ ),
+ (
+ DeliveryMechanism.STREAMING_AUDIO_CONTENT_TYPE,
+ DeliveryMechanism.STREAMING_DRM,
+ ),
)
# An unrecognized format does not correspond to any delivery
# mechanisms.
- assert_formats('no-such-format')
+ assert_formats("no-such-format")
-class TestOverdriveAdvantageAccount(OverdriveTestWithAPI):
+class TestOverdriveAdvantageAccount(OverdriveTestWithAPI):
def test_no_advantage_accounts(self):
"""When there are no Advantage accounts, get_advantage_accounts()
returns an empty list.
@@ -779,19 +824,24 @@ def test_to_collection(self):
# a Collection object.
account = OverdriveAdvantageAccount(
- "parent_id", "child_id", "Library Name",
+ "parent_id",
+ "child_id",
+ "Library Name",
)
# We can't just create a Collection object for this object because
# the parent doesn't exist.
with pytest.raises(ValueError) as excinfo:
account.to_collection(self._db)
- assert "Cannot create a Collection whose parent does not already exist." in str(excinfo.value)
+ assert "Cannot create a Collection whose parent does not already exist." in str(
+ excinfo.value
+ )
# So, create a Collection to be the parent.
parent = self._collection(
- name="Parent", protocol=ExternalIntegration.OVERDRIVE,
- external_account_id="parent_id"
+ name="Parent",
+ protocol=ExternalIntegration.OVERDRIVE,
+ external_account_id="parent_id",
)
# Now it works.
@@ -799,10 +849,8 @@ def test_to_collection(self):
assert p == parent
assert parent == collection.parent
assert collection.external_account_id == account.library_id
- assert (ExternalIntegration.LICENSE_GOAL ==
- collection.external_integration.goal)
- assert (ExternalIntegration.OVERDRIVE ==
- collection.protocol)
+ assert ExternalIntegration.LICENSE_GOAL == collection.external_integration.goal
+ assert ExternalIntegration.OVERDRIVE == collection.protocol
# To ensure uniqueness, the collection was named after its
# parent.
@@ -824,19 +872,17 @@ def test_script_instantiation(self):
the coverage provider.
"""
script = RunCollectionCoverageProviderScript(
- OverdriveBibliographicCoverageProvider, self._db,
- api_class=MockOverdriveAPI
+ OverdriveBibliographicCoverageProvider, self._db, api_class=MockOverdriveAPI
)
[provider] = script.providers
- assert isinstance(provider,
- OverdriveBibliographicCoverageProvider)
+ assert isinstance(provider, OverdriveBibliographicCoverageProvider)
assert isinstance(provider.api, MockOverdriveAPI)
assert self.collection == provider.collection
def test_invalid_or_unrecognized_guid(self):
"""A bad or malformed GUID can't get coverage."""
identifier = self._identifier()
- identifier.identifier = 'bad guid'
+ identifier.identifier = "bad guid"
self.api.queue_collection_token()
error = '{"errorCode": "InvalidGuid", "message": "An invalid guid was given.", "token": "7aebce0e-2e88-41b3-b6d3-82bf15f8e1a2"}'
@@ -865,7 +911,7 @@ def test_process_item_creates_presentation_ready_work(self):
# Here's the book mentioned in overdrive_metadata.json.
identifier = self._identifier(identifier_type=Identifier.OVERDRIVE_ID)
- identifier.identifier = '3896665d-9d81-4cac-bd43-ffc5066de1f5'
+ identifier.identifier = "3896665d-9d81-4cac-bd43-ffc5066de1f5"
# This book has no LicensePool.
assert [] == identifier.licensed_through
@@ -885,10 +931,16 @@ def test_process_item_creates_presentation_ready_work(self):
assert 0 == pool.licenses_owned
[lpdm1, lpdm2] = pool.delivery_mechanisms
names = [x.delivery_mechanism.name for x in pool.delivery_mechanisms]
- assert sorted(['application/pdf (application/vnd.adobe.adept+xml)',
- 'Kindle via Amazon (Kindle DRM)']) == sorted(names)
+ assert (
+ sorted(
+ [
+ "application/pdf (application/vnd.adobe.adept+xml)",
+ "Kindle via Amazon (Kindle DRM)",
+ ]
+ )
+ == sorted(names)
+ )
# A Work was created and made presentation ready.
assert "Agile Documentation" == pool.work.title
assert True == pool.work.presentation_ready
-
diff --git a/tests/test_personal_names.py b/tests/test_personal_names.py
index 354896269..d89a5080c 100644
--- a/tests/test_personal_names.py
+++ b/tests/test_personal_names.py
@@ -1,37 +1,28 @@
# encoding: utf-8
-from io import StringIO
import datetime
import os
-import sys
-import site
import re
+import site
+import sys
import tempfile
+from io import StringIO
+from ..mock_analytics_provider import MockAnalyticsProvider
from ..model import (
Contributor,
DataSource,
- Work,
- Identifier,
Edition,
+ Identifier,
+ Work,
create,
get_one,
get_one_or_create,
)
-
-from ..testing import (
- DatabaseTest,
- DummyHTTPClient,
-)
-
-from ..util.personal_names import (
- display_name_to_sort_name,
-)
-from ..mock_analytics_provider import MockAnalyticsProvider
-
+from ..testing import DatabaseTest, DummyHTTPClient
+from ..util.personal_names import display_name_to_sort_name
class TestNameConversions(DatabaseTest):
-
def test_display_name_to_sort_name(self):
# Make sure the sort name algorithm processes the messy reality of contributor
# names in a way we expect.
@@ -69,8 +60,9 @@ def unchanged(x):
sort_name = m("Bob Bitshifter, III")
assert "Bitshifter, Bob III" == sort_name
- assert ("Beck, James M. (James Montgomery)" ==
- m("James M. (James Montgomery) Beck"))
+ assert "Beck, James M. (James Montgomery)" == m(
+ "James M. (James Montgomery) Beck"
+ )
# all forms of PhD are recognized
sort_name = m("John Doe, PhD")
@@ -107,8 +99,3 @@ def test_name_tidy(self):
# retain proper period
sort_name = display_name_to_sort_name("Bitshifter, B.")
assert "Bitshifter, B." == sort_name
-
-
-
-
-
diff --git a/tests/test_problem_detail.py b/tests/test_problem_detail.py
index 7afb22af7..dd32fdf1a 100644
--- a/tests/test_problem_detail.py
+++ b/tests/test_problem_detail.py
@@ -1,12 +1,10 @@
# encoding: utf-8
import json
-from ..util.problem_detail import (
- ProblemDetail,
-)
+from ..util.problem_detail import ProblemDetail
-class TestProblemDetail(object):
+class TestProblemDetail(object):
def test_with_debug(self):
detail = ProblemDetail("http://uri/", title="Title", detail="Detail")
with_debug = detail.with_debug("Debug Message")
@@ -15,5 +13,5 @@ def test_with_debug(self):
assert "Title" == with_debug.title
json_data, status, headers = with_debug.response
data = json.loads(json_data)
- assert "Debug Message" == data['debug_message']
- assert "Detail" == data['detail']
+ assert "Debug Message" == data["debug_message"]
+ assert "Detail" == data["detail"]
diff --git a/tests/test_s3.py b/tests/test_s3.py
index ded8bcb36..13829c082 100644
--- a/tests/test_s3.py
+++ b/tests/test_s3.py
@@ -2,50 +2,44 @@
import functools
import os
from urllib.parse import urlsplit
+
import boto3
import botocore
import pytest
-from botocore.exceptions import (
- BotoCoreError,
- ClientError,
-)
+from botocore.exceptions import BotoCoreError, ClientError
from mock import MagicMock
-import pytest
from parameterized import parameterized
-from ..testing import (
- DatabaseTest
-)
from ..mirror import MirrorUploader
from ..model import (
- Identifier,
DataSource,
ExternalIntegration,
Hyperlink,
+ Identifier,
Representation,
create,
)
from ..s3 import (
- S3Uploader,
+ MinIOUploader,
+ MinIOUploaderConfiguration,
MockS3Client,
MultipartS3Upload,
S3AddressingStyle,
- MinIOUploader,
+ S3Uploader,
S3UploaderConfiguration,
- MinIOUploaderConfiguration
)
+from ..testing import DatabaseTest
from ..util.datetime_helpers import datetime_utc, utc_now
-class S3UploaderTest(DatabaseTest):
+class S3UploaderTest(DatabaseTest):
def _integration(self, **settings):
"""Create and configure a simple S3 integration."""
integration = self._external_integration(
- ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL,
- settings=settings
+ ExternalIntegration.S3, ExternalIntegration.STORAGE_GOAL, settings=settings
)
- integration.username = settings.get('username', 'username')
- integration.password = settings.get('password', 'password')
+ integration.username = settings.get("username", "username")
+ integration.password = settings.get("password", "password")
return integration
def _add_settings_value(self, settings, key, value):
@@ -68,19 +62,18 @@ def _add_settings_value(self, settings, key, value):
settings[key] = value
else:
- settings = {
- key: value
- }
+ settings = {key: value}
return settings
def _create_s3_uploader(
- self,
- client_class=None,
- uploader_class=None,
- region=None,
- addressing_style=None,
- **settings):
+ self,
+ client_class=None,
+ uploader_class=None,
+ region=None,
+ addressing_style=None,
+ **settings
+ ):
"""Creates a new instance of S3 uploader
:param client_class: (Optional) Custom class to be used instead of boto3's client class
@@ -101,8 +94,12 @@ def _create_s3_uploader(
:return: New intance of S3 uploader
:rtype: S3Uploader
"""
- settings = self._add_settings_value(settings, S3UploaderConfiguration.S3_REGION, region)
- settings = self._add_settings_value(settings, S3UploaderConfiguration.S3_ADDRESSING_STYLE, addressing_style)
+ settings = self._add_settings_value(
+ settings, S3UploaderConfiguration.S3_REGION, region
+ )
+ settings = self._add_settings_value(
+ settings, S3UploaderConfiguration.S3_ADDRESSING_STYLE, addressing_style
+ )
integration = self._integration(**settings)
uploader_class = uploader_class or S3Uploader
@@ -110,10 +107,18 @@ def _create_s3_uploader(
class S3UploaderIntegrationTest(S3UploaderTest):
- SIMPLIFIED_TEST_MINIO_ENDPOINT_URL = os.environ.get('SIMPLIFIED_TEST_MINIO_ENDPOINT_URL', 'http://localhost:9000')
- SIMPLIFIED_TEST_MINIO_USER = os.environ.get('SIMPLIFIED_TEST_MINIO_USER', 'minioadmin')
- SIMPLIFIED_TEST_MINIO_PASSWORD = os.environ.get('SIMPLIFIED_TEST_MINIO_PASSWORD', 'minioadmin')
- _, SIMPLIFIED_TEST_MINIO_HOST, _, _, _ = urlsplit(SIMPLIFIED_TEST_MINIO_ENDPOINT_URL)
+ SIMPLIFIED_TEST_MINIO_ENDPOINT_URL = os.environ.get(
+ "SIMPLIFIED_TEST_MINIO_ENDPOINT_URL", "http://localhost:9000"
+ )
+ SIMPLIFIED_TEST_MINIO_USER = os.environ.get(
+ "SIMPLIFIED_TEST_MINIO_USER", "minioadmin"
+ )
+ SIMPLIFIED_TEST_MINIO_PASSWORD = os.environ.get(
+ "SIMPLIFIED_TEST_MINIO_PASSWORD", "minioadmin"
+ )
+ _, SIMPLIFIED_TEST_MINIO_HOST, _, _, _ = urlsplit(
+ SIMPLIFIED_TEST_MINIO_ENDPOINT_URL
+ )
minio_s3_client = None
"""boto3 client connected to locally running MinIO instance"""
@@ -127,14 +132,14 @@ def setup_class(cls):
super(S3UploaderIntegrationTest, cls).setup_class()
cls.minio_s3_client = boto3.client(
- 's3',
+ "s3",
aws_access_key_id=TestS3UploaderIntegration.SIMPLIFIED_TEST_MINIO_USER,
aws_secret_access_key=TestS3UploaderIntegration.SIMPLIFIED_TEST_MINIO_PASSWORD,
- endpoint_url=TestS3UploaderIntegration.SIMPLIFIED_TEST_MINIO_ENDPOINT_URL
+ endpoint_url=TestS3UploaderIntegration.SIMPLIFIED_TEST_MINIO_ENDPOINT_URL,
)
cls.s3_client_class = functools.partial(
boto3.client,
- endpoint_url=TestS3UploaderIntegration.SIMPLIFIED_TEST_MINIO_ENDPOINT_URL
+ endpoint_url=TestS3UploaderIntegration.SIMPLIFIED_TEST_MINIO_ENDPOINT_URL,
)
def teardown_method(self):
@@ -143,25 +148,26 @@ def teardown_method(self):
response = self.minio_s3_client.list_buckets()
- for bucket in response['Buckets']:
- bucket_name = bucket['Name']
+ for bucket in response["Buckets"]:
+ bucket_name = bucket["Name"]
response = self.minio_s3_client.list_objects(Bucket=bucket_name)
- for object in response.get('Contents', []):
- object_key = object['Key']
+ for object in response.get("Contents", []):
+ object_key = object["Key"]
self.minio_s3_client.delete_object(Bucket=bucket_name, Key=object_key)
self.minio_s3_client.delete_bucket(Bucket=bucket_name)
def _create_s3_uploader(
- self,
- client_class=None,
- uploader_class=None,
- region=None,
- addressing_style=None,
- **settings):
+ self,
+ client_class=None,
+ uploader_class=None,
+ region=None,
+ addressing_style=None,
+ **settings
+ ):
"""Creates a new instance of S3 uploader
:param client_class: (Optional) Custom class to be used instead of boto3's client class
@@ -182,19 +188,19 @@ def _create_s3_uploader(
:return: New intance of S3 uploader
:rtype: S3Uploader
"""
- if settings and 'username' not in settings:
- self._add_settings_value(settings, 'username', self.SIMPLIFIED_TEST_MINIO_USER)
- if settings and 'password' not in settings:
- self._add_settings_value(settings, 'password', self.SIMPLIFIED_TEST_MINIO_PASSWORD)
+ if settings and "username" not in settings:
+ self._add_settings_value(
+ settings, "username", self.SIMPLIFIED_TEST_MINIO_USER
+ )
+ if settings and "password" not in settings:
+ self._add_settings_value(
+ settings, "password", self.SIMPLIFIED_TEST_MINIO_PASSWORD
+ )
if not client_class:
client_class = self.s3_client_class
return super(S3UploaderIntegrationTest, self)._create_s3_uploader(
- client_class,
- uploader_class,
- region,
- addressing_style,
- **settings
+ client_class, uploader_class, region, addressing_style, **settings
)
@@ -205,45 +211,40 @@ def test_names(self):
# better if it's the same as the name of the external
# integration.
assert S3Uploader.NAME == ExternalIntegration.S3
- assert (S3Uploader ==
- MirrorUploader.IMPLEMENTATION_REGISTRY[ExternalIntegration.S3])
+ assert (
+ S3Uploader == MirrorUploader.IMPLEMENTATION_REGISTRY[ExternalIntegration.S3]
+ )
def test_instantiation(self):
integration = self._external_integration(
ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL
)
- integration.username = 'your-access-key'
- integration.password = 'your-secret-key'
- integration.setting(S3UploaderConfiguration.URL_TEMPLATE_KEY).value = 'a transform'
+ integration.username = "your-access-key"
+ integration.password = "your-secret-key"
+ integration.setting(
+ S3UploaderConfiguration.URL_TEMPLATE_KEY
+ ).value = "a transform"
uploader = MirrorUploader.implementation(integration)
assert True == isinstance(uploader, S3Uploader)
# The URL_TEMPLATE_KEY setting becomes the .url_transform
# attribute on the S3Uploader object.
- assert 'a transform' == uploader.url_transform
-
- @parameterized.expand([
- (
- 'empty_credentials',
- None,
- None
- ),
- (
- 'empty_string_credentials',
- '',
- ''
- ),
- (
- 'non_empty_string_credentials',
- 'username',
- 'password'
- )
- ])
+ assert "a transform" == uploader.url_transform
+
+ @parameterized.expand(
+ [
+ ("empty_credentials", None, None),
+ ("empty_string_credentials", "", ""),
+ ("non_empty_string_credentials", "username", "password"),
+ ]
+ )
def test_initialization(self, name, username, password):
# Arrange
- settings = {'username': username, 'password': password}
+ settings = {"username": username, "password": password}
integration = self._external_integration(
- ExternalIntegration.S3, goal=ExternalIntegration.STORAGE_GOAL, settings=settings
+ ExternalIntegration.S3,
+ goal=ExternalIntegration.STORAGE_GOAL,
+ settings=settings,
)
client_class = MagicMock()
@@ -254,26 +255,33 @@ def test_initialization(self, name, username, password):
assert client_class.call_count == 2
service_name = client_class.call_args_list[0].args[0]
- region_name = client_class.call_args_list[0].kwargs['region_name']
- aws_access_key_id = client_class.call_args_list[0].kwargs['aws_access_key_id']
- aws_secret_access_key = client_class.call_args_list[0].kwargs['aws_secret_access_key']
- config = client_class.call_args_list[0].kwargs['config']
- assert service_name == 's3'
+ region_name = client_class.call_args_list[0].kwargs["region_name"]
+ aws_access_key_id = client_class.call_args_list[0].kwargs["aws_access_key_id"]
+ aws_secret_access_key = client_class.call_args_list[0].kwargs[
+ "aws_secret_access_key"
+ ]
+ config = client_class.call_args_list[0].kwargs["config"]
+ assert service_name == "s3"
assert region_name == S3UploaderConfiguration.S3_DEFAULT_REGION
assert aws_access_key_id == None
assert aws_secret_access_key == None
assert config.signature_version == botocore.UNSIGNED
- assert config.s3['addressing_style'] == S3UploaderConfiguration.S3_DEFAULT_ADDRESSING_STYLE
+ assert (
+ config.s3["addressing_style"]
+ == S3UploaderConfiguration.S3_DEFAULT_ADDRESSING_STYLE
+ )
service_name = client_class.call_args_list[1].args[0]
- region_name = client_class.call_args_list[1].kwargs['region_name']
- aws_access_key_id = client_class.call_args_list[1].kwargs['aws_access_key_id']
- aws_secret_access_key = client_class.call_args_list[1].kwargs['aws_secret_access_key']
- assert service_name == 's3'
+ region_name = client_class.call_args_list[1].kwargs["region_name"]
+ aws_access_key_id = client_class.call_args_list[1].kwargs["aws_access_key_id"]
+ aws_secret_access_key = client_class.call_args_list[1].kwargs[
+ "aws_secret_access_key"
+ ]
+ assert service_name == "s3"
assert region_name == S3UploaderConfiguration.S3_DEFAULT_REGION
- assert aws_access_key_id == (username if username != '' else None)
- assert aws_secret_access_key == (password if password != '' else None)
- assert 'config' not in client_class.call_args_list[1].kwargs
+ assert aws_access_key_id == (username if username != "" else None)
+ assert aws_secret_access_key == (password if password != "" else None)
+ assert "config" not in client_class.call_args_list[1].kwargs
def test_custom_client_class(self):
"""You can specify a client class to use instead of boto3.client."""
@@ -283,11 +291,11 @@ def test_custom_client_class(self):
def test_get_bucket(self):
buckets = {
- S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'banana',
- S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'bucket'
+ S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "banana",
+ S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "bucket",
}
buckets_plus_irrelevant_setting = dict(buckets)
- buckets_plus_irrelevant_setting['not-a-bucket-at-all'] = "value"
+ buckets_plus_irrelevant_setting["not-a-bucket-at-all"] = "value"
uploader = self._create_s3_uploader(**buckets_plus_irrelevant_setting)
# This S3Uploader knows about the configured buckets. It
@@ -296,118 +304,124 @@ def test_get_bucket(self):
assert buckets == uploader.buckets
# get_bucket just does a lookup in .buckets
- uploader.buckets['foo'] = object()
- result = uploader.get_bucket('foo')
- assert uploader.buckets['foo'] == result
-
- @parameterized.expand([
- (
- 's3_url_with_path_without_slash',
- 'a-bucket',
- 'a-path',
- 'https://a-bucket.s3.amazonaws.com/a-path',
- None
- ),
- (
- 's3_dummy_url_with_path_without_slash',
- 'dummy',
- 'dummy',
- 'https://dummy.s3.amazonaws.com/dummy',
- None
- ),
- (
- 's3_path_style_url_with_path_without_slash',
- 'a-bucket',
- 'a-path',
- 'https://s3.amazonaws.com/a-bucket/a-path',
- None,
- S3AddressingStyle.PATH.value
- ),
- (
- 's3_path_style_dummy_url_with_path_without_slash',
- 'dummy',
- 'dummy',
- 'https://s3.amazonaws.com/dummy/dummy',
- None,
- S3AddressingStyle.PATH.value
- ),
- (
- 's3_url_with_path_with_slash',
- 'a-bucket',
- '/a-path',
- 'https://a-bucket.s3.amazonaws.com/a-path',
- None,
- ),
- (
- 's3_path_style_url_with_path_with_slash',
- 'a-bucket',
- '/a-path',
- 'https://s3.amazonaws.com/a-bucket/a-path',
- None,
- S3AddressingStyle.PATH.value
- ),
- (
- 's3_url_with_custom_region_and_path_without_slash',
- 'a-bucket',
- 'a-path',
- 'https://a-bucket.s3.us-east-2.amazonaws.com/a-path',
- 'us-east-2',
- ),
- (
- 's3_path_style_url_with_custom_region_and_path_without_slash',
- 'a-bucket',
- 'a-path',
- 'https://s3.us-east-2.amazonaws.com/a-bucket/a-path',
- 'us-east-2',
- S3AddressingStyle.PATH.value
- ),
- (
- 's3_url_with_custom_region_and_path_with_slash',
- 'a-bucket',
- '/a-path',
- 'https://a-bucket.s3.us-east-3.amazonaws.com/a-path',
- 'us-east-3'
- ),
- (
- 's3_path_style_url_with_custom_region_and_path_with_slash',
- 'a-bucket',
- '/a-path',
- 'https://s3.us-east-3.amazonaws.com/a-bucket/a-path',
- 'us-east-3',
- S3AddressingStyle.PATH.value
- ),
- (
- 'custom_http_url_and_path_without_slash',
- 'http://a-bucket.com/',
- 'a-path',
- 'http://a-bucket.com/a-path',
- None
- ),
- (
- 'custom_http_url_and_path_with_slash',
- 'http://a-bucket.com/',
- '/a-path',
- 'http://a-bucket.com/a-path',
- None
- ),
- (
- 'custom_http_url_and_path_without_slash',
- 'https://a-bucket.com/',
- 'a-path',
- 'https://a-bucket.com/a-path',
- None
- ),
- (
- 'custom_http_url_and_path_with_slash',
- 'https://a-bucket.com/',
- '/a-path',
- 'https://a-bucket.com/a-path',
- None
- )
- ])
- def test_url(self, name, bucket, path, expected_result, region=None, addressing_style=None):
+ uploader.buckets["foo"] = object()
+ result = uploader.get_bucket("foo")
+ assert uploader.buckets["foo"] == result
+
+ @parameterized.expand(
+ [
+ (
+ "s3_url_with_path_without_slash",
+ "a-bucket",
+ "a-path",
+ "https://a-bucket.s3.amazonaws.com/a-path",
+ None,
+ ),
+ (
+ "s3_dummy_url_with_path_without_slash",
+ "dummy",
+ "dummy",
+ "https://dummy.s3.amazonaws.com/dummy",
+ None,
+ ),
+ (
+ "s3_path_style_url_with_path_without_slash",
+ "a-bucket",
+ "a-path",
+ "https://s3.amazonaws.com/a-bucket/a-path",
+ None,
+ S3AddressingStyle.PATH.value,
+ ),
+ (
+ "s3_path_style_dummy_url_with_path_without_slash",
+ "dummy",
+ "dummy",
+ "https://s3.amazonaws.com/dummy/dummy",
+ None,
+ S3AddressingStyle.PATH.value,
+ ),
+ (
+ "s3_url_with_path_with_slash",
+ "a-bucket",
+ "/a-path",
+ "https://a-bucket.s3.amazonaws.com/a-path",
+ None,
+ ),
+ (
+ "s3_path_style_url_with_path_with_slash",
+ "a-bucket",
+ "/a-path",
+ "https://s3.amazonaws.com/a-bucket/a-path",
+ None,
+ S3AddressingStyle.PATH.value,
+ ),
+ (
+ "s3_url_with_custom_region_and_path_without_slash",
+ "a-bucket",
+ "a-path",
+ "https://a-bucket.s3.us-east-2.amazonaws.com/a-path",
+ "us-east-2",
+ ),
+ (
+ "s3_path_style_url_with_custom_region_and_path_without_slash",
+ "a-bucket",
+ "a-path",
+ "https://s3.us-east-2.amazonaws.com/a-bucket/a-path",
+ "us-east-2",
+ S3AddressingStyle.PATH.value,
+ ),
+ (
+ "s3_url_with_custom_region_and_path_with_slash",
+ "a-bucket",
+ "/a-path",
+ "https://a-bucket.s3.us-east-3.amazonaws.com/a-path",
+ "us-east-3",
+ ),
+ (
+ "s3_path_style_url_with_custom_region_and_path_with_slash",
+ "a-bucket",
+ "/a-path",
+ "https://s3.us-east-3.amazonaws.com/a-bucket/a-path",
+ "us-east-3",
+ S3AddressingStyle.PATH.value,
+ ),
+ (
+ "custom_http_url_and_path_without_slash",
+ "http://a-bucket.com/",
+ "a-path",
+ "http://a-bucket.com/a-path",
+ None,
+ ),
+ (
+ "custom_http_url_and_path_with_slash",
+ "http://a-bucket.com/",
+ "/a-path",
+ "http://a-bucket.com/a-path",
+ None,
+ ),
+ (
+ "custom_http_url_and_path_without_slash",
+ "https://a-bucket.com/",
+ "a-path",
+ "https://a-bucket.com/a-path",
+ None,
+ ),
+ (
+ "custom_http_url_and_path_with_slash",
+ "https://a-bucket.com/",
+ "/a-path",
+ "https://a-bucket.com/a-path",
+ None,
+ ),
+ ]
+ )
+ def test_url(
+ self, name, bucket, path, expected_result, region=None, addressing_style=None
+ ):
# Arrange
- uploader = self._create_s3_uploader(region=region, addressing_style=addressing_style)
+ uploader = self._create_s3_uploader(
+ region=region, addressing_style=addressing_style
+ )
# Act
result = uploader.url(bucket, path)
@@ -415,52 +429,56 @@ def test_url(self, name, bucket, path, expected_result, region=None, addressing_
# Assert
assert result == expected_result
- @parameterized.expand([
- (
- 'implicit_s3_url_template',
- 'bucket',
- 'the key',
- 'https://bucket.s3.amazonaws.com/the%20key'
- ),
- (
- 'implicit_s3_url_template_with_custom_region',
- 'bucket',
- 'the key',
- 'https://bucket.s3.us-east-2.amazonaws.com/the%20key',
- None,
- 'us-east-2'
- ),
- (
- 'explicit_s3_url_template',
- 'bucket',
- 'the key',
- 'https://bucket.s3.amazonaws.com/the%20key',
- S3UploaderConfiguration.URL_TEMPLATE_DEFAULT
- ),
- (
- 'explicit_s3_url_template_with_custom_region',
- 'bucket',
- 'the key',
- 'https://bucket.s3.us-east-2.amazonaws.com/the%20key',
- S3UploaderConfiguration.URL_TEMPLATE_DEFAULT,
- 'us-east-2'
- ),
- (
- 'http_url_template',
- 'bucket',
- 'the këy',
- 'http://bucket/the%20k%C3%ABy',
- S3UploaderConfiguration.URL_TEMPLATE_HTTP
- ),
- (
- 'https_url_template',
- 'bucket',
- 'the këy',
- 'https://bucket/the%20k%C3%ABy',
- S3UploaderConfiguration.URL_TEMPLATE_HTTPS
- )
- ])
- def test_final_mirror_url(self, name, bucket, key, expected_result, url_transform=None, region=None):
+ @parameterized.expand(
+ [
+ (
+ "implicit_s3_url_template",
+ "bucket",
+ "the key",
+ "https://bucket.s3.amazonaws.com/the%20key",
+ ),
+ (
+ "implicit_s3_url_template_with_custom_region",
+ "bucket",
+ "the key",
+ "https://bucket.s3.us-east-2.amazonaws.com/the%20key",
+ None,
+ "us-east-2",
+ ),
+ (
+ "explicit_s3_url_template",
+ "bucket",
+ "the key",
+ "https://bucket.s3.amazonaws.com/the%20key",
+ S3UploaderConfiguration.URL_TEMPLATE_DEFAULT,
+ ),
+ (
+ "explicit_s3_url_template_with_custom_region",
+ "bucket",
+ "the key",
+ "https://bucket.s3.us-east-2.amazonaws.com/the%20key",
+ S3UploaderConfiguration.URL_TEMPLATE_DEFAULT,
+ "us-east-2",
+ ),
+ (
+ "http_url_template",
+ "bucket",
+ "the këy",
+ "http://bucket/the%20k%C3%ABy",
+ S3UploaderConfiguration.URL_TEMPLATE_HTTP,
+ ),
+ (
+ "https_url_template",
+ "bucket",
+ "the këy",
+ "https://bucket/the%20k%C3%ABy",
+ S3UploaderConfiguration.URL_TEMPLATE_HTTPS,
+ ),
+ ]
+ )
+ def test_final_mirror_url(
+ self, name, bucket, key, expected_result, url_transform=None, region=None
+ ):
# Arrange
uploader = self._create_s3_uploader(region=region)
@@ -472,62 +490,76 @@ def test_final_mirror_url(self, name, bucket, key, expected_result, url_transfor
# Assert
if not url_transform:
- assert S3UploaderConfiguration.URL_TEMPLATE_DEFAULT == uploader.url_transform
+ assert (
+ S3UploaderConfiguration.URL_TEMPLATE_DEFAULT == uploader.url_transform
+ )
assert result == expected_result
def test_key_join(self):
"""Test the code used to build S3 keys from parts."""
parts = ["Gutenberg", b"Gutenberg ID", 1234, "Die Flügelmaus+.epub"]
- assert ('Gutenberg/Gutenberg%20ID/1234/Die%20Fl%C3%BCgelmaus%2B.epub' ==
- S3Uploader.key_join(parts))
-
- @parameterized.expand([
- (
- 'with_gutenberg_cover_generator_data_source',
- 'test-book-covers-s3-bucket',
- DataSource.GUTENBERG_COVER_GENERATOR,
- 'https://test-book-covers-s3-bucket.s3.amazonaws.com/Gutenberg%20Illustrated/'
- ),
- (
- 'with_overdrive_data_source',
- 'test-book-covers-s3-bucket',
- DataSource.OVERDRIVE,
- 'https://test-book-covers-s3-bucket.s3.amazonaws.com/Overdrive/'
- ),
- (
- 'with_overdrive_data_source_and_scaled_size',
- 'test-book-covers-s3-bucket',
- DataSource.OVERDRIVE,
- 'https://test-book-covers-s3-bucket.s3.amazonaws.com/scaled/300/Overdrive/',
- 300
- ),
- (
- 'with_gutenberg_cover_generator_data_source_and_custom_region',
- 'test-book-covers-s3-bucket',
- DataSource.GUTENBERG_COVER_GENERATOR,
- 'https://test-book-covers-s3-bucket.s3.us-east-3.amazonaws.com/Gutenberg%20Illustrated/',
- None,
- 'us-east-3'
- ),
- (
- 'with_overdrive_data_source_and_custom_region',
- 'test-book-covers-s3-bucket',
- DataSource.OVERDRIVE,
- 'https://test-book-covers-s3-bucket.s3.us-east-3.amazonaws.com/Overdrive/',
- None,
- 'us-east-3'
- ),
- (
- 'with_overdrive_data_source_and_scaled_size_and_custom_region',
- 'test-book-covers-s3-bucket',
- DataSource.OVERDRIVE,
- 'https://test-book-covers-s3-bucket.s3.us-east-3.amazonaws.com/scaled/300/Overdrive/',
- 300,
- 'us-east-3'
+ assert (
+ "Gutenberg/Gutenberg%20ID/1234/Die%20Fl%C3%BCgelmaus%2B.epub"
+ == S3Uploader.key_join(parts)
)
- ])
- def test_cover_image_root(self, name, bucket, data_source_name, expected_result, scaled_size=None, region=None):
+
+ @parameterized.expand(
+ [
+ (
+ "with_gutenberg_cover_generator_data_source",
+ "test-book-covers-s3-bucket",
+ DataSource.GUTENBERG_COVER_GENERATOR,
+ "https://test-book-covers-s3-bucket.s3.amazonaws.com/Gutenberg%20Illustrated/",
+ ),
+ (
+ "with_overdrive_data_source",
+ "test-book-covers-s3-bucket",
+ DataSource.OVERDRIVE,
+ "https://test-book-covers-s3-bucket.s3.amazonaws.com/Overdrive/",
+ ),
+ (
+ "with_overdrive_data_source_and_scaled_size",
+ "test-book-covers-s3-bucket",
+ DataSource.OVERDRIVE,
+ "https://test-book-covers-s3-bucket.s3.amazonaws.com/scaled/300/Overdrive/",
+ 300,
+ ),
+ (
+ "with_gutenberg_cover_generator_data_source_and_custom_region",
+ "test-book-covers-s3-bucket",
+ DataSource.GUTENBERG_COVER_GENERATOR,
+ "https://test-book-covers-s3-bucket.s3.us-east-3.amazonaws.com/Gutenberg%20Illustrated/",
+ None,
+ "us-east-3",
+ ),
+ (
+ "with_overdrive_data_source_and_custom_region",
+ "test-book-covers-s3-bucket",
+ DataSource.OVERDRIVE,
+ "https://test-book-covers-s3-bucket.s3.us-east-3.amazonaws.com/Overdrive/",
+ None,
+ "us-east-3",
+ ),
+ (
+ "with_overdrive_data_source_and_scaled_size_and_custom_region",
+ "test-book-covers-s3-bucket",
+ DataSource.OVERDRIVE,
+ "https://test-book-covers-s3-bucket.s3.us-east-3.amazonaws.com/scaled/300/Overdrive/",
+ 300,
+ "us-east-3",
+ ),
+ ]
+ )
+ def test_cover_image_root(
+ self,
+ name,
+ bucket,
+ data_source_name,
+ expected_result,
+ scaled_size=None,
+ region=None,
+ ):
# Arrange
uploader = self._create_s3_uploader(region=region)
data_source = DataSource.lookup(self._db, data_source_name)
@@ -538,19 +570,21 @@ def test_cover_image_root(self, name, bucket, data_source_name, expected_result,
# Assert
assert result == expected_result
- @parameterized.expand([
- (
- 'with_default_region',
- 'test-open-access-s3-bucket',
- 'https://test-open-access-s3-bucket.s3.amazonaws.com/'
- ),
- (
- 'with_custom_region',
- 'test-open-access-s3-bucket',
- 'https://test-open-access-s3-bucket.s3.us-east-3.amazonaws.com/',
- 'us-east-3'
- )
- ])
+ @parameterized.expand(
+ [
+ (
+ "with_default_region",
+ "test-open-access-s3-bucket",
+ "https://test-open-access-s3-bucket.s3.amazonaws.com/",
+ ),
+ (
+ "with_custom_region",
+ "test-open-access-s3-bucket",
+ "https://test-open-access-s3-bucket.s3.us-east-3.amazonaws.com/",
+ "us-east-3",
+ ),
+ ]
+ )
def test_content_root(self, name, bucket, expected_result, region=None):
# Arrange
uploader = self._create_s3_uploader(region=region)
@@ -561,34 +595,28 @@ def test_content_root(self, name, bucket, expected_result, region=None):
# Assert
assert result == expected_result
- @parameterized.expand([
- (
- 's3_url',
- 'test-marc-s3-bucket',
- 'SHORT',
- 'https://test-marc-s3-bucket.s3.amazonaws.com/SHORT/'
- ),
- (
- 's3_url_with_custom_region',
- 'test-marc-s3-bucket',
- 'SHORT',
- 'https://test-marc-s3-bucket.s3.us-east-2.amazonaws.com/SHORT/',
- 'us-east-2'
- ),
- (
- 'custom_http_url',
- 'http://my-feed/',
- 'SHORT',
- 'http://my-feed/SHORT/'
- ),
- (
- 'custom_https_url',
- 'https://my-feed/',
- 'SHORT',
- 'https://my-feed/SHORT/'
- ),
- ])
- def test_marc_file_root(self, name, bucket, library_name, expected_result, region=None):
+ @parameterized.expand(
+ [
+ (
+ "s3_url",
+ "test-marc-s3-bucket",
+ "SHORT",
+ "https://test-marc-s3-bucket.s3.amazonaws.com/SHORT/",
+ ),
+ (
+ "s3_url_with_custom_region",
+ "test-marc-s3-bucket",
+ "SHORT",
+ "https://test-marc-s3-bucket.s3.us-east-2.amazonaws.com/SHORT/",
+ "us-east-2",
+ ),
+ ("custom_http_url", "http://my-feed/", "SHORT", "http://my-feed/SHORT/"),
+ ("custom_https_url", "https://my-feed/", "SHORT", "https://my-feed/SHORT/"),
+ ]
+ )
+ def test_marc_file_root(
+ self, name, bucket, library_name, expected_result, region=None
+ ):
# Arrange
uploader = self._create_s3_uploader(region=region)
library = self._library(short_name=library_name)
@@ -599,100 +627,103 @@ def test_marc_file_root(self, name, bucket, library_name, expected_result, regio
# Assert
assert result == expected_result
- @parameterized.expand([
- (
- 'with_identifier',
- {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'thebooks'},
- 'ABOOK',
- 'https://thebooks.s3.amazonaws.com/Gutenberg%20ID/ABOOK.epub'
- ),
- (
- 'with_custom_extension',
- {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'thebooks'},
- 'ABOOK',
- 'https://thebooks.s3.amazonaws.com/Gutenberg%20ID/ABOOK.pdf',
- 'pdf'
- ),
- (
- 'with_custom_dotted_extension',
- {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'thebooks'},
- 'ABOOK',
- 'https://thebooks.s3.amazonaws.com/Gutenberg%20ID/ABOOK.pdf',
- '.pdf'
- ),
- (
- 'with_custom_data_source',
- {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'thebooks'},
- 'ABOOK',
- 'https://thebooks.s3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK.epub',
- None,
- DataSource.UNGLUE_IT
- ),
- (
- 'with_custom_title',
- {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'thebooks'},
- 'ABOOK',
- 'https://thebooks.s3.amazonaws.com/Gutenberg%20ID/ABOOK/On%20Books.epub',
- None,
- None,
- 'On Books'
- ),
- (
- 'with_custom_extension_and_title_and_data_source',
- {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'thebooks'},
- 'ABOOK',
- 'https://thebooks.s3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/On%20Books.pdf',
- '.pdf',
- DataSource.UNGLUE_IT,
- 'On Books'
- ),
- (
- 'with_custom_extension_and_title_and_data_source_and_region',
- {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: 'thebooks'},
- 'ABOOK',
- 'https://thebooks.s3.us-east-3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/On%20Books.pdf',
- '.pdf',
- DataSource.UNGLUE_IT,
- 'On Books',
- 'us-east-3'
- ),
- (
- 'with_protected_access_and_custom_extension_and_title_and_data_source_and_region',
- {S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: 'thebooks'},
- 'ABOOK',
- 'https://thebooks.s3.us-east-3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/On%20Books.pdf',
- '.pdf',
- DataSource.UNGLUE_IT,
- 'On Books',
- 'us-east-3',
- False,
- )
- ])
+ @parameterized.expand(
+ [
+ (
+ "with_identifier",
+ {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "thebooks"},
+ "ABOOK",
+ "https://thebooks.s3.amazonaws.com/Gutenberg%20ID/ABOOK.epub",
+ ),
+ (
+ "with_custom_extension",
+ {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "thebooks"},
+ "ABOOK",
+ "https://thebooks.s3.amazonaws.com/Gutenberg%20ID/ABOOK.pdf",
+ "pdf",
+ ),
+ (
+ "with_custom_dotted_extension",
+ {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "thebooks"},
+ "ABOOK",
+ "https://thebooks.s3.amazonaws.com/Gutenberg%20ID/ABOOK.pdf",
+ ".pdf",
+ ),
+ (
+ "with_custom_data_source",
+ {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "thebooks"},
+ "ABOOK",
+ "https://thebooks.s3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK.epub",
+ None,
+ DataSource.UNGLUE_IT,
+ ),
+ (
+ "with_custom_title",
+ {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "thebooks"},
+ "ABOOK",
+ "https://thebooks.s3.amazonaws.com/Gutenberg%20ID/ABOOK/On%20Books.epub",
+ None,
+ None,
+ "On Books",
+ ),
+ (
+ "with_custom_extension_and_title_and_data_source",
+ {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "thebooks"},
+ "ABOOK",
+ "https://thebooks.s3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/On%20Books.pdf",
+ ".pdf",
+ DataSource.UNGLUE_IT,
+ "On Books",
+ ),
+ (
+ "with_custom_extension_and_title_and_data_source_and_region",
+ {S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY: "thebooks"},
+ "ABOOK",
+ "https://thebooks.s3.us-east-3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/On%20Books.pdf",
+ ".pdf",
+ DataSource.UNGLUE_IT,
+ "On Books",
+ "us-east-3",
+ ),
+ (
+ "with_protected_access_and_custom_extension_and_title_and_data_source_and_region",
+ {S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY: "thebooks"},
+ "ABOOK",
+ "https://thebooks.s3.us-east-3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/On%20Books.pdf",
+ ".pdf",
+ DataSource.UNGLUE_IT,
+ "On Books",
+ "us-east-3",
+ False,
+ ),
+ ]
+ )
def test_book_url(
- self,
- name,
- buckets,
- identifier,
- expected_result,
- extension=None,
- data_source_name=None,
- title=None,
- region=None,
- open_access=True):
+ self,
+ name,
+ buckets,
+ identifier,
+ expected_result,
+ extension=None,
+ data_source_name=None,
+ title=None,
+ region=None,
+ open_access=True,
+ ):
# Arrange
identifier = self._identifier(foreign_id=identifier)
uploader = self._create_s3_uploader(region=region, **buckets)
- parameters = {'identifier': identifier, 'open_access': open_access}
+ parameters = {"identifier": identifier, "open_access": open_access}
if extension:
- parameters['extension'] = extension
+ parameters["extension"] = extension
if title:
- parameters['title'] = title
+ parameters["title"] = title
if data_source_name:
data_source = DataSource.lookup(self._db, DataSource.UNGLUE_IT)
- parameters['data_source'] = data_source
+ parameters["data_source"] = data_source
# Act
result = uploader.book_url(**parameters)
@@ -700,55 +731,58 @@ def test_book_url(
# Assert
assert result == expected_result
- @parameterized.expand([
- (
- 'without_scaled_size',
- {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'thecovers'},
- DataSource.UNGLUE_IT,
- 'ABOOK',
- 'filename',
- 'https://thecovers.s3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/filename'
- ),
- (
- 'without_scaled_size_and_with_custom_region',
- {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'thecovers'},
- DataSource.UNGLUE_IT,
- 'ABOOK',
- 'filename',
- 'https://thecovers.s3.us-east-3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/filename',
- None,
- 'us-east-3'
- ),
- (
- 'with_scaled_size',
- {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'thecovers'},
- DataSource.UNGLUE_IT,
- 'ABOOK',
- 'filename',
- 'https://thecovers.s3.amazonaws.com/scaled/601/unglue.it/Gutenberg%20ID/ABOOK/filename',
- 601
- ),
- (
- 'with_scaled_size_and_custom_region',
- {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: 'thecovers'},
- DataSource.UNGLUE_IT,
- 'ABOOK',
- 'filename',
- 'https://thecovers.s3.us-east-3.amazonaws.com/scaled/601/unglue.it/Gutenberg%20ID/ABOOK/filename',
- 601,
- 'us-east-3'
- )
- ])
+ @parameterized.expand(
+ [
+ (
+ "without_scaled_size",
+ {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "thecovers"},
+ DataSource.UNGLUE_IT,
+ "ABOOK",
+ "filename",
+ "https://thecovers.s3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/filename",
+ ),
+ (
+ "without_scaled_size_and_with_custom_region",
+ {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "thecovers"},
+ DataSource.UNGLUE_IT,
+ "ABOOK",
+ "filename",
+ "https://thecovers.s3.us-east-3.amazonaws.com/unglue.it/Gutenberg%20ID/ABOOK/filename",
+ None,
+ "us-east-3",
+ ),
+ (
+ "with_scaled_size",
+ {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "thecovers"},
+ DataSource.UNGLUE_IT,
+ "ABOOK",
+ "filename",
+ "https://thecovers.s3.amazonaws.com/scaled/601/unglue.it/Gutenberg%20ID/ABOOK/filename",
+ 601,
+ ),
+ (
+ "with_scaled_size_and_custom_region",
+ {S3UploaderConfiguration.BOOK_COVERS_BUCKET_KEY: "thecovers"},
+ DataSource.UNGLUE_IT,
+ "ABOOK",
+ "filename",
+ "https://thecovers.s3.us-east-3.amazonaws.com/scaled/601/unglue.it/Gutenberg%20ID/ABOOK/filename",
+ 601,
+ "us-east-3",
+ ),
+ ]
+ )
def test_cover_image_url(
- self,
- name,
- buckets,
- data_source_name,
- identifier,
- filename,
- expected_result,
- scaled_size=None,
- region=None):
+ self,
+ name,
+ buckets,
+ data_source_name,
+ identifier,
+ filename,
+ expected_result,
+ scaled_size=None,
+ region=None,
+ ):
# identifier = self._identifier(foreign_id="ABOOK")
# buckets = {S3Uploader.BOOK_COVERS_BUCKET_KEY : 'thecovers'}
# uploader = self._uploader(**buckets)
@@ -765,68 +799,73 @@ def test_cover_image_url(
uploader = self._create_s3_uploader(region=region, **buckets)
# Act
- result = uploader.cover_image_url(data_source, identifier, filename, scaled_size=scaled_size)
+ result = uploader.cover_image_url(
+ data_source, identifier, filename, scaled_size=scaled_size
+ )
# Assert
assert result == expected_result
- @parameterized.expand([
- (
- 'with_s3_bucket_and_end_time',
- 'marc',
- 'SHORT',
- 'Lane',
- datetime_utc(2020, 1, 1, 0, 0, 0),
- 'https://marc.s3.amazonaws.com/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00/Lane.mrc'
- ),
- (
- 'with_s3_bucket_and_end_time_and_start_time',
- 'marc',
- 'SHORT',
- 'Lane',
- datetime_utc(2020, 1, 2, 0, 0, 0),
- 'https://marc.s3.amazonaws.com/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00-2020-01-02%2000%3A00%3A00%2B00%3A00/Lane.mrc',
- datetime_utc(2020, 1, 1, 0, 0, 0),
- ),
- (
- 'with_s3_bucket_and_end_time_and_start_time_and_custom_region',
- 'marc',
- 'SHORT',
- 'Lane',
- datetime_utc(2020, 1, 2, 0, 0, 0),
- 'https://marc.s3.us-east-2.amazonaws.com/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00-2020-01-02%2000%3A00%3A00%2B00%3A00/Lane.mrc',
- datetime_utc(2020, 1, 1, 0, 0, 0),
- 'us-east-2'
- ),
- (
- 'with_http_bucket_and_end_time_and_start_time',
- 'http://marc',
- 'SHORT',
- 'Lane',
- datetime_utc(2020, 1, 2, 0, 0, 0),
- 'http://marc/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00-2020-01-02%2000%3A00%3A00%2B00%3A00/Lane.mrc',
- datetime_utc(2020, 1, 1, 0, 0, 0)
- ),
- (
- 'with_https_bucket_and_end_time_and_start_time',
- 'https://marc',
- 'SHORT',
- 'Lane',
- datetime_utc(2020, 1, 2, 0, 0, 0),
- 'https://marc/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00-2020-01-02%2000%3A00%3A00%2B00%3A00/Lane.mrc',
- datetime_utc(2020, 1, 1, 0, 0, 0)
- )
- ])
+ @parameterized.expand(
+ [
+ (
+ "with_s3_bucket_and_end_time",
+ "marc",
+ "SHORT",
+ "Lane",
+ datetime_utc(2020, 1, 1, 0, 0, 0),
+ "https://marc.s3.amazonaws.com/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00/Lane.mrc",
+ ),
+ (
+ "with_s3_bucket_and_end_time_and_start_time",
+ "marc",
+ "SHORT",
+ "Lane",
+ datetime_utc(2020, 1, 2, 0, 0, 0),
+ "https://marc.s3.amazonaws.com/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00-2020-01-02%2000%3A00%3A00%2B00%3A00/Lane.mrc",
+ datetime_utc(2020, 1, 1, 0, 0, 0),
+ ),
+ (
+ "with_s3_bucket_and_end_time_and_start_time_and_custom_region",
+ "marc",
+ "SHORT",
+ "Lane",
+ datetime_utc(2020, 1, 2, 0, 0, 0),
+ "https://marc.s3.us-east-2.amazonaws.com/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00-2020-01-02%2000%3A00%3A00%2B00%3A00/Lane.mrc",
+ datetime_utc(2020, 1, 1, 0, 0, 0),
+ "us-east-2",
+ ),
+ (
+ "with_http_bucket_and_end_time_and_start_time",
+ "http://marc",
+ "SHORT",
+ "Lane",
+ datetime_utc(2020, 1, 2, 0, 0, 0),
+ "http://marc/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00-2020-01-02%2000%3A00%3A00%2B00%3A00/Lane.mrc",
+ datetime_utc(2020, 1, 1, 0, 0, 0),
+ ),
+ (
+ "with_https_bucket_and_end_time_and_start_time",
+ "https://marc",
+ "SHORT",
+ "Lane",
+ datetime_utc(2020, 1, 2, 0, 0, 0),
+ "https://marc/SHORT/2020-01-01%2000%3A00%3A00%2B00%3A00-2020-01-02%2000%3A00%3A00%2B00%3A00/Lane.mrc",
+ datetime_utc(2020, 1, 1, 0, 0, 0),
+ ),
+ ]
+ )
def test_marc_file_url(
- self,
- name,
- bucket,
- library_name,
- lane_name,
- end_time,
- expected_result,
- start_time=None,
- region=None):
+ self,
+ name,
+ bucket,
+ library_name,
+ lane_name,
+ end_time,
+ expected_result,
+ start_time=None,
+ region=None,
+ ):
# Arrange
library = self._library(short_name=library_name)
lane = self._lane(display_name=lane_name)
@@ -839,54 +878,56 @@ def test_marc_file_url(
# Assert
assert result == expected_result
- @parameterized.expand([
- (
- 's3_path_style_request_without_region',
- 'https://s3.amazonaws.com/bucket/directory/filename.jpg',
- ('bucket', 'directory/filename.jpg')
- ),
- (
- 's3_path_style_request_with_region',
- 'https://s3.us-east-2.amazonaws.com/bucket/directory/filename.jpg',
- ('bucket', 'directory/filename.jpg')
- ),
- (
- 's3_virtual_hosted_style_request_with_global_endpoint',
- 'https://bucket.s3.amazonaws.com/directory/filename.jpg',
- ('bucket', 'directory/filename.jpg')
- ),
- (
- 's3_virtual_hosted_style_request_with_dashed_region',
- 'https://bucket.s3-us-east-2.amazonaws.com/directory/filename.jpg',
- ('bucket', 'directory/filename.jpg')
- ),
- (
- 's3_virtual_hosted_style_request_with_dotted_region',
- 'https://bucket.s3.us-east-2.amazonaws.com/directory/filename.jpg',
- ('bucket', 'directory/filename.jpg')
- ),
- (
- 'http_url',
- 'http://book-covers.nypl.org/directory/filename.jpg',
- ('book-covers.nypl.org', 'directory/filename.jpg')
- ),
- (
- 'https_url',
- 'https://book-covers.nypl.org/directory/filename.jpg',
- ('book-covers.nypl.org', 'directory/filename.jpg')
- ),
- (
- 'http_url_with_escaped_symbols',
- 'http://book-covers.nypl.org/directory/filename+with+spaces%21.jpg',
- ('book-covers.nypl.org', 'directory/filename with spaces!.jpg')
- ),
- (
- 'http_url_with_escaped_symbols_but_unquote_set_to_false',
- 'http://book-covers.nypl.org/directory/filename+with+spaces%21.jpg',
- ('book-covers.nypl.org', 'directory/filename+with+spaces%21.jpg'),
- False
- ),
- ])
+ @parameterized.expand(
+ [
+ (
+ "s3_path_style_request_without_region",
+ "https://s3.amazonaws.com/bucket/directory/filename.jpg",
+ ("bucket", "directory/filename.jpg"),
+ ),
+ (
+ "s3_path_style_request_with_region",
+ "https://s3.us-east-2.amazonaws.com/bucket/directory/filename.jpg",
+ ("bucket", "directory/filename.jpg"),
+ ),
+ (
+ "s3_virtual_hosted_style_request_with_global_endpoint",
+ "https://bucket.s3.amazonaws.com/directory/filename.jpg",
+ ("bucket", "directory/filename.jpg"),
+ ),
+ (
+ "s3_virtual_hosted_style_request_with_dashed_region",
+ "https://bucket.s3-us-east-2.amazonaws.com/directory/filename.jpg",
+ ("bucket", "directory/filename.jpg"),
+ ),
+ (
+ "s3_virtual_hosted_style_request_with_dotted_region",
+ "https://bucket.s3.us-east-2.amazonaws.com/directory/filename.jpg",
+ ("bucket", "directory/filename.jpg"),
+ ),
+ (
+ "http_url",
+ "http://book-covers.nypl.org/directory/filename.jpg",
+ ("book-covers.nypl.org", "directory/filename.jpg"),
+ ),
+ (
+ "https_url",
+ "https://book-covers.nypl.org/directory/filename.jpg",
+ ("book-covers.nypl.org", "directory/filename.jpg"),
+ ),
+ (
+ "http_url_with_escaped_symbols",
+ "http://book-covers.nypl.org/directory/filename+with+spaces%21.jpg",
+ ("book-covers.nypl.org", "directory/filename with spaces!.jpg"),
+ ),
+ (
+ "http_url_with_escaped_symbols_but_unquote_set_to_false",
+ "http://book-covers.nypl.org/directory/filename+with+spaces%21.jpg",
+ ("book-covers.nypl.org", "directory/filename+with+spaces%21.jpg"),
+ False,
+ ),
+ ]
+ )
def test_split_url(self, name, url, expected_result, unquote=True):
# Arrange
s3_uploader = self._create_s3_uploader()
@@ -900,22 +941,24 @@ def test_split_url(self, name, url, expected_result, unquote=True):
def test_mirror_one(self):
edition, pool = self._edition(with_license_pool=True)
original_cover_location = "http://example.com/a-cover.png"
- content = open(
- self.sample_cover_path("test-book-cover.png"), 'rb'
- ).read()
+ content = open(self.sample_cover_path("test-book-cover.png"), "rb").read()
cover, ignore = pool.add_link(
- Hyperlink.IMAGE, original_cover_location, edition.data_source,
+ Hyperlink.IMAGE,
+ original_cover_location,
+ edition.data_source,
Representation.PNG_MEDIA_TYPE,
- content=content
+ content=content,
)
cover_rep = cover.resource.representation
assert None == cover_rep.mirrored_at
original_epub_location = "https://books.com/a-book.epub"
epub, ignore = pool.add_link(
- Hyperlink.OPEN_ACCESS_DOWNLOAD, original_epub_location,
- edition.data_source, Representation.EPUB_MEDIA_TYPE,
- content="i'm an epub"
+ Hyperlink.OPEN_ACCESS_DOWNLOAD,
+ original_epub_location,
+ edition.data_source,
+ Representation.EPUB_MEDIA_TYPE,
+ content="i'm an epub",
)
epub_rep = epub.resource.representation
assert None == epub_rep.mirrored_at
@@ -925,9 +968,7 @@ def test_mirror_one(self):
# Mock final_mirror_url so we can verify that it's called with
# the right arguments
def mock_final_mirror_url(bucket, key):
- return "final_mirror_url was called with bucket %s, key %s" % (
- bucket, key
- )
+ return "final_mirror_url was called with bucket %s, key %s" % (bucket, key)
s3.final_mirror_url = mock_final_mirror_url
@@ -935,29 +976,33 @@ def mock_final_mirror_url(bucket, key):
cover_url = "http://s3.amazonaws.com/covers-go/here.png"
s3.mirror_one(cover.resource.representation, cover_url)
s3.mirror_one(epub.resource.representation, book_url)
- [[data1, bucket1, key1, args1, ignore1],
- [data2, bucket2, key2, args2, ignore2], ] = s3.client.uploads
+ [
+ [data1, bucket1, key1, args1, ignore1],
+ [data2, bucket2, key2, args2, ignore2],
+ ] = s3.client.uploads
# Both representations have had .mirror_url set and been
# mirrored to those URLs.
- assert data1.startswith(b'\x89')
+ assert data1.startswith(b"\x89")
assert "covers-go" == bucket1
assert "here.png" == key1
- assert Representation.PNG_MEDIA_TYPE == args1['ContentType']
+ assert Representation.PNG_MEDIA_TYPE == args1["ContentType"]
assert (utc_now() - cover_rep.mirrored_at).seconds < 10
assert b"i'm an epub" == data2
assert "books-go" == bucket2
assert "here.epub" == key2
- assert Representation.EPUB_MEDIA_TYPE == args2['ContentType']
+ assert Representation.EPUB_MEDIA_TYPE == args2["ContentType"]
# In both cases, mirror_url was set to the result of final_mirror_url.
assert (
- 'final_mirror_url was called with bucket books-go, key here.epub' ==
- epub_rep.mirror_url)
+ "final_mirror_url was called with bucket books-go, key here.epub"
+ == epub_rep.mirror_url
+ )
assert (
- 'final_mirror_url was called with bucket covers-go, key here.png' ==
- cover_rep.mirror_url)
+ "final_mirror_url was called with bucket covers-go, key here.png"
+ == cover_rep.mirror_url
+ )
# mirrored-at was set when the representation was 'mirrored'
for rep in epub_rep, cover_rep:
@@ -967,9 +1012,11 @@ def test_mirror_failure(self):
edition, pool = self._edition(with_license_pool=True)
original_epub_location = "https://books.com/a-book.epub"
epub, ignore = pool.add_link(
- Hyperlink.OPEN_ACCESS_DOWNLOAD, original_epub_location,
- edition.data_source, Representation.EPUB_MEDIA_TYPE,
- content="i'm an epub"
+ Hyperlink.OPEN_ACCESS_DOWNLOAD,
+ original_epub_location,
+ edition.data_source,
+ Representation.EPUB_MEDIA_TYPE,
+ content="i'm an epub",
)
epub_rep = epub.resource.representation
@@ -1015,18 +1062,21 @@ def test_svg_mirroring(self):
"""
hyperlink, ignore = pool.add_link(
- Hyperlink.IMAGE, original, edition.data_source,
+ Hyperlink.IMAGE,
+ original,
+ edition.data_source,
Representation.SVG_MEDIA_TYPE,
- content=svg)
+ content=svg,
+ )
# 'Upload' it to S3.
s3 = self._create_s3_uploader(MockS3Client)
s3.mirror_one(hyperlink.resource.representation, self._url)
[[data, bucket, key, args, ignore]] = s3.client.uploads
- assert Representation.SVG_MEDIA_TYPE == args['ContentType']
- assert b'svg' in data
- assert b'PNG' not in data
+ assert Representation.SVG_MEDIA_TYPE == args["ContentType"]
+ assert b"svg" in data
+ assert b"PNG" not in data
def test_multipart_upload(self):
class MockMultipartS3Upload(MultipartS3Upload):
@@ -1048,13 +1098,18 @@ def abort(self):
MockMultipartS3Upload.aborted = True
rep, ignore = create(
- self._db, Representation, url="http://books.mrc",
- media_type=Representation.MARC_MEDIA_TYPE)
+ self._db,
+ Representation,
+ url="http://books.mrc",
+ media_type=Representation.MARC_MEDIA_TYPE,
+ )
s3 = self._create_s3_uploader(MockS3Client)
# Successful upload
- with s3.multipart_upload(rep, rep.url, upload_class=MockMultipartS3Upload) as upload:
+ with s3.multipart_upload(
+ rep, rep.url, upload_class=MockMultipartS3Upload
+ ) as upload:
assert [] == upload.parts
assert False == upload.completed
assert False == upload.aborted
@@ -1073,7 +1128,9 @@ def upload_part(self, content):
raise Exception("Error!")
# Failed during upload
- with s3.multipart_upload(rep, rep.url, upload_class=FailingMultipartS3Upload) as upload:
+ with s3.multipart_upload(
+ rep, rep.url, upload_class=FailingMultipartS3Upload
+ ) as upload:
upload.upload_part("Part 1")
assert False == MockMultipartS3Upload.completed
@@ -1086,24 +1143,36 @@ def complete(self):
rep.mirror_exception = None
# Failed during completion
- with s3.multipart_upload(rep, rep.url, upload_class=AnotherFailingMultipartS3Upload) as upload:
+ with s3.multipart_upload(
+ rep, rep.url, upload_class=AnotherFailingMultipartS3Upload
+ ) as upload:
upload.upload_part("Part 1")
assert False == MockMultipartS3Upload.completed
assert True == MockMultipartS3Upload.aborted
assert "Error!" == rep.mirror_exception
- @parameterized.expand([
- ('default_expiration_parameter', None, int(S3UploaderConfiguration.S3_DEFAULT_PRESIGNED_URL_EXPIRATION)),
- ('empty_expiration_parameter', {S3UploaderConfiguration.S3_PRESIGNED_URL_EXPIRATION: 100}, 100)
- ])
+ @parameterized.expand(
+ [
+ (
+ "default_expiration_parameter",
+ None,
+ int(S3UploaderConfiguration.S3_DEFAULT_PRESIGNED_URL_EXPIRATION),
+ ),
+ (
+ "empty_expiration_parameter",
+ {S3UploaderConfiguration.S3_PRESIGNED_URL_EXPIRATION: 100},
+ 100,
+ ),
+ ]
+ )
def test_sign_url(self, name, expiration_settings, expected_expiration):
# Arrange
- region = 'us-east-1'
- bucket = 'bucket'
- filename = 'filename'
- url = 'https://{0}.s3.{1}.amazonaws.com/{2}'.format(bucket, region, filename)
- expected_url = url + '?AWSAccessKeyId=KEY&Expires=1&Signature=S'
+ region = "us-east-1"
+ bucket = "bucket"
+ filename = "filename"
+ url = "https://{0}.s3.{1}.amazonaws.com/{2}".format(bucket, region, filename)
+ expected_url = url + "?AWSAccessKeyId=KEY&Expires=1&Signature=S"
settings = expiration_settings if expiration_settings else {}
s3_uploader = self._create_s3_uploader(region=region, **settings)
s3_uploader.split_url = MagicMock(return_value=(bucket, filename))
@@ -1116,19 +1185,20 @@ def test_sign_url(self, name, expiration_settings, expected_expiration):
assert result == expected_url
s3_uploader.split_url.assert_called_once_with(url)
s3_uploader.client.generate_presigned_url.assert_called_once_with(
- 'get_object',
+ "get_object",
ExpiresIn=expected_expiration,
- Params={
- 'Bucket': bucket,
- 'Key': filename
- })
+ Params={"Bucket": bucket, "Key": filename},
+ )
class TestMultiPartS3Upload(S3UploaderTest):
def _representation(self):
rep, ignore = create(
- self._db, Representation, url="http://bucket/books.mrc",
- media_type=Representation.MARC_MEDIA_TYPE)
+ self._db,
+ Representation,
+ url="http://bucket/books.mrc",
+ media_type=Representation.MARC_MEDIA_TYPE,
+ )
return rep
def test_init(self):
@@ -1152,12 +1222,27 @@ def test_upload_part(self):
upload = MultipartS3Upload(uploader, rep, rep.url)
upload.upload_part("Part 1")
upload.upload_part("Part 2")
- assert ([{'Body': 'Part 1', 'UploadId': 1, 'PartNumber': 1, 'Bucket': 'bucket', 'Key': 'books.mrc'},
- {'Body': 'Part 2', 'UploadId': 1, 'PartNumber': 2, 'Bucket': 'bucket', 'Key': 'books.mrc'}] ==
- uploader.client.parts)
+ assert [
+ {
+ "Body": "Part 1",
+ "UploadId": 1,
+ "PartNumber": 1,
+ "Bucket": "bucket",
+ "Key": "books.mrc",
+ },
+ {
+ "Body": "Part 2",
+ "UploadId": 1,
+ "PartNumber": 2,
+ "Bucket": "bucket",
+ "Key": "books.mrc",
+ },
+ ] == uploader.client.parts
assert 3 == upload.part_number
- assert ([{'ETag': 'etag', 'PartNumber': 1}, {'ETag': 'etag', 'PartNumber': 2}] ==
- upload.parts)
+ assert [
+ {"ETag": "etag", "PartNumber": 1},
+ {"ETag": "etag", "PartNumber": 2},
+ ] == upload.parts
uploader.client.fail_with = Exception("Error!")
pytest.raises(Exception, upload.upload_part, "Part 3")
@@ -1169,9 +1254,19 @@ def test_complete(self):
upload.upload_part("Part 1")
upload.upload_part("Part 2")
upload.complete()
- assert [{'Bucket': 'bucket', 'Key': 'books.mrc', 'UploadId': 1, 'MultipartUpload': {
- 'Parts': [{'ETag': 'etag', 'PartNumber': 1}, {'ETag': 'etag', 'PartNumber': 2}],
- }}] == uploader.client.uploads
+ assert [
+ {
+ "Bucket": "bucket",
+ "Key": "books.mrc",
+ "UploadId": 1,
+ "MultipartUpload": {
+ "Parts": [
+ {"ETag": "etag", "PartNumber": 1},
+ {"ETag": "etag", "PartNumber": 2},
+ ],
+ },
+ }
+ ] == uploader.client.uploads
def test_abort(self):
uploader = self._create_s3_uploader(MockS3Client)
@@ -1185,48 +1280,60 @@ def test_abort(self):
@pytest.mark.minio
class TestS3UploaderIntegration(S3UploaderIntegrationTest):
- @parameterized.expand([
- (
- 'using_s3_uploader_and_open_access_bucket',
- functools.partial(S3Uploader, host=S3UploaderIntegrationTest.SIMPLIFIED_TEST_MINIO_HOST),
- S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY,
- 'test-bucket',
- True
- ),
- (
- 'using_s3_uploader_and_protected_access_bucket',
- functools.partial(S3Uploader, host=S3UploaderIntegrationTest.SIMPLIFIED_TEST_MINIO_HOST),
- S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY,
- 'test-bucket',
- False
- ),
- (
- 'using_minio_uploader_and_open_access_bucket',
- MinIOUploader,
- S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY,
- 'test-bucket',
- True,
- {
- MinIOUploaderConfiguration.ENDPOINT_URL: S3UploaderIntegrationTest.SIMPLIFIED_TEST_MINIO_ENDPOINT_URL
- }
- ),
- (
- 'using_minio_uploader_and_protected_access_bucket',
- MinIOUploader,
- S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY,
- 'test-bucket',
- False,
- {
- MinIOUploaderConfiguration.ENDPOINT_URL: S3UploaderIntegrationTest.SIMPLIFIED_TEST_MINIO_ENDPOINT_URL
- }
- )
- ])
- def test_mirror(self, name, uploader_class, bucket_type, bucket_name, open_access, settings=None):
+ @parameterized.expand(
+ [
+ (
+ "using_s3_uploader_and_open_access_bucket",
+ functools.partial(
+ S3Uploader,
+ host=S3UploaderIntegrationTest.SIMPLIFIED_TEST_MINIO_HOST,
+ ),
+ S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY,
+ "test-bucket",
+ True,
+ ),
+ (
+ "using_s3_uploader_and_protected_access_bucket",
+ functools.partial(
+ S3Uploader,
+ host=S3UploaderIntegrationTest.SIMPLIFIED_TEST_MINIO_HOST,
+ ),
+ S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY,
+ "test-bucket",
+ False,
+ ),
+ (
+ "using_minio_uploader_and_open_access_bucket",
+ MinIOUploader,
+ S3UploaderConfiguration.OA_CONTENT_BUCKET_KEY,
+ "test-bucket",
+ True,
+ {
+ MinIOUploaderConfiguration.ENDPOINT_URL: S3UploaderIntegrationTest.SIMPLIFIED_TEST_MINIO_ENDPOINT_URL
+ },
+ ),
+ (
+ "using_minio_uploader_and_protected_access_bucket",
+ MinIOUploader,
+ S3UploaderConfiguration.PROTECTED_CONTENT_BUCKET_KEY,
+ "test-bucket",
+ False,
+ {
+ MinIOUploaderConfiguration.ENDPOINT_URL: S3UploaderIntegrationTest.SIMPLIFIED_TEST_MINIO_ENDPOINT_URL
+ },
+ ),
+ ]
+ )
+ def test_mirror(
+ self, name, uploader_class, bucket_type, bucket_name, open_access, settings=None
+ ):
# Arrange
- book_title = '1234567890'
- book_content = '1234567890'
+ book_title = "1234567890"
+ book_content = "1234567890"
identifier = Identifier(type=Identifier.ISBN, identifier=book_title)
- representation = Representation(content=book_content, media_type=Representation.EPUB_MEDIA_TYPE)
+ representation = Representation(
+ content=book_content, media_type=Representation.EPUB_MEDIA_TYPE
+ )
buckets = {
bucket_type: bucket_name,
}
@@ -1236,7 +1343,9 @@ def test_mirror(self, name, uploader_class, bucket_type, bucket_name, open_acces
else:
settings = buckets
- s3_uploader = self._create_s3_uploader(uploader_class=uploader_class, **settings)
+ s3_uploader = self._create_s3_uploader(
+ uploader_class=uploader_class, **settings
+ )
self.minio_s3_client.create_bucket(Bucket=bucket_name)
@@ -1246,9 +1355,9 @@ def test_mirror(self, name, uploader_class, bucket_type, bucket_name, open_acces
# Assert
response = self.minio_s3_client.list_objects(Bucket=bucket_name)
- assert 'Contents' in response
- assert len(response['Contents']) == 1
+ assert "Contents" in response
+ assert len(response["Contents"]) == 1
- [object] = response['Contents']
+ [object] = response["Contents"]
- assert object['Key'] == 'ISBN/{0}.epub'.format(book_title)
+ assert object["Key"] == "ISBN/{0}.epub".format(book_title)
diff --git a/tests/test_scripts.py b/tests/test_scripts.py
index 5b7dd1baf..12b801f3d 100644
--- a/tests/test_scripts.py
+++ b/tests/test_scripts.py
@@ -4,30 +4,17 @@
import stat
import tempfile
from io import StringIO
+
import pytest
from parameterized import parameterized
-from ..testing import (
- DatabaseTest,
-)
from ..classifier import Classifier
-from ..config import (
- CannotLoadConfiguration,
-)
+from ..config import CannotLoadConfiguration
from ..external_search import MockExternalSearchIndex
-from ..lane import (
- Lane,
- WorkList,
-)
-from ..metadata_layer import (
- LinkData,
- TimestampData,
-)
+from ..lane import Lane, WorkList
+from ..metadata_layer import LinkData, TimestampData
from ..mirror import MirrorUploader
from ..model import (
- create,
- dump_query,
- get_one,
CachedFeed,
Collection,
Complaint,
@@ -43,19 +30,19 @@
Timestamp,
Work,
WorkCoverageRecord,
+ create,
+ dump_query,
+ get_one,
)
from ..model.configuration import ExternalIntegrationLink
-from ..monitor import (
- Monitor,
- CollectionMonitor,
- ReaperMonitor,
-)
-from ..s3 import S3Uploader, MinIOUploader, MinIOUploaderConfiguration
+from ..monitor import CollectionMonitor, Monitor, ReaperMonitor
+from ..s3 import MinIOUploader, MinIOUploaderConfiguration, S3Uploader
from ..scripts import (
AddClassificationScript,
CheckContributorNamesInDB,
CollectionArgumentsScript,
CollectionInputScript,
+ CollectionType,
ConfigureCollectionScript,
ConfigureIntegrationScript,
ConfigureLaneScript,
@@ -93,23 +80,17 @@
WhereAreMyBooksScript,
WorkClassificationScript,
WorkProcessingScript,
- CollectionType)
+)
from ..testing import (
AlwaysSuccessfulCollectionCoverageProvider,
AlwaysSuccessfulWorkCoverageProvider,
+ DatabaseTest,
)
-from ..util.worker_pools import (
- DatabasePool,
-)
-from ..util.datetime_helpers import (
- datetime_utc,
- strptime_utc,
- to_utc,
- utc_now,
-)
+from ..util.datetime_helpers import datetime_utc, strptime_utc, to_utc, utc_now
+from ..util.worker_pools import DatabasePool
-class TestScript(DatabaseTest):
+class TestScript(DatabaseTest):
def test_parse_time(self):
reference_date = datetime_utc(2016, 1, 1)
@@ -123,9 +104,7 @@ def test_parse_time(self):
pytest.raises(ValueError, Script.parse_time, "201601-01")
-
def test_script_name(self):
-
class Sample(Script):
pass
@@ -134,14 +113,12 @@ class Sample(Script):
script = Sample(self._db)
assert "Sample" == script.script_name
-
# If a script does define .name, that's used instead.
script.name = "I'm a script"
assert script.name == script.script_name
class TestTimestampScript(DatabaseTest):
-
def _ts(self, script):
"""Convenience method to look up the Timestamp for a script.
@@ -157,6 +134,7 @@ def test_update_timestamp(self):
class Noisy(TimestampScript):
def do_run(self):
pass
+
script = Noisy(self._db)
script.run()
@@ -215,6 +193,7 @@ def test_normal_script_has_no_timestamp(self):
class Silent(Script):
def do_run(self):
pass
+
script = Silent(self._db)
script.run()
assert None == self._ts(script)
@@ -230,16 +209,15 @@ def test_process_contribution_local(self):
identifier_type=Identifier.GUTENBERG_ID,
identifier_id="1",
with_open_access_download=True,
- title="Alice Writes Books")
+ title="Alice Writes Books",
+ )
alice, new = self._contributor(sort_name="Alice Alrighty")
alice._sort_name = "Alice Alrighty"
- alice.display_name="Alice Alrighty"
+ alice.display_name = "Alice Alrighty"
- edition_alice.add_contributor(
- alice, [Contributor.PRIMARY_AUTHOR_ROLE]
- )
- edition_alice.sort_author="Alice Rocks"
+ edition_alice.add_contributor(alice, [Contributor.PRIMARY_AUTHOR_ROLE])
+ edition_alice.sort_author = "Alice Rocks"
# everything is set up as we expect
assert "Alice Alrighty" == alice.sort_name
@@ -251,15 +229,14 @@ def test_process_contribution_local(self):
identifier_type=Identifier.GUTENBERG_ID,
identifier_id="2",
with_open_access_download=True,
- title="Bob Writes Books")
+ title="Bob Writes Books",
+ )
bob, new = self._contributor(sort_name="Bob")
- bob.display_name="Bob Bitshifter"
+ bob.display_name = "Bob Bitshifter"
- edition_bob.add_contributor(
- bob, [Contributor.PRIMARY_AUTHOR_ROLE]
- )
- edition_bob.sort_author="Bob Rocks"
+ edition_bob.add_contributor(bob, [Contributor.PRIMARY_AUTHOR_ROLE])
+ edition_bob.sort_author = "Bob Rocks"
assert "Bob" == bob.sort_name
assert "Bob Bitshifter" == bob.display_name
@@ -281,38 +258,41 @@ def test_process_contribution_local(self):
assert "Bob Rocks" == edition_bob.sort_author
# and we lodged a proper complaint
- q = self._db.query(Complaint).filter(Complaint.source==CheckContributorNamesInDB.COMPLAINT_SOURCE)
- q = q.filter(Complaint.type==CheckContributorNamesInDB.COMPLAINT_TYPE).filter(Complaint.license_pool==pool_bob)
+ q = self._db.query(Complaint).filter(
+ Complaint.source == CheckContributorNamesInDB.COMPLAINT_SOURCE
+ )
+ q = q.filter(Complaint.type == CheckContributorNamesInDB.COMPLAINT_TYPE).filter(
+ Complaint.license_pool == pool_bob
+ )
complaints = q.all()
assert 1 == len(complaints)
assert None == complaints[0].resolved
-
class TestIdentifierInputScript(DatabaseTest):
-
def test_parse_list_as_identifiers(self):
i1 = self._identifier()
i2 = self._identifier()
- args = [i1.identifier, 'no-such-identifier', i2.identifier]
+ args = [i1.identifier, "no-such-identifier", i2.identifier]
identifiers = IdentifierInputScript.parse_identifier_list(
self._db, i1.type, None, args
)
assert [i1, i2] == identifiers
assert [] == IdentifierInputScript.parse_identifier_list(
- self._db, i1.type, None, [])
+ self._db, i1.type, None, []
+ )
def test_parse_list_as_identifiers_with_autocreate(self):
type = Identifier.OVERDRIVE_ID
- args = ['brand-new-identifier']
+ args = ["brand-new-identifier"]
[i] = IdentifierInputScript.parse_identifier_list(
self._db, type, None, args, autocreate=True
)
assert type == i.type
- assert 'brand-new-identifier' == i.identifier
+ assert "brand-new-identifier" == i.identifier
def test_parse_list_as_identifiers_with_data_source(self):
lp1 = self._licensepool(None, data_source_name=DataSource.UNGLUE_IT)
@@ -325,7 +305,8 @@ def test_parse_list_as_identifiers_with_data_source(self):
# Only URIs with a FeedBooks LicensePool are selected.
identifiers = IdentifierInputScript.parse_identifier_list(
- self._db, Identifier.URI, source, [])
+ self._db, Identifier.URI, source, []
+ )
assert [i2] == identifiers
def test_parse_list_as_identifiers_by_database_id(self):
@@ -338,27 +319,27 @@ def test_parse_list_as_identifiers_by_database_id(self):
ids = [id1.id, "10000000", "abcde", id2.id]
identifiers = IdentifierInputScript.parse_identifier_list(
- self._db, IdentifierInputScript.DATABASE_ID, None, ids)
+ self._db, IdentifierInputScript.DATABASE_ID, None, ids
+ )
assert [id1, id2] == identifiers
def test_parse_command_line(self):
i1 = self._identifier()
i2 = self._identifier()
# We pass in one identifier on the command line...
- cmd_args = ["--identifier-type",
- i1.type, i1.identifier]
+ cmd_args = ["--identifier-type", i1.type, i1.identifier]
# ...and another one into standard input.
stdin = MockStdin(i2.identifier)
- parsed = IdentifierInputScript.parse_command_line(
- self._db, cmd_args, stdin
- )
+ parsed = IdentifierInputScript.parse_command_line(self._db, cmd_args, stdin)
assert [i1, i2] == parsed.identifiers
assert i1.type == parsed.identifier_type
def test_parse_command_line_no_identifiers(self):
cmd_args = [
- "--identifier-type", Identifier.OVERDRIVE_ID,
- "--identifier-data-source", DataSource.STANDARD_EBOOKS
+ "--identifier-type",
+ Identifier.OVERDRIVE_ID,
+ "--identifier-data-source",
+ DataSource.STANDARD_EBOOKS,
]
parsed = IdentifierInputScript.parse_command_line(
self._db, cmd_args, MockStdin()
@@ -370,13 +351,16 @@ def test_parse_command_line_no_identifiers(self):
class SuccessMonitor(Monitor):
"""A simple Monitor that alway succeeds."""
+
SERVICE_NAME = "Success"
+
def run(self):
self.ran = True
class OPDSCollectionMonitor(CollectionMonitor):
"""Mock Monitor for use in tests of Run*MonitorScript."""
+
SERVICE_NAME = "Test Monitor"
PROTOCOL = ExternalIntegration.OPDS_IMPORT
@@ -390,8 +374,10 @@ def run_once(self, progress):
class DoomedCollectionMonitor(CollectionMonitor):
"""Mock CollectionMonitor that always raises an exception."""
+
SERVICE_NAME = "Doomed Monitor"
PROTOCOL = ExternalIntegration.OPDS_IMPORT
+
def run(self, *args, **kwargs):
self.ran = True
self.collection.doomed = True
@@ -406,10 +392,15 @@ class TestCollectionMonitorWithDifferentRunners(DatabaseTest):
names are specified, then the monitor will be run only on the ones specified.
"""
- @parameterized.expand([
- ('run CollectionMonitor from RunMonitorScript', RunMonitorScript),
- ('run CollectionMonitor from RunCollectionMonitorScript', RunCollectionMonitorScript),
- ])
+ @parameterized.expand(
+ [
+ ("run CollectionMonitor from RunMonitorScript", RunMonitorScript),
+ (
+ "run CollectionMonitor from RunCollectionMonitorScript",
+ RunCollectionMonitorScript,
+ ),
+ ]
+ )
def test_run_collection_monitor_with_no_args(self, name, script_runner):
# Run CollectionMonitor via RunMonitor for all applicable collections.
c1 = self._collection()
@@ -421,33 +412,46 @@ def test_run_collection_monitor_with_no_args(self, name, script_runner):
for c in [c1, c2]:
assert "test value" == c.ran_with_argument
- @parameterized.expand([
- ('run CollectionMonitor with collection args from RunMonitorScript', RunMonitorScript),
- ('run CollectionMonitor with collection args from RunCollectionMonitorScript', RunCollectionMonitorScript),
- ])
+ @parameterized.expand(
+ [
+ (
+ "run CollectionMonitor with collection args from RunMonitorScript",
+ RunMonitorScript,
+ ),
+ (
+ "run CollectionMonitor with collection args from RunCollectionMonitorScript",
+ RunCollectionMonitorScript,
+ ),
+ ]
+ )
def test_run_collection_monitor_with_collection_args(self, name, script_runner):
# Run CollectionMonitor via RunMonitor for only specified collections.
- c1 = self._collection(name='Collection 1')
- c2 = self._collection(name='Collection 2')
- c3 = self._collection(name='Collection 3')
+ c1 = self._collection(name="Collection 1")
+ c2 = self._collection(name="Collection 2")
+ c3 = self._collection(name="Collection 3")
all_collections = [c1, c2, c3]
monitored_collections = [c1, c3]
monitored_names = [c.name for c in monitored_collections]
script = script_runner(
- OPDSCollectionMonitor, self._db, cmd_args=monitored_names, test_argument="test value"
+ OPDSCollectionMonitor,
+ self._db,
+ cmd_args=monitored_names,
+ test_argument="test value",
)
script.run()
for c in monitored_collections:
- assert hasattr(c, 'ran_with_argument')
+ assert hasattr(c, "ran_with_argument")
assert "test value" == c.ran_with_argument
- for c in [collection for collection in all_collections
- if collection not in monitored_collections]:
- assert not hasattr(c, 'ran_with_argument')
+ for c in [
+ collection
+ for collection in all_collections
+ if collection not in monitored_collections
+ ]:
+ assert not hasattr(c, "ran_with_argument")
class TestRunMultipleMonitorsScript(DatabaseTest):
-
def test_do_run(self):
m1 = SuccessMonitor(self._db)
m2 = DoomedCollectionMonitor(self._db, self._default_collection)
@@ -455,6 +459,7 @@ def test_do_run(self):
class MockScript(RunMultipleMonitorsScript):
name = "Run three monitors"
+
def monitors(self, **kwargs):
self.kwargs = kwargs
return [m1, m2, m3]
@@ -476,11 +481,10 @@ def monitors(self, **kwargs):
# The exception that crashed the second monitor was stored as
# .exception, in case we want to look at it.
assert "Doomed!" == str(m2.exception)
- assert None == getattr(m1, 'exception', None)
+ assert None == getattr(m1, "exception", None)
class TestRunCollectionMonitorScript(DatabaseTest):
-
def test_monitors(self):
# Here we have three OPDS import Collections...
o1 = self._collection()
@@ -490,7 +494,9 @@ def test_monitors(self):
# ...and a Bibliotheca collection.
b1 = self._collection(protocol=ExternalIntegration.BIBLIOTHECA)
- script = RunCollectionMonitorScript(OPDSCollectionMonitor, self._db, cmd_args=[])
+ script = RunCollectionMonitorScript(
+ OPDSCollectionMonitor, self._db, cmd_args=[]
+ )
# Calling monitors() instantiates an OPDSCollectionMonitor
# for every OPDS import collection. The Bibliotheca collection
@@ -503,7 +509,6 @@ def test_monitors(self):
class TestRunReaperMonitorsScript(DatabaseTest):
-
def test_monitors(self):
"""This script instantiates a Monitor for every class in
ReaperMonitor.REGISTRY.
@@ -517,7 +522,6 @@ def test_monitors(self):
class TestPatronInputScript(DatabaseTest):
-
def test_parse_patron_list(self):
"""Test that patrons can be identified with any unique identifier."""
l1 = self._library()
@@ -534,15 +538,22 @@ def test_parse_patron_list(self):
p4 = self._patron()
p4.external_identifier = self._str
p4.library_id = l2.id
- args = [p1.authorization_identifier, 'no-such-patron',
- '', p2.username, p3.external_identifier]
- patrons = PatronInputScript.parse_patron_list(
- self._db, l1, args
- )
+ args = [
+ p1.authorization_identifier,
+ "no-such-patron",
+ "",
+ p2.username,
+ p3.external_identifier,
+ ]
+ patrons = PatronInputScript.parse_patron_list(self._db, l1, args)
assert [p1, p2, p3] == patrons
assert [] == PatronInputScript.parse_patron_list(self._db, l1, [])
- assert [p1] == PatronInputScript.parse_patron_list(self._db, l1, [p1.external_identifier, p4.external_identifier])
- assert [p4] == PatronInputScript.parse_patron_list(self._db, l2, [p1.external_identifier, p4.external_identifier])
+ assert [p1] == PatronInputScript.parse_patron_list(
+ self._db, l1, [p1.external_identifier, p4.external_identifier]
+ )
+ assert [p4] == PatronInputScript.parse_patron_list(
+ self._db, l2, [p1.external_identifier, p4.external_identifier]
+ )
def test_parse_command_line(self):
l1 = self._library()
@@ -556,9 +567,7 @@ def test_parse_command_line(self):
cmd_args = [l1.short_name, p1.authorization_identifier]
# ...and another one into standard input.
stdin = MockStdin(p2.authorization_identifier)
- parsed = PatronInputScript.parse_command_line(
- self._db, cmd_args, stdin
- )
+ parsed = PatronInputScript.parse_command_line(self._db, cmd_args, stdin)
assert [p1, p2] == parsed.patrons
def test_patron_different_library(self):
@@ -581,9 +590,11 @@ def test_do_run(self):
"""Test that PatronInputScript.do_run() calls process_patron()
for every patron designated by the command-line arguments.
"""
+
class MockPatronInputScript(PatronInputScript):
def process_patron(self, patron):
patron.processed = True
+
l1 = self._library()
p1 = self._patron()
p2 = self._patron()
@@ -606,15 +617,12 @@ def process_patron(self, patron):
class TestLibraryInputScript(DatabaseTest):
-
def test_parse_library_list(self):
"""Test that libraries can be identified with their full name or short name."""
l1 = self._library()
l2 = self._library()
- args = [l1.name, 'no-such-library', '', l2.short_name]
- libraries = LibraryInputScript.parse_library_list(
- self._db, args
- )
+ args = [l1.name, "no-such-library", "", l2.short_name]
+ libraries = LibraryInputScript.parse_library_list(self._db, args)
assert [l1, l2] == libraries
assert [] == LibraryInputScript.parse_library_list(self._db, [])
@@ -632,19 +640,18 @@ def test_parse_command_line_no_identifiers(self):
"""If you don't specify any libraries on the command
line, we will process all libraries in the system.
"""
- parsed =LibraryInputScript.parse_command_line(
- self._db, []
- )
+ parsed = LibraryInputScript.parse_command_line(self._db, [])
assert self._db.query(Library).all() == parsed.libraries
-
def test_do_run(self):
"""Test that LibraryInputScript.do_run() calls process_library()
for every library designated by the command-line arguments.
"""
+
class MockLibraryInputScript(LibraryInputScript):
def process_library(self, library):
library.processed = True
+
l1 = self._library()
l2 = self._library()
l2.processed = False
@@ -656,9 +663,7 @@ def process_library(self, library):
class TestLaneSweeperScript(DatabaseTest):
-
def test_process_library(self):
-
class Mock(LaneSweeperScript):
def __init__(self, _db):
super(Mock, self).__init__(_db)
@@ -667,7 +672,7 @@ def __init__(self, _db):
def should_process_lane(self, lane):
self.considered.append(lane)
- return lane.display_name == 'process me'
+ return lane.display_name == "process me"
def process_lane(self, lane):
self.processed.append(lane)
@@ -696,11 +701,15 @@ def process_lane(self, lane):
class TestRunCoverageProviderScript(DatabaseTest):
-
def test_parse_command_line(self):
identifier = self._identifier()
- cmd_args = ["--cutoff-time", "2016-05-01", "--identifier-type",
- identifier.type, identifier.identifier]
+ cmd_args = [
+ "--cutoff-time",
+ "2016-05-01",
+ "--identifier-type",
+ identifier.type,
+ identifier.identifier,
+ ]
parsed = RunCoverageProviderScript.parse_command_line(
self._db, cmd_args, MockStdin()
)
@@ -710,7 +719,6 @@ def test_parse_command_line(self):
class TestRunThreadedCollectionCoverageProviderScript(DatabaseTest):
-
def test_run(self):
provider = AlwaysSuccessfulCollectionCoverageProvider
script = RunThreadedCollectionCoverageProviderScript(
@@ -735,8 +743,10 @@ def test_run(self):
# Set a timestamp for the provider.
timestamp = Timestamp.stamp(
- self._db, provider.SERVICE_NAME, Timestamp.COVERAGE_PROVIDER_TYPE,
- collection=collection
+ self._db,
+ provider.SERVICE_NAME,
+ Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection=collection,
)
original_timestamp = timestamp.finish
self._db.commit()
@@ -752,7 +762,9 @@ def test_run(self):
# All relevant identifiers have been given coverage.
source = DataSource.lookup(self._db, provider.DATA_SOURCE_NAME)
identifiers_missing_coverage = Identifier.missing_coverage_from(
- self._db, provider.INPUT_IDENTIFIER_TYPES, source,
+ self._db,
+ provider.INPUT_IDENTIFIER_TYPES,
+ source,
)
assert [id3] == identifiers_missing_coverage.all()
@@ -762,22 +774,21 @@ def test_run(self):
assert CoverageRecord.SUCCESS == record2.status
assert (False, False) == (was_registered1, was_registered2)
-
# The timestamp for the provider has been updated.
new_timestamp = Timestamp.value(
- self._db, provider.SERVICE_NAME, Timestamp.COVERAGE_PROVIDER_TYPE,
- collection
+ self._db,
+ provider.SERVICE_NAME,
+ Timestamp.COVERAGE_PROVIDER_TYPE,
+ collection,
)
assert new_timestamp != original_timestamp
assert new_timestamp > original_timestamp
class TestRunWorkCoverageProviderScript(DatabaseTest):
-
def test_constructor(self):
script = RunWorkCoverageProviderScript(
- AlwaysSuccessfulWorkCoverageProvider, _db=self._db,
- batch_size=123
+ AlwaysSuccessfulWorkCoverageProvider, _db=self._db, batch_size=123
)
[provider] = script.providers
assert isinstance(provider, AlwaysSuccessfulWorkCoverageProvider)
@@ -785,7 +796,6 @@ def test_constructor(self):
class TestWorkProcessingScript(DatabaseTest):
-
def test_make_query(self):
# Create two Gutenberg works and one Overdrive work
g1 = self._work(with_license_pool=True, with_open_access_download=True)
@@ -794,27 +804,28 @@ def test_make_query(self):
overdrive_edition = self._edition(
data_source_name=DataSource.OVERDRIVE,
identifier_type=Identifier.OVERDRIVE_ID,
- with_license_pool=True
+ with_license_pool=True,
)[0]
overdrive_work = self._work(presentation_edition=overdrive_edition)
ugi_edition = self._edition(
data_source_name=DataSource.UNGLUE_IT,
identifier_type=Identifier.URI,
- with_license_pool=True
+ with_license_pool=True,
)[0]
unglue_it = self._work(presentation_edition=ugi_edition)
se_edition = self._edition(
data_source_name=DataSource.STANDARD_EBOOKS,
identifier_type=Identifier.URI,
- with_license_pool=True
+ with_license_pool=True,
)[0]
standard_ebooks = self._work(presentation_edition=se_edition)
everything = WorkProcessingScript.make_query(self._db, None, None, None)
- assert (set([g1, g2, overdrive_work, unglue_it, standard_ebooks]) ==
- set(everything.all()))
+ assert set([g1, g2, overdrive_work, unglue_it, standard_ebooks]) == set(
+ everything.all()
+ )
all_gutenberg = WorkProcessingScript.make_query(
self._db, Identifier.GUTENBERG_ID, [], None
@@ -839,37 +850,36 @@ class TestTimestampInfo(DatabaseTest):
def test_find(self):
# If there isn't a timestamp for the given service,
# nothing is returned.
- result = self.TimestampInfo.find(self, 'test')
+ result = self.TimestampInfo.find(self, "test")
assert None == result
# But an empty Timestamp has been placed into the database.
- timestamp = self._db.query(Timestamp).filter(Timestamp.service=='test').one()
+ timestamp = self._db.query(Timestamp).filter(Timestamp.service == "test").one()
assert None == timestamp.start
assert None == timestamp.finish
assert None == timestamp.counter
# A repeat search for the empty Timestamp also results in None.
script = DatabaseMigrationScript(self._db)
- assert None == self.TimestampInfo.find(script, 'test')
+ assert None == self.TimestampInfo.find(script, "test")
# If the Timestamp is stamped, it is returned.
timestamp.finish = utc_now()
timestamp.counter = 1
self._db.flush()
- result = self.TimestampInfo.find(script, 'test')
+ result = self.TimestampInfo.find(script, "test")
assert timestamp.finish == result.finish
assert 1 == result.counter
def test_update(self):
# Create a Timestamp to be updated.
- past = strptime_utc('19980101', '%Y%m%d')
+ past = strptime_utc("19980101", "%Y%m%d")
stamp = Timestamp.stamp(
- self._db, 'test', Timestamp.SCRIPT_TYPE, None, start=past,
- finish=past
+ self._db, "test", Timestamp.SCRIPT_TYPE, None, start=past, finish=past
)
script = DatabaseMigrationScript(self._db)
- timestamp_info = self.TimestampInfo.find(script, 'test')
+ timestamp_info = self.TimestampInfo.find(script, "test")
now = utc_now()
timestamp_info.update(self._db, now, 2)
@@ -882,11 +892,11 @@ def test_update(self):
def save(self):
# The Timestamp doesn't exist.
- timestamp_qu = self._db.query(Timestamp).filter(Timestamp.service=='test')
+ timestamp_qu = self._db.query(Timestamp).filter(Timestamp.service == "test")
assert False == timestamp_qu.exists()
now = utc_now()
- timestamp_info = self.TimestampInfo('test', now, 47)
+ timestamp_info = self.TimestampInfo("test", now, 47)
timestamp_info.save(self._db)
# The Timestamp exists now.
@@ -896,12 +906,11 @@ def save(self):
class DatabaseMigrationScriptTest(DatabaseTest):
-
@pytest.fixture
def migration_dirs(self, tmp_path):
# create migration file structure
- server = tmp_path / 'migration'
- core = tmp_path / 'server_core' / 'migation'
+ server = tmp_path / "migration"
+ core = tmp_path / "server_core" / "migation"
server.mkdir()
core.mkdir(parents=True)
@@ -916,36 +925,43 @@ def recursive_delete(path):
if file.is_dir():
recursive_delete(file)
file.rmdir()
+
recursive_delete(tmp_path)
@pytest.fixture()
def migration_file(self, tmp_path):
- def create_migration_file(directory, unique_string, migration_type, migration_date=None):
- suffix = '.'+migration_type
+ def create_migration_file(
+ directory, unique_string, migration_type, migration_date=None
+ ):
+ suffix = "." + migration_type
- if migration_type=='sql':
+ if migration_type == "sql":
# Create unique, innocuous content for a SQL file.
# This SQL inserts a timestamp into the test database.
service = "Test Database Migration Script - %s" % unique_string
- content = (("insert into timestamps(service, finish)"
- " values ('%s', '%s');") % (service, '1970-01-01'))
- elif migration_type=='py':
+ content = (
+ "insert into timestamps(service, finish)" " values ('%s', '%s');"
+ ) % (service, "1970-01-01")
+ elif migration_type == "py":
# Create unique, innocuous content for a Python file.
content = (
- "#!/usr/bin/env python\n\n"+
- "import tempfile\nimport os\n\n"+
- "file_info = tempfile.mkstemp(prefix='"+
- unique_string+"-', suffix='.py', dir='"+str(tmp_path)+"')\n\n"+
- "# Close file descriptor\n"+
- "os.close(file_info[0])\n"
+ "#!/usr/bin/env python\n\n"
+ + "import tempfile\nimport os\n\n"
+ + "file_info = tempfile.mkstemp(prefix='"
+ + unique_string
+ + "-', suffix='.py', dir='"
+ + str(tmp_path)
+ + "')\n\n"
+ + "# Close file descriptor\n"
+ + "os.close(file_info[0])\n"
)
else:
content = ""
if not migration_date:
# Default date is just after self.timestamp.
- migration_date = '20260811'
- prefix = migration_date + '-'
+ migration_date = "20260811"
+ prefix = migration_date + "-"
fd, migration_file = tempfile.mkstemp(
prefix=prefix, suffix=suffix, dir=directory, text=True
@@ -953,7 +969,7 @@ def create_migration_file(directory, unique_string, migration_type, migration_da
os.write(fd, content.encode("utf-8"))
# If it's a python migration, make it executable.
- if migration_file.endswith('py'):
+ if migration_file.endswith("py"):
original_mode = os.stat(migration_file).st_mode
mode = original_mode | (stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
os.chmod(migration_file, mode)
@@ -963,6 +979,7 @@ def create_migration_file(directory, unique_string, migration_type, migration_da
# return the filename
return migration_file
+
return create_migration_file
@pytest.fixture()
@@ -971,44 +988,40 @@ def migrations(self, migration_file, migration_dirs):
core_migration_files = []
server_migration_files = []
[core_dir, server_dir] = migration_dirs
- core_migration_files.append(migration_file(core_dir, 'CORE', 'sql'))
- core_migration_files.append(migration_file(core_dir, 'CORE', 'py'))
- server_migration_files.append(migration_file(server_dir, 'SERVER', 'sql'))
- server_migration_files.append(migration_file(server_dir, 'SERVER', 'py'))
+ core_migration_files.append(migration_file(core_dir, "CORE", "sql"))
+ core_migration_files.append(migration_file(core_dir, "CORE", "py"))
+ server_migration_files.append(migration_file(server_dir, "SERVER", "sql"))
+ server_migration_files.append(migration_file(server_dir, "SERVER", "py"))
return core_migration_files, server_migration_files
def teardown_method(self):
self._db.query(Timestamp).filter(
- Timestamp.service.like('%Database Migration%')
+ Timestamp.service.like("%Database Migration%")
).delete(synchronize_session=False)
super(DatabaseMigrationScriptTest, self).teardown_method()
class TestDatabaseMigrationScript(DatabaseMigrationScriptTest):
-
@pytest.fixture()
def script(self, monkeypatch, migration_dirs):
# Patch DatabaseMigrationScript to use test directories for migrations
- monkeypatch.setattr(DatabaseMigrationScript, "directories_by_priority", migration_dirs)
+ monkeypatch.setattr(
+ DatabaseMigrationScript, "directories_by_priority", migration_dirs
+ )
return DatabaseMigrationScript(self._db)
@pytest.fixture()
def timestamp(self, script):
- stamp = strptime_utc('20260810', '%Y%m%d')
- timestamp = Timestamp(
- service=script.name, start=stamp, finish=stamp
- )
+ stamp = strptime_utc("20260810", "%Y%m%d")
+ timestamp = Timestamp(service=script.name, start=stamp, finish=stamp)
python_timestamp = Timestamp(
- service=script.PY_TIMESTAMP_SERVICE_NAME, start=stamp,
- finish=stamp
+ service=script.PY_TIMESTAMP_SERVICE_NAME, start=stamp, finish=stamp
)
self._db.add_all([timestamp, python_timestamp])
self._db.flush()
- timestamp_info = script.TimestampInfo(
- timestamp.service, timestamp.start
- )
+ timestamp_info = script.TimestampInfo(timestamp.service, timestamp.start)
return timestamp, python_timestamp, timestamp_info
@@ -1021,7 +1034,7 @@ def test_name(self, script):
assert "Database Migration" == script.name
# A python-only script returns a Python-specific timestamp name.
- script.python_only=True
+ script.python_only = True
assert "Database Migration - Python" == script.name
def test_timestamp_properties(self, script):
@@ -1031,7 +1044,7 @@ def test_timestamp_properties(self, script):
# If there aren't any Database Migrations in the database, no
# timestamps are returned.
timestamps = self._db.query(Timestamp).filter(
- Timestamp.service.like('Database Migration%')
+ Timestamp.service.like("Database Migration%")
)
for timestamp in timestamps:
self._db.delete(timestamp)
@@ -1043,12 +1056,16 @@ def test_timestamp_properties(self, script):
# If the Timestamps exist in the database, but they don't have
# a timestamp, nothing is returned. Timestamps must be initialized.
- overall = self._db.query(Timestamp).filter(
- Timestamp.service==script.SERVICE_NAME
- ).one()
- python = self._db.query(Timestamp).filter(
- Timestamp.service==script.PY_TIMESTAMP_SERVICE_NAME
- ).one()
+ overall = (
+ self._db.query(Timestamp)
+ .filter(Timestamp.service == script.SERVICE_NAME)
+ .one()
+ )
+ python = (
+ self._db.query(Timestamp)
+ .filter(Timestamp.service == script.PY_TIMESTAMP_SERVICE_NAME)
+ .one()
+ )
# Neither Timestamp object has a timestamp.
assert (None, None) == (python.finish, overall.finish)
@@ -1057,8 +1074,8 @@ def test_timestamp_properties(self, script):
assert None == script.overall_timestamp
# If you give the Timestamps data, suddenly they show up.
- overall.finish = script.parse_time('1998-08-25')
- python.finish = script.parse_time('1993-06-11')
+ overall.finish = script.parse_time("1998-08-25")
+ python.finish = script.parse_time("1993-06-11")
python.counter = 2
self._db.flush()
@@ -1074,15 +1091,13 @@ def test_timestamp_properties(self, script):
def test_directories_by_priority(self):
core = os.path.split(os.path.split(__file__)[0])[0]
parent = os.path.split(core)[0]
- expected_core = os.path.join(core, 'migration')
- expected_parent = os.path.join(parent, 'migration')
+ expected_core = os.path.join(core, "migration")
+ expected_parent = os.path.join(parent, "migration")
# This is the only place we're testing the real script.
# Everywhere else should use the mock.
script = DatabaseMigrationScript()
- assert (
- [expected_core, expected_parent] ==
- script.directories_by_priority)
+ assert [expected_core, expected_parent] == script.directories_by_priority
def test_fetch_migration_files(self, script, migrations, migration_dirs):
result = script.fetch_migration_files()
@@ -1100,27 +1115,37 @@ def test_fetch_migration_files(self, script, migrations, migration_dirs):
# the 'core' directory array in migrations_by_directory.
assert 2 == len(core_migrations)
for filename in core_migrations:
- assert os.path.split(filename)[1] in result_migrations_by_dir[core_migration_dir]
+ assert (
+ os.path.split(filename)[1]
+ in result_migrations_by_dir[core_migration_dir]
+ )
# Ensure that all the expected migrations from the parent server
# are included in the appropriate array in migrations_by_directory.
assert 2 == len(server_migrations)
for filename in server_migrations:
- assert os.path.split(filename)[1] in result_migrations_by_dir[server_migration_dir]
+ assert (
+ os.path.split(filename)[1]
+ in result_migrations_by_dir[server_migration_dir]
+ )
# When the script is python_only, only python migrations are returned.
script.python_only = True
result_migrations, result_migrations_by_dir = script.fetch_migration_files()
- py_migration_files = [m for m in all_migrations if m.endswith('.py')]
+ py_migration_files = [m for m in all_migrations if m.endswith(".py")]
py_migration_filenames = [os.path.split(f)[1] for f in py_migration_files]
assert sorted(py_migration_filenames) == sorted(result_migrations)
- core_migration_files = [os.path.split(m)[1] for m in core_migrations if m.endswith('.py')]
+ core_migration_files = [
+ os.path.split(m)[1] for m in core_migrations if m.endswith(".py")
+ ]
assert 1 == len(core_migration_files)
assert result_migrations_by_dir[core_migration_dir] == core_migration_files
- server_migration_files = [os.path.split(m)[1] for m in server_migrations if m.endswith('.py')]
+ server_migration_files = [
+ os.path.split(m)[1] for m in server_migrations if m.endswith(".py")
+ ]
assert 1 == len(server_migration_files)
assert result_migrations_by_dir[server_migration_dir] == server_migration_files
@@ -1128,19 +1153,22 @@ def test_migratable_files(self, script):
"""Returns migrations that end with particular extensions."""
migrations = [
- '.gitkeep', '20250521-make-bananas.sql', '20260810-do-a-thing.py',
- '20260802-did-a-thing.pyc', 'why-am-i-here.rb'
+ ".gitkeep",
+ "20250521-make-bananas.sql",
+ "20260810-do-a-thing.py",
+ "20260802-did-a-thing.pyc",
+ "why-am-i-here.rb",
]
- result = script.migratable_files(migrations, ['.sql', '.py'])
+ result = script.migratable_files(migrations, [".sql", ".py"])
assert 2 == len(result)
- assert ['20250521-make-bananas.sql', '20260810-do-a-thing.py'] == result
+ assert ["20250521-make-bananas.sql", "20260810-do-a-thing.py"] == result
- result = script.migratable_files(migrations, ['.rb'])
+ result = script.migratable_files(migrations, [".rb"])
assert 1 == len(result)
- assert ['why-am-i-here.rb'] == result
+ assert ["why-am-i-here.rb"] == result
- result = script.migratable_files(migrations, ['banana'])
+ result = script.migratable_files(migrations, ["banana"])
assert [] == result
def test_get_new_migrations(self, script, timestamp):
@@ -1148,23 +1176,23 @@ def test_get_new_migrations(self, script, timestamp):
timestamp, python_timestamp, timestamp_info = timestamp
migrations = [
- '20271204-far-future-migration-funtime.sql',
- '20271202-future-migration-funtime.sql',
- '20271203-do-another-thing.py',
- '20250521-make-bananas.sql',
- '20260810-last-timestamp',
- '20260811-do-a-thing.py',
- '20260809-already-done.sql',
+ "20271204-far-future-migration-funtime.sql",
+ "20271202-future-migration-funtime.sql",
+ "20271203-do-another-thing.py",
+ "20250521-make-bananas.sql",
+ "20260810-last-timestamp",
+ "20260811-do-a-thing.py",
+ "20260809-already-done.sql",
]
result = script.get_new_migrations(timestamp_info, migrations)
# Expected migrations will be sorted by timestamp. Python migrations
# will be sorted after SQL migrations.
expected = [
- '20271202-future-migration-funtime.sql',
- '20271204-far-future-migration-funtime.sql',
- '20260811-do-a-thing.py',
- '20271203-do-another-thing.py',
+ "20271202-future-migration-funtime.sql",
+ "20271204-far-future-migration-funtime.sql",
+ "20260811-do-a-thing.py",
+ "20271203-do-another-thing.py",
]
assert 4 == len(result)
@@ -1173,17 +1201,17 @@ def test_get_new_migrations(self, script, timestamp):
# If the timestamp has a counter, the filter only finds new migrations
# past the counter.
migrations = [
- '20260810-last-timestamp.sql',
- '20260810-1-do-a-thing.sql',
- '20271202-future-migration-funtime.sql',
- '20260810-2-do-all-the-things.sql',
- '20260809-already-done.sql'
+ "20260810-last-timestamp.sql",
+ "20260810-1-do-a-thing.sql",
+ "20271202-future-migration-funtime.sql",
+ "20260810-2-do-all-the-things.sql",
+ "20260809-already-done.sql",
]
timestamp_info.counter = 1
result = script.get_new_migrations(timestamp_info, migrations)
expected = [
- '20260810-2-do-all-the-things.sql',
- '20271202-future-migration-funtime.sql',
+ "20260810-2-do-all-the-things.sql",
+ "20271202-future-migration-funtime.sql",
]
assert 2 == len(result)
@@ -1193,19 +1221,19 @@ def test_get_new_migrations(self, script, timestamp):
# migrations with the same datetime, migrations with counters are
# sorted after migrations without them.
migrations = [
- '20260810-do-a-thing.sql',
- '20271202-1-more-future-migration-funtime.sql',
- '20260810-1-do-all-the-things.sql',
- '20260809-already-done.sql',
- '20271202-future-migration-funtime.sql',
+ "20260810-do-a-thing.sql",
+ "20271202-1-more-future-migration-funtime.sql",
+ "20260810-1-do-all-the-things.sql",
+ "20260809-already-done.sql",
+ "20271202-future-migration-funtime.sql",
]
timestamp_info.counter = None
result = script.get_new_migrations(timestamp_info, migrations)
expected = [
- '20260810-1-do-all-the-things.sql',
- '20271202-future-migration-funtime.sql',
- '20271202-1-more-future-migration-funtime.sql'
+ "20260810-1-do-all-the-things.sql",
+ "20271202-future-migration-funtime.sql",
+ "20271202-1-more-future-migration-funtime.sql",
]
assert 3 == len(result)
assert expected == result
@@ -1214,7 +1242,7 @@ def test_update_timestamps(self, script, timestamp):
"""Resets a timestamp according to the date of a migration file"""
timestamp, python_timestamp, timestamp_info = timestamp
- migration = '20271202-future-migration-funtime.sql'
+ migration = "20271202-future-migration-funtime.sql"
py_last_run_time = python_timestamp.finish
def assert_unchanged_python_timestamp():
@@ -1222,95 +1250,90 @@ def assert_unchanged_python_timestamp():
def assert_timestamp_matches_migration(timestamp, migration, counter=None):
self._db.refresh(timestamp)
- timestamp_str = timestamp.finish.strftime('%Y%m%d')
+ timestamp_str = timestamp.finish.strftime("%Y%m%d")
assert migration[0:8] == timestamp_str
assert counter == timestamp.counter
- assert timestamp_info.finish.strftime('%Y%m%d') != migration[0:8]
+ assert timestamp_info.finish.strftime("%Y%m%d") != migration[0:8]
script.update_timestamps(migration)
assert_timestamp_matches_migration(timestamp, migration)
assert_unchanged_python_timestamp()
# It also takes care of counter digits when multiple migrations
# exist for the same date.
- migration = '20280810-2-do-all-the-things.sql'
+ migration = "20280810-2-do-all-the-things.sql"
script.update_timestamps(migration)
assert_timestamp_matches_migration(timestamp, migration, counter=2)
assert_unchanged_python_timestamp()
# And removes those counter digits when the timestamp is updated.
- migration = '20280901-what-it-do.sql'
+ migration = "20280901-what-it-do.sql"
script.update_timestamps(migration)
assert_timestamp_matches_migration(timestamp, migration)
assert_unchanged_python_timestamp()
# If the migration is earlier than the existing timestamp,
# the timestamp is not updated.
- migration = '20280801-before-the-existing-timestamp.sql'
+ migration = "20280801-before-the-existing-timestamp.sql"
script.update_timestamps(migration)
- assert timestamp.finish.strftime('%Y%m%d') == '20280901'
+ assert timestamp.finish.strftime("%Y%m%d") == "20280901"
# Python migrations update both timestamps.
- migration = '20281001-new-task.py'
+ migration = "20281001-new-task.py"
script.update_timestamps(migration)
assert_timestamp_matches_migration(timestamp, migration)
assert_timestamp_matches_migration(python_timestamp, migration)
- def test_running_a_migration_updates_the_timestamps(self, timestamp, migration_file, migration_dirs, script):
+ def test_running_a_migration_updates_the_timestamps(
+ self, timestamp, migration_file, migration_dirs, script
+ ):
timestamp, python_timestamp, timestamp_info = timestamp
- future_time = strptime_utc('20261030', '%Y%m%d')
+ future_time = strptime_utc("20261030", "%Y%m%d")
timestamp_info.finish = future_time
[core_dir, server_dir] = migration_dirs
# Create a test migration after that point and grab relevant info about it.
migration_filepath = migration_file(
- core_dir, 'SINGLE', 'sql',
- migration_date='20261202'
+ core_dir, "SINGLE", "sql", migration_date="20261202"
)
# Run the migration with the relevant information.
migration_filename = os.path.split(migration_filepath)[1]
- migrations_by_dir = {
- core_dir : [migration_filename],
- server_dir : []
- }
+ migrations_by_dir = {core_dir: [migration_filename], server_dir: []}
# Running the migration updates the timestamps
- script.run_migrations(
- [migration_filename], migrations_by_dir, timestamp_info
- )
- assert timestamp.finish.strftime('%Y%m%d') == '20261202'
+ script.run_migrations([migration_filename], migrations_by_dir, timestamp_info)
+ assert timestamp.finish.strftime("%Y%m%d") == "20261202"
# Even when there are counters.
migration_filepath = migration_file(
- core_dir, 'COUNTER', 'sql',
- migration_date='20261203-3'
+ core_dir, "COUNTER", "sql", migration_date="20261203-3"
)
migration_filename = os.path.split(migration_filepath)[1]
migrations_by_dir[core_dir] = [migration_filename]
- script.run_migrations(
- [migration_filename], migrations_by_dir, timestamp_info
- )
- assert timestamp.finish.strftime('%Y%m%d') == '20261203'
+ script.run_migrations([migration_filename], migrations_by_dir, timestamp_info)
+ assert timestamp.finish.strftime("%Y%m%d") == "20261203"
assert timestamp.counter == 3
def test_all_migration_files_are_run(self, script, migrations, timestamp, tmp_path):
script.run(
- test_db=self._db, test=True,
- cmd_args=["--last-run-date", "2010-01-01"]
+ test_db=self._db, test=True, cmd_args=["--last-run-date", "2010-01-01"]
)
# There are two test timestamps in the database, confirming that
# the test SQL files created by the migrations fixture
# have been run.
- timestamps = self._db.query(Timestamp).filter(
- Timestamp.service.like('Test Database Migration Script - %')
- ).order_by(Timestamp.service).all()
+ timestamps = (
+ self._db.query(Timestamp)
+ .filter(Timestamp.service.like("Test Database Migration Script - %"))
+ .order_by(Timestamp.service)
+ .all()
+ )
assert 2 == len(timestamps)
# A timestamp has been generated from each migration directory.
- assert True == timestamps[0].service.endswith('CORE')
- assert True == timestamps[1].service.endswith('SERVER')
+ assert True == timestamps[0].service.endswith("CORE")
+ assert True == timestamps[1].service.endswith("SERVER")
for timestamp in timestamps:
self._db.delete(timestamp)
@@ -1318,103 +1341,131 @@ def test_all_migration_files_are_run(self, script, migrations, timestamp, tmp_pa
# There are two temporary files created in tmp_path,
# confirming that the test Python files created by
# migrations fixture have been run.
- test_generated_files = sorted([f.name for f in tmp_path.iterdir()
- if f.name.startswith(('CORE', 'SERVER')) and f.is_file()])
+ test_generated_files = sorted(
+ [
+ f.name
+ for f in tmp_path.iterdir()
+ if f.name.startswith(("CORE", "SERVER")) and f.is_file()
+ ]
+ )
assert 2 == len(test_generated_files)
# A file has been generated from each migration directory.
- assert 'CORE' in test_generated_files[0]
- assert 'SERVER' in test_generated_files[1]
+ assert "CORE" in test_generated_files[0]
+ assert "SERVER" in test_generated_files[1]
- def test_python_migration_files_can_be_run_independently(self, script, migrations, timestamp, tmp_path):
+ def test_python_migration_files_can_be_run_independently(
+ self, script, migrations, timestamp, tmp_path
+ ):
script.run(
- test_db=self._db, test=True,
- cmd_args=["--last-run-date", "2010-01-01", "--python-only"]
+ test_db=self._db,
+ test=True,
+ cmd_args=["--last-run-date", "2010-01-01", "--python-only"],
)
# There are no test timestamps in the database, confirming that
# no test SQL files created by the migrations fixture
# have been run.
- timestamps = self._db.query(Timestamp).filter(
- Timestamp.service.like('Test Database Migration Script - %')
- ).order_by(Timestamp.service).all()
+ timestamps = (
+ self._db.query(Timestamp)
+ .filter(Timestamp.service.like("Test Database Migration Script - %"))
+ .order_by(Timestamp.service)
+ .all()
+ )
assert [] == timestamps
# There are two temporary files in tmp_path, confirming that the test
# Python files created by the migrations fixture were run.
test_dir = os.path.split(__file__)[0]
all_files = os.listdir(test_dir)
- test_generated_files = sorted([f.name for f in tmp_path.iterdir()
- if f.name.startswith(('CORE', 'SERVER')) and f.is_file()])
+ test_generated_files = sorted(
+ [
+ f.name
+ for f in tmp_path.iterdir()
+ if f.name.startswith(("CORE", "SERVER")) and f.is_file()
+ ]
+ )
assert 2 == len(test_generated_files)
# A file has been generated from each migration directory.
- assert 'CORE' in test_generated_files[0]
- assert 'SERVER' in test_generated_files[1]
+ assert "CORE" in test_generated_files[0]
+ assert "SERVER" in test_generated_files[1]
class TestDatabaseMigrationInitializationScript(DatabaseMigrationScriptTest):
-
@pytest.fixture()
def script(self, monkeypatch, migration_dirs, migrations):
# Patch DatabaseMigrationInitializationScript to use test directories for migrations
- monkeypatch.setattr(DatabaseMigrationInitializationScript, "directories_by_priority", migration_dirs)
+ monkeypatch.setattr(
+ DatabaseMigrationInitializationScript,
+ "directories_by_priority",
+ migration_dirs,
+ )
return DatabaseMigrationInitializationScript(self._db)
def assert_matches_latest_python_migration(self, timestamp, script):
migrations = script.fetch_migration_files()[0]
migrations_sorted = script.sort_migrations(migrations)
- last_migration_date = [x for x in migrations_sorted if x.endswith('.py')][-1][0:8]
+ last_migration_date = [x for x in migrations_sorted if x.endswith(".py")][-1][
+ 0:8
+ ]
self.assert_matches_timestamp(timestamp, last_migration_date)
def assert_matches_latest_migration(self, timestamp, script):
migrations = script.fetch_migration_files()[0]
migrations_sorted = script.sort_migrations(migrations)
- py_migration = [x for x in migrations_sorted if x.endswith('.py')][-1][0:8]
- sql_migration = [x for x in migrations_sorted if x.endswith('.sql')][-1][0:8]
- last_migration_date = py_migration if int(py_migration) > int(sql_migration) else sql_migration
+ py_migration = [x for x in migrations_sorted if x.endswith(".py")][-1][0:8]
+ sql_migration = [x for x in migrations_sorted if x.endswith(".sql")][-1][0:8]
+ last_migration_date = (
+ py_migration if int(py_migration) > int(sql_migration) else sql_migration
+ )
self.assert_matches_timestamp(timestamp, last_migration_date)
def assert_matches_timestamp(self, timestamp, migration_date):
- assert timestamp.finish.strftime('%Y%m%d') == migration_date
+ assert timestamp.finish.strftime("%Y%m%d") == migration_date
def test_accurate_timestamps_created(self, script):
- assert (
- None ==
- Timestamp.value(
- self._db, script.name, Timestamp.SCRIPT_TYPE,
- collection=None
- ))
+ assert None == Timestamp.value(
+ self._db, script.name, Timestamp.SCRIPT_TYPE, collection=None
+ )
script.run()
self.assert_matches_latest_migration(script.overall_timestamp, script)
self.assert_matches_latest_python_migration(script.python_timestamp, script)
- def test_accurate_python_timestamp_created_python_later(self, script, migration_dirs, migration_file):
+ def test_accurate_python_timestamp_created_python_later(
+ self, script, migration_dirs, migration_file
+ ):
[core_migration_dir, server_migration_dir] = migration_dirs
- assert None == Timestamp.value(self._db, script.name, Timestamp.SCRIPT_TYPE, collection=None)
+ assert None == Timestamp.value(
+ self._db, script.name, Timestamp.SCRIPT_TYPE, collection=None
+ )
# If the last python migration and the last SQL migration have
# different timestamps, they're set accordingly.
- migration_file(core_migration_dir, 'CORE', 'sql', '20310101')
- migration_file(server_migration_dir, 'SERVER', 'py', '20300101')
+ migration_file(core_migration_dir, "CORE", "sql", "20310101")
+ migration_file(server_migration_dir, "SERVER", "py", "20300101")
script.run()
- self.assert_matches_timestamp(script.overall_timestamp, '20310101')
- self.assert_matches_timestamp(script.python_timestamp, '20300101')
+ self.assert_matches_timestamp(script.overall_timestamp, "20310101")
+ self.assert_matches_timestamp(script.python_timestamp, "20300101")
- def test_accurate_python_timestamp_created_python_earlier(self, script, migration_dirs, migration_file):
+ def test_accurate_python_timestamp_created_python_earlier(
+ self, script, migration_dirs, migration_file
+ ):
[core_migration_dir, server_migration_dir] = migration_dirs
- assert None == Timestamp.value(self._db, script.name, Timestamp.SCRIPT_TYPE, collection=None)
+ assert None == Timestamp.value(
+ self._db, script.name, Timestamp.SCRIPT_TYPE, collection=None
+ )
# If the last python migration and the last SQL migration have
# different timestamps, they're set accordingly.
- migration_file(core_migration_dir, 'CORE', 'sql', '20310101')
- migration_file(server_migration_dir, 'SERVER', 'py', '20350101')
+ migration_file(core_migration_dir, "CORE", "sql", "20310101")
+ migration_file(server_migration_dir, "SERVER", "py", "20350101")
script.run()
- self.assert_matches_timestamp(script.overall_timestamp, '20350101')
- self.assert_matches_timestamp(script.python_timestamp, '20350101')
+ self.assert_matches_timestamp(script.overall_timestamp, "20350101")
+ self.assert_matches_timestamp(script.python_timestamp, "20350101")
def test_error_raised_when_timestamp_exists(self):
script = DatabaseMigrationInitializationScript(self._db)
@@ -1422,38 +1473,38 @@ def test_error_raised_when_timestamp_exists(self):
pytest.raises(RuntimeError, script.run)
def test_error_not_raised_when_timestamp_forced(self, script):
- past = script.parse_time('19951127')
+ past = script.parse_time("19951127")
Timestamp.stamp(self._db, script.name, Timestamp.SCRIPT_TYPE, None, finish=past)
- script.run(['-f'])
+ script.run(["-f"])
self.assert_matches_latest_migration(script.overall_timestamp, script)
self.assert_matches_latest_python_migration(script.python_timestamp, script)
def test_accepts_last_run_date(self, script):
# A timestamp can be passed via the command line.
- script.run(['--last-run-date', '20101010'])
- expected_stamp = strptime_utc('20101010', '%Y%m%d')
+ script.run(["--last-run-date", "20101010"])
+ expected_stamp = strptime_utc("20101010", "%Y%m%d")
assert expected_stamp == script.overall_timestamp.finish
# It will override an existing timestamp if forced.
- script.run(['--last-run-date', '20111111', '--force'])
- expected_stamp = strptime_utc('20111111', '%Y%m%d')
+ script.run(["--last-run-date", "20111111", "--force"])
+ expected_stamp = strptime_utc("20111111", "%Y%m%d")
assert expected_stamp == script.overall_timestamp.finish
assert expected_stamp == script.python_timestamp.finish
def test_accepts_last_run_counter(self, script):
# If a counter is passed without a date, an error is raised.
- pytest.raises(ValueError, script.run, ['--last-run-counter', '7'])
+ pytest.raises(ValueError, script.run, ["--last-run-counter", "7"])
# With a date, the counter can be set.
- script.run(['--last-run-date', '20101010', '--last-run-counter', '7'])
- expected_stamp = strptime_utc('20101010', '%Y%m%d')
+ script.run(["--last-run-date", "20101010", "--last-run-counter", "7"])
+ expected_stamp = strptime_utc("20101010", "%Y%m%d")
assert expected_stamp == script.overall_timestamp.finish
assert 7 == script.overall_timestamp.counter
# When forced, the counter can be reset on an existing timestamp.
previous_timestamp = script.overall_timestamp.finish
- script.run(['--last-run-date', '20121212', '--last-run-counter', '2', '-f'])
- expected_stamp = strptime_utc('20121212', '%Y%m%d')
+ script.run(["--last-run-date", "20121212", "--last-run-counter", "2", "-f"])
+ expected_stamp = strptime_utc("20121212", "%Y%m%d")
assert expected_stamp == script.overall_timestamp.finish
assert expected_stamp == script.python_timestamp.finish
assert 2 == script.overall_timestamp.counter
@@ -1461,7 +1512,6 @@ def test_accepts_last_run_counter(self, script):
class TestAddClassificationScript(DatabaseTest):
-
def test_end_to_end(self):
work = self._work(with_license_pool=True)
identifier = work.license_pools[0].identifier
@@ -1469,14 +1519,17 @@ def test_end_to_end(self):
assert Classifier.AUDIENCE_ADULT == work.audience
cmd_args = [
- "--identifier-type", identifier.type,
- "--subject-type", Classifier.FREEFORM_AUDIENCE,
- "--subject-identifier", Classifier.AUDIENCE_CHILDREN,
- "--weight", "42", '--create-subject'
+ "--identifier-type",
+ identifier.type,
+ "--subject-type",
+ Classifier.FREEFORM_AUDIENCE,
+ "--subject-identifier",
+ Classifier.AUDIENCE_CHILDREN,
+ "--weight",
+ "42",
+ "--create-subject",
]
- script = AddClassificationScript(
- _db=self._db, cmd_args=cmd_args, stdin=stdin
- )
+ script = AddClassificationScript(_db=self._db, cmd_args=cmd_args, stdin=stdin)
script.do_run()
# The identifier has been classified under 'children'.
@@ -1497,13 +1550,14 @@ def test_autocreate(self):
assert Classifier.AUDIENCE_ADULT == work.audience
cmd_args = [
- "--identifier-type", identifier.type,
- "--subject-type", Classifier.TAG,
- "--subject-identifier", "some random tag"
+ "--identifier-type",
+ identifier.type,
+ "--subject-type",
+ Classifier.TAG,
+ "--subject-identifier",
+ "some random tag",
]
- script = AddClassificationScript(
- _db=self._db, cmd_args=cmd_args, stdin=stdin
- )
+ script = AddClassificationScript(_db=self._db, cmd_args=cmd_args, stdin=stdin)
script.do_run()
# Nothing has happened. There was no Subject with that
@@ -1514,10 +1568,8 @@ def test_autocreate(self):
# command-line arguments, the Subject is created and the
# classification happens.
stdin = MockStdin(identifier.identifier)
- cmd_args.append('--create-subject')
- script = AddClassificationScript(
- _db=self._db, cmd_args=cmd_args, stdin=stdin
- )
+ cmd_args.append("--create-subject")
+ script = AddClassificationScript(_db=self._db, cmd_args=cmd_args, stdin=stdin)
script.do_run()
[classification] = identifier.classifications
@@ -1526,7 +1578,6 @@ def test_autocreate(self):
class TestShowLibrariesScript(DatabaseTest):
-
def test_with_no_libraries(self):
output = StringIO()
ShowLibrariesScript().do_run(self._db, output=output)
@@ -1534,13 +1585,19 @@ def test_with_no_libraries(self):
def test_with_multiple_libraries(self):
l1, ignore = create(
- self._db, Library, name="Library 1", short_name="L1",
+ self._db,
+ Library,
+ name="Library 1",
+ short_name="L1",
)
- l1.library_registry_shared_secret="a"
+ l1.library_registry_shared_secret = "a"
l2, ignore = create(
- self._db, Library, name="Library 2", short_name="L2",
+ self._db,
+ Library,
+ name="Library 2",
+ short_name="L2",
)
- l2.library_registry_shared_secret="b"
+ l2.library_registry_shared_secret = "b"
# The output of this script is the result of running explain()
# on both libraries.
@@ -1551,13 +1608,10 @@ def test_with_multiple_libraries(self):
assert expect_1 + "\n" + expect_2 + "\n" == output.getvalue()
-
# We can tell the script to only list a single library.
output = StringIO()
ShowLibrariesScript().do_run(
- self._db,
- cmd_args=["--short-name=L2"],
- output=output
+ self._db, cmd_args=["--short-name=L2"], output=output
)
assert expect_2 + "\n" == output.getvalue()
@@ -1565,9 +1619,7 @@ def test_with_multiple_libraries(self):
# shared secret.
output = StringIO()
ShowLibrariesScript().do_run(
- self._db,
- cmd_args=["--show-secrets"],
- output=output
+ self._db, cmd_args=["--show-secrets"], output=output
)
expect_1 = "\n".join(l1.explain(include_secrets=True))
expect_2 = "\n".join(l2.explain(include_secrets=True))
@@ -1575,23 +1627,24 @@ def test_with_multiple_libraries(self):
class TestConfigureSiteScript(DatabaseTest):
-
def test_unknown_setting(self):
script = ConfigureSiteScript()
with pytest.raises(ValueError) as excinfo:
- script.do_run(self._db, [
- "--setting=setting1=value1"
- ])
- assert "'setting1' is not a known site-wide setting. Use --force to set it anyway." in str(excinfo.value)
+ script.do_run(self._db, ["--setting=setting1=value1"])
+ assert (
+ "'setting1' is not a known site-wide setting. Use --force to set it anyway."
+ in str(excinfo.value)
+ )
assert None == ConfigurationSetting.sitewide(self._db, "setting1").value
# Running with --force sets the setting.
script.do_run(
- self._db, [
+ self._db,
+ [
"--setting=setting1=value1",
"--force",
- ]
+ ],
)
assert "value1" == ConfigurationSetting.sitewide(self._db, "setting1").value
@@ -1599,48 +1652,53 @@ def test_unknown_setting(self):
def test_settings(self):
class TestConfig(object):
SITEWIDE_SETTINGS = [
- { "key": "setting1" },
- { "key": "setting2" },
- { "key": "setting_secret" },
+ {"key": "setting1"},
+ {"key": "setting2"},
+ {"key": "setting_secret"},
]
script = ConfigureSiteScript(config=TestConfig)
output = StringIO()
script.do_run(
- self._db, [
+ self._db,
+ [
"--setting=setting1=value1",
- "--setting=setting2=[1,2,\"3\"]",
+ '--setting=setting2=[1,2,"3"]',
"--setting=setting_secret=secretvalue",
],
- output
+ output,
)
# The secret was set, but is not shown.
expect = "\n".join(
ConfigurationSetting.explain(self._db, include_secrets=False)
)
assert expect == output.getvalue()
- assert 'setting_secret' not in expect
+ assert "setting_secret" not in expect
assert "value1" == ConfigurationSetting.sitewide(self._db, "setting1").value
assert '[1,2,"3"]' == ConfigurationSetting.sitewide(self._db, "setting2").value
- assert "secretvalue" == ConfigurationSetting.sitewide(self._db, "setting_secret").value
+ assert (
+ "secretvalue"
+ == ConfigurationSetting.sitewide(self._db, "setting_secret").value
+ )
# If we run again with --show-secrets, the secret is shown.
output = StringIO()
script.do_run(self._db, ["--show-secrets"], output)
- expect = "\n".join(
- ConfigurationSetting.explain(self._db, include_secrets=True)
- )
+ expect = "\n".join(ConfigurationSetting.explain(self._db, include_secrets=True))
assert expect == output.getvalue()
- assert 'setting_secret' in expect
+ assert "setting_secret" in expect
-class TestConfigureLibraryScript(DatabaseTest):
+class TestConfigureLibraryScript(DatabaseTest):
def test_bad_arguments(self):
script = ConfigureLibraryScript()
library, ignore = create(
- self._db, Library, name="Library 1", short_name="L1",
+ self._db,
+ Library,
+ name="Library 1",
+ short_name="L1",
)
- library.library_registry_shared_secret='secret'
+ library.library_registry_shared_secret = "secret"
self._db.commit()
with pytest.raises(ValueError) as excinfo:
script.do_run(self._db, [])
@@ -1657,12 +1715,13 @@ def test_create_library(self):
script = ConfigureLibraryScript()
output = StringIO()
script.do_run(
- self._db, [
+ self._db,
+ [
"--short-name=L1",
"--name=Library 1",
- '--setting=customkey=value',
+ "--setting=customkey=value",
],
- output
+ output,
)
# Now there is one library.
@@ -1670,48 +1729,57 @@ def test_create_library(self):
assert "Library 1" == library.name
assert "L1" == library.short_name
assert "value" == library.setting("customkey").value
- expect_output = "Configuration settings stored.\n" + "\n".join(library.explain()) + "\n"
+ expect_output = (
+ "Configuration settings stored.\n" + "\n".join(library.explain()) + "\n"
+ )
assert expect_output == output.getvalue()
def test_reconfigure_library(self):
# The library exists.
library, ignore = create(
- self._db, Library, name="Library 1", short_name="L1",
+ self._db,
+ Library,
+ name="Library 1",
+ short_name="L1",
)
script = ConfigureLibraryScript()
output = StringIO()
# We're going to change one value and add a setting.
script.do_run(
- self._db, [
+ self._db,
+ [
"--short-name=L1",
"--name=Library 1 New Name",
- '--setting=customkey=value',
+ "--setting=customkey=value",
],
- output
+ output,
)
assert "Library 1 New Name" == library.name
assert "value" == library.setting("customkey").value
- expect_output = "Configuration settings stored.\n" + "\n".join(library.explain()) + "\n"
+ expect_output = (
+ "Configuration settings stored.\n" + "\n".join(library.explain()) + "\n"
+ )
assert expect_output == output.getvalue()
class TestShowCollectionsScript(DatabaseTest):
-
def test_with_no_collections(self):
output = StringIO()
ShowCollectionsScript().do_run(self._db, output=output)
assert "No collections found.\n" == output.getvalue()
def test_with_multiple_collections(self):
- c1 = self._collection(name="Collection 1",
- protocol=ExternalIntegration.OVERDRIVE)
- c1.collection_password="a"
- c2 = self._collection(name="Collection 2",
- protocol=ExternalIntegration.BIBLIOTHECA)
- c2.collection_password="b"
+ c1 = self._collection(
+ name="Collection 1", protocol=ExternalIntegration.OVERDRIVE
+ )
+ c1.collection_password = "a"
+ c2 = self._collection(
+ name="Collection 2", protocol=ExternalIntegration.BIBLIOTHECA
+ )
+ c2.collection_password = "b"
# The output of this script is the result of running explain()
# on both collections.
@@ -1722,22 +1790,17 @@ def test_with_multiple_collections(self):
assert expect_1 + "\n" + expect_2 + "\n" == output.getvalue()
-
# We can tell the script to only list a single collection.
output = StringIO()
ShowCollectionsScript().do_run(
- self._db,
- cmd_args=["--name=Collection 2"],
- output=output
+ self._db, cmd_args=["--name=Collection 2"], output=output
)
assert expect_2 + "\n" == output.getvalue()
# We can tell the script to include the collection password
output = StringIO()
ShowCollectionsScript().do_run(
- self._db,
- cmd_args=["--show-secrets"],
- output=output
+ self._db, cmd_args=["--show-secrets"], output=output
)
expect_1 = "\n".join(c1.explain(include_secrets=True))
expect_2 = "\n".join(c2.explain(include_secrets=True))
@@ -1745,11 +1808,13 @@ def test_with_multiple_collections(self):
class TestConfigureCollectionScript(DatabaseTest):
-
def test_bad_arguments(self):
script = ConfigureCollectionScript()
library, ignore = create(
- self._db, Library, name="Library 1", short_name="L1",
+ self._db,
+ Library,
+ name="Library 1",
+ short_name="L1",
)
self._db.commit()
@@ -1757,36 +1822,54 @@ def test_bad_arguments(self):
# necessary to create it.
with pytest.raises(ValueError) as excinfo:
script.do_run(self._db, ["--name=collection"])
- assert 'No collection called "collection". You can create it, but you must specify a protocol.' in str(excinfo.value)
+ assert (
+ 'No collection called "collection". You can create it, but you must specify a protocol.'
+ in str(excinfo.value)
+ )
# Incorrect format for the 'setting' argument.
with pytest.raises(ValueError) as excinfo:
- script.do_run(self._db, [
- "--name=collection", "--protocol=Overdrive",
- "--setting=key"
- ])
- assert 'Incorrect format for setting: "key". Should be "key=value"' in str(excinfo.value)
+ script.do_run(
+ self._db, ["--name=collection", "--protocol=Overdrive", "--setting=key"]
+ )
+ assert 'Incorrect format for setting: "key". Should be "key=value"' in str(
+ excinfo.value
+ )
# Try to add the collection to a nonexistent library.
with pytest.raises(ValueError) as excinfo:
- script.do_run(self._db, [
- "--name=collection", "--protocol=Overdrive",
- "--library=nosuchlibrary"
- ])
- assert 'No such library: "nosuchlibrary". I only know about: "L1"' in str(excinfo.value)
-
+ script.do_run(
+ self._db,
+ [
+ "--name=collection",
+ "--protocol=Overdrive",
+ "--library=nosuchlibrary",
+ ],
+ )
+ assert 'No such library: "nosuchlibrary". I only know about: "L1"' in str(
+ excinfo.value
+ )
def test_success(self):
script = ConfigureCollectionScript()
l1, ignore = create(
- self._db, Library, name="Library 1", short_name="L1",
+ self._db,
+ Library,
+ name="Library 1",
+ short_name="L1",
)
l2, ignore = create(
- self._db, Library, name="Library 2", short_name="L2",
+ self._db,
+ Library,
+ name="Library 2",
+ short_name="L2",
)
l3, ignore = create(
- self._db, Library, name="Library 3", short_name="L3",
+ self._db,
+ Library,
+ name="Library 3",
+ short_name="L3",
)
self._db.commit()
@@ -1794,14 +1877,19 @@ def test_success(self):
# setting, and associate it with two libraries.
output = StringIO()
script.do_run(
- self._db, ["--name=New Collection", "--protocol=Overdrive",
- "--library=L2", "--library=L1",
- "--setting=library_id=1234",
- "--external-account-id=acctid",
- "--url=url",
- "--username=username",
- "--password=password",
- ], output
+ self._db,
+ [
+ "--name=New Collection",
+ "--protocol=Overdrive",
+ "--library=L2",
+ "--library=L1",
+ "--setting=library_id=1234",
+ "--external-account-id=acctid",
+ "--url=url",
+ "--username=username",
+ "--password=password",
+ ],
+ output,
)
# The collection was created and configured properly.
@@ -1824,41 +1912,42 @@ def test_success(self):
assert "1234" == setting.value
# The output explains the collection settings.
- expect = ("Configuration settings stored.\n"
- + "\n".join(collection.explain()) + "\n")
+ expect = (
+ "Configuration settings stored.\n" + "\n".join(collection.explain()) + "\n"
+ )
assert expect == output.getvalue()
def test_reconfigure_collection(self):
# The collection exists.
collection = self._collection(
- name="Collection 1",
- protocol=ExternalIntegration.OVERDRIVE
+ name="Collection 1", protocol=ExternalIntegration.OVERDRIVE
)
script = ConfigureCollectionScript()
output = StringIO()
# We're going to change one value and add a new one.
script.do_run(
- self._db, [
+ self._db,
+ [
"--name=Collection 1",
"--url=foo",
- "--protocol=%s" % ExternalIntegration.BIBLIOTHECA
+ "--protocol=%s" % ExternalIntegration.BIBLIOTHECA,
],
- output
+ output,
)
# The collection has been changed.
assert "foo" == collection.external_integration.url
assert ExternalIntegration.BIBLIOTHECA == collection.protocol
- expect = ("Configuration settings stored.\n"
- + "\n".join(collection.explain()) + "\n")
+ expect = (
+ "Configuration settings stored.\n" + "\n".join(collection.explain()) + "\n"
+ )
assert expect == output.getvalue()
class TestShowIntegrationsScript(DatabaseTest):
-
def test_with_no_integrations(self):
output = StringIO()
ShowIntegrationsScript().do_run(self._db, output=output)
@@ -1866,17 +1955,13 @@ def test_with_no_integrations(self):
def test_with_multiple_integrations(self):
i1 = self._external_integration(
- name="Integration 1",
- goal="Goal",
- protocol=ExternalIntegration.OVERDRIVE
+ name="Integration 1", goal="Goal", protocol=ExternalIntegration.OVERDRIVE
)
- i1.password="a"
+ i1.password = "a"
i2 = self._external_integration(
- name="Integration 2",
- goal="Goal",
- protocol=ExternalIntegration.BIBLIOTHECA
+ name="Integration 2", goal="Goal", protocol=ExternalIntegration.BIBLIOTHECA
)
- i2.password="b"
+ i2.password = "b"
# The output of this script is the result of running explain()
# on both integrations.
@@ -1887,22 +1972,17 @@ def test_with_multiple_integrations(self):
assert expect_1 + "\n" + expect_2 + "\n" == output.getvalue()
-
# We can tell the script to only list a single integration.
output = StringIO()
ShowIntegrationsScript().do_run(
- self._db,
- cmd_args=["--name=Integration 2"],
- output=output
+ self._db, cmd_args=["--name=Integration 2"], output=output
)
assert expect_2 + "\n" == output.getvalue()
# We can tell the script to include the integration secrets
output = StringIO()
ShowIntegrationsScript().do_run(
- self._db,
- cmd_args=["--show-secrets"],
- output=output
+ self._db, cmd_args=["--show-secrets"], output=output
)
expect_1 = "\n".join(i1.explain(include_secrets=True))
expect_2 = "\n".join(i2.explain(include_secrets=True))
@@ -1910,13 +1990,15 @@ def test_with_multiple_integrations(self):
class TestConfigureIntegrationScript(DatabaseTest):
-
def test_load_integration(self):
m = ConfigureIntegrationScript._integration
with pytest.raises(ValueError) as excinfo:
m(self._db, None, None, "protocol", None)
- assert "An integration must by identified by either ID, name, or the combination of protocol and goal." in str(excinfo.value)
+ assert (
+ "An integration must by identified by either ID, name, or the combination of protocol and goal."
+ in str(excinfo.value)
+ )
with pytest.raises(ValueError) as excinfo:
m(self._db, "notanid", None, None, None)
@@ -1924,20 +2006,18 @@ def test_load_integration(self):
with pytest.raises(ValueError) as excinfo:
m(self._db, None, "Unknown integration", None, None)
- assert 'No integration with name "Unknown integration". To create it, you must also provide protocol and goal.' in str(excinfo.value)
-
- integration = self._external_integration(
- protocol="Protocol", goal="Goal"
+ assert (
+ 'No integration with name "Unknown integration". To create it, you must also provide protocol and goal.'
+ in str(excinfo.value)
)
+
+ integration = self._external_integration(protocol="Protocol", goal="Goal")
integration.name = "An integration"
- assert (integration ==
- m(self._db, integration.id, None, None, None))
+ assert integration == m(self._db, integration.id, None, None, None)
- assert (integration ==
- m(self._db, None, integration.name, None, None))
+ assert integration == m(self._db, None, integration.name, None, None)
- assert (integration ==
- m(self._db, None, None, "Protocol", "Goal"))
+ assert integration == m(self._db, None, None, "Protocol", "Goal")
# An integration may be created given a protocol and goal.
integration2 = m(self._db, None, "I exist now", "Protocol", "Goal2")
@@ -1951,23 +2031,27 @@ def test_add_settings(self):
output = StringIO()
script.do_run(
- self._db, [
+ self._db,
+ [
"--protocol=aprotocol",
"--goal=agoal",
"--setting=akey=avalue",
],
- output
+ output,
)
# An ExternalIntegration was created and configured.
- integration = get_one(self._db, ExternalIntegration,
- protocol="aprotocol", goal="agoal")
+ integration = get_one(
+ self._db, ExternalIntegration, protocol="aprotocol", goal="agoal"
+ )
- expect_output = "Configuration settings stored.\n" + "\n".join(integration.explain()) + "\n"
+ expect_output = (
+ "Configuration settings stored.\n" + "\n".join(integration.explain()) + "\n"
+ )
assert expect_output == output.getvalue()
-class TestShowLanesScript(DatabaseTest):
+class TestShowLanesScript(DatabaseTest):
def test_with_no_lanes(self):
output = StringIO()
ShowLanesScript().do_run(self._db, output=output)
@@ -1988,28 +2072,24 @@ def test_with_multiple_lanes(self):
# We can tell the script to only list a single lane.
output = StringIO()
- ShowLanesScript().do_run(
- self._db,
- cmd_args=["--id=%s" % l2.id],
- output=output
- )
+ ShowLanesScript().do_run(self._db, cmd_args=["--id=%s" % l2.id], output=output)
assert expect_2 + "\n\n" == output.getvalue()
-class TestConfigureLaneScript(DatabaseTest):
+class TestConfigureLaneScript(DatabaseTest):
def test_bad_arguments(self):
script = ConfigureLaneScript()
# No lane id but no library short name for creating it either.
with pytest.raises(ValueError) as excinfo:
script.do_run(self._db, [])
- assert 'Library short name is required to create a new lane' in str(excinfo.value)
+ assert "Library short name is required to create a new lane" in str(
+ excinfo.value
+ )
# Try to create a lane for a nonexistent library.
with pytest.raises(ValueError) as excinfo:
- script.do_run(self._db, [
- "--library-short-name=nosuchlibrary"
- ])
+ script.do_run(self._db, ["--library-short-name=nosuchlibrary"])
assert 'No such library: "nosuchlibrary".' in str(excinfo.value)
def test_create_lane(self):
@@ -2019,11 +2099,14 @@ def test_create_lane(self):
# Create a lane and set its attributes.
output = StringIO()
script.do_run(
- self._db, ["--library-short-name=%s" % self._default_library.short_name,
- "--parent-id=%s" % parent.id,
- "--priority=3",
- "--display-name=NewLane",
- ], output
+ self._db,
+ [
+ "--library-short-name=%s" % self._default_library.short_name,
+ "--parent-id=%s" % parent.id,
+ "--priority=3",
+ "--display-name=NewLane",
+ ],
+ output,
)
# The lane was created and configured properly.
@@ -2033,8 +2116,7 @@ def test_create_lane(self):
assert 3 == lane.priority
# The output explains the lane settings.
- expect = ("Lane settings stored.\n"
- + "\n".join(lane.explain()) + "\n")
+ expect = "Lane settings stored.\n" + "\n".join(lane.explain()) + "\n"
assert expect == output.getvalue()
def test_reconfigure_lane(self):
@@ -2048,19 +2130,19 @@ def test_reconfigure_lane(self):
output = StringIO()
script.do_run(
- self._db, [
+ self._db,
+ [
"--id=%s" % lane.id,
"--priority=1",
"--parent-id=%s" % parent.id,
],
- output
+ output,
)
# The lane has been changed.
assert 1 == lane.priority
assert parent == lane.parent
- expect = ("Lane settings stored.\n"
- + "\n".join(lane.explain()) + "\n")
+ expect = "Lane settings stored.\n" + "\n".join(lane.explain()) + "\n"
assert expect == output.getvalue()
@@ -2069,11 +2151,8 @@ class TestCollectionInputScript(DatabaseTest):
"""Test the ability to name collections on the command line."""
def test_parse_command_line(self):
-
def collections(cmd_args):
- parsed = CollectionInputScript.parse_command_line(
- self._db, cmd_args
- )
+ parsed = CollectionInputScript.parse_command_line(self._db, cmd_args)
return parsed.collections
# No collections named on command line -> no collections
@@ -2097,11 +2176,8 @@ class TestCollectionArgumentsScript(DatabaseTest):
"""Test the ability to take collection arguments on the command line."""
def test_parse_command_line(self):
-
def collections(cmd_args):
- parsed = CollectionArgumentsScript.parse_command_line(
- self._db, cmd_args
- )
+ parsed = CollectionArgumentsScript.parse_command_line(self._db, cmd_args)
return parsed.collections
# No collections named on command line -> no collections
@@ -2109,8 +2185,8 @@ def collections(cmd_args):
# Nonexistent collection -> ValueError
with pytest.raises(ValueError) as excinfo:
- collections(['no such collection'])
- assert 'Unknown collection: no such collection' in str(excinfo.value)
+ collections(["no such collection"])
+ assert "Unknown collection: no such collection" in str(excinfo.value)
# Collections are presented in the order they were encountered
# on the command line.
@@ -2130,6 +2206,7 @@ def collections(cmd_args):
# Mock classes used by TestOPDSImportScript
class MockOPDSImportMonitor(object):
"""Pretend to monitor an OPDS feed for new titles."""
+
INSTANCES = []
def __init__(self, _db, collection, *args, **kwargs):
@@ -2142,22 +2219,25 @@ def __init__(self, _db, collection, *args, **kwargs):
def run(self):
self.was_run = True
+
class MockOPDSImporter(object):
"""Pretend to import titles from an OPDS feed."""
+
pass
+
class MockOPDSImportScript(OPDSImportScript):
"""Actually instantiate a monitor that will pretend to do something."""
+
MONITOR_CLASS = MockOPDSImportMonitor
IMPORTER_CLASS = MockOPDSImporter
class TestOPDSImportScript(DatabaseTest):
-
def test_do_run(self):
- self._default_collection.external_integration.setting(Collection.DATA_SOURCE_NAME_SETTING).value = (
- DataSource.OA_CONTENT_SERVER
- )
+ self._default_collection.external_integration.setting(
+ Collection.DATA_SOURCE_NAME_SETTING
+ ).value = DataSource.OA_CONTENT_SERVER
script = MockOPDSImportScript(self._db)
script.do_run([])
@@ -2167,7 +2247,7 @@ def test_do_run(self):
monitor = MockOPDSImportMonitor.INSTANCES.pop()
assert self._default_collection == monitor.collection
- args = ['--collection=%s' % self._default_collection.name]
+ args = ["--collection=%s" % self._default_collection.name]
script.do_run(args)
# If we provide the collection name, a MockOPDSImportMonitor is
@@ -2179,22 +2259,23 @@ def test_do_run(self):
# Our replacement OPDS importer class was passed in to the
# monitor constructor. If this had been a real monitor, that's the
# code we would have used to import OPDS feeds.
- assert MockOPDSImporter == monitor.kwargs['import_class']
- assert False == monitor.kwargs['force_reimport']
+ assert MockOPDSImporter == monitor.kwargs["import_class"]
+ assert False == monitor.kwargs["force_reimport"]
# Setting --force changes the 'force_reimport' argument
# passed to the monitor constructor.
- args.append('--force')
+ args.append("--force")
script.do_run(args)
monitor = MockOPDSImportMonitor.INSTANCES.pop()
assert self._default_collection == monitor.collection
- assert True == monitor.kwargs['force_reimport']
+ assert True == monitor.kwargs["force_reimport"]
class MockWhereAreMyBooks(WhereAreMyBooksScript):
"""A mock script that keeps track of its output in an easy-to-test
form, so we don't have to mess around with StringIO.
"""
+
def __init__(self, _db=None, output=None, search=None):
# In most cases a list will do fine for `output`.
output = output or []
@@ -2213,7 +2294,6 @@ def out(self, s, *args):
class TestWhereAreMyBooksScript(DatabaseTest):
-
def test_no_search_integration(self):
# We can't even get started without a working search integration.
@@ -2223,17 +2303,20 @@ def test_no_search_integration(self):
# out(), so this verifies that output actually gets written
# out.
output = StringIO()
- pytest.raises(CannotLoadConfiguration, WhereAreMyBooksScript,
- self._db, output=output)
+ pytest.raises(
+ CannotLoadConfiguration, WhereAreMyBooksScript, self._db, output=output
+ )
assert (
- "Here's your problem: the search integration is missing or misconfigured.\n" ==
- output.getvalue())
+ "Here's your problem: the search integration is missing or misconfigured.\n"
+ == output.getvalue()
+ )
def test_overall_structure(self):
# Verify that run() calls the methods we expect.
class Mock(MockWhereAreMyBooks):
"""Used to verify that the correct methods are called."""
+
def __init__(self, *args, **kwargs):
super(Mock, self).__init__(*args, **kwargs)
self.delete_cached_feeds_called = False
@@ -2252,8 +2335,10 @@ def explain_collection(self, collection):
# If there are no libraries in the system, that's a big problem.
script = Mock(self._db)
script.run()
- assert (["There are no libraries in the system -- that's a problem.", "\n"] ==
- script.output)
+ assert [
+ "There are no libraries in the system -- that's a problem.",
+ "\n",
+ ] == script.output
# We still run the other checks, though.
assert True == script.delete_cached_feeds_called
@@ -2275,8 +2360,7 @@ def explain_collection(self, collection):
assert True == script.delete_cached_feeds_called
# Every collection in the database was explained.
- assert (set([collection1, collection2]) ==
- set(script.explained_collections))
+ assert set([collection1, collection2]) == set(script.explained_collections)
# There only output were the newlines after the five method
# calls. All other output happened inside the methods we
@@ -2300,21 +2384,18 @@ def test_check_library(self):
script.check_library(library)
checking, has_collection, has_lanes = script.output
- assert ('Checking library %s', [library.name]) == checking
- assert ((' Associated with collection %s.', [collection.name]) ==
- has_collection)
- assert (' Associated with %s lanes.', [1]) == has_lanes
+ assert ("Checking library %s", [library.name]) == checking
+ assert (" Associated with collection %s.", [collection.name]) == has_collection
+ assert (" Associated with %s lanes.", [1]) == has_lanes
# This library has no collections and no lanes.
library2 = self._library()
script.output = []
script.check_library(library2)
checking, no_collection, no_lanes = script.output
- assert ('Checking library %s', [library2.name]) == checking
- assert (" This library has no collections -- that's a problem." ==
- no_collection)
- assert (" This library has no lanes -- that's a problem." ==
- no_lanes)
+ assert ("Checking library %s", [library2.name]) == checking
+ assert " This library has no collections -- that's a problem." == no_collection
+ assert " This library has no lanes -- that's a problem." == no_lanes
def test_delete_cached_feeds(self):
groups = CachedFeed(type=CachedFeed.GROUPS_TYPE, pagination="")
@@ -2327,8 +2408,10 @@ def test_delete_cached_feeds(self):
script = MockWhereAreMyBooks(self._db)
script.delete_cached_feeds()
how_many, theyre_gone = script.output
- assert (('%d feeds in cachedfeeds table, not counting grouped feeds.', [1]) ==
- how_many)
+ assert (
+ "%d feeds in cachedfeeds table, not counting grouped feeds.",
+ [1],
+ ) == how_many
assert " Deleting them all." == theyre_gone
# Call it again, and we don't see "Deleting them all". There aren't
@@ -2336,13 +2419,20 @@ def test_delete_cached_feeds(self):
script.output = []
script.delete_cached_feeds()
[how_many] = script.output
- assert (('%d feeds in cachedfeeds table, not counting grouped feeds.', [0]) ==
- how_many)
+ assert (
+ "%d feeds in cachedfeeds table, not counting grouped feeds.",
+ [0],
+ ) == how_many
def check_explanation(
- self, presentation_ready=1, not_presentation_ready=0,
- no_delivery_mechanisms=0, suppressed=0, not_owned=0,
- in_search_index=0, **kwargs
+ self,
+ presentation_ready=1,
+ not_presentation_ready=0,
+ no_delivery_mechanisms=0,
+ suppressed=0,
+ not_owned=0,
+ in_search_index=0,
+ **kwargs
):
"""Runs explain_collection() and verifies expected output."""
script = MockWhereAreMyBooks(self._db, **kwargs)
@@ -2350,43 +2440,46 @@ def check_explanation(
out = script.output
# This always happens.
- assert (('Examining collection "%s"', [self._default_collection.name]) ==
- out.pop(0))
- assert ((' %d presentation-ready works.', [presentation_ready]) ==
- out.pop(0))
- assert ((' %d works not presentation-ready.', [not_presentation_ready]) ==
- out.pop(0))
+ assert (
+ 'Examining collection "%s"',
+ [self._default_collection.name],
+ ) == out.pop(0)
+ assert (" %d presentation-ready works.", [presentation_ready]) == out.pop(0)
+ assert (
+ " %d works not presentation-ready.",
+ [not_presentation_ready],
+ ) == out.pop(0)
# These totals are only given if the numbers are nonzero.
#
if no_delivery_mechanisms:
assert (
- (" %d works are missing delivery mechanisms and won't show up.", [no_delivery_mechanisms]) ==
- out.pop(0))
+ " %d works are missing delivery mechanisms and won't show up.",
+ [no_delivery_mechanisms],
+ ) == out.pop(0)
if suppressed:
assert (
- (" %d works have suppressed LicensePools and won't show up.",
- [suppressed]) ==
- out.pop(0))
+ " %d works have suppressed LicensePools and won't show up.",
+ [suppressed],
+ ) == out.pop(0)
if not_owned:
assert (
- (" %d non-open-access works have no owned licenses and won't show up.",
- [not_owned]
- ) ==
- out.pop(0))
+ " %d non-open-access works have no owned licenses and won't show up.",
+ [not_owned],
+ ) == out.pop(0)
# Search engine statistics are always shown.
assert (
- (" %d works in the search index, expected around %d.",
- [in_search_index, presentation_ready]) ==
- out.pop(0))
+ " %d works in the search index, expected around %d.",
+ [in_search_index, presentation_ready],
+ ) == out.pop(0)
def test_no_presentation_ready_works(self):
# This work is not presentation-ready.
work = self._work(with_license_pool=True)
- work.presentation_ready=False
+ work.presentation_ready = False
script = MockWhereAreMyBooks(self._db)
self.check_explanation(presentation_ready=0, not_presentation_ready=1)
@@ -2422,7 +2515,6 @@ def test_search_engine(self):
class TestExplain(DatabaseTest):
-
def test_explain(self):
"""Make sure the Explain script runs without crashing."""
work = self._work(with_license_pool=True, genre="Science Fiction")
@@ -2449,11 +2541,11 @@ def test_explain(self):
# CoverageRecords associated with the primary identifier were
# printed out.
- assert 'OCLC Linked Data | an operation | success' in output
+ assert "OCLC Linked Data | an operation | success" in output
# WorkCoverageRecords associated with the work were
# printed out.
- assert 'generate-opds | success' in output
+ assert "generate-opds | success" in output
# There is an active LicensePool that is fulfillable and has
# copies owned.
@@ -2461,51 +2553,55 @@ def test_explain(self):
assert "Fulfillable" in output
assert "ACTIVE" in output
-class TestReclassifyWorksForUncheckedSubjectsScript(DatabaseTest):
+class TestReclassifyWorksForUncheckedSubjectsScript(DatabaseTest):
def test_constructor(self):
"""Make sure that we're only going to classify works
with unchecked subjects.
"""
script = ReclassifyWorksForUncheckedSubjectsScript(self._db)
- assert (WorkClassificationScript.policy ==
- ReclassifyWorksForUncheckedSubjectsScript.policy)
+ assert (
+ WorkClassificationScript.policy
+ == ReclassifyWorksForUncheckedSubjectsScript.policy
+ )
assert 100 == script.batch_size
- assert (dump_query(Work.for_unchecked_subjects(self._db)) ==
- dump_query(script.query))
+ assert dump_query(Work.for_unchecked_subjects(self._db)) == dump_query(
+ script.query
+ )
class TestListCollectionMetadataIdentifiersScript(DatabaseTest):
-
def test_do_run(self):
output = StringIO()
- script = ListCollectionMetadataIdentifiersScript(
- _db=self._db, output=output
- )
+ script = ListCollectionMetadataIdentifiersScript(_db=self._db, output=output)
# Create two collections.
c1 = self._collection(external_account_id=self._url)
c2 = self._collection(
- name='Local Over', protocol=ExternalIntegration.OVERDRIVE,
- external_account_id='banana'
+ name="Local Over",
+ protocol=ExternalIntegration.OVERDRIVE,
+ external_account_id="banana",
)
script.do_run()
def expected(c):
- return '(%s) %s/%s => %s\n' % (
- str(c.id), c.name, c.protocol, c.metadata_identifier
+ return "(%s) %s/%s => %s\n" % (
+ str(c.id),
+ c.name,
+ c.protocol,
+ c.metadata_identifier,
)
# In the output, there's a header, a line describing the format,
# metdata identifiers for each collection, and a count of the
# collections found.
output = output.getvalue()
- assert 'COLLECTIONS' in output
- assert '(id) name/protocol => metadata_identifier\n' in output
+ assert "COLLECTIONS" in output
+ assert "(id) name/protocol => metadata_identifier\n" in output
assert expected(c1) in output
assert expected(c2) in output
- assert '2 collections found.\n' in output
+ assert "2 collections found.\n" in output
class TestMirrorResourcesScript(DatabaseTest):
@@ -2549,39 +2645,49 @@ def process_collection(self, collection, policy):
processed = script.processed.pop()
assert (has_uploader, mock_uploader) == processed
- @parameterized.expand([
- (
- 'containing_open_access_books_with_s3_uploader',
- CollectionType.OPEN_ACCESS,
- ExternalIntegrationLink.OPEN_ACCESS_BOOKS,
- ExternalIntegration.S3,
- S3Uploader
- ),
- (
- 'containing_protected_access_books_with_s3_uploader',
- CollectionType.PROTECTED_ACCESS,
- ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS,
- ExternalIntegration.S3,
- S3Uploader
- ),
- (
- 'containing_open_access_books_with_minio_uploader',
- CollectionType.OPEN_ACCESS,
- ExternalIntegrationLink.OPEN_ACCESS_BOOKS,
- ExternalIntegration.MINIO,
- MinIOUploader,
- {MinIOUploaderConfiguration.ENDPOINT_URL: 'http://localhost'}
- ),
- (
- 'containing_protected_access_books_with_minio_uploader',
- CollectionType.PROTECTED_ACCESS,
- ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS,
- ExternalIntegration.MINIO,
- MinIOUploader,
- {MinIOUploaderConfiguration.ENDPOINT_URL: 'http://localhost'}
- )
- ])
- def test_collections(self, name, collection_type, book_mirror_type, protocol, uploader_class, settings=None):
+ @parameterized.expand(
+ [
+ (
+ "containing_open_access_books_with_s3_uploader",
+ CollectionType.OPEN_ACCESS,
+ ExternalIntegrationLink.OPEN_ACCESS_BOOKS,
+ ExternalIntegration.S3,
+ S3Uploader,
+ ),
+ (
+ "containing_protected_access_books_with_s3_uploader",
+ CollectionType.PROTECTED_ACCESS,
+ ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS,
+ ExternalIntegration.S3,
+ S3Uploader,
+ ),
+ (
+ "containing_open_access_books_with_minio_uploader",
+ CollectionType.OPEN_ACCESS,
+ ExternalIntegrationLink.OPEN_ACCESS_BOOKS,
+ ExternalIntegration.MINIO,
+ MinIOUploader,
+ {MinIOUploaderConfiguration.ENDPOINT_URL: "http://localhost"},
+ ),
+ (
+ "containing_protected_access_books_with_minio_uploader",
+ CollectionType.PROTECTED_ACCESS,
+ ExternalIntegrationLink.PROTECTED_ACCESS_BOOKS,
+ ExternalIntegration.MINIO,
+ MinIOUploader,
+ {MinIOUploaderConfiguration.ENDPOINT_URL: "http://localhost"},
+ ),
+ ]
+ )
+ def test_collections(
+ self,
+ name,
+ collection_type,
+ book_mirror_type,
+ protocol,
+ uploader_class,
+ settings=None,
+ ):
class Mock(MirrorResourcesScript):
mock_policy = object()
@@ -2596,9 +2702,7 @@ def replacement_policy(cls, uploader):
# The default collection does not have an uploader.
# This new collection does.
has_uploader = self._collection()
- mirror = self._external_integration(
- protocol, ExternalIntegration.STORAGE_GOAL
- )
+ mirror = self._external_integration(protocol, ExternalIntegration.STORAGE_GOAL)
if settings:
for key, value in settings.items():
@@ -2607,7 +2711,7 @@ def replacement_policy(cls, uploader):
integration_link = self._external_integration_link(
integration=has_uploader._external_integration,
other_integration=mirror,
- purpose=ExternalIntegrationLink.COVERS
+ purpose=ExternalIntegrationLink.COVERS,
)
# Calling collections_with_uploader will do nothing for collections
@@ -2616,7 +2720,7 @@ def replacement_policy(cls, uploader):
# and yield the result.
result = script.collections_with_uploader(
[self._default_collection, has_uploader, self._default_collection],
- collection_type
+ collection_type,
)
[(collection, policy)] = result
@@ -2626,7 +2730,8 @@ def replacement_policy(cls, uploader):
# expect to have one MirrorUploader.
assert Mock.replacement_policy_called_with[book_mirror_type] == None
assert isinstance(
- Mock.replacement_policy_called_with[ExternalIntegrationLink.COVERS], MirrorUploader
+ Mock.replacement_policy_called_with[ExternalIntegrationLink.COVERS],
+ MirrorUploader,
)
# Add another storage for books.
@@ -2637,21 +2742,25 @@ def replacement_policy(cls, uploader):
integration_link = self._external_integration_link(
integration=has_uploader._external_integration,
other_integration=another_mirror,
- purpose=book_mirror_type
+ purpose=book_mirror_type,
)
result = script.collections_with_uploader(
[self._default_collection, has_uploader, self._default_collection],
- collection_type
+ collection_type,
)
[(collection, policy)] = result
assert has_uploader == collection
assert Mock.mock_policy == policy
# There should be two MirrorUploaders, one for each purpose.
- assert isinstance(Mock.replacement_policy_called_with[ExternalIntegrationLink.COVERS], uploader_class)
assert isinstance(
- Mock.replacement_policy_called_with[book_mirror_type], uploader_class)
+ Mock.replacement_policy_called_with[ExternalIntegrationLink.COVERS],
+ uploader_class,
+ )
+ assert isinstance(
+ Mock.replacement_policy_called_with[book_mirror_type], uploader_class
+ )
def test_replacement_policy(self):
uploader = object()
@@ -2662,15 +2771,16 @@ def test_replacement_policy(self):
assert False == p.rights
def test_process_collection(self):
-
class MockScript(MirrorResourcesScript):
process_item_called_with = []
+
def process_item(self, collection, link, policy):
self.process_item_called_with.append((collection, link, policy))
# Mock the Hyperlink.unmirrored method
link1 = object()
link2 = object()
+
def unmirrored(collection):
assert collection == self._default_collection
yield link1
@@ -2721,8 +2831,7 @@ def test_derive_rights_status(self):
# no LicensePoolDeliveryMechanisms at all, then we just don't
# know.
pool2.set_delivery_mechanism(
- content_type="text/plain", drm_scheme=None,
- rights_uri=RightsStatus.CC_BY_ND
+ content_type="text/plain", drm_scheme=None, rights_uri=RightsStatus.CC_BY_ND
)
assert None == m(pool2, None)
@@ -2740,6 +2849,7 @@ def __init__(self):
def mirror_link(self, **kwargs):
self.mirrored.append(kwargs)
+
mirror = MockMirrorUtility()
class MockScript(MirrorResourcesScript):
@@ -2774,9 +2884,7 @@ def __init__(self, rel, href, identifier):
# associated with Identifiers licensed through a Collection.)
identifier = self._identifier()
policy = object()
- download_link = MockLink(
- Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url, identifier
- )
+ download_link = MockLink(Hyperlink.OPEN_ACCESS_DOWNLOAD, self._url, identifier)
self._default_collection.data_source = DataSource.GUTENBERG
m(self._default_collection, download_link, policy)
assert [] == mirror.mirrored
@@ -2799,34 +2907,31 @@ def __init__(self, rel, href, identifier):
# don't do anything.
m(self._default_collection, download_link, policy)
assert [] == mirror.mirrored
- assert ((pool, download_link.resource) ==
- script.derive_rights_status_called_with)
+ assert (pool, download_link.resource) == script.derive_rights_status_called_with
# If we _can_ determine the rights status, a mirror attempt is made.
script.RIGHTS_STATUS = object()
m(self._default_collection, download_link, policy)
attempt = mirror.mirrored.pop()
- assert policy == attempt['policy']
- assert pool.data_source == attempt['data_source']
- assert pool == attempt['model_object']
- assert download_link == attempt['link_obj']
+ assert policy == attempt["policy"]
+ assert pool.data_source == attempt["data_source"]
+ assert pool == attempt["model_object"]
+ assert download_link == attempt["link_obj"]
- link = attempt['link']
+ link = attempt["link"]
assert isinstance(link, LinkData)
assert download_link.resource.url == link.href
# For other types of links, we rely on fair use, so the "rights
# status" doesn't matter.
script.RIGHTS_STATUS = None
- thumb_link = MockLink(Hyperlink.THUMBNAIL_IMAGE, self._url,
- pool.identifier)
+ thumb_link = MockLink(Hyperlink.THUMBNAIL_IMAGE, self._url, pool.identifier)
m(self._default_collection, thumb_link, policy)
attempt = mirror.mirrored.pop()
- assert thumb_link.resource.url == attempt['link'].href
+ assert thumb_link.resource.url == attempt["link"].href
class TestRebuildSearchIndexScript(DatabaseTest):
-
def test_do_run(self):
class MockSearchIndex(object):
def setup_index(self):
@@ -2846,12 +2951,10 @@ def bulk_update(self, works):
# Set up some coverage records.
for operation in decoys + [wcr.UPDATE_SEARCH_INDEX_OPERATION]:
for w in (work, work2):
- wcr.add_for(
- w, operation, status=random.choice(wcr.ALL_STATUSES)
- )
+ wcr.add_for(w, operation, status=random.choice(wcr.ALL_STATUSES))
coverage_qu = self._db.query(wcr).filter(
- wcr.operation==wcr.UPDATE_SEARCH_INDEX_OPERATION
+ wcr.operation == wcr.UPDATE_SEARCH_INDEX_OPERATION
)
original_coverage = [x.id for x in coverage_qu]
@@ -2868,8 +2971,9 @@ def bulk_update(self, works):
# information about what happened (from the CoverageProvider's
# point of view).
assert (
- 'Items processed: 2. Successes: 2, transient failures: 0, persistent failures: 0' ==
- progress.achievements)
+ "Items processed: 2. Successes: 2, transient failures: 0, persistent failures: 0"
+ == progress.achievements
+ )
# The old WorkCoverageRecords for the works were deleted. Then
# the CoverageProvider did its job and new ones were added.
@@ -2891,9 +2995,7 @@ def test_do_run(self):
# Set up some coverage records.
for operation in decoys + [wcr.UPDATE_SEARCH_INDEX_OPERATION]:
for w in (work, work2):
- wcr.add_for(
- w, operation, status=random.choice(wcr.ALL_STATUSES)
- )
+ wcr.add_for(w, operation, status=random.choice(wcr.ALL_STATUSES))
# Run the script.
script = SearchIndexCoverageRemover(self._db)
@@ -2909,7 +3011,6 @@ def test_do_run(self):
class TestUpdateLaneSizeScript(DatabaseTest):
-
def test_do_run(self):
lane = self._lane()
lane.size = 100
@@ -2927,7 +3028,6 @@ def test_should_process_lane(self):
class TestUpdateCustomListSizeScript(DatabaseTest):
-
def test_do_run(self):
customlist, ignore = self._customlist(num_entries=1)
customlist.library = self._default_library
@@ -2938,29 +3038,35 @@ def test_do_run(self):
class TestWorkConsolidationScript(object):
"""TODO"""
+
pass
class TestWorkPresentationScript(object):
"""TODO"""
+
pass
class TestWorkClassificationScript(object):
"""TODO"""
+
pass
class TestWorkOPDSScript(object):
"""TODO"""
+
pass
class TestCustomListManagementScript(object):
"""TODO"""
+
pass
class TestNYTBestSellerListsScript(object):
"""TODO"""
+
pass
diff --git a/tests/test_selftest.py b/tests/test_selftest.py
index 84d24083d..55de92c5b 100644
--- a/tests/test_selftest.py
+++ b/tests/test_selftest.py
@@ -6,15 +6,12 @@
"""
import datetime
-from ..testing import DatabaseTest
-
-from ..selftest import (
- SelfTestResult,
- HasSelfTests,
-)
-from ..util.http import IntegrationException
+from ..selftest import HasSelfTests, SelfTestResult
+from ..testing import DatabaseTest
from ..util.datetime_helpers import utc_now
+from ..util.http import IntegrationException
+
class TestSelfTestResult(DatabaseTest):
@@ -32,23 +29,25 @@ def test_success_representation(self):
result.result = "The result"
result.success = True
assert (
- "" ==
- repr(result))
+ ""
+ == repr(result)
+ )
# A SelfTestResult may have an associated Collection.
self._default_collection.name = "CollectionA"
result.collection = self._default_collection
assert (
- "" ==
- repr(result))
+ ""
+ == repr(result)
+ )
d = result.to_dict
- assert "success1" == d['name']
- assert "The result" == d['result']
- assert 5.0 == d['duration']
- assert True == d['success']
- assert None == d['exception']
- assert 'CollectionA' == d['collection']
+ assert "success1" == d["name"]
+ assert "The result" == d["result"]
+ assert 5.0 == d["duration"]
+ assert True == d["success"]
+ assert None == d["exception"]
+ assert "CollectionA" == d["collection"]
# A test result can be either a string (which will be displayed
# in a fixed-width font) or a list of strings (which will be hidden
@@ -56,13 +55,13 @@ def test_success_representation(self):
list_result = ["list", "of", "strings"]
result.result = list_result
d = result.to_dict
- assert list_result == d['result']
+ assert list_result == d["result"]
# Other .result values don't make it into the dictionary because
# it's not defined how to display them.
result.result = {"a": "dictionary"}
d = result.to_dict
- assert None == d['result']
+ assert None == d["result"]
def test_repr_failure(self):
"""Show the string representation of a failed test result."""
@@ -75,21 +74,21 @@ def test_repr_failure(self):
result.exception = exception
result.result = "The result"
assert (
- "" ==
- repr(result))
+ ""
+ == repr(result)
+ )
d = result.to_dict
- assert "failure1" == d['name']
- assert "The result" == d['result']
- assert 5.0 == d['duration']
- assert False == d['success']
- assert 'IntegrationException' == d['exception']['class']
- assert 'basic info' == d['exception']['message']
- assert 'debug info' == d['exception']['debug_message']
+ assert "failure1" == d["name"]
+ assert "The result" == d["result"]
+ assert 5.0 == d["duration"]
+ assert False == d["success"]
+ assert "IntegrationException" == d["exception"]["class"]
+ assert "basic info" == d["exception"]["message"]
+ assert "debug info" == d["exception"]["debug_message"]
class TestHasSelfTests(DatabaseTest):
-
def test_run_self_tests(self):
"""See what might happen when run_self_tests tries to instantiate an
object and run its self-tests.
@@ -98,7 +97,7 @@ def test_run_self_tests(self):
class Tester(HasSelfTests):
def __init__(self, extra_arg=None):
"""This constructor works."""
- self.invoked_with = (extra_arg)
+ self.invoked_with = extra_arg
@classmethod
def good_alternate_constructor(self, another_extra_arg=None):
@@ -119,6 +118,7 @@ def external_integration(self, _db):
def _run_self_tests(self, _db):
self._run_self_tests_called_with = _db
return [SelfTestResult("a test result")]
+
mock_db = object()
# This integration will be used to store the test results.
@@ -127,9 +127,7 @@ def _run_self_tests(self, _db):
# By default, the default constructor is instantiated and its
# _run_self_tests method is called.
- data, [setup, test] = Tester.run_self_tests(
- mock_db, extra_arg="a value"
- )
+ data, [setup, test] = Tester.run_self_tests(mock_db, extra_arg="a value")
assert mock_db == setup.result._run_self_tests_called_with
# There are two results -- `setup` from the initial setup
@@ -141,13 +139,13 @@ def _run_self_tests(self, _db):
# The `data` variable contains a dictionary describing the test
# suite as a whole.
- assert data['duration'] < 1
- for key in 'start', 'end':
+ assert data["duration"] < 1
+ for key in "start", "end":
assert key in data
# `data['results']` contains dictionary versions of the self-tests
# that were returned separately.
- r1, r2 = data['results']
+ r1, r2 = data["results"]
assert r1 == setup.to_dict
assert r2 == test.to_dict
@@ -167,8 +165,9 @@ def _run_self_tests(self, _db):
# constructor. Once the object is instantiated, the same basic
# code runs.
data, [setup, test] = Tester.run_self_tests(
- mock_db, Tester.good_alternate_constructor,
- another_extra_arg="another value"
+ mock_db,
+ Tester.good_alternate_constructor,
+ another_extra_arg="another value",
)
assert "Initial setup." == setup.name
assert True == setup.success
@@ -185,7 +184,8 @@ def _run_self_tests(self, _db):
# single SelfTestResult describing that failure. Since there is
# no instance, _run_self_tests can't be called.
data, [result] = Tester.run_self_tests(
- mock_db, Tester.bad_alternate_constructor,
+ mock_db,
+ Tester.bad_alternate_constructor,
)
assert isinstance(result, SelfTestResult)
assert False == result.success
@@ -195,6 +195,7 @@ def test_exception_in_has_self_tests(self):
"""An exception raised in has_self_tests itself is converted into a
test failure.
"""
+
class Tester(HasSelfTests):
def _run_self_tests(self, _db):
yield SelfTestResult("everything's ok so far")
@@ -218,19 +219,19 @@ def test_run_test_success(self):
# This self-test method will succeed.
def successful_test(arg, kwarg):
return arg, kwarg
- result = o.run_test(
- "A successful test", successful_test, "arg1", kwarg="arg2"
- )
+
+ result = o.run_test("A successful test", successful_test, "arg1", kwarg="arg2")
assert True == result.success
assert "A successful test" == result.name
assert ("arg1", "arg2") == result.result
- assert (result.end-result.start).total_seconds() < 1
+ assert (result.end - result.start).total_seconds() < 1
def test_run_test_failure(self):
o = HasSelfTests()
# This self-test method will fail.
def unsuccessful_test(arg, kwarg):
raise IntegrationException(arg, kwarg)
+
result = o.run_test(
"An unsuccessful test", unsuccessful_test, "arg1", kwarg="arg2"
)
@@ -239,7 +240,7 @@ def unsuccessful_test(arg, kwarg):
assert None == result.result
assert "arg1" == str(result.exception)
assert "arg2" == result.exception.debug_message
- assert (result.end-result.start).total_seconds() < 1
+ assert (result.end - result.start).total_seconds() < 1
def test_test_failure(self):
o = HasSelfTests()
@@ -253,7 +254,7 @@ def test_test_failure(self):
assert "a failure" == result.name
assert isinstance(result.exception, IntegrationException)
assert "argh" == str(result.exception)
- assert (result.start-now).total_seconds() < 1
+ assert (result.start - now).total_seconds() < 1
# ... or you can pass in arguments to an IntegrationException
result = o.test_failure("another failure", "message", "debug")
diff --git a/tests/test_summary_evaluator.py b/tests/test_summary_evaluator.py
index 841486de3..9b7748c00 100644
--- a/tests/test_summary_evaluator.py
+++ b/tests/test_summary_evaluator.py
@@ -2,10 +2,11 @@
from textblob import TextBlob
from textblob.exceptions import MissingCorpusError
+
from ..util.summary import SummaryEvaluator
-class TestSummaryEvaluator(object):
+class TestSummaryEvaluator(object):
def _best(self, *summaries):
e = SummaryEvaluator()
for s in summaries:
@@ -47,11 +48,9 @@ def test_non_english_is_penalized(self):
evaluator.add(dutch)
evaluator.ready()
- dutch_no_language_penalty = evaluator.score(
- dutch, apply_language_penalty=False)
+ dutch_no_language_penalty = evaluator.score(dutch, apply_language_penalty=False)
- dutch_language_penalty = evaluator.score(
- dutch, apply_language_penalty=True)
+ dutch_language_penalty = evaluator.score(dutch, apply_language_penalty=True)
def test_english_is_not_penalized(self):
"""If description text appears to be in English, it is not rated down
@@ -65,10 +64,10 @@ def test_english_is_not_penalized(self):
evaluator.ready()
english_no_language_penalty = evaluator.score(
- english, apply_language_penalty=False)
+ english, apply_language_penalty=False
+ )
- english_language_penalty = evaluator.score(
- english, apply_language_penalty=True)
+ english_language_penalty = evaluator.score(english, apply_language_penalty=True)
assert english_language_penalty == english_no_language_penalty
def test_missing_corpus_error_ignored(self):
diff --git a/tests/test_user_profile.py b/tests/test_user_profile.py
index 64ff34019..b60010d28 100644
--- a/tests/test_user_profile.py
+++ b/tests/test_user_profile.py
@@ -1,23 +1,25 @@
import json
-from ..user_profile import (
- ProfileController,
- MockProfileStorage,
-)
-class TestProfileController(object):
+from ..user_profile import MockProfileStorage, ProfileController
+
+class TestProfileController(object):
def setup_method(self):
self.read_only_settings = dict(key="value")
self.writable_settings = dict(writable_key="old_value")
- self.storage = MockProfileStorage(self.read_only_settings, self.writable_settings)
+ self.storage = MockProfileStorage(
+ self.read_only_settings, self.writable_settings
+ )
self.controller = ProfileController(self.storage)
def test_profile_document(self):
"""Test that the default setup becomes a dictionary ready for
conversion to JSON.
"""
- assert ({'key': 'value', 'settings': {'writable_key': 'old_value'}} ==
- self.storage.profile_document)
+ assert {
+ "key": "value",
+ "settings": {"writable_key": "old_value"},
+ } == self.storage.profile_document
def test_get_success(self):
"""Test that sending a GET request to the controller results in the
@@ -25,7 +27,7 @@ def test_get_success(self):
"""
body, status_code, headers = self.controller.get()
assert 200 == status_code
- assert ProfileController.MEDIA_TYPE == headers['Content-Type']
+ assert ProfileController.MEDIA_TYPE == headers["Content-Type"]
assert json.dumps(self.storage.profile_document) == body
def test_put_success(self):
@@ -33,7 +35,7 @@ def test_put_success(self):
leads to changes in the writable part of the store, but not in
the read-only part.
"""
- headers = {"Content-Type" : ProfileController.MEDIA_TYPE}
+ headers = {"Content-Type": ProfileController.MEDIA_TYPE}
expected_new_state = dict(writable_key="new value")
old_read_only = dict(self.storage.read_only_settings)
body = json.dumps(dict(settings=expected_new_state))
@@ -46,11 +48,9 @@ def test_put_noop(self):
"""Test that sending an empty dictionary of key-value pairs
succeeds but does nothing.
"""
- headers = {"Content-Type" : ProfileController.MEDIA_TYPE}
+ headers = {"Content-Type": ProfileController.MEDIA_TYPE}
expected_new_state = dict(self.storage.writable_settings)
- body, status_code, headers = self.controller.put(
- headers, json.dumps({})
- )
+ body, status_code, headers = self.controller.put(headers, json.dumps({}))
assert 200 == status_code
assert expected_new_state == self.storage.writable_settings
@@ -70,8 +70,7 @@ def profile_document(self):
assert "Oh no" == problem.debug_message
def test_get_non_dictionary_profile_document(self):
- """Test what happens if the profile_document is not a dictionary.
- """
+ """Test what happens if the profile_document is not a dictionary."""
class BadStorage(MockProfileStorage):
@property
@@ -81,12 +80,13 @@ def profile_document(self):
self.controller.storage = BadStorage()
problem = self.controller.get()
assert 500 == problem.status_code
- assert ("Profile profile_document is not a JSON object: 'Here it is!'." ==
- problem.debug_message)
+ assert (
+ "Profile profile_document is not a JSON object: 'Here it is!'."
+ == problem.debug_message
+ )
def test_get_non_dictionary_profile_document(self):
- """Test what happens if the profile_document cannot be converted to JSON.
- """
+ """Test what happens if the profile_document cannot be converted to JSON."""
class BadStorage(MockProfileStorage):
@property
@@ -102,45 +102,45 @@ def profile_document(self):
def test_put_bad_media_type(self):
"""You must send the proper media type with your PUT request."""
- headers = {"Content-Type" : "application/json"}
+ headers = {"Content-Type": "application/json"}
body = json.dumps(dict(settings={}))
problem = self.controller.put(headers, body)
assert 415 == problem.status_code
- assert ('Expected vnd.librarysimplified/user-profile+json' ==
- problem.detail)
+ assert "Expected vnd.librarysimplified/user-profile+json" == problem.detail
def test_put_invalid_json(self):
"""You can't send any random string that's not JSON."""
- headers = {"Content-Type" : ProfileController.MEDIA_TYPE}
+ headers = {"Content-Type": ProfileController.MEDIA_TYPE}
problem = self.controller.put(headers, "blah blah")
assert 400 == problem.status_code
assert "Submitted profile document was not valid JSON." == problem.detail
def test_put_non_object(self):
"""You can't send any random JSON string, it has to be an object."""
- headers = {"Content-Type" : ProfileController.MEDIA_TYPE}
+ headers = {"Content-Type": ProfileController.MEDIA_TYPE}
problem = self.controller.put(headers, json.dumps("blah blah"))
assert 400 == problem.status_code
- assert ('Submitted profile document was not a JSON object.' ==
- problem.detail)
+ assert "Submitted profile document was not a JSON object." == problem.detail
def test_attempt_to_set_read_only_setting(self):
"""You can't change the value of a setting that's not
writable.
"""
- headers = {"Content-Type" : ProfileController.MEDIA_TYPE}
+ headers = {"Content-Type": ProfileController.MEDIA_TYPE}
body = json.dumps(dict(settings=dict(key="new value")))
problem = self.controller.put(headers, body)
assert 400 == problem.status_code
assert '"key" is not a writable setting.' == problem.detail
def test_update_raises_exception(self):
-
class BadStorage(MockProfileStorage):
def update(self, settable, full):
raise Exception("Oh no")
- self.controller.storage = BadStorage(self.read_only_settings, self.writable_settings)
- headers = {"Content-Type" : ProfileController.MEDIA_TYPE}
+
+ self.controller.storage = BadStorage(
+ self.read_only_settings, self.writable_settings
+ )
+ headers = {"Content-Type": ProfileController.MEDIA_TYPE}
body = json.dumps(dict(settings=dict(writable_key="new value")))
problem = self.controller.put(headers, body)
assert 500 == problem.status_code
diff --git a/tests/util/test_datetime_helpers.py b/tests/util/test_datetime_helpers.py
index 6461c3b8b..158e0cf69 100644
--- a/tests/util/test_datetime_helpers.py
+++ b/tests/util/test_datetime_helpers.py
@@ -1,8 +1,9 @@
import datetime
+from pdb import set_trace
+
import pytest
import pytz
from parameterized import parameterized
-from pdb import set_trace
from ...util.datetime_helpers import (
datetime_utc,
@@ -12,13 +13,20 @@
utc_now,
)
+
class TestDatetimeUTC(object):
- @parameterized.expand([
- ([2021, 1, 1], "2021-01-01T00:00:00", "2021-01-01T00:00:00+00:00"),
- ([1955, 11, 5, 12], "1955-11-05T12:00:00", "1955-11-05T12:00:00+00:00"),
- ([2015, 10, 21, 4, 29], "2015-10-21T04:29:00", "2015-10-21T04:29:00+00:00"),
- ([2015, 5, 9, 9, 30, 15], "2015-05-09T09:30:15", "2015-05-09T09:30:15+00:00"),
- ])
+ @parameterized.expand(
+ [
+ ([2021, 1, 1], "2021-01-01T00:00:00", "2021-01-01T00:00:00+00:00"),
+ ([1955, 11, 5, 12], "1955-11-05T12:00:00", "1955-11-05T12:00:00+00:00"),
+ ([2015, 10, 21, 4, 29], "2015-10-21T04:29:00", "2015-10-21T04:29:00+00:00"),
+ (
+ [2015, 5, 9, 9, 30, 15],
+ "2015-05-09T09:30:15",
+ "2015-05-09T09:30:15+00:00",
+ ),
+ ]
+ )
def test_datetime_utc(self, time, formatted, isoformat):
"""`datetime_utc` is a wrapper around `datetime.datetime` but it also
includes UTC information when it is created.
@@ -38,6 +46,7 @@ def test_datetime_utc(self, time, formatted, isoformat):
assert util_dt.month == time[1]
assert util_dt.day == time[2]
+
class TestFromTimestamp(object):
def test_from_timestamp(self):
"""`from_timestamp` is a wrapper around `datetime.fromtimestamp`
@@ -56,6 +65,7 @@ def test_from_timestamp(self):
assert util_from_ts.tzinfo is not None
assert util_from_ts.tzinfo == pytz.UTC
+
class TestUTCNow(object):
def test_utc_now(self):
"""`utc_now` is a wrapper around `datetime.now` but it also includes
@@ -63,23 +73,24 @@ def test_utc_now(self):
"""
datetime_now = datetime.datetime.now(tz=pytz.UTC)
util_now = utc_now()
-
+
# Same time but it's going to be off by a few milliseconds.
assert (datetime_now - util_now).total_seconds() < 2
-
+
# The UTC information for this datetime object is the pytz UTC value.
assert util_now.tzinfo == pytz.UTC
-
+
+
class TestToUTC(object):
def test_to_utc(self):
# `utc` marks a naive datetime object as being UTC, or
# converts a timezone-aware datetime object to UTC.
d1 = datetime.datetime(2021, 1, 1)
d2 = datetime.datetime.strptime("2020", "%Y")
-
+
assert d1.tzinfo is None
assert d2.tzinfo is None
-
+
d1_utc = to_utc(d1)
d2_utc = to_utc(d2)
@@ -90,7 +101,7 @@ def test_to_utc(self):
# The timezone information is from pytz UTC.
assert d1_utc.tzinfo == pytz.UTC
assert d2_utc.tzinfo == pytz.UTC
-
+
# Passing in None gets you None.
assert to_utc(None) == None
@@ -103,10 +114,12 @@ def test_to_utc(self):
d1_eastern = d1_utc.astimezone(pytz.timezone("US/Eastern"))
assert d1_utc == to_utc(d1_eastern)
- @parameterized.expand([
- ([2021, 1, 1], "2021-01-01", "%Y-%m-%d"),
- ([1955, 11, 5, 12], "1955-11-05T12:00:00", "%Y-%m-%dT%H:%M:%S")
- ])
+ @parameterized.expand(
+ [
+ ([2021, 1, 1], "2021-01-01", "%Y-%m-%d"),
+ ([1955, 11, 5, 12], "1955-11-05T12:00:00", "%Y-%m-%dT%H:%M:%S"),
+ ]
+ )
def test_strptime_utc(self, expect, date_string, format):
assert strptime_utc(date_string, format) == datetime_utc(*expect)
@@ -115,5 +128,7 @@ def test_strptime_utc_error(self):
# mention a timezone.
with pytest.raises(ValueError) as excinfo:
strptime_utc("2020-01-01T12:00:00+0300", "%Y-%m-%dT%H:%M:%S%z")
- assert ("Cannot use strptime_utc with timezone-aware format %Y-%m-%dT%H:%M:%S%z"
- in str(excinfo.value))
+ assert (
+ "Cannot use strptime_utc with timezone-aware format %Y-%m-%dT%H:%M:%S%z"
+ in str(excinfo.value)
+ )
diff --git a/tests/util/test_flask_util.py b/tests/util/test_flask_util.py
index 37667bbda..0b4a53df3 100644
--- a/tests/util/test_flask_util.py
+++ b/tests/util/test_flask_util.py
@@ -3,23 +3,25 @@
import datetime
import time
-from flask import Response as FlaskResponse
from wsgiref.handlers import format_date_time
-from ...util.flask_util import (
- OPDSEntryResponse,
- OPDSFeedResponse,
- Response,
-)
-from ...util.opds_writer import OPDSFeed
+from flask import Response as FlaskResponse
+
from ...util.datetime_helpers import utc_now
+from ...util.flask_util import OPDSEntryResponse, OPDSFeedResponse, Response
+from ...util.opds_writer import OPDSFeed
-class TestResponse(object):
+class TestResponse(object):
def test_constructor(self):
response = Response(
- "content", 401, dict(Header="value"), "mime/type",
- "content/type", True, 1002
+ "content",
+ 401,
+ dict(Header="value"),
+ "mime/type",
+ "content/type",
+ True,
+ 1002,
)
assert 1002 == response.max_age
assert isinstance(response, FlaskResponse)
@@ -29,9 +31,9 @@ def test_constructor(self):
# Response.headers is tested in more detail below.
headers = response.headers
- assert "value" == headers['Header']
- assert 'Cache-Control' in headers
- assert 'Expires' in headers
+ assert "value" == headers["Header"]
+ assert "Cache-Control" in headers
+ assert "Expires" in headers
def test_headers(self):
# First, test cases where the response should be private and
@@ -39,37 +41,36 @@ def test_headers(self):
# messages.
def assert_not_cached(max_age):
headers = Response(max_age=max_age).headers
- assert "private, no-cache" == headers['Cache-Control']
- assert 'Authorization' == headers['Vary']
- assert 'Expires' not in headers
+ assert "private, no-cache" == headers["Cache-Control"]
+ assert "Authorization" == headers["Vary"]
+ assert "Expires" not in headers
+
assert_not_cached(max_age=None)
assert_not_cached(max_age=0)
assert_not_cached(max_age="Not a number")
# Test the case where the response is public but should not be cached.
headers = Response(max_age=0, private=False).headers
- assert "public, no-cache" == headers['Cache-Control']
- assert 'Vary' not in headers
+ assert "public, no-cache" == headers["Cache-Control"]
+ assert "Vary" not in headers
# Test the case where the response is private but may be
# cached privately.
headers = Response(max_age=300, private=True).headers
- assert "private, no-transform, max-age=300" == headers['Cache-Control']
- assert 'Authorization' == headers['Vary']
+ assert "private, no-transform, max-age=300" == headers["Cache-Control"]
+ assert "Authorization" == headers["Vary"]
# Test the case where the response is public and may be cached,
# including by intermediaries.
- max_age = 60*60*24*12
+ max_age = 60 * 60 * 24 * 12
obj = Response(max_age=max_age)
headers = obj.headers
- cc = headers['Cache-Control']
- assert cc == 'public, no-transform, max-age=1036800, s-maxage=518400'
+ cc = headers["Cache-Control"]
+ assert cc == "public, no-transform, max-age=1036800, s-maxage=518400"
# We expect the Expires header to look basically like this.
- expect_expires = (
- utc_now() + datetime.timedelta(seconds=max_age)
- )
+ expect_expires = utc_now() + datetime.timedelta(seconds=max_age)
expect_expires_string = format_date_time(
time.mktime(expect_expires.timetuple())
)
@@ -77,17 +78,17 @@ def assert_not_cached(max_age):
# We'll only check the date part of the Expires header, to
# minimize the changes of spurious failures based on
# unfortunate timing.
- expires = headers['Expires']
+ expires = headers["Expires"]
assert expires[:17] == expect_expires_string[:17]
# It's possible to have a response that is private but should
# be cached. The feed of a patron's current loans is a good
# example.
response = Response(max_age=30, private=True)
- cache_control = response.headers['Cache-Control']
- assert 'private' in cache_control
- assert 'max-age=30' in cache_control
- assert 'Authorization' == response.headers['Vary']
+ cache_control = response.headers["Cache-Control"]
+ assert "private" in cache_control
+ assert "max-age=30" in cache_control
+ assert "Authorization" == response.headers["Vary"]
def test_unicode(self):
# You can easily convert a Response object to Unicode
@@ -98,6 +99,7 @@ def test_unicode(self):
class TestOPDSFeedResponse(object):
"""Test the OPDS feed-specific specialization of Response."""
+
def test_defaults(self):
# OPDSFeedResponse provides reasonable defaults for
# `mimetype` and `max_age`.
@@ -113,8 +115,7 @@ def test_defaults(self):
# These defaults can be overridden.
override_defaults = c(
- "a feed", 200, dict(Header="value"), "mime/type",
- "content/type", True, 1002
+ "a feed", 200, dict(Header="value"), "mime/type", "content/type", True, 1002
)
assert 1002 == override_defaults.max_age
@@ -127,8 +128,10 @@ def test_defaults(self):
do_not_cache = c(max_age=0)
assert 0 == do_not_cache.max_age
+
class TestOPDSEntryResponse(object):
"""Test the OPDS entry-specific specialization of Response."""
+
def test_defaults(self):
# OPDSEntryResponse provides a reasonable defaults for
# `mimetype`.
@@ -142,6 +145,6 @@ def test_defaults(self):
assert OPDSFeed.ATOM_TYPE == use_defaults.mimetype
# These defaults can be overridden.
- override_defaults = c("an entry", content_type= "content/type")
+ override_defaults = c("an entry", content_type="content/type")
assert "content/type" == override_defaults.content_type
assert "content/type" == override_defaults.mimetype
diff --git a/tests/util/test_http.py b/tests/util/test_http.py
index 7cf3b02b0..ac071a2e4 100644
--- a/tests/util/test_http.py
+++ b/tests/util/test_http.py
@@ -1,30 +1,32 @@
+import json
+
import pytest
import requests
-import json
+
+from ...problem_details import INVALID_INPUT
+from ...testing import MockRequestsResponse
from ...util.http import (
HTTP,
+ INTEGRATION_ERROR,
BadResponseException,
RemoteIntegrationException,
RequestNetworkException,
RequestTimedOut,
- INTEGRATION_ERROR,
)
-from ...testing import MockRequestsResponse
from ...util.problem_detail import ProblemDetail
-from ...problem_details import INVALID_INPUT
-class TestHTTP(object):
+class TestHTTP(object):
def test_series(self):
m = HTTP.series
assert "2xx" == m(201)
assert "3xx" == m(399)
assert "5xx" == m(500)
-
def test_request_with_timeout_success(self):
called_with = None
+
def fake_200_response(*args, **kwargs):
# The HTTP method and URL are passed in the order
# requests.request would expect.
@@ -35,7 +37,7 @@ def fake_200_response(*args, **kwargs):
assert "value" == kwargs["kwarg"]
# A default timeout is added.
- assert 20 == kwargs['timeout']
+ assert 20 == kwargs["timeout"]
return MockRequestsResponse(200, content="Success!")
response = HTTP._request_with_timeout(
@@ -45,7 +47,6 @@ def fake_200_response(*args, **kwargs):
assert b"Success!" == response.content
def test_request_with_timeout_failure(self):
-
def immediately_timeout(*args, **kwargs):
raise requests.exceptions.Timeout("I give up")
@@ -54,7 +55,6 @@ def immediately_timeout(*args, **kwargs):
assert "Timeout accessing http://url/: I give up" in str(excinfo.value)
def test_request_with_network_failure(self):
-
def immediately_fail(*args, **kwargs):
raise requests.exceptions.ConnectionError("a disaster")
@@ -63,13 +63,15 @@ def immediately_fail(*args, **kwargs):
assert "Network error contacting http://url/: a disaster" in str(excinfo.value)
def test_request_with_response_indicative_of_failure(self):
-
def fake_500_response(*args, **kwargs):
return MockRequestsResponse(500, content="Failure!")
with pytest.raises(BadResponseException) as excinfo:
HTTP._request_with_timeout("http://url/", fake_500_response, "a", "b")
- assert "Bad response from http://url/: Got status code 500 from external server" in str(excinfo.value)
+ assert (
+ "Bad response from http://url/: Got status code 500 from external server"
+ in str(excinfo.value)
+ )
def test_allowed_response_codes(self):
"""Test our ability to raise BadResponseException when
@@ -92,8 +94,11 @@ def fake_200_response(*args, **kwargs):
# You can say that certain codes are specifically allowed, and
# all others are forbidden.
with pytest.raises(BadResponseException) as excinfo:
- m(url, fake_401_response, allowed_response_codes = [201, 200])
- assert "Bad response from http://url/: Got status code 401 from external server, but can only continue on: 200, 201." in str(excinfo.value)
+ m(url, fake_401_response, allowed_response_codes=[201, 200])
+ assert (
+ "Bad response from http://url/: Got status code 401 from external server, but can only continue on: 200, 201."
+ in str(excinfo.value)
+ )
response = m(url, fake_401_response, allowed_response_codes=[401])
response = m(url, fake_401_response, allowed_response_codes=["4xx"])
@@ -101,27 +106,34 @@ def fake_200_response(*args, **kwargs):
# In this way you can even raise an exception on a 200 response code.
with pytest.raises(BadResponseException) as excinfo:
m(url, fake_200_response, allowed_response_codes=[401])
- assert "Bad response from http://url/: Got status code 200 from external server, but can only continue on: 401." in str(excinfo.value)
+ assert (
+ "Bad response from http://url/: Got status code 200 from external server, but can only continue on: 401."
+ in str(excinfo.value)
+ )
# You can say that certain codes are explicitly forbidden, and
# all others are allowed.
with pytest.raises(BadResponseException) as excinfo:
m(url, fake_401_response, disallowed_response_codes=[401])
- assert "Bad response from http://url/: Got status code 401 from external server, cannot continue." in str(excinfo.value)
+ assert (
+ "Bad response from http://url/: Got status code 401 from external server, cannot continue."
+ in str(excinfo.value)
+ )
with pytest.raises(BadResponseException) as excinfo:
m(url, fake_200_response, disallowed_response_codes=["2xx", 301])
- assert "Bad response from http://url/: Got status code 200 from external server, cannot continue." in str(excinfo.value)
+ assert (
+ "Bad response from http://url/: Got status code 200 from external server, cannot continue."
+ in str(excinfo.value)
+ )
- response = m(url, fake_401_response,
- disallowed_response_codes=["2xx"])
+ response = m(url, fake_401_response, disallowed_response_codes=["2xx"])
assert 401 == response.status_code
# The exception can be turned into a useful problem detail document.
exc = None
try:
- m(url, fake_200_response,
- disallowed_response_codes=["2xx"])
+ m(url, fake_200_response, disallowed_response_codes=["2xx"])
except Exception as e:
exc = e
pass
@@ -135,18 +147,28 @@ def fake_200_response(*args, **kwargs):
#
assert 502 == debug_doc.status_code
assert "Bad response" == debug_doc.title
- assert 'The server made a request to http://url/, and got an unexpected or invalid response.' == debug_doc.detail
- assert 'Bad response from http://url/: Got status code 200 from external server, cannot continue.\n\nResponse content: Hurray' == debug_doc.debug_message
+ assert (
+ "The server made a request to http://url/, and got an unexpected or invalid response."
+ == debug_doc.detail
+ )
+ assert (
+ "Bad response from http://url/: Got status code 200 from external server, cannot continue.\n\nResponse content: Hurray"
+ == debug_doc.debug_message
+ )
no_debug_doc = exc.as_problem_detail_document(debug=False)
assert "Bad response" == no_debug_doc.title
- assert 'The server made a request to url, and got an unexpected or invalid response.' == no_debug_doc.detail
+ assert (
+ "The server made a request to url, and got an unexpected or invalid response."
+ == no_debug_doc.detail
+ )
assert None == no_debug_doc.debug_message
def test_unicode_converted_to_utf8(self):
"""Any Unicode that sneaks into the URL, headers or body is
converted to UTF-8.
"""
+
class ResponseGenerator(object):
def __init__(self):
self.requests = []
@@ -158,18 +180,20 @@ def response(self, *args, **kwargs):
generator = ResponseGenerator()
url = "http://foo"
response = HTTP._request_with_timeout(
- url, generator.response, "POST",
- headers = { "unicode header": "unicode value"},
- data="unicode data"
+ url,
+ generator.response,
+ "POST",
+ headers={"unicode header": "unicode value"},
+ data="unicode data",
)
[(args, kwargs)] = generator.requests
url, method = args
- headers = kwargs['headers']
- data = kwargs['data']
+ headers = kwargs["headers"]
+ data = kwargs["data"]
# All the Unicode data was converted to bytes before being sent
# "over the wire".
- for k,v in list(headers.items()):
+ for k, v in list(headers.items()):
assert isinstance(k, bytes)
assert isinstance(v, bytes)
assert isinstance(data, bytes)
@@ -180,6 +204,7 @@ class Mock(HTTP):
def _request_with_timeout(cls, *args, **kwargs):
cls.called_with = (args, kwargs)
return "response"
+
def mock_request(*args, **kwargs):
response = MockRequestsResponse(200, "Success!")
return response
@@ -215,25 +240,26 @@ def test_process_debuggable_response(self):
problem = m("url", error)
assert isinstance(problem, ProblemDetail)
assert INTEGRATION_ERROR.uri == problem.uri
- assert ("Remote service returned a problem detail document: %r" % content ==
- problem.detail)
+ assert (
+ "Remote service returned a problem detail document: %r" % content
+ == problem.detail
+ )
assert content == problem.debug_message
# You can force a response to be treated as successful by
# passing in its response code as allowed_response_codes.
assert error == m("url", error, allowed_response_codes=[400])
assert error == m("url", error, allowed_response_codes=["400"])
- assert error == m("url", error, allowed_response_codes=['4xx'])
+ assert error == m("url", error, allowed_response_codes=["4xx"])
-class TestRemoteIntegrationException(object):
+class TestRemoteIntegrationException(object):
def test_with_service_name(self):
"""You don't have to provide a URL when creating a
RemoteIntegrationException; you can just provide the service
name.
"""
exc = RemoteIntegrationException(
- "Unreliable Service",
- "I just can't handle your request right now."
+ "Unreliable Service", "I just can't handle your request right now."
)
# Since only the service name is provided, there are no details to
@@ -242,11 +268,13 @@ def test_with_service_name(self):
other_detail = exc.document_detail(debug=False)
assert debug_detail == other_detail
- assert ('The server tried to access Unreliable Service but the third-party service experienced an error.' ==
- debug_detail)
+ assert (
+ "The server tried to access Unreliable Service but the third-party service experienced an error."
+ == debug_detail
+ )
-class TestBadResponseException(object):
+class TestBadResponseException(object):
def test_helper_constructor(self):
response = MockRequestsResponse(102, content="nonsense")
exc = BadResponseException.from_response(
@@ -258,54 +286,66 @@ def test_helper_constructor(self):
doc, status_code, headers = exc.as_problem_detail_document(debug=True).response
doc = json.loads(doc)
- assert 'Bad response' == doc['title']
- assert 'The server made a request to http://url/, and got an unexpected or invalid response.' == doc['detail']
+ assert "Bad response" == doc["title"]
assert (
- 'Bad response from http://url/: Terrible response, just terrible\n\nStatus code: 102\nContent: nonsense' ==
- doc['debug_message']
+ "The server made a request to http://url/, and got an unexpected or invalid response."
+ == doc["detail"]
+ )
+ assert (
+ "Bad response from http://url/: Terrible response, just terrible\n\nStatus code: 102\nContent: nonsense"
+ == doc["debug_message"]
)
# Unless debug is turned off, in which case none of that
# information is present.
doc, status_code, headers = exc.as_problem_detail_document(debug=False).response
- assert 'debug_message' not in json.loads(doc)
+ assert "debug_message" not in json.loads(doc)
def test_bad_status_code_helper(object):
response = MockRequestsResponse(500, content="Internal Server Error!")
- exc = BadResponseException.bad_status_code(
- "http://url/", response
- )
+ exc = BadResponseException.bad_status_code("http://url/", response)
doc, status_code, headers = exc.as_problem_detail_document(debug=True).response
doc = json.loads(doc)
- assert doc['debug_message'].startswith("Bad response from http://url/: Got status code 500 from external server, cannot continue.")
+ assert doc["debug_message"].startswith(
+ "Bad response from http://url/: Got status code 500 from external server, cannot continue."
+ )
def test_as_problem_detail_document(self):
exception = BadResponseException(
- "http://url/", "What even is this",
- debug_message="some debug info"
+ "http://url/", "What even is this", debug_message="some debug info"
)
document = exception.as_problem_detail_document(debug=True)
assert 502 == document.status_code
assert "Bad response" == document.title
- assert ("The server made a request to http://url/, and got an unexpected or invalid response." ==
- document.detail)
- assert "Bad response from http://url/: What even is this\n\nsome debug info" == document.debug_message
+ assert (
+ "The server made a request to http://url/, and got an unexpected or invalid response."
+ == document.detail
+ )
+ assert (
+ "Bad response from http://url/: What even is this\n\nsome debug info"
+ == document.debug_message
+ )
class TestRequestTimedOut(object):
-
def test_as_problem_detail_document(self):
exception = RequestTimedOut("http://url/", "I give up")
debug_detail = exception.as_problem_detail_document(debug=True)
assert "Timeout" == debug_detail.title
- assert 'The server made a request to http://url/, and that request timed out.' == debug_detail.detail
+ assert (
+ "The server made a request to http://url/, and that request timed out."
+ == debug_detail.detail
+ )
# If we're not in debug mode, we hide the URL we accessed and just
# show the hostname.
standard_detail = exception.as_problem_detail_document(debug=False)
- assert "The server made a request to url, and that request timed out." == standard_detail.detail
+ assert (
+ "The server made a request to url, and that request timed out."
+ == standard_detail.detail
+ )
# The status code corresponding to an upstream timeout is 502.
document, status_code, headers = standard_detail.response
@@ -313,18 +353,23 @@ def test_as_problem_detail_document(self):
class TestRequestNetworkException(object):
-
def test_as_problem_detail_document(self):
exception = RequestNetworkException("http://url/", "Colossal failure")
debug_detail = exception.as_problem_detail_document(debug=True)
assert "Network failure contacting third-party service" == debug_detail.title
- assert 'The server experienced a network error while contacting http://url/.' == debug_detail.detail
+ assert (
+ "The server experienced a network error while contacting http://url/."
+ == debug_detail.detail
+ )
# If we're not in debug mode, we hide the URL we accessed and just
# show the hostname.
standard_detail = exception.as_problem_detail_document(debug=False)
- assert "The server experienced a network error while contacting url." == standard_detail.detail
+ assert (
+ "The server experienced a network error while contacting url."
+ == standard_detail.detail
+ )
# The status code corresponding to an upstream timeout is 502.
document, status_code, headers = standard_detail.response
diff --git a/tests/util/test_languages.py b/tests/util/test_languages.py
index efb3867d6..6fa8aa7e6 100644
--- a/tests/util/test_languages.py
+++ b/tests/util/test_languages.py
@@ -2,55 +2,49 @@
"""Test language lookup capabilities."""
import pytest
-from ...util.languages import (
- LanguageCodes,
- LanguageNames,
- LookupTable,
-)
+from ...util.languages import LanguageCodes, LanguageNames, LookupTable
class TestLookupTable(object):
-
def test_lookup(self):
d = LookupTable()
- d['key'] = 'value'
- assert 'value' == d['key']
- assert None == d['missing']
- assert False == ('missing' in d)
- assert None == d['missing']
+ d["key"] = "value"
+ assert "value" == d["key"]
+ assert None == d["missing"]
+ assert False == ("missing" in d)
+ assert None == d["missing"]
class TestLanguageCodes(object):
-
def test_lookups(self):
c = LanguageCodes
- assert "eng" == c.two_to_three['en']
- assert "en" == c.three_to_two['eng']
- assert ["English"] == c.english_names['en']
- assert ["English"] == c.english_names['eng']
- assert ["English"] == c.native_names['en']
- assert ["English"] == c.native_names['eng']
-
- assert "spa" == c.two_to_three['es']
- assert "es" == c.three_to_two['spa']
- assert ['Spanish', 'Castilian'] == c.english_names['es']
- assert ['Spanish', 'Castilian'] == c.english_names['spa']
- assert ["español", "castellano"] == c.native_names['es']
- assert ["español", "castellano"] == c.native_names['spa']
-
- assert "chi" == c.two_to_three['zh']
- assert "zh" == c.three_to_two['chi']
- assert ["Chinese"] == c.english_names['zh']
- assert ["Chinese"] == c.english_names['chi']
+ assert "eng" == c.two_to_three["en"]
+ assert "en" == c.three_to_two["eng"]
+ assert ["English"] == c.english_names["en"]
+ assert ["English"] == c.english_names["eng"]
+ assert ["English"] == c.native_names["en"]
+ assert ["English"] == c.native_names["eng"]
+
+ assert "spa" == c.two_to_three["es"]
+ assert "es" == c.three_to_two["spa"]
+ assert ["Spanish", "Castilian"] == c.english_names["es"]
+ assert ["Spanish", "Castilian"] == c.english_names["spa"]
+ assert ["español", "castellano"] == c.native_names["es"]
+ assert ["español", "castellano"] == c.native_names["spa"]
+
+ assert "chi" == c.two_to_three["zh"]
+ assert "zh" == c.three_to_two["chi"]
+ assert ["Chinese"] == c.english_names["zh"]
+ assert ["Chinese"] == c.english_names["chi"]
# We don't have this translation yet.
- assert [] == c.native_names['zh']
- assert [] == c.native_names['chi']
+ assert [] == c.native_names["zh"]
+ assert [] == c.native_names["chi"]
- assert None == c.two_to_three['nosuchlanguage']
- assert None == c.three_to_two['nosuchlanguage']
- assert [] == c.english_names['nosuchlanguage']
- assert [] == c.native_names['nosuchlanguage']
+ assert None == c.two_to_three["nosuchlanguage"]
+ assert None == c.three_to_two["nosuchlanguage"]
+ assert [] == c.english_names["nosuchlanguage"]
+ assert [] == c.native_names["nosuchlanguage"]
def test_locale(self):
m = LanguageCodes.iso_639_2_for_locale
@@ -76,10 +70,10 @@ def test_name_for_languageset(self):
assert "" == m([])
assert "English" == m(["en"])
assert "English" == m(["eng"])
- assert "español" == m(['es'])
+ assert "español" == m(["es"])
assert "English/español" == m(["eng", "spa"])
assert "español/English" == m("spa,eng")
- assert "español/English/Chinese" == m(["spa","eng","chi"])
+ assert "español/English/Chinese" == m(["spa", "eng", "chi"])
pytest.raises(ValueError, m, ["eng, nxx"])
@@ -110,43 +104,42 @@ def coded(name, code):
coded("espanol", "spa")
coded("castellano", "spa")
for item in LanguageCodes.NATIVE_NAMES_RAW_DATA:
- coded(item['nativeName'].lower(),
- LanguageCodes.two_to_three[item['code']])
+ coded(item["nativeName"].lower(), LanguageCodes.two_to_three[item["code"]])
# Languages associated with a historical period are not mapped
# to codes.
- assert set() == d['irish, old (to 900)']
+ assert set() == d["irish, old (to 900)"]
# This general rule would exclude Greek ("Greek, Modern
# (1453-)") and Occitan ("Occitan (post 1500)"), so we added
# them manually.
- coded('greek', 'gre')
- coded('occitan', 'oci')
+ coded("greek", "gre")
+ coded("occitan", "oci")
# Languages associated with a geographical area, such as "Luo
# (Kenya and Tanzania)", can be looked up without that area.
- coded('luo', 'luo')
+ coded("luo", "luo")
# This causes a little problem for Tonga: there are two
# unrelated languages called 'Tonga', and the geographic area
# is the only way to distinguish them. For now, we map 'tonga'
# to both ISO codes. (This is why name_to_codes is called that
# rather than name_to_code.)
- assert set(['ton', 'tog']) == d['tonga']
+ assert set(["ton", "tog"]) == d["tonga"]
# Language families such as "Himacahli languages" can be
# looked up without the " languages".
- coded('himachali', 'him')
+ coded("himachali", "him")
# Language groups such as "Bantu (Other)" can be looked up
# without the "(Other)".
- coded('south american indian', 'sai')
- coded('bantu', 'bnt')
+ coded("south american indian", "sai")
+ coded("bantu", "bnt")
# If a language is known by multiple English names, lookup on
# any of those names will work.
for i in "Blissymbols; Blissymbolics; Bliss".split(";"):
- coded(i.strip().lower(), 'zbl')
+ coded(i.strip().lower(), "zbl")
def test_name_re(self):
# Verify our ability to find language names inside text.
diff --git a/tests/util/test_opds_writer.py b/tests/util/test_opds_writer.py
index 7fcb7379d..fc9110ef4 100644
--- a/tests/util/test_opds_writer.py
+++ b/tests/util/test_opds_writer.py
@@ -1,20 +1,18 @@
-import re
import datetime
-from parameterized import parameterized
+import re
+
import pytz
from lxml import etree
-from ...util.opds_writer import (
- AtomFeed,
- OPDSMessage
-)
+from parameterized import parameterized
+from ...util.opds_writer import AtomFeed, OPDSMessage
-class TestOPDSMessage(object):
+class TestOPDSMessage(object):
def test_equality(self):
a = OPDSMessage("urn", 200, "message")
- assert a ==a
+ assert a == a
assert a != None
assert a != "message"
@@ -30,21 +28,20 @@ def test_tag(self):
assert text == str(a)
# Verify that we start with a simplified:message tag.
- assert text.startswith('urn