// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include #include "CNTKLibrary.h" #include "Utils.h" #include "fileutil.h" #include "PerformanceProfiler.h" namespace CNTK { using namespace std; const static std::wstring s_trainingMinibatchSource = L"TrainingMinibatchSource"; inline bool isNumber(const std::wstring& s) { return !s.empty() && find_if(s.begin(), s.end(), [](wchar_t c) { return !isdigit(c); }) == s.end(); } inline static bool IsInfinite(MinibatchSourcePtr mbSource, size_t numberOfSamples = std::numeric_limits::max()) { return (numberOfSamples == MinibatchSource::InfinitelyRepeat || numberOfSamples >= (size_t)std::numeric_limits::max()) && mbSource->IsInfinite(); } CheckpointConfig::CheckpointConfig( const std::wstring& checkPointFileName, size_t checkpointFrequency, DataUnit checkpointFrequencyUnit, bool restoreFromCheckpointIfExists, bool preserveAllCheckpoints) : m_preserveAll(preserveAllCheckpoints), m_restore(restoreFromCheckpointIfExists), m_fileName(checkPointFileName), m_frequency(checkpointFrequency), m_frequencyUnit(checkpointFrequencyUnit) { if (m_fileName.empty()) { if (checkpointFrequency != 0 && checkpointFrequency != std::numeric_limits::max()) InvalidArgument("Checkpoint file name must not be empty if checkpoint frequency is non zero."); if (preserveAllCheckpoints) InvalidArgument("Checkpoint file name must not be empty if 'preserve all checkpoints' option is specified."); checkpointFrequency = 0; } } CrossValidationConfig::CrossValidationConfig( const MinibatchSourcePtr& crossValidationSource, const MinibatchSizeSchedule& crossValidationSchedule, size_t crossValidationFrequency, DataUnit crossValidationFrequencyUnit, size_t maxSamples, const std::unordered_map& inputVarToStream): m_source(crossValidationSource), m_mbSize(crossValidationSchedule), m_frequency(crossValidationFrequency), m_frequencyUnit(crossValidationFrequencyUnit), m_maxSamples(maxSamples), m_varToStream(inputVarToStream) { } TestConfig::TestConfig( const MinibatchSourcePtr& source, const MinibatchSizeSchedule& schedule, const std::unordered_map& inputVarToStream) : m_source(source), m_mbSize(schedule), m_varToStream(inputVarToStream) { } CNTK_API TrainingSessionPtr CreateTrainingSession( const TrainerPtr& trainer, const MinibatchSourcePtr& trainingSource, const MinibatchSizeSchedule& minibatchSizeSchedule, const std::unordered_map& inputVarToStream, size_t maxNumTrainingSamples, size_t progressFrequency, DataUnit progressFrequencyUnit, const CheckpointConfig& checkpointing, const CrossValidationConfig& crossValidation, const TestConfig& test) { return MakeSharedObject(trainer, trainingSource, minibatchSizeSchedule, inputVarToStream, maxNumTrainingSamples, progressFrequency, progressFrequencyUnit, checkpointing, crossValidation, test); } TrainingSession::TrainingSession( const TrainerPtr& trainer, const MinibatchSourcePtr& trainingSource, const MinibatchSizeSchedule& minibatchSizeSchedule, const std::unordered_map& inputVarToStream, size_t maxNumTrainingSamples, size_t progressFrequency, DataUnit progressFrequencyUnit, const CheckpointConfig& checkpointing, const CrossValidationConfig& crossValidation, const TestConfig& test) : m_trainer(trainer), m_source(trainingSource), m_mbSize(minibatchSizeSchedule), m_varToStream(inputVarToStream), m_maxNumSamples(maxNumTrainingSamples), m_progressFrequency(progressFrequency), m_progressFrequencyUnit(progressFrequencyUnit), m_checkpoint(checkpointing), m_cv(crossValidation), m_parallelAfterSamples(0), m_workerRank(0), m_numberOfWorkers(1), m_test(test), m_mbSizeScaleFactor(1) { if (!m_trainer) InvalidArgument("Trainer must not be null."); if (!m_source) InvalidArgument("Training source must not be null."); if (m_maxNumSamples == 0) InvalidArgument("maxNumTrainingSamples must not be zero."); if (m_varToStream.empty()) InvalidArgument("inputVarToStream mapping must not be empty."); // Let's calculate the warm up period the distributed learners may need. // We will take the maximum warm up period required. auto learners = m_trainer->ParameterLearners(); m_parallelAfterSamples = 0; for (const auto& l : learners) { auto distributed = std::dynamic_pointer_cast(l); if (distributed) { // This is always enforced now in the Learners class. m_parallelAfterSamples = std::max(m_parallelAfterSamples, distributed->ParallelizationAfter()); m_workerRank = distributed->GetCommunicator()->CurrentWorker().m_globalRank; m_numberOfWorkers = distributed->GetCommunicator()->Workers().size(); m_mbSizeScaleFactor = distributed->MinibatchSizeScaleFactor(); } } // Fill-in required actions. if (m_checkpoint.m_frequency != 0) m_actions.push_back({ m_checkpoint.m_frequency, m_checkpoint.m_frequencyUnit, 0, 0, [this](size_t currentIndex, const DeviceDescriptor&) { SaveCheckpoint(currentIndex); // enable profiler after the first checkpoint // This has effect only if the profiler is globally enabled by StartProfiler() #ifndef CNTK_UWP Microsoft::MSR::CNTK::ProfilerEnable(true); #endif return true; } }); // Report progress before we run cross validation if any. if (m_progressFrequency != 0) { m_actions.push_back({ m_progressFrequency, m_progressFrequencyUnit, 0, 0, [this](size_t currentIndex, const DeviceDescriptor&) { ReportProgress(currentIndex); return true; } }); } if (m_cv.m_frequency != 0) m_actions.push_back({ m_cv.m_frequency, m_cv.m_frequencyUnit, 0, 0, [this](size_t currentIndex, const DeviceDescriptor& d) { return CrossValidate(currentIndex, d); } }); } void TrainingSession::Train(const DeviceDescriptor& computeDevice) { std::unordered_map minibatch; bool shouldTrain = m_maxNumSamples > 0; // Let's try to restore if required. size_t restoredNumberOfSamples = 0; if (m_checkpoint.m_restore && !m_checkpoint.m_fileName.empty()) { RestoreFromCheckpoint(); restoredNumberOfSamples = m_trainer->TotalNumberOfSamplesSeen(); } if (IsInfinite(m_source, m_maxNumSamples)) InvalidArgument("Train minibatch source must have a limited number of samples or sweeps."); // Main train loop. bool earlyExit = false; while (shouldTrain) { // Get next minibatch. size_t samplesLeft = earlyExit || m_maxNumSamples <= Trainer()->TotalNumberOfSamplesSeen() ? 0 : m_maxNumSamples - Trainer()->TotalNumberOfSamplesSeen(); //get the sweep end status from GetTrainingMinibatch and use in TrainMiniBatch below bool isMinibatchAtSweepEnd; // Note that in case of distributed training we don't want to stop if the local minibatch // is empty - it is possible that the other workers are still processing their minibatches. GetTrainingMinibatch(minibatch, &isMinibatchAtSweepEnd, samplesLeft, computeDevice); // Train on the minibatch. OnMinibatchStart(); shouldTrain = Trainer()->TrainMinibatch(minibatch, isMinibatchAtSweepEnd, computeDevice); earlyExit |= !OnMinibatchEnd(); // If the callback wants to have early exit - we stop training. #ifndef CNTK_UWP auto profMisc = Microsoft::MSR::CNTK::ScopeProfile(Microsoft::MSR::CNTK::profilerEvtMainPost); #endif // Peform actions if required. for (auto& action : m_actions) { size_t totalNumberOfUnitCounts = Trainer()->TotalNumberOfUnitsSeen(action.frequencyUnit); size_t index = totalNumberOfUnitCounts / action.frequency; if (index != action.currentIndex) { // If any action wants to have early exit - we stop training. earlyExit |= !action.action(action.currentIndex, computeDevice); action.currentIndex = index; action.unitCountWhenLastCalled = totalNumberOfUnitCounts; } } } if (restoredNumberOfSamples != Trainer()->TotalNumberOfSamplesSeen()) { // Let's do all actions on the last probably a partial data at the end. for (auto& action: m_actions) { size_t totalUnitCounts = Trainer()->TotalNumberOfUnitsSeen(action.frequencyUnit); if (totalUnitCounts % action.frequency != 0 && totalUnitCounts != action.unitCountWhenLastCalled) action.action(action.currentIndex, computeDevice); } } // In case of incremental - save final checkpoint. // This is required only when we keep all existing checkpoints, otherwise // The checkpoint was already saved with the proper name. if (m_checkpoint.m_frequency && m_checkpoint.m_preserveAll && !fexists(m_checkpoint.m_fileName)) SaveFinalCheckpoint(); // Perform testing according to the test config. Test(computeDevice); } // TODO: Possibly expose a limiting counter on the number of samples for validation. bool TrainingSession::CrossValidate(size_t currentIndex, const DeviceDescriptor& computeDevice) { // Making sure we get the consistent state of the // training minibatch source in case of bptt. // When CV happens in the middle of the training, the packer can still has some truncated // sequences in the buffer. CV resets the state of the DelayedNode, so the first // training minibatch after CV will cause an exception. // Checkpoining currently drop intermediat buffers. // TODO: This is meant as a stop gap, the minibatch source should be properly drained instead. auto state = m_source->GetCheckpointState(); bool result = false; if (m_cv.m_source) // Running cross validation { if (IsInfinite(m_cv.m_source, m_cv.m_maxSamples)) InvalidArgument("Cross validation minibatch source must have a limited number of samples or sweeps."); std::unordered_map minibatch; double accumulatedError = 0; size_t totalNumberOfSamples = 0; size_t numberOfMinibatches = 0; std::pair errorAndCount; auto checkpoint = m_cv.m_source->GetCheckpointState(); bool shouldCV = true; while (shouldCV) { size_t samplesLeft = m_cv.m_maxSamples <= totalNumberOfSamples ? 0 : m_cv.m_maxSamples - totalNumberOfSamples; GetCrossValidationMinibatch(minibatch, /*pIsMinibatchAtSweepEnd= */nullptr, (std::min)(m_cv.m_mbSize[totalNumberOfSamples], samplesLeft), computeDevice); // TODO: it may be slow to rely on TestMinibatch to return error each time, since it may require transfer // of error from the GPU each time, accumulatedError can be allocated on GPU shouldCV = m_trainer->TestMinibatch(minibatch, errorAndCount, computeDevice, m_numberOfWorkers != 1); if (shouldCV) { accumulatedError += errorAndCount.first->AsScalar(); totalNumberOfSamples += errorAndCount.second; numberOfMinibatches++; } } m_cv.m_source->RestoreFromCheckpoint(checkpoint); Trainer()->SummarizeTestProgress(); result = OnCrossValidationEnd(currentIndex, accumulatedError / totalNumberOfSamples, totalNumberOfSamples, numberOfMinibatches); } else // Only invoking the callback. { result = OnCrossValidationEnd(currentIndex, 0, 0, 0); } m_source->RestoreFromCheckpoint(state); return result; } void TrainingSession::Test(const DeviceDescriptor& computeDevice) { if (!m_test.m_source) return; if (IsInfinite(m_test.m_source)) InvalidArgument("Test minibatch source must have a limited number of samples or sweeps."); std::unordered_map minibatch; size_t totalNumberOfSamples = 0; bool shouldTest = true; std::pair errorAndCount; while (shouldTest) { GetNextMinibatch(m_test.m_source, minibatch, m_test.m_varToStream.empty() ? m_varToStream : m_test.m_varToStream, /*pIsMinibatchAtSweepEnd=*/ nullptr, m_test.m_mbSize[totalNumberOfSamples], m_workerRank, m_numberOfWorkers, computeDevice); shouldTest = m_trainer->TestMinibatch(minibatch, errorAndCount, computeDevice, m_numberOfWorkers != 1); totalNumberOfSamples += errorAndCount.second; } m_trainer->SummarizeTestProgress(); } inline void TrainingSession::ReportProgress(size_t /*currentIndex*/) { Trainer()->SummarizeTrainingProgress(); } void TrainingSession::GetTrainingMinibatch(std::unordered_map& minibatch, bool* pIsMinibatchAtSweepEnd, size_t maxMbSize, const DeviceDescriptor& computeDevice) { size_t workerRank = m_workerRank, numberOfWorkers = m_numberOfWorkers; // Check if we are operating in distributed mode. size_t scaleFactor = m_mbSizeScaleFactor; if (m_parallelAfterSamples > Trainer()->TotalNumberOfSamplesSeen()) { numberOfWorkers = 1; workerRank = 0; scaleFactor = 1; } size_t mbSize = GetMinibatchSize() * scaleFactor; mbSize = (std::min)(mbSize, maxMbSize); GetNextMinibatch(m_source, minibatch, m_varToStream, pIsMinibatchAtSweepEnd, mbSize, workerRank, numberOfWorkers, computeDevice); } void TrainingSession::GetCrossValidationMinibatch(std::unordered_map& minibatch, bool* pIsMinibatchAtSweepEnd, size_t maxMbSize, const DeviceDescriptor& computeDevice) { GetNextMinibatch(m_cv.m_source, minibatch, m_cv.m_varToStream.empty() ? m_varToStream : m_cv.m_varToStream, pIsMinibatchAtSweepEnd, maxMbSize, m_workerRank, m_numberOfWorkers, computeDevice); } void TrainingSession::GetNextMinibatch( const MinibatchSourcePtr& source, std::unordered_map& minibatch, const std::unordered_map& inputVarToStream, bool* pIsMinibatchAtSweepEnd, size_t mbSize, size_t workerRank, size_t numberOfWorkers, const DeviceDescriptor& computeDevice) { minibatch.clear(); if (mbSize == 0) return; // TODO: is copy really necessary here? auto minibatchData = source->GetNextMinibatch(0 /*numberOfSequences*/, mbSize, numberOfWorkers, workerRank, computeDevice); if (pIsMinibatchAtSweepEnd != nullptr) *pIsMinibatchAtSweepEnd = IsAtSweepEnd(minibatchData); if (minibatchData.empty()) return; for (auto v : inputVarToStream) { auto value = minibatchData.find(v.second); if (value == minibatchData.end()) RuntimeError("Minibatch source cannot find a stream with name '%ls'", v.second.m_name.c_str()); minibatch.insert({ v.first, value->second.data }); } } void TrainingSession::RestoreFromCheckpoint(const std::wstring& checkpointFileName) { Dictionary externalState = Trainer()->RestoreFromCheckpoint(checkpointFileName); m_source->RestoreFromCheckpoint(externalState[s_trainingMinibatchSource].Value()); } void TrainingSession::SaveCheckpoint(size_t currentIndex) { OnCheckpointStart(currentIndex); Dictionary externalState; externalState[s_trainingMinibatchSource] = m_source->GetCheckpointState(); wstring checkpointFile = m_checkpoint.m_fileName; if (m_checkpoint.m_preserveAll) checkpointFile += std::to_wstring(currentIndex); Trainer()->SaveCheckpoint(checkpointFile, externalState); OnCheckpointEnd(currentIndex); } void TrainingSession::SaveFinalCheckpoint() { Dictionary externalState; externalState[s_trainingMinibatchSource] = m_source->GetCheckpointState(); Trainer()->SaveCheckpoint(m_checkpoint.m_fileName, externalState); } // Restores from a m_checkPointFileName file. // If the file path exists - simply restores from the corresponding file. // If the file path does not exist - looks into directory where the file is // located and picks up the file with the largest N among N files, // Where N is some positive integer. void TrainingSession::RestoreFromCheckpoint() { assert(!m_checkpoint.m_fileName.empty()); auto checkpoint = m_checkpoint.m_fileName; // Make sure the intermediate directories exist, so no need for further checks. msra::files::make_intermediate_dirs(checkpoint); size_t pos = checkpoint.find_last_of(L"\\/"); wstring parent; wstring fileName; if (pos == wstring::npos) { parent = L"."; fileName = checkpoint; } else { parent = checkpoint.substr(0, pos); fileName = checkpoint.substr(pos + 1); } std::wstring restoreFile; if (fexists(checkpoint)) { restoreFile = checkpoint; } else { // let's check whether there are other possible candidates to restore from. int maxValue = -1; std::vector files = msra::files::get_all_files_from_directory(parent); for (auto f : files) { if (!boost::starts_with(f, fileName)) { continue; } auto suffix = f.substr(fileName.size()); if (!isNumber(suffix) || !fexists(parent + L"/" + f + L".ckp")) { continue; } auto expectedNumber = Microsoft::MSR::CNTK::ToLegacyString(Microsoft::MSR::CNTK::ToUTF8(suffix)); char* tmp = nullptr; int value = strtol(expectedNumber.c_str(), &tmp, 10); if (tmp != expectedNumber.c_str() + expectedNumber.size()) continue; if (value > maxValue) { // Found a better candidate. maxValue = value; restoreFile = parent + L"/" + f; } } } if (restoreFile.empty()) // Nothing to restore. return; // TODO: Should have proper loggin instead. fprintf(stderr, "Restoring training session from the checkpoint '%ls'\n", restoreFile.c_str()); this->RestoreFromCheckpoint(restoreFile); // Recalculate actions indicies. for (auto& action : m_actions) { size_t totalNumberOfUnitCounts = Trainer()->TotalNumberOfUnitsSeen(action.frequencyUnit); action.currentIndex = totalNumberOfUnitCounts / action.frequency; action.unitCountWhenLastCalled = totalNumberOfUnitCounts - totalNumberOfUnitCounts % action.frequency; } } }