https://github.com/Microsoft/CNTK
Raw File
Tip revision: 712cd36ec53c28629d90c36ab7a61923b49b8339 authored by Yuqing Tang on 30 August 2017, 01:46:14 UTC
First running experimental implemenation of functional parameters.
Tip revision: 712cd36
RNNNodes.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once

#include "Basics.h"
#include "ComputationNode.h"
#include "Matrix.h"
#include "TensorView.h"
#include "RNNCommon.h"

#include <unordered_set>
#include <map>
#include <string>
#include <vector>
#include <stdexcept>
#include <list>
#include <memory>
#include <algorithm>
#include <utility>
#include <assert.h>

namespace Microsoft { namespace MSR { namespace CNTK {

// -----------------------------------------------------------------------
// OptimizedRNNStack (weights, data)
// -----------------------------------------------------------------------

template <class ElemType>
class OptimizedRNNStackNode : public ComputationNode<ElemType>, public NumInputs<2>
{
    typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
    static const std::wstring TypeName() { return L"OptimizedRNNStack"; }
    using Base::OperationName;

public:
    OptimizedRNNStackNode(DEVICEID_TYPE deviceId, const wstring& name);
    OptimizedRNNStackNode(const ScriptableObjects::IConfigRecordPtr configp);
    OptimizedRNNStackNode(DEVICEID_TYPE deviceId, const std::wstring& name, bool bidirectional, size_t numLayers, size_t hiddenSize, const std::wstring& recurrentOp);

    virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override;
    virtual void Save(File& fstream) const;
    virtual void Load(File& fstream, size_t modelVersion) override;

public:
    virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override;
    virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override;
    virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override;

    // request matrices needed to do node function value evaluation
    virtual void RequestMatricesBeforeForwardProp(MatrixPool& matrixPool)
    {
        Base::RequestMatricesBeforeForwardProp(matrixPool);
        RequestMatrixFromPool(m_transposedInput, matrixPool);
        RequestMatrixFromPool(m_transposedOutput, matrixPool);
        RequestMatrixFromPool(m_reserve, matrixPool);
        RequestMatrixFromPool(m_workspace, matrixPool);
        RequestMatrixFromPool(m_packingIndex, matrixPool);
    }

    // request matrices needed to do node derivative value evaluation
    virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool)
    {
        Base::RequestMatricesBeforeBackprop(matrixPool);
        RequestMatrixFromPool(m_transposedDInput, matrixPool);
        RequestMatrixFromPool(m_transposedDOutput, matrixPool);
    }

    // release gradient and temp matrices that no longer needed after all the children's gradients are computed.
    virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool)
    {
        Base::ReleaseMatricesAfterBackprop(matrixPool);
        ReleaseMatrixToPool(m_transposedInput, matrixPool);
        ReleaseMatrixToPool(m_transposedOutput, matrixPool);
        ReleaseMatrixToPool(m_transposedDInput, matrixPool);
        ReleaseMatrixToPool(m_transposedDOutput, matrixPool);
        ReleaseMatrixToPool(m_reserve, matrixPool);
        ReleaseMatrixToPool(m_workspace, matrixPool);
        ReleaseMatrixToPool(m_packingIndex, matrixPool);
    }

    virtual bool OutputUsedInComputingInputNodesGradients() const { return false; }
    virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const { return 0 == childIndex; }
    RnnAttributes Attributes() const { return m_rnnAttributes; }

protected:
    bool m_BackwardDataCalledYet;
    TensorShape shapeXT;
    TensorShape shapeYT;
    shared_ptr<Matrix<ElemType>> m_transposedInput;
    shared_ptr<Matrix<ElemType>> m_transposedOutput;
    shared_ptr<Matrix<ElemType>> m_transposedDInput;
    shared_ptr<Matrix<ElemType>> m_transposedDOutput;
    shared_ptr<Matrix<ElemType>> m_workspace;
    shared_ptr<Matrix<ElemType>> m_reserve;
    shared_ptr<Matrix<ElemType>> m_packingIndex;

private:
    void TransposeHelper(const MatrixBasePtr matX, const TensorShape &shapeX, MatrixBasePtr matY, TensorShape &shapeY);

    void PackSequencesForCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst, vector<size_t>& numSequencesForFrame);
    void UnpackSequencesFromCuDNN(const Matrix<ElemType>& src, Matrix<ElemType>& dst);

    RnnAttributes m_rnnAttributes;

    bool m_legacySwapInputsPending = false; // to support an internal legacy version
};

}}}
back to top