https://github.com/cran/precrec
Raw File
Tip revision: 9c56b91df45d18e10abd76aa1359b1482955fbbc authored by Takaya Saito on 04 December 2015, 15:40:36 UTC
version 0.1.1
Tip revision: 9c56b91
etc_utils_autoplot.R
#' Plot performance evaluation measures with ggplot2
#'
#' The \code{autoplot} function plots performance evaluation measures
#'   by using \code{\link[ggplot2]{ggplot2}} instead of the general R plot.
#'
#' @param object An S3 object generated by \code{\link{evalmod}}.
#'   The \code{autoplot} function takes one of the following S3 objects.
#'
#' \enumerate{
#'
#'   \item ROC and Precision-Recall curves (mode = "rocprc")
#'
#'   \tabular{lllll}{
#'     \strong{S3 object}
#'     \tab \tab \strong{# of models}
#'     \tab \tab \strong{# of test datasets} \cr
#'
#'     sscurves \tab \tab single   \tab \tab single   \cr
#'     mscurves \tab \tab multiple \tab \tab single   \cr
#'     smcurves \tab \tab single   \tab \tab multiple \cr
#'     mmcurves \tab \tab multiple \tab \tab multiple
#'   }
#'
#'   \item Basic evaluation measures (mode = "basic")
#'
#'   \tabular{lllll}{
#'     \strong{S3 object}
#'     \tab \tab \strong{# of models}
#'     \tab \tab \strong{# of test datasets} \cr
#'
#'     sspoints \tab \tab single   \tab \tab single   \cr
#'     mspoints \tab \tab multiple \tab \tab single   \cr
#'     smpoints \tab \tab single   \tab \tab multiple \cr
#'     mmpoints \tab \tab multiple \tab \tab multiple
#'   }
#' }
#'
#' @param curvetype A character vector with the following curve types
#' \enumerate{
#'
#'   \item ROC and Precision-Recall curves (mode = "rocprc")
#'
#'     \describe{
#'       \item{"ROC"}{ROC curve}
#'       \item{"PRC"}{Precision-Recall curve}
#'       \item{c("ROC", "PRC")}{ROC and Precision-Recall curves}
#'     }
#'
#'   \item Basic evaluation measures (mode = "basic")
#'
#'     \describe{
#'       \item{"error"}{Normalized threshold values vs. error rate}
#'       \item{"accuracy"}{Normalized threshold values vs. accuracy}
#'       \item{"specificity"}{Normalized threshold values vs. specificity}
#'       \item{"sensitivity"}{Normalized threshold values vs. sensitivity}
#'       \item{"precision"}{Normalized threshold values vs. precision}
#'       \item{c("error", "accuracy", "specificity", "sensitivity",
#'               "precision")}{All of the above}
#'     }
#' }
#'
#' @param ... All the following arguments can be specified.
#'
#' \describe{
#'   \item{curvetype}{
#'     \enumerate{
#'
#'       \item ROC and Precision-Recall curves (mode = "rocprc")
#'
#'         \describe{
#'           \item{"ROC"}{ROC curve}
#'           \item{"PRC"}{Precision-Recall curve}
#'           \item{c("ROC", "PRC")}{ROC and Precision-Recall curves}
#'          }
#'
#'       \item Basic evaluation measures (mode = "basic")
#'
#'         \describe{
#'           \item{"error"}{Normalized threshold values vs. error rate}
#'           \item{"accuracy"}{Normalized threshold values vs. accuracy}
#'           \item{"specificity"}{Normalized threshold values vs. specificity}
#'           \item{"sensitivity"}{Normalized threshold values vs. sensitivity}
#'           \item{"precision"}{Normalized threshold values vs. precision}
#'           \item{c("error", "accuracy", "specificity", "sensitivity",
#'               "precision")}{All of the above}
#'         }
#'      }
#'   }
#'   \item{type}{
#'     \describe{
#'       \item{"l"}{lines}
#'       \item{"p"}{points}
#'       \item{"b"}{both lines and points}
#'     }
#'   }
#'   \item{show_cb}{
#'     A Boolean value to specify whether point-wise confidence
#'     bounds are drawn. It is effective only when \code{calc_avg} is
#'     set to \code{TRUE} of the \code{\link{evalmod}} function.
#'   }
#'   \item{raw_curves}{
#'     A Boolean value to specify whether raw curves are
#'     shown instead of the average curve. It is effective only
#'     when \code{raw_curves} is set to \code{TRUE}
#'     of the \code{\link{evalmod}} function.
#'   }
#'   \item{show_legend}{
#'     A Boolean value to specify whether the legend is shown.
#'   }
#'   \item{ret_grob}{
#'     A logical value to indicate whether
#'     \code{autoplot} returns a \code{grob} object. The \code{grob} object
#'     is internally generated by \code{\link[gridExtra]{arrangeGrob}}.
#'     The \code{\link[grid]{grid.draw}} function takes a \code{grob} object and
#'     generates a plot. It is effective only for a multiple-panel plot that can
#'     be generated, for example, when \code{curvetype} is
#'     \code{c("ROC", "PRC")}.
#'   }
#' }
#'
#' @return The \code{autoplot} function returns a \code{ggplot} object
#'   for a single-panel plot and a frame-grob object for a multiple-panel plot.
#'
#' @seealso \code{\link{evalmod}} for generating an S3 object.
#'   \code{\link{fortify}} for converting a curves and points object
#'   to a data frame.  \code{\link{plot}} for plotting the equivalent curves
#'   with the general R plot.
#'
#' @examples
#'
#' ## Load libraries
#' library(ggplot2)
#' library(grid)
#' library(gridExtra)
#'
#' #############################################################################
#' ### Single model & single test dataset
#' ###
#'
#'\dontrun{
#'
#' ## Load a dataset with 10 positives and 10 negatives
#' data(P10N10)
#'
#' ## Generate an sscurve object that contains ROC and Precision-Recall curves
#' sscurves <- evalmod(scores = P10N10$scores, labels = P10N10$labels)
#'
#' ## Plot both ROC and Precision-Recall curves
#' autoplot(sscurves)
#'
#' ## Get a grob object for multiple plots
#' pp1 <- autoplot(sscurves, ret_grob = TRUE)
#' plot.new()
#' grid.draw(pp1)
#'
#' ## A ROC curve
#' autoplot(sscurves, curvetype = "ROC")
#'
#' ## A Precision-Recall curve
#' autoplot(sscurves, curvetype = "PRC")
#'
#' ## Generate an sspoints object that contains basic evaluation measures
#' sspoints <- evalmod(mode = "basic", scores = P10N10$scores,
#'                     labels = P10N10$labels)
#'
#' ## Threshold values vs. basic evaluation measures
#' autoplot(sspoints)
#'
#' ## Threshold vs. precision
#' autoplot(sspoints, curvetype = "precision")
#'
#'}
#'
#' #############################################################################
#' ### Multiple models & single test dataset
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(1, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'                modnames = samps[["modnames"]])
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mscurves <- evalmod(mdat)
#'
#' ## ROC and Precision-Recall curves
#' autoplot(mscurves)
#'
#' ## Hide the legend
#' autoplot(mscurves, show_legend = FALSE)
#'
#' ## Generate an mspoints object that contains basic evaluation measures
#' mspoints <- evalmod(mdat, mode = "basic")
#'
#' ## Threshold values vs. basic evaluation measures
#' autoplot(mspoints)
#'
#' ## Hide the legend
#'autoplot(mspoints, show_legend = FALSE)
#'
#'}
#'
#' #############################################################################
#' ### Single model & multiple test datasets
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "good_er")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'                modnames = samps[["modnames"]],
#'                dsids = samps[["dsids"]])
#'
#' ## Generate an smcurve object that contains ROC and Precision-Recall curves
#' smcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(smcurves)
#'
#' ## Hide confidence bounds
#' autoplot(smcurves, show_cb = FALSE)
#'
#' ## Raw ROC and Precision-Recall curves
#' autoplot(smcurves, raw_curves = TRUE)
#'
#' ## Generate an smpoints object that contains basic evaluation measures
#' smpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Threshold values vs. average basic evaluation measures
#' autoplot(smpoints)
#'
#'}
#'
#' #############################################################################
#' ### Multiple models & multiple test datasets
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#'                modnames = samps[["modnames"]],
#'                dsids = samps[["dsids"]])
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mmcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(mmcurves)
#'
#' ## Show confidence bounds
#' autoplot(mmcurves, show_cb = TRUE)
#'
#' ## Raw ROC and Precision-Recall curves
#' autoplot(mmcurves, raw_curves = TRUE)
#'
#' ## Generate an mmpoints object that contains basic evaluation measures
#' mmpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Threshold values vs. average basic evaluation measures
#' autoplot(mmpoints)
#'
#'}
#'
#' @name autoplot
NULL

