Quali parametri dovrebbero essere usati per l'arresto anticipato?


97

Sto addestrando una rete neurale per il mio progetto usando Keras. Keras ha fornito una funzione per l'arresto anticipato. Posso sapere quali parametri devono essere osservati per evitare che la mia rete neurale si adatti eccessivamente utilizzando l'arresto anticipato?

Risposte:


157

arresto anticipato

L'arresto anticipato è fondamentalmente l'interruzione dell'allenamento una volta che la perdita inizia ad aumentare (o in altre parole l'accuratezza della convalida inizia a diminuire). Secondo i documenti è usato come segue;

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=0,
                              verbose=0, mode='auto')

I valori dipendono dalla tua implementazione (problema, dimensione del batch ecc ...) ma generalmente per evitare l'overfitting userei;

  1. Monitorare la perdita di convalida (è necessario utilizzare la convalida incrociata o almeno set di addestramento / test) impostando l' monitor argomento su 'val_loss'.
  2. min_deltaè una soglia per quantificare o meno una perdita in un'epoca come miglioramento. Se la differenza di perdita è inferiore min_delta, viene quantificata come nessun miglioramento. Meglio lasciarlo a 0 poiché siamo interessati a quando la perdita peggiora.
  3. patiencel'argomento rappresenta il numero di epoche prima di fermarsi una volta che la perdita inizia ad aumentare (smette di migliorare). Questo dipende dalla tua implementazione, se usi lotti molto piccoli o un grande tasso di apprendimento la tua perdita a zig-zag (la precisione sarà più rumorosa), quindi è meglio impostare un patienceargomento ampio . Se utilizzi lotti di grandi dimensioni e una bassa velocità di apprendimento, la perdita sarà più agevole, quindi puoi utilizzare un patienceargomento più piccolo . In ogni caso, lo lascerò come 2 in modo da dare più possibilità al modello.
  4. verbose decide cosa stampare, lasciarlo al valore predefinito (0).
  5. model'argomento dipende dalla direzione che ha la quantità monitorata (dovrebbe diminuire o aumentare), poiché monitoriamo la perdita, possiamo usarla min. Ma lasciamo che sia Keras a gestirlo per noi e lo impostiauto

Quindi userei qualcosa di simile e sperimenterei tracciando la perdita di errore con e senza arresto anticipato.

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=2,
                              verbose=0, mode='auto')

Per possibili ambiguità su come funzionano i callback, cercherò di spiegare di più. Una volta richiamato fit(... callbacks=[es])il modello, Keras chiama funzioni predeterminate di oggetti callback dati. Queste funzioni possono essere chiamati on_train_begin, on_train_end, on_epoch_begin, on_epoch_ende on_batch_begin, on_batch_end. La richiamata anticipata viene richiamata alla fine di ogni epoca, confronta il miglior valore monitorato con quello attuale e si ferma se le condizioni sono soddisfatte (quante epoche sono trascorse dall'osservazione del miglior valore monitorato ed è più di un argomento di pazienza, la differenza tra l'ultimo valore è maggiore di min_delta ecc ..).

Come indicato da @BrentFaust nei commenti, l'addestramento del modello continuerà fino a quando non vengono soddisfatte le condizioni di arresto anticipato o il epochsparametro (predefinito = 10) in fit(). L'impostazione di un callback di arresto anticipato non farà sì che il modello si alleni oltre il suo epochsparametro. Quindi chiamare una fit()funzione con un epochsvalore maggiore trarrebbe maggiori benefici dal callback di arresto anticipato.


3
@AizuddinAzman close, min_deltaè una soglia per quantificare o meno la variazione del valore monitorato come miglioramento. Quindi sì, se diamo monitor = 'val_loss', si riferirebbe alla differenza tra l'attuale perdita di convalida e la precedente perdita di convalida. In pratica, se si dà min_delta=0.1una diminuzione della perdita di convalida (corrente - precedente) inferiore a 0,1 non si quantifica, quindi si interromperà l'allenamento (se si ha patience = 0).
umutto

3
Nota che callbacks=[EarlyStopping(patience=2)]non ha alcun effetto, a meno che non sia dato epoche model.fit(..., epochs=max_epochs).
Brent Faust

1
@BrentFaust Anche questa è la mia comprensione, ho scritto la risposta supponendo che il modello venga addestrato con almeno 10 epoche (come impostazione predefinita). Dopo il tuo commento, mi sono reso conto che potrebbe esserci un caso in cui il programmatore chiama in forma epoch=1in un ciclo for (per vari casi d'uso) in cui questa richiamata fallirebbe. Se c'è ambiguità nella mia risposta, cercherò di metterla in un modo migliore.
umutto

4
@AdmiralWen Da quando ho scritto la risposta, il codice è leggermente cambiato. Se stai usando l'ultima versione di Keras, puoi usare l' restore_best_weightsargomento (non ancora sulla documentazione), che carica il modello con i pesi migliori dopo l'allenamento. Ma, per i tuoi scopi, ModelCheckpointuserei la richiamata con save_best_onlyargomenti. Puoi controllare la documentazione, è semplice da usare ma devi caricare manualmente i pesi migliori dopo l'allenamento.
umutto

1
@umutto Ciao grazie per il suggerimento di restore_best_weights, tuttavia non sono in grado di usarlo, `es = EarlyStopping (monitor = 'val_acc', min_delta = 1e-4, patience = patience_, verbose = 1, restore_best_weights = True) TypeError: __init __ () ha ricevuto un argomento parola chiave imprevisto 'restore_best_weights'`. Qualche idea? keras 2.2.2, tf, 1.10 qual è la tua versione?
Haramoz
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.