diff --git a/services/refactored_news.py b/services/refactored_news.py new file mode 100644 index 0000000..35d55e3 --- /dev/null +++ b/services/refactored_news.py @@ -0,0 +1,168 @@ +""" +Refactored News Service using Base Service pattern +""" + +import ssl +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional + +import nltk +import requests +from nltk.sentiment.vader import SentimentIntensityAnalyzer +from sqlalchemy import desc +from sqlalchemy.orm import Session + +from config import Config +from etl.news_etl import run_news_etl_pipeline +from models.db_models import NewsArticle +from services.base_service import BaseDataService +from utils.cache import adaptive_ttl_cache, rate_limited_api +from utils.logging_config import logger + +# Fix SSL certificate issues for NLTK downloads +try: + _create_unverified_https_context = ssl._create_unverified_context +except AttributeError: + pass +else: + ssl._create_default_https_context = _create_unverified_https_context + +# Ensure the lexicon is available +nltk.download("vader_lexicon", quiet=True) +sia = SentimentIntensityAnalyzer() + + +class NewsService(BaseDataService): + """Service for fetching news data""" + + model_class = NewsArticle + data_type = "news" + cache_ttl = 3600 * 2 # 2 hours + cache_max_ttl = 3600 * 6 # 6 hours + etl_timeout = 10 # seconds + + @classmethod + @adaptive_ttl_cache(base_ttl=3600 * 2, max_ttl=3600 * 6, error_ttl=300) + def fetch_news(cls, symbol: str, days: int = 30) -> List[Dict[str, Any]]: + """ + Public API to fetch news for a given stock symbol. + Uses the base service pattern with specific handling for news data. + """ + from models.db_models import SessionLocal + + session = SessionLocal() + try: + result = cls.fetch_data(session, symbol, days=days) + return result.get("data", []) + finally: + session.close() + + @classmethod + def _query_database( + cls, session: Session, symbol: str, **kwargs + ) -> List[NewsArticle]: + """Query news articles from the database""" + days = kwargs.get("days", 30) + + return ( + session.query(NewsArticle) + .filter(NewsArticle.symbol == symbol) + .order_by(desc(NewsArticle.datetime)) + .limit(days) # Use days as limit for number of articles + .all() + ) + + @classmethod + def _format_records( + cls, records: List[NewsArticle], source: str = "database" + ) -> Dict[str, Any]: + """Format news records for the API response""" + data = [] + for record in records: + article_data = { + "headline": record.headline, + "summary": record.summary, + "url": record.url, + "source": record.source, + "datetime": int(record.datetime.timestamp()) + if record.datetime + else None, + "sentiment": record.sentiment, + "category": record.category, + "related": record.related, + "image_url": record.image_url, + } + data.append(article_data) + return {"data": data, "source": source} + + @classmethod + def _run_etl_pipeline(cls, symbol: str) -> None: + """Run the news ETL pipeline""" + run_news_etl_pipeline(symbol) + + @classmethod + def _try_alternative_sources(cls, symbol: str, **kwargs) -> Dict[str, Any]: + """Try alternative data sources for news""" + days = kwargs.get("days", 30) + + logger.info(f"No news found in database for {symbol}, trying direct API") + + # For news, we don't have alternative sources like Yahoo Finance + # So we fall back directly to the legacy API call + return cls._legacy_fetch_news(symbol, days) + + @classmethod + @rate_limited_api(calls_per_minute=10) + def _legacy_fetch_news(cls, symbol: str, days: int = 30) -> Dict[str, Any]: + """Legacy method that directly calls the API.""" + end_date = datetime.today().strftime("%Y-%m-%d") + start_date = (datetime.today() - timedelta(days=days)).strftime("%Y-%m-%d") + + url = "https://finnhub.io/api/v1/company-news" + params = { + "symbol": symbol, + "from": start_date, + "to": end_date, + "token": Config.FINNHUB_API_KEY, + } + + try: + logger.info(f"[LEGACY] Fetching news for {symbol} via API") + response = requests.get(url, params=params, timeout=10) + response.raise_for_status() + articles = response.json() + + if isinstance(articles, list): + # Add sentiment analysis to each article + news_data = [] + for article in articles: + headline = article.get("headline", "") + sentiment_score = sia.polarity_scores(headline)["compound"] + + news_item = { + "headline": headline, + "summary": article.get("summary", None), + "url": article.get("url", None), + "source": article.get("source", None), + "datetime": article.get("datetime", None), + "sentiment": sentiment_score, + "category": article.get("category", None), + "related": article.get("related", None), + "image_url": article.get("image", None), + } + news_data.append(news_item) + + logger.info( + f"[LEGACY] Retrieved {len(news_data)} news articles for {symbol}" + ) + return {"data": news_data, "source": "finnhub"} + else: + logger.warning(f"[LEGACY] Unexpected news format for {symbol}") + return {"data": [], "error": "No news data available"} + + except requests.exceptions.Timeout: + logger.error(f"[LEGACY] Timeout fetching news for {symbol}") + return {"data": [], "error": "API timeout"} + except Exception as e: + logger.error(f"[LEGACY] Error fetching news for {symbol}", exc_info=True) + return {"data": [], "error": str(e)} diff --git a/services/service_adapter.py b/services/service_adapter.py index c8bae33..853b84a 100644 --- a/services/service_adapter.py +++ b/services/service_adapter.py @@ -10,8 +10,10 @@ # Import both old and new service implementations from services.financials import fetch_financials as old_fetch_financials +from services.news import fetch_company_news as old_fetch_news from services.refactored_earnings import EarningsService from services.refactored_financials import FinancialsService +from services.refactored_news import NewsService from utils.logging_config import logger # Feature toggle to enable/disable refactored services @@ -68,6 +70,8 @@ def wrapper(*args, **kwargs): fetch_earnings = with_fallback(old_fetch_earnings, EarningsService.fetch_earnings) +fetch_news = with_fallback(old_fetch_news, NewsService.fetch_news) + # Function to completely switch to refactored implementations def use_refactored_services(enabled: bool = True) -> None: diff --git a/tests/unit/test_refactored_services.py.disabled b/tests/unit/test_refactored_services.py similarity index 62% rename from tests/unit/test_refactored_services.py.disabled rename to tests/unit/test_refactored_services.py index 53acaa0..8c0939c 100644 --- a/tests/unit/test_refactored_services.py.disabled +++ b/tests/unit/test_refactored_services.py @@ -6,9 +6,10 @@ from unittest import mock from datetime import datetime, timedelta -from models.db_models import FinancialReport, Earnings +from models.db_models import FinancialReport, Earnings, NewsArticle from services.refactored_financials import FinancialsService from services.refactored_earnings import EarningsService +from services.refactored_news import NewsService from utils.cache import clear_cache @@ -205,6 +206,115 @@ def test_financials_service_fallback_path( mock_executor.submit.assert_called_once() mock_alternative.assert_called_once() + def test_news_service_query_format(self): + """Test that the news service formats data correctly from database records""" + # Create mock news articles + news_articles = [ + NewsArticle( + symbol="TEST", + headline="Test Company Announces New Product", + summary="Test Company announced a new product today.", + url="https://example.com/news1", + source="Financial Times", + datetime=datetime.now() - timedelta(days=1), + sentiment=0.5, + category="business", + related="product", + image_url="https://example.com/image1.jpg", + ), + NewsArticle( + symbol="TEST", + headline="Test Company Stock Up 5%", + summary="Test Company's stock increased by 5% today.", + url="https://example.com/news2", + source="Bloomberg", + datetime=datetime.now() - timedelta(days=2), + sentiment=0.8, + category="markets", + related="stock", + image_url="https://example.com/image2.jpg", + ), + ] + + # Call the format method + result = NewsService._format_records(news_articles) + + # Verify the result + self.assertIsInstance(result, dict) + self.assertIn("data", result) + self.assertEqual(len(result["data"]), 2) + self.assertEqual( + result["data"][0]["headline"], "Test Company Announces New Product" + ) + self.assertEqual(result["data"][0]["sentiment"], 0.5) + self.assertEqual(result["data"][0]["source"], "Financial Times") + self.assertEqual(result["source"], "database") + + @mock.patch("services.refactored_news.NewsService._query_database") + def test_news_service_database_path(self, mock_query): + """Test the happy path when data is in database""" + # Setup mock to return data + mock_article = NewsArticle( + symbol="TEST", + headline="Test News Article", + summary="Test summary", + url="https://example.com/news", + source="Test Source", + datetime=datetime.now(), + sentiment=0.5, + ) + mock_query.return_value = [mock_article] + + # Create a mock session + mock_session = mock.MagicMock() + + # Call the method + result = NewsService.fetch_data(mock_session, "TEST") + + # Verify the result + self.assertIsInstance(result, dict) + self.assertIn("data", result) + self.assertEqual(len(result["data"]), 1) + self.assertEqual(result["source"], "database") + + # Verify the query was called correctly + mock_query.assert_called_once_with(mock_session, "TEST") + + @mock.patch("services.refactored_news.NewsService._query_database") + @mock.patch("services.base_service.ETL_EXECUTOR") + @mock.patch("services.refactored_news.NewsService._try_alternative_sources") + def test_news_service_fallback_path( + self, mock_alternative, mock_executor, mock_query + ): + """Test fallback path when data is not in database""" + # Setup mocks + mock_query.return_value = [] # No data in database + mock_future = mock.MagicMock() + mock_executor.submit.return_value = mock_future + mock_future.result.side_effect = Exception("ETL error") # ETL fails + + # Setup fallback data + mock_alternative.return_value = { + "data": [{"headline": "Test News", "sentiment": 0.5}], + "source": "finnhub", + } + + # Create a mock session + mock_session = mock.MagicMock() + + # Call the method + result = NewsService.fetch_data(mock_session, "TEST") + + # Verify the result + self.assertIsInstance(result, dict) + self.assertIn("data", result) + self.assertEqual(result["source"], "finnhub") + + # Verify mocks were called correctly + mock_query.assert_called_once() + mock_executor.submit.assert_called_once() + mock_alternative.assert_called_once() + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_service_adapter.py.disabled b/tests/unit/test_service_adapter.py similarity index 86% rename from tests/unit/test_service_adapter.py.disabled rename to tests/unit/test_service_adapter.py index 098c3ed..640189b 100644 --- a/tests/unit/test_service_adapter.py.disabled +++ b/tests/unit/test_service_adapter.py @@ -8,6 +8,7 @@ from services.service_adapter import ( fetch_financials, fetch_earnings, + fetch_news, use_refactored_services, ) @@ -32,10 +33,13 @@ def test_service_adapter_structure(self): self.assertTrue(hasattr(fetch_financials, "new_func")) self.assertTrue(hasattr(fetch_earnings, "old_func")) self.assertTrue(hasattr(fetch_earnings, "new_func")) + self.assertTrue(hasattr(fetch_news, "old_func")) + self.assertTrue(hasattr(fetch_news, "new_func")) # Verify the functions are properly wrapped self.assertEqual(fetch_financials.__name__, "fetch_financials") self.assertEqual(fetch_earnings.__name__, "fetch_earnings") + self.assertEqual(fetch_news.__name__, "fetch_company_news") if __name__ == "__main__": diff --git a/views/news.py b/views/news.py index a4e58d6..2993a63 100644 --- a/views/news.py +++ b/views/news.py @@ -1,6 +1,6 @@ from flask import Blueprint, render_template -from services.news import fetch_company_news +from services.service_adapter import fetch_news news_bp = Blueprint("news", __name__) @@ -12,7 +12,7 @@ def news(): stocks = ["SBUX", "KDP", "BROS", "FARM"] news_sections = {} for symbol in stocks: - articles = fetch_company_news(symbol) + articles = fetch_news(symbol) html = ( "