https://github.com/LCBC-UiO/VidalPineiro_BrainAge
Raw File
Tip revision: 2044c6ca40e0b8f99c9190c6edfde8ca76b559ac authored by didacvp on 01 November 2021, 13:45:46 UTC
Update README.md
Tip revision: 2044c6c
BrainAge_VidalPineiro_UKB_GenMod.Rmd
---
title: "BrainAge_VidalPineiro_ModelGeneration(XGB & LASSO)"
author: "dvp"
date: "11/2/2020"
output:
  html_document: default
  pdf_document: default
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = TRUE)
library(knitr)
library(kableExtra)
library(cowplot)
library(psych)
library(boot)
library(tidyverse)
library(magrittr)
library(xgboost)
library(caret)
library(mgcv)
library(GGally)
library(gridExtra)
library(missMDA)
library(DiagrammeR)
library(ggpointdensity)
library(glmnet)


raw_folder="./data-raw/BrainAge/UKB/tabulated_data"
data_folder=file.path("./data/BrainAge",paste0("noExcl_scaled"))
try(dir.create(data_folder))
options(bitmapType = "cairo")

squeue = function(user, job) {
  df.squeue = system(paste0("squeue --name=", job," -u ",user), intern = T) %>% 
      strsplit(., " +") %>% 
      simplify2array() %>% 
      t() %>% 
      as.data.frame()
  return(df.squeue)
  
}
```


# Data Preprocessing (commong steps)

## define subsets of vars 
We exclude global variables, WM variables and non-Brain variables. 
We include Intensity, Cortical Area, CTH, VOlume GWC and subcortical area and intensity
```{r subset vars,echo=F}

if (!file.exists(file.path(data_folder, "vars.Rda"))){
  df = read.csv(file.path(raw_folder,"FS_IDPS_all_tp3.csv"))
  print("define variables to remove. currently removing non-brain variables")
    rm_vars = c(
      "Grey.white.contrast.in.unknown..right.hemisphere."
      )

  gwc_vars = names(df)[grepl("Grey.white", names(df))]
  gwc_vars = gwc_vars[! gwc_vars %in% rm_vars]
  
  cth_vars = names(df)[grepl("Mean.thickness.of", names(df))]
  cth_vars = cth_vars[! cth_vars %in% rm_vars]
  
  Int_vars = names(df)[grepl("Mean.intensity", names(df))]
  Int_vars = Int_vars[! Int_vars %in% rm_vars]  
  
  Area_vars = names(df)[grepl("Area.of", names(df))]
  Area_vars = Area_vars[! Area_vars %in% rm_vars] 
  
  Vol_vars = names(df)[grepl("Volume.of", names(df))]
  Vol_vars = Vol_vars[! Vol_vars %in% rm_vars] 
  
  T1w_vars = c(Vol_vars, Area_vars, Int_vars, cth_vars, gwc_vars)
  
save(T1w_vars, 
     Vol_vars, 
     Area_vars, 
     Int_vars, 
     cth_vars, 
     gwc_vars, 
     rm_vars, 
     file = file.path(data_folder, "vars.Rda"))
} else {
  load(file = file.path(data_folder, "vars.Rda"))  
}

print(paste("remove vars N:",length(rm_vars)))
print(paste("grey white matter contrast;", "N:",length(gwc_vars)))
print(paste("cth; ","N:",length(cth_vars)))
print(paste("Intensity","N:",length(Int_vars)))
print(paste("Area", "N:",length(Area_vars)))
print(paste("Volume", "N:",length(Vol_vars)))
print(paste("MRI", "N:",length(T1w_vars)))


```


## Loading data and merging data.frames
mergeing all data together (MRI, noMRI, cross and long)
Save as a single file
```{r loading data, echo = F}
if (!file.exists(file.path(data_folder, "All_raw.Rda"))) {
  print("loading timpoint 1 data")
  df.tp2 = read.csv(file.path(raw_folder, "FS_IDPS_all_tp2.csv"))
  df.tp2$wave = 1
  
  print("loading timepoint 1 no-MRI data")
  df.tp2.noMRI = read.csv(file.path(raw_folder, "ID_age_all_tp2.csv"))
  
  print("join MRI and noMRIdata")
  df.tp2 = left_join(df.tp2, df.tp2.noMRI)
  
  print("loading timepoint 2 data")
  df.tp3 = read.csv(file.path(raw_folder, "FS_IDPS_all_tp3.csv"))
  df.tp3$wave = 2
  
  print("loading timepoint 2 no-MRI data")
  df.tp3.noMRI = read.csv(file.path(raw_folder, "ID_age_all_tp3.csv"))
  subs.long = df.tp3$eid
  
  print("loading no-MRI data")
  df.noMRI = read.csv(file.path(raw_folder, "ID_all_noMRI.csv"))
  
  
  print("checking whether subjects in timepoint2 have tp1 data")
  subs.long = subs.long[subs.long %in% df.tp2$eid]
  
  print("remove tp3 people without long data")
  df.tp3 = df.tp3 %>% filter(eid %in% subs.long)
  
  print("join MRI and noMRIdata")
  df.tp3 = left_join(df.tp3, df.tp3.noMRI)
  
  print("merge data")
  df = 
    rbind(df.tp2, df.tp3) %>% 
    left_join(.,df.noMRI) %>% 
    mutate(eICV = Volume.of.EstimatedTotalIntraCranial..whole.brain.,
                     TGMV = Volume.of.TotalGray..whole.brain.,
                     TBV = Volume.of.SupraTentorialNotVent..whole.brain.,
                     Center.Cheadle = if_else(UK_Biobank_assessment_centre_c_54.2.0 == 11025, 1,0),
                     Center.Reading = if_else(UK_Biobank_assessment_centre_c_54.2.0 == 11026, 1,0),
                     Center.Newcastle = if_else(UK_Biobank_assessment_centre_c_54.2.0 == 11027, 1,0))
  
  vars_noMRI = c("eid", 
                 "wave", 
                 "eICV", 
                 "TGMV", 
                 "TBV",
                 "Center.Reading",
                 "Center.Newcastle",
                 names(df.tp3.noMRI)[-1],
                 names(df.noMRI)[-1])
  
  df = df %>% dplyr::select(vars_noMRI, T1w_vars)
  
  save(subs.long, vars_noMRI, df, file = file.path(data_folder,"All_raw.Rda"))
} else {
  print("loading data. All merged")
  load(file = file.path(data_folder,"All_raw.Rda"))
}
```

## Preprocessing T1w data
a) Scale. Use training sample and apply parameters to the whole sample.

```{r preprocess T1w data, echo=T}

