30 votos

ROC medio para validación cruzada repetida de 10 veces con estimaciones de probabilidad

Estoy planeando utilizar la validación cruzada estratificada 10 veces en unos 10.000 casos utilizando un algoritmo de aprendizaje automático. Cada vez la repetición se hará con semillas aleatorias diferentes.

En este proceso creo 10 instancias de estimaciones de probabilidad para cada caso. 1 instancia de estimación de probabilidad para cada una de las 10 repeticiones de la validación cruzada de 10 veces

¿Puedo promediar 10 probabilidades para cada caso y crear después una nueva curva ROC media (que represente los resultados de 10 CV repetidos), que pueda compararse con otras curvas ROC mediante comparaciones por pares?

27voto

Alexey Grigorev Puntos 1751

Por tu descripción, parece tener mucho sentido: no sólo puedes calcular la curva ROC media, sino también la varianza en torno a ella para construir intervalos de confianza. Esto debería darte una idea de lo estable que es tu modelo.

Por ejemplo, así:

enter image description here

Aquí pongo las curvas ROC individuales, así como la curva media y los intervalos de confianza. Hay zonas en las que las curvas coinciden, por lo que tenemos menos varianza, y hay zonas en las que discrepan.

Para CV repetidos, basta con repetirlo varias veces y obtener la media total de todos los pliegues individuales:

enter image description here

Es bastante similar a la imagen anterior, pero proporciona estimaciones más estables (es decir, fiables) de la media y la varianza.

Aquí está el código para obtener la trama:

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import make_classification
from sklearn.cross_validation import KFold
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve

X, y = make_classification(n_samples=500, random_state=100, flip_y=0.3)

kf = KFold(n=len(y), n_folds=10)

tprs = []
base_fpr = np.linspace(0, 1, 101)

plt.figure(figsize=(5, 5))
plt.axes().set_aspect('equal', 'datalim')

for i, (train, test) in enumerate(kf):
    model = LogisticRegression().fit(X[train], y[train])
    y_score = model.predict_proba(X[test])
    fpr, tpr, _ = roc_curve(y[test], y_score[:, 1])

    plt.plot(fpr, tpr, 'b', alpha=0.15)
    tpr = np.interp(base_fpr, fpr, tpr)
    tpr[0] = 0.0
    tprs.append(tpr)

tprs = np.array(tprs)
mean_tprs = tprs.mean(axis=0)
std = tprs.std(axis=0)

tprs_upper = np.minimum(mean_tprs + std, 1)
tprs_lower = mean_tprs - std

plt.plot(base_fpr, mean_tprs, 'b')
plt.fill_between(base_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3)

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([-0.01, 1.01])
plt.ylim([-0.01, 1.01])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()

Para CV repetidos:

idx = np.arange(0, len(y))

for j in np.random.randint(0, high=10000, size=10):
    np.random.shuffle(idx)
    kf = KFold(n=len(y), n_folds=10, random_state=j)

    for i, (train, test) in enumerate(kf):
        model = LogisticRegression().fit(X[idx][train], y[idx][train])
        y_score = model.predict_proba(X[idx][test])
        fpr, tpr, _ = roc_curve(y[idx][test], y_score[:, 1])

        plt.plot(fpr, tpr, 'b', alpha=0.05)
        tpr = interp(base_fpr, fpr, tpr)
        tpr[0] = 0.0
        tprs.append(tpr)

Fuente de inspiración: http://scikit-learn.org/stable/auto_examples/model_selection/plot_roc_crossval.html

7voto

dan90266 Puntos 609

No es correcto promediar las probabilidades porque eso no representaría las predicciones que se intentan validar e implica contaminación entre las muestras de validación.

Tenga en cuenta que pueden ser necesarias 100 repeticiones de validación cruzada de 10 veces para lograr una precisión adecuada. O utilice el bootstrap optimista de Efron-Gong, que requiere menos iteraciones para obtener la misma precisión (véase, por ejemplo, R rms paquete validate funciones).

Las curvas ROC no son en absoluto perspicaces para este problema. Utilice una puntuación de precisión adecuada y acompáñela de la $c$ -(probabilidad de concordancia; AUROC) que es mucho más fácil de tratar que la curva, ya que se calcula fácil y rápidamente mediante el estadístico Wilcoxon-Mann-Whitney.

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X