Qual è una spiegazione intuitiva della tecnica di massimizzazione delle aspettative? [chiuso]


109

Expectation Maximization (EM) è una sorta di metodo probabilistico per classificare i dati. Per favore correggimi se sbaglio se non è un classificatore.

Qual è una spiegazione intuitiva di questa tecnica EM? Cosa c'è expectationqui e cosa è essere maximized?


12
Qual è l'algoritmo di massimizzazione delle aspettative? , Nature Biotechnology 26 , 897–899 (2008) ha una bella immagine che illustra come funziona l'algoritmo.
chl

@chl Nella parte b della bella immagine , come hanno ottenuto i valori della distribuzione di probabilità su Z (cioè 0,45xA, 0,55xB, ecc.)?
Noob Saibot

3
Puoi guardare questa domanda math.stackexchange.com/questions/25111/…
v4r

3
Link aggiornato all'immagine menzionata da @chl.
n1k31t4

Risposte:


120

Nota: il codice dietro questa risposta può essere trovato qui .


Supponiamo di avere alcuni dati campionati da due diversi gruppi, rosso e blu:

inserisci qui la descrizione dell'immagine

Qui possiamo vedere quale punto dati appartiene al gruppo rosso o blu. Questo rende facile trovare i parametri che caratterizzano ogni gruppo. Ad esempio, la media del gruppo rosso è di circa 3, la media del gruppo blu è di circa 7 (e potremmo trovare la media esatta se volessimo).

Questa è, in generale, nota come stima di massima verosimiglianza . Dati alcuni dati, calcoliamo il valore di un parametro (o parametri) che meglio spiega quei dati.

Ora immagina di non poter vedere quale valore è stato campionato da quale gruppo. Tutto sembra viola per noi:

inserisci qui la descrizione dell'immagine

Qui abbiamo la consapevolezza che ci sono due gruppi di valori, ma non sappiamo a quale gruppo appartiene un valore particolare.

Possiamo ancora stimare le medie per il gruppo rosso e il gruppo blu che meglio si adattano a questi dati?

Sì, spesso possiamo! La massimizzazione delle aspettative ci offre un modo per farlo. L'idea molto generale dietro l'algoritmo è questa:

  1. Inizia con una stima iniziale di ciò che potrebbe essere ogni parametro.
  2. Calcola la probabilità che ogni parametro produca il punto dati.
  3. Calcola i pesi per ogni punto dati che indica se è più rosso o più blu in base alla probabilità che sia prodotto da un parametro. Combina i pesi con i dati ( aspettativa ).
  4. Calcolare una stima migliore per i parametri utilizzando i dati aggiustati in base al peso ( massimizzazione ).
  5. Ripetere i passaggi da 2 a 4 finché la stima del parametro non converge (il processo interrompe la produzione di una stima diversa).

Questi passaggi richiedono ulteriori spiegazioni, quindi esaminerò il problema sopra descritto.

Esempio: stima della media e della deviazione standard

Userò Python in questo esempio, ma il codice dovrebbe essere abbastanza facile da capire se non hai familiarità con questo linguaggio.

Supponiamo di avere due gruppi, rosso e blu, con i valori distribuiti come nell'immagine sopra. Nello specifico, ogni gruppo contiene un valore tratto da una distribuzione normale con i seguenti parametri:

import numpy as np
from scipy import stats

np.random.seed(110) # for reproducible results

# set parameters
red_mean = 3
red_std = 0.8

blue_mean = 7
blue_std = 2

# draw 20 samples from normal distributions with red/blue parameters
red = np.random.normal(red_mean, red_std, size=20)
blue = np.random.normal(blue_mean, blue_std, size=20)

both_colours = np.sort(np.concatenate((red, blue))) # for later use...

Ecco di nuovo un'immagine di questi gruppi rossi e blu (per evitare di dover scorrere verso l'alto):

inserisci qui la descrizione dell'immagine

Quando possiamo vedere il colore di ogni punto (cioè a quale gruppo appartiene), è molto facile stimare la media e la deviazione standard per ogni gruppo. Passiamo semplicemente i valori rosso e blu alle funzioni incorporate in NumPy. Per esempio:

>>> np.mean(red)
2.802
>>> np.std(red)
0.871
>>> np.mean(blue)
6.932
>>> np.std(blue)
2.195

Ma cosa succede se non possiamo vedere i colori dei punti? Cioè, invece di rosso o blu, ogni punto è stato colorato di viola.

Per provare a recuperare i parametri della media e della deviazione standard per i gruppi rosso e blu, possiamo utilizzare la massimizzazione delle aspettative.

