Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def tearDown(self):

@classmethod
def tearDownClass(cls):
cls.app_context.pop()
db.engine.dispose() # fix too many client issues
cls.app_context.pop()


class ApiBaseTest(BaseTestCase):
Expand Down
62 changes: 62 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from flask import current_app
from tests.common import BaseTestCase
from sqlalchemy import text
import unittest
from unittest.mock import patch
from webservices.rest import create_app
from webservices.common.models import db


class TestBaseTestCase(BaseTestCase):
Expand All @@ -27,3 +31,61 @@ def test_cors_headers(self):
response = self.client.get('/swagger')
assert "Access-Control-Allow-Origin" in response.headers
assert response.headers["Access-Control-Allow-Origin"] == "*"


class StatementTimeoutTest(unittest.TestCase):
"""Test statement_timeout configuration - should only apply to follower (read-only) databases."""
def test_follower_engines_have_statement_timeout(self):
"""Test that follower (read-only) database engines ARE configured with statement_timeout."""
with patch.dict('os.environ', {'SQLA_STATEMENT_TIMEOUT': '30000'}):
app = create_app(test_config='follower')
app_context = app.app_context()
app_context.push()

try:
# Get the follower engines
follower_engines = app.config.get('SQLALCHEMY_FOLLOWERS', [])

# Skip test if no followers configured
if not follower_engines:
self.skipTest("No follower engines configured")

for follower_engine in follower_engines:
self.assertIsNotNone(follower_engine)

# Execute a query to verify the timeout is set
with follower_engine.connect() as conn:
result = conn.execute(text("SHOW statement_timeout")).scalar()
self.assertIsNotNone(result, "Follower should have statement_timeout configured")
# Verify it's not the default (should be 30s or 30000ms)
self.assertNotEqual(result, '0', "Follower statement_timeout should not be 0")

finally:
app_context.pop()

def test_statement_timeout_only_on_followers(self):
"""Test that statement_timeout is only applied to followers, not primary."""
with patch.dict('os.environ', {'SQLA_STATEMENT_TIMEOUT': '60000'}):
app = create_app(test_config='follower')
app_context = app.app_context()
app_context.push()

try:
primary_engine = db.engine
follower_engines = app.config.get('SQLALCHEMY_FOLLOWERS', [])

if not follower_engines:
self.skipTest("No follower engines configured")

# Primary should NOT have timeout
with primary_engine.connect() as conn:
primary_timeout = conn.execute(text("SHOW statement_timeout")).scalar()
self.assertEqual(primary_timeout, '0', "Primary should not have statement_timeout")

# Followers SHOULD have timeout
with follower_engines[0].connect() as conn:
follower_timeout = conn.execute(text("SHOW statement_timeout")).scalar()
self.assertNotEqual(follower_timeout, '0', "Follower should have statement_timeout")

finally:
app_context.pop()
10 changes: 4 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tests.common import ApiBaseTest

from webservices import args
from flask import current_app
from webservices import sorting
from webservices.resources import candidate_aggregates
from webservices.resources import elections
Expand Down Expand Up @@ -310,12 +309,11 @@ def test_hide_null_election(self):
self.assertEqual(results[0].total_disbursements, 0.0)


class TestArgs(TestCase):
class TestArgs(ApiBaseTest):
def test_currency(self):
if current_app.config['TESTING']:
with current_app.test_request_context('?dollars=$24.50'):
parsed = flaskparser.parser.parse({'dollars': args.Currency()}, request, location='query')
self.assertEqual(parsed, {'dollars': 24.50})
with self.application.test_request_context('?dollars=$24.50'):
parsed = flaskparser.parser.parse({'dollars': args.Currency()}, request, location='query')
self.assertEqual(parsed, {'dollars': 24.50})


class TestEnvVarSplit(TestCase):
Expand Down
14 changes: 12 additions & 2 deletions webservices/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from marshmallow import EXCLUDE
import ujson
import sqlalchemy as sa
from sqlalchemy.exc import OperationalError
import flask_cors as cors
from nplusone.ext.flask_sqlalchemy import NPlusOne

Expand Down Expand Up @@ -120,6 +121,7 @@ def create_app(test_config=None):
app.config['PROPAGATE_EXCEPTIONS'] = True
query_cache_size = int(env.get_credential('QUERY_CACHE_SIZE', '100'))
pool_pre_ping = bool(env.get_credential('POOL_PRE_PING_BOOL', 'False'))
statement_timeout = int(env.get_credential('SQLA_STATEMENT_TIMEOUT', '300000'))
app.config['SQLALCHEMY_ENGINE_OPTIONS'] = {
'query_cache_size': query_cache_size,
'max_overflow': 50,
Expand All @@ -132,7 +134,9 @@ def create_sqlalchemy_followers(env_var_name: str, default_value: str = '') -> l
followers = utils.split_env_var(env.get_credential(env_var_name, default_value))
return [sa.create_engine(follower.strip(), query_cache_size=query_cache_size,
pool_size=50, max_overflow=50, pool_timeout=120,
pool_pre_ping=pool_pre_ping) for follower in followers if follower.strip()
pool_pre_ping=pool_pre_ping,
connect_args={"options": f"-c statement_timeout={statement_timeout}"},)
for follower in followers if follower.strip()
]
# app.config['SQLALCHEMY_ECHO'] = True

Expand Down Expand Up @@ -364,7 +368,6 @@ def api_spec():
return jsonify(spec.spec.to_dict())

app.register_blueprint(docs)
app.app_context().push()

parser = FlaskRestParser()
app.config['APISPEC_WEBARGS_PARSER'] = parser
Expand Down Expand Up @@ -499,6 +502,13 @@ def add_secure_headers(response):
response.headers.add(header, value)
return response

@app.errorhandler(OperationalError)
def handle_db_timeout(e):
if hasattr(e, 'orig') and 'canceling statement due to statement timeout' in str(e.orig):
app.logger.warning('Statement timeout on %s', request.path)
return jsonify({'message': 'Query timed out', 'status': 408}), 408
return handle_exception(e)

@app.errorhandler(Exception)
def handle_exception(exception):
wrapped = ResponseException(str(exception), ErrorCode.INTERNAL_ERROR, type(exception))
Expand Down