Raw File
#include <RcppArmadillo.h>
#include <c_optim.h>
#include <splines.h>
#include <aft.h>

namespace rstpm2 {
  using namespace arma;
  using namespace Rcpp;

  aft::aft(SEXP list) : args(as<List>(list)) {
    init = as<vec>(args["init"]);
    X = as<mat>(args["X"]);
    XD = as<mat>(args["XD"]);
    XD0 = as<mat>(args["XD0"]);
    event = as<vec>(args["event"]);
    time = as<vec>(args["time"]);
    boundaryKnots = as<vec>(args["boundaryKnots"]);
    interiorKnots = as<vec>(args["interiorKnots"]);
    q_matrix = as<mat>(args["q.const"]);
    cure = as<int>(args["cure"]);
    s = ns(boundaryKnots, interiorKnots, q_matrix, 1, cure);
    delayed = as<bool>(args["delayed"]);
    if (delayed) {
      time0 = as<vec>(args["time0"]);
      X0 = as<mat>(args["X0"]);
    }
    kappa = 1.0e3;
  }
  mat aft::rmult(mat m, vec v) {
    mat out(m);
    out.each_col() %= v;
    return out;
  }
  mat aft::rmult(mat m, uvec v) {
    mat out(m);
    out.each_col() %= conv_to<vec>::from(v);
    return out;
  }
  double aft::objective(vec betafull)
  {
    vec beta = betafull.subvec(0,X.n_cols-1);
    vec betas = betafull.subvec(X.n_cols,betafull.size()-1);
    vec eta = X * beta;
    vec etaD = XD * beta;
    vec logtstar = log(time) - eta;
    vec etas = s.basis(logtstar) * betas;
    vec etaDs = s.basis(logtstar,1) * betas;
    // fix bounds on etaDs
    vec eps = etaDs*0. + 1e-8;
    double pen = dot(min(etaDs,eps), min(etaDs,eps));
    etaDs = max(etaDs, eps);
    // fix bounds on etaD
    pen += dot(min(1-etaD, eps), min(1-etaD,eps));
    etaD = 1 - max(1-etaD, eps);
    // add penalty for monotone splines
    vec betasStar = s.q_matrix.t() * betas;
    for (size_t i=1; i<betasStar.size(); i++) {
      double delta = betasStar(i)-betasStar(i-1);
      if (delta<0.0)
	pen += kappa*delta*delta;
    }
    vec logh = etas + log(etaDs) + log(1/time - etaD/time);
    vec H = exp(etas);
    double f = pen - (dot(logh,event) - sum(H));
    if (delayed) {
      vec eta0 = X0 * beta;
      vec logtstar0 = log(time0) - eta0;
      vec etas0 = s.basis(logtstar0) * betas;
      vec etaDs0 = s.basis(logtstar0,1) * betas;
      vec H0 = exp(etas0);
      vec eps0 = etaDs0*0. + 1e-16;
      f += dot(min(etaDs0,eps0), min(etaDs0,eps0));
      f -= sum(H0);
    }
    return f;
  }
  vec aft::gradientPenalty(mat Q, vec beta) { // Q: (nbeta+2) x nbeta
    size_t n = Q.n_rows;
    mat D = join_rows(zeros(n-1,1),eye(n-1,n-1)) - join_rows(eye(n-1,n-1),zeros(n-1,1)); // (nbeta+1) x (nbeta+2)
    vec delta = D * Q * beta; // nbeta+1
    mat M = Q.t() * D.row(0).t() * D.row(0) * Q * (delta(0)<0.0); // nbeta x nbeta
    for(size_t j=1; j<delta.size(); j++) {
      if (delta(j)<0.0)
	M += Q.t() * D.row(j).t() * D.row(j) * Q;
    }
    return 2*M*beta*kappa;
  }
  vec aft::gradient(vec betafull)
  {
    vec beta = betafull.subvec(0,X.n_cols-1);
    vec betas = betafull.subvec(X.n_cols,betafull.size()-1);
    vec eta = X * beta;
    vec etaD = XD * beta;
    vec logtstar = log(time) - eta;
    mat Xs = s.basis(logtstar);
    mat XDs = s.basis(logtstar,1);
    mat XDDs = s.basis(logtstar,2);
    vec etas = Xs * betas;
    vec etaDs = XDs * betas;
    vec etaDDs = XDDs * betas;
    // H calculations
    vec H = exp(etas);
    mat dHdbetas = rmult(Xs,H);
    mat dHdbeta = -rmult(X,H % etaDs);
    // penalties
    vec eps = etaDs*0. + 1e-8;
    uvec pindexs = (etaDs < eps);
    uvec pindex = ((1.0 - etaD) < eps);
    // fix bounds on etaDs
    // mat pgrads = join_rows(-2*rmult(X,etaDs % etaDDs),2*rmult(XDs,etaDs));
    mat pgrads = join_rows(X*0.0,2*rmult(XDs,etaDs));
    etaDs = max(etaDs, eps);
    // fix bounds on etaD
    // mat pgrad = join_rows(-2*rmult(XD,1/time-etaD),XDs*0.0);
    // 
    mat pgrad = join_rows(-2.0*rmult(X,etaDs % etaDDs)-2.0*rmult(XD,1-etaD),Xs*0.0);
    etaD = 1 - max(1-etaD, eps);
    vec logh = etas + log(etaDs) + log(1/time - etaD/time);
    vec h = exp(logh);
    mat dloghdbetas = Xs+rmult(XDs,1/etaDs % (1-pindexs));
    mat dloghdbeta = -rmult(X,etaDs % (1-pindexs)) - rmult(X,etaDDs/etaDs % (1-pindexs)) - rmult(XD, (1-pindex)/time/(1-etaD));
    mat gradi = join_rows(-rmult(dloghdbeta,event)+dHdbeta, -rmult(dloghdbetas,event)+dHdbetas) + rmult(pgrad,pindex) + rmult(pgrads,pindexs);
    // mat gradi = join_rows(-rmult(dloghdbeta,event)+dHdbeta, -rmult(dloghdbetas,event)+dHdbetas);
    vec out = sum(gradi,0).t();
    out += join_cols(beta*0.0, gradientPenalty(s.q_matrix.t(), betas));
    if (delayed) {
      vec eta0 = X0 * beta;
      vec etaD0 = XD0 * beta;
      vec logtstar0 = log(time0) - eta0;
      mat Xs0 = s.basis(logtstar0);
      mat XDs0 = s.basis(logtstar0,1);
      mat XDDs0 = s.basis(logtstar0,2);
      vec etas0 = Xs0 * betas;
      vec etaDs0 = XDs0 * betas;
      vec etaDDs0 = XDDs0 * betas;
      vec H0 = exp(etas0);
      mat dHdbetas0 = rmult(Xs0,H0);
      mat dHdbeta0 = -rmult(X0,H0 % etaDs0);
      vec eps0 = etaDs0*0. + 1e-8;
      uvec pindex0 = ((1.0 - etaD0) < eps0);
      uvec pindexs0 = (etaDs0 < eps0);
      etaDs0 = max(etaDs0, eps0);
      // mat pgrad0 = join_rows(-2*rmult(XD0,1-etaD0),XDs0*0.0);
      // mat pgrads0 = join_rows(-2*rmult(X0,etaDs0 % etaDDs0),2*rmult(XDs0,etaDs0));
      mat pgrad0 = join_rows(-2.0*rmult(X0,etaDs0 % etaDDs0)-2.0*rmult(XD0,1-etaD0),Xs0*0.0);
      mat pgrads0 = join_rows(X0*0.0,2*rmult(XDs0,etaDs0));
      // out += sum(join_rows(-dHdbeta0, -dHdbetas0), 0).t();
      out += sum(join_rows(-dHdbeta0, -dHdbetas0) + rmult(pgrads0,pindexs0) +
		 rmult(pgrad0,pindex0), 0).t();
    }
    return out;
  }
  double aft::objective(NumericVector betafull) {
    return objective(as<vec>(wrap(betafull)));
  }
  NumericVector aft::gradient(NumericVector betafull) {
    vec value = gradient(as<vec>(wrap(betafull)));
    return as<NumericVector>(wrap(value));
  }
  vec aft::survival(vec time, mat X) {
    vec beta = init.subvec(0,X.n_cols-1);
    vec betas = init.subvec(X.n_cols,init.size()-1);
    vec eta = X * beta;
    vec logtstar = log(time) - eta;
    vec etas = s.basis(logtstar) * betas;
    vec S = exp(-exp(etas));
    return S;
  }

