https://github.com/cran/brms
Raw File
Tip revision: 8588ac33825f2c759d020f917a29eaea2a0fc073 authored by Paul-Christian Bürkner on 19 January 2020, 19:00:21 UTC
version 2.11.1
Tip revision: 8588ac3
model_weights.R
#' Model Weighting Methods
#' 
#' Compute model weights in various ways, for instance via
#' stacking of predictive distributions, Akaike weights, or
#' marginal likelihoods.
#' 
#' @inheritParams loo.brmsfit
#' @param weights Name of the criterion to compute weights from. Should be one
#'   of \code{"loo"}, \code{"waic"}, \code{"kfold"}, \code{"stacking"} (current
#'   default), or \code{"bma"}, \code{"pseudobma"}, For the former three
#'   options, Akaike weights will be computed based on the information criterion
#'   values returned by the respective methods. For \code{"stacking"} and
#'   \code{"pseudobma"} method \code{\link{loo_model_weights}} will be used to
#'   obtain weights. For \code{"bma"}, method \code{\link{post_prob}} will be
#'   used to compute Bayesian model averaging weights based on log marginal
#'   likelihood values (make sure to specify reasonable priors in this case).
#'   Some some method, \code{weights} may also be be a numeric vector of
#'   pre-specified weights.
#'   
#' @return A numeric vector of weights for the models.
#'   
#' @examples 
#' \dontrun{
#' # model with 'treat' as predictor
#' fit1 <- brm(rating ~ treat + period + carry, data = inhaler)
#' summary(fit1)
#' 
#' # model without 'treat' as predictor
#' fit2 <- brm(rating ~ period + carry, data = inhaler)
#' summary(fit2)
#' 
#' # obtain Akaike weights based on the WAIC
#' model_weights(fit1, fit2, weights = "waic")
#' }
#' 
#' @export
model_weights.brmsfit <- function(x, ..., weights = "stacking", 
                                  model_names = NULL) {
  weights <- validate_weights_method(weights)
  args <- split_dots(x, ..., model_names = model_names)
  models <- args$models
  args$models <- NULL
  model_names <- names(models)
  if (weights %in% c("loo", "waic", "kfold")) {
    # Akaike weights based on information criteria
    ics <- rep(NA, length(models))
    for (i in seq_along(ics)) {
      args$x <- models[[i]]
      args$model_names <- names(models)[i]
      ics[i] <- SW(do_call(weights, args))$estimates[3, 1]
    }
    ic_diffs <- ics - min(ics)
    out <- exp(-ic_diffs / 2)
  } else if (weights %in% c("stacking", "pseudobma")) {
    args <- c(unname(models), args)
    args$method <- weights
    out <- do_call("loo_model_weights", args)
  } else if (weights %in% "bma") {
    args <- c(unname(models), args)
    out <- do_call("post_prob", args)
  }
  out <- as.numeric(out)
  out <- out / sum(out)
  names(out) <- model_names
  out
}

#' @rdname model_weights.brmsfit
#' @export
model_weights <- function(x, ...) {
  UseMethod("model_weights")
}

# validate name of the applied weighting method
validate_weights_method <- function(method) {
  method <- as_one_character(method)
  method <- tolower(method)
  if (method == "loo2") {
    warning2("Weight method 'loo2' is deprecated. Use 'stacking' instead.")
    method <- "stacking"
  }
  if (method == "marglik") {
    warning2("Weight method 'marglik' is deprecated. Use 'bma' instead.")
    method <- "bma"
  }
  options <- c("loo", "waic", "kfold", "stacking", "pseudobma", "bma")
  match.arg(method, options)
}

