tl;dr Aunque se trata de un conjunto de datos de clasificación de imágenes, sigue siendo un muy fácil tarea, para la que se puede encontrar fácilmente un mapeo directo de las entradas a las predicciones.
Respuesta:
Se trata de una pregunta muy interesante y, gracias a la sencillez de la regresión logística, se puede averiguar la respuesta.
Lo que hace la regresión logística es que para cada imagen acepta $784$ entradas y multiplicarlas con pesos para generar su predicción. Lo interesante es que, debido al mapeo directo entre la entrada y la salida (es decir, no hay capa oculta), el valor de cada peso corresponde a lo que cada uno de los $784$ Las entradas se tienen en cuenta al calcular la probabilidad de cada clase. Ahora, tomando las ponderaciones de cada clase y reformulándolas en $28 \times 28$ (es decir, la resolución de la imagen), podemos decir qué píxeles son más importantes para el cálculo de cada clase .
Tenga en cuenta, de nuevo, que estos son los pesos .
Ahora eche un vistazo a la imagen anterior y céntrese en los dos primeros dígitos (es decir, el cero y el uno). Los pesos azules significan que la intensidad de este píxel contribuye mucho para esa clase y los valores rojos significan que contribuye negativamente.
Ahora imagínate, ¿cómo dibuja una persona un $0$ ? Dibuja una forma circular que está vacía en el medio. Eso es exactamente lo que los pesos captaron. De hecho, si alguien dibuja el centro de la imagen, cuenta negativamente como un cero. Por lo tanto, para reconocer los ceros no se necesitan filtros sofisticados ni características de alto nivel. Basta con mirar las ubicaciones de los píxeles dibujados y juzgar según esto.
Lo mismo para el $1$ . Siempre tiene una línea vertical recta en el centro de la imagen. Todo lo demás cuenta negativamente.
El resto de los dígitos son un poco más complicados, pero con un poco de imaginación se puede ver el $2$ El $3$ El $7$ y el $8$ . El resto de los números son un poco más difíciles, que es lo que realmente limita la regresión logística de llegar a los altos 90.
A través de esto se puede ver que la regresión logística tiene una muy buena oportunidad de acertar en muchas imágenes y por eso tiene una puntuación tan alta.
El código para reproducir la figura anterior es un poco anticuado, pero aquí lo tienes:
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))
W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b
y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) #
correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Train model
batch_size = 64
with tf.Session() as sess:
loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []
sess.run(tf.global_variables_initializer())
for step in range(1, 1001):
x_batch, y_batch = mnist.train.next_batch(batch_size)
sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})
l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
loss_tr.append(l_tr)
acc_tr.append(a_tr)
loss_ts.append(l_ts)
acc_ts.append(a_ts)
weights = sess.run(W)
print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
# Plotting:
for i in range(10):
plt.subplot(2, 5, i+1)
weight = weights[:,i].reshape([28,28])
plt.title(i)
plt.imshow(weight, cmap='RdBu') # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
frame1 = plt.gca()
frame1.axes.get_xaxis().set_visible(False)
frame1.axes.get_yaxis().set_visible(False)