if (!file.exists(file.path(data_folder,"All_preproc.Rda"))) {
  # check whether all long went to the same center 
  #table(df$UK_Biobank_assessment_centre_c_54.2.0 == df$UK_Biobank_assessment_centre_c_54.3.0)
  
  ## a) scale data (again use training dataset as reference)
  # get mean and sd for reference sample
  df.pp <- df %>% filter(!eid %in% subs.long)
  df.scale = data.frame(T1w_vars,
                        mean = df.pp %>% summarise_at(T1w_vars, mean, na.rm = T) %>% t(),
                        sd = df.pp %>% summarise_at(T1w_vars, sd, na.rm = T) %>% t())
  
  ## apply coefficients to whole dataset
  for (vars in rownames(df.scale)) {
    print(vars)
    df = df %>% mutate(!!vars := (get(vars) - df.scale[vars, "mean"]) / df.scale[vars, "sd"])
  }
  # save
  save(df, df.scale,  file = file.path(data_folder,"All_preproc.Rda"))
} else {
  load(file.path(data_folder,"All_preproc.Rda"))
}

```


## Select Train and Test 
Separate train and test data. That is separate longitudinal and cross-sectional data
```{r train and test data}

df.Train <- df %>% filter(!eid %in% subs.long)
df.Test <- df %>% filter(eid %in% subs.long)
```

## Basic demogrpahics of train and test sample
Sex and Age and other sociodemographic info
```{r Basic Demographic}
c("n: train/test", dim(df.Train)[1], dim(df.Test)[1], dim(df.Test)[1]/2)
c("age range train/test", range(df.Train$age), range(df.Test$age))
c("age mean - sd train/test", mean(df.Train$age), mean(df.Test$age), sd(df.Train$age), sd(df.Test$age))
c("Sex: Train:", table(df.Train$sex))
tmp= df.Test %>% group_by(eid) %>% summarise(n = n(),
                                    AgeBsl = min(age),
                                    AgeGap = max(age) - min(age),
                                    sex = first(sex))
c("Sex: Test:", table(tmp$sex))
c("Interval: Test:", mean(tmp$AgeGap[!tmp$AgeGap == 0]), sd(tmp$AgeGap[!tmp$AgeGap == 0]))
c("Age bsl: Test:", mean(tmp$AgeBsl), sd(tmp$AgeBsl),range(tmp$AgeBsl))
```

### Fig 1a. background plot for  (theoretical expectations)
```{r Fig1a}

set.seed(1234)
n= 100000
CA = runif(n = n, 
           min = 0,
           max = 90)
Delta = rnorm(n,
              sd = 15)
df.plot = data.frame(CA, 
           Delta,
           BA = CA + Delta)

# create hypothetical trajectories
df.plot.trajectories = data.frame(
    CA = seq(0,90),
    EL2 = c(seq(10,40, length.out = 20), seq(41,111)),
    Other1 = c(seq(-1,39), seq(55,104)),
    AC1 = c(seq(1:45), seq(46,115, length.out = 46))) %>% 
  pivot_longer(cols = -1, names_to = "I", values_to = "BA")

gs = ggplot(df.plot, aes(CA, BA)) +
  #geom_density_2d_filled(bins = 100) +
  #scale_fill_grey(start = 1, 
  #                end = .2,
  #                na.value = "white") +
  geom_segment(x = 25, 
               y = 25, 
               xend = 89, 
               yend = 89, 
               size = 3, 
               lineend = "round", 
               arrow=arrow(type = "closed"))+ 
  geom_segment(x = 0, 
               y = 0, 
               xend = 89, 
               yend = 89, 
               size = 3,
               linetype = 3)+ 
  theme_classic() +
  theme(legend.position = 'none',
        axis.text = element_text(size = 16),
        axis.title = element_text(size = 20)) + 
  scale_x_continuous(limits = c(-2,90), 
                     expand = c(0,0), 
                     breaks = seq(0,80,20), 
                     name = "Chronological Age") + 
  scale_y_continuous(limits = c(0,120), 
                     expand = c(0,0), 
                     breaks = seq(0,100,20), 
                     name = "Brain Age") +
  geom_line(data = df.plot.trajectories, 
            mapping = aes(x = CA, y = BA, group = I, color = I), 
            size = 0,
            arrow =arrow(type = "closed")) +
  geom_line(data = df.plot.trajectories, 
            mapping = aes(x = CA, y = BA, group = I, color = I), 
            linetype = 2, 
            size = 2) +
  scale_color_manual(values= c("#66a61e","#810f7c", '#e6ab02'))
# '#810f7c' purple
# "#66a61e" green
# #bdbdbd   greyish
try(dir.create("figures"))
ggsave("figures/scheme_trajectories.png", 
       dpi = 500, 
       plot = gs,
       width = 10,
       height = 10, 
       units = "cm")
#810f7c# palette info
#https://colorbrewer2.org/#type=qualitative&scheme=Dark2&n6

```
### Fig 1b. Schematic representation of influences
```{r Fig1d}
## Purely for illustrative purposes!!!
# ilustrate aging
set.seed(123)
n = 5
AgeInt = rnorm(n,
              sd = 2.5) + 50
AgeSlope = rnorm(n,
              sd = .8) + 1
age = 57:68 
df.aging = matrix(0,n,length(age))

for (i in 1:n) {
  jj= 0
  for (j in age) {
    jj = jj + 1
    if(AgeInt[i]>= j) {
      df.aging[i,jj] = j
    } else {
      df.aging[i,jj] = AgeInt[i] + (j-AgeInt[i])*AgeSlope[i]
    }
  }
}
df.aging %<>% t() %>% as.data.frame()
df.aging = df.aging[,order(AgeSlope)]
names(df.aging) = c("id1", "id2", "id3", "id4", "id5")
df.aging$CA = age
df.aging %<>% pivot_longer(-CA, names_to = "eid", values_to = "BA")

colorscale <- c('#edf8fb','#ccece6','#99d8c9','#66c2a4','#2ca25f','#006d2c')

gs = ggplot(df.aging, aes(x = CA, y = BA, group = eid, color = eid, fill =eid)) +
  geom_line(size = 0, linetype = 3, arrow =arrow(type = "closed")) +
  geom_line(size = 1, linetype = 4) +
  geom_line(data =df.aging %>% filter(CA %>%  between(61,65)),
            size = 2) +
  geom_point(data =df.aging %>% filter(CA ==  63),
            size = 4, shape = 21, color = "black", stroke = 2) +
  theme_classic() +
  theme(legend.position = 'none',
        axis.text = element_text(size = 14),
        axis.title = element_text(size = 18),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank()) + 
  scale_color_manual(values =rev(colorscale)) +
  scale_fill_manual(values =rev(colorscale)) +
  scale_x_continuous(name = "Chronological Age", breaks = 63, labels = "old age") + 
  scale_y_continuous(name = "Brain Age")

ggsave("figures/aging_h0.png", 
       dpi = 500, 
       plot = gs,
       width = 10,
       height = 5, 
       units = "cm")

