https://github.com/xbpeng/DeepTerrainRL
Tip revision: ed82e2ebe5f14fa875cc3d0a2180c64980408e8f authored by Glen on 19 October 2016, 17:49:36 UTC
Update README.md
Update README.md
Tip revision: ed82e2e
AsyncSolver.h
#pragma once
#include "NNSolver.h"
template <typename tSolverType>
class cSolverAsync : public cCaffeSolver<tSolverType>
{
public:
cSolverAsync(const caffe::SolverParameter& param)
: cCaffeSolver<tSolverType>(param) {};
virtual ~cSolverAsync() {};
virtual void ApplySteps(int steps);
};
template <typename tSolverType>
void cSolverAsync<tSolverType>::ApplySteps(int steps) {
std::vector<caffe::Blob<cNeuralNet::tNNData>*> bottom_vec;
const int start_iter = cSolverAsync<tSolverType>::iter_;
const int stop_iter = cSolverAsync<tSolverType>::iter_ + steps;
while (cSolverAsync<tSolverType>::iter_ < stop_iter) {
cSolverAsync<tSolverType>::ApplyUpdate();
++cSolverAsync<tSolverType>::iter_;
caffe::SolverAction::Enum request = cSolverAsync<tSolverType>::GetRequestedAction();
// Save a snapshot if needed.
if ((cSolverAsync<tSolverType>::param_.snapshot()
&& cSolverAsync<tSolverType>::iter_ % cSolverAsync<tSolverType>::param_.snapshot() == 0
&& caffe::Caffe::root_solver()) ||
(request == caffe::SolverAction::SNAPSHOT)) {
cSolverAsync<tSolverType>::Snapshot();
}
if (caffe::SolverAction::STOP == request) {
cSolverAsync<tSolverType>::requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
typedef cSolverAsync<caffe::SGDSolver<cNeuralNet::tNNData>> cSGDSolverAsync;
typedef cSolverAsync<caffe::NesterovSolver<cNeuralNet::tNNData>> cNesterovSolverAsync;
typedef cSolverAsync<caffe::AdaGradSolver<cNeuralNet::tNNData>> cAdaGradSolverAsync;
typedef cSolverAsync<caffe::RMSPropSolver<cNeuralNet::tNNData>> cRMSPropSolverAsync;
typedef cSolverAsync<caffe::AdaDeltaSolver<cNeuralNet::tNNData>> cAdaDeltaSolverAsync;
typedef cSolverAsync<caffe::AdamSolver<cNeuralNet::tNNData>> cAdamSolverAsync;