https://github.com/cran/cutpointr
Tip revision: 7e56c827a694247d212e9a0167a119f917e1f31b authored by Christian Thiele on 31 August 2018, 15:50:10 UTC
version 0.7.4
version 0.7.4
Tip revision: 7e56c82
plot_precision_recall.R
#' Precision recall plot from a cutpointr object
#'
#' Given a \code{cutpointr} object this function plots the precision recall curve(s)
#' per subgroup, if given.
#' @param x A cutpointr object.
#' @param display_cutpoint (logical) Whether or not to display the optimal
#' cutpoint as a dot on the precision recall curve.
#' @param ... Additional arguments (unused).
#' @examples
#' library(cutpointr)
#'
#' ## Optimal cutpoint for dsi
#' data(suicide)
#' opt_cut <- cutpointr(suicide, dsi, suicide)
#' plot_precision_recall(opt_cut)
#' @family cutpointr plotting functions
#' @export
plot_precision_recall <- function(x, display_cutpoint = TRUE, ...) {
stopifnot("cutpointr" %in% class(x))
args <- list(...)
if (!(has_column(x, "subgroup"))) {
dts_pr <- "roc_curve"
fll <- NULL
clr <- NULL
clr_pr <- NULL
transparency <- 1
} else {
dts_pr <- c("roc_curve", "subgroup")
fll <- "subgroup"
clr <- "subgroup"
clr_pr <- ~ subgroup
transparency <- 0.6
}
if (!(has_column(x, "subgroup"))) {
plot_title <- ggplot2::ggtitle("Precision recall plot")
} else {
plot_title <- ggplot2::ggtitle("Precision recall plot", "by class")
}
for (r in 1:nrow(x)) {
x$roc_curve[[r]] <- x$roc_curve[[r]] %>%
dplyr::mutate_(Precision = ~ tp / (tp + fp),
Recall = ~ tp / (tp + fn))
}
if (display_cutpoint) {
optcut_coords <- apply(x, 1, function(r) {
opt_ind <- get_opt_ind(roc_curve = r$roc_curve,
oc = r$optimal_cutpoint,
direction = r$direction)
data.frame(Precision = r$roc_curve$Precision[opt_ind],
Recall = r$roc_curve$Recall[opt_ind])
})
optcut_coords <- do.call(rbind, optcut_coords)
}
res_unnested <- x %>%
dplyr::select_(.dots = dts_pr) %>%
tidyr::unnest_(unnest_cols = "roc_curve")
res_unnested <- res_unnested[is.finite(res_unnested$x.sorted), ]
pr <- ggplot2::ggplot(res_unnested,
ggplot2::aes_(x = ~ Recall, y = ~ Precision, color = clr_pr)) +
ggplot2::geom_line() +
plot_title +
ggplot2::xlab("Recall") +
ggplot2::ylab("Precision") +
ggplot2::theme(aspect.ratio = 1) +
ggplot2::coord_cartesian(xlim = c(0, 1), ylim = c(0, 1))
if (display_cutpoint) {
pr <- pr + ggplot2::geom_point(data = optcut_coords, color = "black")
}
return(pr)
}