https://github.com/Microsoft/CNTK
Raw File
Tip revision: c3e4f667455d8f39506e16fb14615840e1359446 authored by Mark Hillebrand on 13 September 2017, 14:20:33 UTC
Tests/Install/linux/prep-run-test.sh: do not run CNTK_104_Finance_Timeseries_Basic_with_Pandas_Numpy.ipynb
Tip revision: c3e4f66
DistributedLearnerBase.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include "CNTKLibrary.h"

namespace CNTK
{
    ///
    /// Base class for distributed learners.
    ///
    class DistributedLearnerBase : public DistributedLearner
    {
    public:
        Dictionary CreateCheckpoint() override;

        void RestoreFromCheckpoint(const Dictionary& checkpoint) override;

    protected:
        DistributedLearnerBase(DistributedCommunicatorPtr communicator, LearnerPtr learner, size_t distributeAfterSamples, bool convertSparseToDense=true);

        static void PrepaireZeroGradients(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, MinibatchInfo& info);
        void ConvertToOrdered(const std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, std::vector<std::pair<Parameter, NDArrayViewPtr>>& result, std::unordered_map<Parameter, NDArrayViewPtr>* convertedGradientValues = nullptr);

        std::vector<std::pair<Parameter, NDArrayViewPtr>> m_gradientBuffer;
        std::vector<Parameter> m_parameters;

        bool m_convertSparseToDense;

        DistributedLearnerBase(const DistributedLearnerBase&) = delete; DistributedLearnerBase& operator=(const DistributedLearnerBase&) = delete; DistributedLearnerBase& operator=(DistributedLearnerBase&&) = delete; DistributedLearnerBase(DistributedLearnerBase&& other) = delete;
    };
}
back to top