46 votos

Entender el parámetro input_shape en LSTM con Keras

Estoy tratando de usar el ejemplo descrito en la documentación de Keras llamado "Stacked LSTM para la clasificación de secuencias" (ver código abajo) y no puedo averiguar el input_shape en el contexto de mis datos.

Tengo como entrada una matriz de secuencias de 25 posibles caracteres codificados en enteros a una secuencia acolchada de longitud máxima 31. Como resultado, mi x_train tiene la forma (1085420, 31) que significa (n_observations, sequence_length) .

from keras.models import Sequential
from keras.layers import LSTM, Dense
import numpy as np

data_dim = 16
timesteps = 8
num_classes = 10

# expected input data shape: (batch_size, timesteps, data_dim)
model = Sequential()
model.add(LSTM(32, return_sequences=True,
               input_shape=(timesteps, data_dim)))  # returns a sequence of vectors of dimension 32
model.add(LSTM(32, return_sequences=True))  # returns a sequence of vectors of dimension 32
model.add(LSTM(32))  # return a single vector of dimension 32
model.add(Dense(10, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='rmsprop',
              metrics=['accuracy'])

# Generate dummy training data
x_train = np.random.random((1000, timesteps, data_dim))
y_train = np.random.random((1000, num_classes))

# Generate dummy validation data
x_val = np.random.random((100, timesteps, data_dim))
y_val = np.random.random((100, num_classes))

model.fit(x_train, y_train,
          batch_size=64, epochs=5,
          validation_data=(x_val, y_val))

En este código x_train tiene la forma (1000, 8, 16) , como para una matriz de 1000 matrices de 8 matrices de 16 elementos. Allí me pierdo completamente en lo que es qué y cómo mis datos pueden llegar a esta forma.

Mirando el doc de Keras y varios tutoriales y Q&A, parece que me estoy perdiendo algo obvio. ¿Puede alguien darme una pista de lo que debo buscar?

Gracias por su ayuda.

43voto

PeterTecks Puntos 36

Las formas LSTM son difíciles, así que no te sientas mal, yo mismo tuve que pasar un par de días luchando contra ellas:

Si usted va a alimentar los datos de 1 carácter a la vez su forma de entrada debe ser (31,1) ya que su entrada tiene 31 pasos de tiempo, 1 carácter cada uno. Tendrá que cambiar la forma de su x_train de (1085420, 31) a (1085420, 31,1), lo que se hace fácilmente con este comando:

 x_train=x_train.reshape(x_train.shape[0],x_train.shape[1],1))

14voto

Mick Sharpe Puntos 1463

Consulta este repositorio git Diagrama resumen de LSTM Keras y creo que deberías tenerlo todo muy claro.

Este repositorio git incluye un diagrama de resumen de Keras LSTM que muestra:

  • el uso de parámetros como return_sequences , batch_size , time_step ...
  • la estructura real de las capas de lstm
  • el concepto de estas capas en keras
  • cómo manipular los datos de entrada y salida para adaptarlos a los requisitos del modelo cómo apilar las capas de LSTM

Y más

4voto

prosti Puntos 139

Sé que no es una respuesta directa a su pregunta. Este es un ejemplo simplificado con una sola célula LSTM, que me ayuda a entender la operación de remodelación de los datos de entrada.

from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM
import numpy as np

# define model
inputs1 = Input(shape=(2, 3))
lstm1, state_h, state_c = LSTM(1, return_sequences=True, return_state=True)(inputs1)
model = Model(inputs=inputs1, outputs=[lstm1, state_h, state_c])

# define input data
data = np.random.rand(2, 3)
data = data.reshape((1,2,3))

# make and show prediction
print(model.predict(data))

Este sería un ejemplo de la red LSTM con una sola célula LSTM y con los datos de entrada de forma específica.

Resulta que aquí sólo estamos prediciendo, el entrenamiento no está presente por simplicidad, pero fíjese que hemos necesitado remodelar los datos (para añadir una dimensión adicional) antes de la predict método.

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X