https://github.com/xbpeng/DeepTerrainRL
Raw File
Tip revision: ed82e2ebe5f14fa875cc3d0a2180c64980408e8f authored by Glen on 19 October 2016, 17:49:36 UTC
Update README.md
Tip revision: ed82e2e
QNetTrainer.cpp
#include "QNetTrainer.h"
#include "util/FileUtil.h"

//#define FREEZE_TARGET_NET

cQNetTrainer::cQNetTrainer()
{
}

cQNetTrainer::~cQNetTrainer()
{
}

void cQNetTrainer::Init(const tParams& params)
{
	cNeuralNetTrainer::Init(params);
	InitBatchBuffers();
}

int cQNetTrainer::AddTuple(const tExpTuple& tuple)
{
	int state_size = GetStateSize();
	assert(tuple.mStateEnd.size() == state_size);
	return cNeuralNetTrainer::AddTuple(tuple);
}

void cQNetTrainer::BuildProblemY(int net_id, const std::vector<int>& tuple_ids, 
								const Eigen::MatrixXd& X, cNeuralNet::tProblem& out_prob)
{
	int num_data = static_cast<int>(tuple_ids.size());
	int intput_size = GetInputSize();
	int output_size = GetOutputSize();
	assert(num_data == GetBatchSize());
	assert(out_prob.mY.rows() == num_data);

	int ref_id = GetRandRefID(net_id);
	const auto& ref_net = mNetPool[ref_id];
	const auto& curr_net = mNetPool[net_id];

	for (int i = 0; i < num_data; ++i)
	{
		int t = tuple_ids[i];
		tExpTuple tuple = GetTuple(t);
		mBatchXBuffer.row(i) = tuple.mStateEnd;
	}

	ref_net->EvalBatch(mBatchXBuffer, mBatchYBuffer0);
	curr_net->EvalBatch(mBatchXBuffer, mBatchYBuffer1);
	curr_net->EvalBatch(X, out_prob.mY);

	for (int i = 0; i < num_data; ++i)
	{
		int t = tuple_ids[i];
		tExpTuple tuple = GetTuple(t);
		
		double new_q = 0;
		double r = tuple.mReward;

		double discount = GetDiscount();
		double norm = CalcDiscountNorm(discount);
		r *= norm;

		int action_idx = 0;
		tuple.mAction.maxCoeff(&action_idx);

		bool fail = tuple.GetFlag(eFlagFail);
		if (fail)
		{
			new_q = r;
		}
		else
		{
			auto y_next_ref = mBatchYBuffer0.row(i);
			auto y_next = mBatchYBuffer1.row(i);

			int next_action_idx = 0;
			y_next.maxCoeff(&next_action_idx);

			double q_end = y_next_ref[next_action_idx];
			new_q = r + discount * q_end;
		}
		out_prob.mY(i, action_idx) = new_q;
	}
}

void cQNetTrainer::BuildTupleY(int net_id, const tExpTuple& tuple, Eigen::VectorXd& out_y)
{
	double new_q = 0;
	double r = tuple.mReward;

	double discount = GetDiscount();
	double norm = CalcDiscountNorm(discount);
	r *= norm;
	
	int action_idx = 0;
	tuple.mAction.maxCoeff(&action_idx);

	const auto& curr_net = mNetPool[net_id];
	curr_net->Eval(tuple.mStateBeg, out_y);

	bool fail = tuple.GetFlag(eFlagFail);
	if (fail)
	{
		new_q = r;
	}
	else
	{
		Eigen::VectorXd y_next;
		curr_net->Eval(tuple.mStateEnd, y_next);

		int next_action_idx = 0;
		y_next.maxCoeff(&next_action_idx);

		int ref_id = GetRandRefID(net_id);
		const auto& ref_net = mNetPool[ref_id];
		ref_net->Eval(tuple.mStateEnd, y_next);
		double q_end = y_next[next_action_idx];

		new_q = r + discount * q_end;
	}

#if defined (ENABLE_DEBUG_PRINT)
	double old_q = out_y[action_idx];
	//printf("Update action %i: old: %.5f, new: %.5f\n", action_idx, new_q, old_q);
#endif

	out_y[action_idx] = new_q;
}

