https://github.com/OlgaIsupova/dynamic-hdp
Raw File
Tip revision: 6395f3ed6342c184b8053ab94671a83d9fa7d754 authored by OlgaIsupova on 27 June 2016, 09:57:13 UTC
Update README.md
Tip revision: 6395f3e
state.h
#ifndef STATE_H
#define STATE_H

#include "corpus.h"
#include <map>
#include <vector>
using namespace std;

class hdp_hyperparameter
{
    /// hyperparameters
public:
    double m_gamma_a;
    double m_gamma_b;
    double m_alpha_a;
    double m_alpha_b;
    int    m_max_iter;

public:
    void setup_parameters(double _gamma_a, double _gamma_b,
                        double _alpha_a, double _alpha_b,
                        int _max_iter)
    {
        m_gamma_a   = _gamma_a;
        m_gamma_b   = _gamma_b;
        m_alpha_a   = _alpha_a;
        m_alpha_b   = _alpha_b;
        m_max_iter  = _max_iter;
    }
    void copy_parameters(const hdp_hyperparameter _input_hyperparam);
};

typedef vector<int> int_vec; // define the vector of int
typedef vector<double> double_vec; // define the vector of double

/// word info structure used in the main class
struct word_info
{
public:
    int m_word_index;
    int m_table_assignment;
    //int m_topic_assignment; // this is extra information
};

class doc_state
{
public:
    int m_doc_id; // document id
    int m_doc_length;  // document length
    int m_num_tables;  // number of tables in this document
    word_info * m_words;

    int_vec m_table_to_topic; // for a doc, translate its table index to topic index
    int_vec m_word_counts_by_t; // word counts for each table
    
    //vector < vector<int> > m_words_by_zi; // stores the word idx indexed by z then i
public:
    doc_state();
    ~doc_state();
public:
    void setup_state_from_doc(const document * doc);
    void free_doc_state();
};

class counts
{
public:
    int m_num_words;
	int m_total_num_tables;
    int_vec m_num_tables_by_z;
    int_vec m_word_counts_by_z;
    vector <int*> m_word_counts_by_zw;
    
    counts();
    counts(int _m_num_words, int _m_total_num_tables, const int_vec& _m_num_tables_by_z, const int_vec& _m_word_counts_by_z, const vector<int*>& _m_word_counts_by_zw);
    ~counts();

	void set_counts(int _m_num_words, int _m_total_num_tables, const int_vec& _m_num_tables_by_z,
                    const int_vec& _m_word_counts_by_z,
                    const vector<int*>& _m_word_counts_by_zw);
};

class word_counts
{
public:
    int_vec m_word_counts_by_z;
    vector<int*> m_word_counts_by_zw;
    
    word_counts();
    ~word_counts();
};

class hdp_state
{
public:

/// doc information, fix value
    int m_size_vocab;
    int m_total_words;
    int m_num_docs;

/// document states
    doc_state** m_doc_states;

/// number of topics
    int m_num_topics;


/// total number of tables for all topics
    int m_total_num_tables;

/// by_z, by topic
/// by_d, by document, for each topic
/// by_w, by word, for each topic
/// by_t, by table for each document
    int_vec   m_num_tables_by_z; // how many tables each topic has
    int_vec   m_word_counts_by_z;   // word counts for each topic
    vector <int*> m_word_counts_by_zd; // word counts for [each topic, each doc]
    vector <int*> m_word_counts_by_zw; // word counts for [each topic, each word]

/// for online updates (local changes for the current document)
//	int_vec m_num_tables_by_z_update;
//	int_vec m_word_counts_by_z_update;
	int_vec m_word_counts_by_zd_online;
//	vector <int*> m_word_counts_by_zw_update;
	int m_num_topics_before_update;	// number of topics from the batch plus possible new topics
    
///

/// topic Dirichlet parameter
    double m_eta;

/// including concentration parameters
    double m_gamma;
    double m_alpha;
public:
    hdp_state();
    virtual ~hdp_state();
public:
	void   set_vocab_size(int _vocab_size);
    void   setup_state_from_corpus(const corpus* c);
	void   setup_doc_info_from_document(const document * doc);
    void   allocate_initial_space();
	void   allocate_space_for_online_updates();
    void   free_state();
	void   free_doc_info();
    void   free_counts_update();
    void   free_online_info();

    void   iterate_gibbs_state(bool remove, bool permute,
                               hdp_hyperparameter* hdp_hyperparam,
                               bool table_sampling=false);
	void   iterate_gibbs_state_online(bool remove, hdp_hyperparameter* hdp_hyperparam,
									  bool table_sampling=false);

    void   sample_tables(doc_state* d_state, double_vec & q, double_vec & f);
	void   sample_tables_online(doc_state* d_state, double_vec & q, double_vec & f);

    void   sample_table_assignment(doc_state* d_state, int t, int* words, double_vec & q, double_vec & f);
	void   sample_table_assignment_online(doc_state* d_state, int t, int* words, double_vec & q, double_vec & f);

    void   sample_word_assignment(doc_state* d_state, int i, bool remove, double_vec & q, double_vec & f);
	void   sample_word_assignment_online(doc_state* d_state, int i, bool remove, double_vec & q, double_vec & f);

    void   doc_state_update(doc_state* d_state, int i, int update, int k=-1);
	void   doc_state_update_online(doc_state* d_state, int i, int update, int k=-1);