# ilustrate earkly life effects
colorscale = c('#edf8fb','#bfd3e6','#9ebcda','#8c96c6','#8856a7','#810f7c')
set.seed(123)
n = 5
BirthWeight = rnorm(n,
              sd = 3)
AgeSlope = rnorm(n,
              sd = .1) + BirthWeight*.4
AgeInt = rnorm(n,
              sd = 2) + BirthWeight*2
AgeDevStop = rnorm(n,
              sd =1) + 10

age = 0:25 ### remember - this is purely for illustrative purposes
df.earlylife = matrix(0,n,length(age))
for (i in 1:n) {
  jj= 0
  for (j in age) {
    jj = jj + 1
    if(AgeDevStop[i]>= j) {
      df.earlylife[i,jj] = AgeInt[i] + j + j*AgeSlope[i]
    } else {
      df.earlylife[i,jj] = AgeInt[i] + j + AgeDevStop[i]*AgeSlope[i]
      
    }
  }
}
df.earlylife %<>% t() %>% as.data.frame()
df.earlylife = df.earlylife[,order(BirthWeight)]
names(df.earlylife) = c("id1", "id2", "id3", "id4", "id5")
df.earlylife$CA = age
df.earlylife %<>% pivot_longer(-CA, names_to = "eid", values_to = "BA")

gs = ggplot(df.earlylife, 
       aes(x = CA, y = BA, group = eid, color = eid, fill =eid)) +
  geom_line(size = 0, linetype = 3, arrow =arrow(type = "closed")) +
  geom_line(size = 1, linetype = 4) +
  geom_line(data =df.earlylife %>% filter(CA %>%  between(19,23)),
            size = 2) +
  geom_point(data =df.earlylife %>% filter(CA ==  21),
            size = 4, shape = 21, color = "black", stroke = 2) +
  theme_classic() +
  theme(legend.position = 'none',
        axis.text = element_text(size = 14),
        axis.title = element_text(size = 18),
        axis.text.y = element_blank(),
        axis.ticks.y = element_blank()) + 
  scale_color_manual(values =rev(colorscale)) +
  scale_fill_manual(values =rev(colorscale)) +
  scale_x_continuous(name = "Chronological Age", breaks = c(0,9,20), labels =c("birth", "development", "old age")) + 
  scale_y_continuous(name = "Brain Age")


ggsave("figures/earlylife_h0.png", 
       dpi = 500, 
       plot = gs,
       width = 10,
       height = 5, 
       units = "cm")
```


### Fig1c plot. UKB sample distribution
```{r Fig1c}
## plot train and dataset with age
gs = ggplot() +
  geom_density(df.Train, mapping = aes(x = age, y = -..density..), fill="#66a61e", color="#66a61e", size =1.5, alpha =.8) + 
    geom_label( aes(x=80, y=-0.030, label="Training set"), color="#66a61e", size =4) +
  geom_density(df.Test, mapping = aes(x = age, y = ..density..), fill="#1b9e77", color="#1b9e77", size =1.5, alpha =.8) + 
    geom_label( aes(x=80, y=0.030, label="Test set"), color="#1b9e77", size =4) +
  geom_hline(yintercept = 0, color="#1b9e77", size = 1) +
    theme_classic() +
    theme(legend.position = 'none',
        axis.text = element_text(size = 16),
        axis.title = element_text(size = 20),
        #axis.line.x = element_blank(),
        axis.ticks.x = element_blank(),
        axis.text.x = element_blank()) + 
  xlab("Chronological Age") +
  ylab("Density") +
  coord_flip() + 
  scale_x_continuous(expand = c(0,0))

try(dir.create("figures"))
ggsave("figures/sample_distribution.png", 
       dpi = 500, 
       plot = gs,
       width = 10,
       height = 10, 
       units = "cm")
```

### Fig1d plot. Age effects -GAM
```{r Fig1d}
## plot age effects on features 
if(!file.exists(file.path(data_folder,"age.effects.Rda"))) {
  gam.age = df.Train %>% 
    pivot_longer(T1w_vars, 
                 names_to = "features",
                 values_to = "value") %>% 
    group_by(features) %>% 
    do(fit = gam(value ~ s(age), data = ., method = 'REML'))  
  
  df.age = data.frame(
    feature =gam.age$features,
    r.sq = sapply(gam.age$fit, function(x) summary(x)$r.sq) %>% 
            simplify2array())
  
  df.Harmonize = read.csv("data/Harmonize.csv", stringsAsFactors = F) %>% 
    mutate(feature = UKB,
           modality = as.factor(Stats_file)) %>% 
    select(feature, modality)
  
  df.age = left_join(df.age,df.Harmonize) 
  tmp =df.age %>% 
    group_by(modality) %>% 
    summarise(r.sq =mean(r.sq)) %>% 
    mutate(order = -rank(r.sq)) %>% 
    select(-r.sq)
  
   df.age = left_join(df.age, tmp) %>% 
    arrange(order, -r.sq) %>% 
     mutate(order2 = 1:length(T1w_vars))
   
  df.age$modality = 
    plyr::mapvalues(df.age$modality,
                  from =c("Area_Aparc","GWC_Aparc","Intensity_Aseg","Thickness_Aparc","Volume_Aseg","Volume_Aparc"),
                  to = c("area (c)", "gwc (c)","intensity (s)","thickness (c)", "volume (s)", "volume (c)"))
  
  
  save(gam.age, df.age, file = file.path(data_folder,"age.effects.Rda"))
} else {
  load(file = file.path(data_folder,"age.effects.Rda"))
}

write.csv(df.age, file = "figures/age_effects_UKB.csv")

colorscale =c('#1b9e77','#d95f02','#7570b3','#e7298a','#66a61e','#e6ab02')
gs = ggplot(df.age, aes(x = order2, y = r.sq, group = modality, fill = modality)) +
  geom_point(shape = 21, size = 3) +
    theme_classic() +
    theme(legend.position = c(.7,.8),
        axis.text = element_text(size = 16),
        axis.title = element_text(size = 20),
        axis.text.x = element_blank(),
        axis.ticks.x = element_blank(),
        legend.title = element_blank(),
        legend.text = element_text(size = 16),
        legend.key.size = unit(5,"point")) + 
  xlab("Feature") +
  ylab(bquote('Age Variance (r' ^2*')' )) +
  scale_fill_manual(values = colorscale)

try(dir.create("figures"))
ggsave("figures/age_relationship.png", 
       dpi = 500, 
       plot = gs,
       width = 10,
       height = 10, 
       units = "cm")
```


# ML - model creation.  Use all variables 

## Select Train and Test 
Separate train and test data. That is separate longitudinal and cross-sectional data
```{r train and test data. all vars}

# get data and label for test and train data
data.train = df.Train[, T1w_vars] %>% as.matrix()
label.train = df.Train$age %>% as.matrix()
train.id = df.Train$eid %>% as.matrix()
data.test = df.Test[, T1w_vars] %>% as.matrix() %>% as.matrix()
label.test = df.Test$age %>% as.matrix()
label.id = df.Test$eid %>% as.matrix()