void cQNetTrainer::InitBatchBuffers()
{
	int batch_size = GetBatchSize();
	if (batch_size > 0)
	{
		int input_size = GetInputSize();
		int output_size = GetOutputSize();
		mBatchXBuffer.resize(batch_size, input_size);
		mBatchYBuffer0.resize(batch_size, output_size);
		mBatchYBuffer1.resize(batch_size, output_size);
	}
}

bool cQNetTrainer::Step()
{
	int i = 0;
	int max_idx = GetPoolSize();

#if defined FREEZE_TARGET_NET
	i = mCurrActiveNet;
	max_idx = i + 1;
#endif // FREEZE_TARGET_NET

	for (i; i < max_idx; ++i)
	{
		printf("Update Net %i:\n", i);
		bool succ = BuildProblem(i, mProb);
		if (succ)
		{
			UpdateNet(i, mProb);
		}
	}

	return true;
}

void cQNetTrainer::UpdateCurrActiveNetID()
{
	mCurrActiveNet = GetNextActiveID();
}

int cQNetTrainer::GetNextActiveID() const
{
	int next_idx = 0;
	next_idx = (mCurrActiveNet + 1) % GetPoolSize();
#if defined FREEZE_TARGET_NET
	int iters = GetIter();
	if ((iters == 0) || (iters % mParams.mFreezeTargetIters != 0))
	{
		next_idx = mCurrActiveNet;
	}
#endif // FREEZE_TARGET_NET
	return next_idx;
}

int cQNetTrainer::GetRandRefID(int id) const
{
	int rand_id = 0;
	int pool_size = GetPoolSize();
	if (pool_size > 1)
	{
		rand_id = cMathUtil::RandIntExclude(0, pool_size, id);
	}
	return rand_id;
}

int cQNetTrainer::CalcBufferSize() const
{
	return 1 + GetStateSize() * 2 + GetActionSize();
}

int cQNetTrainer::GetRewardIdx() const
{
	return 0;
}

int cQNetTrainer::GetStateBegIdx() const
{
	return 1;
}

int cQNetTrainer::GetStateEndIdx() const
{
	return 1 + GetStateSize() + GetActionSize();
}

int cQNetTrainer::GetActionIdx() const
{
	return 1 + GetStateSize();
}

void cQNetTrainer::SetTuple(int t, const tExpTuple& tuple)
{
	auto curr_row = mPlaybackMem.row(t);

	curr_row(GetRewardIdx()) = static_cast<float>(tuple.mReward);

	int state_size = GetStateSize();
	int state_beg_idx = GetStateBegIdx();
	int state_end_idx = GetStateEndIdx();
	int action_idx = GetActionIdx();
	int action_size = GetActionSize();

	for (int j = 0; j < state_size; ++j)
	{
		curr_row(state_beg_idx + j) = static_cast<float>(tuple.mStateBeg(j));
		curr_row(state_end_idx + j) = static_cast<float>(tuple.mStateEnd(j));
	}

	for (int j = 0; j < action_size; ++j)
	{
		curr_row(action_idx + j) = static_cast<float>(tuple.mAction(j));
	}

	mFlagBuffer[t] = tuple.mFlags;
}

tExpTuple cQNetTrainer::GetTuple(int t) const
{
	tExpTuple tuple;
	auto curr_row = mPlaybackMem.row(t);

	tuple.mID = t;
	tuple.mReward = curr_row[GetRewardIdx()];

	int state_size = GetStateSize();
	int state_beg_idx = GetStateBegIdx();
	int state_end_idx = GetStateEndIdx();
	int action_idx = GetActionIdx();
	int action_size = GetActionSize();

	tuple.mStateBeg.resize(state_size);
	tuple.mStateEnd.resize(state_size);
	tuple.mAction.resize(action_size);

	for (int j = 0; j < state_size; ++j)
	{
		tuple.mStateBeg(j) = curr_row(state_beg_idx + j);
		tuple.mStateEnd(j) = curr_row(state_end_idx + j);
	}

	for (int j = 0; j < action_size; ++j)
	{
		tuple.mAction(j) = curr_row(action_idx + j);
	}

	tuple.mFlags = mFlagBuffer[t];

	return tuple;
}
back to top