#ifndef STAN_IO_ARRAY_VAR_CONTEXT_HPP #define STAN_IO_ARRAY_VAR_CONTEXT_HPP #include #include #include #include #include #include #include #include #include #include #include #include #include namespace stan { namespace io { /** * An array_var_context object represents a named arrays * with dimensions constructed from an array, a vector * of names, and a vector of all dimensions for each element. */ class array_var_context : public var_context { private: // Pair used in data maps template using data_pair_t = std::pair, std::vector>; std::map> vars_r_; // Holds data for reals std::map> vars_i_; // Holds data for doubles // When search for variable name fails, return one these const std::vector empty_vec_r_; const std::vector empty_vec_i_; const std::vector empty_vec_ui_; /** * Search over the real variables to check if a name is in the map * @param name The name of the variable to search for * @return logical indicating if the variable was found in the map of reals. */ bool contains_r_only(const std::string& name) const { return vars_r_.find(name) != vars_r_.end(); } /** * Check (1) if the vector size of dimensions is no smaller * than the name vector size; (2) if the size of the input * array is large enough for what is needed. * * @param names The names for each variable * @param array_size The total size of the vector holding the values we want * to access. * @param dims Vector holding the dimensions for each variable. * @return If the array size is equal to the number of dimensions, * a vector of the cumulative sum of the dimensions of each inner element of * dims. The return of this function is used in the add_* methods to get the * sequence of values For each variable. * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ template inline std::vector validate_dims( const std::vector& names, const T array_size, const std::vector>& dims) { const size_t num_par = names.size(); stan::math::check_less_or_equal("validate_dims", "array_var_context", dims.size(), num_par); std::vector elem_dims_total(dims.size() + 1); for (int i = 0; i < dims.size(); i++) { elem_dims_total[i + 1] = std::accumulate(dims[i].begin(), dims[i].end(), 1, std::multiplies()) + elem_dims_total[i]; } stan::math::check_less_or_equal("validate_dims", "array_var_context", elem_dims_total[dims.size()], array_size); return elem_dims_total; } /** * Adds a set of floating point variables to the floating point map. * @param names Names of each variable. * @param values The real values of variable in a contiguous * column major order container. * @param dims the dimensions for each variable. * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ void add_r(const std::vector& names, const std::vector& values, const std::vector>& dims) { std::vector dim_vec = validate_dims(names, values.size(), dims); for (size_t i = 0; i < names.size(); i++) { vars_r_.emplace(names[i], data_pair_t{{values.data() + dim_vec[i], values.data() + dim_vec[i + 1]}, dims[i]}); } } /** * Adds a set of floating point variables to the floating point map. * @param names Names of each variable. * @param values The real values of variable in an Eigen column vector. * @param dims the dimensions for each variable. * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ void add_r(const std::vector& names, const Eigen::VectorXd& values, const std::vector>& dims) { std::vector dim_vec = validate_dims(names, values.size(), dims); for (size_t i = 0; i < names.size(); i++) { vars_r_.emplace(names[i], data_pair_t{{values.data() + dim_vec[i], values.data() + dim_vec[i + 1]}, dims[i]}); } } /** * Adds a set of integer variables to the integer map. * @param names Names of each variable. * @param values The integer values of variable in a vector. * @param dims the dimensions for each variable. * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ void add_i(const std::vector& names, const std::vector& values, const std::vector>& dims) { std::vector dim_vec = validate_dims(names, values.size(), dims); for (size_t i = 0; i < names.size(); i++) { vars_i_.emplace(names[i], data_pair_t{{values.data() + dim_vec[i], values.data() + dim_vec[i + 1]}, dims[i]}); } } public: /** * Construct an array_var_context from only real value arrays. * * @param names_r names for each element * @param values_r a vector of double values for all elements * @param dim_r a vector of dimensions * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ array_var_context(const std::vector& names_r, const std::vector& values_r, const std::vector>& dim_r) { add_r(names_r, values_r, dim_r); } /** * Construct an array_var_context from an Eigen column vector. * * @param names_r names for each element * @param values_r a vector of double values for all elements * @param dim_r a vector of dimensions * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ array_var_context(const std::vector& names_r, const Eigen::VectorXd& values_r, const std::vector>& dim_r) { add_r(names_r, values_r, dim_r); } /** * Construct an array_var_context from only integer value arrays. * * @param names_i names for each element * @param values_i a vector of integer values for all elements * @param dim_i a vector of dimensions * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ array_var_context(const std::vector& names_i, const std::vector& values_i, const std::vector>& dim_i) { add_i(names_i, values_i, dim_i); } /** * Construct an array_var_context from arrays of both double * and integer separately * * @param names_r names for each element * @param values_r a vector of double values for all elements * @param dim_r a vector of dimensions * @param names_i names for each element * @param values_i a vector of integer values for all elements * @param dim_i a vector of dimensions * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ array_var_context(const std::vector& names_r, const std::vector& values_r, const std::vector>& dim_r, const std::vector& names_i, const std::vector& values_i, const std::vector>& dim_i) { add_i(names_i, values_i, dim_i); add_r(names_r, values_r, dim_r); } /** * Construct an array_var_context from arrays of both double * and integer separately * * @param names_r names for each element * @param values_r Eigen column vector of double elements. * @param dim_r a vector of dimensions * @param names_i names for each element * @param values_i a vector of integer values for all elements * @param dim_i a vector of dimensions * @throw std::invalid_argument when size of dimensions is less * then array size or array is not long enough to hold * the dimensions of the data. */ array_var_context(const std::vector& names_r, const Eigen::VectorXd& values_r, const std::vector>& dim_r, const std::vector& names_i, const std::vector& values_i, const std::vector>& dim_i) { add_i(names_i, values_i, dim_i); add_r(names_r, values_r, dim_r); } /** * Return true if this dump contains the specified * variable name is defined. This method returns true * even if the values are all integers. * * @param name Variable name to test. * @return true if the variable exists. */ bool contains_r(const std::string& name) const { return contains_r_only(name) || contains_i(name); } /** * Return true if this dump contains an integer * valued array with the specified name. * * @param name Variable name to test. * @return true if the variable name has an integer * array value. */ bool contains_i(const std::string& name) const { return vars_i_.find(name) != vars_i_.end(); } /** * Return the double values for the variable with the specified * name or null. * * @param name Name of variable. * @return Values of variable. * */ std::vector vals_r(const std::string& name) const { const auto ret_val_r = vars_r_.find(name); if (ret_val_r != vars_r_.end()) { return ret_val_r->second.first; } else { const auto ret_val_i = vars_i_.find(name); if (ret_val_i != vars_i_.end()) { return {ret_val_i->second.first.begin(), ret_val_i->second.first.end()}; } } return empty_vec_r_; } /** * Return the double values for the variable with the specified * name or null. * * @param name Name of variable. * @return Values of variable. * */ std::vector> vals_c(const std::string& name) const { const auto val_r = vars_r_.find(name); if (val_r != vars_r_.end()) { std::vector> ret_c(val_r->second.first.size() / 2); int comp_iter; int real_iter; for (comp_iter = 0, real_iter = 0; real_iter < val_r->second.first.size(); comp_iter += 1, real_iter += 2) { ret_c[comp_iter] = std::complex{ val_r->second.first[real_iter], val_r->second.first[real_iter + 1]}; } return ret_c; } else { const auto val_i = vars_i_.find(name); if (val_i != vars_i_.end()) { std::vector> ret_c(val_i->second.first.size() / 2); int comp_iter; int real_iter; for (comp_iter = 0, real_iter = 0; real_iter < val_i->second.first.size(); comp_iter += 1, real_iter += 2) { ret_c[comp_iter] = std::complex{ static_cast(val_i->second.first[real_iter]), static_cast(val_i->second.first[real_iter + 1])}; } return ret_c; } } return std::vector>{}; } /** * Return the dimensions for the double variable with the specified * name. * * @param name Name of variable. * @return Dimensions of variable. */ std::vector dims_r(const std::string& name) const { const auto ret_val_r = vars_r_.find(name); if (ret_val_r != vars_r_.end()) { return ret_val_r->second.second; } else { const auto ret_val_i = vars_i_.find(name); if (ret_val_i != vars_i_.end()) { return ret_val_i->second.second; } } return empty_vec_ui_; } /** * Return the integer values for the variable with the specified * name. * * @param name Name of variable. * @return Values. */ std::vector vals_i(const std::string& name) const { auto ret_val_i = vars_i_.find(name); if (ret_val_i != vars_i_.end()) { return ret_val_i->second.first; } return empty_vec_i_; } /** * Return the dimensions for the integer variable with the specified * name. * * @param name Name of variable. * @return Dimensions of variable. */ std::vector dims_i(const std::string& name) const { auto ret_val_i = vars_i_.find(name); if (ret_val_i != vars_i_.end()) { return ret_val_i->second.second; } return empty_vec_ui_; } /** * Check variable dimensions against variable declaration. * Only used for data read in from file. * * @param stage stan program processing stage * @param name variable name * @param base_type declared stan variable type * @param dims_declared variable dimensions * @throw std::runtime_error if mismatch between declared * dimensions and dimensions found in context. */ void validate_dims(const std::string& stage, const std::string& name, const std::string& base_type, const std::vector& dims_declared) const { stan::io::validate_dims(*this, stage, name, base_type, dims_declared); } /** * Return a list of the names of the floating point variables in * the dump. * * @param names Vector to store the list of names in. */ virtual void names_r(std::vector& names) const { names.clear(); names.reserve(vars_r_.size()); for (const auto& vars_r_iter : vars_r_) { names.push_back(vars_r_iter.first); } } /** * Return a list of the names of the integer variables in * the dump. * * @param names Vector to store the list of names in. */ virtual void names_i(std::vector& names) const { names.clear(); names.reserve(vars_i_.size()); for (const auto& vars_i_iter : vars_r_) { names.push_back(vars_i_iter.first); } } /** * Remove variable from the object. * * @param name Name of the variable to remove. * @return If variable is removed returns true, else * returns false. */ bool remove(const std::string& name) { return (vars_i_.erase(name) > 0) || (vars_r_.erase(name) > 0); } }; } // namespace io } // namespace stan #endif