#
# Process ... for curve objects
#
.get_autoplot_arglist <- function(def_curvetype, def_type, def_show_cb,
                                  def_raw_curves, def_add_np_nn,
                                  def_show_legend, def_ret_grob, ...) {

  arglist <- list(...)

  if (is.null(arglist[["curvetype"]])){
    arglist[["curvetype"]] <- def_curvetype
  }

  if (is.null(arglist[["type"]])){
    arglist[["type"]] <- def_type
  }

  if (is.null(arglist[["show_cb"]])){
    arglist[["show_cb"]] <- def_show_cb
  }

  if (is.null(arglist[["raw_curves"]])){
    arglist[["raw_curves"]] <- def_raw_curves
  }

  if (is.null(arglist[["add_np_nn"]])){
    arglist[["add_np_nn"]] <- def_add_np_nn
  }

  if (is.null(arglist[["show_legend"]])){
    arglist[["show_legend"]] <- def_show_legend
  }

  if (is.null(arglist[["ret_grob"]])){
    arglist[["ret_grob"]] <- def_ret_grob
  }

  arglist
}

#
# Prepare autoplot and return a data frame
#
.prepare_autoplot <- function(object, curve_df = NULL, curvetype = NULL, ...) {
  # === Check package availability  ===
  .load_ggplot2()

  # === Validate input arguments ===
  .validate(object)

  # === Prepare a data frame for ggplot2 ===
  if (is.null(curve_df)) {
    curve_df <- ggplot2::fortify(object, ...)
  }

  if (!is.null(curvetype)) {
    ctype <- curvetype
    curve_df <- subset(curve_df, curvetype == ctype)
  }

  curve_df
}

