Raw File
ScenarioPoliEval.cpp
#include "ScenarioPoliEval.h"
#include "sim/NNController.h"
#include "sim/GroundVar2D.h"
#include "sim/DogControllerCacla.h"
#include "util/FileUtil.h"

const int gNumWarmupCycles = 1;

cScenarioPoliEval::cScenarioPoliEval()
{
	mPosStart.setZero();
	mAvgDist = 0;
	mEpisodeCount = 0;
	mCycleCount = 0;
	
	// analysis stuff
	mRecordNNActivation = false;
	mNNActivationOutputFile = "";
	mNNActivationLayer = "";

	mRecordActions = false;;
	mActionOutputFile = "";

	mRecordVel = false;
	mVelOutputFile = "";

	mRecordActionIDState = false;
	mActionIDStateOutputFile = "";

	mPrevCOMPos.setZero();
	mPrevTime = 0;
}

cScenarioPoliEval::~cScenarioPoliEval()
{
}

void cScenarioPoliEval::ParseArgs(const cArgParser& parser)
{
	cScenarioSimChar::ParseArgs(parser);
	parser.ParseString("policy_net", mPoliNetFile);
	parser.ParseString("policy_model", mPoliModelFile);
	parser.ParseString("critic_net", mCriticNetFile);
	parser.ParseString("critic_model", mCriticModelFile);

	parser.ParseBool("record_nn_activation", mRecordNNActivation);
	parser.ParseString("nn_activation_output_file", mNNActivationOutputFile);
	parser.ParseString("nn_activation_layer", mNNActivationLayer);

	parser.ParseBool("record_actions", mRecordActions);
	parser.ParseString("action_output_file", mActionOutputFile);

	parser.ParseBool("record_vel", mRecordVel);
	parser.ParseString("vel_output_file", mVelOutputFile);

	parser.ParseBool("record_action_id_state", mRecordActionIDState);
	parser.ParseString("action_id_state_output_file", mActionIDStateOutputFile);
}

void cScenarioPoliEval::Init()
{
	cScenarioSimChar::Init();
	mAvgDist = 0;
	mEpisodeCount = 0;
	mCycleCount = 0;
	mDistLog.clear();

	mPrevCOMPos = mChar->CalcCOM();
	mPrevTime = mTime;

	if (EnableRecordNNActivation())
	{
		InitNNActivation(mNNActivationOutputFile);
	}

	if (EnableRecordActions())
	{
		InitActionRecord(mActionOutputFile);
	}

	if (EnableRecordVel())
	{
		InitVelRecord(mVelOutputFile);
	}

	if (EnableRecordActionIDState())
	{
		InitActionIDState(mActionIDStateOutputFile);
	}
}

void cScenarioPoliEval::Reset()
{
	cScenarioSimChar::Reset();
	mPosStart = mChar->GetRootPos();

	mPrevCOMPos = mChar->CalcCOM();
	mPrevTime = mTime;
}

void cScenarioPoliEval::Clear()
{
	cScenarioSimChar::Clear();
	mAvgDist = 0;
	mEpisodeCount = 0;
	mCycleCount = 0;
	mDistLog.clear();
}

void cScenarioPoliEval::Update(double time_elapsed)
{
	cScenarioSimChar::Update(time_elapsed);

	if (time_elapsed > 0)
	{
		if (HasFallen())
		{
			if (IsValidCycle())
			{
				RecordDistTraveled();
			}
			Reset();
		}
	}
}

double cScenarioPoliEval::GetAvgDist() const
{
	return mAvgDist;
}

void cScenarioPoliEval::ResetAvgDist()
{
	mAvgDist = 0;
	mEpisodeCount = 0;
}

int cScenarioPoliEval::GetNumEpisodes() const
{
	return mEpisodeCount;
}

int cScenarioPoliEval::GetNumCycles() const
{
	return mCycleCount;
}

const std::vector<double>& cScenarioPoliEval::GetDistLog() const
{
	return mDistLog;
}

void cScenarioPoliEval::SetRandSeed(unsigned long seed)
{
	if ((typeid(*mGround.get()).hash_code() == typeid(cGroundVar2D).hash_code()))
	{
		std::shared_ptr<cGroundVar2D> ground = std::static_pointer_cast<cGroundVar2D>(mGround);
		ground->SeedRand(seed);
	}
}

std::string cScenarioPoliEval::GetName() const
{
	return "Policy Evaluation";
}

bool cScenarioPoliEval::BuildController(std::shared_ptr<cCharController>& out_ctrl)
{
	bool succ = cScenarioSimChar::BuildController(out_ctrl);

	if (mPoliNetFile != "")
	{
		std::shared_ptr<cNNController> nn_ctrl = std::static_pointer_cast<cNNController>(out_ctrl);
		succ &= nn_ctrl->LoadNet(mPoliNetFile);

		if (succ && mPoliModelFile != "")
		{
			nn_ctrl->LoadModel(mPoliModelFile);
		}
	}

	return succ;
}

bool cScenarioPoliEval::BuildDogControllerCacla(std::shared_ptr<cCharController>& out_ctrl) const
{
	bool succ = cScenarioSimChar::BuildDogControllerCacla(out_ctrl);
	std::shared_ptr<cDogControllerCacla> dog_ctrl = std::dynamic_pointer_cast<cDogControllerCacla>(out_ctrl);
	
	if (mCriticNetFile != "")
	{
		bool critic_succ = dog_ctrl->LoadCriticNet(mCriticNetFile);
		if (critic_succ && mCriticModelFile != "")
		{
			dog_ctrl->LoadCriticModel(mCriticModelFile);
		}
	}

	return succ;
}

