From 00607ca097a81f6cb17b04bdd0f041f488b54adf Mon Sep 17 00:00:00 2001 From: ryan6073 Date: Sun, 26 Nov 2023 12:21:48 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=9D=E5=AD=98=E9=A2=84=E6=B5=8B=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=E5=88=B0=E6=9C=AC=E5=9C=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- yugou_best.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/yugou_best.py b/yugou_best.py index d6b4baf..51b9c92 100644 --- a/yugou_best.py +++ b/yugou_best.py @@ -111,7 +111,6 @@ X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.55, random_state=10) - # 初始化梯度提升树分类器 gradient_boosting = GradientBoostingClassifier(n_estimators=100, random_state=42) # 这里 n_estimators 表示基学习器的数量 @@ -132,9 +131,9 @@ # 定义参数网格 param_grid = { - 'n_estimators': [15,25,50], # 调整基学习器的数量 - 'learning_rate': [0.003,0.005,0.01], # 学习率 - 'max_depth': [1,3, 5] # 调整树的深度 + 'n_estimators': [15, 25, 50], # 调整基学习器的数量 + 'learning_rate': [0.003, 0.005, 0.01], # 学习率 + 'max_depth': [1, 3, 5] # 调整树的深度 # 其他需要调整的参数 } @@ -145,7 +144,6 @@ # 输出最佳参数组合和对应的准确率 print("Best Parameters:", grid_search.best_params_) - # 使用最佳参数组合重新训练模型 best_gradient_boosting = grid_search.best_estimator_ best_gradient_boosting.fit(X_train_selected, y_train) @@ -155,4 +153,19 @@ # 计算准确率 accuracy = accuracy_score(y_test, predictions) -print("Gradient Boosting Accuracy after Parameter Tuning:", accuracy) \ No newline at end of file +print("Gradient Boosting Accuracy after Parameter Tuning:", accuracy) + +# choose = ["user_id", "merchant_id", "mlp_prob"] +# res = df_test[choose] +# res.rename(columns={"mlp_prob": "prob"}, inplace=True) +# print(res.head(10)) +# res.to_csv(path_or_buf=r"data/prediction.csv", index=False) + +pX = df_test.drop(['user_id', 'merchant_id'], axis=1) +pX_selected = select_features.transform(pX) +pPredictions = best_gradient_boosting.predict_proba(pX_selected) +df_test['prob'] = pPredictions +choose = ["user_id", "merchant_id", "label"] +res = df_test[choose] +print(res.head(10)) +res.to_csv(path_or_buf=r"data/prediction.csv", index=False) \ No newline at end of file