Raw File
ScenarioTrainCacla.cpp
#include "ScenarioTrainCacla.h"
#include "ScenarioExpCacla.h"
#include "util/FileUtil.h"
#include "sim/BaseControllerCacla.h"
#include "learning/CaclaTrainer.h"
#include "learning/AsyncCaclaTrainer.h"
#include "learning/ACLearner.h"

cScenarioTrainCacla::cScenarioTrainCacla()
{
	mTrainerParams.mPoolSize = 1;
}

cScenarioTrainCacla::~cScenarioTrainCacla()
{
}

void cScenarioTrainCacla::ParseArgs(const cArgParser& parser)
{
	cScenarioTrain::ParseArgs(parser);
	parser.ParseString("critic_solver", mCriticSolverFile);
	parser.ParseString("critic_net", mCriticNetFile);
	parser.ParseString("critic_model", mCriticModelFile);

	mActorSolverFile = mTrainerParams.mSolverFile;
	mActorNetFile = mTrainerParams.mNetFile;
	mActorModelFile = mPoliModelFile;

	mTrainerParams.mSolverFile = mCriticSolverFile;
	mTrainerParams.mNetFile = mCriticNetFile;
}

std::string cScenarioTrainCacla::GetName() const
{
	return "Train Cacla";
}

void cScenarioTrainCacla::BuildTrainer(std::shared_ptr<cTrainerInterface>& out_trainer)
{
	if (mEnableAsyncMode)
	{
		auto trainer = std::shared_ptr<cAsyncCaclaTrainer>(new cAsyncCaclaTrainer());
		trainer->SetActorFiles(mActorSolverFile, mActorNetFile);
		out_trainer = trainer;
	}
	else
	{
		auto trainer = std::shared_ptr<cCaclaTrainer>(new cCaclaTrainer());
		trainer->SetActorFiles(mActorSolverFile, mActorNetFile);
		out_trainer = trainer;
	}
}

void cScenarioTrainCacla::LoadModel()
{
	if (mCriticModelFile != "")
	{
		mTrainer->LoadCriticModel(mCriticModelFile);
	}

	if (mActorModelFile != "")
	{
		mTrainer->LoadActorModel(mActorModelFile);
	}
}

void cScenarioTrainCacla::BuildExpScene(std::shared_ptr<cScenarioExp>& out_exp) const
{
	out_exp = std::shared_ptr<cScenarioExp>(new cScenarioExpCacla());
}

void cScenarioTrainCacla::SetupLearner(const std::shared_ptr<cSimCharacter>& character, std::shared_ptr<cNeuralNetLearner>& out_learner) const
{
	cScenarioTrain::SetupLearner(character, out_learner);

	auto ac_ctrl = std::dynamic_pointer_cast<cBaseControllerCacla>(character->GetController());
	if (ac_ctrl != nullptr)
	{
		auto ac_learner = std::static_pointer_cast<cACLearner>(out_learner);
		cNeuralNet& critic = ac_ctrl->GetCritic();
		ac_learner->SetCriticNet(&critic);
	}
	else
	{
		assert(false); // controller does not support actor-critic
	}
}

void cScenarioTrainCacla::SetupTrainerOutputOffsetScale()
{
	SetupTrainerCriticOutputOffsetScale();
	SetupTrainerActorOutputOffsetScale();
}

void cScenarioTrainCacla::SetupTrainerCriticOutputOffsetScale()
{
	bool valid_init_model = mTrainer->HasCriticInitModel();
	if (!valid_init_model)
	{
		int critic_output_size = mTrainer->GetCriticOutputSize();

		const std::shared_ptr<cCharController>& ctrl = GetRefController();
		std::shared_ptr<cBaseControllerCacla> cacla_ctrl = std::dynamic_pointer_cast<cBaseControllerCacla>(ctrl);

		if (cacla_ctrl != nullptr)
		{
			Eigen::VectorXd critic_output_offset;
			Eigen::VectorXd critic_output_scale;
			cacla_ctrl->BuildCriticOutputOffsetScale(critic_output_offset, critic_output_scale);

			assert(critic_output_offset.size() == critic_output_size);
			assert(critic_output_scale.size() == critic_output_size);
			mTrainer->SetCriticOutputOffsetScale(critic_output_offset, critic_output_scale);
		}
		else
		{
			assert(false); // controller does not implement actor-critic interface
		}
	}
}

void cScenarioTrainCacla::SetupTrainerActorOutputOffsetScale()
{
	bool valid_init_model = mTrainer->HasCriticInitModel();
	if (!valid_init_model)
	{
		int actor_output_size = mTrainer->GetActorOutputSize();

		const std::shared_ptr<cCharController>& ctrl = GetRefController();
		std::shared_ptr<cBaseControllerCacla> cacla_ctrl = std::dynamic_pointer_cast<cBaseControllerCacla>(ctrl);

		if (cacla_ctrl != nullptr)
		{
			Eigen::VectorXd actor_output_offset;
			Eigen::VectorXd actor_output_scale;
			cacla_ctrl->BuildActorOutputOffsetScale(actor_output_offset, actor_output_scale);

			assert(actor_output_offset.size() == actor_output_size);
			assert(actor_output_scale.size() == actor_output_size);
			mTrainer->SetActorOutputOffsetScale(actor_output_offset, actor_output_scale);
		}
		else
		{
			assert(false); // controller does not support CACLA
		}
	}
}
back to top