Raw File
MACETrainer.h
#pragma once

#include "learning/NeuralNetTrainer.h"
#include "util/CircularBuffer.h"

#define ENABLE_ACTOR_MULTI_SAMPLE_UPDATE

class cMACETrainer : public cNeuralNetTrainer
{
public:
	enum eFlag
	{
		eFlagFail,
		eFlagExpCritic,
		eFlagExpActor,
		eFlagMax
	};

	static int GetMaxFragIdx(const Eigen::VectorXd& params, int num_frags);
	static double GetMaxFragVal(const Eigen::VectorXd& params, int num_frags);
	static void GetFrag(const Eigen::VectorXd& params, int num_frags, int frag_size, int a_idx, Eigen::VectorXd& out_action);
	static void SetFrag(const Eigen::VectorXd& frag, int a_idx, int num_frags, int frag_size, Eigen::VectorXd& out_params);
	static double GetVal(const Eigen::VectorXd& params, int a_idx);
	static void SetVal(double val, int a_idx, Eigen::VectorXd& out_params);
	static int CalcNumFrags(int param_size, int frag_size);

	// for handling the action recorded in tuples
	static int GetActionFragIdx(const Eigen::VectorXd& action_params);
	static void SetActionFragIdx(int a_idx, Eigen::VectorXd& out_action_params);
	static void GetActionFrag(const Eigen::VectorXd& action_params, Eigen::VectorXd& out_frag_params);
	static void SetActionFrag(const Eigen::VectorXd& frag_params, Eigen::VectorXd& out_action_params);

	cMACETrainer();
	virtual ~cMACETrainer();
	
	virtual void Init(const tParams& params);
	virtual void Reset();
	virtual int AddTuple(const tExpTuple& tuple);

	virtual void SetNumActionFrags(int num);
	virtual void SetActionFragSize(int size);
	virtual int GetNumActionFrags() const;
	virtual int GetActionFragSize() const;

	virtual const std::unique_ptr<cNeuralNet>& GetActor() const;

protected:
	int mNumActionFrags;
	int mActionFragSize;
	cNeuralNet::tProblem mActorProb;

	int mActorIter;
	std::vector<int> mActorBatchBuffer;
	std::vector<int> mCriticBuffer;
	std::vector<int> mActorBuffer;
	
	Eigen::MatrixXd mBatchXBuffer;
	Eigen::MatrixXd mBatchYBuffer;
	Eigen::VectorXd mBatchValBuffer0;
	Eigen::VectorXd mBatchValBuffer1;

	virtual void InitBatchBuffers();
	virtual void InitActorProblem(cNeuralNet::tProblem& out_prob) const;

	virtual void FetchMinibatch(int size, std::vector<int>& out_batch);
	virtual void FetchActorMinibatch(int size, std::vector<int>& out_batch);
	virtual void BuildNetPool(const std::string& net_file, const std::string& solver_file, int pool_size);
	virtual void BuildProblemY(int net_id, const std::vector<int>& tuple_ids, const Eigen::MatrixXd& X, cNeuralNet::tProblem& out_prob);
	virtual void BuildTupleActorY(const tExpTuple& tuple, Eigen::VectorXd& out_y);
	virtual void BuildActorProblemX(int num_data, const std::vector<int>& tuple_ids, cNeuralNet::tProblem& out_prob);
	virtual void BuildActorProblemY(int num_data, const std::vector<int>& tuple_ids, const Eigen::MatrixXd& X, cNeuralNet::tProblem& out_prob);
	virtual int GetActorBatchSize() const;

	virtual bool Step();
	virtual void BuildTupleY(int net_id, const tExpTuple& tuple, Eigen::VectorXd& out_y);
	virtual void ApplySteps(int num_steps);

	virtual int GetPoolSize() const;
	virtual int GetTargetNetID(int net_id) const;
	virtual int CalcBufferSize() const;
	
	virtual int GetRewardIdx() const;
	virtual int GetStateBegIdx() const;
	virtual int GetStateEndIdx() const;
	virtual int GetActionIdx() const;
	virtual int GetActionSize() const;

	virtual double CalcCurrCumulativeReward(int net_id, const tExpTuple& tuple);
	virtual double CalcNewCumulativeReward(int net_id, const tExpTuple& tuple);
	virtual void CalcCurrCumulativeRewardBatch(int net_id, const std::vector<int>& tuple_ids, Eigen::VectorXd& out_vals);
	virtual void CalcNewCumulativeRewardBatch(int net_id, const std::vector<int>& tuple_ids, Eigen::VectorXd& out_vals);
	
	virtual void SetTuple(int t, const tExpTuple& tuple);
	virtual tExpTuple GetTuple(int t) const;

	virtual void UpdateActorBatchBuffer();
	virtual void UpdateActor();
	virtual void StepActor();
	virtual void BuildActorProblem(cNeuralNet::tProblem& out_prob);
	virtual void UpdateActorNet(const cNeuralNet::tProblem& prob);
	virtual void IncActorIter();

	virtual void UpdateBuffers(int t);
	virtual bool IsExpCritic(int t) const;
	virtual bool IsExpActor(int t) const;

	virtual bool EnableTargetNet() const;
	virtual void UpdateTargetNet();

	// for handling the (possibly separate) actor and critic outputs from network
	virtual int GetMaxFragIdxAux(const Eigen::VectorXd& params);
	virtual double GetMaxFragValAux(const Eigen::VectorXd& params);
	virtual void GetFragAux(const Eigen::VectorXd& params, int a_idx, Eigen::VectorXd& out_action);
	virtual void SetFragAux(const Eigen::VectorXd& frag, int a_idx, Eigen::VectorXd& out_params);
	virtual double GetValAux(const Eigen::VectorXd& params, int a_idx);
	virtual void SetValAux(double val, int a_idx, Eigen::VectorXd& out_params);
};
back to top