Estoy utilizando dos métodos para entrenar un modelo Bernoulli, y estoy tratando de entender por qué no están dando resultados similares. Para ambos métodos, tengo una longitud $N$ matriz de probabilidades $\{\hat{y}^{(n)}\}_{n=1}^{N}$ y quiero estimar la distribución de una longitud $N$ matriz de parámetros $\{\theta^{(n)}\}_{n=1}^{N}$ . En el método (1), para cada $\hat{y}^{(n)}$ , muestro $M$ veces de una distribución Bernoulli con probabilidad $\hat{y}^{(n)}$ y utilizar los datos binarios resultantes como entrada a mi modelo. En el método (2), utilizo $\hat{y}^{(n)}$ directamente incrementando la densidad logarítmica conjunta mediante la siguiente regla de actualización:
$ \begin{equation*} \begin{split} \log p(\mathbf{y} \mid \boldsymbol{\theta}) &= \sum_{n=1}^{N} \log p(y^{(n)} \mid \theta^{(n)})\\ &\approx \sum_{n=1}^{N} \frac{1}{M} \sum_{m=1}^{M} \log p(y^{(n)}_m \mid \theta^{(n)})\\ &\approx \sum_{n=1}^{N} E_{y^{(n)}_m}[\log p(y^{(n)}_m \mid \theta^{(n)})]\\ &= \sum_{n=1}^{N} E_{y^{(n)}_m}\left[\log({\theta^{(n)}}^{y^{(n)}_m} \cdot (1 - \theta^{(n)})^{1 - y^{(n)}_m})\right]\\ &= \sum_{n=1}^{N} E_{y^{(n)}_m}\left[y^{(n)}_m \log(\theta^{(n)}) + (1 - y^{(n)}_m) \log(1 - \theta^{(n)})\right]\\ &= \sum_{n=1}^{N} \Pr(y^{(n)}_m = 1) \log(\theta^{(n)}) + \Pr(y^{(n)}_m = 0) \log(1 - \theta^{(n)})\\ &= \sum_{n=1}^{N} \hat{y}^{(n)} \log(\theta^{(n)}) + (1 - \hat{y}^{(n)}) \log(1 - \theta^{(n)}) \end{split} \end{equation*} $
Estoy usando Stan, donde se especifica la densidad conjunta logarítmica de forma incremental usando cada punto de datos. El pseudocódigo para estos dos métodos es el siguiente:
Espero que los métodos (1) y (2) produzcan estimaciones similares para $\theta$ para grandes $M$ pero estoy comprobando que no es así. He reproducido este problema en un pequeño problema de juguete usando Stan, aquí está el código:
import matplotlib.pyplot as plt
import numpy as np
import pystan
def get_theta_mean(fit):
samples = fit.extract()
theta = np.moveaxis(samples['theta'], 0, -1)
return theta.mean(axis=1)
rng = np.random.RandomState(0)
N = 100
probs = rng.uniform(0, 1, N)
binary_model = '''
data {
int<lower=0> N;
int<lower=0> M;
int<lower=0, upper=1> y[M, N];
}
parameters {
real<lower=0, upper=1> theta[N];
}
model {
for (m in 1:M) {
y[m] ~ bernoulli(theta);
}
}
'''
binary_sm = pystan.StanModel(model_code=binary_model)
M_list = [10, 100, 1000]
theta_means = {}
for M in M_list:
y = np.full((M, N), np.nan)
for m in range(M):
for n in range(N):
y[m, n] = rng.binomial(1, probs[n])
y = y.astype(int)
binary_fit = binary_sm.sampling(
data={'N': N,
'M': M,
'y': y})
theta_means[M] = get_theta_mean(binary_fit)
prob_model = '''
data {
int<lower=0> N;
real<lower=0, upper=1> yhat[N];
}
parameters {
real<lower=0, upper=1> theta[N];
}
model {
for (n in 1:N) {
target += lmultiply(yhat[n], theta[n]) + lmultiply(1 - yhat[n], 1 - theta[n]);
}
}
'''
prob_sm = pystan.StanModel(model_code=prob_model)
prob_fit = prob_sm.sampling(
data={'N': N,
'yhat': probs})
prob_theta_mean = get_theta_mean(prob_fit)
for M, theta_mean in theta_means.items():
plt.scatter(theta_mean, probs, label=f'Method (1), M={M}')
plt.scatter(prob_theta_mean, probs, label='Method (2)')
plt.grid(True)
plt.xlabel(r'$E[\theta]$')
plt.ylabel('Probability')
plt.legend()
Aquí hay un gráfico de dispersión de los resultados que obtengo en el problema del juguete. Se supone que (1) se aproxima mejor a (2) como $M$ se incrementa, convergiendo finalmente a $E[\theta^{(n)}] = \hat{y}^{(n)}$ . (1) parece seguir esto, pero (2) está significativamente fuera.
ACTUALIZACIÓN:
Me he dado cuenta de que hay un error en la línea uno de mi derivación anterior, y que el método (2) representa en realidad
$ \begin{equation*} \begin{split} \log p(\vec{y} \mid \theta) &= \sum_{n=1}^{N} \log p(y^{(n)} \mid \theta^{(n)})\\ &\approx \sum_{n=1}^{N} \log\left(\frac{1}{M} \sum_{m=1}^{M} p(y^{(n)}_m \mid \theta^{(n)})\right) \end{split} \end{equation*} $
He cambiado el método (1) para reflejar esto, y ahora mis resultados entre el método (1) y (2) son consistentes, y ninguno de ellos satisface $E[\theta \mid D] \approx \hat{y}$ .