# create xgb matrices
dtrain <- xgb.DMatrix(data = data.train,label = label.train) 
dtest <- xgb.DMatrix(data = data.test,label=label.test)

```


## Explore hyperparameter spaces with cross-validation
At the moment, moving forward with default parameters. 
```{r Hyper_parameter_search,  eval=F, echo=T}

set.seed(1234)
it=50
if(!file.exists(file.path(data_folder,"RandomHyperParameterSearchCV.Rda"))) {
  # it might take some minutes 
  # set up the cross-validated hyper-parameter search
  xgb_grid_1 = expand.grid(
                  nrounds = seq(100, 600,50),
                  eta = c(.01, .05, .1, .15, .2),
                  max_depth = seq(2, 8,1),
                  gamma = seq(0.5,1.5,0.5),
                  min_child_weight=seq(1,4),
                  rmse =NaN,
                  Nrmse=NaN,
                  train=NaN,
                  idx=NaN
                  )
  
  save(xgb_grid_1, file = file.path(data_folder,"RandomHyperParameterSearchCV.Rda"))
} else {
  load(file.path(data_folder,"RandomHyperParameterSearchCV.Rda"))
}
  idx = sample(dim(xgb_grid_1)[1],it)
  ii = which(!(1:it %in% xgb_grid_1$idx))
  # parameters
  for (i in ii) {
    print(i)
    eta = xgb_grid_1$eta[idx[i]]
    max_depth=xgb_grid_1$max_depth[idx[i]]
    gamma = xgb_grid_1$gamma[idx[i]]
    min_child_weight = xgb_grid_1$min_child_weight[idx[i]]
    nrounds = xgb_grid_1$nrounds[idx[i]]
    (system(paste("sbatch scripts/BrainAge/RandomHyperParameterSearchCV.sh", 
                 eta, 
                 max_depth, 
                 gamma, 
                 min_child_weight, 
                 nrounds,
                 i,
                 idx[i],
                 data_folder,
                 sep = " ")))
  Sys.sleep(5)  
  }
  
  
  df.squeue = squeue("p274-didacvp","HyperParameterSearch")
  while (length(df.squeue$V1) > 1) {
    Sys.sleep(120) 
    print("script running on sbatch")
    df.squeue = squeue("p274-didacvp","HyperParameterSearch")
  }
  
  # reload 
  load(file.path(data_folder,"RandomHyperParameterSearchCV.Rda"))
  ii = which(!(1:it %in% xgb_grid_1$idx))
  if(is_empty(ii)) {
    xgb_grid_1 = xgb_grid_1 %>% filter(!is.na(rmse))
    save(xgb_grid_1, file = "data/RandomHyperParameterSearchCV.Rda")
  } else { 
    disp("something wrong with randomized search")
  }
  
  xgb_grid_1 %>% arrange(rmse) %>% head()
  
```


### Apply Crossvalidation
Strictly not necessary, 
a) saved for diagnostic statistics and fig 1c

```{r xgboost cv}
if(!file.exists(file.path(data_folder, "xgbcv.CV.Rda"))) {
  
  eta = xgb_grid_1[which.min(xgb_grid_1$rmse),"eta"]
  max_depth=xgb_grid_1[which.min(xgb_grid_1$rmse),"max_depth"]
  gamma = xgb_grid_1[which.min(xgb_grid_1$rmse),"gamma"]
  min_child_weight = xgb_grid_1[which.min(xgb_grid_1$rmse),"min_child_weight"]
  nrounds = xgb_grid_1[which.min(xgb_grid_1$rmse),"nrounds"]
  
  (system(paste("sbatch scripts/BrainAge/BrainAgeCV.sh", 
                   eta, 
                   max_depth, 
                   gamma, 
                   min_child_weight, 
                   nrounds,
                   data_folder,
                   sep = " ")))
  
  df.squeue = squeue("p274-didacvp","BrainAgeCV")
  while (length(df.squeue$V1) > 1) {
    Sys.sleep(120) 
    print("script running on sbatch")
    df.squeue = squeue("p274-didacvp","BrainAgeCV")
  }
} 

load(file.path(data_folder, "xgbcv.CV.Rda"))

  
min_rmse = min(xgbcv$evaluation_log$test_rmse_mean)
min_rmse_nround = which.min(xgbcv$evaluation_log$test_rmse_mean)
print(paste("min rmse", round(min_rmse,3), "at round n", min_rmse_nround))
```


# Plotting and predicting
```{r xgboost cv plotts and summary  stats}

df.out = data.frame(
  eid = df.Train$eid,
  BA = xgbcv$pred,
  Age = label.train,
  sex = df.Train$sex,
  Center.Reading = df.Train$Center.Reading,
  Center.Newcastle = df.Train$Center.Newcastle,
  eiCV = df.Train$eICV)


lm.Age = lm(BA ~ Age, data = df.out)
summary(lm.Age)
cor.test(df.out$BA, df.out$Age)
err <- mean(abs(df.out$BA - df.out$Age))
rmse = sqrt(sum((df.out$BA - df.out$Age)^2) / length(df.out$BA))
age.bias <- cor.test(df.out$Age, (df.out$BA - df.out$Age))$estimate
print(paste("mean absolute error (MAE)=", round(err,2)))
print(paste("root mean square error=",round(rmse,2)))
print(paste("r-sq =", round(summary(lm.Age)$r.squared, 2)))
print(paste("age.bias =", age.bias))
```

### Plot for Fig 1e 
```{r plot CA vs BA}
# Fig 1c (trainging data)
#print("plot BA vs CA. Note expect age.bias")
gs = ggplot(data = df.out,
       mapping = aes(x = Age,
                     y = BA)) + 
  geom_pointdensity(adjust = 2) + 
  geom_abline(intercept = 0, 
              slope = 1, 
              colour = "grey60", 
              linetype = 4,
              size = 1.5) + 
  geom_smooth(method = "lm", 
              color = "#1b9e77",
              size = 1.5) + 
  geom_smooth(color = "#d95f02",
              size = 1.5) +
  geom_label( aes(x=50, y=80, label="Training set"), color="#66a61e", size =4) +
  theme_classic() + 
  theme(legend.position = 'none',
        axis.text = element_text(size = 16),
        axis.title = element_text(size = 20)) +
  scale_color_viridis_c() +
  ylab("Brain Age") +
  xlab("Chronological Age")
  

ggsave("figures/TrainXGV.png", 
       dpi = 500, 
       plot = gs,
       width = 10,
       height = 10, 
       units = "cm")


```

### Correct BA vs CA Bias
correct BA Bias using gam and linear coefficients
```{r Bias Correction}
sm.rel = gam(BA ~ s(Age, k = 8), data = df.out) 