    void   compact_doc_state(doc_state* d_state, int* k_to_new_k);
	void   compact_doc_state_online(doc_state* d_state, int* k_to_new_k);

    void   compact_hdp_state();
	void   compact_hdp_state_online();

	void   save_state_to_file(char * name);
    void   save_state(char * name);
    void   save_state_ex(char * name, int sampler_num = -1);
	void   load_state_from_file(char * name);
    void   load_state_ex(char * name, int sampler_num = -1);
};

class hdp_state_dynamic
{
public:

/// doc information, fix value
    int m_size_vocab;
    int m_total_words;
    int m_num_docs;

/// document states
    doc_state** m_doc_states;

/// number of topics
    int m_num_topics;

/// total number of tables for all topics (during the iteration of batch this is the total number of tables for the documents seen so far)
    int m_total_num_tables;


/// by_z, by topic
/// by_d, by document, for each topic
/// by_w, by word, for each topic
/// by_t, by table for each document
 
	int_vec   m_word_counts_by_z;   // word counts for each topic
	int_vec   m_num_tables_by_z;	// how many tables each topic has (during the iterations of batch this is the number of tables for the documents seen so far)
	vector<int*> m_num_tables_by_zd; // how many tables [each topic, (in) each doc] has
	int_vec m_total_num_tables_by_d; // total number of tables for each doc
    vector <int*> m_word_counts_by_zd; // word counts for [each topic, each doc]
    vector <int*> m_word_counts_by_zw; // word counts for [each topic, each word]

/// for online updates (local changes for the current document)
	int_vec m_word_counts_by_zd_online;
	int_vec m_num_tables_by_zd_online;
	int_vec m_num_tables_by_zd_for_prev_doc_online;
	int m_total_num_tables_by_d_online;
	int m_total_num_tables_by_d_for_prev_doc_online;
	int m_num_topics_before_update;	// number of topics from the batch
    
///

/// topic Dirichlet parameter
    double m_eta;

/// including concentration parameters
    double m_gamma;
    double m_alpha;
	double m_delta;
public:
    hdp_state_dynamic();
    virtual ~hdp_state_dynamic();
public:
	void   set_vocab_size(int _vocab_size);
    void   setup_state_from_corpus(const corpus* c);
	void   setup_doc_info_from_document(const document * doc);
    void   allocate_initial_space();
	void   allocate_space_for_online_updates();

    void   free_state();
	void   free_doc_info();
    void   free_counts_update();
    void   free_online_info();
    void   free_current_doc_counts_update();

    void   iterate_gibbs_state(bool remove,
                               hdp_hyperparameter* hdp_hyperparam,
                               bool table_sampling=false);
	void   iterate_gibbs_state_online(bool remove, hdp_hyperparameter* hdp_hyperparam,
									  bool table_sampling=false);

	void   reset_counts_up_to_date(int doc_id = 0);

    void   sample_tables(doc_state* d_state, bool not_init, double_vec & q, double_vec & f);
	void   sample_tables_online(doc_state* d_state, double_vec & q, double_vec & f);

    void   sample_table_assignment(doc_state* d_state, int t, bool not_init,
                                   int* words, double_vec & q, double_vec & f);
	void   sample_table_assignment_online(doc_state* d_state, int t, int* words, double_vec & q, double_vec & f);

    void   sample_word_assignment(doc_state* d_state, int i, bool remove, double_vec & q, double_vec & f);
	void   sample_word_assignment_online(doc_state* d_state, int i, bool remove, double_vec & q, double_vec & f);

	int    sample_new_topic_for_new_table_online(const vector<double>& likelihood_per_topic_assignment) const;

	double compute_prior_probability_for_topic_sampling(int doc_id, int topic_id,
                                                        bool decrease_table_counts = false) const;
	double compute_prior_probability_for_topic_sampling_online(int topic_id,
                                                               bool decrease_table_counts = false) const;

	double compute_normalisation_constant_for_topic_prior_online() const;

	void   doc_state_update(doc_state* d_state, int i, int update, int k=-1);
	void   doc_state_update_online(doc_state* d_state, int i, int update, int k=-1);

    void   compact_doc_state(doc_state* d_state, int* k_to_new_k);
	void   compact_doc_state_online(doc_state* d_state, int* k_to_new_k);

    void   compact_hdp_state();
	void   compact_hdp_state_online();

	double compute_next_doc_table_assignment_likelihood(int doc_id, int old_topic_id = -1, int new_topic_id = -1, bool new_table = false) const;
    double compute_next_table_assignment_likelihood(int doc_id, int old_topic_id = -1,
                                                    int new_topic_id = -1, bool new_table = false) const;
    double compute_next_table_assignment_likelihood_for_non_born_topics(int doc_id,
                                                                        const vector<int>& updating_num_tables_by_z,
                                                                        int old_topic_id = -1,
                                                                        int new_topic_id = -1,
                                                                        bool update_counts_for_current_doc = false) const;

	void   save_state_to_file(char * name);
    void   save_state_ex(char * name, int sampler_num = -1);
	void   load_state_from_file(char * name);
    void   load_state_ex(char * name, int sampler_num = -1);
    
    void save_online_update_for_next_iteration();
    void update_state_after_online_update();
    void save_last_document_info_for_online_update();
};

#endif // STATE_H
back to top