用于回归的 XGBoost
XGBoost (eXtreme Gradient Boosting) 是一种很受欢迎的监督式学习算法,用于对大型数据集进行回归和分类。它使用顺序构建的浅层决策树来提供准确的结果和高度可扩展的定型方法,以避免过度拟合。
以下 XGBoost 函数使用回归模型创建和执行预测:
示例
此示例使用名为 "mtcars" 的小型数据集(其中包含 1973-1974 年 32 辆汽车的设计和性能数据),并创建 XGBoost 回归模型来预测变量 carb
的值(化油器的数量)。
-
使用
XGB_REGRESSOR
从mtcars
数据集创建 XGBoost 回归模型xgb_cars
。=> SELECT XGB_REGRESSOR ('xgb_cars', 'mtcars', 'carb', 'mpg, cyl, hp, drat, wt' USING PARAMETERS learning_rate=0.5); XGB_REGRESSOR --------------- Finished (1 row)
然后,您可以使用
GET_MODEL_SUMMARY
查看模型的摘要:=> SELECT GET_MODEL_SUMMARY(USING PARAMETERS model_name='xgb_cars'); GET_MODEL_SUMMARY ------------------------------------------------------ =========== call_string =========== xgb_regressor('public.xgb_cars', 'mtcars', '"carb"', 'mpg, cyl, hp, drat, wt' USING PARAMETERS exclude_columns='', max_ntree=10, max_depth=5, nbins=32, objective=squarederror, split_proposal_method=global, epsilon=0.001, learning_rate=0.5, min_split_loss=0, weight_reg=0, sampling_size=1) ======= details ======= predictor| type ---------+---------------- mpg |float or numeric cyl | int hp | int drat |float or numeric wt |float or numeric =============== Additional Info =============== Name |Value ------------------+----- tree_count | 10 rejected_row_count| 0 accepted_row_count| 32 (1 row)
-
使用
PREDICT_XGB_REGRESSOR
预测化油器数量:=> SELECT carb, PREDICT_XGB_REGRESSOR (mpg,cyl,hp,drat,wt USING PARAMETERS model_name='xgb_cars') FROM mtcars; carb | PREDICT_XGB_REGRESSOR ------+----------------------- 4 | 4.00335213618023 2 | 2.0038188946536 6 | 5.98866003194438 1 | 1.01774386191546 2 | 1.9959801016274 2 | 2.0038188946536 4 | 3.99545403625739 8 | 7.99211056556231 2 | 1.99291901733151 3 | 2.9975688946536 3 | 2.9975688946536 1 | 1.00320357711227 2 | 2.0038188946536 4 | 3.99545403625739 4 | 4.00124134679445 1 | 1.00759516721382 4 | 3.99700517763435 4 | 3.99580193056138 4 | 4.00009088187525 3 | 2.9975688946536 2 | 1.98625064560888 1 | 1.00355294416998 2 | 2.00666247039502 1 | 1.01682931210169 4 | 4.00124134679445 1 | 1.01007809485918 2 | 1.98438405824605 4 | 3.99580193056138 2 | 1.99291901733151 4 | 4.00009088187525 2 | 2.0038188946536 1 | 1.00759516721382 (32 rows)