diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..84918163 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,33 @@ +{ + "name": "Python 3", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye", + "customizations": { + "codespaces": { + "openFiles": [ + "README.md", + "app.py" + ] + }, + "vscode": { + "settings": {}, + "extensions": [ + "ms-python.python", + "ms-python.vscode-pylance" + ] + } + }, + "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y + + Support This Project + + + +This is an enhanced version of the AI-powered hedge fund proof of concept. The goal of this project is to explore the use of AI to make trading decisions with an interactive web interface. This project is for **educational** purposes only and is not intended for real trading or investment. This system employs several agents working together: 1. Ben Graham Agent - The godfather of value investing, only buys hidden gems with a margin of safety -2. Bill Ackman Agent - An activist investors, takes bold positions and pushes for change +2. Bill Ackman Agent - An activist investor, takes bold positions and pushes for change 3. Cathie Wood Agent - The queen of growth investing, believes in the power of innovation and disruption 4. Charlie Munger Agent - Warren Buffett's partner, only buys wonderful businesses at fair prices 5. Stanley Druckenmiller Agent - Macro trading legend who hunts for asymmetric opportunities with explosive growth potential @@ -17,12 +23,10 @@ This system employs several agents working together: 11. Risk Manager - Calculates risk metrics and sets position limits 12. Portfolio Manager - Makes final trading decisions and generates orders -Screenshot 2025-03-08 at 4 45 22 PM +Screenshot 2025-03-08 at 4 45 22 PM **Note**: the system simulates trading decisions, it does not actually trade. -[![Twitter Follow](https://img.shields.io/twitter/follow/virattt?style=social)](https://twitter.com/virattt) - ## Disclaimer This project is for **educational and research purposes only**. @@ -38,8 +42,12 @@ By using this software, you agree to use it solely for learning purposes. ## Table of Contents - [Setup](#setup) - [Usage](#usage) - - [Running the Hedge Fund](#running-the-hedge-fund) - - [Running the Backtester](#running-the-backtester) + - [Running the Web App](#running-the-web-app) + - [Running the Hedge Fund CLI](#running-the-hedge-fund-cli) + - [Running the Backtester CLI](#running-the-backtester-cli) +- [Payment Integration](#payment-integration) + - [Stripe Setup](#stripe-setup) + - [Paddle Setup](#paddle-setup) - [Project Structure](#project-structure) - [Contributing](#contributing) - [Feature Requests](#feature-requests) @@ -49,8 +57,8 @@ By using this software, you agree to use it solely for learning purposes. Clone the repository: ```bash -git clone https://github.com/virattt/ai-hedge-fund.git -cd ai-hedge-fund +git clone https://github.com/virattt/ai-hedge-fund-plus.git +cd ai-hedge-fund-plus ``` 1. Install Poetry (if not already installed): @@ -79,12 +87,20 @@ OPENAI_API_KEY=your-openai-api-key # Get your Groq API key from https://groq.com/ GROQ_API_KEY=your-groq-api-key +# For running LLMs hosted by Anthropic (claude-3-opus, etc.) +# Get your Anthropic API key from https://anthropic.com/ +ANTHROPIC_API_KEY=your-anthropic-api-key + +# For running LLMs hosted by DeepSeek +# Get your DeepSeek API key from https://deepseek.com/ +DEEPSEEK_API_KEY=your-deepseek-api-key + # For getting financial data to power the hedge fund # Get your Financial Datasets API key from https://financialdatasets.ai/ FINANCIAL_DATASETS_API_KEY=your-financial-datasets-api-key ``` -**Important**: You must set `OPENAI_API_KEY`, `GROQ_API_KEY`, `ANTHROPIC_API_KEY`, or `DEEPSEEK_API_KEY` for the hedge fund to work. If you want to use LLMs from all providers, you will need to set all API keys. +**Important**: You must set at least one of `OPENAI_API_KEY`, `GROQ_API_KEY`, `ANTHROPIC_API_KEY`, or `DEEPSEEK_API_KEY` for the hedge fund to work. If you want to use LLMs from all providers, you will need to set all API keys. Financial data for AAPL, GOOGL, MSFT, NVDA, and TSLA is free and does not require an API key. @@ -92,13 +108,36 @@ For any other ticker, you will need to set the `FINANCIAL_DATASETS_API_KEY` in t ## Usage -### Running the Hedge Fund +### Running the Web App + +The AI Hedge Fund Plus comes with an interactive Streamlit web interface that allows you to: +- Configure and run backtests with different parameters +- Visualize portfolio performance +- Analyze trading decisions and signals +- Compare different analysts' perspectives + +To run the web app: + +```bash +# Using the provided script +./run_app.sh + +# Or directly with Poetry +poetry run streamlit run app.py +``` + +The web app will be available at http://localhost:8501 in your browser. + +### Running the Hedge Fund CLI + +For command-line usage: + ```bash poetry run python src/main.py --ticker AAPL,MSFT,NVDA ``` **Example Output:** -Screenshot 2025-01-06 at 5 50 17 PM +Screenshot 2025-01-06 at 5 50 17 PM You can also specify a `--show-reasoning` flag to print the reasoning of each agent to the console. @@ -111,14 +150,14 @@ You can optionally specify the start and end dates to make decisions for a speci poetry run python src/main.py --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 ``` -### Running the Backtester +### Running the Backtester CLI ```bash poetry run python src/backtester.py --ticker AAPL,MSFT,NVDA ``` **Example Output:** -Screenshot 2025-01-06 at 5 47 52 PM +Screenshot 2025-01-06 at 5 47 52 PM You can optionally specify the start and end dates to backtest over a specific time period. @@ -126,41 +165,170 @@ You can optionally specify the start and end dates to backtest over a specific t poetry run python src/backtester.py --ticker AAPL,MSFT,NVDA --start-date 2024-01-01 --end-date 2024-03-01 ``` +## Payment Integration + +You can monetize the AI Hedge Fund Plus application by integrating a payment system. This section provides instructions for setting up either Stripe or Paddle as your payment processor. + +### Stripe Setup + +1. **Create a Stripe Account**: + - Sign up at [stripe.com](https://stripe.com) + - Complete the verification process + - Set up your business details + +2. **Install Stripe Dependencies**: + ```bash + poetry add stripe streamlit-stripe + ``` + +3. **Configure Stripe Keys**: + Add these to your `.env` file: + ``` + STRIPE_PUBLISHABLE_KEY=your_publishable_key + STRIPE_SECRET_KEY=your_secret_key + STRIPE_PRICE_ID=your_price_id + ``` + +4. **Create Products and Pricing**: + - Log into your Stripe Dashboard + - Go to Products > Create Product + - Set up subscription tiers (e.g., Basic, Pro, Enterprise) + - Note the Price IDs for each tier + +5. **Implement Stripe in Your App**: + Create a new file `src/payment/stripe_integration.py`: + ```python + import os + import stripe + import streamlit as st + from dotenv import load_dotenv + + load_dotenv() + + stripe.api_key = os.getenv("STRIPE_SECRET_KEY") + + def create_checkout_session(price_id, success_url, cancel_url): + try: + checkout_session = stripe.checkout.Session.create( + payment_method_types=["card"], + line_items=[{"price": price_id, "quantity": 1}], + mode="subscription", + success_url=success_url, + cancel_url=cancel_url, + ) + return checkout_session + except Exception as e: + return str(e) + + def display_payment_options(): + st.header("Choose Your Subscription Plan") + + col1, col2, col3 = st.columns(3) + + with col1: + st.subheader("Basic") + st.write("$9.99/month") + st.write("- Access to basic features") + st.write("- Limited number of stocks") + if st.button("Subscribe to Basic"): + session = create_checkout_session( + os.getenv("STRIPE_BASIC_PRICE_ID"), + "http://localhost:8501/success", + "http://localhost:8501/cancel" + ) + st.markdown(f"[Proceed to Payment]({{session.url}})") + + with col2: + st.subheader("Pro") + st.write("$19.99/month") + st.write("- All basic features") + st.write("- Unlimited stocks") + st.write("- Advanced analytics") + if st.button("Subscribe to Pro"): + session = create_checkout_session( + os.getenv("STRIPE_PRO_PRICE_ID"), + "http://localhost:8501/success", + "http://localhost:8501/cancel" + ) + st.markdown(f"[Proceed to Payment]({{session.url}})") + + with col3: + st.subheader("Enterprise") + st.write("$49.99/month") + st.write("- All pro features") + st.write("- Priority support") + st.write("- Custom analytics") + if st.button("Subscribe to Enterprise"): + session = create_checkout_session( + os.getenv("STRIPE_ENTERPRISE_PRICE_ID"), + "http://localhost:8501/success", + "http://localhost:8501/cancel" + ) + st.markdown(f"[Proceed to Payment]({{session.url}})") + ``` + +6. **Integrate with Your Streamlit App**: + Update `app.py` to include the payment page: + ```python + # Add this import at the top + from src.payment.stripe_integration import display_payment_options + + # Add this to your sidebar or as a separate page + if st.sidebar.button("Subscription Plans"): + display_payment_options() + ``` + +7. **Set Up Webhook for Subscription Management**: + - Create a webhook endpoint in your application + - Configure the webhook in your Stripe Dashboard + - Handle events like `customer.subscription.created`, `customer.subscription.updated`, etc. + ## Project Structure ``` -ai-hedge-fund/ +ai-hedge-fund-plus/ +├── app.py # Streamlit web application +├── run_app.sh # Script to run the web app ├── src/ -│ ├── agents/ # Agent definitions and workflow -│ │ ├── bill_ackman.py # Bill Ackman agent -│ │ ├── fundamentals.py # Fundamental analysis agent +│ ├── agents/ # Agent definitions and workflow +│ │ ├── bill_ackman.py # Bill Ackman agent +│ │ ├── fundamentals.py # Fundamental analysis agent +│ │ ├── portfolio_manager.py # Portfolio management agent +│ │ ├── risk_manager.py # Risk management agent +│ │ ├── sentiment.py # Sentiment analysis agent +│ │ ├── technicals.py # Technical analysis agent +│ │ ├── valuation.py # Valuation analysis agent +│ │ ├── warren_buffett.py # Warren Buffett agent +│ │ ├── ben_graham.py # Ben Graham agent +│ │ ├── cathie_wood.py # Cathie Wood agent +│ │ ├── charlie_munger.py # Charlie Munger agent +│ │ ├── stanley_druckenmiller.py # Stanley Druckenmiller agent +│ │ ├── warren_buffett.py # Warren Buffett agent +│ │ ├── valuation.py # Valuation analysis agent +│ │ ├── sentiment.py # Sentiment analysis agent +│ │ ├── fundamentals.py # Fundamental analysis agent +│ │ ├── technicals.py # Technical analysis agent +│ │ ├── risk_manager.py # Risk management agent │ │ ├── portfolio_manager.py # Portfolio management agent -│ │ ├── risk_manager.py # Risk management agent -│ │ ├── sentiment.py # Sentiment analysis agent -│ │ ├── technicals.py # Technical analysis agent -│ │ ├── valuation.py # Valuation analysis agent -│ │ ├── warren_buffett.py # Warren Buffett agent -│ ├── tools/ # Agent tools -│ │ ├── api.py # API tools -│ ├── backtester.py # Backtesting tools -│ ├── main.py # Main entry point -├── pyproject.toml -├── ... +│ │ ├── ben_graham.py # Ben Graham agent +│ │ ├── cathie_wood.py # Cathie Wood agent +│ │ ├── charlie_munger.py # Charlie Munger agent +│ │ ├── stanley_druckenmiller.py # Stanley Druckenmiller agent +│ │ ├── warren_buffett.py # Warren Buffett agent +│ │ └── valuation.py # Valuation analysis agent +│ ├── data/ # Data handling and processing +│ ├── graph/ # Visualization components +│ ├── llm/ # LLM integration and models +│ ├── tools/ # Agent tools +│ │ ├── api.py # API tools +│ ├── utils/ # Utility functions +│ ├── backtester.py # Backtesting engine +│ ├── main.py # Main CLI entry point +├── backtester.py # CLI backtester entry point +├── pyproject.toml # Poetry configuration +├── .env.example # Example environment variables +├── LICENSE # MIT License ``` -## Contributing - -1. Fork the repository -2. Create a feature branch -3. Commit your changes -4. Push to the branch -5. Create a Pull Request - -**Important**: Please keep your pull requests small and focused. This will make it easier to review and merge. - -## Feature Requests - -If you have a feature request, please open an [issue](https://github.com/virattt/ai-hedge-fund/issues) and make sure it is tagged with `enhancement`. - ## License This project is licensed under the MIT License - see the LICENSE file for details. diff --git a/app.py b/app.py new file mode 100644 index 00000000..d4939711 --- /dev/null +++ b/app.py @@ -0,0 +1,1250 @@ +import streamlit as st +import pandas as pd +import matplotlib.pyplot as plt +from datetime import datetime, timedelta +from dateutil.relativedelta import relativedelta +import numpy as np +import sys +import os +from dotenv import load_dotenv +import matplotlib as mpl + +# Add src directory to path +sys.path.append(os.path.join(os.path.dirname(__file__), "src")) + +# Import necessary modules from the project +from src.backtester import Backtester +from src.main import run_hedge_fund +from src.llm.models import LLM_ORDER, get_model_info, ModelProvider +from src.utils.analysts import ANALYST_ORDER + +# Set up matplotlib to use colors that work in both light and dark mode +plt.style.use('default') # Reset to default style + +# Define a function to get theme-based colors +def get_theme_colors(is_dark_theme=False): + if is_dark_theme: + return { + 'primary': '#8ab4f8', # Light blue for dark theme + 'success': '#81c995', # Light green for dark theme + 'error': '#f28b82', # Light red for dark theme + 'warning': '#fdd663', # Light yellow for dark theme + 'neutral': '#9aa0a6', # Light gray for dark theme + 'background': '#202124', # Dark background + 'text': '#e8eaed', # Light text for dark theme + 'grid': '#5f6368', # Grid lines for dark theme + } + else: + return { + 'primary': '#1a73e8', # Blue for light theme + 'success': '#0f9d58', # Green for light theme + 'error': '#d93025', # Red for light theme + 'warning': '#f9ab00', # Yellow/orange for light theme + 'neutral': '#5f6368', # Gray for light theme + 'background': '#ffffff', # Light background + 'text': '#202124', # Dark text for light theme + 'grid': '#dadce0', # Grid lines for light theme + } + +# Function to detect if Streamlit is in dark mode +def is_dark_theme(): + try: + # This is a hack to detect dark theme in Streamlit + # It may not always work, but it's a reasonable approximation + return st.get_option("theme.base") == "dark" + except: + return False + +# Get theme colors based on current theme +theme_colors = get_theme_colors(is_dark_theme()) + +# Set default matplotlib colors based on theme +mpl.rcParams['axes.facecolor'] = 'none' # Transparent background +mpl.rcParams['figure.facecolor'] = 'none' # Transparent figure +mpl.rcParams['axes.edgecolor'] = theme_colors['grid'] +mpl.rcParams['axes.labelcolor'] = theme_colors['text'] +mpl.rcParams['xtick.color'] = theme_colors['text'] +mpl.rcParams['ytick.color'] = theme_colors['text'] +mpl.rcParams['grid.color'] = theme_colors['grid'] +mpl.rcParams['text.color'] = theme_colors['text'] + +# Set page configuration +st.set_page_config( + page_title="AI Hedge Fund Backtester", + page_icon="📈", + layout="wide", + initial_sidebar_state="expanded", +) + +# Add custom CSS +st.markdown(""" + +""", unsafe_allow_html=True) + +# Header with support button in the right corner +st.markdown(""" +
+
+
AI Hedge Fund Backtester
+

Backtest your AI-powered hedge fund strategy with historical data.

+
+ +
+""", unsafe_allow_html=True) + +# Sidebar for inputs +st.sidebar.markdown('
Configuration
', unsafe_allow_html=True) + +# Ticker input +ticker_input = st.sidebar.text_input( + "Stock Tickers (comma-separated)", + value="AAPL,MSFT", + help="Enter stock tickers separated by commas (e.g., AAPL,MSFT,GOOGL)" +) +tickers = [ticker.strip() for ticker in ticker_input.split(",") if ticker.strip()] + +# Date range selection +today = datetime.now() +day_before_yesterday = today - timedelta(days=2) +default_end_date = day_before_yesterday.strftime("%Y-%m-%d") +default_start_date = (today - relativedelta(days=3)).strftime("%Y-%m-%d") + +col1, col2 = st.sidebar.columns(2) +with col1: + start_date = st.date_input( + "Start Date", + value=datetime.strptime(default_start_date, "%Y-%m-%d"), + max_value=today - timedelta(days=1), + ) +with col2: + end_date = st.date_input( + "End Date", + value=datetime.strptime(default_end_date, "%Y-%m-%d"), + max_value=day_before_yesterday, + ) + +# Convert date inputs to string format +start_date_str = start_date.strftime("%Y-%m-%d") +end_date_str = end_date.strftime("%Y-%m-%d") + +# Initial capital +initial_capital = st.sidebar.number_input( + "Initial Capital ($)", + min_value=1000.0, + max_value=10000000.0, + value=100000.0, + step=10000.0, + help="Initial capital amount for the backtest" +) + +# Margin requirement +margin_requirement = st.sidebar.slider( + "Margin Requirement (%)", + min_value=0.0, + max_value=100.0, + value=0.0, + step=5.0, + help="Margin ratio for short positions (e.g., 50% means 50% margin required)" +) / 100.0 + +# LLM model selection +# Filter LLM_ORDER to only include DeepSeek models +deepseek_models = [(display, value, provider) for display, value, provider in LLM_ORDER if provider == ModelProvider.DEEPSEEK.value] + +# Check if there are any DeepSeek models available +if not deepseek_models: + st.error("No DeepSeek models are available. Please check your configuration.") + st.stop() + +model_options = [display for display, value, _ in deepseek_models] +model_values = [value for display, value, _ in deepseek_models] +model_display_to_value = {display: value for display, value, _ in deepseek_models} + +# Add a note about DeepSeek models +st.sidebar.info("Only DeepSeek models are available for selection.") + +selected_model_display = st.sidebar.selectbox( + "LLM Model", + options=model_options, + index=0, + help="Select the LLM model to use for trading decisions" +) +selected_model = model_display_to_value[selected_model_display] + +# Get model provider +model_info = get_model_info(selected_model) +model_provider = model_info.provider.value if model_info else "Unknown" + +# Analyst selection +analyst_options = [display for display, value in ANALYST_ORDER] +analyst_values = [value for display, value in ANALYST_ORDER] +analyst_display_to_value = {display: value for display, value in ANALYST_ORDER} + +selected_analyst_displays = st.sidebar.multiselect( + "Select Analysts", + options=analyst_options, + default=analyst_options[:3], # Default to first 3 analysts + help="Select the analysts to include in the backtest" +) +selected_analysts = [analyst_display_to_value[display] for display in selected_analyst_displays] + +# Run backtest button +run_button = st.sidebar.button("Run Backtest", type="primary") + +# Main content area +if not tickers: + st.warning("Please enter at least one ticker symbol.") +elif start_date >= end_date: + st.warning("Start date must be before end date.") +elif not selected_analysts: + st.warning("Please select at least one analyst.") +elif end_date > day_before_yesterday.date(): + st.warning("End date cannot be later than the day before yesterday. Please select a valid end date.") +elif run_button: + # Store the fact that we've run a backtest + st.session_state.backtest_run = True + + with st.spinner("Running backtest..."): + # Create progress bar + progress_bar = st.progress(0) + progress_value = 0.0 + + # Create and run the backtester + backtester = Backtester( + agent=run_hedge_fund, + tickers=tickers, + start_date=start_date_str, + end_date=end_date_str, + initial_capital=initial_capital, + model_name=selected_model, + model_provider=model_provider, + selected_analysts=selected_analysts, + initial_margin_requirement=margin_requirement, + ) + + # Override the print function to update progress + original_print = print + # Create a class to store and update progress value + class ProgressTracker: + def __init__(self, initial_value=0.0): + self.value = initial_value + + def increment(self): + self.value = min(0.9, self.value + 0.05) + return self.value + + # Initialize progress tracker + progress_tracker = ProgressTracker() + + def progress_print(*args, **kwargs): + original_print(*args, **kwargs) + # Update progress bar (this is approximate since we don't know total steps) + progress_bar.progress(progress_tracker.increment()) + + # Monkey patch print function + import builtins + builtins.print = progress_print + + try: + # Run the backtest + performance_metrics = backtester.run_backtest() + performance_df = backtester.analyze_performance() + + # Store results in session state for later use + st.session_state.performance_metrics = performance_metrics + st.session_state.performance_df = performance_df + st.session_state.backtester = backtester + st.session_state.initial_capital = initial_capital + + # Set progress to complete + progress_bar.progress(1.0) + + finally: + # Restore original print function + builtins.print = original_print + +# Display results if a backtest has been run +if 'backtest_run' in st.session_state and st.session_state.backtest_run: + # Get stored results from session state + performance_metrics = st.session_state.performance_metrics + performance_df = st.session_state.performance_df + backtester = st.session_state.backtester + initial_capital = st.session_state.initial_capital + + # Display results + st.markdown('
Backtest Results
', unsafe_allow_html=True) + + # Performance metrics in cards + col1, col2, col3 = st.columns(3) + + # Calculate total return + if not performance_df.empty: + final_portfolio_value = performance_df["Portfolio Value"].iloc[-1] + total_return = ((final_portfolio_value - initial_capital) / initial_capital) * 100 + + with col1: + st.markdown(f""" +
+

Total Return

+

{total_return:.2f}%

+
+ """, unsafe_allow_html=True) + + with col2: + sharpe_ratio = performance_metrics.get('sharpe_ratio', 0) if performance_metrics else 0 + st.markdown(f""" +
+

Sharpe Ratio

+

{sharpe_ratio if sharpe_ratio is not None else 0:.2f}

+
+ """, unsafe_allow_html=True) + + with col3: + max_drawdown = performance_metrics.get('max_drawdown', 0) if performance_metrics else 0 + st.markdown(f""" +
+

Max Drawdown

+

{max_drawdown if max_drawdown is not None else 0:.2f}%

+
+ """, unsafe_allow_html=True) + + # Portfolio value chart + st.markdown("### Portfolio Value Over Time") + fig, ax = plt.subplots(figsize=(12, 6)) + ax.plot(performance_df.index, performance_df["Portfolio Value"], color=theme_colors['primary']) + ax.set_title("Portfolio Value Over Time") + ax.set_ylabel("Portfolio Value ($)") + ax.set_xlabel("Date") + ax.grid(True, alpha=0.3) + st.pyplot(fig) + + # Daily returns chart + st.markdown("### Daily Returns") + fig, ax = plt.subplots(figsize=(12, 6)) + # Ensure we have valid data for the daily returns + if "Daily Return" in performance_df.columns and not performance_df["Daily Return"].isnull().all(): + performance_df["Daily Return"].plot(kind="bar", ax=ax, color=performance_df["Daily Return"].apply( + lambda x: theme_colors['success'] if x >= 0 else theme_colors['error'])) + ax.set_title("Daily Returns") + ax.set_ylabel("Return (%)") + ax.set_xlabel("Date") + ax.grid(True, alpha=0.3) + st.pyplot(fig) + else: + st.warning("No daily return data available.") + + # Exposure chart + if "Long Exposure" in performance_df.columns and "Short Exposure" in performance_df.columns: + st.markdown("### Long/Short Exposure") + fig, ax = plt.subplots(figsize=(12, 6)) + ax.plot(performance_df.index, performance_df["Long Exposure"], color=theme_colors['success'], label="Long Exposure") + ax.plot(performance_df.index, performance_df["Short Exposure"], color=theme_colors['error'], label="Short Exposure") + ax.plot(performance_df.index, performance_df["Net Exposure"], color=theme_colors['primary'], label="Net Exposure") + ax.set_title("Portfolio Exposure Over Time") + ax.set_ylabel("Exposure ($)") + ax.set_xlabel("Date") + ax.legend() + ax.grid(True, alpha=0.3) + st.pyplot(fig) + + # Detailed performance table + st.markdown("### Detailed Performance") + st.dataframe(performance_df) + + # Portfolio Manager Decisions + st.markdown("### Portfolio Manager Decisions") + + # Get the trading decisions from the backtester + if hasattr(backtester, 'trading_decisions') and backtester.trading_decisions: + # Group decisions by date + decisions_by_date = {} + for date, ticker_decisions in backtester.trading_decisions.items(): + if date not in decisions_by_date: + decisions_by_date[date] = [] + + if ticker_decisions: # Check if ticker_decisions is not None + for ticker, decision in ticker_decisions.items(): + if decision and (decision.get('action') != 'hold' or decision.get('quantity', 0) > 0): + decisions_by_date[date].append({ + 'ticker': ticker, + 'action': decision.get('action', 'hold'), + 'quantity': decision.get('quantity', 0), + 'confidence': decision.get('confidence', 0), + 'reasoning': decision.get('reasoning', 'No reasoning provided') + }) + + # Display decisions by date + if decisions_by_date: # Check if we have any decisions to display + # Get all trading dates + trading_dates = sorted(decisions_by_date.keys()) + + # Initialize session state for selected trading day if not already set + if 'selected_trading_day' not in st.session_state: + st.session_state.selected_trading_day = trading_dates[0] if trading_dates else "" + + # Create callback function to update session state + def update_trading_day(): + st.session_state.selected_trading_day = st.session_state.trading_day_selector + + # Create a date selector with key and on_change callback + trading_day_selector = st.selectbox( + "Select Trading Day", + trading_dates, + index=trading_dates.index(st.session_state.selected_trading_day) if st.session_state.selected_trading_day in trading_dates else 0, + key="trading_day_selector", + on_change=update_trading_day + ) + + # Use the value from session state + selected_trading_day = st.session_state.selected_trading_day + + # Create a container for the decisions to avoid refreshing the entire page + decisions_container = st.container() + + with decisions_container: + st.markdown(f"#### Trading Day: {selected_trading_day}") + + if not decisions_by_date[selected_trading_day]: + st.info("No trading actions taken on this day.") + else: + for decision in decisions_by_date[selected_trading_day]: + action_class = f"decision-action-{decision['action']}" + + st.markdown(f""" +
+
{decision['ticker']}
+
{decision['action'].upper()} {decision['quantity']} shares (Confidence: {decision['confidence']:.1f}%)
+
Reasoning: {decision['reasoning']}
+
+ """, unsafe_allow_html=True) + else: + st.info("No trading decisions were made during the backtest period.") + else: + st.info("No detailed trading decisions available. This may be because the backtester didn't store the decision reasoning.") + + # Suggest a modification to the backtester + with st.expander("How to enable detailed decision tracking"): + st.markdown(""" + To track detailed portfolio manager decisions, you need to modify the backtester to store the decision reasoning: + + 1. In `src/backtester.py`, add a dictionary to store decisions in the `__init__` method: + ```python + self.trading_decisions = {} + ``` + + 2. In the `run_backtest` method, store the decisions with their reasoning: + ```python + # Store decisions with reasoning for this date + self.trading_decisions[current_date_str] = { + ticker: { + 'action': decision.get('action', 'hold'), + 'quantity': decision.get('quantity', 0), + 'confidence': decision.get('confidence', 0), + 'reasoning': decision.get('reasoning', '') + } + for ticker, decision in decisions.items() + } + ``` + """) + + # Signal Explanation Section + st.markdown("### Signal Analysis and Weighting") + + # Add a brief introduction outside the expander + st.markdown(""" + The AI Hedge Fund uses a multi-analyst approach where different specialists analyze stocks from various perspectives. + Each analyst produces a signal (bullish, bearish, or neutral) with a confidence level. + The Portfolio Manager then considers all these signals to make the final trading decision. + """) + + # Create tabs for different aspects of signal analysis + signal_tabs = st.tabs(["Analysts Overview", "Technical Analysis", "Sentiment Analysis", "Decision Process"]) + + with signal_tabs[0]: + st.markdown(""" + ## Available Analysts + + The AI Hedge Fund incorporates perspectives from multiple analysts, each with their own specialty and approach: + """) + + # Create columns for different analyst categories + col1, col2 = st.columns(2) + + with col1: + st.markdown(""" + ### Value Investors + +
+ Warren Buffett
+ Focuses on companies with strong competitive advantages, good management, and reasonable valuations +
+ +
+ Charlie Munger
+ Emphasizes mental models and avoiding psychological biases in investment decisions +
+ +
+ Ben Graham
+ The father of value investing, focuses on margin of safety and quantitative analysis +
+ + ### Active Investors + +
+ Bill Ackman
+ Activist investor approach, looking for companies with potential for significant operational improvements +
+ +
+ Cathie Wood
+ Growth-focused, emphasizing disruptive innovation and technological trends +
+ +
+ Stanley Druckenmiller
+ Macro-focused approach with emphasis on capital preservation and concentrated positions +
+ """, unsafe_allow_html=True) + + with col2: + st.markdown(""" + ### Specialized Analysts + +
+ Technical Analyst
+ Uses price patterns and technical indicators to generate trading signals +
+ +
+ Fundamentals Analyst
+ Analyzes financial statements and business fundamentals +
+ +
+ Sentiment Analyst
+ Evaluates market sentiment from news and insider trading activity +
+ +
+ Valuation Analyst
+ Focuses specifically on valuation metrics and fair value estimates +
+ """, unsafe_allow_html=True) + + with signal_tabs[1]: + st.markdown(""" + ## Technical Analysis Methodology + + The Technical Analyst combines multiple strategies to generate trading signals: + """) + + # Create a DataFrame for the technical strategies + tech_strategies = pd.DataFrame({ + "Strategy": ["Trend Following", "Momentum", "Mean Reversion", "Volatility", "Statistical Arbitrage"], + "Weight": [25, 25, 20, 15, 15], + "Description": [ + "Identifies directional price movements using moving averages and ADX", + "Measures the rate of price changes using RSI and other momentum indicators", + "Identifies overbought/oversold conditions using Bollinger Bands", + "Analyzes price volatility patterns using ATR and other volatility measures", + "Uses statistical methods to identify pricing inefficiencies" + ] + }) + + # Display the strategies as a table + st.table(tech_strategies) + + # Create a pie chart for the weights + fig, ax = plt.subplots(figsize=(8, 8)) + ax.pie(tech_strategies["Weight"], labels=tech_strategies["Strategy"], autopct='%1.1f%%', + startangle=90, shadow=True, explode=[0.05, 0.05, 0, 0, 0], + colors=[theme_colors['error'], theme_colors['primary'], theme_colors['success'], + theme_colors['warning'], theme_colors['neutral']]) + ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle + ax.set_title("Technical Analysis Strategy Weights") + st.pyplot(fig) + + st.markdown(""" + ### How Technical Signals Are Combined + + The Technical Analyst calculates individual signals for each strategy, then combines them using a weighted average approach. + The final signal (bullish, bearish, or neutral) is determined by the weighted score: + + - Score > 0.2: Bullish + - Score < -0.2: Bearish + - Otherwise: Neutral + + The confidence level is derived from the absolute value of the final score. + """) + + with signal_tabs[2]: + st.markdown(""" + ## Sentiment Analysis Methodology + + The Sentiment Analyst evaluates market sentiment from multiple sources: + """) + + # Create a DataFrame for the sentiment sources + sentiment_sources = pd.DataFrame({ + "Source": ["News Sentiment", "Insider Trading"], + "Weight": [70, 30], + "Description": [ + "Analyzes sentiment from recent news articles about the company", + "Evaluates recent insider buying and selling patterns" + ] + }) + + # Display the sources as a table + st.table(sentiment_sources) + + # Create a pie chart for the weights + fig, ax = plt.subplots(figsize=(8, 8)) + ax.pie(sentiment_sources["Weight"], labels=sentiment_sources["Source"], autopct='%1.1f%%', + startangle=90, shadow=True, explode=[0.05, 0], + colors=[theme_colors['primary'], theme_colors['error']]) + ax.axis('equal') # Equal aspect ratio ensures that pie is drawn as a circle + ax.set_title("Sentiment Analysis Source Weights") + st.pyplot(fig) + + st.markdown(""" + ### How Sentiment Signals Are Generated + + The Sentiment Analyst: + + 1. Collects recent news articles and insider trading data + 2. Classifies each news article as positive, negative, or neutral + 3. Classifies insider transactions as bullish (buying) or bearish (selling) + 4. Applies weights to each source + 5. Determines the final signal based on which weighted count is higher + 6. Calculates confidence based on the proportion of the dominant signal + """) + + with signal_tabs[3]: + st.markdown(""" + ## Portfolio Manager Decision Process + + The Portfolio Manager is the final decision-maker that considers all analyst signals along with portfolio constraints: + """) + + st.markdown(""" +
+

Inputs to the Decision Process

+ +
+ +
+

Trading Rules

+ +
+ +
+

Available Actions

+ +
+ +
+

Output Decision

+

For each ticker, the Portfolio Manager outputs:

+ +
+ """, unsafe_allow_html=True) + + # Signal Visualization + if hasattr(backtester, 'trading_decisions') and backtester.trading_decisions: + st.markdown("### Analyst Signal Visualization") + + # Check if analyst signals are available + if hasattr(backtester, 'analyst_signals') and backtester.analyst_signals: + # Get a list of all dates + all_dates = sorted(backtester.analyst_signals.keys()) + + if all_dates: + # Create tabs for different visualization modes + viz_tabs = st.tabs(["Single Date View", "Date Comparison", "Signal Trend"]) + + # Initialize session state for selected date and ticker if not already set + if 'selected_date' not in st.session_state: + st.session_state.selected_date = all_dates[0] + if 'selected_ticker' not in st.session_state: + st.session_state.selected_ticker = tickers[0] if tickers else "" + if 'comparison_date' not in st.session_state and len(all_dates) > 1: + st.session_state.comparison_date = all_dates[1] if len(all_dates) > 1 else all_dates[0] + if 'selected_analyst' not in st.session_state: + st.session_state.selected_analyst = "All Analysts" + + # Create callback functions to update session state + def update_selected_date(): + st.session_state.selected_date = st.session_state.date_selector + + def update_selected_ticker(): + st.session_state.selected_ticker = st.session_state.ticker_selector + + def update_comparison_date(): + st.session_state.comparison_date = st.session_state.comparison_date_selector + + def update_selected_analyst(): + st.session_state.selected_analyst = st.session_state.analyst_selector + + # Single Date View Tab + with viz_tabs[0]: + # Create a date selector with key and on_change callback + date_selector = st.selectbox( + "Select Trading Day", + all_dates, + index=all_dates.index(st.session_state.selected_date) if st.session_state.selected_date in all_dates else 0, + key="date_selector", + on_change=update_selected_date + ) + + # Create a ticker selector with key and on_change callback + ticker_selector = st.selectbox( + "Select Ticker", + tickers, + index=tickers.index(st.session_state.selected_ticker) if st.session_state.selected_ticker in tickers else 0, + key="ticker_selector", + on_change=update_selected_ticker + ) + + # Use the values from session state for visualization + selected_date = st.session_state.selected_date + selected_ticker = st.session_state.selected_ticker + + # Get the analyst signals for the selected date + date_signals = backtester.analyst_signals.get(selected_date, {}) + + # Create a DataFrame for the signals + signal_data = [] + + # Process signals for each analyst + for analyst_name, signals in date_signals.items(): + if analyst_name != "risk_management_agent" and isinstance(signals, dict): + # Get the signal for the selected ticker + ticker_signal = signals.get(selected_ticker, {}) + if isinstance(ticker_signal, dict): + signal = ticker_signal.get("signal", "N/A") + confidence = ticker_signal.get("confidence", 0) + signal_data.append({ + "Analyst": analyst_name.replace("_agent", "").replace("_", " ").title(), + "Signal": signal.title() if signal else "N/A", + "Confidence": confidence if confidence else 0, + "Color": theme_colors['success'] if signal == "bullish" else (theme_colors['error'] if signal == "bearish" else theme_colors['neutral']) + }) + + if signal_data: + signal_df = pd.DataFrame(signal_data) + + # Create a horizontal bar chart + fig, ax = plt.subplots(figsize=(10, len(signal_data) * 0.5 + 2)) + bars = ax.barh(signal_df["Analyst"], signal_df["Confidence"], color=signal_df["Color"]) + ax.set_xlabel("Confidence (%)") + ax.set_title(f"Analyst Signals for {selected_ticker} on {selected_date}") + ax.set_xlim(0, 100) + + # Add the signal labels to the bars + for i, bar in enumerate(bars): + ax.text( + bar.get_width() + 2, + bar.get_y() + bar.get_height()/2, + signal_df["Signal"].iloc[i], + va='center' + ) + + st.pyplot(fig) + + # Display the signal data in a table + st.dataframe(signal_df[["Analyst", "Signal", "Confidence"]]) + else: + st.info(f"No analyst signals available for {selected_ticker} on {selected_date}") + + # Date Comparison Tab + with viz_tabs[1]: + # Create two columns for date selection + date_col1, date_col2 = st.columns(2) + + with date_col1: + # Create a date selector with key and on_change callback + comparison_date1_selector = st.selectbox( + "Date 1", + all_dates, + index=all_dates.index(st.session_state.selected_date) if st.session_state.selected_date in all_dates else 0, + key="comparison_date1_selector", + on_change=update_selected_date + ) + + with date_col2: + # Create a comparison date selector + comparison_date2_selector = st.selectbox( + "Date 2", + all_dates, + index=all_dates.index(st.session_state.comparison_date) if st.session_state.comparison_date in all_dates else (1 if len(all_dates) > 1 else 0), + key="comparison_date_selector", + on_change=update_comparison_date + ) + + # Create a ticker selector for comparison + comparison_ticker_selector = st.selectbox( + "Select Ticker", + tickers, + index=tickers.index(st.session_state.selected_ticker) if st.session_state.selected_ticker in tickers else 0, + key="comparison_ticker_selector", + on_change=update_selected_ticker + ) + + # Use the values from session state for visualization + selected_date = st.session_state.selected_date + comparison_date = st.session_state.comparison_date + selected_ticker = st.session_state.selected_ticker + + # Create two columns for side-by-side comparison + viz_col1, viz_col2 = st.columns(2) + + # Function to create signal visualization for a date + def create_signal_viz(date, container): + with container: + st.subheader(f"Signals for {selected_ticker} on {date}") + + # Get the analyst signals for the selected date + date_signals = backtester.analyst_signals.get(date, {}) + + # Create a DataFrame for the signals + signal_data = [] + + # Process signals for each analyst + for analyst_name, signals in date_signals.items(): + if analyst_name != "risk_management_agent" and isinstance(signals, dict): + # Get the signal for the selected ticker + ticker_signal = signals.get(selected_ticker, {}) + if isinstance(ticker_signal, dict): + signal = ticker_signal.get("signal", "N/A") + confidence = ticker_signal.get("confidence", 0) + signal_data.append({ + "Analyst": analyst_name.replace("_agent", "").replace("_", " ").title(), + "Signal": signal.title() if signal else "N/A", + "Confidence": confidence if confidence else 0, + "Color": theme_colors['success'] if signal == "bullish" else (theme_colors['error'] if signal == "bearish" else theme_colors['neutral']) + }) + + if signal_data: + signal_df = pd.DataFrame(signal_data) + + # Create a horizontal bar chart + fig, ax = plt.subplots(figsize=(6, len(signal_data) * 0.5 + 2)) + bars = ax.barh(signal_df["Analyst"], signal_df["Confidence"], color=signal_df["Color"]) + ax.set_xlabel("Confidence (%)") + ax.set_xlim(0, 100) + + # Add the signal labels to the bars + for i, bar in enumerate(bars): + ax.text( + bar.get_width() + 2, + bar.get_y() + bar.get_height()/2, + signal_df["Signal"].iloc[i], + va='center' + ) + + st.pyplot(fig) + + # Display the signal data in a table + st.dataframe(signal_df[["Analyst", "Signal", "Confidence"]]) + else: + st.info(f"No analyst signals available for {selected_ticker} on {date}") + + # Create visualizations for both dates + create_signal_viz(selected_date, viz_col1) + create_signal_viz(comparison_date, viz_col2) + + # Add a section to show changes between dates + st.subheader("Signal Changes Analysis") + + # Get signals for both dates + date1_signals = backtester.analyst_signals.get(selected_date, {}) + date2_signals = backtester.analyst_signals.get(comparison_date, {}) + + # Find all analysts that appear in either date + all_analysts = set() + for signals in [date1_signals, date2_signals]: + for analyst_name in signals.keys(): + if analyst_name != "risk_management_agent": + all_analysts.add(analyst_name) + + # Create a DataFrame to track changes + changes_data = [] + + for analyst_name in all_analysts: + # Get signals for the selected ticker from both dates + signal1 = None + confidence1 = None + if analyst_name in date1_signals and selected_ticker in date1_signals[analyst_name]: + ticker_signal = date1_signals[analyst_name].get(selected_ticker, {}) + if isinstance(ticker_signal, dict): + signal1 = ticker_signal.get("signal", "N/A") + confidence1 = ticker_signal.get("confidence", 0) + + signal2 = None + confidence2 = None + if analyst_name in date2_signals and selected_ticker in date2_signals[analyst_name]: + ticker_signal = date2_signals[analyst_name].get(selected_ticker, {}) + if isinstance(ticker_signal, dict): + signal2 = ticker_signal.get("signal", "N/A") + confidence2 = ticker_signal.get("confidence", 0) + + # Only add to changes if we have data for both dates + if signal1 and signal2: + signal_changed = signal1 != signal2 + confidence_change = (confidence2 - confidence1) if confidence1 is not None and confidence2 is not None else None + + changes_data.append({ + "Analyst": analyst_name.replace("_agent", "").replace("_", " ").title(), + "Signal Date 1": signal1.title() if signal1 else "N/A", + "Confidence Date 1": confidence1 if confidence1 is not None else 0, + "Signal Date 2": signal2.title() if signal2 else "N/A", + "Confidence Date 2": confidence2 if confidence2 is not None else 0, + "Signal Changed": signal_changed, + "Confidence Change": confidence_change if confidence_change is not None else 0 + }) + + if changes_data: + changes_df = pd.DataFrame(changes_data) + + # Style the DataFrame to highlight changes + def highlight_changes(val): + if isinstance(val, bool) and val: + return 'background-color: yellow' + elif isinstance(val, (int, float)) and val != 0: + return 'color: green' if val > 0 else 'color: red' + return '' + + styled_changes = changes_df.style.applymap(highlight_changes, subset=['Signal Changed', 'Confidence Change']) + st.dataframe(styled_changes) + else: + st.info("No comparable signals available for both dates.") + + # Signal Trend Tab + with viz_tabs[2]: + # Create a ticker selector for trend analysis + trend_ticker_selector = st.selectbox( + "Select Ticker", + tickers, + index=tickers.index(st.session_state.selected_ticker) if st.session_state.selected_ticker in tickers else 0, + key="trend_ticker_selector", + on_change=update_selected_ticker + ) + + # Get all analysts from the signals + all_analysts = set() + for date, date_signals in backtester.analyst_signals.items(): + for analyst_name in date_signals.keys(): + if analyst_name != "risk_management_agent": + all_analysts.add(analyst_name.replace("_agent", "").replace("_", " ").title()) + + # Create a list of analysts with "All Analysts" as the first option + analyst_options = ["All Analysts"] + sorted(list(all_analysts)) + + # Create an analyst selector + analyst_selector = st.selectbox( + "Select Analyst", + analyst_options, + index=analyst_options.index(st.session_state.selected_analyst) if st.session_state.selected_analyst in analyst_options else 0, + key="analyst_selector", + on_change=update_selected_analyst + ) + + # Use the values from session state for visualization + selected_ticker = st.session_state.selected_ticker + selected_analyst = st.session_state.selected_analyst + + # Create a DataFrame to track signals over time + trend_data = [] + + for date in all_dates: + date_signals = backtester.analyst_signals.get(date, {}) + + if selected_analyst == "All Analysts": + # Aggregate signals from all analysts + bullish_count = 0 + bearish_count = 0 + neutral_count = 0 + total_confidence = 0 + analyst_count = 0 + + for analyst_name, signals in date_signals.items(): + if analyst_name != "risk_management_agent" and isinstance(signals, dict): + ticker_signal = signals.get(selected_ticker, {}) + if isinstance(ticker_signal, dict): + signal = ticker_signal.get("signal", "") + confidence = ticker_signal.get("confidence", 0) + + if signal == "bullish": + bullish_count += 1 + elif signal == "bearish": + bearish_count += 1 + elif signal == "neutral": + neutral_count += 1 + + total_confidence += confidence + analyst_count += 1 + + if analyst_count > 0: + # Calculate consensus signal + if bullish_count > bearish_count and bullish_count > neutral_count: + consensus_signal = "Bullish" + elif bearish_count > bullish_count and bearish_count > neutral_count: + consensus_signal = "Bearish" + else: + consensus_signal = "Neutral" + + # Calculate average confidence + avg_confidence = total_confidence / analyst_count + + trend_data.append({ + "Date": date, + "Signal": consensus_signal, + "Confidence": avg_confidence, + "Bullish Count": bullish_count, + "Bearish Count": bearish_count, + "Neutral Count": neutral_count, + "Analyst Count": analyst_count + }) + else: + # Get signals for the selected analyst + for analyst_name, signals in date_signals.items(): + analyst_display = analyst_name.replace("_agent", "").replace("_", " ").title() + + if analyst_display == selected_analyst and isinstance(signals, dict): + ticker_signal = signals.get(selected_ticker, {}) + if isinstance(ticker_signal, dict): + signal = ticker_signal.get("signal", "") + confidence = ticker_signal.get("confidence", 0) + + if signal: + trend_data.append({ + "Date": date, + "Signal": signal.title(), + "Confidence": confidence + }) + + if trend_data: + trend_df = pd.DataFrame(trend_data) + + # Create a figure with two subplots + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [2, 1]}) + + # Convert signals to numeric values for plotting + signal_values = {"Bullish": 1, "Neutral": 0, "Bearish": -1} + trend_df["Signal Value"] = trend_df["Signal"].map(signal_values) + + # Plot signal trend + ax1.plot(trend_df["Date"], trend_df["Signal Value"], marker='o', linestyle='-', color=theme_colors['primary']) + ax1.set_ylabel("Signal") + ax1.set_title(f"Signal Trend for {selected_ticker}" + (f" - {selected_analyst}" if selected_analyst != "All Analysts" else " - Consensus")) + ax1.set_ylim([-1.5, 1.5]) + ax1.set_yticks([-1, 0, 1]) + ax1.set_yticklabels(["Bearish", "Neutral", "Bullish"]) + ax1.grid(True) + + # Plot confidence trend + ax2.bar(trend_df["Date"], trend_df["Confidence"], color=trend_df["Signal"].map({ + "Bullish": theme_colors['success'], + "Neutral": theme_colors['neutral'], + "Bearish": theme_colors['error'] + })) + ax2.set_ylabel("Confidence (%)") + ax2.set_xlabel("Date") + ax2.set_ylim([0, 100]) + ax2.grid(True) + + plt.tight_layout() + st.pyplot(fig) + + # Display additional information for "All Analysts" + if selected_analyst == "All Analysts" and "Bullish Count" in trend_df.columns: + # Create a stacked bar chart of analyst counts + fig, ax = plt.subplots(figsize=(12, 6)) + + # Create the stacked bars + ax.bar(trend_df["Date"], trend_df["Bullish Count"], label="Bullish", color=theme_colors['success']) + ax.bar(trend_df["Date"], trend_df["Neutral Count"], bottom=trend_df["Bullish Count"], label="Neutral", color=theme_colors['neutral']) + ax.bar(trend_df["Date"], trend_df["Bearish Count"], bottom=trend_df["Bullish Count"] + trend_df["Neutral Count"], label="Bearish", color=theme_colors['error']) + + ax.set_ylabel("Analyst Count") + ax.set_xlabel("Date") + ax.set_title(f"Analyst Signal Distribution for {selected_ticker}") + ax.legend() + ax.grid(True) + + st.pyplot(fig) + + # Display the trend data in a table + st.dataframe(trend_df) + else: + st.info(f"No trend data available for {selected_ticker}" + (f" with analyst {selected_analyst}" if selected_analyst != "All Analysts" else "")) + else: + st.info("No analyst signals data available for visualization.") + else: + st.warning("No performance data available. The backtest may not have completed successfully.") +else: + # Display instructions when not running + st.info(""" + ### How to use this backtester: + + 1. Enter stock tickers separated by commas in the sidebar + 2. Select the date range for your backtest + 3. Set your initial capital amount + 4. Choose a margin requirement percentage (for short selling) + 5. Select the LLM model to use for trading decisions + 6. Choose which analysts to include in your strategy + 7. Click "Run Backtest" to start the simulation + + The backtest will simulate trading based on the AI analysts' recommendations and display performance metrics when complete. + """) + + # Display sample visualization + st.markdown("### Sample Portfolio Performance") + + # Generate sample data + dates = pd.date_range(start=start_date, end=end_date, freq='B') + sample_data = pd.DataFrame({ + 'Date': dates, + 'Portfolio Value': initial_capital * (1 + np.cumsum(np.random.normal(0.001, 0.02, size=len(dates)))) + }).set_index('Date') + + # Plot sample data + fig, ax = plt.subplots(figsize=(12, 6)) + ax.plot(sample_data.index, sample_data["Portfolio Value"], color=theme_colors['primary'], alpha=0.7) + ax.set_title("Sample Portfolio Performance (Simulated Data)") + ax.set_ylabel("Portfolio Value ($)") + ax.set_xlabel("Date") + ax.grid(True, alpha=0.3) + st.pyplot(fig) diff --git a/backtester.py b/backtester.py new file mode 100644 index 00000000..8da8c257 --- /dev/null +++ b/backtester.py @@ -0,0 +1,746 @@ +import sys + +from datetime import datetime, timedelta +from dateutil.relativedelta import relativedelta +import questionary + +import matplotlib.pyplot as plt +import pandas as pd +from colorama import Fore, Style, init +import numpy as np +import itertools + +from llm.models import LLM_ORDER, get_model_info +from utils.analysts import ANALYST_ORDER +from main import run_hedge_fund +from tools.api import ( + get_company_news, + get_price_data, + get_prices, + get_financial_metrics, + get_insider_trades, +) +from utils.display import print_backtest_results, format_backtest_row +from typing_extensions import Callable + +init(autoreset=True) + + +class Backtester: + def __init__( + self, + agent: Callable, + tickers: list[str], + start_date: str, + end_date: str, + initial_capital: float, + model_name: str = "gpt-4o", + model_provider: str = "OpenAI", + selected_analysts: list[str] = [], + initial_margin_requirement: float = 0.0, + ): + """ + :param agent: The trading agent (Callable). + :param tickers: List of tickers to backtest. + :param start_date: Start date string (YYYY-MM-DD). + :param end_date: End date string (YYYY-MM-DD). + :param initial_capital: Starting portfolio cash. + :param model_name: Which LLM model name to use (gpt-4, etc). + :param model_provider: Which LLM provider (OpenAI, etc). + :param selected_analysts: List of analyst names or IDs to incorporate. + :param initial_margin_requirement: The margin ratio (e.g. 0.5 = 50%). + """ + self.agent = agent + self.tickers = tickers + self.start_date = start_date + self.end_date = end_date + self.initial_capital = initial_capital + self.model_name = model_name + self.model_provider = model_provider + self.selected_analysts = selected_analysts + + # Store the margin ratio (e.g. 0.5 means 50% margin required). + self.margin_ratio = initial_margin_requirement + + # Initialize portfolio with support for long/short positions + self.portfolio_values = [] + self.portfolio = { + "cash": initial_capital, + "margin_used": 0.0, # total margin usage across all short positions + "margin_requirement": initial_margin_requirement, # Store margin requirement in portfolio + "positions": { + ticker: { + "long": 0, # Number of shares held long + "short": 0, # Number of shares held short + "long_cost_basis": 0.0, # Average cost basis per share (long) + "short_cost_basis": 0.0, # Average cost basis per share (short) + "short_margin_used": 0.0 # Dollars of margin used for this ticker's short + } for ticker in tickers + }, + "realized_gains": { + ticker: { + "long": 0.0, # Realized gains from long positions + "short": 0.0, # Realized gains from short positions + } for ticker in tickers + } + } + + # Dictionary to store trading decisions with reasoning + self.trading_decisions = {} + + # Dictionary to store analyst signals for each date + self.analyst_signals = {} + + def execute_trade(self, ticker: str, action: str, quantity: float, current_price: float): + """ + Execute trades with support for both long and short positions. + `quantity` is the number of shares the agent wants to buy/sell/short/cover. + We will only trade integer shares to keep it simple. + """ + if quantity <= 0: + return 0 + + quantity = int(quantity) # force integer shares + position = self.portfolio["positions"][ticker] + + if action == "buy": + cost = quantity * current_price + if cost <= self.portfolio["cash"]: + # Weighted average cost basis for the new total + old_shares = position["long"] + old_cost_basis = position["long_cost_basis"] + new_shares = quantity + total_shares = old_shares + new_shares + + if total_shares > 0: + total_old_cost = old_cost_basis * old_shares + total_new_cost = cost + position["long_cost_basis"] = (total_old_cost + total_new_cost) / total_shares + + position["long"] += quantity + self.portfolio["cash"] -= cost + return quantity + else: + # Calculate maximum affordable quantity + max_quantity = int(self.portfolio["cash"] / current_price) + if max_quantity > 0: + cost = max_quantity * current_price + old_shares = position["long"] + old_cost_basis = position["long_cost_basis"] + total_shares = old_shares + max_quantity + + if total_shares > 0: + total_old_cost = old_cost_basis * old_shares + total_new_cost = cost + position["long_cost_basis"] = (total_old_cost + total_new_cost) / total_shares + + position["long"] += max_quantity + self.portfolio["cash"] -= cost + return max_quantity + return 0 + + elif action == "sell": + # You can only sell as many as you own + quantity = min(quantity, position["long"]) + if quantity > 0: + # Realized gain/loss using average cost basis + avg_cost_per_share = position["long_cost_basis"] if position["long"] > 0 else 0 + realized_gain = (current_price - avg_cost_per_share) * quantity + self.portfolio["realized_gains"][ticker]["long"] += realized_gain + + position["long"] -= quantity + self.portfolio["cash"] += quantity * current_price + + if position["long"] == 0: + position["long_cost_basis"] = 0.0 + + return quantity + + elif action == "short": + """ + Typical short sale flow: + 1) Receive proceeds = current_price * quantity + 2) Post margin_required = proceeds * margin_ratio + 3) Net effect on cash = +proceeds - margin_required + """ + proceeds = current_price * quantity + margin_required = proceeds * self.margin_ratio + if margin_required <= self.portfolio["cash"]: + # Weighted average short cost basis + old_short_shares = position["short"] + old_cost_basis = position["short_cost_basis"] + new_shares = quantity + total_shares = old_short_shares + new_shares + + if total_shares > 0: + total_old_cost = old_cost_basis * old_short_shares + total_new_cost = current_price * new_shares + position["short_cost_basis"] = (total_old_cost + total_new_cost) / total_shares + + position["short"] += quantity + + # Update margin usage + position["short_margin_used"] += margin_required + self.portfolio["margin_used"] += margin_required + + # Increase cash by proceeds, then subtract the required margin + self.portfolio["cash"] += proceeds + self.portfolio["cash"] -= margin_required + return quantity + else: + # Calculate maximum shortable quantity + if self.margin_ratio > 0: + max_quantity = int(self.portfolio["cash"] / (current_price * self.margin_ratio)) + else: + max_quantity = 0 + + if max_quantity > 0: + proceeds = current_price * max_quantity + margin_required = proceeds * self.margin_ratio + + old_short_shares = position["short"] + old_cost_basis = position["short_cost_basis"] + total_shares = old_short_shares + max_quantity + + if total_shares > 0: + total_old_cost = old_cost_basis * old_short_shares + total_new_cost = current_price * max_quantity + position["short_cost_basis"] = (total_old_cost + total_new_cost) / total_shares + + position["short"] += max_quantity + position["short_margin_used"] += margin_required + self.portfolio["margin_used"] += margin_required + + self.portfolio["cash"] += proceeds + self.portfolio["cash"] -= margin_required + return max_quantity + return 0 + + elif action == "cover": + """ + When covering shares: + 1) Pay cover cost = current_price * quantity + 2) Release a proportional share of the margin + 3) Net effect on cash = -cover_cost + released_margin + """ + quantity = min(quantity, position["short"]) + if quantity > 0: + cover_cost = quantity * current_price + avg_short_price = position["short_cost_basis"] if position["short"] > 0 else 0 + realized_gain = (avg_short_price - current_price) * quantity + + if position["short"] > 0: + portion = quantity / position["short"] + else: + portion = 1.0 + + margin_to_release = portion * position["short_margin_used"] + + position["short"] -= quantity + position["short_margin_used"] -= margin_to_release + self.portfolio["margin_used"] -= margin_to_release + + # Pay the cost to cover, but get back the released margin + self.portfolio["cash"] += margin_to_release + self.portfolio["cash"] -= cover_cost + + self.portfolio["realized_gains"][ticker]["short"] += realized_gain + + if position["short"] == 0: + position["short_cost_basis"] = 0.0 + position["short_margin_used"] = 0.0 + + return quantity + + return 0 + + def calculate_portfolio_value(self, current_prices): + """ + Calculate total portfolio value, including: + - cash + - market value of long positions + - unrealized gains/losses for short positions + """ + total_value = self.portfolio["cash"] + + for ticker in self.tickers: + position = self.portfolio["positions"][ticker] + price = current_prices[ticker] + + # Long position value + long_value = position["long"] * price + total_value += long_value + + # Short position unrealized PnL = short_shares * (short_cost_basis - current_price) + if position["short"] > 0: + total_value += position["short"] * (position["short_cost_basis"] - price) + + return total_value + + def prefetch_data(self): + """Pre-fetch all data needed for the backtest period.""" + print("\nPre-fetching data for the entire backtest period...") + + # Convert end_date string to datetime, fetch up to 1 year before + end_date_dt = datetime.strptime(self.end_date, "%Y-%m-%d") + start_date_dt = end_date_dt - relativedelta(years=1) + start_date_str = start_date_dt.strftime("%Y-%m-%d") + + for ticker in self.tickers: + # Fetch price data for the entire period, plus 1 year + get_prices(ticker, start_date_str, self.end_date) + + # Fetch financial metrics + get_financial_metrics(ticker, self.end_date, limit=10) + + # Fetch insider trades + get_insider_trades(ticker, self.end_date, start_date=self.start_date, limit=1000) + + # Fetch company news + get_company_news(ticker, self.end_date, start_date=self.start_date, limit=1000) + + print("Data pre-fetch complete.") + + def parse_agent_response(self, agent_output): + """Parse JSON output from the agent (fallback to 'hold' if invalid).""" + import json + + try: + decision = json.loads(agent_output) + return decision + except Exception: + print(f"Error parsing action: {agent_output}") + return {"action": "hold", "quantity": 0} + + def run_backtest(self): + # Pre-fetch all data at the start + self.prefetch_data() + + dates = pd.date_range(self.start_date, self.end_date, freq="B") + table_rows = [] + performance_metrics = { + 'sharpe_ratio': None, + 'sortino_ratio': None, + 'max_drawdown': None, + 'long_short_ratio': None, + 'gross_exposure': None, + 'net_exposure': None + } + + print("\nStarting backtest...") + + # Initialize portfolio values list with initial capital + if len(dates) > 0: + self.portfolio_values = [{"Date": dates[0], "Portfolio Value": self.initial_capital}] + else: + self.portfolio_values = [] + + for current_date in dates: + lookback_start = (current_date - timedelta(days=30)).strftime("%Y-%m-%d") + current_date_str = current_date.strftime("%Y-%m-%d") + previous_date_str = (current_date - timedelta(days=1)).strftime("%Y-%m-%d") + + # Skip if there's no prior day to look back (i.e., first date in the range) + if lookback_start == current_date_str: + continue + + # Get current prices for all tickers + try: + current_prices = { + ticker: get_price_data(ticker, previous_date_str, current_date_str).iloc[-1]["close"] + for ticker in self.tickers + } + except Exception: + # If data is missing or there's an API error, skip this day + print(f"Error fetching prices between {previous_date_str} and {current_date_str}") + continue + + # --------------------------------------------------------------- + # 1) Execute the agent's trades + # --------------------------------------------------------------- + output = self.agent( + tickers=self.tickers, + start_date=lookback_start, + end_date=current_date_str, + portfolio=self.portfolio, + model_name=self.model_name, + model_provider=self.model_provider, + selected_analysts=self.selected_analysts, + ) + decisions = output["decisions"] + analyst_signals = output["analyst_signals"] + + # Store decisions with reasoning for this date + if decisions: # Check if decisions is not None + self.trading_decisions[current_date_str] = { + ticker: { + 'action': decision.get('action', 'hold'), + 'quantity': decision.get('quantity', 0), + 'confidence': decision.get('confidence', 0), + 'reasoning': decision.get('reasoning', 'No reasoning provided') + } + for ticker, decision in decisions.items() if decision # Skip None values + } + + # Store analyst signals for this date + if analyst_signals: # Check if analyst_signals is not None + self.analyst_signals[current_date_str] = analyst_signals + + # Execute trades for each ticker + executed_trades = {} + for ticker in self.tickers: + decision = decisions.get(ticker, {"action": "hold", "quantity": 0}) + action, quantity = decision.get("action", "hold"), decision.get("quantity", 0) + + executed_quantity = self.execute_trade(ticker, action, quantity, current_prices[ticker]) + executed_trades[ticker] = executed_quantity + + # --------------------------------------------------------------- + # 2) Now that trades have executed trades, recalculate the final + # portfolio value for this day. + # --------------------------------------------------------------- + total_value = self.calculate_portfolio_value(current_prices) + + # Also compute long/short exposures for final post‐trade state + long_exposure = sum( + self.portfolio["positions"][t]["long"] * current_prices[t] + for t in self.tickers + ) + short_exposure = sum( + self.portfolio["positions"][t]["short"] * current_prices[t] + for t in self.tickers + ) + + # Calculate gross and net exposures + gross_exposure = long_exposure + short_exposure + net_exposure = long_exposure - short_exposure + long_short_ratio = ( + long_exposure / short_exposure if short_exposure > 1e-9 else float('inf') + ) + + # Track each day's portfolio value in self.portfolio_values + self.portfolio_values.append({ + "Date": current_date, + "Portfolio Value": total_value, + "Long Exposure": long_exposure, + "Short Exposure": short_exposure, + "Gross Exposure": gross_exposure, + "Net Exposure": net_exposure, + "Long/Short Ratio": long_short_ratio + }) + + # --------------------------------------------------------------- + # 3) Build the table rows to display + # --------------------------------------------------------------- + date_rows = [] + + # For each ticker, record signals/trades + for ticker in self.tickers: + ticker_signals = {} + for agent_name, signals in analyst_signals.items(): + if ticker in signals: + ticker_signals[agent_name] = signals[ticker] + + bullish_count = len([s for s in ticker_signals.values() if s.get("signal", "").lower() == "bullish"]) + bearish_count = len([s for s in ticker_signals.values() if s.get("signal", "").lower() == "bearish"]) + neutral_count = len([s for s in ticker_signals.values() if s.get("signal", "").lower() == "neutral"]) + + # Calculate net position value + pos = self.portfolio["positions"][ticker] + long_val = pos["long"] * current_prices[ticker] + short_val = pos["short"] * current_prices[ticker] + net_position_value = long_val - short_val + + # Get the action and quantity from the decisions + action = decisions.get(ticker, {}).get("action", "hold") + quantity = executed_trades.get(ticker, 0) + + # Append the agent action to the table rows + date_rows.append( + format_backtest_row( + date=current_date_str, + ticker=ticker, + action=action, + quantity=quantity, + price=current_prices[ticker], + shares_owned=pos["long"] - pos["short"], # net shares + position_value=net_position_value, + bullish_count=bullish_count, + bearish_count=bearish_count, + neutral_count=neutral_count, + ) + ) + # --------------------------------------------------------------- + # 4) Calculate performance summary metrics + # --------------------------------------------------------------- + total_realized_gains = sum( + self.portfolio["realized_gains"][t]["long"] + + self.portfolio["realized_gains"][t]["short"] + for t in self.tickers + ) + + # Calculate cumulative return vs. initial capital + portfolio_return = ((total_value + total_realized_gains) / self.initial_capital - 1) * 100 + + # Add summary row for this day + date_rows.append( + format_backtest_row( + date=current_date_str, + ticker="", + action="", + quantity=0, + price=0, + shares_owned=0, + position_value=0, + bullish_count=0, + bearish_count=0, + neutral_count=0, + is_summary=True, + total_value=total_value, + return_pct=portfolio_return, + cash_balance=self.portfolio["cash"], + total_position_value=total_value - self.portfolio["cash"], + sharpe_ratio=performance_metrics["sharpe_ratio"], + sortino_ratio=performance_metrics["sortino_ratio"], + max_drawdown=performance_metrics["max_drawdown"], + ), + ) + + table_rows.extend(date_rows) + print_backtest_results(table_rows) + + # Update performance metrics if we have enough data + if len(self.portfolio_values) > 3: + self._update_performance_metrics(performance_metrics) + + return performance_metrics + + def _update_performance_metrics(self, performance_metrics): + """Helper method to update performance metrics using daily returns.""" + values_df = pd.DataFrame(self.portfolio_values).set_index("Date") + values_df["Daily Return"] = values_df["Portfolio Value"].pct_change() + clean_returns = values_df["Daily Return"].dropna() + + if len(clean_returns) < 2: + return # not enough data points + + # Assumes 252 trading days/year + daily_risk_free_rate = 0.0434 / 252 + excess_returns = clean_returns - daily_risk_free_rate + mean_excess_return = excess_returns.mean() + std_excess_return = excess_returns.std() + + # Sharpe ratio + if std_excess_return > 1e-12: + performance_metrics["sharpe_ratio"] = np.sqrt(252) * (mean_excess_return / std_excess_return) + else: + performance_metrics["sharpe_ratio"] = 0.0 + + # Sortino ratio + negative_returns = excess_returns[excess_returns < 0] + if len(negative_returns) > 0: + downside_std = negative_returns.std() + if downside_std > 1e-12: + performance_metrics["sortino_ratio"] = np.sqrt(252) * (mean_excess_return / downside_std) + else: + performance_metrics["sortino_ratio"] = float('inf') if mean_excess_return > 0 else 0 + else: + performance_metrics["sortino_ratio"] = float('inf') if mean_excess_return > 0 else 0 + + # Maximum drawdown + rolling_max = values_df["Portfolio Value"].cummax() + drawdown = (values_df["Portfolio Value"] - rolling_max) / rolling_max + performance_metrics["max_drawdown"] = drawdown.min() * 100 + + def analyze_performance(self): + """Creates a performance DataFrame, prints summary stats, and plots equity curve.""" + if not self.portfolio_values: + print("No portfolio data found. Please run the backtest first.") + return pd.DataFrame() + + performance_df = pd.DataFrame(self.portfolio_values).set_index("Date") + if performance_df.empty: + print("No valid performance data to analyze.") + return performance_df + + final_portfolio_value = performance_df["Portfolio Value"].iloc[-1] + total_realized_gains = sum( + self.portfolio["realized_gains"][ticker]["long"] for ticker in self.tickers + ) + total_return = ((final_portfolio_value - self.initial_capital) / self.initial_capital) * 100 + + print(f"\n{Fore.WHITE}{Style.BRIGHT}PORTFOLIO PERFORMANCE SUMMARY:{Style.RESET_ALL}") + print(f"Total Return: {Fore.GREEN if total_return >= 0 else Fore.RED}{total_return:.2f}%{Style.RESET_ALL}") + print(f"Total Realized Gains/Losses: {Fore.GREEN if total_realized_gains >= 0 else Fore.RED}${total_realized_gains:,.2f}{Style.RESET_ALL}") + + # Plot the portfolio value over time + plt.figure(figsize=(12, 6)) + plt.plot(performance_df.index, performance_df["Portfolio Value"], color="blue") + plt.title("Portfolio Value Over Time") + plt.ylabel("Portfolio Value ($)") + plt.xlabel("Date") + plt.grid(True) + plt.show() + + # Compute daily returns + performance_df["Daily Return"] = performance_df["Portfolio Value"].pct_change().fillna(0) + daily_rf = 0.0434 / 252 # daily risk-free rate + mean_daily_return = performance_df["Daily Return"].mean() + std_daily_return = performance_df["Daily Return"].std() + + # Annualized Sharpe Ratio + if std_daily_return != 0: + annualized_sharpe = np.sqrt(252) * ((mean_daily_return - daily_rf) / std_daily_return) + else: + annualized_sharpe = 0 + print(f"\nSharpe Ratio: {Fore.YELLOW}{annualized_sharpe:.2f}{Style.RESET_ALL}") + + # Max Drawdown + rolling_max = performance_df["Portfolio Value"].cummax() + drawdown = (performance_df["Portfolio Value"] - rolling_max) / rolling_max + max_drawdown = drawdown.min() + max_drawdown_date = drawdown.idxmin() + if pd.notnull(max_drawdown_date): + print(f"Maximum Drawdown: {Fore.RED}{max_drawdown * 100:.2f}%{Style.RESET_ALL} (on {max_drawdown_date.strftime('%Y-%m-%d')})") + else: + print(f"Maximum Drawdown: {Fore.RED}0.00%{Style.RESET_ALL}") + + # Win Rate + winning_days = len(performance_df[performance_df["Daily Return"] > 0]) + total_days = max(len(performance_df) - 1, 1) + win_rate = (winning_days / total_days) * 100 + print(f"Win Rate: {Fore.GREEN}{win_rate:.2f}%{Style.RESET_ALL}") + + # Average Win/Loss Ratio + positive_returns = performance_df[performance_df["Daily Return"] > 0]["Daily Return"] + negative_returns = performance_df[performance_df["Daily Return"] < 0]["Daily Return"] + avg_win = positive_returns.mean() if not positive_returns.empty else 0 + avg_loss = abs(negative_returns.mean()) if not negative_returns.empty else 0 + if avg_loss != 0: + win_loss_ratio = avg_win / avg_loss + else: + win_loss_ratio = float('inf') if avg_win > 0 else 0 + print(f"Win/Loss Ratio: {Fore.GREEN}{win_loss_ratio:.2f}{Style.RESET_ALL}") + + # Maximum Consecutive Wins / Losses + returns_binary = (performance_df["Daily Return"] > 0).astype(int) + if len(returns_binary) > 0: + max_consecutive_wins = max((len(list(g)) for k, g in itertools.groupby(returns_binary) if k == 1), default=0) + max_consecutive_losses = max((len(list(g)) for k, g in itertools.groupby(returns_binary) if k == 0), default=0) + else: + max_consecutive_wins = 0 + max_consecutive_losses = 0 + + print(f"Max Consecutive Wins: {Fore.GREEN}{max_consecutive_wins}{Style.RESET_ALL}") + print(f"Max Consecutive Losses: {Fore.RED}{max_consecutive_losses}{Style.RESET_ALL}") + + return performance_df + + +### 4. Run the Backtest ##### +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run backtesting simulation") + parser.add_argument( + "--tickers", + type=str, + required=False, + help="Comma-separated list of stock ticker symbols (e.g., AAPL,MSFT,GOOGL)", + ) + parser.add_argument( + "--end-date", + type=str, + default=datetime.now().strftime("%Y-%m-%d"), + help="End date in YYYY-MM-DD format", + ) + parser.add_argument( + "--start-date", + type=str, + default=(datetime.now() - relativedelta(months=1)).strftime("%Y-%m-%d"), + help="Start date in YYYY-MM-DD format", + ) + parser.add_argument( + "--initial-capital", + type=float, + default=100000, + help="Initial capital amount (default: 100000)", + ) + parser.add_argument( + "--margin-requirement", + type=float, + default=0.0, + help="Margin ratio for short positions, e.g. 0.5 for 50% (default: 0.0)", + ) + + args = parser.parse_args() + + # Parse tickers from comma-separated string + tickers = [ticker.strip() for ticker in args.tickers.split(",")] if args.tickers else [] + + # Choose analysts + selected_analysts = None + choices = questionary.checkbox( + "Use the Space bar to select/unselect analysts.", + choices=[questionary.Choice(display, value=value) for display, value in ANALYST_ORDER], + instruction="\n\nPress 'a' to toggle all.\n\nPress Enter when done to run the hedge fund.", + validate=lambda x: len(x) > 0 or "You must select at least one analyst.", + style=questionary.Style( + [ + ("checkbox-selected", "fg:green"), + ("selected", "fg:green noinherit"), + ("highlighted", "noinherit"), + ("pointer", "noinherit"), + ] + ), + ).ask() + + if not choices: + print("\n\nInterrupt received. Exiting...") + sys.exit(0) + else: + selected_analysts = choices + print( + f"\nSelected analysts: " + f"{', '.join(Fore.GREEN + choice.title().replace('_', ' ') + Style.RESET_ALL for choice in choices)}" + ) + + # Select LLM model + model_choice = questionary.select( + "Select your LLM model:", + choices=[questionary.Choice(display, value=value) for display, value, _ in LLM_ORDER], + style=questionary.Style([ + ("selected", "fg:green bold"), + ("pointer", "fg:green bold"), + ("highlighted", "fg:green"), + ("answer", "fg:green bold"), + ]) + ).ask() + + if not model_choice: + print("\n\nInterrupt received. Exiting...") + sys.exit(0) + else: + model_info = get_model_info(model_choice) + if model_info: + model_provider = model_info.provider.value + print(f"\nSelected {Fore.CYAN}{model_provider}{Style.RESET_ALL} model: {Fore.GREEN + Style.BRIGHT}{model_choice}{Style.RESET_ALL}\n") + else: + model_provider = "Unknown" + print(f"\nSelected model: {Fore.GREEN + Style.BRIGHT}{model_choice}{Style.RESET_ALL}\n") + + # Create and run the backtester + backtester = Backtester( + agent=run_hedge_fund, + tickers=tickers, + start_date=args.start_date, + end_date=args.end_date, + initial_capital=args.initial_capital, + model_name=model_choice, + model_provider=model_provider, + selected_analysts=selected_analysts, + initial_margin_requirement=args.margin_requirement, + ) + + performance_metrics = backtester.run_backtest() + performance_df = backtester.analyze_performance() diff --git a/run_app.sh b/run_app.sh new file mode 100755 index 00000000..e893f6cc --- /dev/null +++ b/run_app.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Activate poetry environment if it exists +if command -v poetry &> /dev/null; then + echo "Using Poetry to run Streamlit app..." + # Make sure dependencies are installed + poetry install + # Run the app + poetry run streamlit run app.py +else + # Fall back to regular Python if Poetry is not installed + echo "Poetry not found, using regular Python..." + # Install dependencies + pip install -r requirements.txt + # Run the app + streamlit run app.py +fi \ No newline at end of file diff --git a/src/backtester.py b/src/backtester.py index 90836653..23ac0a00 100644 --- a/src/backtester.py +++ b/src/backtester.py @@ -62,6 +62,10 @@ def __init__( # Store the margin ratio (e.g. 0.5 means 50% margin required). self.margin_ratio = initial_margin_requirement + # Store trading decisions and analyst signals for UI display + self.trading_decisions = {} + self.analyst_signals = {} + # Initialize portfolio with support for long/short positions self.portfolio_values = [] self.portfolio = { @@ -362,6 +366,20 @@ def run_backtest(self): ) decisions = output["decisions"] analyst_signals = output["analyst_signals"] + + # Store decisions with reasoning for this date + self.trading_decisions[current_date_str] = { + ticker: { + 'action': decision.get('action', 'hold'), + 'quantity': decision.get('quantity', 0), + 'confidence': decision.get('confidence', 0), + 'reasoning': decision.get('reasoning', 'No reasoning provided') + } + for ticker, decision in decisions.items() + } + + # Store analyst signals for this date + self.analyst_signals[current_date_str] = analyst_signals # Execute trades for each ticker executed_trades = {} diff --git a/src/llm/models.py b/src/llm/models.py index 2eaccb35..6580e494 100644 --- a/src/llm/models.py +++ b/src/llm/models.py @@ -7,7 +7,7 @@ from enum import Enum from pydantic import BaseModel from typing import Tuple - +import streamlit as st class ModelProvider(str, Enum): """Enum for supported LLM providers""" @@ -60,13 +60,13 @@ def is_gemini(self) -> bool: provider=ModelProvider.ANTHROPIC ), LLMModel( - display_name="[deepseek] deepseek-r1", - model_name="deepseek-reasoner", + display_name="[deepseek] deepseek-v3", + model_name="deepseek-chat", provider=ModelProvider.DEEPSEEK ), LLMModel( - display_name="[deepseek] deepseek-v3", - model_name="deepseek-chat", + display_name="[deepseek] deepseek-r1", + model_name="deepseek-reasoner", provider=ModelProvider.DEEPSEEK ), LLMModel( @@ -137,10 +137,15 @@ def get_model(model_name: str, model_provider: ModelProvider) -> ChatOpenAI | Ch return ChatAnthropic(model=model_name, api_key=api_key) elif model_provider == ModelProvider.DEEPSEEK: api_key = os.getenv("DEEPSEEK_API_KEY") - if not api_key: - print(f"API Key Error: Please make sure DEEPSEEK_API_KEY is set in your .env file.") - raise ValueError("DeepSeek API key not found. Please make sure DEEPSEEK_API_KEY is set in your .env file.") - return ChatDeepSeek(model=model_name, api_key=api_key) + if api_key: + return ChatDeepSeek(model=model_name, api_key=api_key) + else: + try: + api_key = st.secrets["DEEPSEEK_API_KEY"] if "DEEPSEEK_API_KEY" in st.secrets else None + except Exception as e: + print(f"API Key Error: Please make sure DEEPSEEK_API_KEY is set in your .env file.") + raise ValueError("DeepSeek API key not found. Please make sure DEEPSEEK_API_KEY is set in your .env file.") + return ChatDeepSeek(model=model_name, api_key=api_key) elif model_provider == ModelProvider.GEMINI: api_key = os.getenv("GOOGLE_API_KEY") if not api_key: diff --git a/src/main.py b/src/main.py index f69f2dba..f1af19ec 100644 --- a/src/main.py +++ b/src/main.py @@ -64,34 +64,43 @@ def run_hedge_fund( progress.start() try: - # Create a new workflow if analysts are customized - if selected_analysts: - workflow = create_workflow(selected_analysts) + # Get all available analyst nodes + analyst_nodes = get_analyst_nodes() + + # Filter out any invalid analyst keys + valid_selected_analysts = [analyst for analyst in selected_analysts if analyst in analyst_nodes] + + # Create a new workflow with valid analysts + if valid_selected_analysts: + workflow = create_workflow(valid_selected_analysts) agent = workflow.compile() else: - agent = app - - final_state = agent.invoke( - { - "messages": [ - HumanMessage( - content="Make trading decisions based on the provided data.", - ) - ], - "data": { - "tickers": tickers, - "portfolio": portfolio, - "start_date": start_date, - "end_date": end_date, - "analyst_signals": {}, - }, - "metadata": { - "show_reasoning": show_reasoning, - "model_name": model_name, - "model_provider": model_provider, - }, + # If no valid analysts, use default workflow + workflow = create_workflow() + agent = workflow.compile() + + # Initialize the agent state with proper structure + initial_state = { + "messages": [ + HumanMessage( + content="Make trading decisions based on the provided data.", + ) + ], + "data": { + "tickers": tickers, + "portfolio": portfolio, + "start_date": start_date, + "end_date": end_date, + "analyst_signals": {}, + }, + "metadata": { + "show_reasoning": show_reasoning, + "model_name": model_name, + "model_provider": model_provider, }, - ) + } + + final_state = agent.invoke(initial_state) return { "decisions": parse_hedge_fund_response(final_state["messages"][-1].content), @@ -104,6 +113,9 @@ def run_hedge_fund( def start(state: AgentState): """Initialize the workflow with the input message.""" + # Ensure the data structure is properly initialized + if "analyst_signals" not in state["data"]: + state["data"]["analyst_signals"] = {} return state @@ -118,8 +130,12 @@ def create_workflow(selected_analysts=None): # Default to all analysts if none selected if selected_analysts is None: selected_analysts = list(analyst_nodes.keys()) + + # Filter out any invalid analyst keys + valid_selected_analysts = [analyst for analyst in selected_analysts if analyst in analyst_nodes] + # Add selected analyst nodes - for analyst_key in selected_analysts: + for analyst_key in valid_selected_analysts: node_name, node_func = analyst_nodes[analyst_key] workflow.add_node(node_name, node_func) workflow.add_edge("start_node", node_name) @@ -129,7 +145,7 @@ def create_workflow(selected_analysts=None): workflow.add_node("portfolio_management_agent", portfolio_management_agent) # Connect selected analysts to risk management - for analyst_key in selected_analysts: + for analyst_key in valid_selected_analysts: node_name = analyst_nodes[analyst_key][0] workflow.add_edge(node_name, "risk_management_agent") diff --git a/src/tools/api.py b/src/tools/api.py index d2cc064b..c4a00cc0 100644 --- a/src/tools/api.py +++ b/src/tools/api.py @@ -1,6 +1,7 @@ import os import pandas as pd import requests +import streamlit as st from data.cache import get_cache from data.models import ( @@ -20,6 +21,40 @@ _cache = get_cache() +def get_financial_datasets_api_key(): + """ + Get the Financial Datasets API key from environment variables or Streamlit secrets. + First checks environment variables, then falls back to Streamlit secrets if available. + + Returns: + str: The API key + + Raises: + ValueError: If the API key is not found in either environment variables or Streamlit secrets + """ + # First try to get from environment variables + api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY") + if api_key: + return api_key + + # If not in environment variables, try Streamlit secrets + try: + # This will raise an exception if we're not in a Streamlit context + # (e.g., when running from CLI or in tests) + api_key = st.secrets.get("FINANCIAL_DATASETS_API_KEY") + if api_key: + return api_key + except (RuntimeError, AttributeError): + # Not in a Streamlit context or key doesn't exist in secrets + pass + + # If we get here, the API key was not found in either place + raise ValueError( + "Financial Datasets API key not found. " + "Please set FINANCIAL_DATASETS_API_KEY in either environment variables or Streamlit secrets." + ) + + def get_prices(ticker: str, start_date: str, end_date: str) -> list[Price]: """Fetch price data from cache or API.""" # Check cache first @@ -31,7 +66,7 @@ def get_prices(ticker: str, start_date: str, end_date: str) -> list[Price]: # If not in cache or no data in range, fetch from API headers = {} - if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): + if api_key := get_financial_datasets_api_key(): headers["X-API-KEY"] = api_key url = f"https://api.financialdatasets.ai/prices/?ticker={ticker}&interval=day&interval_multiplier=1&start_date={start_date}&end_date={end_date}" @@ -68,7 +103,7 @@ def get_financial_metrics( # If not in cache or insufficient data, fetch from API headers = {} - if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): + if api_key := get_financial_datasets_api_key(): headers["X-API-KEY"] = api_key url = f"https://api.financialdatasets.ai/financial-metrics/?ticker={ticker}&report_period_lte={end_date}&limit={limit}&period={period}" @@ -99,7 +134,7 @@ def search_line_items( """Fetch line items from API.""" # If not in cache or insufficient data, fetch from API headers = {} - if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): + if api_key := get_financial_datasets_api_key(): headers["X-API-KEY"] = api_key url = "https://api.financialdatasets.ai/financials/search/line-items" @@ -143,7 +178,7 @@ def get_insider_trades( # If not in cache or insufficient data, fetch from API headers = {} - if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): + if api_key := get_financial_datasets_api_key(): headers["X-API-KEY"] = api_key all_trades = [] @@ -206,7 +241,7 @@ def get_company_news( # If not in cache or insufficient data, fetch from API headers = {} - if api_key := os.environ.get("FINANCIAL_DATASETS_API_KEY"): + if api_key := get_financial_datasets_api_key(): headers["X-API-KEY"] = api_key all_news = []