En casi todos los ejemplos de código que he visto de un VAE, las funciones de pérdida se definen de la siguiente manera (este es el código de tensorflow, pero he visto similar para theano, torch etc. También es para un convnet, pero eso tampoco es demasiado relevante, sólo afecta a los ejes sobre los que se toman las sumas):
# latent space loss. KL divergence between latent space distribution and unit gaussian, for each batch.
# first half of eq 10. in https://arxiv.org/abs/1312.6114
kl_loss = -0.5 * tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - tf.exp(log_sigma_sq), axis=1)
# reconstruction error, using pixel-wise L2 loss, for each batch
rec_loss = tf.reduce_sum(tf.squared_difference(y, x), axis=[1,2,3])
# or binary cross entropy (assuming 0...1 values)
y = tf.clip_by_value(y, 1e-8, 1-1e-8) # prevent nan on log(0)
rec_loss = -tf.reduce_sum(x * tf.log(y) + (1-x) * tf.log(1-y), axis=[1,2,3])
# sum the two and average over batches
loss = tf.reduce_mean(kl_loss + rec_loss)
Sin embargo, el rango numérico de kl_loss y rec_loss depende en gran medida de las dimensiones del espacio latente y del tamaño de la característica de entrada (por ejemplo, la resolución de píxeles) respectivamente. ¿Sería sensato reemplazar los reduce_sum por reduce_mean para obtener por z-dim KLD y por píxel (o característica) LSE o BCE? Y lo que es más importante, ¿cómo ponderamos la pérdida latente con la pérdida de reconstrucción al sumarlas para obtener la pérdida final? ¿Es sólo prueba y error? o ¿hay alguna teoría (o al menos una regla general) para ello? No he podido encontrar ninguna información al respecto en ningún sitio (incluido el documento original).
El problema que tengo es que si el equilibrio entre las dimensiones de mis características de entrada (x) y las dimensiones del espacio latente (z) no es "óptimo", o bien mis reconstrucciones son muy buenas pero el espacio latente aprendido no está estructurado (si las dimensiones x son muy altas y el error de reconstrucción domina sobre el KLD), o viceversa (las reconstrucciones no son buenas pero el espacio latente aprendido está bien estructurado si el KLD domina).
Me encuentro con que tengo que normalizar la pérdida de reconstrucción (dividiendo por el tamaño de la característica de entrada), y KLD (dividiendo por las dimensiones z) y luego ponderar manualmente el término KLD con un factor de peso arbitrario (La normalización es para que pueda utilizar el mismo o similar peso independientemente de las dimensiones de x o z ). Empíricamente he encontrado alrededor de 0,1 para proporcionar un buen equilibrio entre la reconstrucción y el espacio latente estructurado que se siente como un "punto dulce" para mí. Estoy buscando trabajos anteriores en esta área.
A petición, notación matemática de lo anterior (centrándose en la pérdida L2 para el error de reconstrucción)
$$\mathcal{L}_{latent}^{(i)} = -\frac{1}{2} \sum_{j=1}^{J}(1+\log (\sigma_j^{(i)})^2 - (\mu_j^{(i)})^2 - (\sigma_j^{(i)})^2)$$
$$\mathcal{L}_{recon}^{(i)} = -\sum_{k=1}^{K}(y_k^{(i)}-x_k^{(i)})^2$$
$$\mathcal{L}^{(m)} = \frac{1}{M}\sum_{i=1}^{M}(\mathcal{L}_{latent}^{(i)} + \mathcal{L}_{recon}^{(i)})$$
donde $J$ es la dimensionalidad del vector latente $z$ (y la correspondiente media $\mu$ y la varianza $\sigma^2$ ), $K$ es la dimensionalidad de las características de entrada, $M$ es el tamaño del minilote, el superíndice $(i)$ denota el $i$ punto de datos y $\mathcal{L}^{(m)}$ es la pérdida para el $m$ a la mini-lote.