forked from notadamking/RLTrader
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
34 lines (24 loc) · 892 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import gym
import pandas as pd
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import A2C
from env.BitcoinTradingEnv import BitcoinTradingEnv
df = pd.read_csv('./data/bitstamp.csv')
df = df.sort_values('Timestamp')
slice_point = int(len(df) - 50000)
train_df = df[:slice_point]
test_df = df[slice_point:]
train_env = DummyVecEnv(
[lambda: BitcoinTradingEnv(train_df, serial=True)])
model = A2C(MlpPolicy, train_env, verbose=1,
tensorboard_log="./tensorboard/")
model.learn(total_timesteps=200000)
test_env = DummyVecEnv(
[lambda: BitcoinTradingEnv(test_df, serial=True)])
obs = test_env.reset()
for i in range(50000):
action, _states = model.predict(obs)
obs, rewards, done, info = test_env.step(action)
test_env.render(mode="system", title="BTC")
test_env.close()