#
# Reformat input data for Precision-Recall and ROC evaluation
#
reformat_data <- function(scores, labels,
modname = as.character(NA), dsid = 1L,
posclass = NULL, na_worst = TRUE,
ties_method = "equiv", mode = "rocprc", ...) {
# === Validate input arguments ===
new_ties_method <- .pmatch_tiesmethod(ties_method, ...)
new_na_worst <- .get_new_naworst(na_worst, ...)
new_mode <- .pmatch_mode(mode)
.validate_reformat_data_args(scores, labels, modname = modname, dsid = dsid,
posclass = posclass, na_worst = new_na_worst,
ties_method = new_ties_method, mode = new_mode,
...)
# === Reformat input data ===
# Get a factor with "positive" and "negative"
fmtlabs <- .factor_labels(labels, posclass, validate = FALSE)
if (mode == "aucroc") {
# === Create an S3 object ===
s3obj <- structure(list(scores = scores,
labels = fmtlabs[["labels"]]),
class = "sdat")
} else {
# Get score ranks and sorted indices
sranks <- .rank_scores(scores, new_na_worst, new_ties_method,
validate = FALSE)
ranks <- sranks[["ranks"]]
rank_idx <- sranks[["rank_idx"]]
# === Create an S3 object ===
s3obj <- structure(list(scores = scores,
labels = fmtlabs[["labels"]],
ranks = ranks,
rank_idx = rank_idx),
class = "fmdat")
}
# Set attributes
attr(s3obj, "modname") <- modname
attr(s3obj, "dsid") <- dsid
attr(s3obj, "nn") <- fmtlabs[["nn"]]
attr(s3obj, "np") <- fmtlabs[["np"]]
attr(s3obj, "args") <- list(posclass = posclass, na_worst = new_na_worst,
ties_method = new_ties_method,
modname = modname, dsid = dsid)
attr(s3obj, "validated") <- FALSE
# Call .validate.fmdat() / .validate.sdat()
.validate(s3obj)
}
#
# Factor labels
#
.factor_labels <- function(labels, posclass, validate = TRUE) {
# === Validate input arguments ===
if (validate) {
.validate_labels(labels)
.validate_posclass(posclass)
}
# Update posclass if necessary
if (is.null(posclass)) {
posclass <- NA
} else if (is.factor(labels)) {
lv <- levels(labels)
posclass <- which(lv == posclass)
}
# Check the data type of posclass
if (!is.na(posclass) && typeof(posclass) != typeof(labels[1])) {
stop("posclass must be the same data type as labels", call. = FALSE)
}
# === Generate label factors ===
flabels <- format_labels(labels, posclass)
.check_cpp_func_error(flabels, "format_labels")
flabels
}
#
# Rank scores
#
.rank_scores <- function(scores, na_worst = TRUE, ties_method = "equiv",
validate = TRUE) {
# === Validate input arguments ===
if (validate) {
.validate_scores(scores)
.validate_na_worst(na_worst)
.validate_ties_method(ties_method)
}
# === Create ranks ===
# ranks <- rank(scores, na_worst, ties_method)
sranks <- get_score_ranks(scores, na_worst, ties_method)
.check_cpp_func_error(sranks, "get_score_ranks")
sranks
}
#
# Validate arguments of reformat_data()
#
.validate_reformat_data_args <- function(scores, labels, modname, dsid,
posclass, na_worst, ties_method,
mode, ...) {
# Check '...'
arglist <- list(...)
if (!is.null(names(arglist))){
stop(paste0("Invalid arguments: ", paste(names(arglist), collapse = ", ")),
call. = FALSE)
}
# Check scores and labels
.validate_scores_and_labels(NULL, NULL, scores, labels)
# Check model name
.validate_modname(modname)
# Check dataset ID
.validate_dsid(dsid)
# Check posclass
.validate_posclass(posclass)
# Check na_worst
.validate_na_worst(na_worst)
# Check ties_method
.validate_ties_method(ties_method)
# Check mode
.validate_mode(mode)
}
#
# Validate 'fmdat' object generated by reformat_data()
#
.validate.fmdat <- function(fmdat) {
# Need to validate only once
if (methods::is(fmdat, "fmdat") && attr(fmdat, "validated")) {
return(fmdat)
}
# Validate class items and attributes
item_names <- c("scores", "labels", "ranks", "rank_idx")
attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
.validate_basic(fmdat, "fmdat", "reformat_data", item_names, attr_names,
arg_names)
# Check values of class items
if (length(fmdat[["labels"]]) == 0
|| length(fmdat[["labels"]]) != length(fmdat[["ranks"]])
|| length(fmdat[["labels"]]) != length(fmdat[["rank_idx"]])) {
stop("List items in fmdat must be all the same lengths", call. = FALSE)
}
# Labels
assertthat::assert_that(is.atomic(fmdat[["labels"]]),
is.vector(fmdat[["labels"]]),
is.numeric(fmdat[["labels"]]))
# Ranks
assertthat::assert_that(is.atomic(fmdat[["ranks"]]),
is.vector(fmdat[["ranks"]]),
is.numeric(fmdat[["ranks"]]))
# Rank index
assertthat::assert_that(is.atomic(fmdat[["rank_idx"]]),
is.vector(fmdat[["rank_idx"]]),
is.integer(fmdat[["rank_idx"]]))
attr(fmdat, "validated") <- TRUE
fmdat
}
#
# Validate 'sdat' object generated by reformat_data()
#
.validate.sdat <- function(sdat) {
# Need to validate only once
if (methods::is(sdat, "sdat") && attr(sdat, "validated")) {
return(sdat)
}
# Validate class items and attributes
item_names <- c("scores", "labels")
attr_names <- c("modname", "dsid", "nn", "np", "args", "validated")
arg_names <- c("posclass", "na_worst", "ties_method", "modname", "dsid")
.validate_basic(sdat, "sdat", "reformat_data", item_names, attr_names,
arg_names)
# Check values of class items
if (length(sdat[["labels"]]) == 0
|| length(sdat[["labels"]]) != length(sdat[["scores"]])) {
stop("List items in sdat must be all the same lengths", call. = FALSE)
}
# Labels
assertthat::assert_that(is.atomic(sdat[["labels"]]),
is.vector(sdat[["labels"]]),
is.numeric(sdat[["labels"]]))
attr(sdat, "validated") <- TRUE
sdat
}