-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_random_forest.py
More file actions
115 lines (89 loc) · 3.81 KB
/
train_random_forest.py
File metadata and controls
115 lines (89 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from matplotlib import pyplot as plt
from sklearn.tree import plot_tree
from ucimlrepo import fetch_ucirepo
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, r2_score, root_mean_squared_error
import pickle
def load_data():
ds = fetch_ucirepo(id=492)
X = ds.data.features.copy()
y = ds.data.targets.copy()
# ensure y is a 1-d array/series
if isinstance(y, pd.DataFrame) and y.shape[1] == 1:
y = y.iloc[:, 0]
return X, y
def build_features(X, y, lags=(1, 3, 24)):
df = X.copy().reset_index(drop=True)
y = y.copy().reset_index(drop=True)
df['traffic_volume'] = y
if 'date_time' in df.columns:
df['date_time'] = pd.to_datetime(df['date_time'])
df = df.sort_values('date_time').reset_index(drop=True)
df['hour'] = df['date_time'].dt.hour
df['dayofweek'] = df['date_time'].dt.dayofweek
else:
df['hour'] = df.index % 24
df['dayofweek'] = (df.index // 24) % 7
for lag in lags:
df[f'lag_{lag}'] = df['traffic_volume'].shift(lag)
candidate_cols = []
if 'weather_main' in df.columns:
candidate_cols.append('weather_main')
if 'weather_description' in df.columns:
candidate_cols.append('weather_description')
candidate_cols += ['hour', 'dayofweek']
candidate_cols += [f'lag_{lag}' for lag in lags]
X_features = df[candidate_cols].copy()
y_target = df['traffic_volume'].copy()
# drop rows with NaNs (e.g., from lagging)
valid_idx = X_features.dropna().index
return X_features.loc[valid_idx].reset_index(drop=True), y_target.loc[valid_idx].reset_index(drop=True)
def load_pipeline(path='rf_pipeline.pkl'):
if not os.path.exists(path):
raise FileNotFoundError(f"Pipeline file not found: {path}. Run train_random_forest.py first.")
with open(path, 'rb') as f:
return pickle.load(f)
def preprocess_and_train(X, y, lags=(1, 3, 24)):
X_features, y_target = build_features(X, y, lags=lags)
cat_cols = X_features.select_dtypes(include=['object', 'category']).columns.tolist()
preprocessor = ColumnTransformer(
transformers=[
('cat', OneHotEncoder(handle_unknown='ignore', sparse_output=False), cat_cols),
],
remainder='passthrough'
)
pipeline = Pipeline([
('preproc', preprocessor),
('rf', RandomForestRegressor(n_estimators=200, random_state=42, n_jobs=-1))
])
X_train, _, y_train, _ = train_test_split(X_features, y_target, test_size=0.2, random_state=42)
pipeline.fit(X_train, y_train)
return pipeline
def visualize_tree(pipeline, feature_names):
rf = pipeline.named_steps['rf']
if hasattr(rf, 'estimators_'):
plt.figure(figsize=(20, 10))
plot_tree(rf.estimators_[0], feature_names=feature_names, filled=True, max_depth=3)
plt.title('Visualization of first tree in the Random Forest')
plt.show()
else:
print("The model does not have individual trees to visualize.")
def main():
X, y = load_data()
model_pipeline = preprocess_and_train(X, y, lags=(1, 3, 24))
#Save full pipeline (preprocessing + model)
with open('rf_pipeline.pkl', 'wb') as f:
pickle.dump(model_pipeline, f)
print('Saved pipeline to rf_pipeline.pkl')
# Visualize the first tree in the random forest
feature_names = model_pipeline.named_steps['preproc'].get_feature_names_out()
visualize_tree(model_pipeline, feature_names)
if __name__ == '__main__':
main()