https://github.com/Microsoft/CNTK
Tip revision: 3664c4cf5aaa683898cf4119b9aa3f1d28ea1dcb authored by Spandan Tiwari on 23 May 2018, 17:33:50 UTC
Fixed convolution_transpose test failure.
Fixed convolution_transpose test failure.
Tip revision: 3664c4c
mkldnn_convolution-inl.h
/*******************************************************************************
* Copyright 2017 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* \file mkl_convolution-inl.h
* \brief
* \author lingyan.guo@intel.com
* zhenlin.luo@intel.com
*
*******************************************************************************/
#ifndef CNTK_OPERATOR_MKL_DNN_MKLDNN_CONVOLUTION_INL_H_
#define CNTK_OPERATOR_MKL_DNN_MKLDNN_CONVOLUTION_INL_H_
#include <string>
#include <algorithm>
#include <vector>
#include <iostream>
#include "mkl_memory.h"
#include "mkldnn_memory-inl.h"
#include "mkl_conv-common-inl.h"
#include "mkldnn_base-inl.h"
namespace Microsoft
{
namespace MSR
{
namespace CNTK
{
extern void GetSizesAndStrides(int dimension, const TensorShape& shape, size_t lastDim, SmallVector<size_t>& sizes,
SmallVector<size_t>& strides, size_t mapCount = 0);
extern void GetInputOffsets(const ConvolveGeometry* geometry, SmallVector<int>& inputOffset);
} // namespace CNTK
} // namespace MSR
} // namespace Microsoft
#ifdef USE_MKLDNN
namespace Microsoft
{
namespace MSR
{
namespace CNTK
{
template <typename DType>
class MKLDNNConvolutionOp : public MKLDNNLayer<DType>, public MKLConvCommon<DType>
{
static int s_id_gen;
int m_id;
public:
using Mat = Matrix<DType>;
std::string getName()
{
std::string name = "MKLDNNConvolutionOp_";
name = name + std::to_string(m_id);
return name;
}
explicit MKLDNNConvolutionOp(ConvolveGeometryPtr geometry, ImageLayoutKind imageLayout, bool bias = false,
bool relu = false)
: MKLDNNLayer<DType>(),
dilate_w(0),
dilate_h(0),
fwd_bottom_data(NULL),
fwd_top_data(NULL),
fwd_weights_data(NULL),
fwd_bias_data(NULL),
convFwd_pd(NULL),
convBwdData_pd(NULL),
convBwdWeights_pd(NULL),
init_gbias(-1),
b_init_convBwdData(false),
b_init_convBwdWeights(false),
b_init_convFwd(false)
{
b_init_conv = false;
m_geometry = geometry;
m_imageLayout = imageLayout;
m_bias = bias;
m_relu = relu;
m_id = s_id_gen++;
}
virtual ~MKLDNNConvolutionOp() {}
void init_properties(int batchSize)
{
this->num_ = batchSize;
this->group_ = 1; // TODO: CNTK support group?
// Check ComputeOutputShape
if (m_geometry->InputShape().GetRank() == 3)
{
ImageDimensions inT(m_geometry->InputShape(), m_imageLayout);
ImageDimensions outT(m_geometry->OutputShape(), m_imageLayout);
ImageDimensions kernelT(m_geometry->KernelShape(), m_imageLayout);
ImageDimensions strideT(m_geometry->Stride(), m_imageLayout);
this->stride_w_ = (int) strideT.w();
this->stride_h_ = (int) strideT.h();
this->width_ = (int) inT.w();
this->height_ = (int) inT.h();
this->kernel_w_ = (int) kernelT.w();
this->kernel_h_ = (int) kernelT.h();
this->channels_ = (int) inT.c();
this->width_out_ = (int) outT.w();
this->height_out_ = (int) outT.h();
this->channel_output_ = (int) outT.c();
}
else
{
int dimension = 4;
SmallVector<size_t> outputSize, outputStrides, filterSize, filterStrides, inputSize, inputStrides,
stridesSize, stridesStrides;
SmallVector<int> inputOffset;
size_t mapCount = m_geometry->GetMapCount(m_geometry->KernelShape().GetRank() - 1);
GetSizesAndStrides(dimension, m_geometry->OutputShape(), batchSize, outputSize, outputStrides, mapCount);
GetSizesAndStrides(dimension, m_geometry->KernelShape(), mapCount, filterSize, filterStrides);
GetSizesAndStrides(dimension, m_geometry->InputShape(), batchSize, inputSize, inputStrides);
GetSizesAndStrides(dimension, m_geometry->Stride(), batchSize, stridesSize, stridesStrides);
this->width_ = (int) inputSize[0];
this->height_ = (int) inputSize[1];
this->channels_ = (int) inputSize[2];
this->kernel_w_ = (int) filterSize[0];
this->kernel_h_ = (int) filterSize[1];
this->width_out_ = (int) outputSize[0];
this->height_out_ = (int) outputSize[1];
this->channel_output_ = (int) outputSize[2];
this->stride_w_ = (int) stridesSize[0];
this->stride_h_ = (int) stridesSize[1];
}
if (m_geometry->GetDilation(0) > 1)
this->dilate_w = (int) m_geometry->GetDilation(0) - 1;
if (m_geometry->GetDilation(1) > 1)
this->dilate_h = (int) m_geometry->GetDilation(1) - 1;
const SmallVector<bool>& autopad = m_geometry->AutoPad();
int autopad_size = (int) autopad.size();
const TensorShape& padShape = m_geometry->LowerPad();
int pad_size = (int) padShape.size();
// For CHW
if (autopad_size > 0 && autopad[0])
{
this->pad_l_w_ = m_geometry->GetLowerPad(0);
}
else if (pad_size > 0)
{
this->pad_l_w_ = (int) padShape[0];
}
if (autopad_size > 1 && autopad[1])
{
this->pad_l_h_ = m_geometry->GetLowerPad(1);
}
else if (pad_size > 1)
{
this->pad_l_h_ = (int) padShape[1];
}
this->pad_r_h_ = (this->height_out_ - 1) * this->stride_h_ - this->pad_l_h_ - this->height_ +
((this->kernel_h_ - 1) * (this->dilate_h + 1) + 1);
this->pad_r_w_ = (this->width_out_ - 1) * this->stride_w_ - this->pad_l_w_ - this->width_ +
((this->kernel_w_ - 1) * (this->dilate_w + 1) + 1);
}
private:
void InitForward(bool inferenceOnly)
{
auto propagation = (inferenceOnly) ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training;
if (m_relu)
propagation = mkldnn::prop_kind::forward_inference;
int32_t g = std::max(this->group_, 1);
int32_t n = this->num_;
int32_t iw = this->width_;
int32_t ih = this->height_;
int32_t ic = this->channels_;
int32_t ow = this->width_out_;
int32_t oh = this->height_out_;
int32_t oc = this->channel_output_;
int32_t kw = this->kernel_w_;
int32_t kh = this->kernel_h_;
mkldnn::memory::dims convolutionStrides{static_cast<int>(this->stride_h_), static_cast<int>(this->stride_w_)};
mkldnn::memory::dims padding_l{this->pad_l_h_, this->pad_l_w_};
mkldnn::memory::dims padding_r{this->pad_r_h_, this->pad_r_w_};
mkldnn::memory::dims dnn_dilate{this->dilate_h, this->dilate_w};
mkldnn::memory::data_type mpcsn = mkldnn::memory::data_type::f32;
mkldnn::memory::format mfmt_any = mkldnn::memory::format::any;
mkldnn::engine cpu_engine = CpuEngine::Instance().get_engine();
mkldnn::memory::dims bottom_tz = {n, ic, ih, iw};
mkldnn::memory::dims bias_tz = {oc};
mkldnn::memory::dims top_tz = {n, oc, oh, ow};
mkldnn::memory::dims weights_tz =
(g != 1) ? mkldnn::memory::dims{g, oc / g, ic / g, kh, kw} : mkldnn::memory::dims{oc, ic, kh, kw};
mkldnn::memory::desc init_bottom_md({bottom_tz}, mpcsn, mfmt_any);
mkldnn::memory::desc init_bias_md({bias_tz}, mpcsn, mfmt_any);
mkldnn::memory::desc init_top_md({top_tz}, mpcsn, mfmt_any);
mkldnn::memory::desc init_weights_md({weights_tz}, mpcsn, mfmt_any);
// ---- Initialize convolution primitive descriptor
std::shared_ptr<mkldnn::convolution_forward::desc> convFwd_desc;
if (this->m_bias)
{
convFwd_desc.reset(new mkldnn::convolution_forward::desc(
propagation, mkldnn::algorithm::convolution_direct, init_bottom_md, init_weights_md, init_bias_md,
init_top_md, convolutionStrides, dnn_dilate, padding_l, padding_r, mkldnn::padding_kind::zero));
}
else
{
convFwd_desc.reset(new mkldnn::convolution_forward::desc(
propagation, mkldnn::algorithm::convolution_direct, init_bottom_md, init_weights_md, init_top_md,
convolutionStrides, dnn_dilate, padding_l, padding_r, mkldnn::padding_kind::zero));
}
if (m_relu)
{
// add fusion for relu
attr_t attr = attr_t(mkldnn::round_mode::round_nearest, 1.0, attr_t::scale_t::policy_t::COMMON);
attr.pops.entry[0].kind = attr_t::post_ops_t::kind_t::RELU;
attr.pops.entry[0].eltwise.alpha = 0.0;
attr.pops.entry[0].eltwise.beta = 0.0;
attr.pops.entry[0].eltwise.scale = 1.0;
attr.pops.len = 1;
attr.mkldnn_attr_create();
convFwd_pd.reset(
new mkldnn::convolution_forward::primitive_desc(*convFwd_desc, attr.mkldnn_attr, cpu_engine));
}
else
{
convFwd_pd.reset(new mkldnn::convolution_forward::primitive_desc(*convFwd_desc, cpu_engine));
}
assert(convFwd_pd);
// ---- Create priv memory primitive descriptors stored as class members -------------
typedef typename mkldnn::memory::primitive_desc MemPD;
std::shared_ptr<MemPD> prv_fwd_bottom_data_memory_pd(new MemPD(convFwd_pd->src_primitive_desc()));
std::shared_ptr<MemPD> prv_fwd_top_data_memory_pd(new MemPD(convFwd_pd->dst_primitive_desc()));
std::shared_ptr<MemPD> prv_fwd_weights_data_memory_pd(new MemPD(convFwd_pd->weights_primitive_desc()));
// ---- Create usr memory primitive descriptors -------------
mkldnn::memory::format mfmt_nchw = mkldnn::memory::format::nchw;
mkldnn::memory::format weights_mfmt = (g != 1) ? mkldnn::memory::format::goihw : mkldnn::memory::format::oihw;
std::shared_ptr<MemPD> usr_bottom_data_memory_pd(new MemPD({{bottom_tz}, mpcsn, mfmt_nchw}, cpu_engine));
std::shared_ptr<MemPD> usr_bias_data_memory_pd(
new MemPD({{bias_tz}, mpcsn, mkldnn::memory::format::x}, cpu_engine));
std::shared_ptr<MemPD> usr_top_data_memory_pd(new MemPD({{top_tz}, mpcsn, mfmt_nchw}, cpu_engine));
std::shared_ptr<MemPD> usr_weights_data_memory_pd(new MemPD({{weights_tz}, mpcsn, weights_mfmt}, cpu_engine));
// --- init primitive and prv_memory descriptors ----------------------
fwd_bottom_data.reset(new MKLDNNData<DType>(usr_bottom_data_memory_pd, prv_fwd_bottom_data_memory_pd));
fwd_bottom_data->name = "fwd_bottom_data @ " + this->getName();
fwd_top_data.reset(new MKLDNNData<DType>(usr_top_data_memory_pd, prv_fwd_top_data_memory_pd));
fwd_top_data->name = "fwd_top_data @ " + this->getName();
fwd_weights_data.reset(new MKLDNNData<DType>(usr_weights_data_memory_pd, prv_fwd_weights_data_memory_pd));
fwd_weights_data->name = "fwd_weights_data @ " + this->getName();
if (this->m_bias)
{
std::shared_ptr<MemPD> prv_fwd_bias_data_memory_pd(new MemPD(convFwd_pd->bias_primitive_desc()));
fwd_bias_data.reset(new MKLDNNData<DType>(usr_bias_data_memory_pd, prv_fwd_bias_data_memory_pd));
fwd_bias_data->name = "fwd_bias_data @ " + this->getName();
}
}
public:
virtual void Forward(const Mat& in, const Mat& kernel, Mat& out, bool inferenceOnly, Mat* pBias = NULL)
{
DType* data_ptr = mkl_experimental_direct_get(in);
DType* out_ptr = mkl_experimental_direct_get(out);
DType* wmat_ptr = mkl_experimental_direct_get(kernel);
DType* bias_ptr = NULL;
if (pBias != NULL)
bias_ptr = mkl_experimental_direct_get(*pBias);
if (!b_init_conv)
{
this->init_properties((int) in.GetNumCols());
this->b_init_conv = true;
}
bool b_same = true;
if (convFwd_pd == NULL)
{
InitForward(inferenceOnly);
}
// --- init primitive and prv_memory descriptors ---------
fwd_bottom_data_primitive = fwd_bottom_data->get_converted_prv(data_ptr, false, in, &b_same);
fwd_weights_data_primitive = fwd_weights_data->get_converted_prv(wmat_ptr, true, kernel, &b_same);
if (this->m_bias)
{
fwd_bias_data_primitive = fwd_bias_data->get_converted_prv(bias_ptr, true, *pBias, &b_same);
init_gbias.AssignValuesOf(*pBias);
}
fwd_top_data_memory = fwd_top_data->create_output_memory(out_ptr, out, false, &b_same);
if (!b_init_convFwd || !b_same)
{
//each mkldnn memory have dedicate _prv_memory
if (this->m_bias)
{
convFwd.reset(new mkldnn::convolution_forward(*convFwd_pd, *fwd_bottom_data_primitive,
*fwd_weights_data_primitive, *fwd_bias_data_primitive,
*fwd_top_data_memory));
}
else
{
convFwd.reset(new mkldnn::convolution_forward(*convFwd_pd, *fwd_bottom_data_primitive,
*fwd_weights_data_primitive, *fwd_top_data_memory));
}
if (!b_init_convFwd)
b_init_convFwd = true;
}
convFwd.submit();
}
void InitConvolutionBwd()
{
int32_t g = std::max(this->group_, 1);
int32_t n = this->num_;
int32_t iw = this->width_;
int32_t ih = this->height_;
int32_t ic = this->channels_;
int32_t ow = this->width_out_;
int32_t oh = this->height_out_;
int32_t oc = this->channel_output_;
int32_t kw = this->kernel_w_;
int32_t kh = this->kernel_h_;
mkldnn::memory::dims convolutionStrides{this->stride_h_, this->stride_w_};
mkldnn::memory::dims padding_l{this->pad_l_h_, this->pad_l_w_};
mkldnn::memory::dims padding_r{this->pad_r_h_, this->pad_r_w_};
mkldnn::memory::data_type mpcsn = mkldnn::memory::data_type::f32;
mkldnn::memory::format mfmt_any = mkldnn::memory::format::any;
mkldnn::memory::dims bottom_tz = {n, ic, ih, iw};
mkldnn::memory::dims bias_tz = {oc};
mkldnn::memory::dims top_tz = {n, oc, oh, ow};
mkldnn::memory::dims weights_tz =
(g != 1) ? mkldnn::memory::dims{g, oc / g, ic / g, kh, kw} : mkldnn::memory::dims{oc, ic, kh, kw};
mkldnn::memory::desc init_bottom_md({bottom_tz}, mpcsn, mfmt_any);
mkldnn::memory::desc init_bias_md({bias_tz}, mpcsn, mfmt_any);
mkldnn::memory::desc init_top_md({top_tz}, mpcsn, mfmt_any);
mkldnn::memory::desc init_weights_md({weights_tz}, mpcsn, mfmt_any);
// ---- Initialize convolution primitive descriptor -------------
std::shared_ptr<mkldnn::convolution_backward_data::desc> convBwdData_desc;
std::shared_ptr<mkldnn::convolution_backward_weights::desc> convBwdWeights_desc;
if (this->m_bias)
{
convBwdWeights_desc.reset(new mkldnn::convolution_backward_weights::desc(
mkldnn::algorithm::convolution_direct, init_bottom_md, init_weights_md, init_bias_md, init_top_md,
convolutionStrides, padding_l, padding_r, mkldnn::padding_kind::zero));
}
else
{
convBwdWeights_desc.reset(new mkldnn::convolution_backward_weights::desc(
mkldnn::algorithm::convolution_direct, init_bottom_md, init_weights_md, init_top_md, convolutionStrides,
padding_l, padding_r, mkldnn::padding_kind::zero));
}
mkldnn::engine cpu_engine = CpuEngine::Instance().get_engine();
convBwdData_desc.reset(new mkldnn::convolution_backward_data::desc(
mkldnn::algorithm::convolution_direct, init_bottom_md, init_weights_md, init_top_md, convolutionStrides,
padding_l, padding_r, mkldnn::padding_kind::zero));
convBwdData_pd.reset(
new mkldnn::convolution_backward_data::primitive_desc(*convBwdData_desc, cpu_engine, *convFwd_pd));
convBwdWeights_pd.reset(
new mkldnn::convolution_backward_weights::primitive_desc(*convBwdWeights_desc, cpu_engine, *convFwd_pd));
// ---- Create priv memory primitive descriptors stored as class members -------------
typedef typename mkldnn::memory::primitive_desc MemPD;
std::shared_ptr<MemPD> prv_bwdd_bottom_diff_memory_pd(new MemPD(convBwdData_pd->diff_src_primitive_desc()));
std::shared_ptr<MemPD> prv_bwdd_top_diff_memory_pd(new MemPD(convBwdData_pd->diff_dst_primitive_desc()));
std::shared_ptr<MemPD> prv_bwdd_weights_data_memory_pd(new MemPD(convBwdData_pd->weights_primitive_desc()));
std::shared_ptr<MemPD> prv_bwdw_bottom_data_memory_pd(new MemPD(convBwdWeights_pd->src_primitive_desc()));
std::shared_ptr<MemPD> prv_bwdw_top_diff_memory_pd(new MemPD(convBwdWeights_pd->diff_dst_primitive_desc()));
std::shared_ptr<MemPD> prv_bwdw_weights_diff_memory_pd(
new MemPD(convBwdWeights_pd->diff_weights_primitive_desc()));
// ---- Create usr memory primitive descriptors -------------
mkldnn::memory::format mfmt_nchw = mkldnn::memory::format::nchw;
mkldnn::memory::format weights_mfmt = (g != 1) ? mkldnn::memory::format::goihw : mkldnn::memory::format::oihw;
// ???!!! can we use usr memory primitive descrittors for backward??
std::shared_ptr<MemPD> usr_bottom_data_memory_pd(new MemPD({{bottom_tz}, mpcsn, mfmt_nchw}, cpu_engine));
std::shared_ptr<MemPD> usr_bias_data_memory_pd(
new MemPD({{bias_tz}, mpcsn, mkldnn::memory::format::x}, cpu_engine));
std::shared_ptr<MemPD> usr_top_data_memory_pd(new MemPD({{top_tz}, mpcsn, mfmt_nchw}, cpu_engine));
std::shared_ptr<MemPD> usr_weights_data_memory_pd(new MemPD({{weights_tz}, mpcsn, weights_mfmt}, cpu_engine));
// --- init primitive and prv_memory descriptors ----------------------
bwdd_bottom_diff.reset(new MKLDNNData<DType>(usr_bottom_data_memory_pd, prv_bwdd_bottom_diff_memory_pd));
bwdd_bottom_diff->name = "bwdd_bottom_diff @ " + this->getName();
bwdd_bottom_diff_ws.reset(
new MKLDNNData<DType>(usr_bottom_data_memory_pd, prv_bwdd_bottom_diff_memory_pd));
bwdd_bottom_diff_ws->name = "bwdd_bottom_diff_ws @ " + this->getName();
bwdw_bottom_data.reset(new MKLDNNData<DType>(usr_bottom_data_memory_pd, prv_bwdw_bottom_data_memory_pd));
bwdw_bottom_data->name = "bwdw_bottom_data @ " + this->getName();
bwdd_top_diff.reset(new MKLDNNData<DType>(usr_top_data_memory_pd, prv_bwdd_top_diff_memory_pd));
bwdd_top_diff->name = "bwdd_top_diff @ " + this->getName();
bwdw_top_diff.reset(new MKLDNNData<DType>(usr_top_data_memory_pd, prv_bwdw_top_diff_memory_pd));
bwdw_top_diff->name = "bwdw_top_diff @ " + this->getName();
bwdd_weights_data.reset(new MKLDNNData<DType>(usr_weights_data_memory_pd, prv_bwdd_weights_data_memory_pd));
bwdd_weights_data->name = "bwdd_weights_data @ " + this->getName();
bwdw_weights_diff.reset(new MKLDNNData<DType>(usr_weights_data_memory_pd, prv_bwdw_weights_diff_memory_pd));
bwdw_weights_diff->name = "bwdw_weights_diff @ " + this->getName();
bwdw_weights_diff_ws.reset(
new MKLDNNData<DType>(usr_weights_data_memory_pd, prv_bwdw_weights_diff_memory_pd));
bwdw_weights_diff_ws->name = "bwdw_weights_diff_ws @ " + this->getName();
if (this->m_bias)
{
std::shared_ptr<MemPD> prv_bwdw_bias_diff_memory_pd(
new MemPD(convBwdWeights_pd->diff_bias_primitive_desc()));
mkldnn::memory::desc prv_bwd_bias_desc = convBwdWeights_pd->diff_bias_primitive_desc().desc();
bwdw_bias_diff.reset(new MKLDNNData<DType>(usr_bias_data_memory_pd, prv_bwdw_bias_diff_memory_pd));
bwdw_bias_diff->name = "bwdw_bias_diff @ " + this->getName();
}
}
void InitReLUBwd(const Mat& src)
{
int32_t n = this->num_;
int32_t iw = this->width_;
int32_t ih = this->height_;
int32_t ic = this->channels_;
DType negative_slope = 0;
void* src_data = const_cast<DType*>(mkl_prv_data<DType>(src));
bool src_is_prv = (src_data != NULL);
mkldnn::engine cpu_engine = CpuEngine::Instance().get_engine();
mkldnn::memory::data_type mpcsn = mkldnn::memory::data_type::f32;
// ---- Initialize memory descriptors -------------
// std::shared_ptr<mkldnn::memory::desc> bottom_diff_md;
std::shared_ptr<mkldnn::memory::desc> top_diff_md;
std::shared_ptr<mkldnn::memory::desc> top_data_md;
std::shared_ptr<mkldnn::memory::primitive_desc> usr_diff_mpd;
std::shared_ptr<mkldnn::memory::primitive_desc> prv_diff_mpd;
if (src_is_prv)
{
std::shared_ptr<MKLDNNMemoryDescriptor<DType>> mem_descr = get_mkldnn_prv_descriptor<DType>(src);
top_diff_md.reset(new mkldnn::memory::desc(mem_descr->prv_memory_pd()->desc()));
usr_diff_mpd = mem_descr->usr_memory_pd();
prv_diff_mpd = mem_descr->prv_memory_pd();
}
else
{
top_diff_md.reset(new mkldnn::memory::desc({{n, ic, ih, iw}}, mpcsn, mkldnn::memory::format::nchw));
usr_diff_mpd.reset(new mkldnn::memory::primitive_desc(*top_diff_md, cpu_engine));
}
top_data_md = top_diff_md;
mkldnn::eltwise_forward::desc fwd_training_desc(mkldnn::prop_kind::forward_training, mkldnn::eltwise_relu,
*top_data_md, negative_slope);
fwd_relu_training_pd.reset(new mkldnn::relu_forward::primitive_desc(fwd_training_desc, cpu_engine));
mkldnn::eltwise_backward::desc reluBwd_desc(mkldnn::eltwise_relu, *top_diff_md, *top_data_md, negative_slope);
bwd_relu_pd.reset(new mkldnn::relu_backward::primitive_desc(reluBwd_desc, cpu_engine, *fwd_relu_training_pd));
bwd_relu_top_diff.reset(new MKLDNNData<DType>(usr_diff_mpd, prv_diff_mpd));
bwd_relu_top_diff->name = "bwd_top_diff @ " + this->getName();
bwd_relu_dst_data.reset(new MKLDNNData<DType>(usr_diff_mpd, prv_diff_mpd));
bwd_relu_dst_data->name = "bwd_bottom_data @ " + this->getName();
}
virtual void BackwardData(const Mat& srcGrad, const Mat& kernel, Mat& grad, bool accumulateGradient, Mat& workspace)
{
DType* srcgrad_ptr = mkl_experimental_direct_get(srcGrad);
DType* kernel_ptr = mkl_experimental_direct_get(kernel);
DType* grad_ptr = mkl_experimental_direct_get(grad);
if (!b_init_conv)
{
this->init_properties((int) grad.GetNumCols());
b_init_conv = true;
}
if (convFwd_pd == NULL)
{
this->InitForward(true);
}
bool b_same = true;
if (convBwdData_pd == NULL)
{
this->InitConvolutionBwd();
}
std::shared_ptr<mkldnn::memory> bwdd_top_diff_primitive, bwdd_weights_data_primitive, bwdd_diff_src_primitive;
std::shared_ptr<mkldnn::memory> bwdd_bottom_diff_memory;
std::shared_ptr<mkldnn::memory> bwdd_bottom_diff_dst;
// --- init primitive and prv_memory descriptors ---------
bwdd_top_diff_primitive = bwdd_top_diff->get_converted_prv(srcgrad_ptr, true, srcGrad, &b_same);
bwdd_weights_data_primitive = bwdd_weights_data->get_converted_prv(kernel_ptr, false, kernel, &b_same);
if (accumulateGradient) {
workspace.Resize(grad);
bwdd_bottom_diff_dst = bwdd_bottom_diff_ws->create_output_memory(workspace.Data(), workspace, false, &b_same);
bwdd_bottom_diff_memory = bwdd_bottom_diff->get_converted_prv(grad_ptr, true, grad, &b_same);
grad_ptr = mkl_experimental_direct_get(grad);
}
else
{
bwdd_bottom_diff_dst = bwdd_bottom_diff_memory = bwdd_bottom_diff->create_output_memory(grad_ptr, grad, false, &b_same);
}
if (!b_init_convBwdData || !b_same)
{
convBwdData.reset(new mkldnn::convolution_backward_data(
*convBwdData_pd, *bwdd_top_diff_primitive, *bwdd_weights_data_primitive, *bwdd_bottom_diff_dst));
if (!b_init_convBwdData)
b_init_convBwdData = true;
}
convBwdData.submit();
if (accumulateGradient)
{
DType * workspace_ptr = mkl_experimental_direct_get(workspace);
grad.MklMem()->template AddTo<DType>(grad_ptr, *workspace.MklMem(), workspace_ptr);
}
}
void BackwardKernel(const Mat& srcGrad, const Mat& in, const Mat& out, Mat& kernelGrad, bool accumulateGradient, Mat& workspace, Mat* pbiasGrad = NULL)
{
DType* srcgrad_ptr = mkl_experimental_direct_get(srcGrad);
DType* in_ptr = mkl_experimental_direct_get(in);
DType* out_ptr = mkl_experimental_direct_get(out);
DType* kernelgrad_ptr = mkl_experimental_direct_get(kernelGrad);
if (!b_init_conv)
{
this->init_properties((int) srcGrad.GetNumCols());
b_init_conv = true;
}
if (convFwd_pd == NULL)
{
this->InitForward(true);
}
if (convBwdData_pd == NULL)
{
this->InitConvolutionBwd();
}
bool b_same = true;
if (m_relu)
{
// inplace relu to update srcGrad, do once then backwarddata can also use the converted data
if (bwd_relu_pd == NULL)
{
InitReLUBwd(out);
}
std::shared_ptr<mkldnn::memory> dst_memory, diff_dst_memory, diff_src_memory;
dst_memory = bwd_relu_dst_data->get_converted_prv(out_ptr, false, out, &b_same);
diff_src_memory = bwd_relu_top_diff->get_converted_prv(srcgrad_ptr, false, srcGrad, &b_same);
MKLDNNPrimitive<DType> reluBwd;
reluBwd.reset(new mkldnn::relu_backward(*bwd_relu_pd, *dst_memory, *diff_src_memory, *diff_src_memory));
reluBwd.submit();
}
std::shared_ptr<mkldnn::memory> bwdw_bottom_data_primitive, bwdw_top_diff_primitive;
std::shared_ptr<mkldnn::memory> bwdw_weights_diff_memory, bwdw_bias_diff_memory;
std::shared_ptr<mkldnn::memory> bwdw_weights_diff_ws_memory, bwdw_weights_diff_dst;
bwdw_top_diff_primitive = bwdw_top_diff->get_converted_prv(srcgrad_ptr, true, srcGrad, &b_same);
bwdw_bottom_data_primitive = bwdw_bottom_data->get_converted_prv(in_ptr, false, in, &b_same);
if (accumulateGradient) {
// make sure workspace is user data
workspace.Resize(kernelGrad);
bwdw_weights_diff_dst = bwdw_weights_diff_ws_memory = bwdw_weights_diff_ws->create_output_memory(workspace.Data(), workspace, false, &b_same);
bwdw_weights_diff_memory = bwdw_weights_diff->get_converted_prv(kernelgrad_ptr, true, kernelGrad, &b_same);
}
else
{
bwdw_weights_diff_dst = bwdw_weights_diff_memory = bwdw_weights_diff->create_output_memory(kernelgrad_ptr, kernelGrad, false, &b_same);
}
if (this->m_bias)
{
DType* gbias_ptr = mkl_experimental_direct_get(*pbiasGrad);
if (gbias_ptr == nullptr)
{
pbiasGrad->AssignValuesOf(init_gbias);
gbias_ptr = mkl_experimental_direct_get(*pbiasGrad);
}
bwdw_bias_diff_memory = bwdw_bias_diff->create_output_memory(gbias_ptr, *pbiasGrad, false, &b_same);
}
if (!b_init_convBwdWeights || !b_same)
{
if (this->m_bias)
{
convBwdWeights.reset(new mkldnn::convolution_backward_weights(
*convBwdWeights_pd, *bwdw_bottom_data_primitive, *bwdw_top_diff_primitive, *bwdw_weights_diff_dst,
*bwdw_bias_diff_memory));
}
else
{
convBwdWeights.reset(new mkldnn::convolution_backward_weights(
*convBwdWeights_pd, *bwdw_bottom_data_primitive, *bwdw_top_diff_primitive, *bwdw_weights_diff_dst));
}
if (!b_init_convBwdWeights)
b_init_convBwdWeights = true;
}
convBwdWeights.submit();
if (accumulateGradient)
{
DType * workspace_ptr = mkl_experimental_direct_get(workspace);
kernelGrad.MklMem()->template AddTo<DType>(kernelgrad_ptr, *workspace.MklMem(), workspace_ptr);
}
}
private:
std::shared_ptr<mkldnn::memory> fwd_bottom_data_primitive, fwd_weights_data_primitive, fwd_bias_data_primitive;
std::shared_ptr<mkldnn::memory> fwd_top_data_memory;
std::shared_ptr<MKLDNNData<DType>> fwd_bottom_data, fwd_top_data, fwd_weights_data, fwd_bias_data,
bwdd_weights_data, bwdw_bottom_data;
std::shared_ptr<MKLDNNData<DType>> bwdd_bottom_diff, bwdd_top_diff, bwdw_top_diff, bwdw_weights_diff,
bwdw_bias_diff;
std::shared_ptr<MKLDNNData<DType> > bwdd_bottom_diff_ws, bwdw_weights_diff_ws;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> convFwd_pd;
MKLDNNPrimitive<DType> convFwd;
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc> convBwdData_pd;
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc> convBwdWeights_pd;
std::shared_ptr<mkldnn::relu_forward::primitive_desc> fwd_relu_training_pd;
std::shared_ptr<mkldnn::relu_backward::primitive_desc> bwd_relu_pd;
std::shared_ptr<MKLDNNData<DType>> bwd_relu_dst_data, bwd_relu_top_diff;
MKLDNNPrimitive<DType> convBwdData, convBwdWeights;
ConvolveGeometryPtr m_geometry;
ImageLayoutKind m_imageLayout;
bool b_init_conv;
bool m_bias;
int dilate_w;
int dilate_h;
Mat init_gbias;
bool m_relu;
bool b_init_convBwdData;
bool b_init_convBwdWeights;
bool b_init_convFwd;
}; // class MKLDNNConvolutionOp
} // namespace CNTK
} // namespace MSR
} // namespace Microsoft
#endif
#endif // CNTK_OPERATOR_MKL_DNN_MKLDNN_CONVOLUTION_INL_H_