Skip to content

Commit cfc590c

Browse files
committed
Create finrl_simple.py
1 parent 2df1ff5 commit cfc590c

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed

examples/finrl_simple.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple example showing how to use the FinRL environment with OpenEnv.
4+
5+
This example demonstrates:
6+
1. Connecting to a FinRL environment server
7+
2. Resetting the environment
8+
3. Executing random trading actions
9+
4. Tracking portfolio value over time
10+
"""
11+
12+
import sys
13+
from pathlib import Path
14+
15+
import numpy as np
16+
17+
# Add src to path
18+
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
19+
20+
from envs.finrl_env import FinRLAction, FinRLEnv
21+
22+
23+
def main():
24+
"""Run a simple FinRL environment example."""
25+
print("=" * 70)
26+
print("FinRL Environment - Simple Example")
27+
print("=" * 70)
28+
print()
29+
30+
# Connect to server
31+
print("Connecting to FinRL environment at http://localhost:8000...")
32+
try:
33+
client = FinRLEnv(base_url="http://localhost:8000")
34+
except Exception as e:
35+
print(f"❌ Failed to connect to server: {e}")
36+
print()
37+
print("Make sure the server is running:")
38+
print(" docker run -p 8000:8000 finrl-env:latest")
39+
return False
40+
41+
print("✅ Connected successfully!")
42+
print()
43+
44+
# Get configuration
45+
try:
46+
config = client.get_config()
47+
print("Environment Configuration:")
48+
print(f" Stock dimension: {config['stock_dim']}")
49+
print(f" Initial amount: ${config['initial_amount']:,.0f}")
50+
print(f" Action space: {config['action_space']}")
51+
print(f" State space: {config['state_space']}")
52+
print(f" Technical indicators: {', '.join(config['tech_indicators'])}")
53+
print()
54+
except Exception as e:
55+
print(f"⚠️ Could not fetch config: {e}")
56+
print()
57+
58+
# Reset environment
59+
print("Resetting environment...")
60+
result = client.reset()
61+
print(f"✅ Environment reset successfully!")
62+
print(f" Initial portfolio value: ${result.observation.portfolio_value:,.2f}")
63+
print(f" State dimension: {len(result.observation.state)}")
64+
if result.observation.date:
65+
print(f" Starting date: {result.observation.date}")
66+
print()
67+
68+
# Run trading simulation
69+
print("-" * 70)
70+
print("Running 20-step trading simulation with random actions...")
71+
print("-" * 70)
72+
print()
73+
74+
portfolio_history = [result.observation.portfolio_value]
75+
cumulative_reward = 0
76+
77+
for step in range(20):
78+
# Get current state
79+
state = result.observation.state
80+
81+
# Generate random actions (in real use, replace with your RL policy)
82+
num_stocks = config.get("stock_dim", 1)
83+
actions = np.random.uniform(-0.5, 0.5, size=num_stocks).tolist()
84+
85+
# Execute action
86+
result = client.step(FinRLAction(actions=actions))
87+
88+
# Track metrics
89+
portfolio_history.append(result.observation.portfolio_value)
90+
cumulative_reward += result.reward or 0
91+
92+
# Print progress
93+
print(
94+
f"Step {step + 1:2d}: "
95+
f"Portfolio=${result.observation.portfolio_value:>12,.2f} | "
96+
f"Reward={result.reward:>8.2f} | "
97+
f"Date={result.observation.date}"
98+
)
99+
100+
if result.done:
101+
print()
102+
print("Episode finished!")
103+
break
104+
105+
# Summary
106+
print()
107+
print("-" * 70)
108+
print("Trading Simulation Complete")
109+
print("-" * 70)
110+
print(f"Initial portfolio value: ${portfolio_history[0]:,.2f}")
111+
print(f"Final portfolio value: ${portfolio_history[-1]:,.2f}")
112+
print(
113+
f"Total return: ${portfolio_history[-1] - portfolio_history[0]:,.2f} "
114+
f"({((portfolio_history[-1] / portfolio_history[0] - 1) * 100):.2f}%)"
115+
)
116+
print(f"Cumulative reward: {cumulative_reward:.2f}")
117+
print(f"Steps executed: {len(portfolio_history) - 1}")
118+
print()
119+
120+
# Plot portfolio value over time (if matplotlib available)
121+
try:
122+
import matplotlib.pyplot as plt
123+
124+
plt.figure(figsize=(10, 6))
125+
plt.plot(portfolio_history, marker="o", linewidth=2)
126+
plt.title("Portfolio Value Over Time", fontsize=14, fontweight="bold")
127+
plt.xlabel("Step", fontsize=12)
128+
plt.ylabel("Portfolio Value ($)", fontsize=12)
129+
plt.grid(True, alpha=0.3)
130+
plt.tight_layout()
131+
plt.savefig("finrl_portfolio_history.png", dpi=150)
132+
print("📊 Portfolio chart saved to: finrl_portfolio_history.png")
133+
print()
134+
except ImportError:
135+
pass
136+
137+
# Cleanup
138+
print("Closing connection...")
139+
client.close()
140+
print("✅ Done!")
141+
print()
142+
143+
print("=" * 70)
144+
print("Example completed successfully! 🎉")
145+
print("=" * 70)
146+
147+
return True
148+
149+
150+
if __name__ == "__main__":
151+
success = main()
152+
sys.exit(0 if success else 1)

0 commit comments

Comments
 (0)