|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Simple example showing how to use the FinRL environment with OpenEnv. |
| 4 | +
|
| 5 | +This example demonstrates: |
| 6 | +1. Connecting to a FinRL environment server |
| 7 | +2. Resetting the environment |
| 8 | +3. Executing random trading actions |
| 9 | +4. Tracking portfolio value over time |
| 10 | +""" |
| 11 | + |
| 12 | +import sys |
| 13 | +from pathlib import Path |
| 14 | + |
| 15 | +import numpy as np |
| 16 | + |
| 17 | +# Add src to path |
| 18 | +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) |
| 19 | + |
| 20 | +from envs.finrl_env import FinRLAction, FinRLEnv |
| 21 | + |
| 22 | + |
| 23 | +def main(): |
| 24 | + """Run a simple FinRL environment example.""" |
| 25 | + print("=" * 70) |
| 26 | + print("FinRL Environment - Simple Example") |
| 27 | + print("=" * 70) |
| 28 | + print() |
| 29 | + |
| 30 | + # Connect to server |
| 31 | + print("Connecting to FinRL environment at http://localhost:8000...") |
| 32 | + try: |
| 33 | + client = FinRLEnv(base_url="http://localhost:8000") |
| 34 | + except Exception as e: |
| 35 | + print(f"❌ Failed to connect to server: {e}") |
| 36 | + print() |
| 37 | + print("Make sure the server is running:") |
| 38 | + print(" docker run -p 8000:8000 finrl-env:latest") |
| 39 | + return False |
| 40 | + |
| 41 | + print("✅ Connected successfully!") |
| 42 | + print() |
| 43 | + |
| 44 | + # Get configuration |
| 45 | + try: |
| 46 | + config = client.get_config() |
| 47 | + print("Environment Configuration:") |
| 48 | + print(f" Stock dimension: {config['stock_dim']}") |
| 49 | + print(f" Initial amount: ${config['initial_amount']:,.0f}") |
| 50 | + print(f" Action space: {config['action_space']}") |
| 51 | + print(f" State space: {config['state_space']}") |
| 52 | + print(f" Technical indicators: {', '.join(config['tech_indicators'])}") |
| 53 | + print() |
| 54 | + except Exception as e: |
| 55 | + print(f"⚠️ Could not fetch config: {e}") |
| 56 | + print() |
| 57 | + |
| 58 | + # Reset environment |
| 59 | + print("Resetting environment...") |
| 60 | + result = client.reset() |
| 61 | + print(f"✅ Environment reset successfully!") |
| 62 | + print(f" Initial portfolio value: ${result.observation.portfolio_value:,.2f}") |
| 63 | + print(f" State dimension: {len(result.observation.state)}") |
| 64 | + if result.observation.date: |
| 65 | + print(f" Starting date: {result.observation.date}") |
| 66 | + print() |
| 67 | + |
| 68 | + # Run trading simulation |
| 69 | + print("-" * 70) |
| 70 | + print("Running 20-step trading simulation with random actions...") |
| 71 | + print("-" * 70) |
| 72 | + print() |
| 73 | + |
| 74 | + portfolio_history = [result.observation.portfolio_value] |
| 75 | + cumulative_reward = 0 |
| 76 | + |
| 77 | + for step in range(20): |
| 78 | + # Get current state |
| 79 | + state = result.observation.state |
| 80 | + |
| 81 | + # Generate random actions (in real use, replace with your RL policy) |
| 82 | + num_stocks = config.get("stock_dim", 1) |
| 83 | + actions = np.random.uniform(-0.5, 0.5, size=num_stocks).tolist() |
| 84 | + |
| 85 | + # Execute action |
| 86 | + result = client.step(FinRLAction(actions=actions)) |
| 87 | + |
| 88 | + # Track metrics |
| 89 | + portfolio_history.append(result.observation.portfolio_value) |
| 90 | + cumulative_reward += result.reward or 0 |
| 91 | + |
| 92 | + # Print progress |
| 93 | + print( |
| 94 | + f"Step {step + 1:2d}: " |
| 95 | + f"Portfolio=${result.observation.portfolio_value:>12,.2f} | " |
| 96 | + f"Reward={result.reward:>8.2f} | " |
| 97 | + f"Date={result.observation.date}" |
| 98 | + ) |
| 99 | + |
| 100 | + if result.done: |
| 101 | + print() |
| 102 | + print("Episode finished!") |
| 103 | + break |
| 104 | + |
| 105 | + # Summary |
| 106 | + print() |
| 107 | + print("-" * 70) |
| 108 | + print("Trading Simulation Complete") |
| 109 | + print("-" * 70) |
| 110 | + print(f"Initial portfolio value: ${portfolio_history[0]:,.2f}") |
| 111 | + print(f"Final portfolio value: ${portfolio_history[-1]:,.2f}") |
| 112 | + print( |
| 113 | + f"Total return: ${portfolio_history[-1] - portfolio_history[0]:,.2f} " |
| 114 | + f"({((portfolio_history[-1] / portfolio_history[0] - 1) * 100):.2f}%)" |
| 115 | + ) |
| 116 | + print(f"Cumulative reward: {cumulative_reward:.2f}") |
| 117 | + print(f"Steps executed: {len(portfolio_history) - 1}") |
| 118 | + print() |
| 119 | + |
| 120 | + # Plot portfolio value over time (if matplotlib available) |
| 121 | + try: |
| 122 | + import matplotlib.pyplot as plt |
| 123 | + |
| 124 | + plt.figure(figsize=(10, 6)) |
| 125 | + plt.plot(portfolio_history, marker="o", linewidth=2) |
| 126 | + plt.title("Portfolio Value Over Time", fontsize=14, fontweight="bold") |
| 127 | + plt.xlabel("Step", fontsize=12) |
| 128 | + plt.ylabel("Portfolio Value ($)", fontsize=12) |
| 129 | + plt.grid(True, alpha=0.3) |
| 130 | + plt.tight_layout() |
| 131 | + plt.savefig("finrl_portfolio_history.png", dpi=150) |
| 132 | + print("📊 Portfolio chart saved to: finrl_portfolio_history.png") |
| 133 | + print() |
| 134 | + except ImportError: |
| 135 | + pass |
| 136 | + |
| 137 | + # Cleanup |
| 138 | + print("Closing connection...") |
| 139 | + client.close() |
| 140 | + print("✅ Done!") |
| 141 | + print() |
| 142 | + |
| 143 | + print("=" * 70) |
| 144 | + print("Example completed successfully! 🎉") |
| 145 | + print("=" * 70) |
| 146 | + |
| 147 | + return True |
| 148 | + |
| 149 | + |
| 150 | +if __name__ == "__main__": |
| 151 | + success = main() |
| 152 | + sys.exit(0 if success else 1) |
0 commit comments