Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions services/refactored_news.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
"""
Refactored News Service using Base Service pattern
"""

import ssl
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional

import nltk
import requests
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from sqlalchemy import desc
from sqlalchemy.orm import Session

from config import Config
from etl.news_etl import run_news_etl_pipeline
from models.db_models import NewsArticle
from services.base_service import BaseDataService
from utils.cache import adaptive_ttl_cache, rate_limited_api
from utils.logging_config import logger

# Fix SSL certificate issues for NLTK downloads
try:
_create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
pass
else:
ssl._create_default_https_context = _create_unverified_https_context

# Ensure the lexicon is available
nltk.download("vader_lexicon", quiet=True)
sia = SentimentIntensityAnalyzer()


class NewsService(BaseDataService):
"""Service for fetching news data"""

model_class = NewsArticle
data_type = "news"
cache_ttl = 3600 * 2 # 2 hours
cache_max_ttl = 3600 * 6 # 6 hours
etl_timeout = 10 # seconds

@classmethod
@adaptive_ttl_cache(base_ttl=3600 * 2, max_ttl=3600 * 6, error_ttl=300)
def fetch_news(cls, symbol: str, days: int = 30) -> List[Dict[str, Any]]:
"""
Public API to fetch news for a given stock symbol.
Uses the base service pattern with specific handling for news data.
"""
from models.db_models import SessionLocal

session = SessionLocal()
try:
result = cls.fetch_data(session, symbol, days=days)
return result.get("data", [])
finally:
session.close()

@classmethod
def _query_database(
cls, session: Session, symbol: str, **kwargs
) -> List[NewsArticle]:
"""Query news articles from the database"""
days = kwargs.get("days", 30)

return (
session.query(NewsArticle)
.filter(NewsArticle.symbol == symbol)
.order_by(desc(NewsArticle.datetime))
.limit(days) # Use days as limit for number of articles
.all()
)

@classmethod
def _format_records(
cls, records: List[NewsArticle], source: str = "database"
) -> Dict[str, Any]:
"""Format news records for the API response"""
data = []
for record in records:
article_data = {
"headline": record.headline,
"summary": record.summary,
"url": record.url,
"source": record.source,
"datetime": int(record.datetime.timestamp())
if record.datetime
else None,
"sentiment": record.sentiment,
"category": record.category,
"related": record.related,
"image_url": record.image_url,
}
data.append(article_data)
return {"data": data, "source": source}

@classmethod
def _run_etl_pipeline(cls, symbol: str) -> None:
"""Run the news ETL pipeline"""
run_news_etl_pipeline(symbol)

@classmethod
def _try_alternative_sources(cls, symbol: str, **kwargs) -> Dict[str, Any]:
"""Try alternative data sources for news"""
days = kwargs.get("days", 30)

logger.info(f"No news found in database for {symbol}, trying direct API")

# For news, we don't have alternative sources like Yahoo Finance
# So we fall back directly to the legacy API call
return cls._legacy_fetch_news(symbol, days)

@classmethod
@rate_limited_api(calls_per_minute=10)
def _legacy_fetch_news(cls, symbol: str, days: int = 30) -> Dict[str, Any]:
"""Legacy method that directly calls the API."""
end_date = datetime.today().strftime("%Y-%m-%d")
start_date = (datetime.today() - timedelta(days=days)).strftime("%Y-%m-%d")

url = "https://finnhub.io/api/v1/company-news"
params = {
"symbol": symbol,
"from": start_date,
"to": end_date,
"token": Config.FINNHUB_API_KEY,
}

try:
logger.info(f"[LEGACY] Fetching news for {symbol} via API")
response = requests.get(url, params=params, timeout=10)
response.raise_for_status()
articles = response.json()

if isinstance(articles, list):
# Add sentiment analysis to each article
news_data = []
for article in articles:
headline = article.get("headline", "")
sentiment_score = sia.polarity_scores(headline)["compound"]

news_item = {
"headline": headline,
"summary": article.get("summary", None),
"url": article.get("url", None),
"source": article.get("source", None),
"datetime": article.get("datetime", None),
"sentiment": sentiment_score,
"category": article.get("category", None),
"related": article.get("related", None),
"image_url": article.get("image", None),
}
news_data.append(news_item)

logger.info(
f"[LEGACY] Retrieved {len(news_data)} news articles for {symbol}"
)
return {"data": news_data, "source": "finnhub"}
else:
logger.warning(f"[LEGACY] Unexpected news format for {symbol}")
return {"data": [], "error": "No news data available"}

except requests.exceptions.Timeout:
logger.error(f"[LEGACY] Timeout fetching news for {symbol}")
return {"data": [], "error": "API timeout"}
except Exception as e:
logger.error(f"[LEGACY] Error fetching news for {symbol}", exc_info=True)
return {"data": [], "error": str(e)}
4 changes: 4 additions & 0 deletions services/service_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

# Import both old and new service implementations
from services.financials import fetch_financials as old_fetch_financials
from services.news import fetch_company_news as old_fetch_news
from services.refactored_earnings import EarningsService
from services.refactored_financials import FinancialsService
from services.refactored_news import NewsService
from utils.logging_config import logger

# Feature toggle to enable/disable refactored services
Expand Down Expand Up @@ -68,6 +70,8 @@ def wrapper(*args, **kwargs):

