// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // // CNTKEval.h - Include file for the CNTK Evaluation DLL // // NOTICE: This interface is a public interface for evaluating models in CNTK. // Changes to this interface may affect other projects, such as Argon and LatGen, // and therefore need to be communicated with such groups. // #pragma once #include #include #include #include "Eval.h" #include "EvalReader.h" #include "EvalWriter.h" #include "ComputationNetwork.h" namespace Microsoft { namespace MSR { namespace CNTK { template class CNTKEvalBase : public IEvaluateModelBase { protected: typedef shared_ptr> ComputationNodePtr; ConfigParameters m_config; ComputationNetworkPtr m_net; // constructor CNTKEvalBase() : m_net(nullptr) { } public: // CreateNetwork - create a network based on the network description // networkDescription - network description virtual void CreateNetwork(const std::string& networkDescription); virtual void Init(const std::string& config); virtual void Destroy(); }; // ------------------------------------------------------------------------ // Basic interface // ------------------------------------------------------------------------ template class CNTKEval : public CNTKEvalBase, public IEvaluateModel { EvalReader* m_reader; EvalWriter* m_writer; std::map m_dimensions; size_t m_start; public: CNTKEval() : CNTKEvalBase(), m_reader(nullptr), m_writer(nullptr) {} virtual void GetNodeDimensions(std::map& dimensions, NodeGroup nodeGroup); virtual void StartEvaluateMinibatchLoop(const std::wstring& outputNodeName); virtual void Evaluate(std::map*>& inputs, std::map*>& outputs); virtual void Evaluate(std::map*>& outputs); virtual void Destroy() override; virtual void CreateNetwork(const std::string& networkDescription) override { CNTKEvalBase::CreateNetwork(networkDescription); } virtual void Init(const std::string& config) override { CNTKEvalBase::Init(config); m_start = 0; } virtual void ResetState() override { m_start = 1 - m_start; } }; // ------------------------------------------------------------------------ // Extended interface // ------------------------------------------------------------------------ template class CNTKEvalExtended : public CNTKEvalBase, public IEvaluateModelExtended { virtual VariableSchema GetOutputSchema() const override; virtual void StartForwardEvaluation(std::vector outputs) override; virtual VariableSchema GetInputSchema() const override; virtual void ForwardPass(const Variables& inputs, Variables& output) override; virtual void Destroy() override; virtual void CreateNetwork(const std::string& networkDescription) override { CNTKEvalBase::CreateNetwork(networkDescription); } virtual void Init(const std::string& config) override { CNTKEvalBase::Init(config); } private: static VariableLayout ToVariableLayout(const ComputationNodeBasePtr n); std::vector m_outputNodes; std::shared_ptr m_scopedNetworkOperationMode; std::vector m_inputNodes; StreamMinibatchInputs m_inputMatrices; }; } } }