Revision e9396480025b9ca457d26b6f33dd07c474c6aa04 authored by liqunfu on 31 March 2020, 15:55:14 UTC, committed by GitHub on 31 March 2020, 15:55:14 UTC
1 parent e1467a7
Raw File
CNTKEval.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// CNTKEval.h - Include file for the CNTK Evaluation DLL
// 
// NOTICE: This interface is a public interface for evaluating models in CNTK. 
//         Changes to this interface may affect other projects, such as Argon and LatGen,
//         and therefore need to be communicated with such groups.
//
#pragma once

#include <string>
#include <map>
#include <vector>

#include "Eval.h"
#include "EvalReader.h"
#include "EvalWriter.h"

#include "ComputationNetwork.h"

namespace Microsoft { namespace MSR { namespace CNTK {

template <typename ElemType>
class CNTKEvalBase : public IEvaluateModelBase<ElemType>
{
protected:
    typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
    ConfigParameters m_config;
    ComputationNetworkPtr m_net;

    // constructor
    CNTKEvalBase() : m_net(nullptr) { }
public:

    // CreateNetwork - create a network based on the network description
    // networkDescription - network description
    virtual void CreateNetwork(const std::string& networkDescription);
    virtual void Init(const std::string& config);
    virtual void Destroy();
};

// ------------------------------------------------------------------------
// Basic interface
// ------------------------------------------------------------------------
template <typename ElemType>
class CNTKEval : public CNTKEvalBase<ElemType>, public IEvaluateModel<ElemType>
{
    EvalReader<ElemType>* m_reader;
    EvalWriter<ElemType>* m_writer;
    std::map<std::wstring, size_t> m_dimensions;
    size_t m_start;
public:
    CNTKEval() : CNTKEvalBase<ElemType>(), m_reader(nullptr), m_writer(nullptr) {}

    virtual void GetNodeDimensions(std::map<std::wstring, size_t>& dimensions, NodeGroup nodeGroup);

    virtual void StartEvaluateMinibatchLoop(const std::wstring& outputNodeName);

    virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& inputs, std::map<std::wstring, std::vector<ElemType>*>& outputs);

    virtual void Evaluate(std::map<std::wstring, std::vector<ElemType>*>& outputs);

    virtual void Destroy() override;

    virtual void CreateNetwork(const std::string& networkDescription) override
    {
        CNTKEvalBase<ElemType>::CreateNetwork(networkDescription);
    }
    
    virtual void Init(const std::string& config) override
    {
        CNTKEvalBase<ElemType>::Init(config);
        m_start = 0;
    }

    virtual void ResetState() override
    {
        m_start = 1 - m_start;
    }
};



// ------------------------------------------------------------------------
// Extended interface
// ------------------------------------------------------------------------
template <typename ElemType>
class CNTKEvalExtended : public CNTKEvalBase<ElemType>, public IEvaluateModelExtended<ElemType>
{
public:
    CNTKEvalExtended() : CNTKEvalBase<ElemType>(), 
        m_started(false){}

    virtual VariableSchema GetOutputSchema() const override;

    virtual void StartForwardEvaluation(const std::vector<wstring>& outputs) override;

    virtual VariableSchema GetInputSchema() const override;

    virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output) override;

    virtual void ForwardPass(const Values<ElemType>& inputs, Values<ElemType>& output, bool resetRNN) override;

    virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output) override;

    virtual void ForwardPass(const ValueRefs<ElemType>& inputs, ValueRefs<ElemType>& output, bool resetRNN) override;

    virtual void Destroy() override;

    virtual void CreateNetwork(const std::string& networkDescription) override
    {
        CNTKEvalBase<ElemType>::CreateNetwork(networkDescription);
    }

    virtual void Init(const std::string& config) override
    {
        CNTKEvalBase<ElemType>::Init(config);
    }

private:
    static VariableLayout ToVariableLayout(const ComputationNodeBasePtr n);
    std::vector<ComputationNodeBasePtr> m_outputNodes;
    std::shared_ptr<ScopedNetworkOperationMode> m_scopedNetworkOperationMode;
    std::vector<ComputationNodeBasePtr> m_inputNodes;
    StreamMinibatchInputs m_inputMatrices;
    bool m_started;

    template<template<typename> class ValueContainer> 
    void ForwardPassT(const std::vector < ValueBuffer<ElemType, ValueContainer> >& inputs,
                      std::vector < ValueBuffer<ElemType, ValueContainer> >& outputs, bool resetRNN);

};
} } }
back to top