// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include "BatchNormalizationEngine.h" #include "CuDnnFactories.h" namespace Microsoft { namespace MSR { namespace CNTK { template void BatchNormEngine::Forward(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, double blendFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) { assert(in.GetNumRows() == m_inOutT.GetNumElements()); assert(out.GetNumRows() == m_inOutT.GetNumElements()); assert(in.GetNumCols() == out.GetNumCols()); assert(std::isfinite(expAvgFactor) && (0 <= expAvgFactor && expAvgFactor <= 1)); assert(std::isfinite(blendFactor) && (0 <= blendFactor && blendFactor <= 1)); assert(std::isfinite(epsilon) && epsilon > 0); if (!m_spatial) { assert(m_inOutT.GetNumElements() == scale.GetNumRows()); assert(m_inOutT.GetNumElements() == bias.GetNumRows()); assert(m_inOutT.GetNumElements() == runMean.GetNumRows()); assert(m_inOutT.GetNumElements() == runInvStdDev.GetNumRows()); assert(saveMean.GetNumElements() == 0 || m_inOutT.GetNumElements() == saveMean.GetNumRows()); assert(saveInvStdDev.GetNumElements() == 0 || m_inOutT.GetNumElements() == saveInvStdDev.GetNumRows()); } else { assert((m_inOutT.GetNumElements() % scale.GetNumRows()) == 0); assert((m_inOutT.GetNumElements() % bias.GetNumRows()) == 0); assert((m_inOutT.GetNumElements() % runMean.GetNumRows()) == 0); assert((m_inOutT.GetNumElements() % runInvStdDev.GetNumRows()) == 0); assert(saveMean.GetNumElements() == 0 || (m_inOutT.GetNumElements() % saveMean.GetNumRows()) == 0); assert(saveInvStdDev.GetNumElements() == 0 || (m_inOutT.GetNumElements() % saveInvStdDev.GetNumRows()) == 0); } assert(scale.GetNumCols() == 1); assert(bias.GetNumCols() == 1); assert(runMean.GetNumCols() == 1); assert(runInvStdDev.GetNumCols() == 1); assert(saveMean.GetNumElements() == 0 || saveMean.GetNumCols() == 1); assert(saveInvStdDev.GetNumElements() == 0 || saveInvStdDev.GetNumCols() == 1); EnsureCompatible(); ForwardCore(in, scale, bias, expAvgFactor, blendFactor, runMean, runInvStdDev, out, epsilon, saveMean, saveInvStdDev); } template void BatchNormEngine::Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev, Mat& scaleGrad, Mat& biasGrad) { EnsureCompatible(); BackwardCore(in, srcGrad, grad, scale, saveMean, saveInvStdDev, scaleGrad, biasGrad); } template class CntkBatchNormEngine : public BatchNormEngine { public: using Base = BatchNormEngine; using typename Base::Mat; public: CntkBatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT, bool spatial, ImageLayoutKind imageLayout) : Base(deviceId, inOutT, spatial, imageLayout) { } protected: using Base::m_deviceId; using Base::m_imageLayout; using Base::m_inOutT; using Base::m_spatial; void EnsureCompatible() override { if (m_spatial && m_imageLayout == ImageLayoutKind::HWC) InvalidArgument("CNTK batch normalization supports only cudnn(CHW) layout."); } void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, double expAvgFactor, double blendFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) override { in.BatchNormalizationForward(scale, bias, expAvgFactor, blendFactor, runMean, runInvStdDev, out, epsilon, saveMean, saveInvStdDev); } void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& saveMean, const Mat& saveInvStdDev, Mat& scaleGrad, Mat& biasGrad) override { srcGrad.BatchNormalizationBackward(in, grad, scale, saveMean, saveInvStdDev, scaleGrad, biasGrad); } }; template class CntkBatchNormEngine; template class CntkBatchNormEngine; template bool HasFlag(T src, T testFlag) { return ((int)src & (int)testFlag) != 0; } template std::unique_ptr> BatchNormEngine::Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT, bool spatial, ImageLayoutKind imageLayout, BatchNormEngineKind enabledEngines) { // Use CNTK as default batch norm engine. if (HasFlag(enabledEngines, BatchNormEngineKind::Cntk)) { fprintf(stderr, "\nUsing CNTK batch normalization engine.\n"); return std::make_unique>(deviceId, inOutT, spatial, imageLayout); } if (HasFlag(enabledEngines, BatchNormEngineKind::CuDnn)) { fprintf(stderr, "\nUsing cuDNN batch normalization engine.\n"); return CuDnnBatchNormEngineFactory::Create(deviceId, inOutT, spatial, imageLayout); } RuntimeError("Could not find appropriate batch normalization engine."); } template class BatchNormEngine; template class BatchNormEngine; } } }