Come utilizzare l'output di GridSearch?


23

Attualmente sto lavorando con Python e Scikit per imparare a fini di classificazione, e facendo alcune letture su GridSearch ho pensato che questo fosse un ottimo modo per ottimizzare i miei parametri di stima per ottenere i migliori risultati.

La mia metodologia è questa:

  1. Dividi i miei dati in allenamento / test.
  2. Usa GridSearch con validazione 5Fold Cross per addestrare e testare i miei stimatori (Random Forest, Gradient Boost, SVC tra gli altri) per ottenere i migliori stimatori con la combinazione ottimale di iper parametri.
  3. Calcolo quindi le metriche su ciascuno dei miei stimatori come Precisione, Richiamo, FMeasure e Coefficiente di correlazione di Matthews, usando il mio set di test per prevedere le classificazioni e confrontarle con le etichette delle classi effettive.

È a questo punto che vedo comportamenti strani e non sono sicuro di come procedere. Prendo il .best_estimator_ da GridSearch e lo utilizzo come output "ottimale" dalla ricerca della griglia ed eseguo una previsione utilizzando questo stimatore? Se lo faccio, trovo che le metriche della fase 3 sono in genere molto più basse rispetto a quando mi alleno semplicemente su tutti i dati di allenamento e collaudo sul set di test. Oppure, prendo semplicemente l'oggetto GridSearchCV di output come nuovo stimatore ? Se lo faccio ottengo punteggi migliori per le mie metriche dello stage 3, ma sembra strano usare un oggetto GridSearchCV invece del classificatore previsto (ad esempio una foresta casuale) ...

EDIT: Quindi la mia domanda è qual è la differenza tra l'oggetto GridSearchCV restituito e l'attributo .best_estimator_? Quale di questi dovrei usare per calcolare ulteriori parametri? Posso usare questo output come un normale classificatore (ad es. Usando predict), oppure come dovrei usarlo?

Risposte:


27

Decisi di andare via e trovare le risposte che avrebbero soddisfatto la mia domanda, e scriverle qui per chiunque si chiedesse.

L'attributo .best_estimator_ è un'istanza del tipo di modello specificato, che ha la combinazione "migliore" di determinati parametri da param_grid. L'utilità di questa istanza dipende dal fatto che il parametro refit sia impostato su True (è per impostazione predefinita). Per esempio:

clf = GridSearchCV(estimator=RandomForestClassifier(), 
                    param_grid=parameter_candidates,
                    cv=5,
                    refit=True,
                    error_score=0,
                    n_jobs=-1)

clf.fit(training_set, training_classifications)
optimised_random_forest = clf.best_estimator_
return optimised_random_forest

Restituirà un RandomForestClassifier. Questo è tutto abbastanza chiaro dalla documentazione . Ciò che non è chiaro dalla documentazione è perché la maggior parte degli esempi non utilizza specificamente il .best_estimator_ e invece fa questo:

clf = GridSearchCV(estimator=RandomForestClassifier(), 
                    param_grid=parameter_candidates,
                    cv=5,
                    refit=True,
                    error_score=0,
                    n_jobs=-1)

clf.fit(training_set, training_classifications)
return clf

Questo secondo approccio restituisce un'istanza GridSearchCV, con tutte le campane e fischietti di GridSearchCV come .best_estimator_, .best_params, ecc., Che può essere usato come un classificatore addestrato perché:

Optimised Random Forest Accuracy:  0.916970802919708
[[139  47]
 [ 44 866]]
GridSearchCV Accuracy:  0.916970802919708
[[139  47]
 [ 44 866]]

Utilizza solo la stessa migliore istanza dello stimatore durante le previsioni. Quindi in pratica non c'è alcuna differenza tra questi due a meno che non si desideri solo l'istanza dello stimatore stessa. Come nota a margine, le mie differenze nelle metriche non erano correlate e dipendevano da una funzione di ponderazione della classe con errori.


Grazie per il tuo post @Dan, è molto utile. Volevo chiederti un chiarimento. In quest'ultimo caso, se ho refit=Falsequindi clf.fitnon sarà presente con i migliori classificatore?
Poete Maudit,

