Implementazione di t-SNE Python: divergenza di Kullback-Leibler


11

t-SNE, come in [1], agisce riducendo progressivamente la divergenza di Kullback-Leibler (KL), fino a quando non viene soddisfatta una certa condizione. I creatori di t-SNE suggeriscono di usare la divergenza di KL come criterio di prestazione per le visualizzazioni:

puoi confrontare le divergenze di Kullback-Leibler riportate da t-SNE. Va benissimo eseguire t-SNE dieci volte e selezionare la soluzione con la divergenza di KL più bassa [2]

Ho provato due implementazioni di t-SNE:

  • pitone : sklearn.manifold.TSNE ().
  • R : tsne, dalla biblioteca (tsne).

Entrambe queste implementazioni, quando è impostata la verbosità, stampano l'errore (divergenza di Kullback-Leibler) per ogni iterazione. Tuttavia, non consentono all'utente di ottenere queste informazioni, il che mi sembra un po 'strano.

Ad esempio, il codice:

import numpy as np
from sklearn.manifold import TSNE
X = np.array([[0, 0, 0], [0, 1, 1], [1, 0, 1], [1, 1, 1]])
model = TSNE(n_components=2, verbose=2, n_iter=200)
t = model.fit_transform(X)

produce:

[t-SNE] Computing pairwise distances...
[t-SNE] Computed conditional probabilities for sample 4 / 4
[t-SNE] Mean sigma: 1125899906842624.000000
[t-SNE] Iteration 10: error = 6.7213750, gradient norm = 0.0012028
[t-SNE] Iteration 20: error = 6.7192064, gradient norm = 0.0012062
[t-SNE] Iteration 30: error = 6.7178683, gradient norm = 0.0012114
...
[t-SNE] Error after 200 iterations: 0.270186

Ora, per quanto ho capito, 0.270186 dovrebbe essere la divergenza di KL. Tuttavia non riesco a ottenere queste informazioni, né dal modello né da t (che è un semplice numpy.ndarray).

Per risolvere questo problema ho potuto: i) Calcolare la divergenza di KL per me stesso, ii) Fare qualcosa di brutto in Python per catturare e analizzare l'output della funzione TSNE () [3]. Tuttavia: i) sarebbe abbastanza stupido ricalcolare la divergenza di KL, quando TSNE () l'ha già calcolata, ii) sarebbe un po 'insolito in termini di codice.

Hai qualche altro suggerimento? Esiste un modo standard per ottenere queste informazioni utilizzando questa libreria?

Ho detto di aver provato la libreria tsne di R , ma preferirei che le risposte si concentrassero sull'implementazione di Python Sklearn.


Riferimenti

[1] http://nbviewer.ipython.org/urls/gist.githubusercontent.com/AlexanderFabisch/1a0c648de22eff4a2a3e/raw/59d5bc5ed8f8bfd9ff1f7faa749d1b095aa97d5a/t-SNE.ipynb

[2] http://homepage.tudelft.nl/19j49/t-SNE.html

[3] /programming/16571150/how-to-capture-stdout-output-from-a-python-function-call

Risposte:


4

La fonte TSNE in scikit-learn è in puro Python. Il fit_transform()metodo Fit in realtà chiama una _fit()funzione privata che quindi chiama una _tsne()funzione privata . Tale _tsne()funzione ha una variabile locale errorche viene stampata alla fine dell'adattamento. Sembra che potresti facilmente cambiare una o due righe di codice sorgente per riportare quel valore fit_transform().


In sostanza, ciò che ho potuto fare è impostare self.error = error alla fine di _tsne (), al fine di recuperarlo dall'istanza TSNE in seguito. Sì, ma ciò significherebbe cambiare il codice sklearn.manifold, e mi chiedevo se gli sviluppatori pensassero ad altri modi per ottenere le informazioni o se non perché non lo facessero (cioè: l'errore è considerato da loro inutile?). Inoltre, se avessi cambiato quel codice avrei bisogno che tutte le persone che eseguivano il mio codice avessero lo stesso hack sulle loro installazioni sklearn. È quello che mi hai suggerito o ho sbagliato?
Joker,

Sì, è quello che ho suggerito come possibile soluzione. Poiché scikit-learn è open source, puoi anche inviare la tua soluzione come richiesta pull e vedere se gli autori lo includeranno nelle versioni future. Non posso parlare del perché abbiano incluso o meno varie cose.
Trey,

2
Grazie. Se qualcun altro è interessato a questo, github.com/scikit-learn/scikit-learn/pull/3422 .
Joker,
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.