ROC medio per ripetute convalide incrociate di 10 volte con stime di probabilità


15

Sto pianificando di utilizzare una convalida incrociata ripetuta (10 volte) stratificata su circa 10.000 casi usando l'algoritmo di apprendimento automatico. Ogni volta che la ripetizione verrà eseguita con seme casuale diverso.

In questo processo creo 10 istanze di stime di probabilità per ciascun caso. 1 istanza di stima della probabilità per ciascuna delle 10 ripetizioni della convalida incrociata di 10 volte

Posso calcolare una media di 10 probabilità per ciascun caso e quindi creare una nuova curva ROC media (che rappresenta i risultati di CV 10 volte ripetuto), che può essere confrontata con altre curve ROC mediante confronti accoppiati?

Risposte:


13

Dalla tua descrizione sembra avere perfettamente senso: non solo puoi calcolare la curva ROC media, ma anche la varianza attorno ad essa per costruire intervalli di confidenza. Dovrebbe darti l'idea di quanto sia stabile il tuo modello.

Ad esempio, in questo modo:

inserisci qui la descrizione dell'immagine

Qui inserisco singole curve ROC, nonché la curva media e gli intervalli di confidenza. Ci sono aree in cui le curve sono d'accordo, quindi abbiamo meno varianze e ci sono aree in cui non sono d'accordo.

Per CV ripetuti puoi semplicemente ripeterlo più volte e ottenere la media totale su tutte le singole pieghe:

inserisci qui la descrizione dell'immagine

È abbastanza simile all'immagine precedente, ma fornisce stime più stabili (cioè affidabili) della media e della varianza.

Ecco il codice per ottenere la trama:

import matplotlib.pyplot as plt
import numpy as np
from scipy import interp

from sklearn.datasets import make_classification
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve

X, y = make_classification(n_samples=500, random_state=100, flip_y=0.3)

kf = KFold(n=len(y), n_folds=10)

tprs = []
base_fpr = np.linspace(0, 1, 101)

plt.figure(figsize=(5, 5))

for i, (train, test) in enumerate(kf):
    model = LogisticRegression().fit(X[train], y[train])
    y_score = model.predict_proba(X[test])
    fpr, tpr, _ = roc_curve(y[test], y_score[:, 1])

    plt.plot(fpr, tpr, 'b', alpha=0.15)
    tpr = interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tprs.append(tpr)

tprs = np.array(tprs)
mean_tprs = tprs.mean(axis=0)
std = tprs.std(axis=0)

tprs_upper = np.minimum(mean_tprs + std, 1)
tprs_lower = mean_tprs - std


plt.plot(base_fpr, mean_tprs, 'b')
plt.fill_between(base_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3)

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.axes().set_aspect('equal', 'datalim')
plt.show()

Per CV ripetuti:

idx = np.arange(0, len(y))

for j in np.random.randint(0, high=10000, size=10):
    np.random.shuffle(idx)
    kf = KFold(n=len(y), n_folds=10, random_state=j)

    for i, (train, test) in enumerate(kf):
        model = LogisticRegression().fit(X[idx][train], y[idx][train])
        y_score = model.predict_proba(X[idx][test])
        fpr, tpr, _ = roc_curve(y[idx][test], y_score[:, 1])

        plt.plot(fpr, tpr, 'b', alpha=0.05)
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

Fonte di ispirazione: http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html


3

Non è corretto per le probabilità medie perché ciò non rappresenterebbe le previsioni che si sta tentando di convalidare e comporta la contaminazione attraverso i campioni di convalida.

Si noti che per ottenere un'adeguata precisione possono essere necessarie 100 ripetizioni di convalida incrociata di 10 volte. Oppure usa il bootstrap di ottimismo Efron-Gong che richiede meno iterazioni per la stessa precisione (vedi ad esempio le funzioni del rmspacchetto R validate).

c


Potresti per favore approfondire ulteriormente il motivo per cui la media non è corretta?
DataD'oh,

Già dichiarato. Devi convalidare la misura che utilizzerai nel campo.
Frank Harrell,
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.