21 votos

RNNs: ¿Cuándo aplicar el BPTT y/o actualizar los pesos?

Estoy intentando comprender la aplicación de alto nivel de las RNN al etiquetado de secuencias a través de (entre otros) el artículo de 2005 de Graves sobre clasificación de los fonemas.

Para resumir el problema: tenemos un gran conjunto de entrenamiento que consiste en (entrada) archivos de audio de frases individuales y (salida) tiempos de inicio, tiempos de parada y etiquetas etiquetadas por expertos para los fonemas individuales (incluyendo algunos fonemas "especiales" como el silencio, de manera que cada muestra en cada archivo de audio está etiquetada con algún símbolo de fonema).

La idea central del artículo es aplicar a este problema una RNN con células de memoria LSTM en la capa oculta. (Aplica varias variantes y otras técnicas como comparación. Por el momento sólo me interesa la LSTM unidireccional, para simplificar las cosas).

Creo que entiendo la arquitectura de la red: Una capa de entrada que corresponde a ventanas de 10 ms de los archivos de audio, preprocesados de forma estándar para el trabajo de audio; una capa oculta de células LSTM, y una capa de salida con una codificación de un solo disparo de todos los 61 símbolos telefónicos posibles.

Creo que entiendo las ecuaciones (intrincadas pero directas) del paso hacia delante y hacia atrás por las unidades LSTM. No son más que cálculos y la regla de la cadena.

Lo que no entiendo, después de leer varias veces este documento y otros similares, es cuando exactamente para aplicar el algoritmo de retropropagación y cuando exactamente para actualizar los distintos pesos de las neuronas.

Existen dos métodos plausibles:

1) Retroalimentación y actualización de fotogramas

Load a sentence.  
Divide into frames/timesteps.  
For each frame:
- Apply forward step
- Determine error function
- Apply backpropagation to this frame's error
- Update weights accordingly
At end of sentence, reset memory
load another sentence and continue.

o,

2) Retroalimentación y actualización por frases:

Load a sentence.  
Divide into frames/timesteps.  
For each frame:
- Apply forward step
- Determine error function
At end of sentence:
- Apply backprop to average of sentence error function
- Update weights accordingly
- Reset memory
Load another sentence and continue.

Tenga en cuenta que esta es una pregunta general sobre el entrenamiento de RNN utilizando el documento de Graves como un ejemplo puntiagudo (y personalmente relevante): Cuando se entrenan RNNs en secuencias, ¿se aplica backprop en cada paso de tiempo? ¿Se ajustan los pesos en cada paso de tiempo? O, en una analogía poco precisa con el entrenamiento por lotes en arquitecturas estrictamente de avance, ¿se acumulan los errores y se promedian en una secuencia concreta antes de aplicar la retropropulsión y las actualizaciones de peso?

¿O estoy más confundido de lo que creo?

39voto

throwaway Puntos 18

Asumiré que estamos hablando de redes neuronales recurrentes (RNN) que producen una salida en cada paso de tiempo (si la salida sólo está disponible al final de la secuencia, sólo tiene sentido ejecutar la retropropagación al final). Las RNN de este tipo suelen entrenarse mediante retropropagación truncada en el tiempo (BPTT), operando secuencialmente en "trozos" de una secuencia. El procedimiento es el siguiente:

  1. Pase hacia adelante: Paso a través del siguiente $k_1$ pasos de tiempo, calculando los estados de entrada, ocultos y de salida.
  2. Calcule la pérdida, sumada a los pasos de tiempo anteriores (véase más abajo).
  3. Pase hacia atrás: Calcular el gradiente de la pérdida con respecto a todos los parámetros, acumulando sobre el anterior $k_2$ pasos de tiempo (esto requiere haber almacenado todas las activaciones para estos pasos de tiempo). Recortar los gradientes para evitar el problema de la explosión del gradiente (ocurre raramente).
  4. Actualizar los parámetros (esto ocurre una vez por chunk, no de forma incremental en cada paso de tiempo).
  5. Si se procesan varios trozos de una secuencia más larga, almacene el estado oculto en el último paso de tiempo (se utilizará para inicializar el estado oculto para el comienzo del siguiente trozo). Si hemos llegado al final de la secuencia, reiniciar la memoria/el estado oculto y pasar al principio de la siguiente secuencia (o al principio de la misma secuencia, si sólo hay una).
  6. Repite desde el paso 1.

La forma de sumar las pérdidas depende de $k_1$ y $k_2$ . Por ejemplo, cuando $k_1 = k_2$ la pérdida se suma sobre el pasado $k_1 = k_2$ pasos de tiempo, pero el procedimiento es diferente cuando $k_2 > k_1$ (véase Williams y Peng 1990).

El cálculo del gradiente y las actualizaciones se realizan cada $k_1$ pasos de tiempo porque es computacionalmente más barato que actualizar en cada paso de tiempo. Actualizar varias veces por secuencia (es decir, establecer $k_1$ menor que la longitud de la secuencia) puede acelerar el entrenamiento porque las actualizaciones del peso son más frecuentes.