Il nostro primo passaggio ( passaggio 1 sopra) consiste nell'indovinare i valori dei parametri per la media e la deviazione standard di ciascun gruppo. Non dobbiamo indovinare in modo intelligente; possiamo scegliere qualsiasi numero che ci piace:

# estimates for the mean
red_mean_guess = 1.1
blue_mean_guess = 9

# estimates for the standard deviation
red_std_guess = 2
blue_std_guess = 1.7

Queste stime dei parametri producono curve a campana simili a queste:

inserisci qui la descrizione dell'immagine

Queste sono stime sbagliate. Entrambi i mezzi (le linee tratteggiate verticali) sembrano lontani da qualsiasi tipo di "mezzo" per gruppi di punti sensibili, per esempio. Vogliamo migliorare queste stime.

Il passaggio successivo ( passaggio 2 ) consiste nel calcolare la probabilità che ogni punto dati appaia sotto le ipotesi dei parametri correnti:

likelihood_of_red = stats.norm(red_mean_guess, red_std_guess).pdf(both_colours)
likelihood_of_blue = stats.norm(blue_mean_guess, blue_std_guess).pdf(both_colours)

Qui, abbiamo semplicemente inserito ogni punto di dati nella funzione di densità di probabilità per una distribuzione normale utilizzando le nostre ipotesi attuali alla media e alla deviazione standard per il rosso e il blu. Questo ci dice, ad esempio, che con le nostre attuali ipotesi il punto dati a 1,761 è molto più probabile che sia rosso (0,189) che blu (0,00003).

Per ogni punto dati, possiamo trasformare questi due valori di probabilità in pesi ( passaggio 3 ) in modo che si sommino a 1 come segue:

likelihood_total = likelihood_of_red + likelihood_of_blue

red_weight = likelihood_of_red / likelihood_total
blue_weight = likelihood_of_blue / likelihood_total

Con le nostre stime attuali e i pesi appena calcolati, possiamo ora calcolare nuove stime per la media e la deviazione standard dei gruppi rosso e blu ( passaggio 4 ).

Calcoliamo due volte la media e la deviazione standard utilizzando tutti i punti dati, ma con le diverse ponderazioni: una volta per i pesi rossi e una volta per i pesi blu.

Il punto chiave dell'intuizione è che maggiore è il peso di un colore su un punto dati, più il punto dati influenza le stime successive per i parametri di quel colore. Questo ha l'effetto di "tirare" i parametri nella giusta direzione.

def estimate_mean(data, weight):
    """
    For each data point, multiply the point by the probability it
    was drawn from the colour's distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among our data points.
    """
    return np.sum(data * weight) / np.sum(weight)

def estimate_std(data, weight, mean):
    """
    For each data point, multiply the point's squared difference
    from a mean value by the probability it was drawn from
    that distribution (its "weight").

    Divide by the total weight: essentially, we're finding where 
    the weight is centred among the values for the difference of
    each data point from the mean.

    This is the estimate of the variance, take the positive square
    root to find the standard deviation.
    """
    variance = np.sum(weight * (data - mean)**2) / np.sum(weight)
    return np.sqrt(variance)

# new estimates for standard deviation
blue_std_guess = estimate_std(both_colours, blue_weight, blue_mean_guess)
red_std_guess = estimate_std(both_colours, red_weight, red_mean_guess)

# new estimates for mean
red_mean_guess = estimate_mean(both_colours, red_weight)
blue_mean_guess = estimate_mean(both_colours, blue_weight)

Abbiamo nuove stime per i parametri. Per migliorarli di nuovo, possiamo tornare al passaggio 2 e ripetere il processo. Lo facciamo finché le stime non convergono o dopo che è stato eseguito un certo numero di iterazioni ( passaggio 5 ).

Per i nostri dati, le prime cinque iterazioni di questo processo hanno questo aspetto (le iterazioni recenti hanno un aspetto più forte):

inserisci qui la descrizione dell'immagine

Vediamo che le medie stanno già convergendo su alcuni valori, e anche le forme delle curve (governate dalla deviazione standard) stanno diventando più stabili.

Se continuiamo per 20 iterazioni, finiamo con quanto segue:

inserisci qui la descrizione dell'immagine

Il processo EM è convergente ai seguenti valori, che risultano molto vicini ai valori effettivi (dove possiamo vedere i colori - nessuna variabile nascosta):

          | EM guess | Actual |  Delta
----------+----------+--------+-------
Red mean  |    2.910 |  2.802 |  0.108
Red std   |    0.854 |  0.871 | -0.017
Blue mean |    6.838 |  6.932 | -0.094
Blue std  |    2.227 |  2.195 |  0.032

