https://github.com/cran/precrec
Tip revision: 9c56b91df45d18e10abd76aa1359b1482955fbbc authored by Takaya Saito on 04 December 2015, 15:40:36 UTC
version 0.1.1
version 0.1.1
Tip revision: 9c56b91
etc_utils_plot.R
#' Plot performance evaluation measures
#'
#' The \code{plot} function plots performance evaluation measures
#'
#' @param x An S3 object generated by \code{\link{evalmod}}.
#' The \code{plot} function takes one of the following S3 objects.
#'
#' \enumerate{
#'
#' \item ROC and Precision-Recall curves (mode = "rocprc")
#'
#' \tabular{lllll}{
#' \strong{S3 object}
#' \tab \tab \strong{# of models}
#' \tab \tab \strong{# of test datasets} \cr
#'
#' sscurves \tab \tab single \tab \tab single \cr
#' mscurves \tab \tab multiple \tab \tab single \cr
#' smcurves \tab \tab single \tab \tab multiple \cr
#' mmcurves \tab \tab multiple \tab \tab multiple
#' }
#'
#' \item Basic evaluation measures (mode = "basic")
#'
#' \tabular{lllll}{
#' \strong{S3 object}
#' \tab \tab \strong{# of models}
#' \tab \tab \strong{# of test datasets} \cr
#'
#' sspoints \tab \tab single \tab \tab single \cr
#' mspoints \tab \tab multiple \tab \tab single \cr
#' smpoints \tab \tab single \tab \tab multiple \cr
#' mmpoints \tab \tab multiple \tab \tab multiple
#' }
#' }
#'
#' @param y Equivalent with \code{curvetype}.
#'
#' @param ... All the following arguments can be specified.
#'
#' \describe{
#' \item{curvetype}{
#' \enumerate{
#'
#' \item ROC and Precision-Recall curves (mode = "rocprc")
#'
#' \describe{
#' \item{"ROC"}{ROC curve}
#' \item{"PRC"}{Precision-Recall curve}
#' \item{c("ROC", "PRC")}{ROC and Precision-Recall curves}
#' }
#'
#' \item Basic evaluation measures (mode = "basic")
#'
#' \describe{
#' \item{"error"}{Normalized threshold values vs. error rate}
#' \item{"accuracy"}{Normalized threshold values vs. accuracy}
#' \item{"specificity"}{Normalized threshold values vs. specificity}
#' \item{"sensitivity"}{Normalized threshold values vs. sensitivity}
#' \item{"precision"}{Normalized threshold values vs. precision}
#' \item{c("error", "accuracy", "specificity", "sensitivity",
#' "precision")}{All of the above}
#' }
#' }
#' }
#' \item{type}{
#' \describe{
#' \item{"l"}{lines}
#' \item{"p"}{points}
#' \item{"b"}{both lines and points}
#' }
#' }
#' \item{show_cb}{
#' A Boolean value to specify whether point-wise confidence
#' bounds are drawn. It is effective only when \code{calc_avg} is
#' set to \code{TRUE} of the \code{\link{evalmod}} function.
#' }
#' \item{raw_curves}{
#' A Boolean value to specify whether raw curves are
#' shown instead of the average curve. It is effective only
#' when \code{raw_curves} is set to \code{TRUE}
#' of the \code{\link{evalmod}} function.
#' }
#' \item{show_legend}{
#' A Boolean value to specify whether the legend is shown.
#' }
#' }
#'
#' @return The \code{plot} function shows a plot and returns NULL.
#'
#' @seealso \code{\link{evalmod}} for generating an S3 object.
#' \code{\link{autoplot}} for plotting the equivalent curves
#' with \code{\link[ggplot2]{ggplot2}}.
#'
#' @examples
#'
#' #############################################################################
#' ### Single model & single test dataset
#' ###
#'
#'\dontrun{
#'
#' ## Load a dataset with 10 positives and 10 negatives
#' data(P10N10)
#'
#' ## Generate an sscurve object that contains ROC and Precision-Recall curves
#' sscurves <- evalmod(scores = P10N10$scores, labels = P10N10$labels)
#'
#' ## Plot both ROC and Precision-Recall curves
#' plot(sscurves)
#'
#' ## Plot a ROC curve
#' plot(sscurves, curvetype = "ROC")
#'
#' ## Plot a Precision-Recall curve
#' plot(sscurves, curvetype = "PRC")
#'
#' ## Generate an sspoints object that contains basic evaluation measures
#' sspoints <- evalmod(mode = "basic", scores = P10N10$scores,
#' labels = P10N10$labels)
#'
#' ## Plot threshold values vs. basic evaluation measures
#' plot(sspoints)
#'
#' ## Plot threshold vs. precision
#' plot(sspoints, curvetype = "precision")
#'
#'}
#'
#' #############################################################################
#' ### Multiple models & single test dataset
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(1, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]])
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mscurves <- evalmod(mdat)
#'
#' ## Plot both ROC and Precision-Recall curves
#' plot(mscurves)
#'
#' ## Hide the legend
#' plot(mscurves, show_legend = FALSE)
#'
#' ## Generate an mspoints object that contains basic evaluation measures
#' mspoints <- evalmod(mdat, mode = "basic")
#'
#' ## Plot threshold values vs. basic evaluation measures
#' plot(mspoints)
#'
#' ## Hide the legend
#' plot(mspoints, show_legend = FALSE)
#'
#'}
#'
#' #############################################################################
#' ### Single model & multiple test datasets
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "good_er")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]],
#' dsids = samps[["dsids"]])
#'
#' ## Generate an smcurve object that contains ROC and Precision-Recall curves
#' smcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Plot average ROC and Precision-Recall curves
#' plot(smcurves)
#'
#' ## Hide confidence bounds
#' plot(smcurves, show_cb = FALSE)
#'
#' ## Plot raw ROC and Precision-Recall curves
#' plot(smcurves, raw_curves = TRUE)
#'
#' ## Generate an smpoints object that contains basic evaluation measures
#' smpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Plot threshold values vs. average basic evaluation measures
#' plot(smpoints)
#'
#'}
#'
#' #############################################################################
#' ### Multiple models & multiple test datasets
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]],
#' dsids = samps[["dsids"]])
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mmcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Plot average ROC and Precision-Recall curves
#' plot(mmcurves)
#'
#' ## Show confidence bounds
#' plot(mmcurves, show_cb = TRUE)
#'
#' ## Plot raw ROC and Precision-Recall curves
#' plot(mmcurves, raw_curves = TRUE)
#'
#' ## Generate an mmpoints object that contains basic evaluation measures
#' mmpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Plot threshold values vs. average basic evaluation measures
#' plot(mmpoints)
#'}
#'
#' @name plot
NULL
#
# Check partial match - ROC and PRC curve type
#
.pmatch_curvetype_rocprc <- function(vals) {
pfunc <- function(val) {
if (assertthat::is.string(val)) {
sval <- tolower(val)
if (!is.na(pmatch(sval, "roc"))) {
return("ROC")
}
if (!is.na(pmatch(sval, "prc"))) {
return("PRC")
}
}
val
}
unlist(lapply(vals, pfunc))
}
#
# Check partial match - Basic evaluation measures
#
.pmatch_curvetype_basic <- function(vals) {
pfunc <- function(val) {
if (assertthat::is.string(val)) {
sval <- tolower(val)
if (!is.na(pmatch(sval, "error rate"))) {
return("error")
}
if (!is.na(pmatch(sval, "accuracy"))) {
return("accuracy")
}
if (!is.na(pmatch(sval, "specificity")) || sval == "tnr") {
return("specificity")
}
if (!is.na(pmatch(sval, "sensitivity"))
|| !is.na(pmatch(sval, "recall")) || sval == "tpr" || sval == "sn") {
return("sensitivity")
}
if (!is.na(pmatch(sval, "precision")) || sval == "ppv") {
return("precision")
}
}
val
}
unlist(lapply(vals, pfunc))
}
#
# Process ... for curve objects
#
.get_plot_arglist <- function(y, def_curvetype, def_type, def_show_cb,
def_raw_curves, def_add_np_nn, def_show_legend,
...) {
arglist <- list(...)
if (!is.null(y)) {
arglist[["curvetype"]] <- y
}
if (is.null(arglist[["curvetype"]])){
arglist[["curvetype"]] <- def_curvetype
}
if (is.null(arglist[["type"]])){
arglist[["type"]] <- def_type
}
if (is.null(arglist[["show_cb"]])){
arglist[["show_cb"]] <- def_show_cb
}
if (is.null(arglist[["raw_curves"]])){
arglist[["raw_curves"]] <- def_raw_curves
}
if (is.null(arglist[["add_np_nn"]])){
arglist[["add_np_nn"]] <- def_add_np_nn
}
if (is.null(arglist[["show_legend"]])){
arglist[["show_legend"]] <- def_show_legend
}
arglist
}
#
# Plot ROC and Precision-Recall
#
.plot_multi <- function(x, arglist) {
curvetype <- arglist[["curvetype"]]
type <- arglist[["type"]]
show_cb <- arglist[["show_cb"]]
raw_curves <- arglist[["raw_curves"]]
add_np_nn <- arglist[["add_np_nn"]]
show_legend <- arglist[["show_legend"]]
# === Validate input arguments ===
.validate(x)
.check_curvetype(curvetype, x)
.check_type(type)
.check_show_cb(show_cb, x)
.check_raw_curves(raw_curves, x)
.check_add_np_nn(add_np_nn)
.check_show_legend(show_legend)
# === Create a plot ===
show_legend2 <- show_legend
if (length(curvetype) > 1) {
.set_layout(length(curvetype), show_legend)
on.exit(graphics::layout(1), add = TRUE)
show_legend2 <- FALSE
}
for (ct in curvetype) {
.plot_single(x, ct, type = type, show_cb = show_cb,
raw_curves = raw_curves, add_np_nn = add_np_nn,
show_legend = show_legend2)
}
if (length(curvetype) == 5) {
graphics::plot(1, type = "n", axes = FALSE, xlab = "", ylab = "")
}
if (length(curvetype) > 1) {
.show_legend(x, show_legend)
}
}
#
# Set layout
#
.set_layout <- function(ctype_len, show_legend) {
if (ctype_len == 1) {
nrow1 <- 2
ncol1 <- 1
mat1 <- c(1, 2)
mat2 <- c(1)
heights = c(0.85, 0.15)
} else if (ctype_len == 2) {
nrow1 <- 2
ncol1 <- 2
mat1 <- c(1, 2, 3, 3)
mat2 <- c(1, 2)
heights = c(0.85, 0.15)
} else if (ctype_len == 3) {
nrow1 <- 2
ncol1 <- 3
mat1 <- c(1, 2, 3, 4, 4, 4)
mat2 <- c(1, 2, 3)
heights = c(0.85, 0.15)
} else if (ctype_len == 4) {
nrow1 <- 3
ncol1 <- 2
mat1 <- c(1, 2, 3, 4, 5, 5)
mat2 <- c(1, 2, 3, 4)
heights = c(0.425, 0.425, 0.15)
} else if (ctype_len > 4) {
nrow1 <- 3
ncol1 <- 3
mat1 <- c(1, 2, 3, 4, 5, 6, 7, 7, 7)
mat2 <- c(1, 2, 3, 4, 5, 6)
heights = c(0.425, 0.425, 0.15)
}
if (show_legend) {
m <- matrix(mat1, nrow = nrow1, ncol = ncol1, byrow = TRUE)
graphics::layout(mat = m, heights = heights)
} else {
m <- matrix(mat2, nrow = nrow1 - 1, ncol = ncol1)
graphics::layout(mat = m)
}
}
#
# matplot wrapper
#
.matplot_wrapper <- function(obj, type, curvetype, main, xlab, ylab) {
# === Validate input arguments ===
.validate(obj[[curvetype]])
# === Create line colours ===
model_type <- attr(obj, "model_type")
if (model_type == "single") {
line_col <- "black"
} else {
line_col <- .make_multi_colors(obj)
}
# === Create a plot ===
mats <- .make_matplot_mats(obj[[curvetype]])
graphics::matplot(mats[["x"]], mats[["y"]], type = type, lty = 1, pch = 19,
col = line_col, main = main, xlab = xlab, ylab = ylab,
ylim = c(0, 1), xlim = c(0, 1))
}
#
# Make matrices for matplot
#
.make_matplot_mats <- function(obj) {
ncol <- length(obj)
max_nrow <- max(unlist(lapply(obj, function(o) length(o[["x"]]))))
x <- matrix(as.double(NA), nrow = max_nrow, ncol = ncol)
y <- matrix(as.double(NA), nrow = max_nrow, ncol = ncol)
for (i in seq_along(obj)) {
x[1:length(obj[[i]][["x"]]), i] <- obj[[i]][["x"]]
y[1:length(obj[[i]][["y"]]), i] <- obj[[i]][["y"]]
}
list(x = x, y = y)
}
#
# Make colours for multiple models and multiple datasets
#
.make_multi_colors <- function(obj) {
uniq_modnames <- attr(obj, "uniq_modnames")
modnames <- attr(obj, "data_info")[["modnames"]]
uniq_col <- grDevices::rainbow(length(uniq_modnames))
modnams_idx <- as.numeric(factor(modnames, levels = uniq_modnames))
unlist(lapply(seq_along(modnames), function(i) uniq_col[modnams_idx[i]]))
}
#
# Plot average line with CI
#
.plot_avg <- function(obj, type, curvetype, main, xlab, ylab, show_cb) {
# === Create a plot ===
grp_avg <- attr(obj, "grp_avg")
avgcurves <- grp_avg[[curvetype]]
graphics::plot(1, type = "n", main = main, xlab = xlab, ylab = ylab,
ylim = c(0, 1), xlim = c(0, 1))
if (length(avgcurves) == 1) {
lcols <- "blue"
} else {
lcols <- grDevices::rainbow(length(avgcurves))
}
for (i in 1:length(avgcurves)) {
.add_curve_with_ci(avgcurves, type, i, "grey", lcols[i], show_cb)
}
}
#
# Add a curve with CI
#
.add_curve_with_ci <- function(avgcurves, type, idx, pcol, lcol, show_cb) {
x <- avgcurves[[idx]][["x"]]
y <- avgcurves[[idx]][["y_avg"]]
if (show_cb) {
ymin <- avgcurves[[idx]][["y_ci_l"]]
ymax <- avgcurves[[idx]][["y_ci_h"]]
g <- grDevices::col2rgb(pcol)
graphics::polygon(c(x, rev(x)), c(ymin, rev(ymax)), border = FALSE,
col = grDevices::rgb(g[1], g[2], g[3], 180,
maxColorValue = 255))
}
b <- grDevices::col2rgb(lcol)
graphics::lines(x, y, type = type, lty = 1, pch = 19,
col = grDevices::rgb(b[1], b[2], b[3], 200,
maxColorValue = 255))
}
#
# Plot ROC or Precision-Recall
#
.plot_single <- function(x, curvetype, type = type, show_cb = FALSE,
raw_curves = FALSE, add_np_nn = TRUE,
show_legend = TRUE) {
tlist <- .get_titiles(curvetype)
main <- tlist[["main"]]
np <- attr(x, "data_info")[["np"]][[1]]
nn <- attr(x, "data_info")[["nn"]][[1]]
if (add_np_nn) {
main <- paste0(main, " - P: ", np, ", N: ", nn)
}
old_pty <- graphics::par(pty = "s")
on.exit(graphics::par(old_pty), add = TRUE)
if (show_legend) {
.set_layout(1, show_legend)
on.exit(graphics::layout(1), add = TRUE)
}
# === Create a plot ===
if (raw_curves) {
.matplot_wrapper(x, type, tlist[["ctype"]], main, tlist[["xlab"]],
tlist[["ylab"]])
} else {
.plot_avg(x, type, tlist[["ctype"]], main, tlist[["xlab"]],
tlist[["ylab"]], show_cb)
}
if (curvetype == "ROC") {
graphics::abline(a = 0, b = 1, col = "grey", lty = 3)
} else if (curvetype == "PRC") {
graphics::abline(h = np / (np + nn), col = "grey", lty = 3)
}
.show_legend(x, show_legend)
}
#
# Get title and subtitles
#
.get_titiles <- function(curvetype) {
tlist = list()
if (curvetype == "ROC") {
tlist[["main"]] <- "ROC"
tlist[["xlab"]] <- "1 - Specificity"
tlist[["ylab"]] <- "Sensitivity"
tlist[["ctype"]] <- "rocs"
} else if (curvetype == "PRC") {
tlist[["main"]] <- "Precision-Recall"
tlist[["xlab"]] <- "Recall"
tlist[["ylab"]] <- "Precision"
tlist[["ctype"]] <- "prcs"
} else {
mnames <- list(error = "err", accuracy = "acc", specificity = "sp",
sensitivity = "sn", precision = "prec")
main <- paste0(toupper(substring(curvetype, 1, 1)), substring(curvetype,2))
tlist[["main"]] <- main
tlist[["xlab"]] <- "threshold"
tlist[["ylab"]] <- curvetype
tlist[["ctype"]] <- mnames[[curvetype]]
}
tlist
}
#
# Show legend
#
.show_legend <- function(obj, show_legend, gnames = "modnames") {
if (show_legend) {
old_mar <- graphics::par(mar = c(0, 0, 0, 0))
on.exit(graphics::par(old_mar), add = TRUE)
old_pty <- graphics::par(pty = "m")
on.exit(graphics::par(old_pty), add = TRUE)
gnames <- attr(obj, paste0("uniq_", gnames))
graphics::plot(1, type = "n", axes = FALSE, xlab = "", ylab = "")
graphics::legend(x = "top", lty = 1,
legend = gnames,
col = grDevices::rainbow(length(gnames)),
horiz = TRUE)
}
}