tl; dr Anche se si tratta di un set di dati di classificazione delle immagini, rimane un compito molto semplice , per il quale è possibile trovare facilmente una mappatura diretta dagli input alle previsioni.
Risposta:
Questa è una domanda molto interessante e grazie alla semplicità della regressione logistica puoi effettivamente trovare la risposta.
Ciò che la regressione logistica fa è che ogni immagine accetti input e li moltiplichi con i pesi per generare la sua previsione. La cosa interessante è che a causa della mappatura diretta tra input e output (cioè nessun layer nascosto), il valore di ciascun peso corrisponde a quanto ciascuno dei input viene preso in considerazione quando si calcola la probabilità di ogni classe. Ora, prendendo i pesi per ogni classe e rimodellandoli in (ovvero la risoluzione dell'immagine), possiamo dire quali pixel sono più importanti per il calcolo di ogni classe .78478428×28
Si noti, ancora una volta, che questi sono i pesi .
Ora dai un'occhiata all'immagine sopra e concentrati sulle prime due cifre (ovvero zero e una). I pesi blu indicano che l'intensità di questo pixel contribuisce molto per quella classe e i valori rossi indicano che contribuisce negativamente.
Ora immagina, come fa una persona a disegnare uno ? Disegna una forma circolare che è vuota nel mezzo. Questo è esattamente ciò che i pesi hanno raccolto. In effetti se qualcuno disegna il centro dell'immagine, conta negativamente come zero. Quindi per riconoscere gli zeri non hai bisogno di filtri sofisticati e funzionalità di alto livello. Puoi semplicemente guardare le posizioni dei pixel disegnati e giudicare in base a questo.0
Stessa cosa per l' . Ha sempre una linea verticale dritta nel mezzo dell'immagine. Tutto il resto conta negativamente.1
Il resto delle cifre è un po 'più complicato, ma con poca immaginazione puoi vedere il , il , il e l' . Il resto dei numeri è un po 'più difficile, che è ciò che effettivamente limita la regressione logistica dal raggiungere gli anni '90.2378
In questo modo puoi vedere che la regressione logistica ha ottime possibilità di ottenere correttamente molte immagini ed è per questo che ha un punteggio così alto.
Il codice per riprodurre la figura sopra è un po 'datato, ma qui vai:
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b
y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) #
correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Train model
batch_size = 64
with tf.Session() as sess:
loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []
sess.run(tf.global_variables_initializer())
for step in range(1, 1001):
x_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})
l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
loss_tr.append(l_tr)
acc_tr.append(a_tr)
loss_ts.append(l_ts)
acc_ts.append(a_ts)
weights = sess.run(W)
print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# Plotting:
for i in range(10):
plt.subplot(2, 5, i+1)
weight = weights[:,i].reshape([28,28])
plt.title(i)
plt.imshow(weight, cmap='RdBu') # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)
frame1.axes.get_yaxis().set_visible(False)