https://github.com/twolock/distreg-illustration
Raw File
Tip revision: a7f808f2cde2bb16edde8fdcbfa6e208df7952f9 authored by Tim Wolock on 02 June 2021, 09:05:36 UTC
Update README.md
Tip revision: a7f808f
Illustration.Rmd
---
title: "Distributional Regression Illustration"
author: "Tim Wolock"
date: "05/03/2021"
output:
  bookdown::pdf_document2:
    toc: true
    number_sections: false
bibliography: references.bib
---

```{r setup, include=FALSE}
library(knitr)
library(kableExtra)
opts_chunk$set(echo = TRUE, message = F, warning = F,
               error = F,
               fig.align = 'center', fig.height = 3)
```

In this document, which accompanies Wolock *et al.* (2021), we will demonstrate how to fit a distributional model with a sinh-arcsinh likelihood in `brms` [@bürkner2018]. First, we will outline how to implement a custom family in `brms`, then we will simulate sinh-arcsinh distributed data with moments that vary with data, and finally, we will fit one conventional regression model and two distributional models.

The likelihood we will be using comes from @jones2009. The density of a sinh-arcsinh random variable, $X$, is

$$
p(x\;|\;\mu, \sigma, \epsilon, \delta) = \frac{1}{\sigma\sqrt{2\pi}}\frac{\delta C_{\epsilon, \delta}(x_z)}{\sqrt{1+x_z^2}} \exp\left[-\frac{S_{\epsilon,\delta}(x_z)^2}{2}\right],
$$

where $x_z = (x-\mu)/\sigma$, $S_{\epsilon,\delta}(x) = \sinh(\epsilon + \delta\mathrm{asinh}(x))$, and $C_{\epsilon,\delta}=\cosh(\epsilon + \delta\mathrm{asinh}(x))$. As expected, $\mu$ and $\sigma$ control the location and scale of the distribution. Both $\sigma$ and $\delta$ must be greater than zero.

## Setup

To step through this analysis yourself, clone or download this repository, open `distreg-illustration.Rproj` in RStudio, and open the `.Rmd` file in the RStudio editor. Be sure to have the following package installed:

1.  `rstan`
2.  `brms`
3.  `data.table`
4.  `ggplot2`

`data.table` is not strictly necessary for the analysis, but it will be useful when constructing posterior summaries.

First, we will load these required packages and configure some useful options.

```{r options, include=T, results='hide', message=F, warning=F}
library(brms)
library(rstan)
library(data.table)
library(ggplot2)

set.seed(10)
options(mc.cores = parallel::detectCores())
theme_set(theme_bw())
```

Now, we will load the functions necessary to fit a custom family in `brms`.

```{r stan_funs, include=T}
source('R/stan_funs.R')
stanvars <- stanvar(scode = stan_funs, block = "functions")
```

This `.R` file loads a `brms` family, two functions, and a string containing Stan functions into the global environment. The custom family, `sinhasinh`, will allow us to fit a model in `brms` with a likelihood that is not included in `brms` by default. It is defined as

```{r sinhFamily, eval=F}
sinhasinh <- custom_family(
  name = "sinhasinh",
  dpars = c("mu", "sigma", "eps", "delta"),
  links = c("identity", "log", "identity", "log"),
  lb = c(NA, NA, NA, NA),
  type = "real"
)
```

In the `dpars` argument, we are defining the names of the parameters of custom family. Note that every custom family in `brms` must have "mu" as its first distributional parameter. With `links`, we are defining the link functions for the linear models for each distributional parameter: both $\sigma$ and $\delta$ must be greater than zero, so we assign them `log` link functions. The `lb` argument sets the lower bounds of each parameter, and the `type` argument defines the support of the distribution.

The R script we ran also adds two functions to the global environment:

