https://github.com/Microsoft/CNTK
Raw File
Tip revision: f4087300186676c16dd58d5c753574a27b3e183c authored by KeDengMS on 01 May 2018, 16:59:55 UTC
Merge branch 'master' into kedeng/mkldnn
Tip revision: f408730
EltWiseEngine.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include "Matrix.h"
#include "TensorShape.h" // for ImageLayoutKind

namespace Microsoft
{
namespace MSR
{
namespace CNTK
{

enum class EltWiseEngineKind
{
    None = 0,
    MKLDNN = 1 << 1,
    All = MKLDNN
};
enum class UnaryEltWiseKind
{
    RELU
};
template <class ElemType>
class MATH_API UnaryEltWiseEngine
{
public:
    using Mat = Matrix<ElemType>;

public:
    virtual ~UnaryEltWiseEngine(){};

    void Forward(const Mat& in, Mat& out, bool inferenceOnly);

    void Backward(const Mat& in, const Mat& srcGrad, Mat& grad);

    static std::unique_ptr<UnaryEltWiseEngine<ElemType>>
    Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT, ImageLayoutKind imageLayout, UnaryEltWiseKind kind,
           EltWiseEngineKind enabledEngines = EltWiseEngineKind::All);

    DISABLE_COPY_AND_MOVE(UnaryEltWiseEngine);

protected:
    UnaryEltWiseEngine(DEVICEID_TYPE deviceId, const TensorShape& inOutT, ImageLayoutKind imageLayout)
        : m_deviceId(deviceId), m_inOutT(inOutT), m_imageLayout(imageLayout)
    {
    }

    virtual void EnsureCompatible() = 0;

    virtual void ForwardCore(const Mat& in, Mat& out, bool inferenceOnly) = 0;

    virtual void BackwardCore(const Mat& in, const Mat& srcGrad, Mat& grad) = 0;

protected:
    DEVICEID_TYPE m_deviceId;
    TensorShape m_inOutT;
    ImageLayoutKind m_imageLayout;
};

enum class BinaryEltWiseKind
{
  PLUS
};
template <class ElemType>
class MATH_API BinaryEltWiseEngine
{
public:
  using Mat = Matrix<ElemType>;

public:
  virtual ~BinaryEltWiseEngine() {};

  void Forward(const TensorShape& ishape, Mat& ina, Mat& inb, Mat& out);
  void Backward(Mat& in, Mat& out);
  static std::unique_ptr<BinaryEltWiseEngine<ElemType>> Create(DEVICEID_TYPE deviceId,
    BinaryEltWiseKind kind,
    EltWiseEngineKind enabledEngines = EltWiseEngineKind::All);

  DISABLE_COPY_AND_MOVE(BinaryEltWiseEngine);

protected:
  BinaryEltWiseEngine(DEVICEID_TYPE deviceId)
    : m_deviceId(deviceId)
  {
  }

  virtual void EnsureCompatible() = 0;

  virtual void ForwardTwoCore(const TensorShape& iashape, Mat& ina, Mat& inb, Mat& out) = 0;
  virtual void BackwardTwoCore(Mat& in, Mat& out) = 0;
protected:
  DEVICEID_TYPE m_deviceId;
};
} // namespace CNTK
} // namespace MSR
} // namespace Microsoft
back to top