Dimentica il livello in una rete neurale ricorrente (RNN) -


13

Sto cercando di capire le dimensioni di ciascuna variabile in un RNN nel livello di dimenticanza, tuttavia, non sono sicuro di essere sulla strada giusta. L'immagine e l'equazione successive sono tratte dal post sul blog di Colah "Capire le reti LSTM" :

inserisci qui la descrizione dell'immagine

dove:

  • xt è un input di dimensione vettorem1
  • ht1 è lo stato nascosto della dimensione vettoren1
  • [xt,ht1] è una concatenazione (ad esempio, se , quindi )xt=[1,2,3],ht1=[4,5,6][xt,ht1]=[1,2,3,4,5,6]
  • wf è pesi di dimensione matrice, dove è il numero di stati di cella (se , e nell'esempio precedente e se abbiamo 3 stati di cella, allora matrice)k(m+n)km=3n=3wf=33
  • bf è un bias di dimensione vettore, dove è il numero di stati di cella (poiché come nell'esempio precedente, allora è un vettore ).k1kk=3bf31

Se impostiamo su: wf

[1234565678910345678]

E essere:bf[1,2,3]

QuindiWf.[ht1,xt]=

[1234565678910345678].[123456]=[91175133]

Quindi possiamo aggiungere il bias,Wf.[ht1,xt]+bf=

[91175133]+[123]=[92177136]

Quindi li inseriamo in una funzione sigmoide: , dove , quindi eseguiamo questo elemento funzione saggio e ottieni .11+exx=[92177136]

[111]

Il che significa per ogni stato di cella, , (ci sono stati di cella), gli permettiamo di passare al livello successivo.Ct1k=3

L'assunzione di cui sopra è corretta?

Questo significa anche che il numero di stato delle celle e stato nascosto è lo stesso?

Risposte:


14

Ottima domanda!

tl; dr: lo stato della cella e lo stato nascosto sono due cose diverse, ma lo stato nascosto dipende dallo stato della cella e hanno effettivamente le stesse dimensioni.

Spiegazione più lunga

La differenza tra i due può essere vista dallo schema seguente (parte dello stesso blog):

Lo stato della cella è la linea in grassetto che viaggia da ovest a est attraverso la cima.

Lo stato della cella è la linea in grassetto che viaggia da ovest a est attraverso la cima. L'intero blocco verde è chiamato "cellula".

Lo stato nascosto della fase temporale precedente viene trattato come parte dell'input nella fase temporale corrente.

Tuttavia, è un po 'più difficile vedere la dipendenza tra i due senza fare una passeggiata completa. Lo farò qui, per fornire un'altra prospettiva, ma fortemente influenzato dal blog. La mia notazione sarà la stessa e userò le immagini del blog nella mia spiegazione.

Mi piace pensare all'ordine delle operazioni in modo leggermente diverso dal modo in cui sono state presentate nel blog. Personalmente, come partire dal gate di input. Presenterò questo punto di vista qui sotto, ma tieni presente che il blog potrebbe benissimo essere il modo migliore per impostare un LSTM a livello computazionale e questa spiegazione è puramente concettuale.

Ecco cosa sta succedendo:

La porta d'ingresso

inserisci qui la descrizione dell'immagine

L'input al momento è e . Questi vengono concatenati e alimentati in una funzione non lineare (in questo caso un sigmoide). Questa funzione sigmoide si chiama "input gate", perché funge da stopgap per l'ingresso. Decide stocasticamente quali valori aggiorneremo in questo momento, in base all'input corrente.txtht1

Cioè, (seguendo il tuo esempio), se abbiamo un vettore di input e uno stato nascosto precedente , allora il gate di ingresso fa quanto segue:xt=[1,2,3]ht=[4,5,6]

a) Concatenare e per darcixtht1[1,2,3,4,5,6]

b) Calcola volte il vettore concatenato e aggiungi il bias (in matematica: , dove è la matrice di peso dal vettore di input alla non linearità; è l'input bias).WiWi[xt,ht1]+biWibi

Supponiamo di passare da un input a sei dimensioni (la lunghezza del vettore di input concatenato) a una decisione tridimensionale su quali stati aggiornare. Ciò significa che abbiamo bisogno di una matrice di peso 3x6 e di un vettore di polarizzazione 3x1. Diamo alcuni valori:

Wi=[111111222222333333]

bi=[111]

Il calcolo sarebbe:

[111111222222333333][123456]+[111]=[224262]

c) Alimenta il calcolo precedente in una non linearità:it=σ(Wi[xt,ht1]+bi)