1.  `log_lik_sinhasinh`: allows `brms` to calculate the log-likelihood of the sinh-arcsinh distribution using Stan functions that we will expose later on.
2.  `posterior_predict_sinhasinh`: allows `brms` to sample from the sinh-arcsinh distribution similarly using exposed functions from Stan.

Both of these functions are named in accordance with `brms` convention (`log_lik_FAMILYNAME` and `predict_FAMILYNAME`) so that we will be able to use the posterior prediction and checking tools built-in to the package.

Finally, the script sets a string called `stan_funs` in the global environment. This string contains two Stan functions, `sinhasinh_lpdf` and `sinhasinh_rng`, which allows to calculate the sinh-arcsinh log-density and take sinh-arcsinh samples, respectively. The log-density function with that particular name is required for the custom family to work. We will expose both functions to R after fitting the first model, and the two R functions defined above will wrap them.

For convenience, we also define a native R function to produce sinh-arcsinh samples.

```{r sinhSampler}
rsinhasinh <- function(n, mu = 0, sigma = 1, eps = 0, delta = 1) {
  mu + sigma * sinh((asinh(rnorm(n, 0, 1)) - eps) / delta)
}
```

## Simulated Data

We will generate sinh-arcsinh distributed data with moments that vary with data to fit to. First, we define the population of interest.

```{r definePop, results='asis'}
N <- 5000
ages <- round(runif(N, 15, 64))
pct_f <- 0.6
sexes <- c('Male', 'Female')[rbinom(N, 1, pct_f) + 1]

# Load everything into a dataframe
age_df <- data.frame(age = ages, sex = sexes, id = 1:N)
age_df$scaled_age <- (ages - mean(ages)) / sd(ages)
age_df$sex <- relevel(factor(age_df$sex), ref = 'Female')
age_df$age_bin <- with(age_df, age - age %% 5)
```

Now we will use distribution regression coefficients from Wolock (2021) to predict distributional parameters for each individual. These coefficients come from a model with linear age-sex interactions for all four parameters, so we need to be sure to define the model matrix in the same way.

```{r predictParams}
# Hard-coded coefficients from paper
B_mu <- c(0.1, -0.02, -0.18, 0.01)
B_sigma <- c(-2.44, -0.11, 0.00, 0.26)
B_eps <- c(-0.2, 0.01, 0.36, 0.06)
B_delta <- c(-0.46, -0.05, 0.01, 0.04)

# Create model matrix for simulating data
X <- model.matrix(id ~ scaled_age * sex, data = age_df)

# Get true distributional parameter values
age_df$mu <- X %*% B_mu
age_df$delta <- exp(X %*% B_delta)
age_df$sigma <- exp(X %*% B_sigma) * age_df$delta
age_df$eps <- X %*% B_eps
```

With all four distributional parameters in hand for every individual, we can sample sinh-arcsinh distributed outcomes. In this case, the dependent variable of the model those coefficients are associated with the log-ratio of partner's age to respondent's age.

```{r simOutcome}
# Sample with function we defined earlier
age_df$log_ratio <- with(age_df, rsinhasinh(N, mu, sigma, eps, delta))
# Calculate partner age
age_df$p_age <- round(exp(age_df$log_ratio) * age_df$age)
```

We can plot the resulting distribution for several five-year age bins by sex.

```{r simHist, echo=F, fig.height=4}
ggplot(age_df[age_df$age_bin %in% c(20, 25, 30),],
       aes(x = p_age)) +
  geom_bar() +
  facet_grid(sex ~ age_bin, scales='free_y') +
  coord_cartesian(expand = 0) +
  labs(x = 'Partner age', y = 'Number of observations')
```

## Fitting Models

Now, we will fit three models in `brms`:

1.  **Conventional:** regression with linear age-sex interaction for the location parameter and constant higher-order parameters
2.  **Distributional 1:** distributional regression with linear age-sex interaction for location and independent age and sex effects for higher-order parameters
3.  **Distributional 2:** distributional regression with linear age-sex interactions for all four parameters

