// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include "BatchNormalizationEngine.h" #include "CuDnnFactories.h" #include "Mkl2017DnnCommon.h" namespace Microsoft { namespace MSR { namespace CNTK { template void BatchNormEngine::Forward(const Mat& in, const Mat& scale, const Mat& bias, bool inferenceOnly, double expAvgFactor, double blendFactor, Mat& runMean, Mat& runVariance, Mat& out, double epsilon, Mat& savedMean, Mat& savedInvStdDev) { assert(in.GetNumRows() == m_inOutT.GetNumElements()); assert(out.GetNumRows() == m_inOutT.GetNumElements()); assert(in.GetNumCols() == out.GetNumCols()); assert(std::isfinite(expAvgFactor) && (0 <= expAvgFactor && expAvgFactor <= 1)); assert(std::isfinite(blendFactor) && (0 <= blendFactor && blendFactor <= 1)); // In inference mode, must only use running statistics assert(!inferenceOnly || ((expAvgFactor == 0.0) && (blendFactor == 1.0))); assert(std::isfinite(epsilon) && epsilon > 0); if (!m_spatial) { assert(m_inOutT.GetNumElements() == scale.GetNumRows()); assert(m_inOutT.GetNumElements() == bias.GetNumRows()); assert(m_inOutT.GetNumElements() == runMean.GetNumRows()); assert(m_inOutT.GetNumElements() == runVariance.GetNumRows()); } else { assert((m_inOutT.GetNumElements() % scale.GetNumRows()) == 0); assert((m_inOutT.GetNumElements() % bias.GetNumRows()) == 0); assert((m_inOutT.GetNumElements() % runMean.GetNumRows()) == 0); assert((m_inOutT.GetNumElements() % runVariance.GetNumRows()) == 0); } assert(scale.GetNumCols() == 1); assert(bias.GetNumCols() == 1); assert(runMean.GetNumCols() == 1); assert(runVariance.GetNumCols() == 1); EnsureCompatible(); ForwardCore(in, scale, bias, inferenceOnly, expAvgFactor, blendFactor, runMean, runVariance, out, epsilon, savedMean, savedInvStdDev); if (!inferenceOnly) { assert(!savedMean.IsEmpty()); assert(!savedInvStdDev.IsEmpty()); if (!m_spatial) { assert(m_inOutT.GetNumElements() == savedMean.GetNumRows()); assert(m_inOutT.GetNumElements() == savedInvStdDev.GetNumRows()); } else { assert((m_inOutT.GetNumElements() % savedMean.GetNumRows()) == 0); assert((m_inOutT.GetNumElements() % savedInvStdDev.GetNumRows()) == 0); } assert(savedMean.GetNumCols() == 1); assert(savedInvStdDev.GetNumCols() == 1); } } template void BatchNormEngine::Backward(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, double blendFactor, const Mat& savedMean, const Mat& savedInvStdDev, Mat& scaleGrad, Mat& biasGrad, bool accumulateDataGrad) { assert(!savedMean.IsEmpty()); assert(!savedInvStdDev.IsEmpty()); EnsureCompatible(); BackwardCore(in, srcGrad, grad, scale, blendFactor, savedMean, savedInvStdDev, scaleGrad, biasGrad, accumulateDataGrad); } template class CntkBatchNormEngine : public BatchNormEngine { public: using Base = BatchNormEngine; using typename Base::Mat; public: CntkBatchNormEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT, bool spatial, ImageLayoutKind imageLayout) : Base(deviceId, inOutT, spatial, imageLayout) { } protected: using Base::m_deviceId; using Base::m_imageLayout; using Base::m_inOutT; using Base::m_spatial; void EnsureCompatible() override { if (m_spatial && m_imageLayout == ImageLayoutKind::HWC) InvalidArgument("CNTK batch normalization supports only cudnn(CHW) layout."); } void ForwardCore(const Mat& in, const Mat& scale, const Mat& bias, bool inferenceOnly, double expAvgFactor, double blendFactor, Mat& runMean, Mat& runVariance, Mat& out, double epsilon, Mat& savedMean, Mat& savedInvStdDev) override { #ifdef USE_MKL2017DNN if (in.GetCurrentMatrixLocation() == CPU && ForwardCoreMKL(in, scale, bias, inferenceOnly, expAvgFactor, runMean, runVariance, out, epsilon, savedMean, savedInvStdDev)) return; #endif in.BatchNormalizationForward(scale, bias, inferenceOnly, expAvgFactor, blendFactor, runMean, runVariance, out, epsilon, savedMean, savedInvStdDev); } void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, double blendFactor, const Mat& savedMean, const Mat& savedInvStdDev, Mat& scaleGrad, Mat& biasGrad, bool accumulateDataGrad) override { #ifdef USE_MKL2017DNN if (srcGrad.GetCurrentMatrixLocation() == CPU && BackwardCoreMKL(in, srcGrad, grad, scale, savedMean, savedInvStdDev, scaleGrad, biasGrad, accumulateDataGrad)) return; #endif if (!accumulateDataGrad) grad.SetValue((ElemType)0); srcGrad.BatchNormalizationBackward(in, grad, scale, blendFactor, savedMean, savedInvStdDev, scaleGrad, biasGrad); } private: #ifdef USE_MKL2017DNN // default epsilon that matches cuDNN when no forward executed #define DEFAULT_EPSILON 1e-5 enum ContextIndex { ContextIndex_ForwardInfer = 0, ContextIndex_ForwardTrain, ContextIndex_Backward, ContextIndex_Total }; class MKLBatchNormalizationContext { private: int m_contextFlags = 0; // MKL uses a single buffer for both scale and shift, so allocate a buffer and convert struct MKLScaleShiftAdapter { bool isInput; std::shared_ptr> mat; dnnResourceType_t resourceType; size_t numChannels; void Create(dnnResourceType_t rt, bool userToPrim, size_t n) { Clear(); numChannels = n; mat = std::make_shared>(numChannels, 2, CPUDEVICE); isInput = userToPrim; resourceType = rt; } void PrepareForExecution(void* scale, void* bias, void* resources[dnnResourceNumber]) { ElemType* buffer = mat->Data(); resources[resourceType] = buffer; if (isInput) { memcpy(buffer, scale, sizeof(ElemType) * numChannels); memcpy(buffer + numChannels, bias, sizeof(ElemType) * numChannels); } } void ConvertOutput(void* scale, void* bias) { if (isInput) RuntimeError("Cannot execute output ResourceAdapter for input"); ElemType* buffer = mat->Data(); memcpy(scale, buffer, sizeof(ElemType) * numChannels); memcpy(bias, buffer + numChannels, sizeof(ElemType) * numChannels); } void Clear() { if (mat) mat.reset(); } ~MKLScaleShiftAdapter() { Clear(); } }; struct PrimitiveContext { MKLDnnResourceAdapter input; MKLDnnResourceAdapter output; MKLScaleShiftAdapter scaleShift; std::shared_ptr varianceMat; // variance matrix used for converting InvStdDev dnnPrimitive_t primitive = nullptr; dnnPrimitiveAttributes_t attributes = nullptr; void Clear() { if (primitive) { dnnDelete(primitive); primitive = nullptr; } input.Clear(); scaleShift.Clear(); output.Clear(); if (attributes) { dnnPrimitiveAttributesDestroy(attributes); attributes = nullptr; } } ~PrimitiveContext() { Clear(); } } m_context[ContextIndex_Total]; TensorShape m_shape; size_t m_numSamples; ElemType m_epsilon; public: MKLBatchNormalizationContext() : m_numSamples(0), m_epsilon(0) { } bool HasPreparedFor(ContextIndex contextIndex) const { return !!(m_contextFlags & (1 << contextIndex)); } void Prepare(const TensorShape& shape, bool spatial, size_t numSamples, ContextIndex contextIndex, ElemType epsilon = 0) { int flag = (1 << contextIndex); if (contextIndex == ContextIndex_Backward) { epsilon = HasPreparedFor(ContextIndex_ForwardTrain) ? m_epsilon : (ElemType)DEFAULT_EPSILON; } bool same = (shape == m_shape) && (numSamples == m_numSamples) && (epsilon == m_epsilon); if (same && !!(m_contextFlags & flag)) return; if (!same) m_contextFlags = 0; if (m_contextFlags) { if ((m_numSamples != numSamples) || (m_epsilon != epsilon) || (m_shape != shape)) RuntimeError("MKLBatchNormalizationContext: Inconsistent num samples between forward and backward"); } else { m_shape = shape; m_numSamples = numSamples; m_epsilon = epsilon; } m_contextFlags |= flag; const size_t inoutDim = 4; size_t rank = m_shape.GetRank(); size_t numElements = m_shape.GetNumElements(); size_t numChannels = spatial ? ((rank > 0) ? m_shape.GetDim(rank - 1) : 1) : numElements; // flatten all dims of a sample when non-spatial size_t numPixels = numElements / numChannels; size_t dimFirst = (rank > 1 && spatial) ? m_shape.GetDim(0) : 1; size_t dimSecond = numPixels / dimFirst; size_t inoutSizes[4] = { dimFirst, dimSecond, numChannels, m_numSamples }; size_t inoutStrides[4] = { 1, dimFirst, numPixels, numElements }; auto& ctx = m_context[contextIndex]; ctx.Clear(); dnnLayout_t ltUserInput, ltPrimInput; dnnLayout_t ltUserOutput, ltPrimOutput; dnnResourceType_t inputType; dnnResourceType_t outputType; dnnResourceType_t scaleShiftType; switch (contextIndex) { case ContextIndex_ForwardInfer: case ContextIndex_ForwardTrain: CHECK_MKL(dnnLayoutCreate(<UserInput, inoutDim, inoutSizes, inoutStrides)); CHECK_MKL(dnnLayoutCreate(<UserOutput, inoutDim, inoutSizes, inoutStrides)); CHECK_MKL(dnnPrimitiveAttributesCreate(&ctx.attributes)); CHECK_MKL(dnnBatchNormalizationCreateForward_v2( &ctx.primitive, ctx.attributes, ltUserInput, m_epsilon, dnnUseScaleShift | ((contextIndex == ContextIndex_ForwardInfer) ? dnnUseInputMeanVariance : 0))); inputType = dnnResourceSrc; outputType = dnnResourceDst; scaleShiftType = dnnResourceScaleShift; break; case ContextIndex_Backward: CHECK_MKL(dnnLayoutCreate(<UserInput, inoutDim, inoutSizes, inoutStrides)); CHECK_MKL(dnnLayoutCreate(<UserOutput, inoutDim, inoutSizes, inoutStrides)); CHECK_MKL(dnnPrimitiveAttributesCreate(&ctx.attributes)); CHECK_MKL(dnnBatchNormalizationCreateBackward_v2( &ctx.primitive, ctx.attributes, ltUserInput, m_epsilon, dnnUseScaleShift)); inputType = dnnResourceDiffDst; outputType = dnnResourceDiffSrc; scaleShiftType = dnnResourceDiffScaleShift; ctx.varianceMat = std::make_shared(numChannels, 1, CPUDEVICE); break; default: RuntimeError("Unexpected context type %d", (int)contextIndex); } CHECK_MKL(dnnLayoutCreateFromPrimitive(<PrimInput, ctx.primitive, inputType)); ctx.input.Create(ltUserInput, ltPrimInput, inputType, true); CHECK_MKL(dnnLayoutCreateFromPrimitive(<PrimOutput, ctx.primitive, outputType)); ctx.output.Create(ltUserOutput, ltPrimOutput, outputType, false); ctx.scaleShift.Create(scaleShiftType, contextIndex != ContextIndex_Backward, numChannels); } void Forward(void* input, void* output, void* scale, void* bias, void* runMean, void* runVariance, ContextIndex contextIndex) { auto& ctx = m_context[contextIndex]; void* resources[dnnResourceNumber] = { 0 }; ctx.input.PrepareForExecution(input, resources); ctx.output.PrepareForExecution(output, resources); ctx.scaleShift.PrepareForExecution(scale, bias, resources); resources[dnnResourceMean] = runMean; resources[dnnResourceVariance] = runVariance; CHECK_MKL(dnnExecute(ctx.primitive, resources)); ctx.output.ConvertOutput(output); } void Backward(void* in, void* srcGrad, void* grad, void* scale, void* savedMean, void* savedInvStdDev, void* scaleGrad, void* biasGrad) { auto& ctx = m_context[ContextIndex_Backward]; void* resources[dnnResourceNumber] = { 0 }; ctx.input.PrepareForExecution(srcGrad, resources); ctx.output.PrepareForExecution(grad, resources); ctx.scaleShift.PrepareForExecution(scaleGrad, biasGrad, resources); std::shared_ptr scaleShiftMat; scaleShiftMat = std::make_shared(ctx.scaleShift.numChannels, 2, CPUDEVICE); memcpy(scaleShiftMat->Data(), scale, ctx.scaleShift.numChannels * sizeof(ElemType)); resources[dnnResourceScaleShift] = scaleShiftMat->Data(); // convert from InvStdDev to variance for (size_t i = 0; i < ctx.scaleShift.numChannels; i++) { ElemType& v = ctx.varianceMat->Data()[i]; ElemType& s = ((ElemType*)savedInvStdDev)[i]; v = (1 / (s * s) - m_epsilon); } resources[dnnResourceSrc] = in; resources[dnnResourceMean] = savedMean; resources[dnnResourceVariance] = ctx.varianceMat->Data(); CHECK_MKL(dnnExecute(ctx.primitive, resources)); ctx.output.ConvertOutput(grad); ctx.scaleShift.ConvertOutput(scaleGrad, biasGrad); } }; MKLBatchNormalizationContext m_mklContext; std::shared_ptr m_dataGradWorkspace; bool ForwardCoreMKL(const Mat& in, const Mat& scale, const Mat& bias, bool inferenceOnly, double expAvgFactor, Mat& runMean, Mat& runVariance, Mat& out, double epsilon, Mat& savedMean, Mat& savedInvStdDev) { ContextIndex contextIndex = inferenceOnly ? ContextIndex_ForwardInfer : ContextIndex_ForwardTrain; m_mklContext.Prepare(m_inOutT, m_spatial, in.GetNumCols(), contextIndex, (ElemType)epsilon); if (inferenceOnly) { m_mklContext.Forward(in.Data(), out.Data(), scale.Data(), bias.Data(), runMean.Data(), runVariance.Data(), contextIndex); } else { savedMean.Resize(runMean); savedInvStdDev.Resize(runVariance); m_mklContext.Forward(in.Data(), out.Data(), scale.Data(), bias.Data(), savedMean.Data(), savedInvStdDev.Data(), contextIndex); // update savedMean, savedInvStdDev ElemType OneMinusExpAvgFactor = (ElemType)(1.0 - expAvgFactor); cblas_axpby((MKL_INT)runMean.GetNumElements(), (ElemType)expAvgFactor, savedMean.Data(), OneMinusExpAvgFactor, runMean.Data()); // note savedInvStdDev currently hold variance of in.Data(), need to convert to InvStdDev and interpolate ElemType numReduced = (ElemType)(in.GetNumElements() / runVariance.GetNumElements()); ElemType bcf = numReduced / (numReduced - 1); for (size_t i = 0; i < runVariance.GetNumElements(); i++) { ElemType& v = runVariance.Data()[i]; ElemType& s = savedInvStdDev.Data()[i]; v = v * OneMinusExpAvgFactor + bcf * s * (ElemType)expAvgFactor; s = (ElemType)1 / sqrt(s + (ElemType)epsilon); } } return true; } bool BackwardCoreMKL(const Mat& in, const Mat& srcGrad, Mat& grad, const Mat& scale, const Mat& savedMean, const Mat& savedInvStdDev, Mat& scaleGrad, Mat& biasGrad, bool accumulateDataGrad) { m_mklContext.Prepare(m_inOutT, m_spatial, srcGrad.GetNumCols(), ContextIndex_Backward); if (accumulateDataGrad) { if (!m_dataGradWorkspace) m_dataGradWorkspace = std::make_shared>(0, 0, CPUDEVICE); m_dataGradWorkspace->SetValue(grad); } m_mklContext.Backward(in.Data(), srcGrad.Data(), grad.Data(), scale.Data(), savedMean.Data(), savedInvStdDev.Data(), scaleGrad.Data(), biasGrad.Data()); if (accumulateDataGrad) cblas_axpby((MKL_INT)grad.GetNumElements(), (ElemType)1.0, m_dataGradWorkspace->Data(), (ElemType)1.0, grad.Data()); return true; } #endif }; template class CntkBatchNormEngine; template class CntkBatchNormEngine; template bool HasFlag(T src, T testFlag) { return ((int)src & (int)testFlag) != 0; } template std::unique_ptr> BatchNormEngine::Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT, bool spatial, ImageLayoutKind imageLayout, BatchNormEngineKind enabledEngines) { // Use CNTK as default batch norm engine. if (HasFlag(enabledEngines, BatchNormEngineKind::Cntk)) { if (GetMathLibTraceLevel() > 0) fprintf(stderr, "Using CNTK batch normalization engine.\n"); return std::make_unique>(deviceId, inOutT, spatial, imageLayout); } if (HasFlag(enabledEngines, BatchNormEngineKind::CuDnn)) { if (GetMathLibTraceLevel() > 0) fprintf(stderr, "Using cuDNN batch normalization engine.\n"); return CuDnnBatchNormEngineFactory::Create(deviceId, inOutT, spatial, imageLayout); } RuntimeError("Could not find appropriate batch normalization engine."); } template class BatchNormEngine; template class BatchNormEngine; }}}