63 votos

¿Cómo evita LSTM el problema de la desaparición del gradiente?

El LSTM se inventó específicamente para evitar el problema del gradiente de fuga. Se supone que lo hace con el carrusel de errores constantes (CEC), que en el diagrama siguiente (de Greff et al. ) corresponden al bucle alrededor de célula .

LSTM
(fuente: <a href="https://deeplearning4j.org/img/greff_lstm_diagram.png" rel="noreferrer">deeplearning4j.org </a>)

Y entiendo que esa parte puede verse como una especie de función identidad, por lo que la derivada es uno y el gradiente se mantiene constante.

Lo que no entiendo es cómo no se desvanece debido a las otras funciones de activación ? Las puertas de entrada, salida y olvido utilizan una sigmoide, cuya derivada es como máximo 0,25, y g y h eran tradicionalmente tanh . ¿Cómo es que la retropropagación a través de estos no hace desaparecer el gradiente?

35voto

karatchov Puntos 230

El gradiente de fuga se explica mejor en el caso unidimensional. El multidimensional es más complicado pero esencialmente análogo. Puedes repasarlo en este excelente artículo [1].

Supongamos que tenemos un estado oculto $h_t$ en el paso de tiempo $t$ . Si simplificamos las cosas y eliminamos los sesgos y las entradas, tenemos $$h_t = \sigma(w h_{t-1}).$$ Entonces puede demostrar que

\begin{align} \frac{\partial h_{t'}}{\partial h_t} &= \prod_{k=1}^{t' - t} w \sigma'(w h_{t'-k})\\ &= \underbrace{w^{t' - t}}_{!!!}\prod_{k=1}^{t' - t} \sigma'(w h_{t'-k}) \end{align} El factor marcado con !!! es el crucial. Si el peso no es igual a 1, o bien decaerá a cero exponencialmente rápido en $t'-t$ o crecer exponencialmente rápido .

En los LSTMs, tienes el estado de la célula $s_t$ . La derivada allí es de la forma $$\frac{\partial s_{t'}}{\partial s_t} = \prod_{k=1}^{t' - t} \sigma(v_{t+k}).$$ Aquí $v_t$ es la entrada a la puerta del olvido. Como puedes ver, no hay ningún factor de decaimiento exponencialmente rápido involucrado. En consecuencia, hay al menos un camino en el que el gradiente no desaparece. Para la derivación completa, véase [2].

[1] Pascanu, Razvan, Tomas Mikolov y Yoshua Bengio. "Sobre la dificultad de entrenar redes neuronales recurrentes". ICML (3) 28 (2013): 1310-1318.

[2] Bayer, Justin Simon. Aprendizaje de representaciones de secuencias. Diss. München, Technische Universität München, Diss., 2015, 2015.

25voto

Riaan Engelbrecht Puntos 544

Me gustaría añadir algunos detalles a la respuesta aceptada, porque creo que es un poco más matizada y el matiz puede no ser obvio para alguien que aprende por primera vez sobre las RNN.

Para la RNN de vainilla, $$\frac{\partial h_{t'}}{\partial h_{t}} = \prod _{k=1} ^{t'-t} w \sigma'(w h_{t'-k})$$ .

Para el LSTM, $$\frac{\partial s_{t'}}{\partial s_{t}} = \prod _{k=1} ^{t'-t} \sigma(v_{t+k})$$

  • una pregunta natural es, ¿no tienen ambas sumas de productos un término sigmoide que cuando se multiplican juntos $t'-t$ ¿los tiempos pueden desaparecer?
  • la respuesta es por lo que la LSTM también sufrirá de gradientes desvanecidos, pero no tanto como la RNN vainilla

La diferencia es que para la RNN de vainilla, el gradiente decae con $w \sigma'(\cdot)$ mientras que para la LSTM el gradiente decae con $\sigma (\cdot)$ .

Para la LSTM, hay un conjunto de pesos que se pueden aprender de manera que $$\sigma (\cdot) \approx 1$$ Supongamos que $v_{t+k} = wx$ para un poco de peso $w$ y la entrada $x$ . Entonces la red neuronal puede aprender un gran $w$ para evitar que los gradientes se desvanezcan.

Por ejemplo, en el caso 1D si $x=1$ , $w=10$ $v_{t+k}=10$ entonces el factor de decaimiento $\sigma (\cdot) = 0.99995$ o el gradiente muere como: $$(0.99995)^{t'-t}$$

Para la RNN de vainilla, no hay un conjunto de pesos que se puede aprender de manera que $$w \sigma'(w h_{t'-k}) \approx 1 $$

Por ejemplo, en el caso 1D, supongamos que $h_{t'-k}=1$ . La función $w \sigma'(w*1)$ logra un máximo de $0.224$ en $w=1.5434$ . Esto significa que el gradiente decaerá como, $$(0.224)^{t'-t}$$

3voto

Joerg Puntos 116

http://www.felixgers.de/papers/phd.pdf Consulte la sección 2.2 y 3.2.2 donde se explica la parte del error truncado. No propagan el error si se escapa de la memoria de la célula (es decir, si hay una puerta de entrada cerrada/activada), pero actualizan los pesos de la puerta basándose en el error sólo para ese instante de tiempo. Más tarde se hace cero durante la propagación posterior. Esto es una especie de hack, pero la razón para hacerlo es que el flujo de error a lo largo de las puertas de todos modos decae con el tiempo.

3voto

Dustin Laine Puntos 213

La imagen del bloque LSTM de Greff et al. (2015) describe una variante que los autores llaman LSTM de vainilla . Es un poco diferente de la definición original de Hochreiter & Schmidhuber (1997). La definición original no incluía la puerta del olvido y las conexiones de la mirilla.

El término carrusel de errores constantes se utilizó en el documento original para denotar la conexión recurrente del estado de la célula. Consideremos la definición original en la que el estado de la célula se cambia sólo por adición, cuando la puerta de entrada se abre. El gradiente del estado de la célula con respecto al estado de la célula en un paso de tiempo anterior es cero.

El error puede seguir entrando en el CEC a través de la puerta de salida y la función de activación. La función de activación reduce un poco la magnitud del error antes de que se añada a la CEC. La CEC es el único lugar donde el error puede fluir sin cambios. De nuevo, cuando la puerta de entrada se abre, el error sale a través de la puerta de entrada, la función de activación y la transformación afín, reduciendo la magnitud del error.

Así, el error se reduce cuando se retropropaga a través de una capa LSTM, pero sólo cuando entra y sale del CEC. Lo importante es que no cambia en el CEC sin importar la distancia que recorra. Esto resuelve el problema en la RNN básica de que en cada paso de tiempo se aplica una transformación afín y no lineal, lo que significa que cuanto mayor sea la distancia temporal entre la entrada y la salida, menor será el error.

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X