  vec aft::haz(vec time, mat X, mat XD)
  {
    vec beta = init.subvec(0,X.n_cols-1);
    vec betas = init.subvec(X.n_cols,init.size()-1);
    vec eta = X * beta;
    vec etaD = XD * beta;
    vec logtstar = log(time) - eta;
    mat Xs = s.basis(logtstar);
    mat XDs = s.basis(logtstar,1);
    mat XDDs = s.basis(logtstar,2);
    vec etas = Xs * betas;
    vec etaDs = XDs * betas;
    vec etaDDs = XDDs * betas;
    // penalties
    vec eps = etaDs*0. + 1e-8;
    uvec pindexs = (etaDs < eps);
    uvec pindex = ((1.0/time - etaD) < eps);
    // fix bounds on etaDs
    etaDs = max(etaDs, eps);
    // fix bounds on etaD
    etaD = 1/time - max(1/time-etaD, eps);
    vec logh = etas + log(etaDs) + log(1/time -etaD);
    vec h = exp(logh);
    return h;
  }
  
  mat aft::gradh(vec time, mat X, mat XD)
  {
    vec beta = init.subvec(0,X.n_cols-1);
    vec betas = init.subvec(X.n_cols,init.size()-1);
    vec eta = X * beta;
    vec etaD = XD * beta;
    vec logtstar = log(time) - eta;
    mat Xs = s.basis(logtstar);
    mat XDs = s.basis(logtstar,1);
    mat XDDs = s.basis(logtstar,2);
    vec etas = Xs * betas;
    vec etaDs = XDs * betas;
    vec etaDDs = XDDs * betas;
    // penalties
    vec eps = etaDs*0. + 1e-8;
    uvec pindexs = (etaDs < eps);
    uvec pindex = ((1.0/time - etaD) < eps);
    // fix bounds on etaDs
    etaDs = max(etaDs, eps);
    // fix bounds on etaD
    etaD = 1/time - max(1/time-etaD, eps);
    vec logh = etas + log(etaDs) + log(1/time -etaD);
    vec h = exp(logh);
    mat dloghdbetas = Xs+rmult(XDs,1/etaDs % (1-pindexs));
    mat dloghdbeta = -rmult(X,etaDs % (1-pindexs) % (1-pindex)) - rmult(X,etaDDs/etaDs % (1-pindexs) % (1-pindex)) - rmult(XD, (1-pindexs) % (1-pindex)/(1/time-etaD));
    mat gradh = join_rows(rmult(dloghdbeta,h), rmult(dloghdbetas,h));
    return gradh;
  }
  
