https://github.com/Microsoft/CNTK
Raw File
Tip revision: 8c2abcb4e37bf056b41950af275f2adba2a713d0 authored by Mark Hillebrand on 05 April 2016, 17:23:16 UTC
WIP
Tip revision: 8c2abcb
BatchNormalizationEngine.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#include "stdafx.h"
#include "BatchNormalizationEngine.h"
#include "CuDnnFactories.h"

namespace Microsoft { namespace MSR { namespace CNTK {

template <class ElemType>
void BatchNormEngine<ElemType>::Forward(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, double blendFactor, Mat& runMean, Mat& runInvStdDev,
                                        Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev)
{
    assert(in.GetNumRows() == m_inOutT.GetNumElements());
    assert(out.GetNumRows() == m_inOutT.GetNumElements());
    assert(in.GetNumCols() == out.GetNumCols());
    assert(std::isfinite(expAvgFactor) && (0 <= expAvgFactor && expAvgFactor <= 1));
    assert(std::isfinite(blendFactor) && (0 <= blendFactor && blendFactor <= 1));
    assert(std::isfinite(epsilon) && epsilon > 0);
    if (!m_spatial)
    {
        assert(m_inOutT.GetNumElements() == scale.GetNumRows());
        assert(m_inOutT.GetNumElements() == bias.GetNumRows());
        assert(m_inOutT.GetNumElements() == runMean.GetNumRows());
        assert(m_inOutT.GetNumElements() == runInvStdDev.GetNumRows());
        assert(saveMean.GetNumElements() == 0 || m_inOutT.GetNumElements() == saveMean.GetNumRows());
        assert(saveInvStdDev.GetNumElements() == 0 || m_inOutT.GetNumElements() == saveInvStdDev.GetNumRows());
    }
    else
    {
        assert((m_inOutT.GetNumElements() % scale.GetNumRows()) == 0);
        assert((m_inOutT.GetNumElements() % bias.GetNumRows()) == 0);
        assert((m_inOutT.GetNumElements() % runMean.GetNumRows()) == 0);
        assert((m_inOutT.GetNumElements() % runInvStdDev.GetNumRows()) == 0);
        assert(saveMean.GetNumElements() == 0 || (m_inOutT.GetNumElements() % saveMean.GetNumRows()) == 0);
        assert(saveInvStdDev.GetNumElements() == 0 || (m_inOutT.GetNumElements() % saveInvStdDev.GetNumRows()) == 0);
    }
    assert(scale.GetNumCols() == 1);
    assert(bias.GetNumCols() == 1);
    assert(runMean.GetNumCols() == 1);
    assert(runInvStdDev.GetNumCols() == 1);
    assert(saveMean.GetNumElements() == 0 || saveMean.GetNumCols() == 1);
    assert(saveInvStdDev.GetNumElements() == 0 || saveInvStdDev.GetNumCols() == 1);

    EnsureCompatible();
    ForwardCore(in, scale, bias, expAvgFactor, blendFactor, runMean, runInvStdDev, out, epsilon, saveMean, saveInvStdDev);
}

template <class ElemType>
void BatchNormEngine<ElemType>::Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, 
                                         const Mat& saveMean, const Mat& saveInvStdDev, Mat& scaleGrad, Mat& biasGrad)
{
    EnsureCompatible();
    BackwardCore(in, srcGrad, grad, scale, saveMean, saveInvStdDev, scaleGrad, biasGrad);
}

template <class ElemType>
class CntkBatchNormEngine : public BatchNormEngine<ElemType>
{
public:
    using Base = BatchNormEngine<ElemType>;
    using typename Base::Mat;

public:
    CntkBatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
                        bool spatial, ImageLayoutKind imageLayout)
                        : Base(deviceId, inOutT, spatial, imageLayout)
    {
    }

protected:
    using Base::m_deviceId;
    using Base::m_imageLayout;
    using Base::m_inOutT;
    using Base::m_spatial;

    void EnsureCompatible() override
    {
        if (m_spatial && m_imageLayout == ImageLayoutKind::HWC)
            InvalidArgument("CNTK batch normalization supports only cudnn(CHW) layout.");
    }

    void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, double blendFactor, Mat& runMean, Mat& runInvStdDev,
                     Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) override
    {
        in.BatchNormalizationForward(scale, bias, expAvgFactor, blendFactor, runMean, runInvStdDev, out, epsilon, saveMean, saveInvStdDev);
    }

    void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev,
                      Mat& scaleGrad, Mat& biasGrad) override
    {
        srcGrad.BatchNormalizationBackward(in, grad, scale, saveMean, saveInvStdDev, scaleGrad, biasGrad);
    }
};

template class CntkBatchNormEngine<float>;
template class CntkBatchNormEngine<double>;

template <typename T>
bool HasFlag(T src, T testFlag)
{
    return ((int)src & (int)testFlag) != 0;
}

template <class ElemType>
std::unique_ptr<BatchNormEngine<ElemType>> BatchNormEngine<ElemType>::Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
                                                                             bool spatial, ImageLayoutKind imageLayout,
                                                                             BatchNormEngineKind enabledEngines)
{
    // Use CNTK as default batch norm engine.
    if (HasFlag(enabledEngines, BatchNormEngineKind::Cntk))
    {
        fprintf(stderr, "\nUsing CNTK batch normalization engine.\n");
        return std::make_unique<CntkBatchNormEngine<ElemType>>(deviceId, inOutT, spatial, imageLayout);
    }

    if (HasFlag(enabledEngines, BatchNormEngineKind::CuDnn))
    {
        fprintf(stderr, "\nUsing cuDNN batch normalization engine.\n");
        return CuDnnBatchNormEngineFactory<ElemType>::Create(deviceId, inOutT, spatial, imageLayout);
    }

    RuntimeError("Could not find appropriate batch normalization engine.");
}

template class BatchNormEngine<float>;
template class BatchNormEngine<double>;

} } }
back to top