Nel codice sopra potresti aver notato che la nuova stima per la deviazione standard è stata calcolata utilizzando la stima dell'iterazione precedente per la media. In definitiva, non importa se calcoliamo prima un nuovo valore per la media poiché stiamo solo trovando la varianza (ponderata) dei valori attorno a un punto centrale. Vedremo ancora convergere le stime per i parametri.


e se non conosciamo nemmeno il numero di distribuzioni normali da cui proviene? Qui hai preso un esempio di distribuzioni k = 2, possiamo anche stimare k e gli insiemi di parametri k?
stackit

1
@stackit: non sono sicuro che esista un modo generale semplice per calcolare il valore più probabile di k come parte del processo EM in questo caso. Il problema principale è che avremmo bisogno di avviare EM con stime per ciascuno dei parametri che vogliamo trovare, e questo implica che dobbiamo conoscere / stimare k prima di iniziare. È possibile, tuttavia, stimare qui la proporzione di punti appartenenti a un gruppo tramite EM. Forse se sovrastimassimo k, la proporzione di tutti i gruppi tranne due scenderebbe quasi a zero. Non l'ho sperimentato, quindi non so quanto bene funzionerebbe nella pratica.
Alex Riley il

1
@AlexRiley Puoi dire qualcosa di più sulle formule per calcolare le nuove stime di media e deviazione standard?
Lemon

2
@AlexRiley Grazie per la spiegazione. Perché le nuove stime di deviazione standard vengono calcolate utilizzando la vecchia ipotesi della media? Cosa succede se vengono trovate per prime le nuove stime della media?
GoodDeeds

1
@Lemon GoodDeeds Kaushal - mi scuso per la mia risposta in ritardo alle tue domande. Ho provato a modificare la risposta per affrontare i punti che hai sollevato. Ho anche reso accessibile tutto il codice utilizzato in questa risposta in un taccuino qui (che include anche spiegazioni più dettagliate di alcuni punti che ho toccato).
Alex Riley

36

EM è un algoritmo per massimizzare una funzione di verosimiglianza quando alcune delle variabili nel modello non sono osservate (cioè quando si hanno variabili latenti).

Ci si potrebbe chiedere, se stiamo solo cercando di massimizzare una funzione, perché non usiamo semplicemente il meccanismo esistente per massimizzare una funzione. Ebbene, se provi a massimizzarlo prendendo derivati ​​e impostandoli a zero, scopri che in molti casi le condizioni del primo ordine non hanno una soluzione. C'è un problema di gallina e uova in quanto per risolvere i parametri del tuo modello devi conoscere la distribuzione dei tuoi dati non osservati; ma la distribuzione dei dati non osservati è una funzione dei parametri del modello.

EM cerca di aggirare questo problema indovinando iterativamente una distribuzione per i dati non osservati, quindi stimando i parametri del modello massimizzando qualcosa che è un limite inferiore sulla funzione di verosimiglianza effettiva e ripetendo fino alla convergenza:

L'algoritmo EM

Inizia con un'ipotesi per i valori dei parametri del modello

Passo E: per ogni punto dati che ha valori mancanti, usa l'equazione del tuo modello per risolvere la distribuzione dei dati mancanti data la tua ipotesi attuale dei parametri del modello e dati i dati osservati (nota che stai risolvendo una distribuzione per ogni valore, non per il valore atteso). Ora che abbiamo una distribuzione per ogni valore mancante, possiamo calcolare l' aspettativa della funzione di verosimiglianza rispetto alle variabili non osservate. Se la nostra ipotesi per il parametro del modello fosse corretta, questa probabilità attesa sarà l'effettiva probabilità dei nostri dati osservati; se i parametri non fossero corretti, sarà solo un limite inferiore.

Passo M: ora che abbiamo una funzione di probabilità attesa senza variabili non osservate, massimizza la funzione come faresti nel caso completamente osservato, per ottenere una nuova stima dei parametri del tuo modello.

Ripeti fino alla convergenza.


5
Non capisco il tuo passo elettronico. Parte del problema è che mentre sto imparando queste cose, non riesco a trovare persone che usano la stessa terminologia. Allora cosa intendi per equazione modello? Non so cosa intendi per risolvere per una distribuzione di probabilità?
user678392

27

Ecco una ricetta semplice per comprendere l'algoritmo di massimizzazione delle aspettative:

1- Leggi questo tutorial paper EM di Do e Batzoglou.