void cScenarioPoliEval::RecordDistTraveled()
{
	tVector curr_pos = mChar->GetRootPos();
	tVector delta = curr_pos - mPosStart;
	double dist = delta[0];

	mAvgDist = cMathUtil::AddAverage(mAvgDist, mEpisodeCount, dist, 1);
	++mEpisodeCount;

	mDistLog.push_back(dist);

#if defined (ENABLE_DEBUG_PRINT)
	printf("\nEpisodes: %i\n", mEpisodeCount);
	printf("Avg dist: %.5f\n", mAvgDist);
#endif
}

bool cScenarioPoliEval::IsNewCycle() const
{
	const auto& ctrl = mChar->GetController();
	return ctrl->IsNewCycle();
}

void cScenarioPoliEval::PostSubstepUpdate(double time_step)
{
	bool new_cycle = IsNewCycle();
	if (new_cycle)
	{
		NewCycleUpdate();
	}
}

void cScenarioPoliEval::NewCycleUpdate()
{
	if (IsValidCycle())
	{
		if (EnableRecordNNActivation())
		{
			RecordNNActivation(mNNActivationLayer, mNNActivationOutputFile);
		}

		if (EnableRecordActions())
		{
			RecordAction(mActionOutputFile);
		}

		if (EnableRecordVel())
		{
			RecordVel(mVelOutputFile);
		}

		if (EnableRecordActionIDState())
		{
			RecordActionIDState(mActionIDStateOutputFile);
		}
	}
	++mCycleCount;
}

void cScenarioPoliEval::InitNNActivation(const std::string& out_file)
{
	cFileUtil::ClearFile(out_file);
}

bool cScenarioPoliEval::EnableRecordNNActivation() const
{
	return mRecordNNActivation && mNNActivationLayer != "" && mNNActivationOutputFile != "";
}

void cScenarioPoliEval::RecordNNActivation(const std::string& layer_name, const std::string& out_file)
{
	const auto& ctrl = mChar->GetController();
	const auto& nn_ctrl = std::static_pointer_cast<cNNController const>(ctrl);
	const cNeuralNet& net = nn_ctrl->GetNet();

	Eigen::VectorXd data;
	net.GetLayerState(layer_name, data);

	int data_size = static_cast<int>(data.size());
	if (data_size > 0)
	{
		std::string data_str = "";
		int action_id = ctrl->GetCurrActionID();
		data_str += std::to_string(action_id);

		for (int i = 0; i < data_size; ++i)
		{
			data_str += ",\t";
			data_str += std::to_string(data[i]);
		}
		data_str += "\n";

		cFileUtil::AppendText(data_str, out_file);
	}
}

void cScenarioPoliEval::InitActionRecord(const std::string& out_file) const
{
	FILE* file = cFileUtil::OpenFile(out_file, "w");
	const auto& ctrl = mChar->GetController();

	int num_actions = ctrl->GetNumActions();
	Eigen::VectorXd params;
	for (int a = 0; a < num_actions; ++a)
	{
		ctrl->BuildActionOptParams(a, params);
		fprintf(file, "%i", a);

		int param_size = static_cast<int>(params.size());
		for (int i = 0; i < param_size; ++i)
		{
			fprintf(file, ", %.5f", params[i]);
		}
		fprintf(file, "\n");
	}

	cFileUtil::CloseFile(file);
}

bool cScenarioPoliEval::EnableRecordActions() const
{
	return mRecordActions && mActionOutputFile != "";
}

void cScenarioPoliEval::RecordAction(const std::string& out_file)
{
	const auto& ctrl = mChar->GetController();
	
	std::string data_str = "";
	int action_id = ctrl->GetCurrActionID();
	data_str += std::to_string(action_id);

	Eigen::VectorXd params;
	ctrl->BuildOptParams(params);

	int data_size = static_cast<int>(params.size());
	for (int i = 0; i < data_size; ++i)
	{
		data_str += ",\t";
		data_str += std::to_string(params[i]);
	}
	data_str += "\n";

	cFileUtil::AppendText(data_str, out_file);
}

void cScenarioPoliEval::InitVelRecord(const std::string& out_file) const
{
	cFileUtil::ClearFile(out_file);
}

bool cScenarioPoliEval::EnableRecordVel() const
{
	return mRecordVel && mVelOutputFile != "";
}

void cScenarioPoliEval::RecordVel(const std::string& out_file)
{
	tVector curr_com = mChar->CalcCOM();
	double curr_time = mTime;

	tVector vel = curr_com - mPrevCOMPos;
	double dt = curr_time - mPrevTime;
	vel /= dt;
	
	std::string str = std::to_string(vel[0]) + "\n";
	cFileUtil::AppendText(str, out_file);

	mPrevCOMPos = curr_com;
	mPrevTime = curr_time;
}

void cScenarioPoliEval::InitActionIDState(const std::string& out_file) const
{
	cFileUtil::ClearFile(out_file);
}

bool cScenarioPoliEval::EnableRecordActionIDState() const
{
	return mRecordActionIDState && mActionIDStateOutputFile != "";
}

void cScenarioPoliEval::RecordActionIDState(const std::string& out_file)
{
	const auto& ctrl = mChar->GetController();
	auto nn_ctrl = std::static_pointer_cast<const cNNController>(mChar->GetController());

	Eigen::VectorXd state;
	nn_ctrl->RecordPoliState(state);

	std::string data_str = "";
	int action_id = ctrl->GetCurrActionID();
	data_str += std::to_string(action_id);

	for (int i = 0; i < static_cast<int>(state.size()); ++i)
	{
		data_str += ",\t";
		data_str += std::to_string(state[i]);
	}
	data_str += "\n";

	cFileUtil::AppendText(data_str, out_file);
}

bool cScenarioPoliEval::IsValidCycle() const
{
	bool valid = mCycleCount >= gNumWarmupCycles;
	return valid;
}
back to top