#
# Load ggplot2
#
.load_ggplot2 <- function() {
  if (!requireNamespace("ggplot2", quietly = TRUE)) {
    stop(paste("This function should not be called directly,",
               "and ggplot2 is needed to work.",
               "Please install it."),
         call. = FALSE)
  }
}

#
# Load grid
#
.load_grid <- function() {
  if (!requireNamespace("grid", quietly = TRUE)) {
    stop("grid needed for this function to work. Please install it.",
         call. = FALSE)
  }
}

#
# Load gridExtra
#
.load_gridExtra <- function() {
  if (!requireNamespace("gridExtra", quietly = TRUE)) {
    stop("gridExtra needed for this function to work. Please install it.",
         call. = FALSE)
  }
}

#
# Plot ROC and Precision-Recall
#
.autoplot_multi <- function(object, arglist) {
  curvetype <- arglist[["curvetype"]]
  type <- arglist[["type"]]
  show_cb <- arglist[["show_cb"]]
  raw_curves <- arglist[["raw_curves"]]
  add_np_nn <- arglist[["add_np_nn"]]
  show_legend <- arglist[["show_legend"]]
  ret_grob <- arglist[["ret_grob"]]

  # === Check package availability  ===
  .load_ggplot2()
  .validate(object)
  .check_curvetype(curvetype, object)
  .check_type(type)
  .check_show_cb(show_cb, object)
  .check_raw_curves(raw_curves, object)
  .check_show_legend(show_legend)
  .check_add_np_nn(add_np_nn)
  .check_ret_grob(ret_grob)

  # === Create a ggplot object for ROC&PRC, ROC, or PRC ===
  curve_df <- ggplot2::fortify(object, raw_curves = raw_curves)

  func_plot <- function(ctype) {
    .autoplot_single(object, curve_df, curvetype = ctype, type = type,
                     show_cb = show_cb, raw_curves = raw_curves,
                     show_legend = show_legend, add_np_nn = add_np_nn)
  }
  lcurves <- lapply(curvetype, func_plot)
  names(lcurves) <- curvetype

  if (length(lcurves) > 1) {
    do.call(.combine_plots, c(lcurves, show_legend = show_legend,
                              ret_grob = ret_grob, nplots = length(lcurves)))
  } else {
    lcurves[[1]]
  }
}

