#include #include #include #include namespace rstpm2 { using namespace arma; using namespace Rcpp; // TODO: // include mixture models in .gradh(), .haz() and .survival() class aft_mixture { public: List args; vec init; mat X, X0, Xc, Xc0; mat XD, XD0; vec event; vec time, time0; vec boundaryKnots; vec interiorKnots; mat q_matrix; int cure; int mixture; ns s; bool delayed; double kappa; // scale for the quadratic penalty for monotone splines aft_mixture(Rcpp::List args) : args(args) { init = as(args("init")); X = as(args("X")); Xc = as(args("Xc")); Xc0 = as(args("Xc0")); XD = as(args("XD")); XD0 = as(args("XD0")); event = as(args("event")); time = as(args("time")); boundaryKnots = as(args("boundaryKnots")); interiorKnots = as(args("interiorKnots")); q_matrix = as(args("q.const")); cure = as(args("cure")); mixture = as(args("mixture")); s = ns(boundaryKnots, interiorKnots, q_matrix, 1, cure); delayed = as(args("delayed")); if (delayed) { time0 = as(args("time0")); X0 = as(args("X0")); } kappa = 1.0e3; } mat rmult(mat m, vec v) { mat out(m); out.each_col() %= v; return out; } mat rmult(mat m, uvec v) { mat out(m); out.each_col() %= conv_to::from(v); return out; } double objective(vec betafull) { vec beta, betac, betas, etac; if (mixture) { beta = betafull.subvec(0,X.n_cols-1); betac = betafull.subvec(X.n_cols, X.n_cols + Xc.n_cols - 1); betas = betafull.subvec(X.n_cols+Xc.n_cols, betafull.size()-1); } else { beta = betafull.subvec(0,X.n_cols-1); betas = betafull.subvec(X.n_cols, betafull.size()-1); } vec eta = X * beta; vec etaD = XD * beta; vec logtstar = log(time) - eta; vec etas = s.basis(logtstar) * betas; vec etaDs = s.basis(logtstar,1) * betas; vec etaDs_old = etaDs; // fix bounds on etaDs vec eps = etaDs*0. + 1e-8; double pen = dot(min(etaDs,eps), min(etaDs,eps)); etaDs = max(etaDs, eps); // fix bounds on etaD pen += dot(min(1-etaD,eps), min(1-etaD,eps)); etaD = 1 - max(1-etaD, eps); // add penalty for monotone splines vec betasStar = s.q_matrix.t() * betas; for (size_t i=1; i(wrap(betafull))); } NumericVector gradient(NumericVector betafull) { vec value = gradient(as(wrap(betafull))); return as(wrap(value)); } vec survival(vec time, mat X) { vec beta = init.subvec(0,X.n_cols-1); vec betas = init.subvec(X.n_cols,init.size()-1); vec eta = X * beta; vec logtstar = log(time) - eta; vec etas = s.basis(logtstar) * betas; vec S = exp(-exp(etas)); return S; } mat haz(vec time, mat X, mat XD) { vec beta = init.subvec(0,X.n_cols-1); vec betas = init.subvec(X.n_cols,init.size()-1); vec eta = X * beta; vec etaD = XD * beta; vec logtstar = log(time) - eta; mat Xs = s.basis(logtstar); mat XDs = s.basis(logtstar,1); mat XDDs = s.basis(logtstar,2); vec etas = Xs * betas; vec etaDs = XDs * betas; vec etaDDs = XDDs * betas; // penalties vec eps = etaDs*0. + 1e-8; uvec pindexs = (etaDs < eps); uvec pindex = ((1.0/time - etaD) < eps); // fix bounds on etaDs etaDs = max(etaDs, eps); // fix bounds on etaD etaD = 1/time - max(1/time-etaD, eps); vec logh = etas + log(etaDs) + log(1/time -etaD); vec h = exp(logh); return h; } mat gradh(vec time, mat X, mat XD) { vec beta = init.subvec(0,X.n_cols-1); vec betas = init.subvec(X.n_cols,init.size()-1); vec eta = X * beta; vec etaD = XD * beta; vec logtstar = log(time) - eta; mat Xs = s.basis(logtstar); mat XDs = s.basis(logtstar,1); mat XDDs = s.basis(logtstar,2); vec etas = Xs * betas; vec etaDs = XDs * betas; vec etaDDs = XDDs * betas; // penalties vec eps = etaDs*0. + 1e-8; uvec pindexs = (etaDs < eps); uvec pindex = ((1.0/time - etaD) < eps); // fix bounds on etaDs etaDs = max(etaDs, eps); // fix bounds on etaD etaD = 1/time - max(1/time-etaD, eps); vec logh = etas + log(etaDs) + log(1/time -etaD); vec h = exp(logh); mat dloghdbetas = Xs+rmult(XDs,1/etaDs % (1-pindexs)); mat dloghdbeta = -rmult(X,etaDs % (1-pindexs) % (1-pindex)) - rmult(X,etaDDs/etaDs % (1-pindexs) % (1-pindex)) - rmult(XD, (1-pindexs) % (1-pindex)/(1/time-etaD)); mat gradh = join_rows(rmult(dloghdbeta,h), rmult(dloghdbetas,h)); return gradh; } }; RcppExport SEXP aft_mixture_model_output(SEXP args) { using namespace Rcpp; using namespace arma; aft_mixture model(args); List list = as(args); std::string return_type = as(list["return_type"]); if (return_type == "nmmin") { // model.pre_process(); NelderMead nm; nm.trace = as(list["trace"]); nm.maxit = as(list["maxit"]); NumericVector betafull = as(wrap(model.init)); nm.optim(betafull,model); // model.post_process(); return List::create(_("fail")=nm.fail, _("coef")=wrap(nm.coef), _("hessian")=wrap(nm.hessian)); } else if (return_type == "vmmin") { // model.pre_process(); BFGS bfgs; bfgs.trace = as(list["trace"]); bfgs.maxit = as(list["maxit"]); NumericVector betafull = as(wrap(model.init)); bfgs.optim(betafull,model); // model.post_process(); return List::create(_("fail")=bfgs.fail, _("coef")=wrap(bfgs.coef), _("hessian")=wrap(bfgs.hessian)); } else if (return_type == "objective") return wrap(model.objective(model.init)); else if (return_type == "gradient") return wrap(model.gradient(model.init)); else if (return_type == "survival") return wrap(model.survival(as(list["time"]),as(list["X"]))); else if (return_type == "haz") return wrap(model.haz(as(list["time"]),as(list["X"]),as(list["XD"]))); else if (return_type == "gradh") return wrap(model.gradh(as(list["time"]),as(list["X"]),as(list["XD"]))); else { REprintf("Unknown return_type.\n"); return wrap(-1); } } } // namespace rstpm2