È pratica comune ridurre al minimo la perdita media sui lotti anziché sulla somma?


15

Tensorflow ha un tutorial di esempio sulla classificazione di CIFAR-10 . Nell'esercitazione viene minimizzata la perdita media di entropia trasversale nel lotto.

def loss(logits, labels):
  """Add L2Loss to all the trainable variables.
  Add summary for for "Loss" and "Loss/avg".
  Args:
    logits: Logits from inference().
    labels: Labels from distorted_inputs or inputs(). 1-D tensor
            of shape [batch_size]
  Returns:
    Loss tensor of type float.
  """
  # Calculate the average cross entropy loss across the batch.
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits, labels, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)

  # The total loss is defined as the cross entropy loss plus all of the weight
  # decay terms (L2 loss).
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

Vedi cifar10.py , linea 267.

Perché invece non minimizza la somma attraverso il batch? Fa la differenza? Non capisco come ciò influirebbe sul calcolo del backprop.


Non si tratta esattamente di somma / avg, ma la scelta della perdita è una scelta di progettazione dell'applicazione. Ad esempio, se sei bravo ad avere ragione in media, ottimizza la media. Se l'applicazione è sensibile allo scenario peggiore (ad esempio un incidente automobilistico), è necessario ottimizzare il valore massimo.
Alex Kreimer,

Risposte:


15

Come menzionato da pkubik, di solito esiste un termine di regolarizzazione per i parametri che non dipende dall'input, ad esempio in tensorflow è come

# Loss function using L2 Regularization
regularizer = tf.nn.l2_loss(weights)
loss = tf.reduce_mean(loss + beta * regularizer)

In questo caso la media sul mini-batch aiuta a mantenere un rapporto fisso tra la cross_entropyperdita e la regularizerperdita mentre le dimensioni del batch vengono modificate.

Inoltre, il tasso di apprendimento è anche sensibile all'entità della perdita (gradiente), quindi per normalizzare il risultato di diverse dimensioni del lotto, prendere la media sembra un'opzione migliore.


Aggiornare

Questo articolo di Facebook (Accuratezza, Minibatch SGD di grandi dimensioni: formazione di ImageNet in 1 ora) mostra che, effettivamente ridimensionare il tasso di apprendimento in base alle dimensioni del lotto funziona abbastanza bene:

Regola di ridimensionamento lineare: quando la dimensione del minibatch viene moltiplicata per k, moltiplicare il tasso di apprendimento per k.

che è essenzialmente lo stesso di moltiplicare il gradiente per k e mantenere invariato il tasso di apprendimento, quindi suppongo che non sia necessario prendere la media.


8

Mi concentrerò sulla parte:

Non capisco come ciò influirebbe sul calcolo del backprop.

Prima di tutto probabilmente avrai già notato che l'unica differenza tra i valori di perdita risultanti è che la perdita media è ridotta rispetto alla somma dal fattore di , cioè che , dove è la dimensione del batch. Possiamo facilmente dimostrare che la stessa relazione è vera per un derivato di qualsiasi variabile wrt. le funzioni di perdita ( ) osservando la definizione di derivata: Ora, vorremmo moltiplicare il valore della funzione e vedere come influenza la derivata: 1BLSUM=BLAVGBdLSUMdx=BdLAVGdx

dLdx=limΔ0L(x+Δ)L(x)Δ
d(cL)dx=limΔ0cL(x+Δ)cL(x)Δ
Quando fattorizziamo la costante e la spostiamo prima del limite dovremmo vedere che troviamo la definizione della derivata originale moltiplicata per una costante, che è esattamente quello che volevamo dimostrare:
d(cL)dx=climΔ0L(x+Δ)L(x)Δ=cdLdx

In SGD aggiorneremo i pesi usando il loro gradiente moltiplicato per il tasso di apprendimento e possiamo vedere chiaramente che possiamo scegliere questo parametro in modo che gli aggiornamenti finali dei pesi siano uguali. La prima regola di aggiornamento: e la seconda regola di aggiornamento (immagina che ): λ

W:=W+λ1dLSUMdW
λ1=λ2B
W:=W+λ1dLAVGdW=W+λ2BdLSUMdW


L'eccellente scoperta di dontloo potrebbe suggerire che l'uso della somma potrebbe essere un approccio un po 'più appropriato. Per giustificare la media che sembra essere più popolare aggiungerei che l'uso della somma potrebbe probabilmente causare alcuni problemi con la regolarizzazione del peso. L'ottimizzazione del fattore di ridimensionamento per i regolarizzatori per lotti di dimensioni diverse può essere fastidiosa quanto l'ottimizzazione del tasso di apprendimento.

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.