28 votos

Lo que está sucediendo aquí, cuando yo uso el cuadrado de la pérdida en la regresión logística?

Estoy tratando de usar el cuadrado de la pérdida a hacer de clasificación binaria. La pérdida es $\sum_i (y_i-p_i)^2$ donde $y_i$ es el fundamento de la verdad de la etiqueta (0 o 1) y $p_i$ es la predicción de la probabilidad de $p_i=\text{Logit}^{-1}(\beta^Tx_i)$.

En otras palabras, estoy reemplazar logística de la pérdida con el cuadrado de la pérdida en la clasificación de configuración, otras partes son los mismos.

Un juguete ejemplo mtcars de datos, en muchos casos, tengo un modelo "similar" a la regresión logística (ver figura siguiente, con la semilla aleatoria 0).

enter image description here

Pero en algo (si lo hacemos set.seed(1)), el cuadrado de la pérdida parece que no funciona bien.

¿Qué está ocurriendo aquí? La optimización no converge? Logística de la pérdida es más fácil optimizar comparación con el cuadrado de la pérdida? Cualquier ayuda se agradece.


Código

d=mtcars[,c("am","mpg","wt")]
plot(d$mpg,d$wt,col=factor(d$am))
lg_fit=glm(am~.,d, family = binomial())
abline(-lg_fit$coefficients[1]/lg_fit$coefficients[3],
       -lg_fit$coefficients[2]/lg_fit$coefficients[3])
grid()

# sq loss
lossSqOnBinary<-function(x,y,w){
  p=plogis(x %*% w)
  return(sum((y-p)^2))
}

# ----------------------------------------------------------------
# note, this random seed is important for squared loss work
# ----------------------------------------------------------------
set.seed(0)

x0=runif(3)
x=as.matrix(cbind(1,d[,2:3]))
y=d$am
opt=optim(x0, lossSqOnBinary, method="BFGS", x=x,y=y)

abline(-opt$par[1]/opt$par[3],
       -opt$par[2]/opt$par[3], lty=2)
legend(25,5,c("logisitc loss","squared loss"), lty=c(1,2))

42voto

Paulius Puntos 369

Parece que has solucionado el problema en tu caso en particular, pero creo que es todavía vale la pena un estudio más atento de la diferencia entre los mínimos cuadrados y máxima probabilidad de regresión logística.

Veamos algo de notación. Deje $L_S(y_i, \hat y_i) = \frac 12(y_i - \hat y_i)^2$$L_L(y_i, \hat y_i) = y_i \log \hat y_i + (1 - y_i) \log(1 - \hat y_i)$. Si estamos haciendo el de máxima verosimilitud (o mínimo de la negativa de registro de probabilidad como estoy haciendo aquí), tenemos $$ \hat \beta_L := \text{argmin}_{b \in \mathbb R^p} -\sum_{i=1}^n y_i \log g^{-1}(x_i^T b) + (1-y_i)\log(1 - g^{-1}(x_i^T b)) $$ con $g$ siendo la nuestra la función de enlace.

Como alternativa tenemos $$ \hat \beta_S := \text{argmin}_{b \in \mathbb R^p} \frac 12 \sum_{i=1}^n (y_i - g^{-1}(x_i^T b))^2 $$ como la solución de mínimos cuadrados. Por lo tanto $\hat \beta_S$ minimiza $L_S$ y de manera similar para $L_L$.

Deje $f_S$ $f_L$ ser el objetivo de las funciones correspondientes a minimizar $L_S$ $L_L$ respectivamente, como se hace para $\hat \beta_S$$\hat \beta_L$. Por último, vamos a $h = g^{-1}$$\hat y_i = h(x_i^T b)$. Tenga en cuenta que si estamos utilizando el enlace canónico tenemos $$ h(z) = \frac{1}{1+e^{-z}} \implica que h'(z) = h(z) (1 - h(z)). $$