@PoeteMaudit Il parametro refit dice alla funzione GridSearchCV di prendere i migliori parametri trovati e riqualificare il modello usando quei parametri sull'intero set di dati. Se refit = False, allora best_estimator non è disponibile, secondo la documentazione: scikit-learn.org/stable/modules/generated/…
Dan Carter

0

GridSearchCV ti consente di combinare uno stimatore con un preambolo di ricerca della griglia per ottimizzare gli iperparametri. Il metodo seleziona il parametro ottimale dalla ricerca della griglia e lo utilizza con lo stimatore selezionato dall'utente. GridSearchCV eredita i metodi dal classificatore, quindi sì, puoi usare i metodi .score, .predict, ecc. Direttamente attraverso l'interfaccia GridSearchCV. Se desideri estrarre i migliori iperparametri identificati dalla ricerca della griglia, puoi utilizzare .best_params_ e questo restituirà il miglior iperparametro. È quindi possibile passare questo iperparametro allo stimatore separatamente.

L'uso diretto di .predict produrrà gli stessi risultati ottenuti dall'ottenimento del miglior iperparametro tramite .best_param_ e quindi dall'utilizzo nel modello. Comprendendo il funzionamento alla base della ricerca in griglia, possiamo capire perché questo è il caso.


Ricerca griglia

Questa tecnica viene utilizzata per trovare i parametri ottimali da utilizzare con un algoritmo. NON si tratta dei pesi o del modello, quelli appresi utilizzando i dati. Questo ovviamente è abbastanza confuso, quindi distinguerò tra questi parametri, chiamando un iperparametro.

Gli iperparametri sono come i k in k-vicini più vicini (k-NN). k-NN richiede all'utente di selezionare quale vicino considerare durante il calcolo della distanza. L'algoritmo quindi sintonizza un parametro, una soglia, per vedere se un nuovo esempio rientra nella distribuzione appresa, questo viene fatto con i dati.

Come scegliamo k?

Alcune persone seguono semplicemente raccomandazioni basate su studi passati del tipo di dati. Altri usano la ricerca della griglia. Questo metodo sarà in grado di determinare meglio quale k è ottimale da utilizzare per i tuoi dati.

Come funziona?

[1,2,3,...,10]

Questo va contro i principi di non usare i dati di test !!

nnn-1n

Il valore dell'iperparametro selezionato è quello che ottiene le massime prestazioni medie tra le pieghe. Una volta che sei soddisfatto del tuo algoritmo, puoi testarlo sul set di test. Se si passa direttamente al set di test, si rischia un overfitting.


Ciao Jah, questa è una buona risposta ma non sono ancora più saggio della risposta alla mia domanda. Ho aggiornato il titolo della domanda e la domanda stessa per cercare di chiarire le cose.
Dan Carter,

Scrivi la tua ricerca della griglia. È letteralmente creare un array, quindi aggiungere un ciclo for attorno al modello. Quindi, alla fine del ciclo for, registrare le prestazioni risultanti in un array. Dopo aver esaminato tutti i possibili valori nella griglia, osserva le matrici delle prestazioni e scegli quella migliore. Questo è il valore ottimale per il tuo iperparametro. Fare affidamento su funzioni integrate per le basi non è altamente raccomandato per la scienza dei dati. I dati variano così selvaggiamente ed è meglio che tu abbia il controllo!
JahKnows

Sarebbe un buon suggerimento se avessi un solo iperparametro da ottimizzare, ma se ne avessi 4? 5? Un ciclo annidato 4/5 volte è brutto e non vedo la necessità di reinventare la ruota qui, sarebbe una perdita di tempo ed è la ragione per cui esistono pacchetti come questo.
Dan Carter,

GridSearchCV ti consente di combinare uno stimatore con l'impostazione GridSearchCV. Quindi fa esattamente quello che abbiamo appena discusso. Quindi seleziona il parametro ottimale e lo utilizza con lo stimatore selezionato. GridSearchCV eredita i metodi dal classificatore, quindi sì, puoi usare i metodi .score, .predict, ecc. Direttamente attraverso l'interfaccia GridSearchCV. Non consiglio di farlo, tuttavia, strumenti più semplici significano meno controllo. Per qualcosa di così semplice come una ricerca in griglia, basta codificarlo da soli.
JahKnows

1
Questa risposta non affronta la domanda relativa all'utilizzo di GridSearchCV.
Hobbes
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.