σ(x)=11+exp(x) (lo applichiamo elementally ai valori nel vettore )x

σ([224262])=[11+exp(22),11+exp(42),11+exp(62)]=[1,1,1]

In inglese, ciò significa che aggiorneremo tutti i nostri stati.

Il gate di ingresso ha una seconda parte:

d)Ct~=tanh(WC[xt,ht1]+bC)

Il punto di questa parte è calcolare come aggiorneremmo lo stato, se dovessimo farlo. È il contributo del nuovo input in questo momento allo stato della cella. Il calcolo segue la stessa procedura illustrata sopra, ma con un'unità tanh anziché un'unità sigmoidea.

L'output viene moltiplicato per quel vettore binario , ma lo quando arriveremo all'aggiornamento della cella.Ct~it

Insieme, ci dice quali stati vogliamo aggiornare e ci dice come vogliamo aggiornarli. Ci dice quali nuove informazioni vogliamo aggiungere alla nostra rappresentazione finora.itCt~

Poi arriva la porta dell'oblio, che era il nocciolo della tua domanda.

La porta dell'oblio

inserisci qui la descrizione dell'immagine

Lo scopo del cancello di dimenticanza è di rimuovere le informazioni apprese in precedenza che non sono più pertinenti. L'esempio fornito nel blog è basato sul linguaggio, ma possiamo anche pensare a una finestra scorrevole. Se stai modellando una serie temporale che è naturalmente rappresentata da numeri interi, come il conteggio di individui infetti in un'area durante un'epidemia, forse una volta che la malattia si è estinta in un'area, non vorrai più preoccuparti di considerare quell'area quando pensando a come la malattia viaggerà dopo.

Proprio come il livello di input, il livello di dimenticanza prende lo stato nascosto dal passaggio temporale precedente e il nuovo input dal passaggio temporale corrente e li concatena. Il punto è decidere stocasticamente cosa dimenticare e cosa ricordare. Nel calcolo precedente, ho mostrato un output di livello sigmoideo di tutti gli 1, ma in realtà era più vicino a 0,999 e ho arrotondato per eccesso.

Il calcolo è molto simile a quello che abbiamo fatto nel livello di input:

ft=σ(Wf[xt,ht1]+bf)

Questo ci darà un vettore di dimensione 3 con valori compresi tra 0 e 1. Facciamo finta che ci abbia dato:

[0.5,0.8,0.9]

Quindi decidiamo stocasticamente in base a questi valori quali di queste tre parti di informazioni dimenticare. Un modo per fare ciò è generare un numero da una distribuzione uniforme (0, 1) e se quel numero è inferiore alla probabilità che l'unità si accenda (0,5, 0,8 e 0,9 per le unità 1, 2 e 3 rispettivamente), quindi accendiamo quell'unità. In questo caso, ciò significherebbe che dimenticheremo tali informazioni.

Nota rapida: il livello di input e il livello di dimenticanza sono indipendenti. Se fossi un giocatore di scommesse, scommetterei che è un buon posto per la parallelizzazione.

Aggiornamento dello stato della cella

inserisci qui la descrizione dell'immagine

Ora abbiamo tutto ciò che serve per aggiornare lo stato della cella. Prendiamo una combinazione delle informazioni dall'ingresso e dalle porte di dimenticanza:

Ct=ftCt1+itCt~

Ora, sarà un po 'strano. Invece di moltiplicarsi come abbiamo fatto prima, qui indica il prodotto Hadamard, che è un prodotto entry-saggio.

A parte: prodotto Hadamard

