Cattura di schemi iniziali quando si utilizza la backpropagation troncata nel tempo (RNN / LSTM)


12

Supponiamo che io usi un RNN / LSTM per fare l'analisi del sentiment, che è un approccio molti-a-uno (vedi questo blog ). La rete viene addestrata attraverso una backpropagation troncata nel tempo (BPTT), dove la rete viene srotolata solo per gli ultimi 30 passaggi.

Nel mio caso, ciascuna delle mie sezioni di testo che voglio classificare sono molto più lunghe dei 30 passaggi che vengono srotolati (~ 100 parole). Sulla base delle mie conoscenze, BPTT viene eseguito una sola volta per una singola sezione di testo, ovvero quando ha superato l'intera sezione di testo e ha calcolato il target di classificazione binaria, , che quindi confronta con la funzione di perdita per trovare l'errore.y

I gradienti non verranno mai calcolati rispetto alle prime parole di ogni sezione di testo. In che modo RNN / LSTM può ancora regolare i suoi pesi per acquisire schemi specifici che si verificano solo nelle prime parole? Ad esempio, supponiamo che tutte le frasi contrassegnate come inizino con "I love this" e che tutte le frasi contrassegnate come inizino con "Lo odio". In che modo RNN / LSTM lo catturerebbe quando viene srotolato solo per gli ultimi 30 passaggi quando raggiunge la fine di una sequenza lunga 100 passaggi?positivenegative


di solito, l'abbreviazione è TBPTT per Troncated Back-Propagation Through Time.
Charlie Parker,

Risposte:


11

È vero che limitare la propagazione del gradiente a 30 passi temporali gli impedirà di apprendere tutto il possibile nel set di dati. Tuttavia, dipende fortemente dal tuo set di dati se ciò gli impedirà di apprendere cose importanti sulle funzionalità del tuo modello!

Limitare il gradiente durante l'allenamento è più come limitare la finestra su cui il modello può assimilare le funzionalità di input e lo stato nascosto con elevata sicurezza. Poiché al momento del test si applica il modello all'intera sequenza di input, sarà comunque in grado di incorporare informazioni su tutte le funzionalità di input nel suo stato nascosto. Potrebbe non sapere esattamente come conservare tali informazioni fino a quando non farà la sua previsione finale per la frase, ma potrebbero esserci alcune connessioni (sicuramente più deboli) che sarebbe ancora in grado di stabilire.

Pensa prima a un esempio inventato. Supponiamo che la tua rete generi un 1 se c'è un 1 ovunque nel suo input e uno 0 altrimenti. Supponi di allenare la rete su sequenze di lunghezza 20 e di limitare poi il gradiente a 10 passi. Se il set di dati di addestramento non contiene mai un 1 negli ultimi 10 passaggi di un input, la rete avrà un problema con gli input di test di qualsiasi configurazione. Tuttavia, se il set di formazione ha alcuni esempi come [1 0 0 ... 0 0 0] e altri come [0 0 0 ... 1 0 0], la rete sarà in grado di rilevare la "presenza di una funzione da 1 "ovunque nel suo input.

Torna quindi all'analisi del sentiment. Diciamo che durante l'allenamento il tuo modello incontra una lunga frase negativa come "Odio questo perché ... intorno e intorno" con, diciamo, 50 parole tra i puntini di sospensione. Limitando la propagazione del gradiente a 30 fasi temporali, il modello non collegherà "Lo odio perché" all'etichetta di output, quindi non prenderà in considerazione "I", "odio" o "questo" da questo addestramento esempio. Ma riprenderà le parole che si trovano entro 30 intervalli di tempo dalla fine della frase. Se il tuo set di allenamento contiene altri esempi che contengono quelle stesse parole, possibilmente insieme a "odio", allora ha la possibilità di raccogliere il legame tra "odio" e l'etichetta del sentimento negativo. Inoltre, se hai esempi di allenamento più brevi, dì "Odiamo questo perché è terribile!" quindi il tuo modello sarà in grado di collegare le caratteristiche "odio" e "questo" all'etichetta di destinazione. Se hai abbastanza di questi esempi di addestramento, il modello dovrebbe essere in grado di apprendere in modo efficace la connessione.

Al momento del test, supponiamo che tu presenti al modello un'altra lunga frase come "Lo odio perché ... sul geco!" L'input del modello inizierà con "Lo odio", che verrà passato in qualche modo allo stato nascosto del modello. Questo stato nascosto viene utilizzato per influenzare i futuri stati nascosti del modello, quindi anche se potrebbero esserci 50 parole prima della fine della frase, lo stato nascosto di quelle parole iniziali ha una probabilità teorica di influenzare l'output, anche se non è mai stato addestrato su campioni che contenevano una distanza così grande tra "Odio questo" e la fine della frase.


0

@ Imjohns3 ha ragione, se si elaborano sequenze lunghe (dimensione N) e si limita la backpropagation agli ultimi passi K, la rete non apprenderà i modelli all'inizio.

Ho lavorato con testi lunghi e utilizzo l'approccio in cui computo la perdita ed eseguo la backpropagation dopo ogni passaggio di K. Supponiamo che la mia sequenza avesse N = 1000 token, il mio processo RNN prima K = 100 quindi provo a fare previsioni (perdita di calcolo) e backpropagate. Successivamente, mantenendo lo stato RNN, frenare la catena del gradiente (in pytorch-> staccare) e iniziare un altro k = 100 passi.

Un buon esempio di questa tecnica puoi trovare qui: https://github.com/ksopyla/pytorch_neural_networks/blob/master/RNN/lstm_imdb_tbptt.py

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.