https://github.com/xbpeng/DeepTerrainRL
Raw File
Tip revision: ed82e2ebe5f14fa875cc3d0a2180c64980408e8f authored by Glen on 19 October 2016, 17:49:36 UTC
Update README.md
Tip revision: ed82e2e
NeuralNetTrainer.h
#pragma once
#include <memory>
#include <mutex>
#include "learning/TrainerInterface.h"
#include "learning/ExpTuple.h"
#include "learning/NeuralNet.h"
#include "learning/NeuralNetLearner.h"
#include "learning/ParamServer.h"

class cNeuralNetTrainer : public cTrainerInterface, 
						public std::enable_shared_from_this<cNeuralNetTrainer>
{
public:
	enum eStage
	{
		eStageInit,
		eStageTrain,
		eStageMax
	};

	static double CalcDiscountNorm(double discount);
	
	cNeuralNetTrainer();
	virtual ~cNeuralNetTrainer();

	virtual void Init(const tParams& params);
	virtual void LoadModel(const std::string& model_file);
	virtual void LoadScale(const std::string& scale_file);
	virtual void Reset();
	virtual void EndTraining();

	virtual int AddTuple(const tExpTuple& tuple);
	virtual void AddTuples(const std::vector<tExpTuple>& tuples);
	virtual void Train();

	virtual const std::unique_ptr<cNeuralNet>& GetNet() const;
	virtual double GetDiscount() const;
	virtual double GetAvgReward() const;
	virtual int GetIter() const;
	virtual double NormalizeReward(double r) const;

	virtual void SetNumInitSamples(int num);
	virtual void SetInputOffsetScale(const Eigen::VectorXd& offset, const Eigen::VectorXd& scale);
	virtual void SetOutputOffsetScale(const Eigen::VectorXd& offset, const Eigen::VectorXd& scale);

	virtual int GetNumInitSamples() const;
	virtual const std::string& GetNetFile() const;
	virtual const std::string& GetSolverFile() const;

	virtual eStage GetStage() const;
	virtual int GetStateSize() const;
	virtual int GetActionSize() const;
	virtual int GetInputSize() const;
	virtual int GetOutputSize() const;
	virtual int GetBatchSize() const;

	virtual int GetNumTuples() const;
	virtual void OutputModel(const std::string& filename) const;

	virtual bool HasInitModel() const;
	virtual void EvalNet(const tExpTuple& tuple, Eigen::VectorXd& out_y);

	virtual void RequestLearner(std::shared_ptr<cNeuralNetLearner>& out_learner);
	virtual int RegisterLearner(cNeuralNetLearner* learner);
	virtual void UnregisterLearner(cNeuralNetLearner* learner);

	virtual bool EnableAsyncMode() const;
	virtual void Lock();
	virtual void Unlock();

	virtual void SetParamServer(cParamServer* server);
	virtual void SyncNets();

	virtual bool IsDone() const;

protected:
	eStage mStage;
	tParams mParams;
	int mIter;
	bool mDone;

	int mBufferHead;
	int mNumTuples;
	int mTotalTuples;
	Eigen::MatrixXf mPlaybackMem;
	std::vector<unsigned int> mFlagBuffer;

	cNeuralNet::tProblem mProb;
	std::vector<std::unique_ptr<cNeuralNet>> mNetPool;
	int mCurrActiveNet;
	std::vector<int> mBatchBuffer;
	double mAvgReward;

	std::mutex mLock;
	std::vector<cNeuralNetLearner*> mLearners;

	cParamServer* mParamServer;

	const std::unique_ptr<cNeuralNet>& GetCurrNet() const;

	virtual void InitPlaybackMem(int size);
	virtual void InitBatchBuffer();
	virtual void InitProblem(cNeuralNet::tProblem& out_prob) const;
	virtual int GetPlaybackMemSize() const;
	virtual void ResetParams();
	
	virtual void Pretrain();
	virtual bool Step();
	virtual bool BuildProblem(int net_id, cNeuralNet::tProblem& out_prob);
	virtual void BuildProblemX(int net_id, const std::vector<int>& tuple_ids, cNeuralNet::tProblem& out_prob);
	virtual void BuildProblemY(int net_id, const std::vector<int>& tuple_ids, const Eigen::MatrixXd& X, cNeuralNet::tProblem& out_prob);
	virtual void UpdateMisc(const std::vector<int>& tuple_ids);

	virtual void BuildTupleX(const tExpTuple& tuple, Eigen::VectorXd& out_x);
	virtual void BuildTupleY(int net_id, const tExpTuple& tuple, Eigen::VectorXd& out_y);
	virtual void FetchMinibatch(int size, std::vector<int>& out_batch);

	virtual int GetTargetNetID(int net_id) const;
	virtual void UpdateCurrActiveNetID();
	virtual const std::unique_ptr<cNeuralNet>& GetTargetNet(int net_id) const;
	virtual bool CheckTuple(const tExpTuple& tuple) const;
	virtual void UpdateNet(int net_id, const cNeuralNet::tProblem& prob);

	virtual int CalcBufferSize() const;
	virtual int GetStateBegIdx() const;
	virtual int GetActionIdx() const;

	virtual void SetTuple(int t, const tExpTuple& tuple);
	virtual tExpTuple GetTuple(int t) const;

	virtual void UpdateOffsetScale();
	virtual void UpdateStage();
	virtual void InitStage();
	virtual void ApplySteps(int num_steps);
	virtual void IncIter();

	virtual int GetNetPoolSize() const;
	virtual void BuildNetPool(const std::string& net_file, const std::string& solver_file, int pool_size);
	virtual int GetPoolSize() const;

	virtual bool EnableIntOutput() const;
	virtual void OutputIntermediate();
	virtual void OutputIntermediateModel(const std::string& filename) const;

	virtual int GetNumLearners() const;
	virtual void ResetLearners();
	virtual void ResetSolvers();

	virtual void UpdateParamServerInputOffsetScale(const Eigen::VectorXd& offset, const Eigen::VectorXd& scale);
	virtual void SyncNet(int net_id);

#if defined(OUTPUT_TRAINER_LOG)
public:
	struct tLog
	{
		int mBuildTupleXSamples;
		double mBuildTupleXTime;
		int mBuildTupleYSamples;
		double mBuildTupleYTime;
		int mUpdateNetSamples;
		double mUpdateNetTime;
		int mStepSamples;
		double mStepTime;

		int mBuildActorTupleXSamples;
		double mBuildActorTupleXTime;
		int mBuildActorTupleYSamples;
		double mBuildActorTupleYTime;
		int mUpdateActorNetSamples;
		double mUpdateActorNetTime;
		int mStepActorSamples;
		double mStepActorTime;

		int mAsyncForwardBackSamples;
		double mAsyncForwardBackTime;
		int mAsyncUpdateNetSamples;
		double mAsyncUpdateNetTime;

		int mLockWaitSamples;
		double mLockWaitTime;

		double mTotalExpTime;
		double mTotalTime;

		int mIters;
		double mAvgIterTime;

		tLog();
		void Write(FILE* f) const;
	};

	const tLog& GetLog() const;

protected:
	std::mutex mLogLock;
	tLog mLog;
	std::clock_t mStartTime;

	void InitLog();
	void EndLog();
	void WriteLog(const std::string& log_file) const;
#endif // OUTPUT_TRAINER_LOG
};
back to top