5 votos

¿Cuál es el$k$ óptimo para el clasificador de vecino más cercano$k$ en el conjunto de datos Iris?

¿Cuál es el valor óptimo de$k$ para un clasificador kNN euclidiano no ponderado aplicado al conjunto de datos Iris ?

Donde óptimo implica el valor para$k$ que lleva al error de generalización más bajo.

23voto

Boris Tsirelson Puntos 191

Digamos que usted desea utilizar la Exactitud (o % de respuestas correctas) para evaluar la "óptima", y usted tiene tiempo para mirar a los 25 valores de k. El siguiente código R va a responder a su pregunta con 15 repeticiones de 10-fold cross-validation. Asimismo, tomará un largo tiempo para ejecutar.

library(caret)
model <- train(
    Species~., 
    data=iris, 
    method='knn',
    tuneGrid=expand.grid(.k=1:25),
    metric='Accuracy',
    trControl=trainControl(
        method='repeatedcv', 
        number=10, 
        repeats=15))

model
plot(model)
> confusionMatrix(model)
Cross-Validated (10 fold, repeated 15 times) Confusion Matrix 

(entries are percentages of table totals)

            Reference
Prediction   setosa versicolor virginica
  setosa       33.3        0.0       0.0
  versicolor    0.0       31.9       1.2
  virginica     0.0        1.4      32.1

Accuracy

Así, por este criterio, recibo una respuesta de 17 años, pero parece que el "verdadero" valor podría estar en cualquier lugar entre 5 y 20. Se puede sustituir el "Kappa" o alguna otra métrica si quieres, y añadir más cv-pliegues así. También puede probar diferentes métodos de validación cruzada, tales como leave-one-out, o bootstrap re-muestreo.

/Edit: en respuesta a su solicitud de variedad, me escribió esta función para calcular una variedad de métricas para multi-clase de problemas:

#Multi-Class Summary Function
#Based on caret:::twoClassSummary
require(compiler)
multiClassSummary <- cmpfun(function (data, lev = NULL, model = NULL){

  #Load Libraries
  require(Metrics)
  require(caret)

  #Check data
  if (!all(levels(data[, "pred"]) == levels(data[, "obs"]))) 
    stop("levels of observed and predicted data do not match")

  #Calculate custom one-vs-all stats for each class
  prob_stats <- lapply(levels(data[, "pred"]), function(class){

    #Grab one-vs-all data for the class
    pred <- ifelse(data[, "pred"] == class, 1, 0)
    obs  <- ifelse(data[,  "obs"] == class, 1, 0)
    prob <- data[,class]

    #Calculate one-vs-all AUC and logLoss and return
    cap_prob <- pmin(pmax(prob, .000001), .999999)
    prob_stats <- c(auc(obs, prob), logLoss(obs, cap_prob))
    names(prob_stats) <- c('ROC', 'logLoss')
    return(prob_stats) 
  })
  prob_stats <- do.call(rbind, prob_stats)
  rownames(prob_stats) <- paste('Class:', levels(data[, "pred"]))

  #Calculate confusion matrix-based statistics
  CM <- confusionMatrix(data[, "pred"], data[, "obs"])

  #Aggregate and average class-wise stats
  #Todo: add weights
  class_stats <- cbind(CM$byClass, prob_stats)
  class_stats <- colMeans(class_stats)

  #Aggregate overall stats
  overall_stats <- c(CM$overall)

  #Combine overall with class-wise stats and remove some stats we don't want 
  stats <- c(overall_stats, class_stats)
  stats <- stats[! names(stats) %in% c('AccuracyNull', 'Prevalence', 'Detection Prevalence')]

  #Clean names and return
  names(stats) <- gsub('[[:blank:]]+', '_', names(stats))
  return(stats)
})

Es un doozy de una función, así que va a ralentizar el símbolo de intercalación un poco, pero yo sería muy feliz si usted publicó los resultados de su 1000 repeticiones de 10 veces CV (no tengo ni el tiempo, no la capacidad computacional a intentar esto en la actualidad). Aquí está mi código para 15 repeticiones de 10 veces CV. Tenga en cuenta que usted puede fácilmente modificar este código para probar otros métodos de muestreo, tales como bootstrap de muestreo:

library(caret)
set.seed(19556)
model <- train(
  Species~., 
  data=iris, 
  method='knn',
  tuneGrid=expand.grid(.k=1:30),
  metric='Accuracy',
  trControl=trainControl(
    method='repeatedcv', 
    number=10, 
    repeats=15,
    classProbs=TRUE,
    summaryFunction=multiClassSummary))

Ambos ROC y LogLoss parecen pico alrededor de las 8: ROClogLoss

Mientras que la sensibilidad y la especificidad parecen pico alrededor de las 15: SensSpec

Aquí un poco de código para la salida de todas las parcelas como pdf:

dev.off()
pdf('plots.pdf')
for(stat in c('Accuracy', 'Kappa', 'AccuracyLower', 'AccuracyUpper', 'AccuracyPValue', 
              'Sensitivity', 'Specificity', 'Pos_Pred_Value', 
              'Neg_Pred_Value', 'Detection_Rate', 'ROC', 'logLoss')) {

  print(plot(model, metric=stat))
}
dev.off()

Si usted pone una pistola en mi cabeza, probablemente diría 8...

i-Ciencias.com

I-Ciencias es una comunidad de estudiantes y amantes de la ciencia en la que puedes resolver tus problemas y dudas.
Puedes consultar las preguntas de otros usuarios, hacer tus propias preguntas o resolver las de los demás.

Powered by:

X