#
# .grid_arrange_shared_legend
#
#   Modified version of grid_arrange_shared_legend from RPubs
#   URL of the original version:
#     http://rpubs.com/sjackman/grid_arrange_shared_legend
#
.grid_arrange_shared_legend <- function(..., main_ncol = 2) {
  plots <- list(...)

  g <- ggplot2::ggplotGrob(plots[[1]]
                           + ggplot2::theme(legend.position = "bottom"))$grobs
  legend <- g[[which(sapply(g, function(x) x$name) == "guide-box")]]
  lheight <- sum(legend$height)

  fncol <- function (...) gridExtra::arrangeGrob(..., ncol = main_ncol)
  fnolegend <- function(x) x + ggplot2::theme(legend.position = "none")

  gridExtra::arrangeGrob(
    do.call(fncol, lapply(plots, fnolegend)),
    legend,
    heights = grid::unit.c(grid::unit(1, "npc") - lheight, lheight),
    ncol = 1)
}

#
# Combine ROC and Precision-Recall plots
#
.combine_plots <- function(..., show_legend, ret_grob, nplots) {
  .load_grid()
  .load_gridExtra()

  if (nplots == 2 || nplots == 4) {
    ncol <- 2
  } else {
    ncol <- 3
  }

  if (show_legend) {
    grobframe <- .grid_arrange_shared_legend(..., main_ncol = ncol)
  } else {
    grobframe <- gridExtra::arrangeGrob(..., ncol = ncol)
  }

  if (ret_grob) {
    grobframe
  } else {
    graphics::plot.new()
    grid::grid.draw(grobframe)
  }
}

#
# Plot ROC or Precision-Recall
#
.autoplot_single <- function(object, curve_df, curvetype = "ROC", type = "l",
                             show_cb = FALSE, raw_curves = FALSE,
                             show_legend = FALSE, add_np_nn = TRUE, ...) {

  curve_df <- .prepare_autoplot(object, curve_df = curve_df,
                                curvetype = curvetype,
                                raw_curves = raw_curves, ...)

  # === Create a ggplot object ===
  if (raw_curves) {
    p <- ggplot2::ggplot(curve_df,
                         ggplot2::aes_string(x = 'x', y = 'y',
                                             group = 'dsid_modname',
                                             color = 'modname'))

    if (type == "l") {
      p <- p + ggplot2::geom_line()
    } else if (type == "b" || type == "p") {
      if (type == "b") {
        p <- p + ggplot2::geom_line(alpha = 0.25)
      }
      p <- p + ggplot2::geom_point()
    }

  } else if (show_cb) {
    p <- ggplot2::ggplot(curve_df,
                         ggplot2::aes_string(x = 'x', y = 'y',
                                             ymin = 'ymin', ymax = 'ymax'))
    if (type == "l") {
      p <- p + ggplot2::geom_smooth(ggplot2::aes_string(color = 'modname'),
                                    stat = "identity")
    } else if (type == "b" || type == "p") {
      p <- p + ggplot2::geom_ribbon(ggplot2::aes_string(min = 'ymin',
                                                        ymax = 'ymax'),
                                    stat = "identity", alpha = 0.25,
                                    fill = "grey25")
      if (type == "b") {
        p <- p + ggplot2::geom_line(ggplot2::aes_string(color = 'modname'),
                                    alpha = 0.25)
      }
      p <- p + ggplot2::geom_point(ggplot2::aes_string(x = 'x', y = 'y',
                                                       color = 'modname'))
    }
  } else {
    p <- ggplot2::ggplot(curve_df, ggplot2::aes_string(x = 'x', y = 'y',
                                                       color = 'modname'))
    if (type == "l") {
      p <- p + ggplot2::geom_line()
    } else if (type == "b" || type == "p") {
      if (type == "b") {
        p <- p + ggplot2::geom_line(alpha = 0.25)
      }
      p <- p + ggplot2::geom_point()
    }
  }

  if (curvetype == "ROC") {
    func_g <- .geom_basic_roc
  } else if (curvetype == "PRC") {
    func_g <- .geom_basic_prc
  } else {
    func_g <- .geom_basic_point
  }
  p <- func_g(p, object[[1]], show_legend = show_legend, add_np_nn = add_np_nn,
              curve_df = curve_df, ...)

  p
}

