1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#' Utility to find where a formula call matches a name.
#' This may be cleaner than using grep on strings:)
#' @param name quoted name to match
#' @param x right-hand-side of a formula call
#' @return index of the matching positions
grep_call = function(name,x) {
    local_function = function(x)
        if(length(x)==1) x==name else any(sapply(x, local_function))
    which(sapply(x, local_function))
}

#' Extract design information from an stpm2/gsm object and newdata
#' for use in C++
#' @param object stpm2/gsm object
#' @param newdata list or data-frame used for evaluation
#' @param inflate double value to inflate minimum and maximum times for root finding
#' @return list that can be read by `gsm ssim::read_gsm(SEX args)` in C++
#' @rdname gsm_design
#' @importFrom stats predict
#' @export
gsm_design = function(object, newdata, newdata0=NULL, t0=NULL, inflate=100) {
    stopifnot(inherits(object, "stpm2"),
              is.list(newdata),
              is.numeric(inflate),
              length(inflate) == 1,
              is.null(t0) || (is.numeric(t0) && length(t0)==1),
              is.null(newdata0) || all(dim(newdata)==dim(newdata0)))
    ## Assumed patterns:
    ## timeEffect := (ns|nsx)(log(timeVar),knots,Boundary.knots,centre=FALSE,derivs=(c(2,2)|c(2,1)))
    ## effect := timeEffect | otherEffect:timeEffect | timeEffect:otherEffect
    terms = attr(object@model.frame, "terms")
    factors = attr(terms, "factors")[-1,,drop=FALSE]
    variables = attr(terms, "variables")[-(1:2)]
    predvars = attr(terms, "predvars")[-(1:2)]
    indices = grep_call(object@timeVar, variables)
    if(length(indices)==0) stop("No timeVar in the formula -- unexpected error")
    index_time_variables = grep_call(object@timeVar, variables) # time variables
    index_time_effects = grep(object@timeVar,colnames(factors)) # components in the rhs with time variables
    ## We need to know how wide is each term
    nms = names(coef(object))
    term.labels = attr(terms, "term.labels")
    coef_index <-
        sapply(strsplit(nms, ":"), function(c) {
            if (length(c)>2)
                stop("current implementation only allows for main effects and two-way interactions")
            pmatchp = function(x, table)
                !is.na(pmatch(x, table))
            if(length(c)==1) {
                for (i in seq_along(term.labels)) {
                    t = term.labels[i]
                    if (pmatchp(t,c))
                        return(i)
                }
                if (c == "(Intercept)")
                    return(0)
                return(-1)
            }
            ## => length(c) == 2
            term.labels.split = strsplit(term.labels, ":")
            for (i in seq_along(term.labels)) {
                t = term.labels.split[[i]]
                if (all(pmatchp(t,c)))
                    return(i)
            }
            return(-1)
        })
    parse_ns = function(mycall,x,index_time_effect) {
        df = length(c(mycall$knots, mycall$Boundary.knots)) - 1
        stopifnot(mycall[[1]] == quote(nsx) || mycall[[1]] == quote(ns),
                  length(mycall[[2]])>1,
                  mycall[[2]][[1]] == quote(log), # assumes log
                  mycall[[2]][[2]] == object@timeVar, # what about a scalar product or divisor?
                  is.null(mycall$deriv) || (mycall$derivs[1] == 2 && mycall$derivs[2] %in% 1:2),
                  mycall$centre == FALSE) # doesn't allow for centering
        cure = !is.null(mycall$derives) && all(mycall$derivs == c(2,1))
        time = object@args$time
        q_const = attr(nsx(log(mean(time)), knots=mycall$knots,
                           Boundary.knots=mycall$Boundary.knots,
                           intercept=mycall$intercept),
                       "q.const")
        list(call = mycall,
             knots=mycall$knots,
             Boundary_knots=mycall$Boundary.knots,
             intercept=as.integer(mycall$intercept),
             gamma=coef(object)[which(coef_index %in% index_time_effect)],
             q_const = q_const,
             cure = as.integer(cure),
             x=x)
    }
    time = object@args$time
    newdata[[object@timeVar]] = mean(time) # NB: time not used
    Xp = predict(object, newdata=newdata, type="lpmatrix")
    index2 = which(!(coef_index %in% index_time_effects))
    etap = drop(Xp[, index2, drop=FALSE] %*% coef(object)[index2])
    if (!is.null(newdata0)) {
        newdata0[[object@timeVar]] = mean(time) # NB: time not used
        Xp0 = predict(object, newdata=newdata0, type="lpmatrix")
        etap0 = drop(Xp0[, index2, drop=FALSE] %*% coef(object)[index2])
    }
    list(type="gsm",
         link_name=object@args$link,
         tmin = min(time), # not currently used?
         tmax = max(time),
         inflate=as.double(inflate),
         etap=etap,
         etap0= if(!is.null(newdata0)) etap0 else etap*0,
         t0 = if(!is.null(t0)) pmax(1e-10,t0) else Inf,
         coefp = coef(object)[index2], # for debugging
         log_time=TRUE,
         terms =
             lapply(index_time_effects,
                    function(i) {
                        j = which(factors[,i] != 0)
                        if (length(j)==1)
                            return(parse_ns(predvars[[j]], rep(1, nrow(newdata)), i))
                        else {
                            if(length(j)>3)
                                stop("Current implementation only allows for two-way interaction terms")
                            if (j[1] %in% index_time_variables)
                                return(parse_ns(predvars[[j[1]]], eval(predvars[[j[2]]], newdata), i))
                            else return(parse_ns(predvars[[j[2]]], eval(predvars[[j[1]]], newdata), i))
                        }
                    })
         )
}