2- Potresti avere dei punti interrogativi nella tua testa, dai un'occhiata alle spiegazioni in questa pagina di scambio di stack matematici .

3- Guarda questo codice che ho scritto in Python che spiega l'esempio nel documento tutorial EM dell'elemento 1:

Attenzione: il codice potrebbe essere disordinato / non ottimale, poiché non sono uno sviluppatore Python. Ma fa il lavoro.

import numpy as np
import math

#### E-M Coin Toss Example as given in the EM tutorial paper by Do and Batzoglou* #### 

def get_mn_log_likelihood(obs,probs):
    """ Return the (log)likelihood of obs, given the probs"""
    # Multinomial Distribution Log PMF
    # ln (pdf)      =             multinomial coeff            *   product of probabilities
    # ln[f(x|n, p)] = [ln(n!) - (ln(x1!)+ln(x2!)+...+ln(xk!))] + [x1*ln(p1)+x2*ln(p2)+...+xk*ln(pk)]     

    multinomial_coeff_denom= 0
    prod_probs = 0
    for x in range(0,len(obs)): # loop through state counts in each observation
        multinomial_coeff_denom = multinomial_coeff_denom + math.log(math.factorial(obs[x]))
        prod_probs = prod_probs + obs[x]*math.log(probs[x])

    multinomial_coeff = math.log(math.factorial(sum(obs))) -  multinomial_coeff_denom
    likelihood = multinomial_coeff + prod_probs
    return likelihood

# 1st:  Coin B, {HTTTHHTHTH}, 5H,5T
# 2nd:  Coin A, {HHHHTHHHHH}, 9H,1T
# 3rd:  Coin A, {HTHHHHHTHH}, 8H,2T
# 4th:  Coin B, {HTHTTTHHTT}, 4H,6T
# 5th:  Coin A, {THHHTHHHTH}, 7H,3T
# so, from MLE: pA(heads) = 0.80 and pB(heads)=0.45

# represent the experiments
head_counts = np.array([5,9,8,4,7])
tail_counts = 10-head_counts
experiments = zip(head_counts,tail_counts)

# initialise the pA(heads) and pB(heads)
pA_heads = np.zeros(100); pA_heads[0] = 0.60
pB_heads = np.zeros(100); pB_heads[0] = 0.50

# E-M begins!
delta = 0.001  
j = 0 # iteration counter
improvement = float('inf')
while (improvement>delta):
    expectation_A = np.zeros((5,2), dtype=float) 
    expectation_B = np.zeros((5,2), dtype=float)
    for i in range(0,len(experiments)):
        e = experiments[i] # i'th experiment
        ll_A = get_mn_log_likelihood(e,np.array([pA_heads[j],1-pA_heads[j]])) # loglikelihood of e given coin A
        ll_B = get_mn_log_likelihood(e,np.array([pB_heads[j],1-pB_heads[j]])) # loglikelihood of e given coin B

        weightA = math.exp(ll_A) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of A proportional to likelihood of A 
        weightB = math.exp(ll_B) / ( math.exp(ll_A) + math.exp(ll_B) ) # corresponding weight of B proportional to likelihood of B                            

        expectation_A[i] = np.dot(weightA, e) 
        expectation_B[i] = np.dot(weightB, e)

    pA_heads[j+1] = sum(expectation_A)[0] / sum(sum(expectation_A)); 
    pB_heads[j+1] = sum(expectation_B)[0] / sum(sum(expectation_B)); 

    improvement = max( abs(np.array([pA_heads[j+1],pB_heads[j+1]]) - np.array([pA_heads[j],pB_heads[j]]) ))
    j = j+1

Trovo che il tuo programma risulterà sia A che B a 0.66, lo implemento anche usando scala, inoltre trovo che il risultato è 0.66, puoi aiutarci a controllarlo?
zjffdu

Utilizzando un foglio di calcolo, trovo i tuoi risultati 0.66 solo se le mie ipotesi iniziali sono uguali. Altrimenti, posso riprodurre l'output del tutorial.
Soakley

@zjffdu, quante iterazioni esegue l'EM prima di restituirti 0.66? Se si inizializza con valori uguali, potrebbe rimanere bloccato a un massimo locale e vedrai che il numero di iterazioni è estremamente basso (poiché non ci sono miglioramenti).
Zhubarb

Puoi anche controllare questa diapositiva di Andrew Ng e la nota del corso di Harvard
Minh Phan

16

Tecnicamente il termine "EM" è un po 'sotto specificato, ma presumo che ti riferisci alla tecnica di analisi dei cluster Gaussian Mixture Modeling, che è un'istanza del principio EM generale.

