Raw File
ScenarioExp.cpp
#include "ScenarioExp.h"

#include <memory>
#include <ctime>
#include "sim/SimDog.h"
#include "sim/DogController.h"
#include "sim/DogControllerCacla.h"
#include "sim/DogControllerMACE.h"
#include "sim/GroundFlat.h"
#include "sim/GroundVar2D.h"

const int gNumWarmupCycles = 1;

cScenarioExp::cScenarioExp()
{
	mTupleBufferSize = 16;
	ResetTupleBuffer();
	ResetParams();

	mEnableExplore = true;
	mExpRate = 0.1;
	mExpBaseActionRate = 0.01;
}

cScenarioExp::~cScenarioExp()
{

}

void cScenarioExp::ParseArgs(const cArgParser& parser)
{
	cScenarioSimChar::ParseArgs(parser);
	parser.ParseString("policy_net", mPoliNetFile);
	parser.ParseString("policy_model", mPoliModelFile);
	parser.ParseString("critic_net", mCriticNetFile);
	parser.ParseString("critic_model", mCriticModelFile);

	parser.ParseInt("tuple_buffer_size", mTupleBufferSize);
	parser.ParseDouble("exp_rate", mExpRate);
	parser.ParseDouble("exp_temp", mExpTemp);
	parser.ParseDouble("exp_base_rate", mExpBaseActionRate);
}

void cScenarioExp::Init()
{
	cScenarioSimChar::Init();
	ResetParams();
	mTupleBuffer.resize(mTupleBufferSize);
	ResetTupleBuffer();
	
	EnableExplore(true);
	SetExpRate(mExpRate);
	SetExpTemp(mExpTemp);
	SetExpBaseActionRate(mExpBaseActionRate);

	if (EnableRandInitAction())
	{
		// start off with random action to get more diverse initial states
		CommandRandAction();
	}
}

void cScenarioExp::Reset()
{
	cScenarioSimChar::Reset();
	ResetParams();

	if (EnableRandInitAction())
	{
		// start off with random action to get more diverse initial states
		CommandRandAction();
	}
}

void cScenarioExp::Clear()
{
	cScenarioSimChar::Clear();
	ResetParams();
	mTupleBuffer.clear();
	ResetTupleBuffer();
}

void cScenarioExp::Update(double time_elapsed)
{
	cScenarioSimChar::Update(time_elapsed);

	if (time_elapsed > 0)
	{
		if (!IsNewCycle())
		{
			if (HasFallen())
			{
				NewCycleUpdate();
				Reset();
			}
		}
	}
}

void cScenarioExp::SetBufferSize(int size)
{
	mTupleBufferSize = size;
}

bool cScenarioExp::IsTupleBufferFull() const
{
	return mTupleCount >= mTupleBufferSize;
}

const std::vector<tExpTuple>& cScenarioExp::GetTuples() const
{
	return mTupleBuffer;
}

std::string cScenarioExp::GetName() const
{
	return "Exploration";
}

bool cScenarioExp::BuildController(std::shared_ptr<cCharController>& out_ctrl)
{
	bool succ = cScenarioSimChar::BuildController(out_ctrl);

	if (mPoliNetFile != "")
	{
		std::shared_ptr<cNNController> nn_ctrl = std::static_pointer_cast<cNNController>(out_ctrl);
		succ &= nn_ctrl->LoadNet(mPoliNetFile);

		if (succ && mPoliModelFile != "")
		{
			nn_ctrl->LoadModel(mPoliModelFile);
		}
	}

	return succ;
}


bool cScenarioExp::BuildDogControllerCacla(std::shared_ptr<cCharController>& out_ctrl) const
{
	bool succ = cScenarioSimChar::BuildDogControllerCacla(out_ctrl);
	std::shared_ptr<cDogControllerCacla> dog_ctrl = std::dynamic_pointer_cast<cDogControllerCacla>(out_ctrl);

	if (mCriticNetFile != "")
	{
		bool critic_succ = dog_ctrl->LoadCriticNet(mCriticNetFile);
		if (critic_succ && mCriticModelFile != "")
		{
			dog_ctrl->LoadCriticModel(mCriticModelFile);
		}
	}

	return succ;
}