La retropropagación se realiza sólo para $k_2$ pasos de tiempo porque es computacionalmente más barato que propagar hacia el principio de la secuencia (lo que requeriría almacenar y procesar repetidamente todos los pasos de tiempo). Los gradientes calculados de esta manera son una aproximación al "verdadero" gradiente calculado en todos los pasos de tiempo. Pero, debido al problema del gradiente evanescente, los gradientes tenderán a acercarse a cero después de un cierto número de pasos de tiempo; la propagación más allá de este límite no aportaría ninguna ventaja. Estableciendo $k_2$ demasiado corta puede limitar la escala temporal en la que la red puede aprender. Sin embargo, la memoria de la red no se limita a $k_2$ pasos de tiempo porque las unidades ocultas pueden almacenar información más allá de este período (por ejemplo, véase Mikolov 2012 y este puesto ).

Además de las consideraciones computacionales, los ajustes adecuados para $k_1$ y $k_2$ dependen de las estadísticas de los datos (por ejemplo, la escala temporal de las estructuras que son relevantes para producir buenos resultados). Probablemente también dependen de los detalles de la red. Por ejemplo, hay una serie de arquitecturas, trucos de inicialización, etc. diseñados para mitigar el problema del gradiente decreciente.

Su opción 1 ('frame-wise backprop') corresponde a la configuración de $k_1$ a $1$ y $k_2$ al número de pasos de tiempo desde el comienzo de la frase hasta el punto actual. La opción 2 ("backprop de la frase") corresponde a la configuración de ambos $k_1$ y $k_2$ a la longitud de la frase. Ambos enfoques son válidos (con las consideraciones computacionales/de rendimiento mencionadas anteriormente; el número 1 sería bastante intensivo en términos computacionales para las secuencias más largas). Ninguno de estos enfoques se llamaría "truncado" porque la retropropagación se produce en toda la secuencia. Otras configuraciones de $k_1$ y $k_2$ son posibles; a continuación enumeraré algunos ejemplos.

Referencias que describen la BPTT truncada (procedimiento, motivación, cuestiones prácticas):

  • Sutskever (2013) . Entrenamiento de redes neuronales recurrentes.
  • Mikolov (2012) . Modelos lingüísticos estadísticos basados en redes neuronales.
    • Al utilizar las RNN de vainilla para procesar datos de texto como una secuencia de palabras, recomienda establecer $k_1$ a 10-20 palabras y $k_2$ a 5 palabras
    • Realización de múltiples actualizaciones por secuencia (es decir $k_1$ menos que la longitud de la secuencia) funciona mejor que la actualización al final de la secuencia
    • Realizar las actualizaciones una vez por chunk es mejor que hacerlo de forma incremental (que puede ser inestable)
  • Williams y Peng (1990) . Un algoritmo eficiente basado en el gradiente para el entrenamiento en línea de trayectorias de redes recurrentes.
    • Propuesta original (?) del algoritmo
    • Se discute la elección de $k_1$ y $k_2$ (que llaman $h'$ y $h$ ). Sólo tienen en cuenta $k_2 \ge k_1$ .
    • Nota: Utilizan la frase "BPTT(h; h')" o 'el algoritmo mejorado' para referirse a lo que las otras referencias llaman 'BPTT truncado'. Utilizan la frase "BPTT truncado" para referirse al caso especial en el que $k_1 = 1$ .

Otros ejemplos que utilizan BPTT truncado:

  • (Karpathy 2015). char-rnn.
    • Descripción y código
    • RNN de vainilla que procesa los documentos de texto un carácter a la vez. Entrenada para predecir el siguiente carácter. $k_1 = k_2 = 25$ personajes. Red utilizada para generar un nuevo texto en el estilo del documento de entrenamiento, con resultados divertidos.
  • Graves (2014) . Generación de secuencias con redes neuronales recurrentes.
    • Véase la sección sobre la generación de artículos simulados de Wikipedia. Red LSTM que procesa datos de texto como una secuencia de bytes. Entrenada para predecir el siguiente byte. $k_1 = k_2 = 100$ bytes. Restablecimiento de la memoria LSTM cada $10,000$ bytes.
  • Sak et al. (2014) . Arquitecturas de redes neuronales recurrentes basadas en la memoria a corto plazo para el reconocimiento del habla de gran vocabulario.
    • Redes LSTM modificadas, que procesan secuencias de características acústicas. $k_1 = k_2 = 20$ .
  • Ollivier et al. (2015) . Entrenamiento de redes recurrentes en línea sin retroceso.
    • El objetivo de este artículo era proponer un algoritmo de aprendizaje diferente, pero lo compararon con el BPTT truncado. Utilizaron RNNs de vainilla para predecir secuencias de símbolos. Sólo lo menciono aquí para decir que usaron $k_1 = k_2 = 15$ .
  • Hochreiter y Schmidhuber (1997) . Memoria a corto plazo.
    • Describen un procedimiento modificado para LSTMs

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X