https://github.com/cran/precrec
Tip revision: 9c56b91df45d18e10abd76aa1359b1482955fbbc authored by Takaya Saito on 04 December 2015, 15:40:36 UTC
version 0.1.1
version 0.1.1
Tip revision: 9c56b91
etc_utils_autoplot.R
#' Plot performance evaluation measures with ggplot2
#'
#' The \code{autoplot} function plots performance evaluation measures
#' by using \code{\link[ggplot2]{ggplot2}} instead of the general R plot.
#'
#' @param object An S3 object generated by \code{\link{evalmod}}.
#' The \code{autoplot} function takes one of the following S3 objects.
#'
#' \enumerate{
#'
#' \item ROC and Precision-Recall curves (mode = "rocprc")
#'
#' \tabular{lllll}{
#' \strong{S3 object}
#' \tab \tab \strong{# of models}
#' \tab \tab \strong{# of test datasets} \cr
#'
#' sscurves \tab \tab single \tab \tab single \cr
#' mscurves \tab \tab multiple \tab \tab single \cr
#' smcurves \tab \tab single \tab \tab multiple \cr
#' mmcurves \tab \tab multiple \tab \tab multiple
#' }
#'
#' \item Basic evaluation measures (mode = "basic")
#'
#' \tabular{lllll}{
#' \strong{S3 object}
#' \tab \tab \strong{# of models}
#' \tab \tab \strong{# of test datasets} \cr
#'
#' sspoints \tab \tab single \tab \tab single \cr
#' mspoints \tab \tab multiple \tab \tab single \cr
#' smpoints \tab \tab single \tab \tab multiple \cr
#' mmpoints \tab \tab multiple \tab \tab multiple
#' }
#' }
#'
#' @param curvetype A character vector with the following curve types
#' \enumerate{
#'
#' \item ROC and Precision-Recall curves (mode = "rocprc")
#'
#' \describe{
#' \item{"ROC"}{ROC curve}
#' \item{"PRC"}{Precision-Recall curve}
#' \item{c("ROC", "PRC")}{ROC and Precision-Recall curves}
#' }
#'
#' \item Basic evaluation measures (mode = "basic")
#'
#' \describe{
#' \item{"error"}{Normalized threshold values vs. error rate}
#' \item{"accuracy"}{Normalized threshold values vs. accuracy}
#' \item{"specificity"}{Normalized threshold values vs. specificity}
#' \item{"sensitivity"}{Normalized threshold values vs. sensitivity}
#' \item{"precision"}{Normalized threshold values vs. precision}
#' \item{c("error", "accuracy", "specificity", "sensitivity",
#' "precision")}{All of the above}
#' }
#' }
#'
#' @param ... All the following arguments can be specified.
#'
#' \describe{
#' \item{curvetype}{
#' \enumerate{
#'
#' \item ROC and Precision-Recall curves (mode = "rocprc")
#'
#' \describe{
#' \item{"ROC"}{ROC curve}
#' \item{"PRC"}{Precision-Recall curve}
#' \item{c("ROC", "PRC")}{ROC and Precision-Recall curves}
#' }
#'
#' \item Basic evaluation measures (mode = "basic")
#'
#' \describe{
#' \item{"error"}{Normalized threshold values vs. error rate}
#' \item{"accuracy"}{Normalized threshold values vs. accuracy}
#' \item{"specificity"}{Normalized threshold values vs. specificity}
#' \item{"sensitivity"}{Normalized threshold values vs. sensitivity}
#' \item{"precision"}{Normalized threshold values vs. precision}
#' \item{c("error", "accuracy", "specificity", "sensitivity",
#' "precision")}{All of the above}
#' }
#' }
#' }
#' \item{type}{
#' \describe{
#' \item{"l"}{lines}
#' \item{"p"}{points}
#' \item{"b"}{both lines and points}
#' }
#' }
#' \item{show_cb}{
#' A Boolean value to specify whether point-wise confidence
#' bounds are drawn. It is effective only when \code{calc_avg} is
#' set to \code{TRUE} of the \code{\link{evalmod}} function.
#' }
#' \item{raw_curves}{
#' A Boolean value to specify whether raw curves are
#' shown instead of the average curve. It is effective only
#' when \code{raw_curves} is set to \code{TRUE}
#' of the \code{\link{evalmod}} function.
#' }
#' \item{show_legend}{
#' A Boolean value to specify whether the legend is shown.
#' }
#' \item{ret_grob}{
#' A logical value to indicate whether
#' \code{autoplot} returns a \code{grob} object. The \code{grob} object
#' is internally generated by \code{\link[gridExtra]{arrangeGrob}}.
#' The \code{\link[grid]{grid.draw}} function takes a \code{grob} object and
#' generates a plot. It is effective only for a multiple-panel plot that can
#' be generated, for example, when \code{curvetype} is
#' \code{c("ROC", "PRC")}.
#' }
#' }
#'
#' @return The \code{autoplot} function returns a \code{ggplot} object
#' for a single-panel plot and a frame-grob object for a multiple-panel plot.
#'
#' @seealso \code{\link{evalmod}} for generating an S3 object.
#' \code{\link{fortify}} for converting a curves and points object
#' to a data frame. \code{\link{plot}} for plotting the equivalent curves
#' with the general R plot.
#'
#' @examples
#'
#' ## Load libraries
#' library(ggplot2)
#' library(grid)
#' library(gridExtra)
#'
#' #############################################################################
#' ### Single model & single test dataset
#' ###
#'
#'\dontrun{
#'
#' ## Load a dataset with 10 positives and 10 negatives
#' data(P10N10)
#'
#' ## Generate an sscurve object that contains ROC and Precision-Recall curves
#' sscurves <- evalmod(scores = P10N10$scores, labels = P10N10$labels)
#'
#' ## Plot both ROC and Precision-Recall curves
#' autoplot(sscurves)
#'
#' ## Get a grob object for multiple plots
#' pp1 <- autoplot(sscurves, ret_grob = TRUE)
#' plot.new()
#' grid.draw(pp1)
#'
#' ## A ROC curve
#' autoplot(sscurves, curvetype = "ROC")
#'
#' ## A Precision-Recall curve
#' autoplot(sscurves, curvetype = "PRC")
#'
#' ## Generate an sspoints object that contains basic evaluation measures
#' sspoints <- evalmod(mode = "basic", scores = P10N10$scores,
#' labels = P10N10$labels)
#'
#' ## Threshold values vs. basic evaluation measures
#' autoplot(sspoints)
#'
#' ## Threshold vs. precision
#' autoplot(sspoints, curvetype = "precision")
#'
#'}
#'
#' #############################################################################
#' ### Multiple models & single test dataset
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(1, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]])
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mscurves <- evalmod(mdat)
#'
#' ## ROC and Precision-Recall curves
#' autoplot(mscurves)
#'
#' ## Hide the legend
#' autoplot(mscurves, show_legend = FALSE)
#'
#' ## Generate an mspoints object that contains basic evaluation measures
#' mspoints <- evalmod(mdat, mode = "basic")
#'
#' ## Threshold values vs. basic evaluation measures
#' autoplot(mspoints)
#'
#' ## Hide the legend
#'autoplot(mspoints, show_legend = FALSE)
#'
#'}
#'
#' #############################################################################
#' ### Single model & multiple test datasets
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "good_er")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]],
#' dsids = samps[["dsids"]])
#'
#' ## Generate an smcurve object that contains ROC and Precision-Recall curves
#' smcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(smcurves)
#'
#' ## Hide confidence bounds
#' autoplot(smcurves, show_cb = FALSE)
#'
#' ## Raw ROC and Precision-Recall curves
#' autoplot(smcurves, raw_curves = TRUE)
#'
#' ## Generate an smpoints object that contains basic evaluation measures
#' smpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Threshold values vs. average basic evaluation measures
#' autoplot(smpoints)
#'
#'}
#'
#' #############################################################################
#' ### Multiple models & multiple test datasets
#' ###
#'
#'\dontrun{
#'
#' ## Create sample datasets with 100 positives and 100 negatives
#' samps <- create_sim_samples(10, 100, 100, "all")
#' mdat <- mmdata(samps[["scores"]], samps[["labels"]],
#' modnames = samps[["modnames"]],
#' dsids = samps[["dsids"]])
#'
#' ## Generate an mscurve object that contains ROC and Precision-Recall curves
#' mmcurves <- evalmod(mdat, raw_curves = TRUE)
#'
#' ## Average ROC and Precision-Recall curves
#' autoplot(mmcurves)
#'
#' ## Show confidence bounds
#' autoplot(mmcurves, show_cb = TRUE)
#'
#' ## Raw ROC and Precision-Recall curves
#' autoplot(mmcurves, raw_curves = TRUE)
#'
#' ## Generate an mmpoints object that contains basic evaluation measures
#' mmpoints <- evalmod(mdat, mode = "basic")
#'
#' ## Threshold values vs. average basic evaluation measures
#' autoplot(mmpoints)
#'
#'}
#'
#' @name autoplot
NULL
#
# Process ... for curve objects
#
.get_autoplot_arglist <- function(def_curvetype, def_type, def_show_cb,
def_raw_curves, def_add_np_nn,
def_show_legend, def_ret_grob, ...) {
arglist <- list(...)
if (is.null(arglist[["curvetype"]])){
arglist[["curvetype"]] <- def_curvetype
}
if (is.null(arglist[["type"]])){
arglist[["type"]] <- def_type
}
if (is.null(arglist[["show_cb"]])){
arglist[["show_cb"]] <- def_show_cb
}
if (is.null(arglist[["raw_curves"]])){
arglist[["raw_curves"]] <- def_raw_curves
}
if (is.null(arglist[["add_np_nn"]])){
arglist[["add_np_nn"]] <- def_add_np_nn
}
if (is.null(arglist[["show_legend"]])){
arglist[["show_legend"]] <- def_show_legend
}
if (is.null(arglist[["ret_grob"]])){
arglist[["ret_grob"]] <- def_ret_grob
}
arglist
}
#
# Prepare autoplot and return a data frame
#
.prepare_autoplot <- function(object, curve_df = NULL, curvetype = NULL, ...) {
# === Check package availability ===
.load_ggplot2()
# === Validate input arguments ===
.validate(object)
# === Prepare a data frame for ggplot2 ===
if (is.null(curve_df)) {
curve_df <- ggplot2::fortify(object, ...)
}
if (!is.null(curvetype)) {
ctype <- curvetype
curve_df <- subset(curve_df, curvetype == ctype)
}
curve_df
}
#
# Load ggplot2
#
.load_ggplot2 <- function() {
if (!requireNamespace("ggplot2", quietly = TRUE)) {
stop(paste("This function should not be called directly,",
"and ggplot2 is needed to work.",
"Please install it."),
call. = FALSE)
}
}
#
# Load grid
#
.load_grid <- function() {
if (!requireNamespace("grid", quietly = TRUE)) {
stop("grid needed for this function to work. Please install it.",
call. = FALSE)
}
}
#
# Load gridExtra
#
.load_gridExtra <- function() {
if (!requireNamespace("gridExtra", quietly = TRUE)) {
stop("gridExtra needed for this function to work. Please install it.",
call. = FALSE)
}
}
#
# Plot ROC and Precision-Recall
#
.autoplot_multi <- function(object, arglist) {
curvetype <- arglist[["curvetype"]]
type <- arglist[["type"]]
show_cb <- arglist[["show_cb"]]
raw_curves <- arglist[["raw_curves"]]
add_np_nn <- arglist[["add_np_nn"]]
show_legend <- arglist[["show_legend"]]
ret_grob <- arglist[["ret_grob"]]
# === Check package availability ===
.load_ggplot2()
.validate(object)
.check_curvetype(curvetype, object)
.check_type(type)
.check_show_cb(show_cb, object)
.check_raw_curves(raw_curves, object)
.check_show_legend(show_legend)
.check_add_np_nn(add_np_nn)
.check_ret_grob(ret_grob)
# === Create a ggplot object for ROC&PRC, ROC, or PRC ===
curve_df <- ggplot2::fortify(object, raw_curves = raw_curves)
func_plot <- function(ctype) {
.autoplot_single(object, curve_df, curvetype = ctype, type = type,
show_cb = show_cb, raw_curves = raw_curves,
show_legend = show_legend, add_np_nn = add_np_nn)
}
lcurves <- lapply(curvetype, func_plot)
names(lcurves) <- curvetype
if (length(lcurves) > 1) {
do.call(.combine_plots, c(lcurves, show_legend = show_legend,
ret_grob = ret_grob, nplots = length(lcurves)))
} else {
lcurves[[1]]
}
}
#
# .grid_arrange_shared_legend
#
# Modified version of grid_arrange_shared_legend from RPubs
# URL of the original version:
# http://rpubs.com/sjackman/grid_arrange_shared_legend
#
.grid_arrange_shared_legend <- function(..., main_ncol = 2) {
plots <- list(...)
g <- ggplot2::ggplotGrob(plots[[1]]
+ ggplot2::theme(legend.position = "bottom"))$grobs
legend <- g[[which(sapply(g, function(x) x$name) == "guide-box")]]
lheight <- sum(legend$height)
fncol <- function (...) gridExtra::arrangeGrob(..., ncol = main_ncol)
fnolegend <- function(x) x + ggplot2::theme(legend.position = "none")
gridExtra::arrangeGrob(
do.call(fncol, lapply(plots, fnolegend)),
legend,
heights = grid::unit.c(grid::unit(1, "npc") - lheight, lheight),
ncol = 1)
}
#
# Combine ROC and Precision-Recall plots
#
.combine_plots <- function(..., show_legend, ret_grob, nplots) {
.load_grid()
.load_gridExtra()
if (nplots == 2 || nplots == 4) {
ncol <- 2
} else {
ncol <- 3
}
if (show_legend) {
grobframe <- .grid_arrange_shared_legend(..., main_ncol = ncol)
} else {
grobframe <- gridExtra::arrangeGrob(..., ncol = ncol)
}
if (ret_grob) {
grobframe
} else {
graphics::plot.new()
grid::grid.draw(grobframe)
}
}
#
# Plot ROC or Precision-Recall
#
.autoplot_single <- function(object, curve_df, curvetype = "ROC", type = "l",
show_cb = FALSE, raw_curves = FALSE,
show_legend = FALSE, add_np_nn = TRUE, ...) {
curve_df <- .prepare_autoplot(object, curve_df = curve_df,
curvetype = curvetype,
raw_curves = raw_curves, ...)
# === Create a ggplot object ===
if (raw_curves) {
p <- ggplot2::ggplot(curve_df,
ggplot2::aes_string(x = 'x', y = 'y',
group = 'dsid_modname',
color = 'modname'))
if (type == "l") {
p <- p + ggplot2::geom_line()
} else if (type == "b" || type == "p") {
if (type == "b") {
p <- p + ggplot2::geom_line(alpha = 0.25)
}
p <- p + ggplot2::geom_point()
}
} else if (show_cb) {
p <- ggplot2::ggplot(curve_df,
ggplot2::aes_string(x = 'x', y = 'y',
ymin = 'ymin', ymax = 'ymax'))
if (type == "l") {
p <- p + ggplot2::geom_smooth(ggplot2::aes_string(color = 'modname'),
stat = "identity")
} else if (type == "b" || type == "p") {
p <- p + ggplot2::geom_ribbon(ggplot2::aes_string(min = 'ymin',
ymax = 'ymax'),
stat = "identity", alpha = 0.25,
fill = "grey25")
if (type == "b") {
p <- p + ggplot2::geom_line(ggplot2::aes_string(color = 'modname'),
alpha = 0.25)
}
p <- p + ggplot2::geom_point(ggplot2::aes_string(x = 'x', y = 'y',
color = 'modname'))
}
} else {
p <- ggplot2::ggplot(curve_df, ggplot2::aes_string(x = 'x', y = 'y',
color = 'modname'))
if (type == "l") {
p <- p + ggplot2::geom_line()
} else if (type == "b" || type == "p") {
if (type == "b") {
p <- p + ggplot2::geom_line(alpha = 0.25)
}
p <- p + ggplot2::geom_point()
}
}
if (curvetype == "ROC") {
func_g <- .geom_basic_roc
} else if (curvetype == "PRC") {
func_g <- .geom_basic_prc
} else {
func_g <- .geom_basic_point
}
p <- func_g(p, object[[1]], show_legend = show_legend, add_np_nn = add_np_nn,
curve_df = curve_df, ...)
p
}
#
# Geom basic
#
.geom_basic <- function(p, main, xlab, ylab, show_legend) {
p <- p + ggplot2::theme_bw()
p <- p + ggplot2::ggtitle(main)
p <- p + ggplot2::xlab(xlab)
p <- p + ggplot2::ylab(ylab)
p <- p + ggplot2::theme(legend.title = ggplot2::element_blank())
if (!show_legend) {
p <- p + ggplot2::theme(legend.position = "none")
}
p
}
#
# Make main title
#
.make_rocprc_title <- function(object, pt) {
np <- attr(object, "data_info")[["np"]]
nn <- attr(object, "data_info")[["nn"]]
main <- paste0(pt, " - P: ", np, ", N: ", nn)
}
#
# Geom basic for ROC
#
.geom_basic_roc <- function(p, object, show_legend = TRUE, add_np_nn = TRUE,
...) {
if (add_np_nn) {
main <- .make_rocprc_title(object, "ROC")
} else {
main <- "ROC"
}
p <- p + ggplot2::geom_abline(intercept = 0, slope = 1, colour = "grey",
linetype = 3)
p <- p + ggplot2::coord_fixed(ratio = 1)
p <- .geom_basic(p, main, "1 - Specificity", "Sensitivity", show_legend)
p
}
#
# Geom_line for Precision-Recall
#
.geom_basic_prc <- function(p, object, show_legend = TRUE, add_np_nn = TRUE,
...) {
if (add_np_nn) {
main <- .make_rocprc_title(object, "Precision-Recall")
} else {
main <- "Precision-Recall"
}
np <- attr(object, "data_info")[["np"]]
nn <- attr(object, "data_info")[["nn"]]
p <- p + ggplot2::geom_hline(yintercept = np / (np + nn), colour = "grey",
linetype = 3)
p <- p + ggplot2::scale_y_continuous(limits = c(0.0, 1.0))
p <- p + ggplot2::coord_fixed(ratio = 1)
p <- .geom_basic(p, main, "Recall", "Precision", show_legend)
p
}
#
# Geom_line for Precision-Recall
#
.geom_basic_point <- function(p, object, show_legend = TRUE,
curve_df = curve_df, ...) {
s <- curve_df[["curvetype"]][1]
main <- paste0(toupper(substring(s, 1, 1)), substring(s,2))
p <- p + ggplot2::scale_y_continuous(limits = c(0.0, 1.0))
p <- p + ggplot2::coord_fixed(ratio = 1)
p <- .geom_basic(p, main, "threshold", s, show_legend)
p
}