In che modo LSTM previene il problema del gradiente evanescente?


35

LSTM è stato inventato appositamente per evitare il problema del gradiente di sparizione. Si suppone che lo faccia con il Constant Error Carousel (CEC), che sul diagramma sottostante (da Greff et al. ) Corrisponde al loop attorno alla cella .

LSTM
(fonte: deeplearning4j.org )

E capisco che quella parte può essere vista come una sorta di funzione identitaria, quindi la derivata è una e il gradiente rimane costante.

Quello che non capisco è come non svanisce a causa delle altre funzioni di attivazione? Le porte di input, output e dimenticare usano un sigmoide, la cui derivata è al massimo di 0,25, e geh erano tradizionalmente tanh . In che modo il backpropagating attraverso quelli non fa svanire il gradiente?


2
LSTM è un modello di rete neurale ricorrente che è molto efficiente nel ricordare le dipendenze a lungo termine e che non è vulnerabile al problema del gradiente in via di estinzione. Non sono sicuro del tipo di spiegazione che stai cercando
TheWalkingCube

LSTM: memoria a breve termine. (Rif .: Hochreiter, S. e Schmidhuber, J. (1997). Memoria a breve termine. Calcolo neurale 9 (8): 1735-80 · dicembre 1997)
horaceT

Le sfumature negli LSTM svaniscono, solo più lentamente rispetto alle RNN vanigliate, consentendo loro di catturare dipendenze più distanti. Evitare il problema della scomparsa dei gradienti è ancora un'area di ricerca attiva.
Artem Sobolev,

1
Ti interessa sostenere il lento svanire con un riferimento?
Bayerj,

Risposte:


22

Il gradiente di sparizione è meglio spiegato nel caso monodimensionale. Il multidimensionale è più complicato ma essenzialmente analogo. Puoi rivederlo in questo eccellente documento [1].

Supponiamo di avere uno stato nascosto al momento . Se rendiamo le cose semplici e rimuoviamo pregiudizi e input, abbiamo Quindi puoi dimostrarlohtt

ht=σ(wht1).

htht=k=1ttwσ(whtk)=wtt!!!k=1ttσ(whtk)
Il factored contrassegnato con !!! è quello cruciale. Se il peso non è uguale a 1, decadrà a zero esponenzialmente velocemente in , o crescerà esponenzialmente velocementett .

In LSTM, hai lo stato di cella . La derivata esiste nella forma Qui è l'input per il gate di dimenticanza. Come puoi vedere, non è coinvolto alcun fattore di decadimento esponenzialmente veloce. Di conseguenza, esiste almeno un percorso in cui il gradiente non svanisce. Per la derivazione completa, vedere [2].stvt

stst=k=1ttσ(vt+k).
vt

[1] Pascanu, Razvan, Tomas Mikolov e Yoshua Bengio. "Sulla difficoltà di formare reti neuronali ricorrenti". ICML (3) 28 (2013): 1310-1318.

[2] Bayer, Justin Simon. Rappresentazioni della sequenza di apprendimento. Diss. München, Technische Universität München, Diss., 2015, 2015.


3
Per lstm, h_t non dipende anche da h_ {t-1}? Cosa intendi nel tuo documento quando dici che ds_t / d_s {t-1} "è l'unica parte in cui i gradienti scorrono nel tempo"?
user3243135

@ user3243135 h_t dipende da h_ {t-1}. Tuttavia, supponiamo che ds_t / d_s {t-1} sia mantenuto, anche se svaniscono altri flussi di gradiente, l'intero flusso di gradiente non svanisce. Questo risolve la scomparsa del gradiente.
soloice,

Ho sempre pensato che il problema principale fosse il termine perché se è di solito la derivata di un sigmoid (o qualcosa del genere con una derivata inferiore a 1) che ha causato sicuramente il gradiente di scomparsa (ad es. i sigmoidi sono <1 in grandezza e la loro derivata è che è < 1 di sicuro). Non è per questo che le ReLU sono state accettate nelle CNN? Questa è una cosa che mi ha sempre confuso sulla differenza nel modo in cui il gradiente di fuga è stato affrontato nei modelli feed forward rispetto ai modelli ricorrenti. Qualche chiarimento per questo? σ(z)σ(x)=σ(z)(1-σ(z))
ttσ(whtk)
σ(z)σ(x)=σ(z)(1σ(z))
Pinocchio il

