Previsione probabilistica casuale della foresta vs voto a maggioranza


10

Scikit Learn sembra usare la previsione probabilistica invece del voto della maggioranza per la tecnica di aggregazione del modello senza una spiegazione del perché (1.9.2.1. Foreste casuali).

C'è una chiara spiegazione del perché? Inoltre, esiste un buon articolo o articolo di revisione per le varie tecniche di aggregazione dei modelli che possono essere utilizzate per l'insaccamento della foresta casuale?

Grazie!

Risposte:


10

A queste domande si risponde sempre meglio guardando il codice, se si parla fluentemente Python.

RandomForestClassifier.predict, almeno nell'attuale versione 0.16.1, prevede la classe con la stima di probabilità più elevata, come indicato da predict_proba. ( questa linea )

La documentazione per predict_probadice:

Le probabilità di classe previste di un campione di input vengono calcolate come probabilità di classe previste medie degli alberi nella foresta. La probabilità di classe di un singolo albero è la frazione di campioni della stessa classe in una foglia.

La differenza rispetto al metodo originale è probabilmente solo quella che predictfornisce previsioni coerenti con predict_proba. Il risultato è talvolta chiamato "voto debole", piuttosto che il voto di maggioranza "duro" usato nel documento originale di Breiman. Nella ricerca rapida non sono riuscito a trovare un confronto adeguato delle prestazioni dei due metodi, ma entrambi sembrano abbastanza ragionevoli in questa situazione.

La predictdocumentazione è alquanto fuorviante; Ho inviato una richiesta pull per risolverlo.

Se invece vuoi fare una previsione del voto a maggioranza, ecco una funzione per farlo. Chiamalo come predict_majvote(clf, X)piuttosto che clf.predict(X). (Basato su predict_proba; solo leggermente testato, ma penso che dovrebbe funzionare.)

from scipy.stats import mode
from sklearn.ensemble.forest import _partition_estimators, _parallel_helper
from sklearn.tree._tree import DTYPE
from sklearn.externals.joblib import Parallel, delayed
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

def predict_majvote(forest, X):
    """Predict class for X.

    Uses majority voting, rather than the soft voting scheme
    used by RandomForestClassifier.predict.

    Parameters
    ----------
    X : array-like or sparse matrix of shape = [n_samples, n_features]
        The input samples. Internally, it will be converted to
        ``dtype=np.float32`` and if a sparse matrix is provided
        to a sparse ``csr_matrix``.
    Returns
    -------
    y : array of shape = [n_samples] or [n_samples, n_outputs]
        The predicted classes.
    """
    check_is_fitted(forest, 'n_outputs_')

    # Check data
    X = check_array(X, dtype=DTYPE, accept_sparse="csr")

    # Assign chunk of trees to jobs
    n_jobs, n_trees, starts = _partition_estimators(forest.n_estimators,
                                                    forest.n_jobs)

    # Parallel loop
    all_preds = Parallel(n_jobs=n_jobs, verbose=forest.verbose,
                         backend="threading")(
        delayed(_parallel_helper)(e, 'predict', X, check_input=False)
        for e in forest.estimators_)

    # Reduce
    modes, counts = mode(all_preds, axis=0)

    if forest.n_outputs_ == 1:
        return forest.classes_.take(modes[0], axis=0)
    else:
        n_samples = all_preds[0].shape[0]
        preds = np.zeros((n_samples, forest.n_outputs_),
                         dtype=forest.classes_.dtype)
        for k in range(forest.n_outputs_):
            preds[:, k] = forest.classes_[k].take(modes[:, k], axis=0)
        return preds

Sul muto caso sintetico che ho provato, le previsioni concordavano sempre con il predictmetodo.


Ottima risposta, Dougal! Grazie per aver dedicato del tempo a spiegarlo attentamente. Considera anche di andare oltre lo stack di overflow e di rispondere a questa domanda lì .
user1745038,

1
C'è anche un documento, qui , che affronta la previsione probabilistica.
user1745038,
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.