Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://doi.org/10.5281/zenodo.15731718
24 June 2025, 17:15:41 UTC
  • Code
  • Branches (0)
  • Releases (1)
  • Visits
    • Branches
    • Releases
      • 1
      • 1
    • 7c9d12f
    • /
    • 04_cross_validation.R
    Raw File Download

    To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
    Select below a type of object currently browsed in order to display its associated SWHID and permalink.

    • content
    • directory
    • snapshot
    • release
    origin badgecontent badge
    swh:1:cnt:0480c81b36ca54af852f8b40137ba375af5cfdb2
    origin badgedirectory badge
    swh:1:dir:7c9d12f286334bcec2d56e656b51bedec0450877
    origin badgesnapshot badge
    swh:1:snp:caa5e5c1373a152e7394b87e885b1339f8fc68ec
    origin badgerelease badge
    swh:1:rel:b8949b6da5a3c1d416881fc3e15d0848afff9f2d

    This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
    Select below a type of object currently browsed in order to generate citations for them.

    • content
    • directory
    • snapshot
    • release
    (requires biblatex-software package)
    Generating citation ...
    (requires biblatex-software package)
    Generating citation ...
    (requires biblatex-software package)
    Generating citation ...
    (requires biblatex-software package)
    Generating citation ...
    04_cross_validation.R
    # ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
    ## CROSS VALIDATION ============================================================
    
    # Description:
    #     Run three different rolling-origin cross validation schemes on selected 
    #     candidate models, where:
    #     - Scheme 1: Moving 12-month window with issue month 9 and target month 12
    #     - Scheme 2: Set year index to NA
    #     - Scheme 3: Remove yearly effect from random effects specification
    
    # Paper:
    #     Compound and cascading effects of climatic extremes on dengue outbreak
    #     risk in the Caribbean: an impact-based modelling framework with long-lag
    #     and short-lag interactions
    
    # Script authors:
    #     Chloe Fletcher        ORCID: 0000-0002-6705-7605
    #     Dr Giovenale Moirano  ORCID: 0000-0001-8748-3321
    #     Prof. Rachel Lowe     ORCID: 0000-0003-3939-7343
    
    # ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''
    
    ## Source packages, functions and data -----------------------------------------
    
    # load packages and functions
    source("functions/00_packages_functions.R")
    
    # read in harmonised data
    data <- read.csv("data/barbados_data.csv")
    
    # load model output
    spi6.mod.out <- readRDS("outputs/spi6_mods.rds")
    candidate_mods <- c(23, 29, 30, 58, 59, 138, 173)
    
    
    ## Set model parameters --------------------------------------------------------
    
    # set start year, start month and start index
    startyear <- data$cal_year[1]
    startmonth <- data$cal_month[1]
    startid <- nrow(data) - 12 * 10 + 1  # last 10 years of dataset
    
    # model covariates
    vars.b <- c(NA)
    vars.p <- spi6.mod.out$mod.gof$vars[candidate_mods]
    
    # forecast lead time
    lags.b <- c(3)
    lags.p <- rep(3, length(vars.p))
    
    # random effects
    re1 <- paste("f(month, model = 'rw2', cyclic = TRUE, constr = TRUE,",
                 "scale.model = TRUE, hyper = precision.prior)")
    re2 <- "f(year_index, model = 'rw2', hyper = precision.prior)"
    fe <- "logdir.4"
    
    res.b <- list(list(re1))
    res.f <- list(list(re1, fe))
    res.p <- list(list(re1, re2, fe))
    
    
    ## Run cross validation --------------------------------------------------------
    
    # run cross validation for baseline model
    # source("functions/out_of_sample_progressive_adjusted.R")
    # pred.b <- out_of_sample(data = data, response = "cases", save = T,
    #                         vars = vars.b, res = res.b, lags = lags.b,
    #                         start_id = startid, source_path = "functions",
    #                         file_path = "outputs", file_name = "preds_b")
    
    # run cross validation for candidate models
    # scheme 1: moving 12-month window with issue month 9 and target month 12
    # source("functions/out_of_sample_progressive_adjusted.R")
    # pred.i.ay <- out_of_sample(data = data, response = "cases", save = T,
    #                            vars = vars.p, res = res.p, lags = lags.p,
    #                            start_id = startid, source_path = "functions",
    #                            file_path = "outputs", file_name = "preds_i_adjustedyear_logdir")
    
    # run cross validation for candidate models
    # scheme 2: set year index to NA
    # source("functions/out_of_sample_progressive.R")
    # pred.i.na <- out_of_sample(data = data, response = "cases", save = T,
    #                            vars = vars.p, res = res.p, lags = lags.p, 
    #                            start_id = startid, source_path = "functions",
    #                            file_path = "outputs", file_name = "preds_i_nayear_logdir")
    
    # run cross validation for candidate models
    # scheme 3: remove yearly effect from REs
    # source("functions/out_of_sample_progressive_adjusted.R")
    # pred.i.ny <- out_of_sample(data = data, response = "cases", save = T,
    #                            vars = vars.p, res = res.f, lags = lags.p,
    #                            start_id = startid, source_path = "functions",
    #                            file_path = "outputs", file_name = "preds_i_noyeareffect_logdir")
    
    pred.b <- readRDS("outputs/cross_validation/preds_b.rds")
    pred.i.ay <- readRDS("outputs/cross_validation/preds_i_adjustedyear_logdir.rds")
    pred.i.na <- readRDS("outputs/cross_validation/preds_i_nayear_logdir.rds")
    pred.i.ny <- readRDS("outputs/cross_validation/preds_i_noyeareffect_logdir.rds")
    
    
    ## Tabulate results ------------------------------------------------------------
    
    # tabulate hit rate, false alarm rate, auc, trigger and crps
    tab <- rbind(pred.b$predictive_ability, pred.i.ay$predictive_ability,
                 pred.i.na$predictive_ability, pred.i.ny$predictive_ability)
    
    tab$scheme <- c(NA, rep(1:3, each = length(vars.p)))
    tab$id <- c(0, rep(1:length(vars.p), 3))
    
    tab <- tab %>% relocate(scheme, .after = model)
    tab <- tab[with(tab, order(id, scheme)),]
    row.names(tab) <- NULL
    
    
    ## Plot final selected model ---------------------------------------------------
    
    # select pred.i
    pred.i <- pred.i.ay
    
    # final model selected
    i <- 4
    
    # plot observed vs fitted
    g1 <- plot.fit.cv(pred.i, startyear, startmonth, i)
    g1_b <- plot.fit.cv(pred.b, startyear, startmonth, 1, colour = "turquoise4")
    
    # plot outbreak probability
    g2 <- plot.prob.outbreak.cv(pred.i, data[startid, "cal_year"], startmonth, i,
                                threshold = pred.i$predictive_ability[i, "trigger"])
    g2_b <- plot.prob.outbreak.cv(pred.b, data[startid, "cal_year"], startmonth, 1,
                                  threshold = pred.b$predictive_ability[1, "trigger"])
    
    # plot double roc curve with current practice model
    rocplt <- roc(pred.i[["outbreak_probability"]][[i]][["obs_outbreak"]],
                  pred.i[["outbreak_probability"]][[i]][["pred_outbreak_prob"]],
                  direction = "<", ci = TRUE, quiet = TRUE)
    ciobj <- ci.se(rocplt, specificities = seq(0, 1, l=100))
    ciobj <- data.frame(x = as.numeric(rownames(ciobj)),
                        lower = ciobj[,1], upper = ciobj[,3])
    
    rocplt2 <- roc(pred.b[["outbreak_probability"]][[1]][["obs_outbreak"]],
                   pred.b[["outbreak_probability"]][[1]][["pred_outbreak_prob"]],
                   direction = "<", ci = TRUE, quiet = TRUE)
    ciobj2 <- ci.se(rocplt2, specificities=seq(0, 1, l=100))
    ciobj2 <- data.frame(x = as.numeric(rownames(ciobj2)),
                        lower = ciobj2[,1], upper = ciobj2[,3])
    
    roc.test(rocplt2, rocplt, method = "delong")
    
    g3 <- plot_roc_2(rocplt, rocplt2, ciobj, ciobj2, "")
    
    # assemble as single plot
    ggdraw() +
      draw_plot(g1, x = 0, y = .6, width = 1, height = .4) +
      draw_plot(g2, x = 0, y = 0, width = .5, height = .6) +
      draw_plot(g3, x = .5, y = 0, width = .5, height = .6) +
      draw_plot_label(label = c("A", "B", "C"), size = 18,
                      x = c(0, 0, .5), y = c(1, .58, .58)) +
      theme_void() +
      theme(plot.background = element_rect(fill = "white")) +
      panel_border(remove = TRUE)
    #ggsave("outputs/figs/F_02.png", dpi=600, height=12, width=15, bg="white")
    
    ggdraw() +
      draw_plot(g1_b, x = 0, y = .6, width = 1, height = .4) +
      draw_plot(g2_b, x = 0, y = 0, width = .5, height = .6) +
      draw_plot(g3, x = .5, y = 0, width = .5, height = .6) +
      draw_plot_label(label = c("A", "B", "C"), size = 18,
                      x = c(0, 0, .5), y = c(1, .58, .58)) +
      theme_void() +
      theme(plot.background = element_rect(fill = "white")) +
      panel_border(remove = TRUE)
    #ggsave("outputs/figs/SF_07.png", dpi=600, height=12, width=15, bg="white")
    
    
    ## Cross validation for models of reduced complexity ---------------------------
    
    # model covariates
    vars.v <- c(NA,
                rep(c(gsub("\\*", "+", spi6.mod.out$mod.gof$vars[candidate_mods[i]]),
                      spi6.mod.out$mod.gof$vars[candidate_mods[i]]), 3, length.out = 5),
                NA, NA)
    
    # forecast lead time
    lags.v <- rep(3, length(vars.v))
    
    # labels for model runs
    labels.v <- c("seasonal_interannual",
                  "temp+longspi+shortspi",
                  "temp*longspi*shortspi",
                  "seasonal_interannual_temp+longspi+shortspi",
                  "seasonal_interannual_temp*longspi*shortspi",
                  "seasonal_interannual_temp+longspi+shortspi_dir",
                  "dir",
                  "seasonal_interannual_dir")
    
    # random effects
    res.v <- list(list(re1, re2), list(NA), list(NA), list(re1, re2), list(re1, re2), 
                  list(re1, re2, fe), list(fe), list(re1, re2, fe))
    
    # run cross validation for variants of reduced complexity
    # scheme 1: moving 12-month window with issue month 9 and target month 12
    # source("functions/out_of_sample_progressive_adjusted.R")
    # pred.v <- out_of_sample(data = data, response = "cases", save = T, vars = vars.v,
    #                         res = res.v, lags = lags.v, labels = labels.v,
    #                         start_id = startid, source_path = "functions",
    #                         file_path = "outputs", file_name = "preds_v_adjustedyear_logdir")
    
    pred.v <- readRDS("outputs/cross_validation/preds_v_adjustedyear_logdir.rds")
    
    # tabulate hit rate, false alarm rate, auc, trigger and crps
    tab.v <- rbind(pred.b$predictive_ability, pred.v$predictive_ability,
                   pred.i.ay$predictive_ability[i, ])
    
    tab.v$model <- c("seasonal (baseline)", labels.v,
                     "seasonal_interannual_temp*longspi*shortspi_dir")
    
    tab.v <- tab.v[c(1:2, 8, 3:6, 9, 7, 10), ]
    row.names(tab.v) <- NULL
    
    
    ## Compare interaction vs additive model ---------------------------------------
    
    # specify baseline, additive and interaction models
    j <- 6
    
    pred.fm <- pred.i$predictive_ability[i, ]
    pred.bl <- pred.b$predictive_ability[1, ]
    pred.am <- pred.v$predictive_ability[j, ]
    
    # specify outbreaks
    n_ob <- sum(pred.i$outbreak_probability[[i]]$obs_outbreak == 1)
    
    # calculate true positive, false positive, true negative, false negative rates
    signal <- data.frame(model = rep(c("Interaction", "Additive", "Baseline"),
                                     each=4),
                         condition = rep(c("4 Hit", "3 Correct Rejection",
                                           "2 False Alarm", "1 Miss"), 3),
                         value = c(pred.fm$tpr * n_ob / 120,
                                   (1 - pred.fm$fpr) * (120 - n_ob) / 120,
                                   pred.fm$fpr * (120 - n_ob) / 120,
                                   (1 - pred.fm$tpr) * n_ob / 120,
                                   pred.am$tpr * n_ob / 120,
                                   (1 - pred.am$fpr) * (120 - n_ob) / 120,
                                   pred.am$fpr * (120 - n_ob) / 120,
                                   (1 - pred.am$tpr) * n_ob / 120,
                                   pred.bl$tpr * n_ob / 120,
                                   (1 - pred.bl$fpr) * (120 - n_ob) / 120,
                                   pred.bl$fpr * (120 - n_ob) / 120,
                                   (1 - pred.bl$tpr) * n_ob / 120))
    signal[,"value"] <- round(signal[,"value"] * 100, 4)
    
    # plot stacked bar chart for baseline vs additive vs interaction models
    cols <- c("#AA1155", "#F8B4D3", "#B4E4DA", "#369683")
    ggplot(data=signal, aes(x=factor(model,
                                     levels=c("Baseline", "Additive",
                                              "Interaction")),
                            y=value, fill=condition)) +
      geom_bar(stat="identity") +
      scale_fill_manual("", values = cols, labels = c("False Negative", 
                                                      "False Positive",
                                                      "True Negative",
                                                      "True Positive")) +
      xlab("Model Formulation") +
      ylab("Percentage of Forecasts (%)") +
      theme_bw() + 
      theme(text = element_text(size = 16),
            plot.background = element_rect(fill = "white"),
            panel.border = element_blank(), axis.ticks.x=element_blank(),
            axis.title.y = element_text(margin = margin(t = 0, r = 10, b = 0, l = 0)),
            axis.title.x = element_text(margin = margin(t = 20, r = 0, b = 0, l = 0)))
    #ggsave("outputs/figs/F_03.png", dpi=600, height=12, width=12)
    
    
    ## END

    back to top

    Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
    The source code of Software Heritage itself is available on our development forge.
    The source code files archived by Software Heritage are available under their own copyright and licenses.
    Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API