train_local.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. from clickhouse_driver import Client
  2. from common.log_utils import logFactory
  3. from common.database_utils import database_util
  4. from common import constant
  5. import numpy as np
  6. import math
  7. import pandas as pd
  8. from random import shuffle
  9. from tqdm import tqdm
  10. import optuna
  11. from sklearn.metrics import log_loss
  12. from sklearn.model_selection import train_test_split
  13. from optuna.integration import LightGBMPruningCallback
  14. import lightgbm as lgb
  15. from lightgbm.sklearn import LGBMClassifier
  16. from lightgbm import plot_importance
  17. from sklearn import metrics
  18. import matplotlib.pyplot as plt
  19. import seaborn as sns
  20. import random
  21. click_client = database_util.get_client()
  22. logger = logFactory("preprocess data").log
  23. #
  24. # # 获取正负样本数量,准备随机id
  25. # tb_pos = constant.insert_pos_tb_name
  26. # tb_neg = constant.insert_neg_tb_name
  27. # positive_sql = f"select count(1) from {tb_pos}"
  28. # positive_sql_result = click_client.execute(positive_sql)[0][0]
  29. # positive_ids = range(0, positive_sql_result)
  30. # negative_sql = f"select count(1) from {tb_neg}"
  31. # negative_sql_result = click_client.execute(negative_sql)[0][0]
  32. # negative_ids = range(0, negative_sql_result)
  33. #
  34. # # 每个批次的数据量
  35. # # batch_size = 100000
  36. # batch_size = 10
  37. # # 分组数量
  38. # batch_num_pos = 1
  39. # batch_num_neg = 1
  40. #
  41. #
  42. # 根据每个批次的数据量计算出每个批次的row_id
  43. def partition_preserve_order(list_in, n):
  44. indices = list(range(len(list_in)))
  45. shuffle(indices)
  46. index_partitions = [sorted(indices[i::n]) for i in range(n)]
  47. return [[list_in[i] for i in index_partition]
  48. for index_partition in index_partitions]
  49. #
  50. #
  51. # def gen_train_tuple(pos_row_ids, neg_row_ids):
  52. # result = []
  53. # for i, e, in enumerate(neg_row_ids):
  54. # pos_index = i % batch_num_pos
  55. # result.append((pos_row_ids[pos_index], neg_row_ids[i]))
  56. # return result
  57. #
  58. #
  59. # pos_row_ids = partition_preserve_order(positive_ids, batch_num_pos)
  60. # neg_row_ids = partition_preserve_order(negative_ids, batch_num_neg)
  61. # neg_row_ids = [random.sample(neg_row_ids[0], 1000000)]
  62. # pos_row_ids = [random.sample(pos_row_ids[0], 1000000)]
  63. # train_tuple = gen_train_tuple(pos_row_ids, neg_row_ids)
  64. DROP_COLUMNS = ['row_id', 'month',
  65. 'EVENT_CATEGORYNAME_C_0',
  66. 'EVENT_CATEGORYNAME_C_1',
  67. 'EVENT_CATEGORYNAME_C_2',
  68. 'EVENT_CATEGORYNAME_C_3',
  69. 'EVENT_CATEGORYNAME_C_4',
  70. 'EVENT_CATEGORYNAME_C_5',
  71. 'EVENT_CATEGORYNAME_C_6',
  72. 'EVENT_CATEGORYNAME_C_7',
  73. 'EVENT_CATEGORYNAME_C_8',
  74. 'EVENT_CATEGORYNAME_C_9',
  75. 'EVENT_CATEGORYNAME_C_10',
  76. "EVENT_CANCEL_DIFF_C_0",
  77. "EVENT_CANCEL_DIFF_C_1",
  78. "EVENT_CANCEL_DIFF_C_2",
  79. "EVENT_CANCEL_DIFF_C_3",
  80. "EVENT_CANCEL_DIFF_C_4",
  81. "EVENT_CANCEL_DIFF_C_5",
  82. "EVENT_CANCEL_DIFF_C_6"
  83. ]
  84. def draw_roc_auc(y_label, y_test):
  85. # ROC曲线绘制
  86. fpr, tpr, thresholds = metrics.roc_curve(y_label, y_test)
  87. ##计算曲线下面积
  88. roc_auc = metrics.auc(fpr, tpr)
  89. ##绘图
  90. plt.clf()
  91. plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
  92. plt.plot([0, 1], [0, 1], 'k--')
  93. plt.xlim([0.0, 1.0])
  94. plt.ylim([0.0, 1.0])
  95. plt.legend(loc="lower right")
  96. plt.show()
  97. def draw_confusion_matrix(y_label, y_test):
  98. # 画出混淆矩阵
  99. confusion_data = metrics.confusion_matrix(y_label, y_test)
  100. print(confusion_data)
  101. sns.heatmap(confusion_data, cmap="Greens", annot=True)
  102. plt.xlabel("Predicted labels")
  103. plt.ylabel("True labels")
  104. plt.tight_layout()
  105. plt.show()
  106. def start_train():
  107. logger.info("准备开始加载数据")
  108. lgb_model = None
  109. train_params = {
  110. 'task': 'train',
  111. 'objective': 'binary',
  112. 'boosting_type': 'gbdt',
  113. 'learning_rate': 0.1,
  114. 'num_leaves': 30,
  115. 'tree_learner': 'serial',
  116. 'metric': {'binary_logloss', 'auc', 'average_precision'}, # l1:mae, l2:mse
  117. # 'max_bin': 3, # 较小的max_bin会导致更快的速度,较大的值会提高准确性
  118. 'max_depth': 6,
  119. # 'min_child_samples': 5,
  120. "bagging_fraction": 0.8, # 样本采样比例,同 XGBoost ,调小可以防止过拟合,加快运算速度
  121. "feature_fraction": 0.8, # 样本采样比例,同 XGBoost ,调小可以防止过拟合,加快运算速度
  122. "n_jobs": 8,
  123. "boost_from_average": False,
  124. 'seed': 2022,
  125. "lambda_l1": 1e-5,
  126. "lambda_l2": 1e-5,
  127. }
  128. total_train_data = pd.read_pickle('total_train_data.pkl')
  129. # 划分训练集和测试集
  130. data_train, data_test = train_test_split(total_train_data, train_size=0.7)
  131. # data_train, data_test = train_test_split(total_train_data, total_train_data['mark'].values, train_size=0.8,
  132. # random_state=0, stratify=total_train_data['mark'].values)
  133. # df1 = data_train[data_train['mark']==0]
  134. # df2 = data_train[data_train['mark']==1]
  135. # df11 = data_test[data_test['mark']==0]
  136. # df12 = data_test[data_test['mark']==1]
  137. train_x_data = data_train[constant.feature_column_names]
  138. train_x_data.drop(columns=DROP_COLUMNS, inplace=True)
  139. train_y_data = data_train[['mark']]
  140. test_x_data = data_test[constant.feature_column_names]
  141. test_x_data.drop(columns=DROP_COLUMNS, inplace=True)
  142. test_y_data = data_test[['mark']]
  143. feature_col = [x for x in constant.feature_column_names if x not in DROP_COLUMNS]
  144. # 创建lgb的数据集
  145. lgb_train = lgb.Dataset(train_x_data, train_y_data.values, silent=True)
  146. lgb_eval = lgb.Dataset(test_x_data, test_y_data.values, reference=lgb_train, silent=True)
  147. # clf = LGBMClassifier()
  148. # clf.fit(train_x_data, train_y_data.values)
  149. # train_predict = clf.predict(train_x_data)
  150. # test_predict = clf.predict(test_x_data)
  151. # train_predict_score = metrics.accuracy_score(train_y_data.values, train_predict)
  152. # test_predict_score = metrics.accuracy_score(test_y_data.values, test_predict)
  153. # print(train_predict_score)
  154. # print(test_predict_score)
  155. # confusion_data = metrics.confusion_matrix(test_y_data.values, test_predict)
  156. # print(confusion_data)
  157. # sns.heatmap(confusion_data, cmap="Greens", annot=True)
  158. # plt.xlabel("Predicted labels")
  159. # plt.ylabel("True labels")
  160. # plt.tight_layout()
  161. # plt.show()
  162. lgb_model = lgb.train(params=train_params, train_set=lgb_train, num_boost_round=100, valid_sets=lgb_eval,
  163. init_model=lgb_model, feature_name=feature_col,
  164. verbose_eval=False, keep_training_booster=True)
  165. # lgb_model = lgb.cv(params=train_params, train_set=lgb_train, num_boost_round=1000, nfold=5,stratified=False,shuffle=True, early_stopping_rounds=50,verbose_eval=50,show_stdv=True)
  166. lgb_model.save_model("temp.model")
  167. # 输出模型评估分数
  168. score_train = str(dict([(s[1], s[2]) for s in lgb_model.eval_train()]))
  169. score_valid = str(dict([(s[1], s[2]) for s in lgb_model.eval_valid()]))
  170. logger.info(f"在训练集上:{score_train}")
  171. logger.info(f"在测试集上:{score_valid}")
  172. result = pd.DataFrame({
  173. 'column': feature_col,
  174. 'importance': lgb_model.feature_importance(),
  175. }).sort_values(by='importance')
  176. sns.barplot(y=train_x_data.columns, x=lgb_model.feature_importance())
  177. plt.tight_layout()
  178. plt.show()
  179. # pass
  180. test_predict = lgb_model.predict(test_x_data)
  181. test_result = []
  182. for x in test_predict:
  183. if x < 0.5:
  184. test_result.append(0)
  185. else:
  186. test_result.append(1)
  187. result = metrics.classification_report(test_y_data.values, test_result)
  188. logger.info(result)
  189. draw_confusion_matrix(test_y_data.values.reshape([-1,]), test_result)
  190. draw_roc_auc(test_y_data.values.reshape([-1,]), test_predict)
  191. #
  192. # ##计算相关数据:注意返回的结果顺序
  193. # fpr, tpr, thresholds = metrics.roc_curve(test_y_data.values.reshape([-1,]), test_predict)
  194. # ##计算曲线下面积
  195. # roc_auc = metrics.auc(fpr, tpr)
  196. # ##绘图
  197. # plt.clf()
  198. # plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
  199. # plt.plot([0, 1], [0, 1], 'k--')
  200. # plt.xlim([0.0, 1.0])
  201. # plt.ylim([0.0, 1.0])
  202. # plt.xlabel('False Positive Rate')
  203. # plt.ylabel('True Positive Rate')
  204. # plt.legend(loc="lower right")
  205. # plt.show()
  206. pass
  207. def test_data_valid():
  208. logger.info("准备开始加载数据")
  209. # # 获取正负样本数量,准备随机id
  210. tb_pos = constant.insert_pos_tb_test_name
  211. tb_neg = constant.insert_neg_tb_test_name
  212. positive_sql = f"select t.* from {tb_pos} t limit 1000000"
  213. positive_sql_result = click_client.execute(positive_sql)
  214. negative_sql = f"select t.* from {tb_neg} t limit 1000000"
  215. negative_sql_result = click_client.execute(negative_sql)
  216. dataf_pos = pd.DataFrame(positive_sql_result, columns=constant.feature_column_names)
  217. dataf_pos['mark'] = 0
  218. dataf_neg = pd.DataFrame(negative_sql_result, columns=constant.feature_column_names)
  219. dataf_neg['mark'] = 1
  220. test_all_data = pd.concat([dataf_pos, dataf_neg], axis=0)
  221. test_x_data = test_all_data[constant.feature_column_names]
  222. test_x_data.drop(columns=DROP_COLUMNS, inplace=True)
  223. test_y_data = test_all_data[['mark']]
  224. lgb_model = lgb.Booster(model_file='temp.model')
  225. test_predict = lgb_model.predict(test_x_data, num_iteration=lgb_model.best_iteration)
  226. test_result = []
  227. for x in test_predict:
  228. if x < 0.5:
  229. test_result.append(0)
  230. else:
  231. test_result.append(1)
  232. result = metrics.classification_report(test_y_data.values, test_result)
  233. logger.info(result)
  234. draw_confusion_matrix(test_y_data.values, test_result)
  235. draw_roc_auc(test_y_data.values.reshape([-1,]), test_predict)
  236. pass
  237. if __name__ == "__main__":
  238. # start_train()
  239. test_data_valid()