Attenzione è un metodo per aggregare un insieme di vettori vi in un solo vettore, spesso tramite una ricerca vettore u . Solitamente, vi sono forniti dal input del modello o gli stati nascosti delle precedenti passi temporali, o gli stati nascosti uno verso il basso livello (nel caso di LSTMs impilati).
Il risultato è spesso chiamato vettore di contesto c , poiché contiene il contesto rilevante per la fase temporale corrente.
Questo vettore di contesto aggiuntivo c viene quindi inserito anche in RNN / LSTM (può essere semplicemente concatenato con l'input originale). Pertanto, il contesto può essere utilizzato per aiutare con la previsione.
Il modo più semplice per farlo è calcolare il vettore di probabilità p=softmax(VTu) e c=∑ipivi dove V è la concatenazione di tutti i precedenti vi . Un vettore di ricerca comune u è lo stato nascosto corrente ht .
Ci sono molte varianti su questo e puoi rendere le cose complicate come vuoi. Ad esempio, anziché utilizzare vTiu i logit, si può scegliere f(vi,u) invece, dove f è una rete neurale arbitrario.
Un meccanismo di attenzione comune per i modelli da sequenza a sequenza utilizza p=softmax(qTtanh(W1vi+W2ht)) , dove v sono gli stati nascosti dell'encoder e ht è l'attuale stato nascosto del decodificatore. q ed entrambi W sono parametri.
Alcuni articoli che mostrano diverse variazioni sull'idea dell'attenzione:
Le reti di puntatori prestano attenzione ai riferimenti di input per risolvere problemi di ottimizzazione combinatoria.
Le reti di entità ricorrenti mantengono stati di memoria separati per entità diverse (persone / oggetti) durante la lettura del testo e aggiornano lo stato di memoria corretto facendo attenzione.
Anche i modelli di trasformatori fanno ampio uso dell'attenzione. La loro formulazione dell'attenzione è leggermente più generale e coinvolge anche i vettori chiave ki : i pesi dell'attenzione p vengono effettivamente calcolati tra i tasti e la ricerca, e il contesto viene quindi costruito con vi .
Ecco una rapida implementazione di una forma di attenzione, anche se non posso garantire la correttezza oltre al fatto che ha superato alcuni semplici test.
RNN di base:
def rnn(inputs_split):
bias = tf.get_variable('bias', shape = [hidden_dim, 1])
weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])
hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
for i, input in enumerate(inputs_split):
input = tf.reshape(input, (batch, in_dim, 1))
last_state = hidden_states[-1]
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
hidden_states.append(hidden)
return hidden_states[-1]
Con attenzione, aggiungiamo solo poche righe prima che venga calcolato il nuovo stato nascosto:
if len(hidden_states) > 1:
logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
probs = tf.nn.softmax(logits)
probs = tf.reshape(probs, (batch, -1, 1, 1))
context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
else:
context = tf.zeros_like(last_state)
last_state = tf.concat([last_state, context], axis = 1)
hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
il codice completo