// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #pragma once #include "stdafx.h" #include "CNTKLibrary.h" #include "PrimitiveFunction.h" #include "ComputationNetwork.h" #include "BackCompat.h" namespace CNTK { class CNTKBackPropState final : public BackPropState { public: CNTKBackPropState(const FunctionPtr& function, const DeviceDescriptor& computeDevice, const std::unordered_map& backpropRootsForwardTimeStamps) : BackPropState(function, computeDevice), m_backpropRootsForwardTimeStamps(backpropRootsForwardTimeStamps) {} const std::unordered_map& BackpropRootsForwardTimeStamps() const { return m_backpropRootsForwardTimeStamps; } private: std::unordered_map m_backpropRootsForwardTimeStamps; }; typedef std::shared_ptr CNTKBackPropStatePtr; class CompositeFunction; typedef std::shared_ptr CompositeFunctionPtr; /// /// Represents a symbolic computation with zero or more input arguments and one or more outputs. /// Opposed to primitive functions, a composite function is composed of other Function instances whose inputs and outputs are wired together. /// CompositeFunction is also responsible for breaking the loop in case of cyclic graphs - it stores the pointers for to the child primitive /// functions and controls their lifetime. /// CompositeFunction class inherits thus from Function. /// class CompositeFunction final : public Function { friend class Function; friend class Trainer; friend class CompositeMinibatchSource; friend class PackedValue; template friend inline std::shared_ptr MakeSharedObject(CtorArgTypes&& ...ctorArgs); friend void Internal::SaveAsLegacyModel(const FunctionPtr& rootFunction, const std::wstring& modelFile); friend void ComputeInputPerDimMeansAndInvStdDevs(const MinibatchSourcePtr& minibatchSource, std::unordered_map>& computedMeanAndInvStdDevs, const DeviceDescriptor& device /*= DeviceDescriptor::CPUDevice()*/); static std::atomic s_nextAutoGeneratedDynamicAxis; static const std::wstring CompositeFunctionOpName; public: static Axis NextAutoGeneratedDynamicAxis() { static const std::wstring s_autoGeneratedDynamicAxisNamePrefix = L"autoGeneratedDynamicAxis_"; return Axis(s_autoGeneratedDynamicAxisNamePrefix + std::to_wstring(s_nextAutoGeneratedDynamicAxis++)); } public: static CompositeFunctionPtr Create(const FunctionPtr& rootFunction, const std::wstring& name = L"", const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction")) { std::unordered_set visitedFunctions; // Call Collect to get the set of all functions in the graph Collect(rootFunction, visitedFunctions); auto composite = MakeSharedObject(rootFunction, std::move(visitedFunctions), name, uid); // Initialize the outputs composite->InitOutputs(); return composite; } BackPropStatePtr Forward(const std::unordered_map& arguments, std::unordered_map& outputs, const DeviceDescriptor& computeDevice, const std::unordered_set& outputsToRetainBackwardStateFor, const std::unordered_set& inputsToExcludeGradientsFor); virtual BackPropStatePtr Forward(const std::vector& /*inputValues*/, std::unordered_map& /*outputs*/, const DeviceDescriptor& /*computeDevice*/, const std::unordered_set& /*outputsToRetainBackwardStateFor*/) override { NOT_IMPLEMENTED; } void InferOutputs(std::vector& outputs) override { auto& inferred = m_rootFunction->InitOutputs(); outputs.assign(inferred.begin(), inferred.end()); } virtual void Backward(const BackPropStatePtr& state, const std::unordered_map& rootGradientValues, std::unordered_map& backPropagatedGradientValuesForInputs) override; Dictionary SerializeBlockComposite() const; virtual Dictionary Serialize() const override; virtual size_t CurrentVersion() const override { return s_serializationVersion; } static FunctionPtr DeserializeBlockComposite(const Dictionary& dict, const std::unordered_set& allPrimitiveFunctions, const std::unordered_map& allPlaceholderReplacements, const CNTK::DeviceDescriptor& device); static FunctionPtr Deserialize(const Dictionary& dictionary, const CNTK::DeviceDescriptor& device); virtual const std::wstring& OpName() const override { return CompositeFunctionOpName; } template static void PreorderTraverseVariables(const FunctionPtr& rootFunction, const FunctionType& functor, bool pythonOperandOrder = false) { std::unordered_set visitedFunctions; PreorderTraverseVariables(rootFunction, visitedFunctions, functor, pythonOperandOrder); } // Recursively traverses the Function graph underlying the 'rootFunction' invoking the provided functor for all visited nodes in the graph. template static void PreorderTraverseVariables(const FunctionPtr& rootFunction, std::unordered_set& visitedFunctions, const FunctionType& functor, bool pythonOperandOrder = false) { visitedFunctions.insert(rootFunction); auto rootFunctionOutputs = rootFunction->InitOutputs(); for (const auto& rootOutput : rootFunctionOutputs) functor(rootOutput); auto rootFunctionInputs = rootFunction->Inputs(pythonOperandOrder); for (const auto& rootInput : rootFunctionInputs) { functor(rootInput); if (rootInput.IsOutput() && visitedFunctions.find(rootInput.Owner()) == visitedFunctions.end()) { const auto& function = rootInput.Owner(); PreorderTraverseVariables(function, visitedFunctions, functor, pythonOperandOrder); } } } private: // Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables // should have been replaced before performing any Forward compute of 'this' Function. virtual void OnPlaceholdersReplaced(const std::unordered_map& placeholderReplacements, std::unordered_set& replacedPlaceholders) override { // If any of the placeholders were replaced with Output variables, let's add the graph of function underneath // each of those to 'm_allPrimitiveFunctions' set for (auto replacedPlaceholder : replacedPlaceholders) { auto replacingVariable = placeholderReplacements.at(replacedPlaceholder); if (replacingVariable.IsOutput()) { auto ownerFunc = replacingVariable.Owner(); std::unordered_set visitedFunctions2; Collect(ownerFunc, visitedFunctions2); // Add the newly visited functions to 'm_allPrimitiveFunctions' set m_allPrimitiveFunctions.insert(visitedFunctions2.begin(), visitedFunctions2.end()); } } } CompositeFunction(const FunctionPtr& rootFunction, std::unordered_set&& allPrimitiveFunctions, const std::wstring& name, const std::wstring& uid = Internal::GenerateUid(L"CompositeFunction")) : Function({}, Dictionary(), rootFunction, name, uid), m_allPrimitiveFunctions(std::move(allPrimitiveFunctions)), m_networkMatricesAllocated(false) {} std::vector DetermineInputs(bool pythonOperandOrder = false) const { const auto& root = RootFunction(); std::unordered_set visitedFunctions; return DetermineInputs(root, visitedFunctions, pythonOperandOrder); } // Recursively traverses the Function graph and populates the provided set of functions. static void Collect(const FunctionPtr& rootFunction, std::unordered_set& functions) { // Call Traverse to get the set of all functions in the graph PreorderTraverseFunctions(rootFunction, functions, [](const FunctionPtr& f){}); } // Recursively traverses the Function graph underlying the 'rootFunction' to determine all the leaves (aka inputs) of the graph static std::vector DetermineInputs(const FunctionPtr& rootFunction, std::unordered_set& visitedFunctions, bool pythonOperandOrder = false) { vector functions; std::vector inputs; std::unordered_set uniqueInputs; PreorderTraverseVariables(rootFunction, visitedFunctions, [&inputs, &uniqueInputs](const Variable& var) { if (!var.IsOutput() && uniqueInputs.find(var) == uniqueInputs.end()) { inputs.push_back(var); uniqueInputs.insert(var); } }, pythonOperandOrder); return inputs; } // If the network is already created, copy internal state over from the functions in the graph into the underlying network. void UpdateInternalNetworkState(); // Copy state info from source function graph into' this' function graph. void CopyState(const CompositeFunction& source); static Variable GetMappingForNoOpOutput(const Variable& variable, bool recursive = false); static Variable GetMappingVariable(const Variable& variable, bool recursive = false); template Microsoft::MSR::CNTK::ComputationNetworkPtr GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, const std::unordered_set& outputs, const std::unordered_set& inputsToExcludeGradientsFor, bool allocateNetworkMatrices); template static Microsoft::MSR::CNTK::ComputationNodeBasePtr CreateComputationNode(const Variable& variable, Function* function, const std::vector>>& inputNodes, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, std::unordered_map& variableToNodeMap); template static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetOutputVariableNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, Microsoft::MSR::CNTK::ComputationNetworkBuilder& builder, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap, const std::unordered_set& inputsToExcludeGradientsFor); template static Microsoft::MSR::CNTK::ComputationNodeBasePtr GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, Microsoft::MSR::CNTK::ComputationNetworkBuilder& builder, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap, const std::unordered_set& inputsToExcludeGradientsFor); template static void PopulateComputationNodeValue(const std::pair& variableValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, std::unordered_map< Microsoft::MSR::CNTK::MBLayoutPtr, Variable>& layoutsPopulated); void PopulateNetworkInputs(const std::unordered_map& arguments); template static void PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode); void PopulateNetworkGradients(const std::unordered_map& gradients); static void GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient); void GetNetworkOutputs(std::unordered_map& outputs); void GetNetworkGradients(std::unordered_map& gradients); // Remove cyclic references for composite nodes static std::unordered_set NonOwnerPreservingCopy(const std::unordered_set& outputs); const std::vector& GetArgumentDependencies(const Variable& output); std::unordered_map GetCurrentBackpropRootsTimeStamps() const; private: // Set of all primitive functions in the graph underlying 'this' Function. Also keeps the primitive Function objects alive // by holding strong references to them std::unordered_set m_allPrimitiveFunctions; // A map from Variable objects to ComputationNode objects in the ComputationNetwork instance that implements 'this' Composite Function std::unordered_map m_variableToNodeMap; // A map that tells whether a Variable in the graph underlying 'this' Function is a root of the graph std::unordered_map m_isVariableRootMap; Microsoft::MSR::CNTK::ComputationNetworkPtr m_computationNetwork; // The backpropRoots sepecified in the most recent 'Forward' call on 'this' Function. // This indicates for which of its roots has 'this' Function retained required intermediate // states from the previos Forward call to be able to backpropagate gradients backwards from in // the next 'Backward' call. std::unordered_set m_currentBackpropRoots; std::unordered_map> m_perOutputVarArgumentDependencies; bool m_networkMatricesAllocated; std::vector m_allNetworkRootsInGlobalEvalOrder; std::unordered_map m_lastRecordedParameterValueTimeStamps; std::unordered_set m_inputsExcludedFromGradientComputation; // Version history: // 1 -- initial version. // 2 -- add support for stateful functions (with corresponding nodes inheriting from RngUser). static const size_t s_serializationVersion = 2; }; }