Raw File
ASGDHelper.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once

#include <list>
#include "ComputationNetwork.h"

namespace Microsoft { namespace MSR { namespace CNTK {

// -----------------------------------------------------------------------
// class AdjustLearningRateAtBeginning
//       Providing option for DataParallelASGD training. so that every nodes
//       could adjust learning rate every minibatch at first N epochs.
// -----------------------------------------------------------------------
// TODO: We can removed these options once we can adjust learning rate at minibatchs level
enum class AdjustLearningRateAtBeginning : int
{
    None = 0,  // default, don't adjust learning rate
    Linearly = 1, // using linear adjustment, learning rate will from 0 to learningRatesPerMB
    Staircase = (1 << 1), // using staircased adjustment, learning rate will from 0 to learningRatesPerMB every adjustNbMinibatch
};

template<class ElemType = float>
class ASGDHelper
{
public:
    virtual ~ASGDHelper() { }
    // -----------------------------------------------------------------------
    // InitModel() -- Upload initialized model (, which was pre-computed by CNTK logic) .
    // to the parameter servers, so that every node could start training from same model
    // -----------------------------------------------------------------------
    virtual void InitModel(const std::list<ComputationNodeBasePtr> & learnableNodes) = 0;

    // -----------------------------------------------------------------------
    // PushAndPullModel() -- Push parameters of learnableNodes to parameter servers, then get the latests model back.
    // -----------------------------------------------------------------------
    virtual bool PushAndPullModel(const std::list<ComputationNodeBasePtr> & learnableNodes, size_t sampleSinceLastSynced = 0) = 0;

    // -----------------------------------------------------------------------
    // WaitAll() -- Wait(Barrier) all the other nodes to process
    // -----------------------------------------------------------------------
    virtual void WaitAll() = 0;

    // -----------------------------------------------------------------------
    // WaitAsyncBuffer() -- Wait pipeline thread to finish job when useAsyncBuffer is true
    // -----------------------------------------------------------------------
    virtual void WaitAsyncBuffer() = 0;

};  // Class ASGDHelper

// Factory method to create a ASGDHelper instance
template<class ElemType = float>
ASGDHelper<ElemType>* NewASGDHelper(
    const std::list<ComputationNodeBasePtr> & learnableNodes,                // Parameters that needs to be train
    size_t nodeNumRanks,                                                     // Number of working nodes
    bool useAsyncBuffered = true,                                            // Using asynchonous buffer to hide communication cost
    bool isSimulatedModelAveragingSGD = false,                               // Using parameter server-based MA rather than ASGD
    AdjustLearningRateAtBeginning adjusttype =
    AdjustLearningRateAtBeginning::None,                                     // Adjust learning per minibatches at very begining of training process
    double adjustCoef = 0.2,                                                 // see in DecayCoefficient()
    size_t adjustPerMinibatches = 600,                                       //
    int traceLevel = 0,                                                      // log level
    int syncPerfStats = 0);                                                  // shown perf data every syncPerfStats

}}}
back to top