1 votos

definición de una epoch en el método fit en keras

Entiendo que una época es una pasada por los datos de entrenamiento. Estoy entrenando una CNN usando las siguientes líneas de código

cnn = tf.keras.models.Sequential()
# ... code to define network layers ..
cnn.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
cnn.fit(x = training_set, validation_data = test_set, epochs = 2)

training_set se generó utilizando

train_datagen = ImageDataGenerator(rescale = 1./255,
                                   shear_range = 0.2,
                                   zoom_range = 0.2,
                                   horizontal_flip = True)
training_set = train_datagen.flow_from_directory('dataset/training_set_cut',
                                                 target_size = (64, 64),
                                                 batch_size = 32,
                                                 class_mode = 'binary')

Y test_set se generó utilizando un código similar. training_set y test_set parecen ser generadores que nunca dejan de rendir o suben StopIteration . Si ese es el caso, entonces ¿cómo cnn.fit saber que se ha completado una época?

1voto

lyinch Puntos 166

En Keras, los generadores generan infinitos elementos. Para definir lo que es una época, tienes que decirle al generador cuándo debe rendir. Esto se puede hacer con steps_per_epoch y epochs en el model.fit llamada. De la documentación de Keras, aquí hay un ejemplo de cómo entrenar un modelo con generadores:

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = utils.to_categorical(y_train, num_classes)
y_test = utils.to_categorical(y_test, num_classes)
datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)
# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(x_train)
# fits the model on batches with real-time data augmentation:
model.fit(datagen.flow(x_train, y_train, batch_size=32),
          steps_per_epoch=len(x_train) / 32, epochs=epochs)

Manualmente, sólo se generan tantas imágenes como se desee por generador (utilizando, por ejemplo. zip si tiene varios generadores) y de la misma fuente obtenemos:

for e in range(epochs):
    print('Epoch', e)
    batches = 0
    for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
        model.fit(x_batch, y_batch)
        batches += 1
        if batches >= len(x_train) / 32:
            # we need to break the loop by hand because
            # the generator loops indefinitely
            break

Ambos ejemplos están tomados textualmente de aquí

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X