KL Perdita con un'unità gaussiana


10

Ho implementato un VAE e ho notato online due diverse implementazioni della divergenza KL gaussiana univaria semplificata. La divergenza originale come qui è Se assumiamo che il nostro precedente sia un'unità gaussiana, cioè e , questo si semplifica fino a Ed ecco dove riposa la mia confusione. Anche se ho trovato alcuni oscuri repository github con l'implementazione di cui sopra, quello che trovo più comunemente usato è:

KLloss=log(σ2σ1)+σ12+(μ1μ2)22σ2212
μ2=0σ2=1
KLloss=log(σ1)+σ12+μ12212
KLloss=12(2log(σ1)σ12μ12+1)

=12(log(σ1)σ1μ12+1)
Ad esempio nell'esercitazione sul codificatore automatico di Keras ufficiale . La mia domanda è quindi: cosa mi sto perdendo tra questi due? La differenza principale è l'eliminazione del fattore 2 sul termine log e non la quadratura della varianza. Analiticamente ho usato quest'ultimo con successo, per quello che vale. Grazie in anticipo per qualsiasi aiuto!

Risposte:


7

Si noti che sostituendo con nell'ultima equazione si recupera il precedente (cioè ). Portandomi a pensare che nel primo caso l'encoder è usato per prevedere la varianza, mentre nel secondo è usato per prevedere la deviazione standard.σ1σ12log(σ1)σ12log(σ1)σ12

Entrambe le formulazioni sono equivalenti e l'obiettivo è invariato.


Non credo che possano essere equivalenti. Sì, entrambi sono minimizzati quando per zero e unit . Tuttavia, nell'equazione originale (con la varianza), la penalità per allontanare dall'unità è molto più grande che nella seconda equazione (basata sulla deviazione standard). La penalità per le variazioni in è la stessa per entrambi, e l'errore di ricostruzione sarebbe lo stesso, quindi l'uso della seconda versione cambia drasticamente l'importanza relativa delle partenze di dall'unità. Cosa mi sto perdendo? μσσμσ
TheBamf

0

Credo che la risposta sia più semplice. Nel VAE, le persone di solito usano una distribuzione normale multivariata, che ha una matrice di covarianza invece della varianza . Sembra confuso in un pezzo di codice ma ha la forma desiderata.Σσ2

Qui puoi trovare la derivazione di una divergenza KL per le distribuzioni normali multivariate: derivare la perdita di divergenza KL per i VAE

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.