Revision b51e8c243a8b0080ed71e960224c7ca20f9a1363 authored by Vadim Mazalov on 06 November 2018, 00:32:59 UTC, committed by Vadim Mazalov on 06 November 2018, 00:32:59 UTC
1 parent a19ce9e
IDistGradAggregator.h
#pragma once
#include "DistGradHeader.h"
#include "MPIWrapper.h"
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class IDistGradAggregator
{
public:
IDistGradAggregator(const MPIWrapperPtr& mpi)
: m_mpi(mpi)
{}
virtual ~IDistGradAggregator()
{}
// Returns a boolean indicating if any samples were processed
virtual bool AggregateGradients(const std::vector<Matrix<ElemType>*>& gradients, DistGradHeader* headerCPU, bool resetState) = 0;
size_t NumProc()
{
return m_mpi->NumNodesInUse();
}
size_t MyRank()
{
return m_mpi->CurrentNodeRank();
}
void WaitAll()
{
m_mpi->WaitAll();
}
protected:
MPIWrapperPtr m_mpi;
};
#define UsingIDistGradAggregatorMembers \
\
protected: \
using IDistGradAggregator<ElemType>::m_mpi; \
using IDistGradAggregator<ElemType>::NumProc; \
using IDistGradAggregator<ElemType>::MyRank
} } }
Computing file changes ...