Loading [MathJax]/jax/element/mml/optable/BasicLatin.js

2 votos

acceso a los tensores LSTM Weights en tensorflow

Estoy probando el código de : https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/recurrent_network.ipynb

y también mirando la arquitectura de la célula LSTM básica como se describe en: https://r2rt.com/written-memories-understanding-deriving-and-extending-the-lstm.html

por lo que deseo acceder a los pesos para Wi Wo Wf y Wo (como se indica en la sección "El LSTM básico")

Estoy ejecutando esto

    for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
            print(v)

y me sale

   Tensor("Variable/read:0", shape=(128, 10), dtype=float32)
   Tensor("Variable_1/read:0", shape=(10,), dtype=float32)
   Tensor("rnn/basic_lstm_cell/weights/read:0", shape=(156, 512), dtype=float32)
   Tensor("rnn/basic_lstm_cell/biases/read:0", shape=(512,), dtype=float32)

Así que parece que estoy haciendo algo mal ya que no puedo conseguir las matrices de las Cuatro W...

¿Alguna ayuda para entender lo que me he perdido, por favor?

4voto

Pantelis Puntos 111

Puedes recuperar los pesos LSTM de tu sesión de tensorflow "sess" de la siguiente manera:

trainable_vars_dict = {}
for key in tvars:
    trainable_vars_dict[key.name] = sess.run(key)
    # Checking the names of the keys
    print(key) 

De este código obtendrás los nombres de las llaves. Un nombre de clave corresponde a una matriz que contiene todos los pesos del LSTM. La clave en tu caso debe tener el nombre "LSTM/rnn/basic_lstm_cell/weights:0". Asumiendo que el tamaño de tu entrada es input_size, tienes que hacer:

lstm_weight_vals = trainable_vars_dict["LSTM/rnn/basic_lstm_cell/weights:0"]
w_i, w_C, w_f, w_o = np.split(lstm_weight_vals, 4, axis=1)

w_xi = w_i[:input_size, :]
w_hi = w_i[input_size:, :]

w_xC = w_C[:input_size, :]
w_hC = w_C[input_size:, :]

w_xf = w_f[:input_size, :]
w_hf = w_f[input_size:, :]

w_xo = w_o[:input_size, :]
w_ho = w_o[input_size:, :]

Donde las matrices con "h" en ellas deben ser cuadráticas al final (de tamaño 128×128 en su caso). Creo que para ti el tamaño de entrada es 28 .

1voto

Steve Mucci Puntos 139

Las cuatro matrices W están en "rnn/basic_lstm_cell/weights/read:0". Puedes ver la dimensión de los pesos. Los 512 representan los cuatro pesos*célula (4*128), y los 156 representan las 28 características de entrada y las 128 células.

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X