#
# Geom basic
#
.geom_basic <- function(p, main, xlab, ylab, show_legend) {
  p <- p + ggplot2::theme_bw()
  p <- p + ggplot2::ggtitle(main)
  p <- p + ggplot2::xlab(xlab)
  p <- p + ggplot2::ylab(ylab)

  p <- p + ggplot2::theme(legend.title = ggplot2::element_blank())
  if (!show_legend) {
    p <- p + ggplot2::theme(legend.position = "none")
  }

  p
}

#
# Make main title
#
.make_rocprc_title <- function(object, pt) {
  np <- attr(object, "data_info")[["np"]]
  nn <- attr(object, "data_info")[["nn"]]

  main <- paste0(pt, " - P: ", np, ", N: ", nn)
}

#
# Geom basic for ROC
#
.geom_basic_roc <- function(p, object, show_legend = TRUE, add_np_nn = TRUE,
                            ...) {
  if (add_np_nn) {
    main <- .make_rocprc_title(object, "ROC")
  } else {
    main <- "ROC"
  }

  p <- p + ggplot2::geom_abline(intercept = 0, slope = 1, colour = "grey",
                                linetype = 3)
  p <- p + ggplot2::coord_fixed(ratio = 1)

  p <- .geom_basic(p, main, "1 - Specificity", "Sensitivity", show_legend)

  p
}

#
# Geom_line for Precision-Recall
#
.geom_basic_prc <- function(p, object, show_legend = TRUE, add_np_nn = TRUE,
                            ...) {
  if (add_np_nn) {
    main <- .make_rocprc_title(object, "Precision-Recall")
  } else {
    main <- "Precision-Recall"
  }

  np <- attr(object, "data_info")[["np"]]
  nn <- attr(object, "data_info")[["nn"]]
  p <- p + ggplot2::geom_hline(yintercept = np / (np + nn), colour = "grey",
                               linetype = 3)
  p <- p + ggplot2::scale_y_continuous(limits = c(0.0, 1.0))
  p <- p + ggplot2::coord_fixed(ratio = 1)

  p <- .geom_basic(p, main, "Recall", "Precision", show_legend)

  p
}

#
# Geom_line for Precision-Recall
#
.geom_basic_point <- function(p, object, show_legend = TRUE,
                              curve_df = curve_df, ...) {

  s <- curve_df[["curvetype"]][1]
  main <- paste0(toupper(substring(s, 1, 1)), substring(s,2))

  p <- p + ggplot2::scale_y_continuous(limits = c(0.0, 1.0))
  p <- p + ggplot2::coord_fixed(ratio = 1)

  p <- .geom_basic(p, main, "threshold", s, show_legend)

  p
}
back to top