81 votos

¿Cómo consigue un modelo de regresión logística simple una precisión de clasificación del 92% en MNIST?

A pesar de que todas las imágenes del conjunto de datos MNIST están centradas, con una escala similar, y boca arriba sin rotaciones, tienen una importante variación de escritura que me desconcierta cómo un modelo lineal consigue una precisión de clasificación tan alta.

Por lo que soy capaz de visualizar, dada la importante variación de la escritura, los dígitos deberían ser linealmente inseparables en un espacio de 784 dimensiones, es decir, debería haber una frontera no lineal poco compleja (aunque no muy compleja) que separe los diferentes dígitos, similar a la bien citada $XOR$ ejemplo en el que las clases positivas y negativas no pueden ser separadas por ningún clasificador lineal. Me parece desconcertante cómo la regresión logística multiclase produce una precisión tan alta con características totalmente lineales (sin características polinómicas).

A modo de ejemplo, dado cualquier píxel de la imagen, diferentes variaciones manuscritas de los dígitos $2$ y $3$ puede hacer que ese píxel se ilumine o no. Por lo tanto, con un conjunto de pesos aprendidos, cada píxel puede hacer que un dígito se vea como $2$ así como un $3$ . Sólo con una combinación de valores de píxeles se puede decir si un dígito es un $2$ o un $3$ . Esto es cierto para la mayoría de los pares de dígitos. Entonces, ¿cómo es posible que la regresión logística, que basa ciegamente su decisión de forma independiente en todos los valores de los píxeles (sin tener en cuenta en absoluto las dependencias entre píxeles), sea capaz de alcanzar una precisión tan elevada?

Sé que me equivoco en alguna parte o simplemente estoy sobreestimando la variación de las imágenes. Sin embargo, sería genial si alguien pudiera ayudarme con una intuición sobre cómo los dígitos son "casi" linealmente separables.

105voto

Djib2011 Puntos 693

tl;dr Aunque se trata de un conjunto de datos de clasificación de imágenes, sigue siendo un muy fácil tarea, para la que se puede encontrar fácilmente un mapeo directo de las entradas a las predicciones.


Respuesta:

Se trata de una pregunta muy interesante y, gracias a la sencillez de la regresión logística, se puede averiguar la respuesta.

Lo que hace la regresión logística es que para cada imagen acepta $784$ entradas y multiplicarlas con pesos para generar su predicción. Lo interesante es que, debido al mapeo directo entre la entrada y la salida (es decir, no hay capa oculta), el valor de cada peso corresponde a lo que cada uno de los $784$ Las entradas se tienen en cuenta al calcular la probabilidad de cada clase. Ahora, tomando las ponderaciones de cada clase y reformulándolas en $28 \times 28$ (es decir, la resolución de la imagen), podemos decir qué píxeles son más importantes para el cálculo de cada clase .

Tenga en cuenta, de nuevo, que estos son los pesos .

Ahora eche un vistazo a la imagen anterior y céntrese en los dos primeros dígitos (es decir, el cero y el uno). Los pesos azules significan que la intensidad de este píxel contribuye mucho para esa clase y los valores rojos significan que contribuye negativamente.

Ahora imagínate, ¿cómo dibuja una persona un $0$ ? Dibuja una forma circular que está vacía en el medio. Eso es exactamente lo que los pesos captaron. De hecho, si alguien dibuja el centro de la imagen, cuenta negativamente como un cero. Por lo tanto, para reconocer los ceros no se necesitan filtros sofisticados ni características de alto nivel. Basta con mirar las ubicaciones de los píxeles dibujados y juzgar según esto.

Lo mismo para el $1$ . Siempre tiene una línea vertical recta en el centro de la imagen. Todo lo demás cuenta negativamente.

El resto de los dígitos son un poco más complicados, pero con un poco de imaginación se puede ver el $2$ El $3$ El $7$ y el $8$ . El resto de los números son un poco más difíciles, que es lo que realmente limita la regresión logística de llegar a los altos 90.

A través de esto se puede ver que la regresión logística tiene una muy buena oportunidad de acertar en muchas imágenes y por eso tiene una puntuación tan alta.


El código para reproducir la figura anterior es un poco anticuado, pero aquí lo tienes:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X