//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include <boost/algorithm/string/predicate.hpp>
#include "CNTKLibrary.h"
#include "Utils.h"
#include "fileutil.h"
#include "PerformanceProfiler.h"
namespace CNTK
{
using namespace std;
const static std::wstring s_trainingMinibatchSource = L"TrainingMinibatchSource";
inline bool isNumber(const std::wstring& s)
{
return !s.empty() &&
find_if(s.begin(), s.end(), [](wchar_t c) { return !isdigit(c); }) == s.end();
}
inline static bool IsInfinite(MinibatchSourcePtr mbSource, size_t numberOfSamples = std::numeric_limits<size_t>::max())
{
return (numberOfSamples == MinibatchSource::InfinitelyRepeat ||
numberOfSamples >= (size_t)std::numeric_limits<long long>::max()) &&
mbSource->IsInfinite();
}
CheckpointConfig::CheckpointConfig(
const std::wstring& checkPointFileName,
size_t checkpointFrequency,
DataUnit checkpointFrequencyUnit,
bool restoreFromCheckpointIfExists,
bool preserveAllCheckpoints) :
m_preserveAll(preserveAllCheckpoints),
m_restore(restoreFromCheckpointIfExists),
m_fileName(checkPointFileName),
m_frequency(checkpointFrequency),
m_frequencyUnit(checkpointFrequencyUnit)
{
if (m_fileName.empty())
{
if (checkpointFrequency != 0 && checkpointFrequency != std::numeric_limits<size_t>::max())
InvalidArgument("Checkpoint file name must not be empty if checkpoint frequency is non zero.");
if (preserveAllCheckpoints)
InvalidArgument("Checkpoint file name must not be empty if 'preserve all checkpoints' option is specified.");
checkpointFrequency = 0;
}
}
CrossValidationConfig::CrossValidationConfig(
const MinibatchSourcePtr& crossValidationSource,
const MinibatchSizeSchedule& crossValidationSchedule,
size_t crossValidationFrequency,
DataUnit crossValidationFrequencyUnit,
size_t maxSamples,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream):
m_source(crossValidationSource),
m_mbSize(crossValidationSchedule),
m_frequency(crossValidationFrequency),
m_frequencyUnit(crossValidationFrequencyUnit),
m_maxSamples(maxSamples),
m_varToStream(inputVarToStream)
{
}
TestConfig::TestConfig(
const MinibatchSourcePtr& source,
const MinibatchSizeSchedule& schedule,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream) :
m_source(source),
m_mbSize(schedule),
m_varToStream(inputVarToStream)
{
}
CNTK_API TrainingSessionPtr CreateTrainingSession(
const TrainerPtr& trainer,
const MinibatchSourcePtr& trainingSource,
const MinibatchSizeSchedule& minibatchSizeSchedule,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream,
size_t maxNumTrainingSamples,
size_t progressFrequency,
DataUnit progressFrequencyUnit,
const CheckpointConfig& checkpointing,
const CrossValidationConfig& crossValidation,
const TestConfig& test)
{
return MakeSharedObject<TrainingSession>(trainer,
trainingSource,
minibatchSizeSchedule,
inputVarToStream,
maxNumTrainingSamples,
progressFrequency,
progressFrequencyUnit,
checkpointing, crossValidation, test);
}
TrainingSession::TrainingSession(
const TrainerPtr& trainer,
const MinibatchSourcePtr& trainingSource,
const MinibatchSizeSchedule& minibatchSizeSchedule,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream,
size_t maxNumTrainingSamples,
size_t progressFrequency,
DataUnit progressFrequencyUnit,
const CheckpointConfig& checkpointing,
const CrossValidationConfig& crossValidation,
const TestConfig& test) :
m_trainer(trainer),
m_source(trainingSource),
m_mbSize(minibatchSizeSchedule),
m_varToStream(inputVarToStream),
m_maxNumSamples(maxNumTrainingSamples),
m_progressFrequency(progressFrequency),
m_progressFrequencyUnit(progressFrequencyUnit),
m_checkpoint(checkpointing),
m_cv(crossValidation),
m_parallelAfterSamples(0),
m_workerRank(0),
m_numberOfWorkers(1),
m_test(test),
m_mbSizeScaleFactor(1)
{
if (!m_trainer)
InvalidArgument("Trainer must not be null.");
if (!m_source)
InvalidArgument("Training source must not be null.");
if (m_maxNumSamples == 0)
InvalidArgument("maxNumTrainingSamples must not be zero.");
if (m_varToStream.empty())
InvalidArgument("inputVarToStream mapping must not be empty.");
// Let's calculate the warm up period the distributed learners may need.
// We will take the maximum warm up period required.
auto learners = m_trainer->ParameterLearners();
m_parallelAfterSamples = 0;
for (const auto& l : learners)
{
auto distributed = std::dynamic_pointer_cast<DistributedLearner>(l);
if (distributed)
{
m_parallelAfterSamples = std::max(m_parallelAfterSamples, distributed->ParallelizationAfter());
m_workerRank = distributed->GetCommunicator()->CurrentWorker().m_globalRank;
m_numberOfWorkers = distributed->GetCommunicator()->Workers().size();
m_mbSizeScaleFactor = distributed->MinibatchSizeScaleFactor();
}
}
// Fill-in required actions.
if (m_checkpoint.m_frequency != 0)
m_actions.push_back({ m_checkpoint.m_frequency, m_checkpoint.m_frequencyUnit, 0, 0,
[this](size_t currentIndex, const DeviceDescriptor&)
{
SaveCheckpoint(currentIndex);
// enable profiler after the first checkpoint
// This has effect only if the profiler is globally enabled by StartProfiler()
#ifndef CNTK_UWP
Microsoft::MSR::CNTK::ProfilerEnable(true);
#endif
return true;
} });
// Report progress before we run cross validation if any.
if (m_progressFrequency != 0)
{
m_actions.push_back({ m_progressFrequency, m_progressFrequencyUnit, 0, 0,
[this](size_t currentIndex, const DeviceDescriptor&) { ReportProgress(currentIndex); return true; } });
}
if (m_cv.m_frequency != 0)
m_actions.push_back({ m_cv.m_frequency, m_cv.m_frequencyUnit, 0, 0,
[this](size_t currentIndex, const DeviceDescriptor& d) { return CrossValidate(currentIndex, d); } });
}
void TrainingSession::Train(const DeviceDescriptor& computeDevice)
{
std::unordered_map<Variable, ValuePtr> minibatch;
bool shouldTrain = m_maxNumSamples > 0;
// Let's try to restore if required.
size_t restoredNumberOfSamples = 0;
if (m_checkpoint.m_restore && !m_checkpoint.m_fileName.empty())
{
RestoreFromCheckpoint();
restoredNumberOfSamples = m_trainer->TotalNumberOfSamplesSeen();
}
if (IsInfinite(m_source, m_maxNumSamples))
InvalidArgument("Train minibatch source must have a limited number of samples or sweeps.");
// Main train loop.
bool earlyExit = false;
while (shouldTrain)
{
// Get next minibatch.
size_t samplesLeft = earlyExit || m_maxNumSamples <= Trainer()->TotalNumberOfSamplesSeen()
? 0
: m_maxNumSamples - Trainer()->TotalNumberOfSamplesSeen();
//get the sweep end status from GetTrainingMinibatch and use in TrainMiniBatch below
bool isMinibatchAtSweepEnd;
// Note that in case of distributed training we don't want to stop if the local minibatch
// is empty - it is possible that the other workers are still processing their minibatches.
GetTrainingMinibatch(minibatch, &isMinibatchAtSweepEnd, samplesLeft, computeDevice);
// Train on the minibatch.
OnMinibatchStart();
shouldTrain = Trainer()->TrainMinibatch(minibatch, isMinibatchAtSweepEnd, computeDevice);
earlyExit |= !OnMinibatchEnd(); // If the callback wants to have early exit - we stop training.
#ifndef CNTK_UWP
auto profMisc = Microsoft::MSR::CNTK::ScopeProfile(Microsoft::MSR::CNTK::profilerEvtMainPost);
#endif
// Peform actions if required.
for (auto& action : m_actions)
{
size_t totalNumberOfUnitCounts = Trainer()->TotalNumberOfUnitsSeen(action.frequencyUnit);
size_t index = totalNumberOfUnitCounts / action.frequency;
if (index != action.currentIndex)
{
// If any action wants to have early exit - we stop training.
earlyExit |= !action.action(action.currentIndex, computeDevice);
action.currentIndex = index;
action.unitCountWhenLastCalled = totalNumberOfUnitCounts;
}
}
}
if (restoredNumberOfSamples != Trainer()->TotalNumberOfSamplesSeen())
{
// Let's do all actions on the last probably a partial data at the end.
for (auto& action: m_actions)
{
size_t totalUnitCounts = Trainer()->TotalNumberOfUnitsSeen(action.frequencyUnit);
if (totalUnitCounts % action.frequency != 0 &&
totalUnitCounts != action.unitCountWhenLastCalled)
action.action(action.currentIndex, computeDevice);
}
}
// In case of incremental - save final checkpoint.
// This is required only when we keep all existing checkpoints, otherwise
// The checkpoint was already saved with the proper name.
if (m_checkpoint.m_frequency &&
m_checkpoint.m_preserveAll &&
!fexists(m_checkpoint.m_fileName))
SaveFinalCheckpoint();
// Perform testing according to the test config.
Test(computeDevice);
}
// TODO: Possibly expose a limiting counter on the number of samples for validation.
bool TrainingSession::CrossValidate(size_t currentIndex, const DeviceDescriptor& computeDevice)
{
// Making sure we get the consistent state of the
// training minibatch source in case of bptt.
// When CV happens in the middle of the training, the packer can still has some truncated
// sequences in the buffer. CV resets the state of the DelayedNode, so the first
// training minibatch after CV will cause an exception.
// Checkpoining currently drop intermediat buffers.
// TODO: This is meant as a stop gap, the minibatch source should be properly drained instead.
auto state = m_source->GetCheckpointState();
bool result = false;
if (m_cv.m_source) // Running cross validation
{
if (IsInfinite(m_cv.m_source, m_cv.m_maxSamples))
InvalidArgument("Cross validation minibatch source must have a limited number of samples or sweeps.");
std::unordered_map<Variable, ValuePtr> minibatch;
double accumulatedError = 0;
size_t totalNumberOfSamples = 0;
size_t numberOfMinibatches = 0;
std::pair<ValuePtr, size_t> errorAndCount;
auto checkpoint = m_cv.m_source->GetCheckpointState();
bool shouldCV = true;
while (shouldCV)
{
size_t samplesLeft = m_cv.m_maxSamples <= totalNumberOfSamples ? 0 : m_cv.m_maxSamples - totalNumberOfSamples;
GetCrossValidationMinibatch(minibatch, /*pIsMinibatchAtSweepEnd= */nullptr, (std::min)(m_cv.m_mbSize[totalNumberOfSamples], samplesLeft), computeDevice);
// TODO: it may be slow to rely on TestMinibatch to return error each time, since it may require transfer
// of error from the GPU each time, accumulatedError can be allocated on GPU
shouldCV = m_trainer->TestMinibatch(minibatch, errorAndCount, computeDevice, m_numberOfWorkers != 1);
if (shouldCV)
{
accumulatedError += errorAndCount.first->AsScalar<double>();
totalNumberOfSamples += errorAndCount.second;
numberOfMinibatches++;
}
}
m_cv.m_source->RestoreFromCheckpoint(checkpoint);
Trainer()->SummarizeTestProgress();
result = OnCrossValidationEnd(currentIndex, accumulatedError / totalNumberOfSamples, totalNumberOfSamples, numberOfMinibatches);
}
else // Only invoking the callback.
{
result = OnCrossValidationEnd(currentIndex, 0, 0, 0);
}
m_source->RestoreFromCheckpoint(state);
return result;
}
void TrainingSession::Test(const DeviceDescriptor& computeDevice)
{
if (!m_test.m_source)
return;
if (IsInfinite(m_test.m_source))
InvalidArgument("Test minibatch source must have a limited number of samples or sweeps.");
std::unordered_map<Variable, ValuePtr> minibatch;
size_t totalNumberOfSamples = 0;
bool shouldTest = true;
std::pair<ValuePtr, size_t> errorAndCount;
while (shouldTest)
{
GetNextMinibatch(m_test.m_source, minibatch, m_test.m_varToStream.empty() ? m_varToStream : m_test.m_varToStream,
/*pIsMinibatchAtSweepEnd=*/ nullptr, m_test.m_mbSize[totalNumberOfSamples], m_workerRank, m_numberOfWorkers, computeDevice);
shouldTest = m_trainer->TestMinibatch(minibatch, errorAndCount, computeDevice, m_numberOfWorkers != 1);
totalNumberOfSamples += errorAndCount.second;
}
m_trainer->SummarizeTestProgress();
}
inline void TrainingSession::ReportProgress(size_t /*currentIndex*/)
{
Trainer()->SummarizeTrainingProgress();
}
void TrainingSession::GetTrainingMinibatch(std::unordered_map<Variable, ValuePtr>& minibatch, bool* pIsMinibatchAtSweepEnd, size_t maxMbSize, const DeviceDescriptor& computeDevice)
{
size_t workerRank = m_workerRank, numberOfWorkers = m_numberOfWorkers;
// Check if we are operating in distributed mode.
size_t scaleFactor = m_mbSizeScaleFactor;
if (m_parallelAfterSamples > Trainer()->TotalNumberOfSamplesSeen())
{
numberOfWorkers = 1;
workerRank = 0;
scaleFactor = 1;
}
size_t mbSize = GetMinibatchSize() * scaleFactor;
mbSize = (std::min)(mbSize, maxMbSize);
GetNextMinibatch(m_source, minibatch, m_varToStream, pIsMinibatchAtSweepEnd, mbSize, workerRank, numberOfWorkers, computeDevice);
}
void TrainingSession::GetCrossValidationMinibatch(std::unordered_map<Variable, ValuePtr>& minibatch, bool* pIsMinibatchAtSweepEnd, size_t maxMbSize, const DeviceDescriptor& computeDevice)
{
GetNextMinibatch(m_cv.m_source, minibatch, m_cv.m_varToStream.empty() ? m_varToStream : m_cv.m_varToStream, pIsMinibatchAtSweepEnd, maxMbSize, m_workerRank, m_numberOfWorkers, computeDevice);
}
void TrainingSession::GetNextMinibatch(
const MinibatchSourcePtr& source,
std::unordered_map<Variable, ValuePtr>& minibatch,
const std::unordered_map<Variable, StreamInformation>& inputVarToStream,
bool* pIsMinibatchAtSweepEnd,
size_t mbSize,
size_t workerRank,
size_t numberOfWorkers,
const DeviceDescriptor& computeDevice)
{
minibatch.clear();
if (mbSize == 0)
return;
// TODO: is copy really necessary here?
auto minibatchData = source->GetNextMinibatch(0 /*numberOfSequences*/, mbSize, numberOfWorkers, workerRank, computeDevice);
if (pIsMinibatchAtSweepEnd != nullptr)
*pIsMinibatchAtSweepEnd = IsAtSweepEnd(minibatchData);
if (minibatchData.empty())
return;
for (auto v : inputVarToStream)
{
auto value = minibatchData.find(v.second);
if (value == minibatchData.end())
RuntimeError("Minibatch source cannot find a stream with name '%ls'", v.second.m_name.c_str());
minibatch.insert({ v.first, value->second.data });
}
}
void TrainingSession::RestoreFromCheckpoint(const std::wstring& checkpointFileName)
{
Dictionary externalState = Trainer()->RestoreFromCheckpoint(checkpointFileName);
m_source->RestoreFromCheckpoint(externalState[s_trainingMinibatchSource].Value<Dictionary>());
}
void TrainingSession::SaveCheckpoint(size_t currentIndex)
{
OnCheckpointStart(currentIndex);
Dictionary externalState;
externalState[s_trainingMinibatchSource] = m_source->GetCheckpointState();
wstring checkpointFile = m_checkpoint.m_fileName;
if (m_checkpoint.m_preserveAll)
checkpointFile += std::to_wstring(currentIndex);
Trainer()->SaveCheckpoint(checkpointFile, externalState);
OnCheckpointEnd(currentIndex);
}
void TrainingSession::SaveFinalCheckpoint()
{
Dictionary externalState;
externalState[s_trainingMinibatchSource] = m_source->GetCheckpointState();
Trainer()->SaveCheckpoint(m_checkpoint.m_fileName, externalState);
}
// Restores from a m_checkPointFileName file.
// If the file path exists - simply restores from the corresponding file.
// If the file path does not exist - looks into directory where the file is
// located and picks up the file with the largest N among <m_checkPointFileName>N files,
// Where N is some positive integer.
void TrainingSession::RestoreFromCheckpoint()
{
assert(!m_checkpoint.m_fileName.empty());
auto checkpoint = m_checkpoint.m_fileName;
// Make sure the intermediate directories exist, so no need for further checks.
msra::files::make_intermediate_dirs(checkpoint);
size_t pos = checkpoint.find_last_of(L"\\/");
wstring parent;
wstring fileName;
if (pos == wstring::npos)
{
parent = L".";
fileName = checkpoint;
}
else
{
parent = checkpoint.substr(0, pos);
fileName = checkpoint.substr(pos + 1);
}
std::wstring restoreFile;
if (fexists(checkpoint))
{
restoreFile = checkpoint;
}
else
{
// let's check whether there are other possible candidates to restore from.
int maxValue = -1;
std::vector<std::wstring> files = msra::files::get_all_files_from_directory(parent);
for (auto f : files)
{
if (!boost::starts_with(f, fileName))
{
continue;
}
auto suffix = f.substr(fileName.size());
if (!isNumber(suffix) || !fexists(parent + L"/" + f + L".ckp"))
{
continue;
}
auto expectedNumber = msra::strfun::utf8(suffix);
char* tmp = nullptr;
int value = strtol(expectedNumber.c_str(), &tmp, 10);
if (tmp != expectedNumber.c_str() + expectedNumber.size())
continue;
if (value > maxValue)
{
// Found a better candidate.
maxValue = value;
restoreFile = parent + L"/" + f;
}
}
}
if (restoreFile.empty()) // Nothing to restore.
return;
// TODO: Should have proper loggin instead.
fprintf(stderr, "Restoring training session from the checkpoint '%ls'\n", restoreFile.c_str());
this->RestoreFromCheckpoint(restoreFile);
// Recalculate actions indicies.
for (auto& action : m_actions)
{
size_t totalNumberOfUnitCounts = Trainer()->TotalNumberOfUnitsSeen(action.frequencyUnit);
action.currentIndex = totalNumberOfUnitCounts / action.frequency;
action.unitCountWhenLastCalled = totalNumberOfUnitCounts - totalNumberOfUnitCounts % action.frequency;
}
}
}