### Conventional Regression

First, we fit the conventional model. We use the `bf` function to build a `brms` formula with models for all four distributional parameters in the `sinhasinh` family. Note that we use the first formula (corresponding to `mu`) to set the outcome variable. We include the `stanvars` object defined by `R/stan_funs.R` in the model using the `stanvars` argument of `brm`.

```{r convBRM, results='hide'}
bf_1 <- bf(
  log_ratio ~ scaled_age * sex,
  sigma ~ 1,
  eps ~ 1,
  delta ~ 1
)

brm_1 <- brm(formula = bf_1,
             data = age_df,
             family = sinhasinh,
             stanvars = stanvars)
```

```{r brm1PostSummary, results='asis'}
kbl(posterior_summary(brm_1), digits = 2, booktabs = T) %>%
    kable_styling(c("striped"))
```

Because of our slightly unconventional specification for the last three parameters, `brm` returns estimates for `b_Intercept` parameters for all four distributional parameters.

The R functions we defined in `R/stan_funs.R` give us what we need to use the posterior checking tools built into `brms`. We just need to expose the Stan functions that the R functions rely on.

```{r exposeFns, results='hide', error=F, warning=F, message=F}
expose_functions(brm_1, vectorize = T, show_compiler_warnings=F)
```

### Distributional Regression 1

To fit a distributional model in `brms`, we just modify the `bf` formula.

```{r distBRM1, results='hide'}
bf_2 <- bf(
  log_ratio ~ scaled_age * sex,
  sigma ~ scaled_age + sex,
  eps ~ scaled_age + sex,
  delta ~ scaled_age + sex
)

brm_2 <- brm(formula = bf_2,
             data = age_df,
             family = sinhasinh,
             stanvars = stanvars)
```

```{r brm2PostSummary, results='asis'}
kbl(posterior_summary(brm_2), digits = 2, booktabs = T) %>%
    kable_styling(c("striped"))
```

We can see that we now have slopes with respect to age and sex for all four distributional parameters.

### Distributional Regression 2

Finally, we fit a distributional model with age-sex interactions for all four distributional parameters, which we know is the correct model.

```{r distBRM2, results='hide'}
bf_3 <- bf(
  log_ratio ~ scaled_age * sex,
  sigma ~ scaled_age * sex,
  eps ~ scaled_age * sex,
  delta ~ scaled_age * sex
)

brm_3 <- brm(formula = bf_3,
             data = age_df,
             family = sinhasinh,
             stanvars = stanvars)
```

```{r brm3PostSummary, results='asis'}
kbl(posterior_summary(brm_3), digits = 2, booktabs = T) %>%
    kable_styling(c("striped"))
```

## Model Comparison

Because we have defined the `log_lik` and `predict` functions `brms` expects, we can use the LOO-CV functions [@vehtari2017] built into the package. Note that this step can be computationally intensive.

```{r loo}
loo_res <- loo(brm_1, brm_2, brm_3)
```

This function estimates the expected log-posterior densities (ELPDs) for all three model, as well as the comparison of all three models. We can print a single model's results first.

```{r looBRM1}
print(loo_res$loos$brm_1)
```

We can also print the comparison. A negative value of

```{r looDiffs}
print(loo_res$diffs)
```

We can see that, when we compare the two distributional models (`brm_2` and `brm_3`) the absolute value of ratio of the ELPD difference to the standard error of the difference is `r sprintf('%0.2f', with(loo_res, abs(diffs[2,'elpd_diff']/diffs[2,'se_diff'])))`, suggesting that the second distributional model is significantly better than the first.

## Posterior Prediction

We can create an evenly spaced `data.frame` to predict over.