void cScenarioExp::ResetParams()
{
	mCycleCount = 0;
}

void cScenarioExp::ResetTupleBuffer()
{
	mTupleCount = 0;
}

void cScenarioExp::EnableExplore(bool enable)
{
	mEnableExplore = enable;
	auto ctrl = mChar->GetController();
	ctrl->EnableExp(mEnableExplore);
}

void cScenarioExp::SetExpRate(double rate)
{
	mExpRate = rate;
	auto ctrl = mChar->GetController();
	ctrl->SetExpRate(rate);
}

void cScenarioExp::SetExpTemp(double temp)
{
	mExpTemp = temp;
	auto ctrl = mChar->GetController();
	ctrl->SetExpTemp(temp);
}

void cScenarioExp::SetExpBaseActionRate(double rate)
{
	mExpBaseActionRate = rate;
	auto ctrl = mChar->GetController();
	ctrl->SetExpBaseActionRate(rate);
}

double cScenarioExp::GetExpRate() const
{
	return mExpRate;
}

double cScenarioExp::GetExpTemp() const
{
	return mExpTemp;
}

double cScenarioExp::GetExpBaseActionRate() const
{
	return mExpBaseActionRate;
}

void cScenarioExp::PostSubstepUpdate(double time_step)
{
	if (IsNewCycle())
	{
		NewCycleUpdate();
	}
}

bool cScenarioExp::IsNewCycle() const
{
	const auto& ctrl = mChar->GetController();
	return ctrl->IsNewCycle();
}

void cScenarioExp::NewCycleUpdate()
{
	// finish recording tuple from previous cycle
	RecordState(mCurrTuple.mStateEnd);
	RecordFlagsEnd(mCurrTuple);
	mCurrTuple.mReward = CalcReward();

	// do something with the tuple
	if (IsValidTuple())
	{
		RecordTuple(mCurrTuple);
	}

	// start recording new tuple
	mCurrTuple.mStateBeg = mCurrTuple.mStateEnd;
	RecordAction(mCurrTuple.mAction);
	ClearFlags(mCurrTuple);
	RecordFlagsBeg(mCurrTuple);

	++mCycleCount;
}

std::shared_ptr<const cNNController> cScenarioExp::GetNNController() const
{
	return std::static_pointer_cast<const cNNController>(mChar->GetController());
}

std::shared_ptr<cNNController> cScenarioExp::GetNNController()
{
	return std::static_pointer_cast<cNNController>(mChar->GetController());
}

void cScenarioExp::RecordState(Eigen::VectorXd& out_state) const
{
	auto ctrl = GetNNController();
	ctrl->RecordPoliState(out_state);
}

void cScenarioExp::RecordAction(Eigen::VectorXd& out_action) const
{
	auto ctrl = GetNNController();
	ctrl->RecordPoliAction(out_action);
}

double cScenarioExp::CalcReward() const
{
	auto ctrl = GetNNController();
	double reward = ctrl->CalcReward();
	return reward;
}

bool cScenarioExp::CheckFail() const
{
	bool fail = HasFallen();
	return fail;
}


void cScenarioExp::ClearFlags(tExpTuple& out_tuple) const
{
	out_tuple.ClearFlags();
}

void cScenarioExp::RecordFlagsBeg(tExpTuple& out_tuple) const
{
}

void cScenarioExp::RecordFlagsEnd(tExpTuple& out_tuple) const
{
	bool fail = CheckFail();
	out_tuple.SetFlag(fail, cQNetTrainer::eFlagFail);
}

void cScenarioExp::RecordTuple(const tExpTuple& tuple)
{
	int idx = mTupleCount % mTupleBufferSize;
	mTupleBuffer[idx] = tuple;
	++mTupleCount;
}

bool cScenarioExp::EnableRandInitAction() const
{
	return true;
}

void cScenarioExp::CommandRandAction()
{
	const auto& ctrl = mChar->GetController();
	ctrl->CommandRandAction();
}

bool cScenarioExp::IsValidTuple() const
{
	bool valid = mCycleCount > gNumWarmupCycles;
	return valid;
}
back to top