df.out %<>% 
  mutate(BAG = BA - Age, 
         ExpAge = lm.Age$coefficients[[1]] + Age*lm.Age$coefficients[[2]],
         BAG_corr = BA - ExpAge, 
         BAG_corr_gam = sm.rel$residuals)

# ggpairs(df.out %>% 
#           select(BA, 
#                  Age, 
#                  BAG, 
#                  ExpAge, 
#                  BAG_corr,
#                  BAG_corr_gam), 
#         progress = F)

# note strong relationship between Age and BrainAge gap. 
# note that linear and gam-based bias reductions are almost equivalent
```

## Explore effects BAG in variables that should not be related to BAG
```{r Sex Scan Biases, echo =T}

print(df.out %>% group_by(sex) %>% 
  summarise_at(c("Age", "BA", "BAG", "BAG_corr", "BAG_corr_gam"), funs(mean, sd)))

print(df.out %>% group_by(Center.Reading) %>%
  summarise_at(c("Age", "BA", "BAG", "BAG_corr", "BAG_corr_gam"), funs(mean, sd)))

print(df.out %>% group_by(Center.Newcastle) %>%
  summarise_at(c("Age", "BA", "BAG", "BAG_corr", "BAG_corr_gam"), funs(mean, sd)))

print(df.out %>% summarise_at(c("Age", "BA", "BAG", "BAG_corr", "BAG_corr_gam"), funs(cor(.,eiCV))))
```

## Predict test data with same parameters as above
```{r xgboost with full train data}


if(!file.exists(file.path(data_folder, "xgbcv.full.Rda"))) {
  
  eta = xgb_grid_1[which.min(xgb_grid_1$rmse),"eta"]
  max_depth=xgb_grid_1[which.min(xgb_grid_1$rmse),"max_depth"]
  gamma = xgb_grid_1[which.min(xgb_grid_1$rmse),"gamma"]
  min_child_weight = xgb_grid_1[which.min(xgb_grid_1$rmse),"min_child_weight"]
  nrounds = xgb_grid_1[which.min(xgb_grid_1$rmse),"nrounds"]
  
  (system(paste("sbatch scripts/BrainAge/BrainAgeFull.sh", 
                   eta, 
                   max_depth, 
                   gamma, 
                   min_child_weight, 
                   nrounds,
                   data_folder,
                   sep = " ")))
  
  df.squeue = squeue("p274-didacvp","BrainAgeFull")
  while (length(df.squeue$V1) > 1) {
    Sys.sleep(120) 
    print("script running on sbatch")
    df.squeue = squeue("p274-didacvp","BrainAgeFull")
  }
} 

load(file.path(data_folder, "xgbcv.full.Rda"))
pred <- predict(bst, data.test)


#view variable importance plot
xgb.dump(bst, with_stats = TRUE, file.path(data_folder,'bst.model.dump.txt'))
mat <- xgb.importance (feature_names = T1w_vars,model = bst)
xgb.plot.importance (importance_matrix = mat, top = 15, rel_to_first = T, left_margin = 15)
xgb.ggplot.deepness(bst)
suppressMessages(xgb.plot.multi.trees(bst, fill = T))

```

## Fig 1e & summarize prediction results
```{r test data info}

df.pred = data.frame(
  eid = label.id,
  BA = pred,
  Age = label.test, 
  wave = df.Test$wave,
  sex = df.Test$sex,
  Center.Reading = df.Test$Center.Reading,
  Center.Newcastle = df.Test$Center.Newcastle, 
  eiCV = df.Test$eICV)


gs = ggplot(data = df.pred,
       mapping = aes(x = Age,
                     y = BA)) + 
  geom_line(mapping = aes(group = eid), size = .3, color = "grey50") +
  geom_pointdensity(adjust = 4) + 
  geom_abline(intercept = 0, 
              slope = 1, 
              colour = "grey60", 
              linetype = 4,
              size = 1.5) + 
  geom_smooth(method = "lm", 
              color = "#1b9e77",
              size = 1.5) + 
  geom_smooth(color = "#d95f02",
              size = 1.5) +
  geom_label( aes(x=50, y=80, label="Test set"), color="#1b9e77", size =4) +
  theme_classic() + 
  theme(legend.position = 'none',
        axis.text = element_text(size = 16),
        axis.title = element_text(size = 20)) +
  scale_color_viridis_c() +
  ylab("Brain Age") +
  xlab("Chronological Age")

ggsave("figures/TestXGV.png", 
       dpi = 500, 
       plot = gs,
       width = 10,
       height = 10, 
       units = "cm")
```

## Summarized model info. 
```{r get brain age stats}

lm.Age = lm(BA ~ Age, data = df.pred)
summary(lm.Age)
cor.test(df.pred$BA, df.pred$Age)
err <- mean(abs(pred - label.test))
rmse = sqrt(sum((pred - label.test)^2) / length(pred))
age.bias <- cor.test(df.pred$Age, (df.pred$BA - df.pred$Age))$estimate

print(paste("mean absolute error (MAE)=", round(err,2)))
print(paste("root mean square error=",round(rmse,2)))
print(paste("r-sq =", round(summary(lm.Age)$r.squared, 2)))
print(paste("age.bias =", age.bias))
```
## And correct age bias
```{r}
sm.rel_pred = gam(BA ~ s(Age, k = 8), data = df.pred) 
sm.rel = gam(BA ~ s(Age, k = 8), data = df.out) 
lm.Age2 = lm(BA ~ poly(Age,2, raw = T), data = df.out)
lm.Age = lm(BA ~ poly(Age,2, raw = T), data = df.out)



df.pred %<>% 
  mutate(BAG = BA - Age, 
         ExpAge = lm.Age$coefficients[[1]] + Age*lm.Age$coefficients[[2]],
         ExpAge2 = lm.Age2$coefficients[[1]] + Age*lm.Age2$coefficients[[2]] + Age^2*lm.Age2$coefficients[[3]],
         BAG_corr = BA - ExpAge, 
         BAG_corr2 = BA - ExpAge2, 
         BAG_corr_gam = BA - predict(sm.rel, df.pred))

# ggpairs(df.pred %>% 
#           select(BA, 
#                  Age, 
#                  BAG, 
#                  ExpAge, 
#                  BAG_corr,
#                  BAG_corr2,
#                  BAG_corr_gam), 
#         progress = F)

# we use gaage-bias correction using the cross-validated model so we do not enforce a dependency on test data
# note that results are almost equivalent as when using the cross-validated approach
plot(df.pred$BAG_corr_gam, sm.rel_pred$residuals)
cor(df.pred$BAG_corr_gam, sm.rel_pred$residuals)
```

## Explore effects BAG in variables that should not be related to BAG (Test Data)
```{r Sex_Scan_Biases_Pred}
Sys.sleep(1)

