https://github.com/xbpeng/DeepTerrainRL
Raw File
Tip revision: ed82e2ebe5f14fa875cc3d0a2180c64980408e8f authored by Glen on 19 October 2016, 17:49:36 UTC
Update README.md
Tip revision: ed82e2e
NeuralNetLearner.cpp
#include "NeuralNetLearner.h"
#include "NeuralNetTrainer.h"

cNeuralNetLearner::cNeuralNetLearner(const std::shared_ptr<cNeuralNetTrainer>& trainer)
{
	assert(trainer != nullptr);
	mTrainer = trainer;
	mIter = 0;
	mNumTuples = 0;
	mNet = nullptr;

	mID = mTrainer->RegisterLearner(this);
}

cNeuralNetLearner::~cNeuralNetLearner()
{
	mTrainer->UnregisterLearner(this);
}

void cNeuralNetLearner::Reset()
{
	mIter = 0;
	mNumTuples = 0;
	SyncNet();
}

void cNeuralNetLearner::Init()
{
	LoadNet(mTrainer->GetNetFile());
	SyncNet();
}

void cNeuralNetLearner::Train(const std::vector<tExpTuple>& tuples)
{
	mTrainer->Lock();

	UpdateTrainer();
	mTrainer->AddTuples(tuples);
	mTrainer->Train();
	SyncNet();

	mIter = mTrainer->GetIter();
	mNumTuples = mTrainer->GetNumTuples();

	mTrainer->Unlock();
}

int cNeuralNetLearner::GetIter() const
{
	return mIter;
}

int cNeuralNetLearner::GetNumTuples() const
{
	return mNumTuples;
}

void cNeuralNetLearner::SetNet(cNeuralNet* net)
{
	assert(net != nullptr);
	mNet = net;
}

const cNeuralNet* cNeuralNetLearner::GetNet() const
{
	return mNet;
}

void cNeuralNetLearner::LoadNet(const std::string& net_file)
{
	mNet->LoadNet(net_file);
}

void cNeuralNetLearner::LoadSolver(const std::string& solver_file)
{
	mNet->LoadSolver(solver_file);
}

void cNeuralNetLearner::OutputModel(const std::string& filename) const
{
	mNet->OutputModel(filename);
	printf("Model saved to %s\n", filename.c_str());
}

void cNeuralNetLearner::SyncNet()
{
	auto& trainer_net = mTrainer->GetNet();
	mNet->CopyModel(*trainer_net);
}

bool cNeuralNetLearner::IsDone() const
{
	return mTrainer->IsDone();
}

void cNeuralNetLearner::UpdateTrainer()
{
}
back to top