https://doi.org/10.5281/zenodo.15731718
04_cross_validation.R
# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
## CROSS VALIDATION ============================================================
# Description:
# Run three different rolling-origin cross validation schemes on selected
# candidate models, where:
# - Scheme 1: Moving 12-month window with issue month 9 and target month 12
# - Scheme 2: Set year index to NA
# - Scheme 3: Remove yearly effect from random effects specification
# Paper:
# Compound and cascading effects of climatic extremes on dengue outbreak
# risk in the Caribbean: an impact-based modelling framework with long-lag
# and short-lag interactions
# Script authors:
# Chloe Fletcher ORCID: 0000-0002-6705-7605
# Dr Giovenale Moirano ORCID: 0000-0001-8748-3321
# Prof. Rachel Lowe ORCID: 0000-0003-3939-7343
# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
## Source packages, functions and data -----------------------------------------
# load packages and functions
source("functions/00_packages_functions.R")
# read in harmonised data
data <- read.csv("data/barbados_data.csv")
# load model output
spi6.mod.out <- readRDS("outputs/spi6_mods.rds")
candidate_mods <- c(23, 29, 30, 58, 59, 138, 173)
## Set model parameters --------------------------------------------------------
# set start year, start month and start index
startyear <- data$cal_year[1]
startmonth <- data$cal_month[1]
startid <- nrow(data) - 12 * 10 + 1 # last 10 years of dataset
# model covariates
vars.b <- c(NA)
vars.p <- spi6.mod.out$mod.gof$vars[candidate_mods]
# forecast lead time
lags.b <- c(3)
lags.p <- rep(3, length(vars.p))
# random effects
re1 <- paste("f(month, model = 'rw2', cyclic = TRUE, constr = TRUE,",
"scale.model = TRUE, hyper = precision.prior)")
re2 <- "f(year_index, model = 'rw2', hyper = precision.prior)"
fe <- "logdir.4"
res.b <- list(list(re1))
res.f <- list(list(re1, fe))
res.p <- list(list(re1, re2, fe))
## Run cross validation --------------------------------------------------------
# run cross validation for baseline model
# source("functions/out_of_sample_progressive_adjusted.R")
# pred.b <- out_of_sample(data = data, response = "cases", save = T,
# vars = vars.b, res = res.b, lags = lags.b,
# start_id = startid, source_path = "functions",
# file_path = "outputs", file_name = "preds_b")
# run cross validation for candidate models
# scheme 1: moving 12-month window with issue month 9 and target month 12
# source("functions/out_of_sample_progressive_adjusted.R")
# pred.i.ay <- out_of_sample(data = data, response = "cases", save = T,
# vars = vars.p, res = res.p, lags = lags.p,
# start_id = startid, source_path = "functions",
# file_path = "outputs", file_name = "preds_i_adjustedyear_logdir")
# run cross validation for candidate models
# scheme 2: set year index to NA
# source("functions/out_of_sample_progressive.R")
# pred.i.na <- out_of_sample(data = data, response = "cases", save = T,
# vars = vars.p, res = res.p, lags = lags.p,
# start_id = startid, source_path = "functions",
# file_path = "outputs", file_name = "preds_i_nayear_logdir")
# run cross validation for candidate models
# scheme 3: remove yearly effect from REs
# source("functions/out_of_sample_progressive_adjusted.R")
# pred.i.ny <- out_of_sample(data = data, response = "cases", save = T,
# vars = vars.p, res = res.f, lags = lags.p,
# start_id = startid, source_path = "functions",
# file_path = "outputs", file_name = "preds_i_noyeareffect_logdir")
pred.b <- readRDS("outputs/cross_validation/preds_b.rds")
pred.i.ay <- readRDS("outputs/cross_validation/preds_i_adjustedyear_logdir.rds")
pred.i.na <- readRDS("outputs/cross_validation/preds_i_nayear_logdir.rds")
pred.i.ny <- readRDS("outputs/cross_validation/preds_i_noyeareffect_logdir.rds")
## Tabulate results ------------------------------------------------------------
# tabulate hit rate, false alarm rate, auc, trigger and crps
tab <- rbind(pred.b$predictive_ability, pred.i.ay$predictive_ability,
pred.i.na$predictive_ability, pred.i.ny$predictive_ability)
tab$scheme <- c(NA, rep(1:3, each = length(vars.p)))
tab$id <- c(0, rep(1:length(vars.p), 3))
tab <- tab %>% relocate(scheme, .after = model)
tab <- tab[with(tab, order(id, scheme)),]
row.names(tab) <- NULL
## Plot final selected model ---------------------------------------------------
# select pred.i
pred.i <- pred.i.ay
# final model selected
i <- 4
# plot observed vs fitted
g1 <- plot.fit.cv(pred.i, startyear, startmonth, i)
g1_b <- plot.fit.cv(pred.b, startyear, startmonth, 1, colour = "turquoise4")
# plot outbreak probability
g2 <- plot.prob.outbreak.cv(pred.i, data[startid, "cal_year"], startmonth, i,
threshold = pred.i$predictive_ability[i, "trigger"])
g2_b <- plot.prob.outbreak.cv(pred.b, data[startid, "cal_year"], startmonth, 1,
threshold = pred.b$predictive_ability[1, "trigger"])
# plot double roc curve with current practice model
rocplt <- roc(pred.i[["outbreak_probability"]][[i]][["obs_outbreak"]],
pred.i[["outbreak_probability"]][[i]][["pred_outbreak_prob"]],
direction = "<", ci = TRUE, quiet = TRUE)
ciobj <- ci.se(rocplt, specificities = seq(0, 1, l=100))
ciobj <- data.frame(x = as.numeric(rownames(ciobj)),
lower = ciobj[,1], upper = ciobj[,3])
rocplt2 <- roc(pred.b[["outbreak_probability"]][[1]][["obs_outbreak"]],
pred.b[["outbreak_probability"]][[1]][["pred_outbreak_prob"]],
direction = "<", ci = TRUE, quiet = TRUE)
ciobj2 <- ci.se(rocplt2, specificities=seq(0, 1, l=100))
ciobj2 <- data.frame(x = as.numeric(rownames(ciobj2)),
lower = ciobj2[,1], upper = ciobj2[,3])
roc.test(rocplt2, rocplt, method = "delong")
g3 <- plot_roc_2(rocplt, rocplt2, ciobj, ciobj2, "")
# assemble as single plot
ggdraw() +
draw_plot(g1, x = 0, y = .6, width = 1, height = .4) +
draw_plot(g2, x = 0, y = 0, width = .5, height = .6) +
draw_plot(g3, x = .5, y = 0, width = .5, height = .6) +
draw_plot_label(label = c("A", "B", "C"), size = 18,
x = c(0, 0, .5), y = c(1, .58, .58)) +
theme_void() +
theme(plot.background = element_rect(fill = "white")) +
panel_border(remove = TRUE)
#ggsave("outputs/figs/F_02.png", dpi=600, height=12, width=15, bg="white")
ggdraw() +
draw_plot(g1_b, x = 0, y = .6, width = 1, height = .4) +
draw_plot(g2_b, x = 0, y = 0, width = .5, height = .6) +
draw_plot(g3, x = .5, y = 0, width = .5, height = .6) +
draw_plot_label(label = c("A", "B", "C"), size = 18,
x = c(0, 0, .5), y = c(1, .58, .58)) +
theme_void() +
theme(plot.background = element_rect(fill = "white")) +
panel_border(remove = TRUE)
#ggsave("outputs/figs/SF_07.png", dpi=600, height=12, width=15, bg="white")
## Cross validation for models of reduced complexity ---------------------------
# model covariates
vars.v <- c(NA,
rep(c(gsub("\\*", "+", spi6.mod.out$mod.gof$vars[candidate_mods[i]]),
spi6.mod.out$mod.gof$vars[candidate_mods[i]]), 3, length.out = 5),
NA, NA)
# forecast lead time
lags.v <- rep(3, length(vars.v))
# labels for model runs
labels.v <- c("seasonal_interannual",
"temp+longspi+shortspi",
"temp*longspi*shortspi",
"seasonal_interannual_temp+longspi+shortspi",
"seasonal_interannual_temp*longspi*shortspi",
"seasonal_interannual_temp+longspi+shortspi_dir",
"dir",
"seasonal_interannual_dir")
# random effects
res.v <- list(list(re1, re2), list(NA), list(NA), list(re1, re2), list(re1, re2),
list(re1, re2, fe), list(fe), list(re1, re2, fe))
# run cross validation for variants of reduced complexity
# scheme 1: moving 12-month window with issue month 9 and target month 12
# source("functions/out_of_sample_progressive_adjusted.R")
# pred.v <- out_of_sample(data = data, response = "cases", save = T, vars = vars.v,
# res = res.v, lags = lags.v, labels = labels.v,
# start_id = startid, source_path = "functions",
# file_path = "outputs", file_name = "preds_v_adjustedyear_logdir")
pred.v <- readRDS("outputs/cross_validation/preds_v_adjustedyear_logdir.rds")
# tabulate hit rate, false alarm rate, auc, trigger and crps
tab.v <- rbind(pred.b$predictive_ability, pred.v$predictive_ability,
pred.i.ay$predictive_ability[i, ])
tab.v$model <- c("seasonal (baseline)", labels.v,
"seasonal_interannual_temp*longspi*shortspi_dir")
tab.v <- tab.v[c(1:2, 8, 3:6, 9, 7, 10), ]
row.names(tab.v) <- NULL
## Compare interaction vs additive model ---------------------------------------
# specify baseline, additive and interaction models
j <- 6
pred.fm <- pred.i$predictive_ability[i, ]
pred.bl <- pred.b$predictive_ability[1, ]
pred.am <- pred.v$predictive_ability[j, ]
# specify outbreaks
n_ob <- sum(pred.i$outbreak_probability[[i]]$obs_outbreak == 1)
# calculate true positive, false positive, true negative, false negative rates
signal <- data.frame(model = rep(c("Interaction", "Additive", "Baseline"),
each=4),
condition = rep(c("4 Hit", "3 Correct Rejection",
"2 False Alarm", "1 Miss"), 3),
value = c(pred.fm$tpr * n_ob / 120,
(1 - pred.fm$fpr) * (120 - n_ob) / 120,
pred.fm$fpr * (120 - n_ob) / 120,
(1 - pred.fm$tpr) * n_ob / 120,
pred.am$tpr * n_ob / 120,
(1 - pred.am$fpr) * (120 - n_ob) / 120,
pred.am$fpr * (120 - n_ob) / 120,
(1 - pred.am$tpr) * n_ob / 120,
pred.bl$tpr * n_ob / 120,
(1 - pred.bl$fpr) * (120 - n_ob) / 120,
pred.bl$fpr * (120 - n_ob) / 120,
(1 - pred.bl$tpr) * n_ob / 120))
signal[,"value"] <- round(signal[,"value"] * 100, 4)
# plot stacked bar chart for baseline vs additive vs interaction models
cols <- c("#AA1155", "#F8B4D3", "#B4E4DA", "#369683")
ggplot(data=signal, aes(x=factor(model,
levels=c("Baseline", "Additive",
"Interaction")),
y=value, fill=condition)) +
geom_bar(stat="identity") +
scale_fill_manual("", values = cols, labels = c("False Negative",
"False Positive",
"True Negative",
"True Positive")) +
xlab("Model Formulation") +
ylab("Percentage of Forecasts (%)") +
theme_bw() +
theme(text = element_text(size = 16),
plot.background = element_rect(fill = "white"),
panel.border = element_blank(), axis.ticks.x=element_blank(),
axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0)),
axis.title.x = element_text(margin = margin(t = 20, r = 0, b = 0, l = 0)))
#ggsave("outputs/figs/F_03.png", dpi=600, height=12, width=12)
## END