print(df.pred %>% group_by(sex) %>% 
  summarise_at(c("Age", "BA", "BAG", "BAG_corr", "BAG_corr_gam"), funs(mean, sd)))

print(df.pred %>% group_by(Center.Reading) %>%
  summarise_at(c("Age", "BA", "BAG", "BAG_corr", "BAG_corr_gam"), funs(mean, sd)))

print(df.pred %>% group_by(Center.Newcastle) %>%
  summarise_at(c("Age", "BA", "BAG", "BAG_corr", "BAG_corr_gam"), funs(mean, sd)))

print(df.pred %>% summarise_at(c("Age", "BA", "BAG", "BAG_corr", "BAG_corr_gam"), funs(cor(.,eiCV))))



```

### tp1 tp2 independently
process tp1 tp2 independently. so no need of applying corrections
```{r xgb tp1 tp2}
if(!file.exists(file.path(data_folder, "xgbcv.tp.ind.Rda"))) {
    load(file.path(data_folder,"RandomHyperParameterSearchCV.Rda"))
    data.test.tp1 = df.Test[df.Test$wave==1, T1w_vars] %>% as.matrix() %>% as.matrix()
    data.test.tp2 = df.Test[df.Test$wave==2, T1w_vars] %>% as.matrix() %>% as.matrix()
    label.test.tp1 = df.Test %>% filter(wave == 1) %>% .$age %>% as.matrix()
    label.test.tp2 = df.Test %>% filter(wave == 2) %>% .$age %>% as.matrix()
  
    eta = xgb_grid_1[which.min(xgb_grid_1$rmse),"eta"]
    max_depth=xgb_grid_1[which.min(xgb_grid_1$rmse),"max_depth"]
    gamma = xgb_grid_1[which.min(xgb_grid_1$rmse),"gamma"]
    min_child_weight = xgb_grid_1[which.min(xgb_grid_1$rmse),"min_child_weight"]
    nrounds = xgb_grid_1[which.min(xgb_grid_1$rmse),"nrounds"]
    nfold = 10
    params = list(booster = "gbtree",
                objective = "reg:squarederror",
                eta = eta,
                max_depth=max_depth,
                gamma = gamma,
                min_child_weight = min_child_weight)
    
    xgbcv.tp1 <- xgb.cv( params = params,
                     data = data.test.tp1,
                     label = label.test.tp1,
                     nrounds = nrounds,
                     nfold = nfold,
                     showsd = T,
                     stratified = T,
                     print_every_n = 10,
                     early_stop_round = 10,
                     maximize = F,
                     prediction = T)
  
    min_rmse = min(xgbcv.tp1$evaluation_log$test_rmse_mean)
    min_rmse_nround = which.min(xgbcv.tp1$evaluation_log$test_rmse_mean)
    print(paste("min rmse", round(min_rmse,3), "at round n", min_rmse_nround))
  
    tmp1 = data.frame(
      eid = df.Test %>% filter(wave == 1) %>% .$eid,
      BA_tp = xgbcv.tp1$pred,
      age = df.Test %>% filter(wave == 1) %>% .$age,
      wave = 1) 
    
    xgbcv.tp2 <- xgb.cv( params = params,
                     data = data.test.tp2,
                     label = label.test.tp2,
                     nrounds = nrounds,
                     nfold = nfold,
                     showsd = T,
                     stratified = T,
                     print_every_n = 10,
                     early_stop_round = 10,
                     maximize = F,
                     prediction = T)
  
    min_rmse = min(xgbcv.tp2$evaluation_log$test_rmse_mean)
    min_rmse_nround = which.min(xgbcv.tp2$evaluation_log$test_rmse_mean)
    print(paste("min rmse", round(min_rmse,3), "at round n", min_rmse_nround))
  
    tmp2 = data.frame(
      eid = df.Test %>% filter(wave == 2) %>% .$eid,
      BA_tp = xgbcv.tp2$pred,
      age = df.Test %>% filter(wave == 2) %>% .$age,
      wave = 2) 
    
    db.out.tp = rbind(tmp1,tmp2) %>% 
      mutate(BAG_tp = BA_tp - age)
  
    save(db.out.tp, file = file.path(data_folder, "xgbcv.tp.ind.Rda")) 
} else {
      load(file.path(data_folder, "xgbcv.tp.ind.Rda"))
    }
```


### saving results
merge data.frame and save results
```{r Saving and Knitting}
# vars to joing
vars = c("eid",
         "BA",
         "BAG",
         "ExpAge",
         "BAG_corr",
         "BAG_corr_gam")    
df.Train = 
  left_join(df.Train, 
            df.out %>% select(vars))

df.Test = 
  left_join(df.Test, 
            df.pred %>% select(vars, "wave"))


df.Test  = 
  left_join(df.Test, 
            db.out.tp %>% select(eid, wave, BA_tp, BAG_tp))
              
save(df.Test, df.Train, file = file.path(data_folder, "ResultsBrainAge_XGB.Rda"))
```



# LASSO prediction
Routine implemented as in Cole 2020. Modifications kept to the minumun to ensure as much consistency as possible with Cole 2020

## Designate subset as training and validation/model testing
```{r redefine train and test data}
train_data = df.Train[, T1w_vars] %>% as.matrix()
train_labels = df.Train$age %>% as.matrix()
test_data = df.Test[, T1w_vars] %>% as.matrix() %>% as.matrix()
test_labels = df.Test$age %>% as.matrix()
label.id = df.Test$eid %>% as.matrix()
```
### Scale variables.
This is essential for ANN and probably a good idea for all the models.
```{r scale cole2020}
# already done for kept for consistency with Cole 2020 flow
scaled.train_data <- scale(train_data, scale = TRUE, center = TRUE)
scaling.parameters.center <- attr(scaled.train_data, "scaled:center")
scaling.parameters.scale <- attr(scaled.train_data, "scaled:scale")
scaled.test_data <- as.data.frame(scale(test_data, scaling.parameters.center, scaling.parameters.scale))
scaled.train_data <- as.data.frame(scaled.train_data)
```

### Functions to output accuracy metrics and plot age by predicted age.
Take predicted age values as input.
```{r functionscole2020}
test_results <- function(pred) {
  r <- cor.test(test_labels, pred)$estimate
  r.sq <- summary(lm(test_labels ~ pred))$r.squared
  MAE <- mean(abs(pred - test_labels), na.rm = T)
  rmse = sqrt(sum((pred - test_labels)^2) / length(pred))
  age.bias <- cor.test(test_labels, (pred - test_labels))$estimate
  value <- sapply(c(r,r.sq, MAE, age.bias, rmse), function(x) round(x, 3))
  results <- cbind(c("r", "R^2", "MAE", "Age.bias", "rmse"), value)
  return(kable(results) %>% kable_styling(bootstrap_options = c("striped","condensed", "responsive", full_width = F, position = "centre")))
}
age_plot <- function(pred) {
  ggplot() +
    geom_abline(slope = 1, intercept = 0) +
    geom_point(aes(x = test_labels, y = pred), shape = 21, bg = "darkgoldenrod2", size = 2) +
    geom_smooth(aes(x = test_labels, y = pred), method = "lm", col = "darkgrey") +
    labs(title = deparse(substitute(pred)), x = "Age (years)", y = "Brain-predicted age (years)") +
    theme_cowplot()
  }
