https://github.com/cran/precrec
Raw File
Tip revision: 9c56b91df45d18e10abd76aa1359b1482955fbbc authored by Takaya Saito on 04 December 2015, 15:40:36 UTC
version 0.1.1
Tip revision: 9c56b91
mm3_reformat_data.R
#
# Reformat input data for Precision-Recall and ROC evaluation
#
reformat_data <- function(scores, labels,
                          modname = as.character(NA), dsid = 1L,
                          posclass = NULL, na_worst = TRUE,
                          ties_method = "equiv", ...) {

  # === Validate input arguments ===
  .validate_reformat_data_args(scores, labels, modname = modname, dsid = dsid,
                               posclass = posclass, na_worst = na_worst,
                               ties_method = ties_method, ...)

  # === Reformat input data ===
  # Get score ranks and sorted indices
  sranks <- .rank_scores(scores, na_worst, ties_method, validate = FALSE)
  ranks <- sranks[["ranks"]]
  rank_idx <- sranks[["rank_idx"]]

  # Get a factor with "positive" and "negative"
  fmtlabs <- .factor_labels(labels, posclass, validate = FALSE)

  # === Create an S3 object ===
  s3obj <- structure(list(labels = fmtlabs[["labels"]],
                          ranks = ranks,
                          rank_idx = rank_idx),
                     class = "fmdat")

  # Set attributes
  attr(s3obj, "modname") <- modname
  attr(s3obj, "dsid") <- dsid
  attr(s3obj, "nn") <- fmtlabs[["nn"]]
  attr(s3obj, "np") <- fmtlabs[["np"]]
  attr(s3obj, "args") <- list(posclass = posclass, na_worst = na_worst,
                              ties_method = ties_method,
                              modname = modname, dsid = dsid)
  attr(s3obj, "validated") <- FALSE

  # Call .validate.fmdat()
  .validate(s3obj)
}

#
# Factor labels
#
.factor_labels <- function(labels, posclass, validate = TRUE) {
  # === Validate input arguments ===
  if (validate) {
    .validate_labels(labels)
    .validate_posclass(posclass)
  }

  # Update posclass if necessary
  if (is.null(posclass)) {
    posclass <- NA
  } else if (is.factor(labels)) {
    lv <- levels(labels)
    posclass <- which(lv == posclass)
  }

  # Check the data type of posclass
  if (!is.na(posclass) && typeof(posclass) != typeof(labels[1])) {
    stop("posclass must be the same data type as labels", call. = FALSE)
  }

  # === Generate label factors ===
  flabels <- format_labels(labels, posclass)
  .check_cpp_func_error(flabels, "format_labels")

  flabels
}

#
# Rank scores
#
.rank_scores <- function(scores, na_worst = TRUE, ties_method = "equiv",
                         validate = TRUE) {

  # === Validate input arguments ===
  if (validate) {
    .validate_scores(scores)
    .validate_na_worst(na_worst)
    .validate_ties_method(ties_method)
  }

  # === Create ranks ===
  #   ranks <- rank(scores, na_worst, ties_method)
  sranks <- get_score_ranks(scores, na_worst, ties_method)
  .check_cpp_func_error(sranks, "get_score_ranks")

  sranks
}

#
# Validate arguments of reformat_data()
#
.validate_reformat_data_args <- function(scores, labels, modname, dsid,
                                         posclass, na_worst, ties_method,
                                         ...) {

  # Check '...'
  arglist <- list(...)
  if (!is.null(names(arglist))){
    stop(paste0("Invalid arguments: ", paste(names(arglist), collapse = ", ")),
         call. = FALSE)
  }

  # Check scores and labels
  .validate_scores_and_labels(NULL, NULL, scores, labels)

  # Check model name
  .validate_modname(modname)

  # Check dataset ID
  .validate_dsid(dsid)

  # Check posclass
  .validate_posclass(posclass)

  # Check na_worst
  .validate_na_worst(na_worst)

  # Check ties_method
  .validate_ties_method(ties_method)

}

#
# Validate 'fmdat' object generated by reformat_data()
#
.validate.fmdat <- function(fmdat) {
  # Need to validate only once
  if (methods::is(fmdat, "fmdat") && attr(fmdat, "validated")) {
    return(fmdat)
  }

  # Validate class items and attributes
  item_names <- c("labels", "ranks", "rank_idx")
  attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
  arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
  .validate_basic(fmdat, "fmdat", "reformat_data", item_names, attr_names,
                  arg_names)

  # Check values of class items
  if (length(fmdat[["labels"]]) == 0
      || length(fmdat[["labels"]]) != length(fmdat[["ranks"]])
      || length(fmdat[["labels"]]) != length(fmdat[["rank_idx"]])) {
    stop("List items in fmdat must be all the same lengths", call. = FALSE)
  }

  # Labels
  assertthat::assert_that(is.atomic(fmdat[["labels"]]),
                          is.vector(fmdat[["labels"]]),
                          is.numeric(fmdat[["labels"]]))

  # Ranks
  assertthat::assert_that(is.atomic(fmdat[["ranks"]]),
                          is.vector(fmdat[["ranks"]]),
                          is.numeric(fmdat[["ranks"]]))

  # Rank index
  assertthat::assert_that(is.atomic(fmdat[["rank_idx"]]),
                          is.vector(fmdat[["rank_idx"]]),
                          is.integer(fmdat[["rank_idx"]]))

  attr(fmdat, "validated") <- TRUE
  fmdat
}

back to top