Para regular la regresión logística nos han $$ \frac{\partial f_L}{\partial b_j} = -\sum_{i=1}^n h'(x_i^T b)x_{ij} \left( \frac{y_i}{h(x_i^T b)} - \frac{1-y_i}{1 - h(x_i^T b)}\right). $$ El uso de $h' = h \cdot (1 - h)$ podemos simplificar esto $$ \frac{\partial f_L}{\partial b_j} = -\sum_{i=1}^n x_{ij} \left( y_i(1 - \hat y_i) - (1-y_i)\hat y_i\right) = -\sum_{i=1}^n x_{ij}(y_i - \hat y_i) $$ así $$ \nabla f_L(b) = -X^T (Y - \hat Y). $$

Siguiente que vamos a hacer segundas derivadas.

$$ \frac{\partial^2 f_L}{\partial b_j \partial b_k} = \sum_{i=1}^n x_{ij} x_{ik} \hat y_i (1 - \hat y_i). $$ Esto significa que $H_L = X^T A X$ donde $A = \text{diag} \left(\hat Y (1 - \hat Y)\right)$. $H_L$ does depend on the current fitted values $\hat Y$ but $S$ has dropped out, and $H_L$ is PSD. Thus our optimization problem is convex in $b$.


Vamos a comparar esto con menos plazas.

$$ \frac{\partial f_S}{\partial b_j} = - \sum_{i=1}^n (y_i - \hat y_i) h'(x^T_i b)x_{ij}. $$

Esto significa que tenemos $$ \nabla f_S(b) = -X^T (Y - \hat Y). $$ Este es un punto vital: la pendiente es casi el mismo, excepto para todos $i$ $\hat y_i (1 - \hat y_i) \in (0,1)$ así que, básicamente, estamos acoplando el gradiente relativo a $\nabla f_L$. Esto va a hacer que la convergencia más lenta.

Para Hesse lo primero que se puede escribir $$ \frac{\partial f_S}{\partial b_j} = - \sum_{i=1}^n x_{ij}(y_i - \hat y_i) \hat y_i (1 - \hat y_i) = - \sum_{i=1}^n x_{ij}\left( y_i \hat y_i - (1-y_i)\hat y_i^2 + \hat y_i^3\right). $$

Esto nos lleva a $$ \frac{\partial^2 f_S}{\partial b_j \partial b_k} = - \sum_{i=1}^n x_{ij} x_{ik} h'(x_i^T b) \left( y_i - 2(1-y_i)\hat y_i + 3 \hat y_i^2 \right) $$

Deje $B = \text{diag} \left( y_i - 2(1-y_i)\hat y_i + 3 \hat y_i ^2 \right)$. Ahora tenemos $$ H_S = -X^T a B X. $$

Por desgracia para nosotros, los pesos en $B$ no están garantizados para ser no negativo: si $y_i = 0$$\hat y_i > \frac 23 \implies y_i - 2(1-y_i)\hat y_i + 3 \hat y_i ^2 > 0$, mientras que el opuesto es por $\hat y_i < \frac 23$.

Esto significa que $H_S$ no es necesariamente PSD, así que no sólo estamos aplastando nuestra gradientes de que va a dificultar el aprendizaje, pero también nos hemos metido hasta la convexidad de nuestro problema.


Con todo, no es ninguna sorpresa que los mínimos cuadrados de la regresión logística luchas a veces, y en tu ejemplo tienes suficiente equipado cerca de los valores de $0$ o $1$, de modo que $\hat y_i (1 - \hat y_i)$ puede ser bastante pequeña y por lo tanto el gradiente es aplanado.

La conexión de este a las redes neuronales, creo que usted está experimentando lo que Goodfellow, Bengio, y Courville se refiere en su Aprendizaje Profundo libro al escribir el siguiente:

Un tema recurrente a lo largo de la red neuronal de diseño es que el gradiente de la función de costo debe ser lo suficientemente grande y lo suficientemente previsible como para servir como una buena guía para el algoritmo de aprendizaje. Las funciones que saturan (llegar a ser muy plana) atentan contra este objetivo, ya que hacer el degradado llegar a ser muy pequeña. En muchos casos esto sucede porque la activación de las funciones que se usan para producir la salida de las unidades ocultas o las unidades de salida se saturan. La negativa de la log-verosimilitud ayuda a evitar este problema para muchos modelos. Muchas unidades de producción implican una función exp que puede saturar cuando su argumento es muy negativa. La función de registro en el registro negativo-de probabilidad de la función de coste deshace la exp de algunas unidades de salida. Vamos a discutir la interacción entre la función de costo y la elección de la unidad de producción en Segundos 6.2.2.