```

### LASSO regression
Using the glmnet package. Alpha = 1 is for LASSO penalisation (0 = ridge, 0.5 = elastic net).
```{r lasso reg}
x.train <- as.matrix(scaled.train_data)
dimnames(x.train) <- NULL
y.train <- as.matrix(train_labels)
## cross-validation for lambda
set.seed(1234)
lasso.fit.cv <- cv.glmnet(x = x.train, y = y.train,
                          alpha = 1, family = "gaussian")
```

Plot results. 
```{r lassofitplot1}
plot(lasso.fit.cv)
```

### LASSO model performance on validation data
Fit model using previously optimised (through CV) lambda value (1 SE value, not minimum).
```{r lasso val}
lasso.fit <- glmnet(x = x.train, y = y.train,
                    alpha = 1, family = "gaussian", lambda = lasso.fit.cv$lambda.1se)
lasso.pred <- predict(lasso.fit, newx = as.matrix(scaled.test_data))
test_results(lasso.pred)
```


```{r}
age_plot(lasso.pred)
ggsave(file.path("figures","LASSO_brain_age_scatterplot.pdf"), useDingbats = FALSE, dpi = 75, height = 4, width = 4)
```

### Variable weightings and feature selection results
```{r wieghting and selecting cole 2020}
LASSO.coefficient <- coef(lasso.fit, s = lasso.fit.cv$lambda.1se)[-1]
var.coefs <- data.frame(T1w_vars, LASSO.coefficient)
non.zero_vars <- subset(var.coefs, var.coefs$LASSO.coefficient != 0)
non.zero_vars$T1w_vars.list <- factor(non.zero_vars$T1w_vars)
```
Out of the original `r dim(var.coefs)[1]` variables, the LASSO regression set `r length(non.zero_vars$imaging.vars.list)` to non-zero, thus `r dim(var.coefs)[1] - length(non.zero_vars$imaging.vars.list)` variables were removed.

### Bootstrap LASSO
Bootstrap 95% confidence intervals. Uses the boot package.

#### Function to obtain LASSO regression coefficients
Essential to convert coefficients to vector that stores zeros.
```{r funcions2 cole2020}
lasso.coef <- function(data, indices) {
  d <- data[indices,]
  fit <- glmnet(x = d[,-1], y = d[,1],
                    alpha = 1, family = "gaussian", lambda = lasso.fit.cv$lambda.1se)
  return(coef(fit)[,1])
}
```

#### Run bootstrap with n replications
Normal printing and plotting of results doesn't work for high-dimensional datasets.
Load data file if it already exists.
```{r boostrap cole 2020}
if (file.exists(file.path(data_folder, "lasso.boot.out.rda"))) {
  load(file.path(data_folder, "lasso.boot.out.rda"))
  cat("loading existing bootstrap file")
  } else {
    cat("running bootstraps")
    boot.out <- boot(data = cbind(y.train, x.train), statistic = lasso.coef, R = 1000)
    save(boot.out, file = file.path(data_folder, "lasso.boot.out.rda"))
  }
```
There were `r table(boot.out$t0[-1] > 0 | boot.out$t0[-1] < 0)[2]` non-zero coefficients.

Check histogram of bootstrap coefficients for top variable by way of example.
```{r}
ggplot() +
  geom_histogram(bins = 100, aes(boot.out$t[,which.max(abs(boot.out$t0[-1])) + 1]),
                 fill = "darkgoldenrod2",
                 colour = "black",
                 lwd = 0.25) +
  xlab("Top variable bootstrapped coefficients") +
  theme_cowplot()
```

#### Function for getting CIs from vector
```{r functions3 cole 2020}
ci.vector <- function(index, boot.object, ci.type) {
  x <- boot.ci(boot.object, type = ci.type, index = index)
  return(x[4])
}
```

Use my ci.vector() function (defined above) to derive confidence intervals. 
```{r message=FALSE, warning=FALSE, paged.print=FALSE}
n <- length(boot.out$t0)
boot.ci.out <- sapply(1:n, ci.vector, boot.object = boot.out, ci.type = "basic")
x <- boot.out$t0[1:n]
y <- data.frame(t(matrix(unlist(boot.ci.out), ncol = n)))[4:5]
ci.df <- cbind(x, y)
names(ci.df) <- c("coef", "l.ci", "u.ci")
```

Identify variables with confidence intervals that do not overlap zero.
```{r idvar not 0 cole 2020 paged.print=FALSE}
# drop intercept from plot using [-1] in vector ci.df$l.ci and ci.df$u.ci (i.e., the intercept is the top row)
sig.vars.index <- which(ci.df$l.ci[-1] > 0 | ci.df$u.ci[-1] < 0)
sig.vars.list <- T1w_vars[sig.vars.index]
sig.vars.df <- ci.df[sig.vars.index + 1,] ## add 1 to omit intercept row
sig.vars.df <- cbind(sig.vars.list, round(sig.vars.df,3))
kable(sig.vars.df[order(abs(sig.vars.df$coef), decreasing = T),]) %>% kable_styling(bootstrap_options = c("striped","condensed", "responsive", full_width = F, position = "centre"), fixed_thead = list(enabled = T, background = "lightgrey"))
```

```{r warning=FALSE}
## sort dataset by coefficient
ci.df2 <- ci.df[order(ci.df$coef, decreasing = T),]
# drop intercept from plot using [-1,] in data.frame ci.df (i.e., the intercept is the top row)
plot(ci.df2[-1,1], ylim = c(min(ci.df2[-1,2]), max(ci.df2[-1,3])),
     pch = 20, col = "darkgoldenrod2", ylab = "LASSO coefficient") + 
  arrows(x0 = 1:(n - 1), y0 = ci.df2[-1,2], y1 = ci.df2[-1,3],
         length = 0.02, angle = 90, code = 3, col = "grey") +
  abline(h = 0, type = 2)
```

### Run model with only significant variables
#### Top variables OLS
```{r topOls predict cole 2020}
top.ols <- lm(train_labels ~ .,
          data = scaled.train_data[,sig.vars.index])
top.ols.pred <- predict(object = top.ols, newdata = scaled.test_data[,sig.vars.index])
test_results(top.ols.pred)
```

```{r}
age_plot(top.ols.pred)