#' Posterior predictive samples averaged across models
#' 
#' Compute posterior predictive samples averaged across models.
#' Weighting can be done in various ways, for instance using
#' Akaike weights based on information criteria or 
#' marginal likelihoods.
#' 
#' @inheritParams model_weights.brmsfit
#' @param method Method used to obtain predictions to average over. Should be
#'   one of \code{"posterior_predict"} (default), \code{"pp_expect"}, or
#'   \code{"predictive_error"}.
#' @param control Optional \code{list} of further arguments 
#'   passed to the function specified in \code{weights}.
#' @param nsamples Total number of posterior samples to use.
#' @param seed A single numeric value passed to \code{\link{set.seed}}
#'   to make results reproducible.
#' @param summary Should summary statistics 
#'   (i.e. means, sds, and 95\% intervals) be returned
#'  instead of the raw values? Default is \code{TRUE}.
#' @param robust If \code{FALSE} (the default) the mean is used as 
#'  the measure of central tendency and the standard deviation as 
#'  the measure of variability. If \code{TRUE}, the median and the 
#'  median absolute deviation (MAD) are applied instead.
#'  Only used if \code{summary} is \code{TRUE}.
#' @param probs  The percentiles to be computed by the \code{quantile} 
#'  function. Only used if \code{summary} is \code{TRUE}. 
#' 
#' @return Same as the output of the method specified 
#'   in argument \code{method}.
#'   
#' @details Weights are computed with the \code{\link{model_weights}} method.
#'   
#' @seealso \code{\link{model_weights}}, \code{\link{posterior_average}}
#'   
#' @examples 
#' \dontrun{
#' # model with 'treat' as predictor
#' fit1 <- brm(rating ~ treat + period + carry, data = inhaler)
#' summary(fit1)
#' 
#' # model without 'treat' as predictor
#' fit2 <- brm(rating ~ period + carry, data = inhaler)
#' summary(fit2)
#' 
#' # compute model-averaged predicted values
#' (df <- unique(inhaler[, c("treat", "period", "carry")]))
#' pp_average(fit1, fit2, newdata = df)
#' 
#' # compute model-averaged fitted values
#' pp_average(fit1, fit2, method = "fitted", newdata = df)
#' }
#' 
#' @export
pp_average.brmsfit <- function(
  x, ..., weights = "stacking", method = "posterior_predict",
  nsamples = NULL, summary = TRUE, probs = c(0.025, 0.975), robust = FALSE,
  model_names = NULL, control = list(), seed = NULL
) {
  if (!is.null(seed)) {
    set.seed(seed) 
  }
  method <- validate_pp_method(method)
  if ("subset" %in% names(list(...))) {
    stop2("Cannot use argument 'subset' in pp_average.")
  }
  args <- split_dots(x, ..., model_names = model_names)
  args$summary <- FALSE
  models <- args$models
  args$models <- NULL
  if (!match_response(models)) {
    stop2("Can only average models predicting the same response.")
  }
  if (is.null(nsamples)) {
    nsamples <- nsamples(models[[1]])
  }
  weights <- validate_weights(weights, models, control)
  nsamples <- round_largest_remainder(weights * nsamples)
  names(weights) <- names(nsamples) <- names(models)
  out <- named_list(names(models))
  for (i in seq_along(out)) {
    if (nsamples[i] > 0) {
      args$object <- models[[i]]
      args$nsamples <- nsamples[i]
      out[[i]] <- do_call(method, args)
    }
  }
  out <- do_call(rbind, out)
  if (summary) {
    out <- posterior_summary(out, probs = probs, robust = robust) 
  }
  attr(out, "weights") <- weights
  attr(out, "nsamples") <- nsamples
  out
}

#' @rdname pp_average.brmsfit
#' @export
pp_average <- function(x, ...) {
  UseMethod("pp_average")
}

# validate weights passed to model averaging functions
# see pp_average.brmsfit for more documentation
validate_weights <- function(weights, models, control = list()) {
  if (!is.numeric(weights)) {
    weight_args <- c(unname(models), control)
    weight_args$weights <- weights
    weights <- do_call(model_weights, weight_args)
  } else {
    if (length(weights) != length(models)) {
      stop2("If numeric, 'weights' must have the same length ",
            "as the number of models.")
    }
    if (any(weights < 0)) {
      stop2("If numeric, 'weights' must be positive.")
    }
  }
  weights / sum(weights)
}

