https://github.com/cran/rstpm2
Raw File
Tip revision: 45536c5580212aeb43a1cdf1a690e92cc843203c authored by Mark Clements on 07 January 2023, 02:40:02 UTC
version 1.5.9
Tip revision: 45536c5
gsm.cpp
#include <RcppArmadillo.h>
#include <splines.h>
#include <c_optim.h>
#include <gsm.h>

namespace rstpm2 {

  double gsm::link(double S) {
    return link_type==PH ? std::log(-std::log(S)) : -100.0;
  }
  double gsm::linkinv(double eta) {
    return link_type==PH ? std::exp(-std::exp(eta)) : 1.0e-10;
  }
  gsm::gsm() {}
  double gsm::eta(double y) {
    double eta = etap(index);
    for (std::vector<gsm_term>::size_type i=0; i<terms.size(); i++)
      if (terms[i].x(index) != 0.0)
	eta += terms[i].x(index) * arma::sum(terms[i].ns1.eval(y,0) % terms[i].gamma);
    return eta;
  }
  double gsm::operator()(double y) {
    return eta(y) - target;
  }
  double gsm::rand(double tentry, int index) {
    using std::log;
    double u = R::runif(0.0,1.0);
    double ymin = tentry == 0.0 ? (log_time ? log(tmin) : tmin) : (log_time ? log(tentry) : tentry);
    double ymax = log_time ? log(tmax) : tmax;
    this->index = index;
    target = (tentry==0.0 ? link(u) : link(u*linkinv(eta(ymin))));
    double root = std::get<0>(R_zeroin2_functor_ptr<gsm>(ymin, ymax, this, 1.0e-8, 100));
    return log_time ? std::exp(root) : root;
  }

  gsm::gsm(Rcpp::List list) {
    try {
      using namespace Rcpp;
      std::string link_name = as<std::string>(list("link_name"));
      tmin = as<double>(list("tmin"));
      tmax = as<double>(list("tmax"));
      double inflate = as<double>(list("inflate"));
      tmin = tmin/inflate; tmax = tmax*inflate;
      etap = as<arma::vec>(list("etap"));
      List lterms = as<List>(list("terms"));
      for (int i=0; i<lterms.size(); i++) {
	List lterm = as<List>(lterms(i));
	gsm_term term;
	term.gamma = as<arma::vec>(lterm("gamma"));
	arma::vec knots = as<arma::vec>(lterm("knots"));
	arma::vec Boundary_knots = as<arma::vec>(lterm("Boundary_knots"));
	int intercept = as<int>(lterm("intercept"));
	arma::mat q_const = as<arma::mat>(lterm("q_const"));
	int cure = as<int>(lterm("cure"));
	term.ns1 = ns(Boundary_knots, knots, q_const, intercept, cure);
	term.x = as<arma::vec>(lterm("x"));
	terms.push_back(term);
      }
      log_time = as<bool>(list("log_time"));
      target = 0.0;
      index = 0;
      if (link_name == "PH") link_type = PH;
    } catch(std::exception &ex) {	
      forward_exception_to_r(ex);
    } catch(...) { 
      ::Rf_error("c++ exception (unknown reason)"); 
    }
  } 

  gsm::gsm(SEXP args) : gsm(Rcpp::as<Rcpp::List>(args)) { }
  
  RcppExport SEXP test_read_gsm(SEXP args) {
    Rcpp::RNGScope rngScope;
    gsm gsm1(args);
    return Rcpp::wrap(gsm1.rand());
  }
  
}
back to top