```

#### Top variables LASSO
```{r top lasso cole 2020}
## fit model using optimal lambda value (1 SE value, not minimum)
set.seed(1234)
top.lasso.fit <- glmnet(x = x.train[,sig.vars.index], y = y.train,
                    alpha = 1, family = "gaussian", lambda = lasso.fit.cv$lambda.1se)
top.lasso.pred <- predict(top.lasso.fit, newx = as.matrix(scaled.test_data[,sig.vars.index]))
test_results(top.lasso.pred)

## get data on train data for age-bias corection of longitudinal data. 
top.lasso.train <- predict(top.lasso.fit, newx = as.matrix(scaled.train_data[,sig.vars.index]))
r <- cor.test(train_labels, top.lasso.train)$estimate
r.sq <- summary(lm(train_labels ~ top.lasso.train))$r.squared
MAE <- mean(abs(top.lasso.train - train_labels), na.rm = T)
rmse = sqrt(sum((top.lasso.train - train_labels)^2) / length(top.lasso.train))
age.bias <- cor.test(train_labels, (top.lasso.train - train_labels))$estimate
value <- sapply(c(r,r.sq, MAE, age.bias, rmse), function(x) round(x, 3))
results <- cbind(c("r", "R^2", "MAE", "Age.bias", "rmse"), value)
kable(results) %>% kable_styling(bootstrap_options = c("striped","condensed", "responsive", full_width = F, position = "centre"))

```


### Correct for age bias and save
Calculate age bias in initial test data.

```{r LASSO PRED df}
df.pred = data.frame(
  eid = label.id,
  LASSO.BA = top.lasso.pred %>% as.numeric(),
  Age = test_labels, 
  wave = df.Test$wave)

df.lasso.out = data.frame(
  eid = train.id,
  Age = train_labels,
  LASSO.BA = top.lasso.train %>% as.numeric())

sm.rel = gam(LASSO.BA ~ s(Age, k = 8), data = df.lasso.out)
lm.Age = lm(LASSO.BA ~ Age, data = df.lasso.out)
lm.Age2 = lm(LASSO.BA ~ poly(Age,2, raw = T), data = df.lasso.out)

df.lasso.out %<>% 
  mutate(LASSO.BAG = LASSO.BA - Age, 
         LASSO.BAG_corr_gam = sm.rel$residuals)

df.pred %<>% 
  mutate(LASSO.BAG = LASSO.BA - Age, 
         ExpAge = lm.Age$coefficients[[1]] + Age*lm.Age$coefficients[[2]],
         ExpAge2 = lm.Age2$coefficients[[1]] + Age*lm.Age2$coefficients[[2]] + Age^2*lm.Age2$coefficients[[3]],
         LASSO.BAG_corr = LASSO.BA - ExpAge, 
         LASSO.BAG_corr2 = LASSO.BA - ExpAge2,
         LASSO.BAG_corr_gam = LASSO.BA - predict(sm.rel, df.pred))

# ggpairs(df.pred %>% 
#           select(LASSO.BA, 
#                  Age, 
#                  LASSO.BAG, 
#                  ExpAge, 
#                  LASSO.BAG_corr,
#                  LASSO.BAG_corr2,
#                  LASSO.BAG_corr_gam), 
#         progress = F)


gs = ggplot(data = df.pred,
       mapping = aes(x = Age,
                     y = LASSO.BA)) + 
  geom_line(mapping = aes(group = eid), size = .3, color = "grey50") +
  geom_pointdensity(adjust = 4) + 
  geom_abline(intercept = 0, 
              slope = 1, 
              colour = "grey60", 
              linetype = 4,
              size = 1.5) + 
  geom_smooth(method = "lm", 
              color = "#1b9e77",
              size = 1.5) + 
  geom_smooth(color = "#d95f02",
              size = 1.5) +
  theme_classic() + 
  theme(legend.position = 'none',
        axis.text = element_text(size = 16),
        axis.title = element_text(size = 20)) +
  scale_color_viridis_c() +
  ylab("Brain Age") +
  xlab("Chronological Age")

ggsave("figures/TestLASSO.png", 
       dpi = 500, 
       plot = gs,
       width = 10,
       height = 10, 
       units = "cm")


# vars to joing
vars = c("eid",
         "LASSO.BA",
         "LASSO.BAG",
         "LASSO.BAG_corr_gam")    
df.Test = 
  left_join(df.Test, 
            df.pred %>% select(vars, "wave"))

df.Train = 
  left_join(df.Train, 
            df.lasso.out %>% select(vars))

save(df.Test, df.Train, T1w_vars, file = file.path(data_folder, "ResultsBrainAge_XGB_LASSO.Rda"))
```


## scaling
```{r scaling}
## apply scaling according to Smith 2.8 formula in Smith Neuroimage 2019.


load(file.path(data_folder, "ResultsBrainAge_XGB_LASSO.Rda"))

set = c(rep("Test", length(df.Test$BAG_corr_gam)),
        rep("Train", length(df.Train$BAG_corr_gam)))
delta = c(df.Test$BAG_corr_gam,df.Train$BAG_corr_gam)
age = c(df.Test$age,df.Train$age)

# Y0
scaled_age = ((age - min(age)) / (max(age) - min (age)))
# eq.18 (log(|delta|)) = D0 + lambda*scaled_aged
mod = lm(log(abs(delta)) ~ scaled_age)
summary(mod)

# lambda =  exp(lambda) - 1
lambda = exp(coef(mod)[["scaled_age"]])-1
mean(abs(delta))*lambda

# eq.17 delta = delta0(1 + lamda*scaled_age)
delta_0_XGB <- delta / (1 + lambda * scaled_age)

par(mfrow=c(1,3)) 
plot(age, delta)
plot(age, delta_0_XGB)
plot(delta, delta_0_XGB)


#  same for LASSO 
delta = c(df.Test$LASSO.BAG_corr_gam,df.Train$LASSO.BAG_corr_gam)
mod = lm(log(abs(delta)) ~ scaled_age)
summary(mod)

# lambda =  exp(lambda) - 1
lambda = exp(coef(mod)[["scaled_age"]])-1
mean(abs(delta))*lambda

# eq.17 delta = delta0(1 + lamda*scaled_age)
delta_0_LASSO <- delta / (1 + lambda * scaled_age)

par(mfrow=c(1,3)) 
plot(age, delta)
plot(age, delta_0_LASSO)
plot(delta, delta_0_LASSO)


df.Test$delta_0_XGB = delta_0_XGB[set == "Test"]
df.Train$delta_0_XGB = delta_0_XGB[set == "Train"]

df.Test$delta_0_LASSO = delta_0_LASSO[set == "Test"]
df.Train$delta_0_LASSO = delta_0_LASSO[set == "Train"]

df.Test$BAG_tp0 <- df.Test$BAG_tp / (1 + lambda * scaled_age[set == "Test"])
save(df.Test, 
     df.Train, 
     file = file.path(data_folder, "ResultsBrainAge_XGB_LASSO.Rda"))

```
back to top