Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

  • 901326c
  • /
  • modular_diags.R
Raw File Download

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
content badge
swh:1:cnt:e6a06eb33c915211acdedd03bdb75aaa282e6f6c
directory badge
swh:1:dir:901326ca4b947cf045195ee799572e4832d86525

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
modular_diags.R
# Helper Functions --------------------------------------------------------

# Simple function for passing list arguments to mapply
papply <- function(.l, .f, ...) {
  args <- c(.l, list(FUN = .f, MoreArgs = list(...), SIMPLIFY = FALSE))
  return(do.call(mapply, args = args))
}

get_HLSM_type <- function(object_list) {
  calls <- lapply(object_list, getCall)
  funcs <- sapply(lapply(calls, `[[`, 1), as.character)
  type <- unique(gsub('HLSM(.*)EF', '\\1', funcs))
  if (length(type) > 1) {
    stop("HLSM list must be all of the same type.")
  } else if (!(type %in% c('fixed', 'random'))) {
    stop("Unknown HLSM type found in object.")
  }
  return(type)
}


# Draws Extraction & Conversion --------------------------------------------

extract_param <- function(chain, type, burnin = 0, thin = 1) {
  # in random, niter X nnet matrix
  inter_draws <- getIntercept(chain, burnin = burnin, thin = thin)
  # in random, niter X nvar X nnet matrix
  beta_draws <- getBeta(chain, burnin = burnin, thin = thin)
  
  # Creating shape and dimnames to pass to array creation, with goal of
  # binding intercept to beta array along the "variable" axis.
  beta_shape <- dim(beta_draws)
  beta_dnames <- list(
    iterations = seq_len(beta_shape[1]),
    variables = paste0('X', seq_len(beta_shape[2]))
  )
  inter_shape <- beta_shape
  inter_shape[2] <- 1
  inter_dnames <- list(
    iterations = seq_len(inter_shape[1]),
    variables = 'Intercept'
  )
  
  # If model is not random effects, it is fixed across network
  if (type == "random") {
    beta_dnames$network <- paste0('Net', seq_len(beta_shape[3]))
    inter_dnames$network <- paste0('Net', seq_len(inter_shape[3]))
  } else if (type != "fixed") {
    stop("Type must be either 'fixed' or 'random'")
  }
  
  # Apply dnames and new variable dimension to intercept array, then bind along
  # the variable dimension.
  beta_array <- array(beta_draws, dim = beta_shape, dimnames = beta_dnames)
  inter_array <- array(inter_draws, dim = inter_shape, dimnames = inter_dnames)
  param_array <- abind(inter_array, beta_array, along = 2)
  
  return(param_array)
}

param_to_mcmc <- function(param) {
  param_df <- as.data.frame(param)
  param_mcmc <- as.mcmc(param_df)
  return(param_mcmc)
}


# PSRF Functions ----------------------------------------------------------

psrf_param <- function(param_mcmc_list, warn_1chain = TRUE) {
  result <- NULL
  if (nchain(param_mcmc_list) > 1) {
    result <- gelman.diag(param_mcmc_list, autoburnin = FALSE)
  } else if (warn_1chain) {
    warning("You have only provided one chain. PSRF is not available and will ",
            "not be included in the output.")
  }
  return(result)
}

psrf_summary <- function(result) {
  psrf_table <- result$psrf
  
  mask <- psrf_table[, 'Point est.'] > 1.05
  if (sum(mask) == 0) {
    mask <- which.max(psrf_table[, 'Upper C.I.'])
  }
  # if only 1 row, doesn't reduce to vector
  bad_psrf_table <- psrf_table[mask, , drop = FALSE]
  
  max_lim_psrf <- max(psrf_table[, 'Upper C.I.'])
  multi_psrf <- result$mpsrf
  
  cat("Potential Scale Reduction Factor:\n")
  cat("Gelman-Rubin between-chain convergence diagnostic.\n")
  cat("Upper C.I. near 1 indicates convergence.\n")
  cat("---\n")
  cat("Variable(s) with worst convergence\n")
  print(bad_psrf_table)
  cat("---\n")
  cat("Maximum Upper C.I.: ", round(max_lim_psrf, 4), '\n')
  cat("Multivariate PSRF Point Estimate: ", multi_psrf, '\n')
}

# Raftery Functions -------------------------------------------------------

raftery_param <- function(param_mcmc) {
  result <- NULL
  if (dim(param_mcmc)[1] <= 3746) {
    warning(
      "The chain length is less than the raftery diagnostic minimum length of ",
      "3746.\n",
      "If would like the raftery diagnostic information, ensure the chain ",
      "length of > 3746 iterations."
    )
  } else {
    result <- as.data.frame(raftery.diag(param_mcmc)$resmatrix)
    result$Nmin <- NULL
    colnames(result) <- c("burnin", "niters", "thinning")
  }
  return(result)
}

raftery_summary <- function(results) {
  longest_stats <- lapply(results, apply, 2, max)
  chain_stats <- do.call(rbind, longest_stats)
  rownames(chain_stats) <- paste("Chain", seq_along(results))
  cat("Raftery Diagnostics:\n")
  cat("Suggested Chain Specifications\n")
  print(chain_stats[, c(2, 1, 3)])
}


# Plotting Functions ------------------------------------------------------

