È possibile modificare la metrica utilizzata dal callback di Early Stopping in Keras?


13

Quando si utilizza il callback con arresto anticipato in Keras, l'allenamento si interrompe quando alcune metriche (generalmente perdita di convalida) non aumentano. Esiste un modo per utilizzare un'altra metrica (come precisione, richiamo, misura f) anziché perdita di convalida? Tutti gli esempi che ho visto finora sono simili a questo: callbacks.EarlyStopping (monitor = 'val_loss', pazienza = 5, verbose = 0, mode = 'auto')

Risposte:


11

È possibile utilizzare qualsiasi funzione metrica specificata durante la compilazione del modello.

Supponiamo che tu abbia la seguente funzione metrica:

def my_metric(y_true, y_pred):
     return some_metric_computation(y_true, y_pred)

L'unico requisito per questa funzione è che accetta il vero y e il previsto y.

Quando compili il modello, specifichi questa metrica, analogamente a come specifichi metriche incorporate come "accuratezza":

model.compile(metrics=['accuracy', my_metric], ...)

Si noti che stiamo usando il nome della funzione my_metric senza '' (in contrasto con l'accuratezza di 'accuratezza').

Quindi, se definisci EarlyStopping, usa semplicemente il nome della funzione (questa volta con ''):

EarlyStopping(monitor='my_metric', mode='min')

Assicurati di specificare la modalità (min se inferiore è meglio, max se superiore è migliore).

Puoi usarlo come qualsiasi metrica incorporata. Questo probabilmente funziona anche con altri callback come ModelCheckpoint (ma non l'ho provato). Internamente, Keras aggiunge semplicemente la nuova metrica all'elenco delle metriche disponibili per questo modello usando il nome della funzione.

Se specifichi i dati per la validazione in model.fit (...), puoi anche usarli per EarlyStopping usando 'val_my_metric'.


3

Certo, creane uno tuo!

class EarlyStopByF1(keras.callbacks.Callback):
    def __init__(self, value = 0, verbose = 0):
        super(keras.callbacks.Callback, self).__init__()
        self.value = value
        self.verbose = verbose


    def on_epoch_end(self, epoch, logs={}):
         predict = np.asarray(self.model.predict(self.validation_data[0]))
         target = self.validation_data[1]
         score = f1_score(target, prediction)
         if score > self.value:
            if self.verbose >0:
                print("Epoch %05d: early stopping Threshold" % epoch)
            self.model.stop_training = True


callbacks = [EarlyStopByF1(value = .90, verbose =1)]
model.fit(X, y, batch_size = 32, nb_epoch=nb_epoch, verbose = 1, 
validation_data(X_val,y_val), callbacks=callbacks)

Non l'ho provato, ma quello dovrebbe essere il sapore generale di come lo fai. Se non funziona fammelo sapere e riproverò nel fine settimana. Suppongo anche che tu abbia già implementato il tuo punteggio f1. Se non solo importare per sklearn.


+1 Funziona ancora dal 2/11/2020 utilizzando gli ultimi Keras e Python 3.7
Austin,
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.