y, en 6.2.2,

Por desgracia, error cuadrático medio y el error absoluto medio a menudo conducen a buenos resultados cuando se utiliza con un gradiente basado en la optimización. Algunas unidades de salida que saturan producir muy pequeños gradientes cuando se combina con estas funciones de costo. Esta es una razón por la que la cruz de entropía función de costo es más popular que el error cuadrático medio o error absoluto medio, incluso cuando no es necesario para estimar una distribución completa $p(y|x)$.

(ambos extractos del capítulo 6).

8voto

David Puntos 41

Me gustaría agradecer las gracias a @whuber y @Chacona para ayudar. Especialmente a @Chacona, esta derivación es lo que he querido durante años.

El problema ESTÁ en la optimización de la parte. Si ponemos la semilla aleatoria a 1, el valor BFGS no funcionará. Pero si queremos cambiar el algoritmo y cambiar el máximo número de la iteración se va a trabajar de nuevo.

Como @Chacona se mencionó, el problema es el cuadrado de la pérdida para la clasificación no es convexo y más difícil de optimizar. Para agregar en @Chacona de matemáticas, me gustaría presentar algunas visualizaciones en logísticos de la pérdida y el cuadrado de la pérdida.

Vamos a cambiar la presentación de los datos de mtcars, ya que el original juguete ejemplo ha $3$ coeficientes incluyendo el intercepto. Vamos a utilizar otro juguete conjunto de datos generados a partir de mlbench, en este conjunto de datos, establecemos $2$ parámetros, lo que es mejor para la visualización.

Aquí está el demo

  • Los datos se muestran en la figura de la izquierda: tenemos dos clases en dos colores. x,y son dos de las características de los datos. Además, utilizamos la línea roja para representar el clasificador lineal de logística de la pérdida, y la línea azul representa el clasificador lineal desde el cuadrado de la pérdida.

  • El centro de la figura y la figura de la derecha muestra el contorno para la logística de la pérdida (rojo) y el cuadrado de la pérdida (azul). x, y son dos parámetros que son el montaje. El punto es el punto óptimo encontrado por BFGS.

enter image description here

Desde el contorno podemos ver fácilmente cómo por qué optimizar el cuadrado de la pérdida es más difícil: como la Chacona se mencionó, no es convexo.

Aquí es una visión más de persp3d.

enter image description here


Código

set.seed(0)
d=mlbench::mlbench.2dnormals(50,2,r=1)
x=d$x
y=ifelse(d$classes==1,1,0)

lg_loss <- function(w){
  p=plogis(x %*% w)
  L=-y*log(p)-(1-y)*log(1-p)
  return(sum(L))
}
sq_loss <- function(w){
  p=plogis(x %*% w)
  L=sum((y-p)^2)
  return(L)
}

w_grid_v=seq(-15,15,0.1)
w_grid=expand.grid(w_grid_v,w_grid_v)

opt1=optimx::optimx(c(1,1),fn=lg_loss ,method="BFGS")
z1=matrix(apply(w_grid,1,lg_loss),ncol=length(w_grid_v))

opt2=optimx::optimx(c(1,1),fn=sq_loss ,method="BFGS")
z2=matrix(apply(w_grid,1,sq_loss),ncol=length(w_grid_v))

par(mfrow=c(1,3))
plot(d,xlim=c(-3,3),ylim=c(-3,3))
abline(0,-opt1$p2/opt1$p1,col='darkred',lwd=2)
abline(0,-opt2$p2/opt2$p1,col='blue',lwd=2)
grid()
contour(w_grid_v,w_grid_v,z1,col='darkred',lwd=2, nlevels = 8)
points(opt1$p1,opt1$p2,col='darkred',pch=19)
grid()
contour(w_grid_v,w_grid_v,z2,col='blue',lwd=2, nlevels = 8)
points(opt2$p1,opt2$p2,col='blue',pch=19)
grid()


# library(rgl)
# persp3d(w_grid_v,w_grid_v,z1,col='darkred')

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X