Raw File
ACLearner.cpp
#include "ACLearner.h"
#include "ACTrainer.h"

cACLearner::cACLearner(const std::shared_ptr<cNeuralNetTrainer>& trainer)
	: cNeuralNetLearner(trainer)
{
	mCriticNet = nullptr;
}

cACLearner::~cACLearner()
{
}

void cACLearner::Init()
{
	auto trainer = std::static_pointer_cast<cACTrainer>(mTrainer);
	LoadActorNet(trainer->GetActorNetFile());
	LoadCriticNet(trainer->GetCriticNetFile());
	SyncNet();
}

void cACLearner::SetNet(cNeuralNet* net)
{
	SetActorNet(net);
}

const cNeuralNet* cACLearner::GetNet() const
{
	return GetActorNet();
}

void cACLearner::SetActorNet(cNeuralNet* net)
{
	assert(net != nullptr);
	mNet = net;
}

const cNeuralNet* cACLearner::GetActorNet() const
{
	return mNet;
}

void cACLearner::SetCriticNet(cNeuralNet* net)
{
	assert(net != nullptr);
	mCriticNet = net;
}

const cNeuralNet* cACLearner::GetCriticNet() const
{
	return mCriticNet;
}

void cACLearner::LoadNet(const std::string& net_file)
{
	LoadActorNet(net_file);
}

void cACLearner::LoadSolver(const std::string& solver_file)
{
	LoadActorSolver(solver_file);
}

void cACLearner::LoadCriticNet(const std::string& net_file)
{
	if (HasCriticNet())
	{
		mCriticNet->LoadNet(net_file);
	}
}

void cACLearner::LoadCriticSolver(const std::string& solver_file)
{
	if (HasCriticNet())
	{
		mCriticNet->LoadSolver(solver_file);
	}
}

void cACLearner::LoadActorNet(const std::string& net_file)
{
	mNet->LoadNet(net_file);
}

void cACLearner::LoadActorSolver(const std::string& solver_file)
{
	mNet->LoadSolver(solver_file);
}

void cACLearner::OutputModel(const std::string& filename) const
{
	std::string critic_filename = cACTrainer::GetCriticFilename(filename);
	OutputActor(filename);
	OutputCritic(critic_filename);
}

void cACLearner::OutputCritic(const std::string& filename) const
{
	if (HasCriticNet())
	{
		mCriticNet->OutputModel(filename);
		printf("Citic model saved to %s\n", filename.c_str());
	}
}

void cACLearner::OutputActor(const std::string& filename) const
{
	mNet->OutputModel(filename);
	printf("Actor model saved to %s\n", filename.c_str());
}

void cACLearner::SyncNet()
{
	auto trainer = std::static_pointer_cast<cACTrainer>(mTrainer);

	auto& actor_net = trainer->GetActor();
	mNet->CopyModel(*actor_net);

	if (HasCriticNet())
	{
		auto& critic_net = trainer->GetCritic();
		mCriticNet->CopyModel(*critic_net);
	}
}

bool cACLearner::HasCriticNet() const
{
	return mCriticNet != nullptr;
}
back to top