Stratificato Train / Test-split in scikit-learn


95

Devo suddividere i miei dati in un set di addestramento (75%) e un set di test (25%). Attualmente lo faccio con il codice seguente:

X, Xt, userInfo, userInfo_train = sklearn.cross_validation.train_test_split(X, userInfo)   

Tuttavia, vorrei stratificare il mio set di dati di addestramento. Come lo faccio? Ho esaminato il StratifiedKFoldmetodo, ma non mi consente di specificare la divisione 75% / 25% e di stratificare solo il set di dati di addestramento.

Risposte:


161

[aggiornamento per 0.17]

Vedi i documenti di sklearn.model_selection.train_test_split:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    stratify=y, 
                                                    test_size=0.25)

[/ update per 0.17]

C'è una richiesta di pull qui . Ma puoi semplicemente fare train, test = next(iter(StratifiedKFold(...))) e usare il treno e testare gli indici se vuoi.


1
@AndreasMueller Esiste un modo semplice per stratificare i dati di regressione?
Jordan

3
@Jordan nulla è implementato in scikit-learn. Non conosco un modo standard. Potremmo usare i percentili.
Andreas Mueller

@AndreasMueller Hai mai visto il comportamento in cui questo metodo è notevolmente più lento di StratifiedShuffleSplit? Stavo usando il set di dati MNIST.
attivato dal

@activatedgeek che sembra molto strano, dato che train_test_split (... stratify =) sta solo chiamando StratifiedShuffleSplit e prendendo la prima divisione. Sentiti libero di aprire un problema sul tracker con un esempio riproducibile.
Andreas Mueller

@AndreasMueller In realtà non ho aperto un problema perché ho la forte sensazione che sto facendo qualcosa di sbagliato (anche se sono solo 2 righe). Ma se sono ancora in grado di riprodurlo più volte oggi, lo farò!
attivato dal

29

TL; DR: usa StratifiedShuffleSplit contest_size=0.25

Scikit-learn fornisce due moduli per la suddivisione stratificata:

  1. StratifiedKFold : questo modulo è utile come operatore di convalida incrociata k-fold diretta: poiché configurerà set di n_foldsaddestramento / test in modo tale che le classi siano ugualmente bilanciate in entrambi.

Ecco del codice (direttamente dalla documentazione sopra)

>>> skf = cross_validation.StratifiedKFold(y, n_folds=2) #2-fold cross validation
>>> len(skf)
2
>>> for train_index, test_index in skf:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
...    #fit and predict with X_train/test. Use accuracy metrics to check validation performance
  1. StratifiedShuffleSplit : questo modulo crea un singolo set di addestramento / test con classi ugualmente bilanciate (stratificate). Essenzialmente questo è ciò che vuoi con il n_iter=1. Puoi menzionare la dimensione del test qui come intrain_test_split

Codice:

>>> sss = StratifiedShuffleSplit(y, n_iter=1, test_size=0.5, random_state=0)
>>> len(sss)
1
>>> for train_index, test_index in sss:
...    print("TRAIN:", train_index, "TEST:", test_index)
...    X_train, X_test = X[train_index], X[test_index]
...    y_train, y_test = y[train_index], y[test_index]
>>> # fit and predict with your classifier using the above X/y train/test

5
Nota che a partire da 0.18.x, n_iterdovrebbe essere n_splitsper StratifiedShuffleSplit - e che c'è un'API leggermente diversa per questo: scikit-learn.org/stable/modules/generated/…
lollercoaster

2
Se yè una serie di Panda, usay.iloc[train_index], y.iloc[test_index]
Owlright

1
@Owlright Ho provato a utilizzare un dataframe panda e gli indici restituiti da StratifiedShuffleSplit non sono gli indici nel dataframe. dataframe index: 2,3,5 the first split in sss:[(array([2, 1]), array([0]))]:(
Meghna Natraj

2
@tangy perché questo è un ciclo for? non è il caso che quando una linea X_train, X_test = X[train_index], X[test_index]viene invocata sovrascrive X_traine X_test? Perché allora non solo un singolo next(sss)?
Bartek Wójcik


13

Ecco un esempio per i dati continui / di regressione (fino a quando questo problema su GitHub non viene risolto).

min = np.amin(y)
max = np.amax(y)

# 5 bins may be too few for larger datasets.
bins     = np.linspace(start=min, stop=max, num=5)
y_binned = np.digitize(y, bins, right=True)

X_train, X_test, y_train, y_test = train_test_split(
    X, 
    y, 
    stratify=y_binned
)
  • Dove startè il minimo e il stopmassimo del tuo obiettivo continuo.
  • Se non imposti right=True, renderà più o meno il tuo valore massimo un contenitore separato e la tua divisione fallirà sempre perché troppi pochi campioni saranno in quel contenitore aggiuntivo.

6

Oltre alla risposta accettata da @Andreas Mueller, voglio solo aggiungerla come @tangy menzionato sopra:

StratifiedShuffleSplit assomiglia di più a train_test_split ( stratify = y) con funzionalità aggiuntive di:

  1. stratificare per impostazione predefinita
  2. specificando n_splits , divide ripetutamente i dati

0
#train_size is 1 - tst_size - vld_size
tst_size=0.15
vld_size=0.15

X_train_test, X_valid, y_train_test, y_valid = train_test_split(df.drop(y, axis=1), df.y, test_size = vld_size, random_state=13903) 

X_train_test_V=pd.DataFrame(X_train_test)
X_valid=pd.DataFrame(X_valid)

X_train, X_test, y_train, y_test = train_test_split(X_train_test, y_train_test, test_size=tst_size, random_state=13903)

0

Aggiornamento della risposta @tangy dall'alto alla versione corrente di scikit-learn: 0.23.2 ( documentazione di StratifiedShuffleSplit ).

from sklearn.model_selection import StratifiedShuffleSplit

n_splits = 1  # We only want a single split in this case
sss = StratifiedShuffleSplit(n_splits=n_splits, test_size=0.25, random_state=0)

for train_index, test_index in sss.split(X, y):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.