XGBoost for regression
XGBoost (eXtreme Gradient Boosting) is a popular supervised-learning algorithm used for regression and classification on large datasets. It uses sequentially-built shallow decision trees to provide accurate results and a highly-scalable training method that avoids overfitting.
The following XGBoost functions create and perform predictions with a regression model:
Example
This example uses a small data set named "mtcars", which contains design and performance data for 32 automobiles from 1973-1974, and creates an XGBoost regression model to predict the value of the variable carb
(the number of carburetors).
-
Use
XGB_REGRESSOR
to create the XGBoost regression modelxgb_cars
from themtcars
dataset:=> SELECT XGB_REGRESSOR ('xgb_cars', 'mtcars', 'carb', 'mpg, cyl, hp, drat, wt' USING PARAMETERS learning_rate=0.5); XGB_REGRESSOR --------------- Finished (1 row)
You can then view a summary of the model with
GET_MODEL_SUMMARY
:=> SELECT GET_MODEL_SUMMARY(USING PARAMETERS model_name='xgb_cars'); GET_MODEL_SUMMARY ------------------------------------------------------ =========== call_string =========== xgb_regressor('public.xgb_cars', 'mtcars', '"carb"', 'mpg, cyl, hp, drat, wt' USING PARAMETERS exclude_columns='', max_ntree=10, max_depth=5, nbins=32, objective=squarederror, split_proposal_method=global, epsilon=0.001, learning_rate=0.5, min_split_loss=0, weight_reg=0, sampling_size=1) ======= details ======= predictor| type ---------+---------------- mpg |float or numeric cyl | int hp | int drat |float or numeric wt |float or numeric =============== Additional Info =============== Name |Value ------------------+----- tree_count | 10 rejected_row_count| 0 accepted_row_count| 32 (1 row)
-
Use
PREDICT_XGB_REGRESSOR
to predict the number of carburetors:=> SELECT carb, PREDICT_XGB_REGRESSOR (mpg,cyl,hp,drat,wt USING PARAMETERS model_name='xgb_cars') FROM mtcars; carb | PREDICT_XGB_REGRESSOR ------+----------------------- 4 | 4.00335213618023 2 | 2.0038188946536 6 | 5.98866003194438 1 | 1.01774386191546 2 | 1.9959801016274 2 | 2.0038188946536 4 | 3.99545403625739 8 | 7.99211056556231 2 | 1.99291901733151 3 | 2.9975688946536 3 | 2.9975688946536 1 | 1.00320357711227 2 | 2.0038188946536 4 | 3.99545403625739 4 | 4.00124134679445 1 | 1.00759516721382 4 | 3.99700517763435 4 | 3.99580193056138 4 | 4.00009088187525 3 | 2.9975688946536 2 | 1.98625064560888 1 | 1.00355294416998 2 | 2.00666247039502 1 | 1.01682931210169 4 | 4.00124134679445 1 | 1.01007809485918 2 | 1.98438405824605 4 | 3.99580193056138 2 | 1.99291901733151 4 | 4.00009088187525 2 | 2.0038188946536 1 | 1.00759516721382 (32 rows)