diff --git a/tests/common.py b/tests/common.py index bdc083ffd..d859f0c43 100644 --- a/tests/common.py +++ b/tests/common.py @@ -76,8 +76,8 @@ def tearDown(self): @classmethod def tearDownClass(cls): - cls.app_context.pop() db.engine.dispose() # fix too many client issues + cls.app_context.pop() class ApiBaseTest(BaseTestCase): diff --git a/tests/test_app.py b/tests/test_app.py index 540f55299..07aeb7ab1 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,6 +1,10 @@ from flask import current_app from tests.common import BaseTestCase from sqlalchemy import text +import unittest +from unittest.mock import patch +from webservices.rest import create_app +from webservices.common.models import db class TestBaseTestCase(BaseTestCase): @@ -27,3 +31,61 @@ def test_cors_headers(self): response = self.client.get('/swagger') assert "Access-Control-Allow-Origin" in response.headers assert response.headers["Access-Control-Allow-Origin"] == "*" + + +class StatementTimeoutTest(unittest.TestCase): + """Test statement_timeout configuration - should only apply to follower (read-only) databases.""" + def test_follower_engines_have_statement_timeout(self): + """Test that follower (read-only) database engines ARE configured with statement_timeout.""" + with patch.dict('os.environ', {'SQLA_STATEMENT_TIMEOUT': '30000'}): + app = create_app(test_config='follower') + app_context = app.app_context() + app_context.push() + + try: + # Get the follower engines + follower_engines = app.config.get('SQLALCHEMY_FOLLOWERS', []) + + # Skip test if no followers configured + if not follower_engines: + self.skipTest("No follower engines configured") + + for follower_engine in follower_engines: + self.assertIsNotNone(follower_engine) + + # Execute a query to verify the timeout is set + with follower_engine.connect() as conn: + result = conn.execute(text("SHOW statement_timeout")).scalar() + self.assertIsNotNone(result, "Follower should have statement_timeout configured") + # Verify it's not the default (should be 30s or 30000ms) + self.assertNotEqual(result, '0', "Follower statement_timeout should not be 0") + + finally: + app_context.pop() + + def test_statement_timeout_only_on_followers(self): + """Test that statement_timeout is only applied to followers, not primary.""" + with patch.dict('os.environ', {'SQLA_STATEMENT_TIMEOUT': '60000'}): + app = create_app(test_config='follower') + app_context = app.app_context() + app_context.push() + + try: + primary_engine = db.engine + follower_engines = app.config.get('SQLALCHEMY_FOLLOWERS', []) + + if not follower_engines: + self.skipTest("No follower engines configured") + + # Primary should NOT have timeout + with primary_engine.connect() as conn: + primary_timeout = conn.execute(text("SHOW statement_timeout")).scalar() + self.assertEqual(primary_timeout, '0', "Primary should not have statement_timeout") + + # Followers SHOULD have timeout + with follower_engines[0].connect() as conn: + follower_timeout = conn.execute(text("SHOW statement_timeout")).scalar() + self.assertNotEqual(follower_timeout, '0', "Follower should have statement_timeout") + + finally: + app_context.pop() diff --git a/tests/test_utils.py b/tests/test_utils.py index 02352bdd5..b9192d1ba 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,7 +9,6 @@ from tests.common import ApiBaseTest from webservices import args -from flask import current_app from webservices import sorting from webservices.resources import candidate_aggregates from webservices.resources import elections @@ -310,12 +309,11 @@ def test_hide_null_election(self): self.assertEqual(results[0].total_disbursements, 0.0) -class TestArgs(TestCase): +class TestArgs(ApiBaseTest): def test_currency(self): - if current_app.config['TESTING']: - with current_app.test_request_context('?dollars=$24.50'): - parsed = flaskparser.parser.parse({'dollars': args.Currency()}, request, location='query') - self.assertEqual(parsed, {'dollars': 24.50}) + with self.application.test_request_context('?dollars=$24.50'): + parsed = flaskparser.parser.parse({'dollars': args.Currency()}, request, location='query') + self.assertEqual(parsed, {'dollars': 24.50}) class TestEnvVarSplit(TestCase): diff --git a/webservices/rest.py b/webservices/rest.py index 75a5b5ce5..fe81faf63 100644 --- a/webservices/rest.py +++ b/webservices/rest.py @@ -10,6 +10,7 @@ from marshmallow import EXCLUDE import ujson import sqlalchemy as sa +from sqlalchemy.exc import OperationalError import flask_cors as cors from nplusone.ext.flask_sqlalchemy import NPlusOne @@ -120,6 +121,7 @@ def create_app(test_config=None): app.config['PROPAGATE_EXCEPTIONS'] = True query_cache_size = int(env.get_credential('QUERY_CACHE_SIZE', '100')) pool_pre_ping = bool(env.get_credential('POOL_PRE_PING_BOOL', 'False')) + statement_timeout = int(env.get_credential('SQLA_STATEMENT_TIMEOUT', '300000')) app.config['SQLALCHEMY_ENGINE_OPTIONS'] = { 'query_cache_size': query_cache_size, 'max_overflow': 50, @@ -132,7 +134,9 @@ def create_sqlalchemy_followers(env_var_name: str, default_value: str = '') -> l followers = utils.split_env_var(env.get_credential(env_var_name, default_value)) return [sa.create_engine(follower.strip(), query_cache_size=query_cache_size, pool_size=50, max_overflow=50, pool_timeout=120, - pool_pre_ping=pool_pre_ping) for follower in followers if follower.strip() + pool_pre_ping=pool_pre_ping, + connect_args={"options": f"-c statement_timeout={statement_timeout}"},) + for follower in followers if follower.strip() ] # app.config['SQLALCHEMY_ECHO'] = True @@ -364,7 +368,6 @@ def api_spec(): return jsonify(spec.spec.to_dict()) app.register_blueprint(docs) - app.app_context().push() parser = FlaskRestParser() app.config['APISPEC_WEBARGS_PARSER'] = parser @@ -499,6 +502,13 @@ def add_secure_headers(response): response.headers.add(header, value) return response + @app.errorhandler(OperationalError) + def handle_db_timeout(e): + if hasattr(e, 'orig') and 'canceling statement due to statement timeout' in str(e.orig): + app.logger.warning('Statement timeout on %s', request.path) + return jsonify({'message': 'Query timed out', 'status': 408}), 408 + return handle_exception(e) + @app.errorhandler(Exception) def handle_exception(exception): wrapped = ResponseException(str(exception), ErrorCode.INTERNAL_ERROR, type(exception))