```{r predDF}
pred_df <- merge(15:64, c('Female', 'Male'))
names(pred_df) <- c('age', 'sex')
pred_df$scaled_age <- (pred_df$age - mean(ages)) / sd(ages)
pred_df$sex <- relevel(factor(pred_df$sex), ref = 'Female')
pred_df$log_ratio <- rep(0, nrow(pred_df))
pred_df$p_age <- round(exp(pred_df$log_ratio) * pred_df$age)
```

We put all three fit objects into a list to make prediction slightly easier. When we apply the `prepare_predictions` function to each fit object, `brms` will generate posterior predictive samples.

```{r pred}
fit_l <- list('Conv' = brm_1,
              'Dist 1' = brm_2,
              'Dist 2' = brm_3)
```

### Posterior Predictive Distributions

We can get posterior predictive samples and construct a `data.frame` for histograms.

```{r getPost}
posterior_l <- lapply(fit_l, posterior_predict, newdata=pred_df)

# Using data.table for convenience
post_dt <- rbindlist(lapply(posterior_l, function(x) {
  melt(cbind(data.table(pred_df), t(x)),
       id.vars=names(pred_df))}),
  idcol = 'Model')
post_dt[, p_age_post := round(exp(value) * age)]

post_hist_dt <- post_dt[, .N, by = .(Model, sex, age, p_age = p_age_post)]
post_hist_dt[, total_N := sum(N), by=.(Model, sex, age)]
post_hist_dt[, share := N / total_N]

post_hist_df <- data.frame(post_hist_dt)
```

We can compare our predictive distributions to the observed distributions.

```{r postPlot, fig.height=4}
hist_dt <- data.table(age_df)[, .N, by=.(age, sex, p_age)]
hist_dt[, total_N := sum(N), by=.(age, sex)]
hist_dt[, share := N / total_N]
hist_df <- data.frame(hist_dt)

ggplot() + 
  geom_bar(data = hist_df[hist_df$age %in% c(24, 32, 41),],
           aes(x = p_age,
               y = share),
           stat = 'identity') +
  geom_step(data = post_hist_df[post_hist_df$age %in% c(24, 32, 41) &
                                  post_hist_df$Model == 'Dist 2',],
            aes(x = p_age,
                y = share),
            direction = 'hv') +
  facet_grid(sex ~ age) +
  coord_cartesian(expand = 0)

```

### Posterior Distributional Parameter Summaries

Now, we will extract estimates of the distributional parameters.

```{r getDpar}
# Get all dpar predictions
prediction_l <- lapply(fit_l, prepare_predictions, newdata=pred_df)
# A data.table-y way to get posterior summaries of dpars
dpar_l <- lapply(c('mu', 'sigma', 'eps', 'delta'), function(d) {
  lapply(prediction_l, function(x) {
    cbind(data.table(pred_df), t(apply(brms:::get_dpar(x, dpar = d), 2,
                                       quantile,
                                       probs=c(0.025, 0.5, 0.975))))
  })
})
names(dpar_l) <- c('mu', 'sigma', 'eps', 'delta')

dpar_df <- data.frame(rbindlist(lapply(dpar_l, rbindlist, idcol = 'Model'), idcol = 'dpar'))
dpar_df$dpar <- factor(dpar_df$dpar, levels = c('mu', 'sigma', 'eps', 'delta'))
```

We can plot the posterior CIs to see the effects of adding distributional models.

```{r dparPlot, fig.height=6}
ggplot(dpar_df,
       aes(x = age,
           ymin = `X2.5.`,
           y = `X50.`,
           ymax = `X97.5.`,
           fill = sex,
           color = sex)) +
  geom_ribbon(alpha=0.3, color=NA) +
  geom_line() + 
  labs(x=NULL, y=NULL) +
  coord_cartesian(expand = 0) +
  facet_grid(dpar ~ Model, scales='free') + 
  theme_bw() + theme(legend.position = 'bottom')
```
From here, we should be able to use any further `brms` processing tools we would like to.

## References

<div id="refs"></div>

## Session Info

```{r sessionInfo, echo=F}
sessionInfo()
```

back to top