  RcppExport SEXP aft_model_output(SEXP args) {
    aft model(args);
    List list = as<List>(args);
    std::string return_type = as<std::string>(list["return_type"]);
    if (return_type == "nmmin") {
      // model.pre_process();
      NelderMead nm;
      nm.trace = as<int>(list["trace"]);
      nm.maxit = as<int>(list["maxit"]);
      NumericVector betafull = as<NumericVector>(wrap(model.init));
      nm.optim<aft>(betafull,model);
      // model.post_process();
      return List::create(_("fail")=nm.fail, 
			  _("coef")=wrap(nm.coef),
			  _("hessian")=wrap(nm.hessian));
    }
    else if (return_type == "vmmin") {
      // model.pre_process();
      BFGS bfgs;
      bfgs.trace = as<int>(list["trace"]);
      bfgs.maxit = as<int>(list["maxit"]);
      NumericVector betafull = as<NumericVector>(wrap(model.init));
      bfgs.optim<aft>(betafull,model);
      // model.post_process();
      return List::create(_("fail")=bfgs.fail, 
			  _("coef")=wrap(bfgs.coef),
			  _("hessian")=wrap(bfgs.hessian));
    }
    else if (return_type == "objective")
      return wrap(model.objective(model.init));
    else if (return_type == "gradient")
      return wrap(model.gradient(model.init));
    else if (return_type == "survival")
      return wrap(model.survival(as<vec>(list["time"]),as<mat>(list["X"])));
    else if (return_type == "haz")
      return wrap(model.haz(as<vec>(list["time"]),as<mat>(list["X"]),as<mat>(list["XD"])));
    else if (return_type == "gradh")
      return wrap(model.gradh(as<vec>(list["time"]),as<mat>(list["X"]),as<mat>(list["XD"])));
    else {
      REprintf("Unknown return_type.\n");
      return wrap(-1);
    }
  }
  
  // RcppExport SEXP aft_objective_function(SEXP args)
  // {
  //   List list = as<List>(args);
  //   vec beta = as<vec>(list["beta"]);
  //   vec betas = as<vec>(list["betas"]);
  //   mat X = as<mat>(list["X"]);
  //   mat XD = as<mat>(list["XD"]);
  //   vec event = as<vec>(list["event"]);
  //   vec time = as<vec>(list["time"]);
  //   vec boundaryKnots = as<vec>(list["boundaryKnots"]);
  //   vec interiorKnots = as<vec>(list["interiorKnots"]);
  //   ns s(boundaryKnots, interiorKnots, 1);
  //   vec eta = X * beta;
  //   vec etaD = XD * beta;
  //   vec logtstar = log(time) - eta;
  //   vec etas = s.basis(logtstar) * betas;
  //   vec etaDs = s.basis(logtstar,1) * betas;
  //   vec eps = etaDs*0. + 1e-8;
  //   double pen = dot(min(etaDs,eps), min(etaDs,eps));
  //   etaDs = max(etaDs, eps);
  //   vec logh = etas + log(etaDs) + log(etaD+1.) - log(time);
  //   vec H = exp(etas);
  //   double f = pen - (dot(logh,event) - sum(H));
  //   return wrap(f);
  // }

} // namespace rstpm2

back to top