Ad esempio, se avessimo due vettori e e volessimo prendere il prodotto Hadamard, faremmo questo:x1=[1,2,3]x2=[3,2,1]

x1x2=[(13),(22),(31)]=[3,4,3]

Fine a parte.

In questo modo, uniamo ciò che vogliamo aggiungere allo stato della cella (input) con ciò che vogliamo togliere dallo stato della cella (dimentica). Il risultato è il nuovo stato della cella.

La porta di uscita

inserisci qui la descrizione dell'immagine

Questo ci darà il nuovo stato nascosto. Essenzialmente il punto del gate di uscita è decidere quali informazioni vogliamo che la parte successiva del modello prenda in considerazione quando si aggiorna il successivo stato di cella. L'esempio nel blog è di nuovo, la lingua: se il sostantivo è plurale, la coniugazione del verbo nel passaggio successivo cambierà. In un modello patologico, se la suscettibilità degli individui in una determinata area è diversa rispetto a un'altra, la probabilità di contrarre un'infezione può cambiare.

Il livello di output accetta di nuovo lo stesso input, ma considera quindi lo stato della cella aggiornato:

ot=σ(Wo[xt,ht1]+bo)

Ancora una volta, questo ci dà un vettore di probabilità. Quindi calcoliamo:

ht=ottanh(Ct)

Quindi lo stato attuale della cella e la porta di uscita devono concordare su cosa emettere.

Cioè, se il risultato di è dopo che è stata presa la decisione stocastica se ciascuna unità è accesa o spenta e il risultato di è , quindi quando prendiamo il prodotto Hadamard, otterremo e solo le unità che sono state accese sia dal gate di uscita che nello stato della cella faranno parte dell'output finale.tanh(Ct)[0,1,1]ot[0,0,1][0,0,1]

[EDIT: C'è un commento sul blog che dice che viene nuovamente trasformato in un output effettivo di , il che significa che l'output effettivo sullo schermo (supponendo che tu ne abbia alcuni) è il risultato di un'altra trasformazione non lineare.]htyt=σ(Wht)

Il diagramma mostra che va in due posizioni: la cella successiva e "output" - sullo schermo. Penso che la seconda parte sia facoltativa.ht

Ci sono molte varianti su LSTM, ma questo copre l'essenziale!


Grazie per la tua risposta! Ho una domanda in più se non ti dispiace. Una rete neurale profonda può essere profonda perché la derivata di ReLU è 1 (se l'uscita è maggiore di 0). È lo stesso caso anche per questa cellula? Non sono sicuro di come Tanh e Sigmoid possano avere una derivata costante di 1?
user1157751,

Il piacere è tutto mio! Una rete neurale è considerata "profonda" quando ha più di un livello nascosto. I derivati ​​delle funzioni di attivazione (tanh, sigmoid, ReLU) influiscono sull'addestramento della rete. Come dici tu, poiché ReLU ha una pendenza costante se il suo input è maggiore di 0, la sua derivata è 1 se ci troviamo in quella regione della funzione. Le unità tanh e sigmoid hanno una derivata vicina a 1 se siamo nel mezzo della loro regione di attivazione, ma la loro derivata non sarà costante. Forse dovrei fare un post sul blog separato sui derivati ​​....
StatsSorceress,

Puoi mostrare un esempio della loro derivata vicino a 1 nella regione di attivazione? Ho visto molte risorse che parlano del derivato ma non viene fatta matematica?
user1157751,

Buona idea, ma mi ci vorrà del tempo per scrivere un post adeguato a riguardo. Nel frattempo, pensa alla forma della funzione tanh: è una "S" allungata. Nel mezzo è dove la derivata è la più alta. Laddove la S è piatta (le code della S) la derivata è 0. Ho visto una fonte secondo cui i sigmoidi hanno una derivata massima di 0,25, ma non ho un limite equivalente per tanh.
StatsSorceress

La parte che non capisco è diversa da ReLU con derivata costante 1 dove x> 0, ma sigmoid e tanh avevano un valore variabile per entrambi i suoi derivati. Come può essere "costante"?
user1157751
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.