https://github.com/Microsoft/CNTK
Raw File
Tip revision: 5051c9e478170a083b0178cb262436aa38ce5926 authored by Mark Hillebrand on 18 January 2016, 08:33:02 UTC
License change
Tip revision: 5051c9e
IDistGradAggregator.h
#pragma once

#include "DistGradHeader.h"
#include "MPIWrapper.h"

namespace Microsoft { namespace MSR { namespace CNTK {

    template<class ElemType>
    class IDistGradAggregator
    {
    public:
        IDistGradAggregator(MPIWrapper* mpi)
            : m_mpi(mpi)
        {
        }
        
        virtual ~IDistGradAggregator()
        {
        }
        
        // Returns a boolean indicating if any samples were processed
        virtual bool AggregateGradients(const std::vector<Matrix<ElemType>*>& gradients, DistGradHeader *headerCPU, int nBits, int epochNumber) = 0;

        size_t NumProc()
        {
            return m_mpi->NumNodesInUse();
        }

        size_t MyRank()
        {
            return m_mpi->CurrentNodeRank();
        }

        void WaitAll()
        {
            m_mpi->WaitAll();
        }

    protected:
        MPIWrapper* m_mpi;
    };

#define UsingIDistGradAggregatorMembers \
    protected: \
        using IDistGradAggregator<ElemType>::m_mpi; using IDistGradAggregator<ElemType>::NumProc; using IDistGradAggregator<ElemType>::MyRank

}}}
back to top