In che modo un semplice modello di regressione logistica raggiunge una precisione di classificazione del 92% su MNIST?


73

Anche se tutte le immagini nel set di dati MNIST sono centrate, con una scala simile e rivolte verso l'alto senza rotazioni, hanno una variazione significativa della scrittura che mi confonde come un modello lineare raggiunge una precisione di classificazione così elevata.

Per quanto sono in grado di visualizzare, data la notevole variazione della grafia, le cifre dovrebbero essere linearmente inseparabili in uno spazio dimensionale di 784, vale a dire che dovrebbe esserci un piccolo confine non lineare complesso (sebbene non molto complesso) che separa le diverse cifre , simile all'esempio ben citato in cui le classi positive e negative non possono essere separate da nessun classificatore lineare. Mi sembra sconcertante come la regressione logistica multi-classe produca una precisione così elevata con funzionalità interamente lineari (nessuna funzionalità polinomiale).XOR

Ad esempio, dato qualsiasi pixel nell'immagine, diverse variazioni scritte a mano delle cifre e possono illuminare o meno quel pixel. Pertanto, con una serie di pesi appresi, ogni pixel può far apparire una cifra come un e un . Solo con una combinazione di valori di pixel dovrebbe essere possibile dire se una cifra è un o un . Questo è vero per la maggior parte delle coppie di cifre. Quindi, in che modo la regressione logistica, che basa ciecamente la propria decisione in modo indipendente su tutti i valori di pixel (senza considerare alcuna dipendenza tra pixel), è in grado di raggiungere accuratezze così elevate.232323

So che mi sbaglio da qualche parte o sto solo sopravvalutando la variazione delle immagini. Tuttavia, sarebbe bello se qualcuno potesse aiutarmi con un'intuizione su come le cifre siano "quasi" linearmente separabili.


Dai un'occhiata al libro di testo Statistical Learning with Sparsity: the Lasso and Generalizations 3.3.1 Esempio: cifre scritte a mano web.stanford.edu/~hastie/StatLearnSparsity_files/SLS.pdf
Adrian

Sono stato curioso: quanto bene fa qualcosa come un modello lineare penalizzato (cioè glmnet) sul problema? Se ricordo, ciò che stai segnalando è l'accuratezza fuori campione non aperta.
Cliff AB,

Risposte:


91

tl; dr Anche se si tratta di un set di dati di classificazione delle immagini, rimane un compito molto semplice , per il quale è possibile trovare facilmente una mappatura diretta dagli input alle previsioni.


Risposta:

Questa è una domanda molto interessante e grazie alla semplicità della regressione logistica puoi effettivamente trovare la risposta.

Ciò che la regressione logistica fa è che ogni immagine accetti input e li moltiplichi con i pesi per generare la sua previsione. La cosa interessante è che a causa della mappatura diretta tra input e output (cioè nessun layer nascosto), il valore di ciascun peso corrisponde a quanto ciascuno dei input viene preso in considerazione quando si calcola la probabilità di ogni classe. Ora, prendendo i pesi per ogni classe e rimodellandoli in (ovvero la risoluzione dell'immagine), possiamo dire quali pixel sono più importanti per il calcolo di ogni classe .78478428×28

Si noti, ancora una volta, che questi sono i pesi .

Ora dai un'occhiata all'immagine sopra e concentrati sulle prime due cifre (ovvero zero e una). I pesi blu indicano che l'intensità di questo pixel contribuisce molto per quella classe e i valori rossi indicano che contribuisce negativamente.

Ora immagina, come fa una persona a disegnare uno ? Disegna una forma circolare che è vuota nel mezzo. Questo è esattamente ciò che i pesi hanno raccolto. In effetti se qualcuno disegna il centro dell'immagine, conta negativamente come zero. Quindi per riconoscere gli zeri non hai bisogno di filtri sofisticati e funzionalità di alto livello. Puoi semplicemente guardare le posizioni dei pixel disegnati e giudicare in base a questo.0

Stessa cosa per l' . Ha sempre una linea verticale dritta nel mezzo dell'immagine. Tutto il resto conta negativamente.1

Il resto delle cifre è un po 'più complicato, ma con poca immaginazione puoi vedere il , il , il e l' . Il resto dei numeri è un po 'più difficile, che è ciò che effettivamente limita la regressione logistica dal raggiungere gli anni '90.2378

In questo modo puoi vedere che la regressione logistica ha ottime possibilità di ottenere correttamente molte immagini ed è per questo che ha un punteggio così alto.


Il codice per riprodurre la figura sopra è un po 'datato, ma qui vai:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

13
Grazie per l'illustrazione. Queste immagini di peso rendono più chiaro quanto l'accuratezza sia così elevata. La moltiplicazione dei punti di un'immagine di una cifra scritta a mano con l'immagine del peso corrispondente alla vera etichetta dell'immagine sembra "essere" la più alta in confronto al prodotto punto con altre etichette di peso per la maggior parte (ancora il 92% mi sembra molto) delle immagini in MNIST. Tuttavia, è un po 'sorprendente che e o e siano raramente classificati erroneamente l'uno nell'altro esaminando la matrice di confusione. Comunque, questo è quello che è. I dati non mentono mai. :)2378
Nitish Agarwal,

13
Ovviamente aiuta che i campioni MNIST siano centrati, ridimensionati e normalizzati in contrasto prima che il classificatore li veda. Non devi rispondere a domande come "cosa succede se il bordo dello zero attraversa effettivamente il centro del riquadro?" perché il pre-processore ha già fatto molta strada per far sembrare tutti gli zero uguali.
Hobbs,

1
@EricDuminil Ho aggiunto un encomio alla sceneggiatura con il tuo suggerimento. Grazie mille per l'input! : D
Djib2011,

1
@NitishAgarwal, Se pensi che questa risposta sia la risposta alla tua domanda, considera di contrassegnarla come tale.
sintassi

16
Per qualcuno che è interessato ma non particolarmente a conoscenza di questo tipo di elaborazione, questa risposta fornisce un fantastico esempio intuitivo della meccanica.
Chrylis
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.