fetch_earnings = with_fallback(old_fetch_earnings, EarningsService.fetch_earnings)

fetch_news = with_fallback(old_fetch_news, NewsService.fetch_news)


# Function to completely switch to refactored implementations
def use_refactored_services(enabled: bool = True) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from unittest import mock
from datetime import datetime, timedelta

from models.db_models import FinancialReport, Earnings
from models.db_models import FinancialReport, Earnings, NewsArticle
from services.refactored_financials import FinancialsService
from services.refactored_earnings import EarningsService
from services.refactored_news import NewsService
from utils.cache import clear_cache


Expand Down Expand Up @@ -205,6 +206,115 @@ def test_financials_service_fallback_path(
mock_executor.submit.assert_called_once()
mock_alternative.assert_called_once()

def test_news_service_query_format(self):
"""Test that the news service formats data correctly from database records"""
# Create mock news articles
news_articles = [
NewsArticle(
symbol="TEST",
headline="Test Company Announces New Product",
summary="Test Company announced a new product today.",
url="https://example.com/news1",
source="Financial Times",
datetime=datetime.now() - timedelta(days=1),
sentiment=0.5,
category="business",
related="product",
image_url="https://example.com/image1.jpg",
),
NewsArticle(
symbol="TEST",
headline="Test Company Stock Up 5%",
summary="Test Company's stock increased by 5% today.",
url="https://example.com/news2",
source="Bloomberg",
datetime=datetime.now() - timedelta(days=2),
sentiment=0.8,
category="markets",
related="stock",
image_url="https://example.com/image2.jpg",
),
]

# Call the format method
result = NewsService._format_records(news_articles)

# Verify the result
self.assertIsInstance(result, dict)
self.assertIn("data", result)
self.assertEqual(len(result["data"]), 2)
self.assertEqual(
result["data"][0]["headline"], "Test Company Announces New Product"
)
self.assertEqual(result["data"][0]["sentiment"], 0.5)
self.assertEqual(result["data"][0]["source"], "Financial Times")
self.assertEqual(result["source"], "database")

@mock.patch("services.refactored_news.NewsService._query_database")
def test_news_service_database_path(self, mock_query):
"""Test the happy path when data is in database"""
# Setup mock to return data
mock_article = NewsArticle(
symbol="TEST",
headline="Test News Article",
summary="Test summary",
url="https://example.com/news",
source="Test Source",
datetime=datetime.now(),
sentiment=0.5,
)
mock_query.return_value = [mock_article]

# Create a mock session
mock_session = mock.MagicMock()

# Call the method
result = NewsService.fetch_data(mock_session, "TEST")

# Verify the result
self.assertIsInstance(result, dict)
self.assertIn("data", result)
self.assertEqual(len(result["data"]), 1)
self.assertEqual(result["source"], "database")

# Verify the query was called correctly
mock_query.assert_called_once_with(mock_session, "TEST")

@mock.patch("services.refactored_news.NewsService._query_database")
@mock.patch("services.base_service.ETL_EXECUTOR")
@mock.patch("services.refactored_news.NewsService._try_alternative_sources")
def test_news_service_fallback_path(
self, mock_alternative, mock_executor, mock_query
):
"""Test fallback path when data is not in database"""
# Setup mocks
mock_query.return_value = [] # No data in database
mock_future = mock.MagicMock()
mock_executor.submit.return_value = mock_future
mock_future.result.side_effect = Exception("ETL error") # ETL fails

# Setup fallback data
mock_alternative.return_value = {
"data": [{"headline": "Test News", "sentiment": 0.5}],
"source": "finnhub",
}

# Create a mock session
mock_session = mock.MagicMock()

# Call the method
result = NewsService.fetch_data(mock_session, "TEST")

# Verify the result
self.assertIsInstance(result, dict)
self.assertIn("data", result)
self.assertEqual(result["source"], "finnhub")

# Verify mocks were called correctly
mock_query.assert_called_once()
mock_executor.submit.assert_called_once()
mock_alternative.assert_called_once()


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from services.service_adapter import (
fetch_financials,
fetch_earnings,
fetch_news,
use_refactored_services,
)

Expand All @@ -32,10 +33,13 @@ def test_service_adapter_structure(self):
self.assertTrue(hasattr(fetch_financials, "new_func"))
self.assertTrue(hasattr(fetch_earnings, "old_func"))
self.assertTrue(hasattr(fetch_earnings, "new_func"))
self.assertTrue(hasattr(fetch_news, "old_func"))
self.assertTrue(hasattr(fetch_news, "new_func"))

# Verify the functions are properly wrapped
self.assertEqual(fetch_financials.__name__, "fetch_financials")
self.assertEqual(fetch_earnings.__name__, "fetch_earnings")
self.assertEqual(fetch_news.__name__, "fetch_company_news")


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions views/news.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from flask import Blueprint, render_template

from services.news import fetch_company_news
from services.service_adapter import fetch_news

news_bp = Blueprint("news", __name__)

Expand All @@ -12,7 +12,7 @@ def news():
stocks = ["SBUX", "KDP", "BROS", "FARM"]
news_sections = {}
for symbol in stocks:
articles = fetch_company_news(symbol)
articles = fetch_news(symbol)
html = (
"<ul>"
+ "".join(
Expand Down
Loading