Anche il gradiente del sigmoide potrebbe diventare un problema, ipotizzando una distribuzione di input con grande varianza e / o media lontana da 0. Tuttavia, anche se si utilizzano ReLU, il problema principale persiste: moltiplicare ripetutamente per una matrice di pesi (di solito piccola ) provoca gradienti che svaniscono o, in alcuni casi, in cui la regolarizzazione non è stata adeguata, esplodendo gradienti.
Atassia

3

L'immagine del blocco LSTM di Greff et al. (2015) descrive una variante che gli autori chiamano vaniglia LSTM . È un po 'diverso dalla definizione originale di Hochreiter & Schmidhuber (1997). La definizione originale non includeva la porta di dimenticanza e le connessioni dello spioncino.

Il termine Carousel errore costante è stato utilizzato nel documento originale per indicare la connessione ricorrente dello stato della cella. Considera la definizione originale in cui lo stato della cella viene modificato solo per aggiunta, quando si apre la porta di ingresso. Il gradiente dello stato della cella rispetto allo stato della cella in una fase temporale precedente è zero.

L'errore può ancora entrare nel CEC attraverso il gate di uscita e la funzione di attivazione. La funzione di attivazione riduce leggermente l'entità dell'errore prima che venga aggiunto al CEC. CEC è l'unico posto in cui l'errore può fluire invariato. Ancora una volta, quando si apre la porta di ingresso, l'errore esce attraverso la porta di ingresso, la funzione di attivazione e la trasformazione affine, riducendo l'entità dell'errore.

Pertanto, l'errore viene ridotto quando viene retropropagato attraverso un livello LSTM, ma solo quando entra ed esce dal CEC. L'importante è che non cambi nel CEC, indipendentemente dalla distanza percorsa. Ciò risolve il problema nell'RNN di base che ogni volta che si applica una trasformazione affine e una non linearità, il che significa che maggiore è la distanza temporale tra input e output, minore è l'errore.


2

http://www.felixgers.de/papers/phd.pdf Fare riferimento alle sezioni 2.2 e 3.2.2 in cui viene spiegata la parte dell'errore troncata. Non propagano l'errore se fuoriesce dalla memoria della cella (ovvero se è presente un gate di ingresso chiuso / attivato), ma aggiornano i pesi del gate in base all'errore solo per quell'istante di tempo. Successivamente viene azzerato durante un'ulteriore propagazione posteriore. Questo è un tipo di hack, ma il motivo è che il flusso di errore lungo le porte decade comunque nel tempo.


7
Potresti espandere un po 'su questo? Al momento, la risposta non avrà alcun valore se la posizione del collegamento cambia o il foglio viene portato offline. Perlomeno sarebbe utile fornire una citazione (riferimento) completa che consentirà di ritrovare il documento se il collegamento smette di funzionare, ma sarebbe meglio un breve riassunto che renda questa risposta autonoma.
Silverfish,

2

Vorrei aggiungere alcuni dettagli alla risposta accettata, perché penso che sia un po 'più sfumato e la sfumatura potrebbe non essere ovvia per qualcuno che apprende prima le RNN.

Per la vaniglia RNN, .

htht=k=1ttwσ(whtk)

Per LSTM,

stst=k=1ttσ(vt+k)

  • una domanda naturale da porsi è: entrambe le somme di prodotto non hanno un termine sigmoide che quando moltiplicato insieme tempi possono svanire?tt
  • la risposta è , motivo per cui LSTM soffrirà anche di gradienti di fuga, ma non tanto quanto la RNN vanigliata

La differenza è per RNN vaniglia, il gradiente decade con mentre per LSTM il gradiente decade con .wσ()σ()

Per LSTM, esiste un insieme di pesi che può essere appreso in modo tale che Supponiamo per un po 'di peso e input . Quindi la rete neurale può imparare una grande per evitare che i gradienti scompaiano.

σ()1
vt+k=wxwxw

ad es. nel caso 1D se , il fattore di decadimento o il gradiente muore come:x=1w=10 vt+k=10σ()=0.99995

(0.99995)tt

Per la RNN alla vaniglia, non esiste un insieme di pesi che può essere appreso in modo tale che

wσ(whtk)1

ad es. nel caso 1D, supponiamo che . La funzione raggiunge un massimo di con . Questo significa che il gradiente decadrà come,htk=1wσ(w1)0.224w=1.5434

(0.224)tt

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.