Strano comportamento con Adam Optimizer quando ci si allena troppo a lungo


11

Sto cercando di addestrare un singolo percettrone (1000 unità di input, 1 output, nessun layer nascosto) su 64 punti dati generati casualmente. Sto usando Pytorch usando l'ottimizzatore Adam:

import torch
from torch.autograd import Variable

torch.manual_seed(545345)
N, D_in, D_out = 64, 1000, 1

x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out))

model = torch.nn.Linear(D_in, D_out)
loss_fn = torch.nn.MSELoss(size_average=False)

optimizer = torch.optim.Adam(model.parameters())
for t in xrange(5000):
  y_pred = model(x)
  loss = loss_fn(y_pred, y)

  print(t, loss.data[0])

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

Inizialmente, la perdita diminuisce rapidamente, come previsto:

(0, 91.74887084960938)
(1, 76.85824584960938)
(2, 63.434078216552734)
(3, 51.46927261352539)
(4, 40.942893981933594)
(5, 31.819372177124023)

Circa 300 iterazioni, l'errore raggiunge quasi lo zero:

(300, 2.1734419819452455e-12)
(301, 1.90354676465887e-12)
(302, 2.3347573874232808e-12)

Questo continua per alcune migliaia di iterazioni. Tuttavia, dopo un allenamento troppo lungo, l'errore ricomincia ad aumentare:

(4997, 0.002102422062307596)
(4998, 0.0020302983466535807)
(4999, 0.0017039275262504816)

Perché sta succedendo?


Non credo che il sovradimensionamento lo spieghi: la perdita di addestramento sta aumentando, non la perdita di convalida. Ad esempio, questo non accade quando si utilizza SGD, solo con Adam.
Bai Li,

Il modello ha 1000 parametri e c'è solo 1 punto dati, quindi il modello dovrebbe adattarsi esattamente ai dati e la perdita dovrebbe essere zero.
Bai Li,

Oh scusa, hai ragione. Ci sono 64 punti dati.
Bai Li,

Ci sono 64 punti dati (cioè, vincoli) e 1000 parametri, quindi è possibile trovare scelte per i parametri in modo che l'errore sia zero (e questo è facile da fare analiticamente). La mia domanda è: perché Adamo non trova questo.
Bai Li,

Risposte:


19

Questa piccola instabilità alla fine della convergenza è una caratteristica di Adam (e RMSProp) a causa del modo in cui stima le magnitudini del gradiente negli ultimi passaggi e le divide per esse.

Una cosa che Adamo fa è mantenere una media geometrica mobile dei recenti gradienti e quadrati dei gradienti. I quadrati dei gradienti vengono usati per dividere (un'altra media mobile) del gradiente corrente per decidere il passo corrente. Tuttavia, quando il gradiente diventa e rimane molto vicino allo zero, i quadrati del gradiente diventano così bassi che presentano errori di arrotondamento elevati o sono effettivamente zero, il che può introdurre instabilità (ad esempio un gradiente stabile a lungo termine in una dimensione fa un passo relativamente piccolo da a causa di cambiamenti in altri parametri) e la dimensione del passo inizierà a saltare, prima di sistemarsi di nuovo. 10 - 51010105

Questo in realtà rende Adam meno stabile e peggiore per il tuo problema rispetto a una discesa del gradiente più elementare, supponendo che tu voglia avvicinarti numericamente alla perdita zero come i calcoli consentono il tuo problema.

In pratica sui problemi di apprendimento profondo, non ti avvicini così tanto alla convergenza (e per alcune tecniche di regolarizzazione come l'arresto anticipato, non vuoi comunque), quindi di solito non è una preoccupazione pratica sui tipi di problema che Adam è stato progettato per.

Puoi effettivamente vederlo accadere per RMSProp in un confronto di diversi ottimizzatori (RMSProp è la linea nera - guarda gli ultimi passi appena raggiunge l'obiettivo):

inserisci qui la descrizione dell'immagine

Puoi rendere Adam più stabile e in grado di avvicinarti alla vera convergenza riducendo il tasso di apprendimento. Per esempio

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

Ci vorrà più tempo per ottimizzare. Utilizzando lr=1e-5è necessario allenarsi per oltre 20.000 iterazioni prima di vedere l'instabilità e l'instabilità è meno drammatica, i valori si aggirano intorno a .107


Questa è una visualizzazione spettacolare, Neil. Quali sono le dimensioni effettive? Cosa rappresenta xey? I frame sono delta t o n epoche per frame? Immagino che la stella sia l'ottimale globale in una rappresentazione topografica della disparità (errore) in relazione a due parametri selezionati. La mia ipotesi è corretta?
Douglas Daseeco,

Non è la mia visualizzazione, la troverai in molti posti. Le dimensioni sono unità arbitrarie di parametri di input per una funzione di test e il grafico mostra le linee di contorno per quella funzione (sempre in unità arbitrarie, presumibilmente ridimensionate in modo che NN funzioni correttamente). Ogni frame è un passaggio di aggiornamento del peso. Probabilmente equivale a un aggiornamento mini-batch e, a causa del comportamento di SGD, mi aspetto che venga effettivamente risolto utilizzando esattamente il gradiente reale della funzione di test, ovvero che non vi siano set di dati o campionamenti.
Neil Slater,

1

Il motivo è esattamente come menzionato nell'altra risposta con un grande suggerimento di utilizzare un tasso di apprendimento più piccolo per evitare questo problema attorno a piccole pendenze.

Mi viene in mente un paio di approcci:

  1. È possibile tagliare i gradienti con un limite superiore / inferiore, ma ciò non garantisce la convergenza e può provocare il congelamento dell'allenamento restando intrappolato in alcuni minimi locali e non uscirne mai.

  2. Allenati con una dimensione del lotto superiore, più epoche e con un tasso di apprendimento decaduto. Ora non ho alcuna prova pratica che l'aumento della dimensione di un lotto si traduca in gradienti migliori, ma da quello che avevo osservato affrontando problemi simili ai tuoi, farlo quasi sempre ha aiutato.

Sono sicuro che ci sono altri metodi (come il tasso di apprendimento ciclico ecc.) Che cercano di trovare un tasso di apprendimento ottimale basato sulle statistiche.

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.