#' Posterior samples of parameters averaged across models
#' 
#' Extract posterior samples of parameters averaged across models.
#' Weighting can be done in various ways, for instance using
#' Akaike weights based on information criteria or 
#' marginal likelihoods.
#' 
#' @inheritParams pp_average.brmsfit
#' @param pars Names of parameters for which to average across models.
#'   Only those parameters can be averaged that appear in every model.
#'   Defaults to all overlapping parameters.
#' @param missing An optional numeric value or a named list of numeric values 
#'   to use if a model does not contain a parameter for which posterior samples 
#'   should be averaged. Defaults to \code{NULL}, in which case only those
#'   parameters can be averaged that are present in all of the models.
#' 
#' @return A \code{data.frame} of posterior samples. Samples are rows
#'   and parameters are columns.
#' 
#' @details Weights are computed with the \code{\link{model_weights}} method.
#' 
#' @seealso \code{\link{model_weights}}, \code{\link{pp_average}}
#'   
#' @examples 
#' \dontrun{
#' # model with 'treat' as predictor
#' fit1 <- brm(rating ~ treat + period + carry, data = inhaler)
#' summary(fit1)
#' 
#' # model without 'treat' as predictor
#' fit2 <- brm(rating ~ period + carry, data = inhaler)
#' summary(fit2)
#' 
#' # compute model-averaged posteriors of overlapping parameters
#' posterior_average(fit1, fit2, weights = "waic")
#' }
#' 
#' @export
posterior_average.brmsfit <- function(
  x, ..., pars = NULL, weights = "stacking", nsamples = NULL,
  missing = NULL, model_names = NULL, control = list(),
  seed = NULL
) {
  if (!is.null(seed)) {
    set.seed(seed) 
  }
  models <- split_dots(x, ..., model_names = model_names, other = FALSE)
  pars_list <- lapply(models, parnames)
  all_pars <- unique(unlist(pars_list))
  if (is.null(missing)) {
    common_pars <- lapply(pars_list, function(x) all_pars %in% x)
    common_pars <- all_pars[Reduce("&", common_pars)]
    if (is.null(pars)) {
      pars <- setdiff(common_pars, "lp__")
    }
    pars <- as.character(pars)
    inv_pars <- setdiff(pars, common_pars)
    if (length(inv_pars)) {
      inv_pars <- collapse_comma(inv_pars)
      stop2(
        "Parameters ", inv_pars, " cannot be found in all ", 
        "of the models. Consider using argument 'missing'."
      )
    }
  } else {
    if (is.null(pars)) {
      pars <- setdiff(all_pars, "lp__")
    }
    pars <- as.character(pars)
    inv_pars <- setdiff(pars, all_pars)
    if (length(inv_pars)) {
      inv_pars <- collapse_comma(inv_pars)
      stop2("Parameters ", inv_pars, " cannot be found in any of the models.")
    }
    if (is.list(missing)) {
      all_miss_pars <- unique(ulapply(
        models, function(m) setdiff(pars, parnames(m))
      ))
      inv_pars <- setdiff(all_miss_pars, names(missing))
      if (length(inv_pars)) {
        stop2("Argument 'missing' has no value for parameters ",
              collapse_comma(inv_pars), ".")
      }
      missing <- lapply(missing, as_one_numeric, allow_na = TRUE)
    } else {
      missing <- as_one_numeric(missing, allow_na = TRUE)
      missing <- named_list(pars, missing)
    }
  }
  if (is.null(nsamples)) {
    nsamples <- nsamples(models[[1]])
  }
  weights <- validate_weights(weights, models, control)
  nsamples <- round_largest_remainder(weights * nsamples)
  names(weights) <- names(nsamples) <- names(models)
  out <- named_list(names(models))
  for (i in seq_along(out)) {
    if (nsamples[i] > 0) {
      subset <- sample(seq_len(nsamples(models[[i]])), nsamples[i])
      subset <- sort(subset)
      ps <- posterior_samples(
        models[[i]], pars = pars, 
        subset = subset, fixed = TRUE
      )
      if (!is.null(ps)) {
        out[[i]] <- ps
      } else {
        out[[i]] <- as.data.frame(matrix(
          numeric(0), nrow = nsamples[i], ncol = 0
        ))
      }
      if (!is.null(missing)) {
        miss_pars <- setdiff(pars, names(out[[i]]))
        if (length(miss_pars)) {
          out[[i]][miss_pars] <- missing[miss_pars]
        }
      }
    }
  }
  out <- do_call(rbind, out)
  rownames(out) <- NULL
  attr(out, "weights") <- weights
  attr(out, "nsamples") <- nsamples
  out
}

#' @rdname posterior_average.brmsfit
#' @export
posterior_average <- function(x, ...) {
  UseMethod("posterior_average")
}
back to top