In realtà, l' analisi dei cluster EM non è un classificatore . So che alcune persone considerano il clustering una "classificazione senza supervisione", ma in realtà l'analisi dei cluster è qualcosa di completamente diverso.

La differenza fondamentale e il grande malinteso classificazione che le persone hanno sempre con l'analisi dei cluster è che: nell'analisi dei cluster non esiste una "soluzione corretta" . È un metodo di scoperta della conoscenza , in realtà ha lo scopo di trovare qualcosa di nuovo ! Questo rende la valutazione molto complicata. Viene spesso valutata utilizzando una classificazione nota come riferimento, ma non sempre è appropriata: la classificazione che hai può o meno riflettere ciò che è nei dati.

Faccio un esempio: disponi di un ampio set di dati di clienti, inclusi i dati sul sesso. Un metodo che divide questo set di dati in "maschio" e "femmina" è ottimale quando lo si confronta con le classi esistenti. In un modo di pensare "predittivo" questo è positivo, poiché per i nuovi utenti ora puoi prevedere il loro sesso. In un modo di pensare "knowledge discovery", questo è effettivamente un male, perché si voleva scoprire qualche nuova struttura nei dati. Un metodo che, ad esempio, suddividerebbe i dati in anziani e bambini, tuttavia, otterrebbe un punteggio peggiore rispetto alla classe maschile / femminile. Tuttavia, sarebbe un eccellente risultato di raggruppamento (se l'età non fosse specificata).

Ora torna a EM. Essenzialmente si presuppone che i dati siano composti da più distribuzioni normali multivariate (si noti che questo è un file presupposto molto forte, in particolare quando si fissa il numero di cluster!). Quindi cerca di trovare un modello ottimale locale per questo, migliorando alternativamente il modello e l'assegnazione degli oggetti al modello .

Per ottenere i migliori risultati in un contesto di classificazione, scegli il numero di cluster più grande del numero di classi o applica il raggruppamento solo a singole classi (per scoprire se esiste una struttura all'interno della classe!).

Supponiamo che tu voglia addestrare un classificatore a distinguere "automobili", "biciclette" e "camion". È poco utile presumere che i dati consistano esattamente di 3 distribuzioni normali. Tuttavia, puoi presumere che ci sia più di un tipo di auto (e camion e biciclette). Quindi, invece di addestrare un classificatore per queste tre classi, raggruppate auto, camion e biciclette in 10 gruppi ciascuno (o forse 10 auto, 3 camion e 3 biciclette, qualunque sia), quindi addestrate un classificatore per distinguere queste 30 classi, e poi unire il risultato della classe di nuovo alle classi originali. Potresti anche scoprire che esiste un cluster particolarmente difficile da classificare, ad esempio Trike. Sono un po 'macchine e un po' biciclette. O camion per le consegne, che sono più simili a macchine di grandi dimensioni che a camion.


come è sottospecificato EM?
sam boosalis

C'è più di una versione di esso. Tecnicamente, puoi anche chiamare lo stile Lloyd k-significa "EM". Devi specificare quale modello utilizzi.
HA QUIT - Anony-Mousse

2

Se le altre risposte sono buone, cercherò di fornire un'altra prospettiva e di affrontare la parte intuitiva della domanda.

L'algoritmo EM (Expectation-Maximization) è una variante di una classe di algoritmi iterativi che utilizzano la dualità

Estratto (enfasi mia):

In matematica, una dualità, in generale, traduce concetti, teoremi o strutture matematiche in altri concetti, teoremi o strutture, in modo uno-a-uno, spesso (ma non sempre) per mezzo di un'operazione di involuzione: se il duale di A è B, quindi il duale di B è A. Tali involuzioni a volte hanno punti fissi , quindi il duale di A è A stesso

Di solito una doppia B di un oggetto A è correlata ad A in un modo che preserva una certa simmetria o compatibilità . Ad esempio AB = const

Esempi di algoritmi iterativi, che impiegano la dualità (nel senso precedente) sono:

  1. Algoritmo euclideo per il massimo divisore comune e sue varianti
  2. Algoritmo Gram – Schmidt Vector Basis e varianti
  3. Media aritmetica - Disuguaglianza media geometrica e sue varianti
  4. Algoritmo di massimizzazione delle aspettative e sue varianti (vedere anche qui per una visualizzazione geometrica delle informazioni )
  5. (.. altri algoritmi simili ..)

In modo simile, l'algoritmo EM può anche essere visto come due passaggi di doppia massimizzazione :

.. [EM] è visto come massimizzare una funzione congiunta dei parametri e della distribuzione sulle variabili non osservate .. L'E-step massimizza questa funzione rispetto alla distribuzione sulle variabili non osservate; il passo M rispetto ai parametri ..

In un algoritmo iterativo che utilizza la dualità c'è l'assunzione esplicita (o implicita) di un punto di convergenza di equilibrio (o fisso) (per EM questo è dimostrato usando la disuguaglianza di Jensen)

Quindi lo schema di tali algoritmi è:

  1. Passo tipo E: trova la migliore soluzione x rispetto a dato y che viene mantenuto costante.
  2. Passo tipo M (doppio): trova la migliore soluzione y rispetto a x (come calcolato nel passaggio precedente) mantenuta costante.
  3. Criterio del passaggio di terminazione / convergenza: ripetere i passaggi 1, 2 con i valori aggiornati di x , y fino a raggiungere la convergenza (o il numero di iterazioni specificato)

Si noti che quando un tale algoritmo converge a un ottimo (globale), ha trovato una configurazione che è migliore in entrambi i sensi (cioè sia nel dominio x / parametri che nel dominio / parametri y ). Tuttavia, l'algoritmo può trovare solo un ottimo locale e non quello globale .

direi che questa è la descrizione intuitiva dello schema dell'algoritmo

Per gli argomenti e le applicazioni statistiche, altre risposte hanno dato buone spiegazioni (controlla anche i riferimenti in questa risposta)


2

La risposta accettata fa riferimento al Chuong EM Paper , che fa un lavoro decente nello spiegare EM. C'è anche un video di YouTube che spiega il documento in modo più dettagliato.

Per ricapitolare, ecco lo scenario:

1st:  {H,T,T,T,H,H,T,H,T,H} 5 Heads, 5 Tails; Did coin A or B generate me?
2nd:  {H,H,H,H,T,H,H,H,H,H} 9 Heads, 1 Tails
3rd:  {H,T,H,H,H,H,H,T,H,H} 8 Heads, 2 Tails
4th:  {H,T,H,T,T,T,H,H,T,T} 4 Heads, 6 Tails
5th:  {T,H,H,H,T,H,H,H,T,H} 7 Heads, 3 Tails

Two possible coins, A & B are used to generate these distributions.
A & B have an unknown parameter: their bias towards heads.

We don't know the biases, but we can simply start with a guess: A=60% heads, B=50% heads.

Nel caso della domanda della prima prova, intuitivamente penseremmo che B l'abbia generata poiché la proporzione di teste corrisponde molto bene al pregiudizio di B ... ma quel valore era solo un'ipotesi, quindi non possiamo esserne sicuri.

Con questo in mente, mi piace pensare alla soluzione EM in questo modo:

  • Ogni prova di lanci arriva a 'votare' su quale moneta gli piace di più
    • Questo si basa su come ogni moneta si adatta alla sua distribuzione
    • OPPURE, dal punto di vista della moneta, c'è un'alta aspettativa di vedere questo processo rispetto all'altra moneta (in base alla probabilità del registro ).
  • A seconda di quanto ogni prova gradisce ogni moneta, può aggiornare l'ipotesi del parametro di quella moneta (bias).
    • Più a una prova piace una moneta, più aggiorna il bias della moneta per riflettere il proprio!
    • Essenzialmente i pregiudizi della moneta vengono aggiornati combinando questi aggiornamenti ponderati in tutte le prove, un processo chiamato ( massimizzazione ), che si riferisce al tentativo di ottenere le migliori ipotesi per il bias di ciascuna moneta data una serie di prove.

Questa potrebbe essere una semplificazione eccessiva (o anche fondamentalmente sbagliata su alcuni livelli), ma spero che questo aiuti a livello intuitivo!


1

EM viene utilizzato per massimizzare la probabilità di un modello Q con variabili latenti Z.

È un'ottimizzazione iterativa.

theta <- initial guess for hidden parameters
while not converged:
    #e-step
    Q(theta'|theta) = E[log L(theta|Z)]
    #m-step
    theta <- argmax_theta' Q(theta'|theta)

e-step: data la stima corrente di Z, calcolare la funzione di verosimiglianza attesa

m-step: trova theta che massimizza questo Q

Esempio GMM:

e-step: stima le assegnazioni delle etichette per ogni datapoint data l'attuale stima del parametro gmm

m-step: massimizza un nuovo theta date le nuove assegnazioni dell'etichetta

K-means è anche un algoritmo EM e ci sono molte animazioni esplicative su K-means.


1

Utilizzando lo stesso articolo di Do e Batzoglou citato nella risposta di Zhubarb, ho implementato EM per quel problema in Java . I commenti alla sua risposta mostrano che l'algoritmo si blocca su un ottimo locale, il che si verifica anche con la mia implementazione se i parametri thetaA e thetaB sono gli stessi.

Di seguito è riportato lo standard output del mio codice, che mostra la convergenza dei parametri.

thetaA = 0.71301, thetaB = 0.58134
thetaA = 0.74529, thetaB = 0.56926
thetaA = 0.76810, thetaB = 0.54954
thetaA = 0.78316, thetaB = 0.53462
thetaA = 0.79106, thetaB = 0.52628
thetaA = 0.79453, thetaB = 0.52239
thetaA = 0.79593, thetaB = 0.52073
thetaA = 0.79647, thetaB = 0.52005
thetaA = 0.79667, thetaB = 0.51977
thetaA = 0.79674, thetaB = 0.51966
thetaA = 0.79677, thetaB = 0.51961
thetaA = 0.79678, thetaB = 0.51960
thetaA = 0.79679, thetaB = 0.51959
Final result:
thetaA = 0.79678, thetaB = 0.51960

Di seguito è riportata la mia implementazione Java di EM per risolvere il problema in (Do e Batzoglou, 2008). La parte centrale dell'implementazione è il ciclo per eseguire EM fino a quando i parametri convergono.

private Parameters _parameters;

public Parameters run()
{
    while (true)
    {
        expectation();

        Parameters estimatedParameters = maximization();

        if (_parameters.converged(estimatedParameters)) {
            break;
        }

        _parameters = estimatedParameters;
    }

    return _parameters;
}

Di seguito è riportato l'intero codice.

import java.util.*;

/*****************************************************************************
This class encapsulates the parameters of the problem. For this problem posed
in the article by (Do and Batzoglou, 2008), the parameters are thetaA and
thetaB, the probability of a coin coming up heads for the two coins A and B,
respectively.
*****************************************************************************/
class Parameters
{
    double _thetaA = 0.0; // Probability of heads for coin A.
    double _thetaB = 0.0; // Probability of heads for coin B.

    double _delta = 0.00001;

    public Parameters(double thetaA, double thetaB)
    {
        _thetaA = thetaA;
        _thetaB = thetaB;
    }

    /*************************************************************************
    Returns true if this parameter is close enough to another parameter
    (typically the estimated parameter coming from the maximization step).
    *************************************************************************/
    public boolean converged(Parameters other)
    {
        if (Math.abs(_thetaA - other._thetaA) < _delta &&
            Math.abs(_thetaB - other._thetaB) < _delta)
        {
            return true;
        }

        return false;
    }

    public double getThetaA()
    {
        return _thetaA;
    }

    public double getThetaB()
    {
        return _thetaB;
    }

    public String toString()
    {
        return String.format("thetaA = %.5f, thetaB = %.5f", _thetaA, _thetaB);
    }

}


/*****************************************************************************
This class encapsulates an observation, that is the number of heads
and tails in a trial. The observation can be either (1) one of the
experimental observations, or (2) an estimated observation resulting from
the expectation step.
*****************************************************************************/
class Observation
{
    double _numHeads = 0;
    double _numTails = 0;

    public Observation(String s)
    {
        for (int i = 0; i < s.length(); i++)
        {
            char c = s.charAt(i);

            if (c == 'H')
            {
                _numHeads++;
            }
            else if (c == 'T')
            {
                _numTails++;
            }
            else
            {
                throw new RuntimeException("Unknown character: " + c);
            }
        }
    }

    public Observation(double numHeads, double numTails)
    {
        _numHeads = numHeads;
        _numTails = numTails;
    }

    public double getNumHeads()
    {
        return _numHeads;
    }

    public double getNumTails()
    {
        return _numTails;
    }

    public String toString()
    {
        return String.format("heads: %.1f, tails: %.1f", _numHeads, _numTails);
    }

}

/*****************************************************************************
This class runs expectation-maximization for the problem posed by the article
from (Do and Batzoglou, 2008).
*****************************************************************************/
public class EM
{
    // Current estimated parameters.
    private Parameters _parameters;

    // Observations from the trials. These observations are set once.
    private final List<Observation> _observations;

    // Estimated observations per coin. These observations are the output
    // of the expectation step.
    private List<Observation> _expectedObservationsForCoinA;
    private List<Observation> _expectedObservationsForCoinB;

    private static java.io.PrintStream o = System.out;

    /*************************************************************************
    Principal constructor.
    @param observations The observations from the trial.
    @param parameters The initial guessed parameters.
    *************************************************************************/
    public EM(List<Observation> observations, Parameters parameters)
    {
        _observations = observations;
        _parameters = parameters;
    }

    /*************************************************************************
    Run EM until parameters converge.
    *************************************************************************/
    public Parameters run()
    {

        while (true)
        {
            expectation();

            Parameters estimatedParameters = maximization();

            o.printf("%s\n", estimatedParameters);

            if (_parameters.converged(estimatedParameters)) {
                break;
            }

            _parameters = estimatedParameters;
        }

        return _parameters;

    }

    /*************************************************************************
    Given the observations and current estimated parameters, compute new
    estimated completions (distribution over the classes) and observations.
    *************************************************************************/
    private void expectation()
    {

        _expectedObservationsForCoinA = new ArrayList<Observation>();
        _expectedObservationsForCoinB = new ArrayList<Observation>();

        for (Observation observation : _observations)
        {
            int numHeads = (int)observation.getNumHeads();
            int numTails = (int)observation.getNumTails();

            double probabilityOfObservationForCoinA=
                binomialProbability(10, numHeads, _parameters.getThetaA());

            double probabilityOfObservationForCoinB=
                binomialProbability(10, numHeads, _parameters.getThetaB());

            double normalizer = probabilityOfObservationForCoinA +
                                probabilityOfObservationForCoinB;

            // Compute the completions for coin A and B (i.e. the probability
            // distribution of the two classes, summed to 1.0).

            double completionCoinA = probabilityOfObservationForCoinA /
                                     normalizer;
            double completionCoinB = probabilityOfObservationForCoinB /
                                     normalizer;

            // Compute new expected observations for the two coins.

            Observation expectedObservationForCoinA =
                new Observation(numHeads * completionCoinA,
                                numTails * completionCoinA);

            Observation expectedObservationForCoinB =
                new Observation(numHeads * completionCoinB,
                                numTails * completionCoinB);

            _expectedObservationsForCoinA.add(expectedObservationForCoinA);
            _expectedObservationsForCoinB.add(expectedObservationForCoinB);
        }
    }

    /*************************************************************************
    Given new estimated observations, compute new estimated parameters.
    *************************************************************************/
    private Parameters maximization()
    {

        double sumCoinAHeads = 0.0;
        double sumCoinATails = 0.0;
        double sumCoinBHeads = 0.0;
        double sumCoinBTails = 0.0;

        for (Observation observation : _expectedObservationsForCoinA)
        {
            sumCoinAHeads += observation.getNumHeads();
            sumCoinATails += observation.getNumTails();
        }

        for (Observation observation : _expectedObservationsForCoinB)
        {
            sumCoinBHeads += observation.getNumHeads();
            sumCoinBTails += observation.getNumTails();
        }

        return new Parameters(sumCoinAHeads / (sumCoinAHeads + sumCoinATails),
                              sumCoinBHeads / (sumCoinBHeads + sumCoinBTails));

        //o.printf("parameters: %s\n", _parameters);

    }

    /*************************************************************************
    Since the coin-toss experiment posed in this article is a Bernoulli trial,
    use a binomial probability Pr(X=k; n,p) = (n choose k) * p^k * (1-p)^(n-k).
    *************************************************************************/
    private static double binomialProbability(int n, int k, double p)
    {
        double q = 1.0 - p;
        return nChooseK(n, k) * Math.pow(p, k) * Math.pow(q, n-k);
    }

    private static long nChooseK(int n, int k)
    {
        long numerator = 1;

        for (int i = 0; i < k; i++)
        {
            numerator = numerator * n;
            n--;
        }

        long denominator = factorial(k);

        return (long)(numerator / denominator);
    }

    private static long factorial(int n)
    {
        long result = 1;
        for (; n >0; n--)
        {
            result = result * n;
        }

        return result;
    }

    /*************************************************************************
    Entry point into the program.
    *************************************************************************/
    public static void main(String argv[])
    {
        // Create the observations and initial parameter guess
        // from the (Do and Batzoglou, 2008) article.

        List<Observation> observations = new ArrayList<Observation>();
        observations.add(new Observation("HTTTHHTHTH"));
        observations.add(new Observation("HHHHTHHHHH"));
        observations.add(new Observation("HTHHHHHTHH"));
        observations.add(new Observation("HTHTTTHHTT"));
        observations.add(new Observation("THHHTHHHTH"));

        Parameters initialParameters = new Parameters(0.6, 0.5);

        EM em = new EM(observations, initialParameters);

        Parameters finalParameters = em.run();

        o.printf("Final result:\n%s\n", finalParameters);
    }
}
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.