Raw File
QNetTrainer.h
#pragma once

#include "learning/NeuralNetTrainer.h"

class cQNetTrainer : public cNeuralNetTrainer
{
public:
	enum eFlag
	{
		eFlagFail,
		eFlagMax
	};

	cQNetTrainer();
	virtual ~cQNetTrainer();

	virtual void Init(const tParams& params);
	virtual int AddTuple(const tExpTuple& tuple);

protected:
	Eigen::MatrixXd mBatchXBuffer;
	Eigen::MatrixXd mBatchYBuffer0;
	Eigen::MatrixXd mBatchYBuffer1;

	virtual void InitBatchBuffers();

	virtual bool Step();
	virtual void BuildProblemY(int net_id, const std::vector<int>& tuple_ids, const Eigen::MatrixXd& X, cNeuralNet::tProblem& out_prob);
	virtual void BuildTupleY(int net_id, const tExpTuple& tuple, Eigen::VectorXd& out_y);

	virtual void UpdateCurrActiveNetID();
	virtual int GetNextActiveID() const;
	virtual int GetRandRefID(int id) const;

	virtual int CalcBufferSize() const;

	virtual int GetRewardIdx() const;
	virtual int GetStateBegIdx() const;
	virtual int GetStateEndIdx() const;
	virtual int GetActionIdx() const;

	virtual void SetTuple(int t, const tExpTuple& tuple);
	virtual tExpTuple GetTuple(int t) const;
};
back to top