Perché xgboost è molto più veloce di sklearn GradientBoostingClassifier?


29

Sto cercando di formare un modello di aumento gradiente su 50.000 esempi con 100 funzioni numeriche. XGBClassifiergestisce 500 alberi entro 43 secondi sulla mia macchina, mentre GradientBoostingClassifiergestisce solo 10 alberi (!) in 1 minuto e 2 secondi :( Non mi sono preoccupato di provare a far crescere 500 alberi perché ci vorranno ore. Sto usando lo stesso learning_ratee le max_depthimpostazioni , vedi sotto.

Cosa rende XGBoost molto più veloce? Usa qualche nuova implementazione per aumentare il gradiente che sklearn non conosce? O è "tagliare gli angoli" e far crescere alberi meno profondi?

ps Sono a conoscenza di questa discussione: https://www.kaggle.com/c/higgs-boson/forums/t/10335/xgboost-post-competition-survey ma non ho trovato la risposta lì ...

XGBClassifier(base_score=0.5, colsample_bylevel=1, colsample_bytree=1,
gamma=0, learning_rate=0.05, max_delta_step=0, max_depth=10,
min_child_weight=1, missing=None, n_estimators=500, nthread=-1,
objective='binary:logistic', reg_alpha=0, reg_lambda=1,
scale_pos_weight=1, seed=0, silent=True, subsample=1)

GradientBoostingClassifier(init=None, learning_rate=0.05, loss='deviance',
max_depth=10, max_features=None, max_leaf_nodes=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=10,
presort='auto', random_state=None, subsample=1.0, verbose=0,
warm_start=False)

2
suppongo che presto dovrò riformularlo come "perché LightGBM è molto più veloce di XGBoost?" :)
ihadanny il

Risposte:


25

Dal momento che menzioni le funzionalità "numeriche", immagino che le tue funzionalità non siano categoriche e abbiano un'elevata arità (possono assumere molti valori diversi e quindi ci sono molti possibili punti di divisione). In tal caso, crescere alberi è difficile poiché ci sono [molte caratteristiche × molti punti di divisione] da valutare.

La mia ipotesi è che l'effetto più grande deriva dal fatto che XGBoost utilizza un'approssimazione sui punti di divisione. Se hai una funzione continua con 10000 divisioni possibili, XGBoost considera di default solo le "migliori" 300 divisioni (questa è una semplificazione). Questo comportamento è controllato dal sketch_epsparametro e puoi leggere di più al riguardo nel documento . Puoi provare ad abbassarlo e verificare la differenza che fa. Dato che non è menzionato nella documentazione di scikit-learn , immagino che non sia disponibile. Puoi imparare qual è il metodo XGBoost nel loro documento (arxiv) .

XGBoost utilizza anche un'approssimazione sulla valutazione di tali punti di divisione. Non so in base a quale criterio scikit impari stia valutando le divisioni, ma potrebbe spiegare il resto della differenza di tempo.


Commenti di indirizzo

Per quanto riguarda la valutazione dei punti di divisione

Tuttavia, cosa intendevi con "XGBoost utilizza anche un'approssimazione sulla valutazione di tali punti di divisione"? per quanto ho capito, per la valutazione usano l'esatta riduzione della funzione obiettivo ottimale, come appare nell'eq (7) nel documento.

L(y,Hio-1+hio)LyHio-1hioLLHio-1io

L(y,Hio-1+hio)L


Grazie @Winks, ho letto il documento e ho capito cosa intendevi con l'algoritmo di approssimazione per la scelta dei candidati divisi. Tuttavia, cosa intendevi con "XGBoost utilizza anche un'approssimazione sulla valutazione di tali punti di divisione"? per quanto ho capito, per la valutazione usano l'esatta riduzione della funzione obiettivo ottimale, come appare nell'eq (7) nel documento.
ihadanny,

Ho modificato la mia risposta per indirizzare il tuo commento. Controllare questo Q / A per maggiori dettagli sulla valutazione dei punti di divisione.
Strizza l'

Grazie mille, @Winks! sarebbe fantastico se tu potessi anche rispondere alla mia domanda più elaborata qui: datascience.stackexchange.com/q/10997/16050
ihadanny

Questa è un'ottima risposta Tripletta !
eliasah,
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.