// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #pragma once #include "CNTKLibrary.h" namespace CNTK { /// /// Base class for distributed learners. /// class DistributedLearnerBase : public DistributedLearner { public: Dictionary CreateCheckpoint() override; void RestoreFromCheckpoint(const Dictionary& checkpoint) override; protected: DistributedLearnerBase(DistributedCommunicatorPtr communicator, LearnerPtr learner, size_t distributeAfterSamples, bool convertSparseToDense=true); static void PrepaireZeroGradients(std::unordered_map& gradientValues); void ConvertToOrdered(const std::unordered_map& gradientValues, std::vector>& result, std::unordered_map* convertedGradientValues = nullptr); std::vector> m_gradientBuffer; std::vector m_parameters; bool m_convertSparseToDense; DistributedLearnerBase(const DistributedLearnerBase&) = delete; DistributedLearnerBase& operator=(const DistributedLearnerBase&) = delete; DistributedLearnerBase& operator=(DistributedLearnerBase&&) = delete; DistributedLearnerBase(DistributedLearnerBase&& other) = delete; }; }