plot_shape <- function(param_mcmc, type) {
  if (type == "fixed") {
    n_vars <- length(varnames(param_mcmc))
    nrows <- min(4, n_vars)
    ncols <- 1
    plotdex <- seq_len(n_vars)
  } else if (type == "random") {
    # Determine the number of distinct variables nad networks in the mcmc object
    vars <- varnames(param_mcmc)
    var_splits <- strsplit(vars, '.Net')
    uvars <- unique(lapply(var_splits, `[`, 1))
    unets <- unique(lapply(var_splits, `[`, 2))
    n_uvars <- length(uvars)
    n_unets <- length(unets)
    
    # set the number of columns and rows for the plot
    nrows <- min(4, n_unets)
    ncols <- min(3, n_uvars)
    
    # create a matrix of indices to control the plot order. this will plot each
    # variable in its own row. the networks will be spaced throughout all of the
    # networks
    big_dex_mat <- t(matrix(seq_len(n_uvars * n_unets), ncol = n_unets))
    netdex <- floor(seq(1, n_unets, length.out = ncols))
    dex_mat <- big_dex_mat[netdex,]
    plotdex <- as.vector(dex_mat)
  } else {
    stop("Type must be either 'fixed' or 'random'.")
  }
  
  return(list(nrows = nrows, ncols = ncols, plotdex = plotdex))
}

param_get_acf <- function(param_mcmc) {
  param_ts <- apply(param_mcmc, 2, as.ts)
  results <- apply(param_ts, 2, acf, plot = FALSE)
  lags <- lapply(results, `[[`, 'lag')
  acfs <- lapply(results, `[[`, 'acf')
  return(list(lag = lags, acf = acfs))
}

autocorr_param <- function(param_mcmc_list, col = 1:6, lty = 1) {
  vars <- varnames(param_mcmc_list)
  acf_results <- lapply(param_mcmc_list, param_get_acf)
  acf_results_t <- papply(acf_results, list)
  lags_bind <- papply(acf_results_t$lag, cbind)
  acf_bind <- papply(acf_results_t$acf, cbind)
  for (i in seq_len(nvar(param_mcmc_list))) {
    main <- paste("ACF Plot of", vars[i])
    nchains <- nchain(param_mcmc_list)
    plot_lags <- jitter(lags_bind[[i]], ifelse(nchains - 1, 2, 0))
    matplot(plot_lags, acf_bind[[i]], type = 'h', col = col, lty = lty,
            main = main)
  }
}


# Main Function -----------------------------------------------------------

HLSMdiag <- function(object, burnin = 0,
                     diags = c('psrf', 'raftery', 'traceplot', 'autocorr'),
                     col = 1:6, lty = 1) {
  if (is(object, 'HLSM')) {
    object_list <- list(object)
  } else if (is(object, 'list')) {
    object_list <- object
  } else {
    stop("object must be single HLSM chain or list of HLSM chains")
  }
  
  warn_1chain <- !missing(diags)
  # the default behavior is to return all information
  diags <- match.arg(diags, several.ok = TRUE)
  
  type <- get_HLSM_type(object_list)
  param <- lapply(object_list, extract_param, type = type, burnin = burnin)
  param_mcmc_list <- as.mcmc.list(lapply(param, param_to_mcmc))
  
  output <- list(call = match.call())
  if ('psrf' %in% diags) {
    # will omit warning if user omitted the diags argument, and therefore
    # did not explicitly ask for PSRF
    psrf_attrs <- psrf_param(param_mcmc_list, warn_1chain = warn_1chain)
    if (!is.null(psrf_attrs)) {
      output <- c(output, psrf = list(psrf_attrs))
    }
  }
  
  if ('raftery' %in% diags) {
    raft_attrs <- lapply(param_mcmc_list, raftery_param)
    if (!is.null(raft_attrs)) {
      output <- c(output, raftery = list(raft_attrs))
    }
  }
  
  if (('traceplot' %in% diags || 'autocorr' %in% diags) &&
      (nchain(param_mcmc_list) > 1) && missing(col)) {
    chain_dex <- seq_along(param_mcmc_list)
    legend <- paste("Chain", chain_dex, "=", grDevices::palette()[chain_dex],
                    collapse = '\n')
    message("Plot Color Legend:\n", legend)
  }
  
  if ('traceplot' %in% diags) {
    shape_args <- plot_shape(param_mcmc_list, type = type)
    par(mfrow = c(shape_args$nrows, shape_args$ncols), mar = rep(2, 4))
    traceplot(param_mcmc_list[, shape_args$plotdex], 
              col = col, lty = lty)
  }
  
  if ('autocorr' %in% diags) {
    shape_args <- plot_shape(param_mcmc_list, type = type)
    par(mfrow = c(shape_args$nrows, shape_args$ncols), mar = rep(2, 4))
    autocorr_param(param_mcmc_list[, shape_args$plotdex], col = col, lty = lty)
  }
  if (length(output) > 1) {
    class(output) <- "HLSMdiag"
    return(output)
  }
}


# Output Printing ---------------------------------------------------------

call_summary <- function(call) {
  cat("\nCall:\n", paste(deparse(call), sep = "\n", collapse = "\n"), 
      "\n", sep = "")
}

print.HLSMdiag <- function(x, ...) {
  # allows this function to be flexible to adding more diagnostic summaries
  summary_funcs <- list(call = call_summary,
                        psrf = psrf_summary,
                        raftery = raftery_summary)
  if (!is(x, 'HLSMdiag')) {
    stop("This function does not work on non-HLSMdiag objects.")
  }
  
  if (!all(names(x) %in% names(summary_funcs))) {
    stop("HLSMdiag function updated without updating print.HLSMdiag.\n",
         "Please contact the maintainer to fix")
  }
  
  for (el in names(x)) {
    func <- summary_funcs[[el]]
    obj <- x[[el]]
    func(obj)
    cat('\n')
  }
  
  cat("Detailed Information:\n")
  cat("To review detailed diagnostic information for each variable,\n")
  cat("access this object as a list with `$`.\n")
}

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API