// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings #include "Basics.h" #include "ComputationNode.h" #include "ComputationNetwork.h" #include "RecurrentNodes.h" #include "InputAndParamNodes.h" #include "LinearAlgebraNodes.h" #include "SpecialPurposeNodes.h" #include #include #include #include #include #include #include #include using namespace std; namespace Microsoft { namespace MSR { namespace CNTK { // This source file contains methods related to evaluation (forward prop, backprop), network validation, and matrix memory allocation (memory sharing). // ----------------------------------------------------------------------- // forward and backward propagation // ----------------------------------------------------------------------- // MAIN ENTRY POINT for evaluating one minibatch (forward prop) // This calls ForwardProp() on all nodes in order of data flow through the network. // By default, the network is applied concurrently on all frames in a minibatch in parallel (PAR mode, a "map" operation) // Recurrent loops must be treated differently: // - a recurrent loop is the loop of nodes that make up computation for one time step (e.g. Times -> Plus -> Sigmoid -> Delay) // - these must be executed frame by frame (SEQuential) rather than as a map // - such a loop is treated as if they were a little nested network; this is done inside SEQTraversalFlowControlNodes // - these little nested networks are defined in the execution network in the form of nested sentinel nodes of type SEQTraversalFlowControlNode void ComputationNetwork::ForwardProp(const ComputationNodeBasePtr rootNode) { VerifyIsCompiled("ForwardProp"); // traverse all nodes in the pre-determined evaluation order GetNestedNetwork(rootNode)->ForwardProp(FrameRange(nullptr)); } void ComputationNetwork::PostForwardAndBackProp(const ComputationNodeBasePtr rootNode) { VerifyIsCompiled("PostForwardAndBackProp"); // traverse all nodes in the pre-determined evaluation order GetNestedNetwork(rootNode)->PostForwardAndBackProp(); } // set the gradient matrix of a (root) node 1.0 // Returns false if the node is not a ComputationNode; see Backprop() below for intended use. template static bool SetRootGradientToScalarOne(ComputationNodeBasePtr nodep) { auto node = dynamic_pointer_cast>(nodep); bool hasMatchingType = (node != nullptr); if (hasMatchingType) { // reset the root gradient to 1 node->ResetGradient(1); } return hasMatchingType; } // MAIN ENTRY POINT for evaluation followed by gradient computation (forward prop then back prop) // The typical calling pattern is: // - ForwardProp() for eval nodes // - ForwardProp() for the training criterion (which will reuse computation results from the previous step) // - Backprop() for the training criterion void ComputationNetwork::Backprop(const ComputationNodeBasePtr rootNode) // training criterion to compute the gradients for { if (!Environment().IsTraining()) LogicError("Backprop: Requires network is to be in training mode."); // initialize root gradient with a scalar value of 1.0 if (!SetRootGradientToScalarOne(rootNode) && !SetRootGradientToScalarOne(rootNode)) LogicError("Backprop: Training criterion is neither ComputationNode nor ComputationNode."); // reset all gradients below rootNode to zero (actually, internally, this is lazy, but we don't care here) ZeroInputGradients(rootNode); // backpropagate through the network GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true); } void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode) { if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end()) fprintf(stderr, "FormNestedNetwork: WARNING: Was called twice for %ls %ls operation\n", rootNode->NodeName().c_str(), rootNode->OperationName().c_str()); m_nestedNetworks[rootNode] = make_shared(m_allSEQNodes, GetEvalOrder(rootNode)); } ComputationNodeBasePtr ComputationNetwork::GetNestedNetwork(const ComputationNodeBasePtr& rootNode) { if (m_nestedNetworks.find(rootNode) == m_nestedNetworks.end()) LogicError("GetNestedNetwork: Called without prior call to FormNestedNetwork() for %ls %ls operation", rootNode->NodeName().c_str(), rootNode->OperationName().c_str()); return m_nestedNetworks[rootNode]; } // ----------------------------------------------------------------------- // PARTraversalFlowControlNode methods -- implements PAR traversal // // This implements an outer loop over non-recurrent nodes, where each node can be // executed in PAR mode; that is, all samples are independent and allow for // concurrent computation in bulk CUDA launches. // ----------------------------------------------------------------------- static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient); ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(const std::vector>& recurrentInfo, const std::list& allNodes /*must be in eval order*/) { // traverse the network in evaluation order and create a new list that replaces all recurrence by a SEQTraversalFlowControlNode set> loopsSeen; // for consistency check only for (auto nodeIter = allNodes.begin(); nodeIter != allNodes.end();) { shared_ptr recInfo = FindInRecurrentLoops(recurrentInfo, *nodeIter); // check if this node participates in a recurrent loop if (recInfo) // node is part of a SEQ loop: gather all of them. The nodes must be consecutive in 'allNodes' { // instead of the node itself, include the sentinel SEQTraversalFlowControlNode in our list m_nestedNodes.push_back(recInfo); // and verify that we only encountered the loop once (all nodes should have been consecutive) if (!loopsSeen.insert(recInfo).second) LogicError("PARTraversalFlowControlNode: members of loop %ls are not consecutive in node list.", recInfo->NodeName().c_str()); // consume all nodes that are part of the same loop (they are all consecutive) while (nodeIter != allNodes.end() && (*nodeIter)->IsPartOfLoop() && FindInRecurrentLoops(recurrentInfo, *nodeIter) == recInfo) nodeIter++; } else // regular top-level node (non-looping, PAR) { m_nestedNodes.push_back(*nodeIter); nodeIter++; // and consume this node } } } /*static*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const ComputationNodeBasePtr& node, const FrameRange& fr) { if (node->IsOutOfDateWrtInputs()) { node->BeginForwardProp(); node->BeginTiming(false /*backward*/); node->ForwardProp(fr.WithLayout(node->GetMBLayout())); node->EndTiming(false /*backward*/); node->EndForwardProp(); node->BumpEvalTimeStamp(); // Extreme Tracing, part 1/4 if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode()) DumpNode(node, /*dumpGradient=*/false); } } /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const FrameRange& fr) /*override*/ { for (auto& node : m_nestedNodes) ForwardProp(node, fr); } /*static*/ void ComputationNetwork::PARTraversalFlowControlNode::PostForwardAndBackProp(const ComputationNodeBasePtr& node) { node->PostForwardAndBackProp(); } /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::PostForwardAndBackProp() /*override*/ { for (auto& node : m_nestedNodes) PostForwardAndBackProp(node); } /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/ { childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode // process nodes in pre-determined order for (auto pnode = m_nestedNodes.rbegin(); pnode != m_nestedNodes.rend(); pnode++) // iterate backwards over evaluation order { auto& node = *pnode; node->BeginBackprop(); node->BeginTiming(true /*backward*/); node->Backprop(fr.WithLayout(node->GetMBLayout()), true /*childrenInThisLoop*/, true /*childrenInOuterLoop*/); node->EndTiming(true /*backward*/); node->EndBackprop(); // Extreme Tracing, part 2/4 if (node->IsParameterUpdateRequired() && node->NeedsGradient() && dynamic_pointer_cast>(node)->Gradient().HasNan("Gradient/UpdateWeights(): ")) DumpNode(node, /*dumpGradient=*/true); } } /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) /*override*/ { } /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool) /*override*/ { } /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::AllocateGradientMatricesForInputs(MatrixPool& matrixPool) /*override*/ { } /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::RequestMatricesBeforeBackprop(MatrixPool& matrixPool) /*override*/ { } /*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/ { } template bool TypedDumpNode(shared_ptr> node, bool dumpGradient) { if (!node) return false; let dataPtr = dumpGradient ? node->GradientPtr() : node->ValuePtr(); if (!dataPtr) return true; // e.g. SEQ sentinel node bool concise = false; /* chrono::milliseconds ms = chrono::duration_cast( chrono::system_clock::now().time_since_epoch() ); string file_path = "D:\\users\\vadimma\\CNTK_MLF\\CNTK\\Tests\\EndToEndTests\\Speech\\Data\\" + to_string(ms.count()) + ".txt"; FILE * pFile; pFile = fopen(file_path.c_str(), "r");*/ fprintf(stderr, "Dump --> %s%s\n", node->FormatOperationPrototype("").c_str(), dumpGradient ? " Grad" : ""); node->WriteMinibatchWithFormatting(stderr, FrameRange(), SIZE_MAX, SIZE_MAX, false/*transpose*/, /*isCategoryLabel=*/false, /*isSparse=*/false, std::vector(), ""/*sequenceSeparator*/, " "/*sequencePrologue*/, "\n"/*sequenceEpilogue*/, " "/*elementSeparator*/, "\n "/*sampleSeparator*/, "%13.10f"/*valueFormatString*/, dumpGradient, concise); /*fclose(pFile);*/ return true; } // helper for logging. Returns false if it was not able to dump static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient) { let nodef = dynamic_pointer_cast>(nodep); if (nodef) return TypedDumpNode(nodef, dumpGradient); let noded = dynamic_pointer_cast>(nodep); if (noded) return TypedDumpNode(noded, dumpGradient); let nodeh = dynamic_pointer_cast>(nodep); if (nodeh) return TypedDumpNode(nodeh, dumpGradient); return false; } // ----------------------------------------------------------------------- // SEQTraversalFlowControlNode methods -- implements SEQ traversal (loop unrolling) // // While PAR mode processes all samples in the MB independently, and thus in // PARallel, SEQ mode is to honor sequential dependencies. As such, it // unrolls the loop over time steps and runs the network once per time step. // ----------------------------------------------------------------------- /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::BeginForwardProp() /*override*/ { // take the opportunity to check that layout is shared by all nodes in the loop // TODO: we should do this in a constructor. for (auto& node : m_nestedNodes) { if (node->GetMBLayout() != GetMBLayout()) LogicError("Evaluate: All nodes inside a recurrent loop must have a layout that is identical; mismatch found for nodes '%ls' (%ls) vs. '%ls' (%ls)", node ->NodeName().c_str(), node ->GetMBLayoutAxisString().c_str(), m_nestedNodes[0]->NodeName().c_str(), m_nestedNodes[0]->GetMBLayoutAxisString().c_str()); } // tell all that loop is about to commence for (auto& node : m_nestedNodes) node->BeginForwardProp(); } // evaluation of a SEQTraversalFlowControlNode FlowControlNode // This evaluates all nodes in this FlowControlNode in SEQ mode: process the loop frame by frame in a nested loop. // This is where the time axis changes. // TODO: Once we do nested loops, then the FrameRange argument to this will refer to the outer loop. /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ForwardProp(const FrameRange&) /*override*/ { // get layout associated with this loop // All nodes share the same layout. assert(GetMBLayout() == m_nestedNodes[0]->GetMBLayout()); // for every time step run through all nodes in this particular loop (treat the loop like a little ComputationNetwork) // Note: Currently, this is limited to linear-time loops. But nothing stops the iteration below to, e.g., be a 2D iteration over an image // if we implement an according FrameRangeIteration. FrameRangeIteration range(GetMBLayout(), m_steppingDirection); for (auto t = range.begin(); t != range.end(); t++) { for (auto& node : m_nestedNodes) { node->BeginTiming(false /*backward*/); node->ForwardProp(t); node->EndTiming(false /*backward*/); node->BumpEvalTimeStamp(); } } // Extreme Tracing, part 3/4 for (auto& node : m_nestedNodes) { if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode()) { DumpNode(node, /*dumpGradient=*/false); } } } /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::EndForwardProp() /*override*/ { // tell all that loop is done --e.g. PastValueNode will capture its state for BPTT processing for (auto& node : m_nestedNodes) node->EndForwardProp(); } // called before first iteration step of ComputeGradient() /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::BeginBackprop() /*override*/ { for (auto& node2 : m_nestedNodes) node2->BeginBackprop(); } /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::Backprop(const FrameRange&, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/ { childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode const auto& recurrentNodes = m_nestedNodes; // BUGBUG: -ForForward?? Does this mean we can remove non-ForForward? auto pMBLayout = recurrentNodes[0]->GetMBLayout(); FrameRangeIteration range(pMBLayout, m_steppingDirection); for (auto t = range.rbegin(); t != range.rend(); t++) // note: reverse iteration { for (auto nodeIter2 = recurrentNodes.rbegin(); nodeIter2 != recurrentNodes.rend(); ++nodeIter2) { auto& node2 = *nodeIter2; node2->BeginTiming(true /*backward*/); node2->Backprop(t, true /*childrenInThisLoop*/, false /*childrenInOuterLoop*/); node2->EndTiming(true /*backward*/); // The above flags tell Backprop() to skip back-propagation from inside a node into // a node that is outside the loop, which is done later in EndBackprop() in PAR mode. } } // Extreme Tracing, part 4 for (auto& node : m_nestedNodes) { if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode() && node->NeedsGradient()) { DumpNode(node, /*dumpGradient=*/true); } } } // called after last iteration step of ComputeGradient() /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::EndBackprop() /*override*/ { // The following loop handles the case that a node inside the loop back-propagates a gradient into a node outside of the loop. // For efficiency, we perform this outside the loop in PAR mode. E.g., in one LSTM speech setup, we measured 12..14% overall speed-up. for (auto nodeIter2 = m_nestedNodes.rbegin(); nodeIter2 != m_nestedNodes.rend(); ++nodeIter2) { auto& node2 = *nodeIter2; node2->Backprop(FrameRange(m_nestedNodes[0]->GetMBLayout()), false /*childrenInThisLoop*/, true /*childrenInOuterLoop*/); } // tell all nodes we are done for this iteraTion for (auto& node2 : m_nestedNodes) node2->EndBackprop(); } /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) /*override*/ { for (auto& nodeLoopIter : m_nestedNodes) nodeLoopIter->RequestMatricesBeforeForwardProp(matrixPool); } /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool) /*override*/ { } /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::AllocateGradientMatricesForInputs(MatrixPool& matrixPool) /*override*/ { // TODO: should we deallocate in opposite order? for (auto nodeIter = m_nestedNodes.rbegin(); nodeIter != m_nestedNodes.rend(); ++nodeIter) { (*nodeIter)->AllocateGradientMatricesForInputs(matrixPool); } } /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::RequestMatricesBeforeBackprop(MatrixPool& matrixPool) /*override*/ { } /*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/ { for (auto nodeIter = m_nestedNodes.rbegin(); nodeIter != m_nestedNodes.rend(); ++nodeIter) { if ((*nodeIter)->NeedsGradient()) (*nodeIter)->ReleaseMatricesAfterBackprop(matrixPool); } } // find if node is part of a recurrent loop; and return the loop id // If found then return a pointer to the list of nodes of this loop. /*static*/ shared_ptr ComputationNetwork::FindInRecurrentLoops(const std::vector>& recurrentInfo, const ComputationNodeBasePtr& node) { // look in all recurrent loops of the network // TODO: Check for IsPartOfLoop(). Also why not store the loop id in the node for direct lookup? for (auto& iter : recurrentInfo) { if (std::find(iter->m_nestedNodes.begin(), iter->m_nestedNodes.end(), node) != iter->m_nestedNodes.end()) // TODO: should this loop need to be a method of SEQTraversalFlowControlNode? return iter; } return nullptr; // not part of a recurrent loop } // check if any of the nodes in the recurrence IsOutOfDateWrtInputs(), with exception of delay nodes for which this check would fail and must be skipped // TODO: Would it be sufficient to check against our own time stamp, so that we can use a unified time-stamping mechanism? Then we'd not need this special check for delayed nodes; just check all inputs against our own time stamp. bool ComputationNetwork::SEQTraversalFlowControlNode::IsOutOfDateWrtInputs() const { for (auto& ptr : m_nestedNodes) { if (ptr->IsOutOfDateWrtInputs() && ptr->OperationName() != OperationNameOf(PastValueNode) && ptr->OperationName() != OperationNameOf(FutureValueNode)) // TODO: when ShiftNode lands, check this as well. Ideally just test whether ptr is a IRecurrentNode { return true; } } return false; } // TODO: do this on PARTraversalFlowControlNode void ComputationNetwork::ResetEvalTimeStamps() { for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) nodeIter->second->ResetEvalTimeStamp(); } // Set EvalTimeStamp of all nodes as outdated, so that each node will be evaluated at least once. // The ResetEvalTimeStamps() above cannot do the work, since it only (re)sets the node to the current // global timestamp, which could be updated by other threads, so that the nodes of the network might // have different timestamps and the nodes with a higher timestamps are not treated as "outdated". void ComputationNetwork::SetEvalTimeStampsOutdatedWithRegardToAll() { for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++) nodeIter->second->SetEvalTimeStampOutdatedWrtAll(); } /*static*/ void ComputationNetwork::BumpEvalTimeStamp(const vector& nodes) { for (size_t i = 0; i < nodes.size(); i++) nodes[i]->BumpEvalTimeStamp(); } // for debugging void ComputationNetwork::PrintComputationTree(const ComputationNodeBasePtr& rootNode, const bool forwardCompute, const bool printMatrices) { auto nodes = GetEvalOrder(rootNode); // note: don't take a reference, since we reverse() below if (forwardCompute) { fprintf(stderr, "\n\nPrinting forward-computation node order ... \n"); } else { fprintf(stderr, "\n\nPrinting gradient-computation node order ... \n"); nodes.reverse(); } if (nodes.size() == 0) fprintf(stderr, "\n(empty)\n"); else { for (const auto& node : nodes) node->PrintSelf(printMatrices); } } // ----------------------------------------------------------------------- // preparation of network // ----------------------------------------------------------------------- // called by model editing operations, such as DeleteNode(); and by RebuildNetwork() // These invalidates any post-processed structures. If they are accessed, we will fail. void ComputationNetwork::InvalidateCompiledNetwork() { m_isCompiled = false; m_allSEQNodes.clear(); m_evalOrders.clear(); m_nestedNetworks.clear(); m_inputValues.clear(); m_learnableParameters.clear(); } // verify that network has undergone CompileNetwork() void ComputationNetwork::VerifyIsCompiled(const char* where) const { if (!IsCompiled()) LogicError("%s: A compiled network was expected.", where); } // CompileNetwork() -- bring network into executable state // Call this after creation, load, and any modification. // This method sets up all members that are cleared in InvalidateCompiledNetwork(); // TODO: This is in a somewhat partial state in that we now have a global eval order (keyed by a nullptr), but don't use it yet. void ComputationNetwork::CompileNetwork() { if (TraceLevel() > 0) fprintf(stderr, "\nPost-processing network...\n"); // We may only get here if not !IsCompiled(). We could now verify each member to be virgin. // Or just invalidate it again, which is easier and safer. InvalidateCompiledNetwork(); // all steps below have to be repeated for all root nodes (=nodes without parents and PreComputeNodes) DetermineSetOfAllRoots(); if (TraceLevel() > 0) { fprintf(stderr, "\n%d roots:\n", (int)m_allRoots.size()); for (const auto& root : m_allRoots) fprintf(stderr, "\t%ls = %ls()\n", root->NodeName().c_str(), root->OperationName().c_str()); } // Note: Steps below are loops over root nodes. We will gradually push those loops through to the functions, // to reduce redundant operation on shared portions of the network. // STEP: Create a depth-first tree-traversal order through complete graph. // TODO: Do not cache this before reordering; get list & pass to FormRecurrentLoops() which reorders it, then store it (such that GetEvalOrder(nullptr) is always valid w.r.t. loops). FormEvalOrder(nullptr); // STEP: Form the m_inputValues and m_learnableParameters sets for the entire network. // Needed for ResetMBLayouts() below. // TODO: Move this further down; or decide whether the 'nullptr' version is needed, other than ResetMBLayouts() which could use the global order and filter by itself. CollectInputAndLearnableParameters(nullptr); // STEP: Establish time-axis relationships. // This sets all MBLayout pointers of Input nodes according to user spec of time axes. // TODO: Don't use m_inputValues, traverse ourselves, to remove dependency on FormEvalOrder(). ResetMBLayouts(); // STEP: Discover nested loops. FormRecurrentLoops(); // STEP: Create loop-corrected depth-first traversals and cached input/parameter sets for every actual root node. for (auto& root : m_allRoots) { FormEvalOrder(root); CollectInputAndLearnableParameters(root); } // STEP: Form nested structure of PAR and SEQ traversal nodes. for (auto& node : m_allRoots) FormNestedNetwork(node); // STEP: Infer node dimensions. ValidateNetwork(); // STEP: Optimize the network. // :) // STEP: Some final details. ResetEvalTimeStamps(); // invalidate all m_value fields. Really belongs into StartEvaluateMinibatchLoop() if (TraceLevel() > 0) fprintf(stderr, "\nPost-processing network complete.\n\n"); m_isCompiled = true; } // determine the set of all root nodes // Roots are nodes that ForwardProp() may be called for. // - training criterion, eval criteria // - outputs // - PreComputeNodes // Result is stored in m_allRoots. // BUGBUG: In the current implementation, outputs that are also inputs to others must be specified explicitly e.g. by a tag. void ComputationNetwork::DetermineSetOfAllRoots() { // start with all non-referenced nodes set allNodes, referencedNodes; for (const auto& iter : m_nameToNodeMap) { auto node = iter.second; allNodes.insert(node); for (size_t i = 0; i < node->GetNumInputs(); i++) { auto input = node->Input(i); if (!input) // this may be the result of an incorrect MEL operation { InvalidArgument("DetermineSetOfAllRoots: Input %d of %ls %ls operation is not connected, network is malformed.", (int) i, node->NodeName().c_str(), node->OperationName().c_str()); } referencedNodes.insert(input); } } set unreferencedNodes; set_difference(allNodes.begin(), allNodes.end(), referencedNodes.begin(), referencedNodes.end(), inserter(unreferencedNodes, unreferencedNodes.end())); // add in all explicitly specified nodes. // TODO: This is not ideal. We will also need on-demand compilation, to allow any node to be used as an output after the fact. set allKnownRoots; for (const auto& node : FinalCriterionNodes()) allKnownRoots.insert(node); for (const auto& node : EvaluationNodes()) allKnownRoots.insert(node); for (const auto& node : OutputNodes()) allKnownRoots.insert(node); for (const auto& iter : m_nameToNodeMap) // PreComputeNodes { auto node = iter.second; if (node->RequiresPreCompute()) allKnownRoots.insert(node); } // set m_allRoots to include both non-referenced nodes and also all explicitly specified roots m_allRoots.clear(); set_union(unreferencedNodes.begin(), unreferencedNodes.end(), allKnownRoots.begin(), allKnownRoots.end(), inserter(m_allRoots, m_allRoots.end())); // and bring the roots into a well-defined order // I did observe different order depending on complexity of non-Node BrainScript expressions. sort(m_allRoots.begin(), m_allRoots.end(),[](const ComputationNodeBasePtr& a, const ComputationNodeBasePtr& b) { return a->NodeName() < b->NodeName(); }); } // initial setup of MBLayout pointers // - link all input nodes to one or more MBLayouts // - reset all others to nullptr, in expectation of a ValidateNetwork() pass void ComputationNetwork::ResetMBLayouts() { // reset to a well-defined MBLayout (any meaningful layout should do here) // Note that Validate is never called during operation. Any actual computation will lead to MBLayout to be set. m_pMBLayoutOfNetwork->Init(1, 0); // first reset all for (const auto& node : GetAllNodesForRoot(nullptr)) node->LinkToMBLayout(nullptr); // DynamicAxis nodes are (apart from the soon-to-be-deprecated network-wide MBLayout) the main holders of MBLayouts. Initialize them. // The only other instances are nodes that change the MBLayout, like WhereNode. for (auto node : GetNodesWithType(L"DynamicAxis")) node->LinkToMBLayout(make_shared(1, 0, node->GetName())); // This is now initialized inside of the Input nodes, with the proper connections. for (auto node : InputNodes(nullptr)) { // TODO: use if (!Is(node))... auto n = dynamic_pointer_cast(node); if (!n) LogicError("Expected %ls to implement ITakesDynamicAxis, but it doesn't.", node->NodeDescription().c_str()); std::wstring axisName = n->GetRequestedDynamicAxis(); if (axisName == L"") { // Legacy behavior: One shared MBLayout // TODO Remove m_pMBLayoutOfNetwork altogether. See issue 358. node->LinkToMBLayout(m_pMBLayoutOfNetwork); } else { auto axisNode = GetNodeFromName(axisName); if (!axisNode) RuntimeError("%ls: Can't find node '%ls' for retrieving dynamic axis.", axisNode->NodeDescription().c_str(), axisName.c_str()); // For now we require the node to be a DynamicAxisNode, though we could derive the same from other nodes. This would involve // more dependencies on the order in which things are evaluated, though. if (axisNode->OperationName() != L"DynamicAxis") RuntimeError("%ls: dynamicAxis argument must be of type DynamicAxis(), but got %ls.", node->NodeDescription().c_str(), axisNode->NodeDescription().c_str()); if (!axisNode->HasMBLayout()) LogicError("%ls: Expected %ls to have MBLayout, but it doesn't.", node->NodeDescription().c_str(), axisNode->NodeDescription().c_str()); node->LinkToMBLayout(axisNode->GetMBLayout()); } } } // ----------------------------------------------------------------------- // validation // ----------------------------------------------------------------------- // validate sub-network needed to evalute a specific output node // This calls Validate() on every node in evaluation order (allowing to propagate things forwards through the net). // This is called lazily but once only per node until next ClearCache(). // MBLayout links are expected to have been set up already for inputs, and reset to nullptr for all other nodes. void ComputationNetwork::ValidateNetwork() { // we call all nodes' Validate() in order to validate, that is, set up MBLayout and FunctionValues dimension // A problem is that recurrent loops may require partial validation. // Nodes validated on partial input (i.e. some children not yet validated) will be revisited. const auto& nodes = GetEvalOrder(nullptr); for (auto& node : nodes) { node->m_visited = false; node->m_needsGradient = node->IsParameterUpdateRequired(); // these get propagated upwards in the following } // loop and validate until we are done // steps: // - validate (not final) // not final means no dimension checks // Keep going through the list until all nodes have been validated and all inputs have been validated as well. // - validate (final) // final means consistency checks // Fail if any change during this stage. size_t pass = 1; size_t toValidate = nodes.size(); while (toValidate > 0) { if (TraceLevel() > 0) fprintf(stderr, "\nValidating network. %d nodes to process in pass %d.\n\n", (int) toValidate, (int) pass); toValidate = ValidateNodes(nodes, /*isFirstPass=*/pass == 1, false /*isFinalValidationPass*/); pass++; } if (TraceLevel() > 0) fprintf(stderr, "\nValidating network, final pass.\n\n"); toValidate = ValidateNodes(nodes, /*isFirstPass=*/pass == 1, true /*isFinalValidationPass*/); if (toValidate != 0) LogicError("ValidateSubNetwork: ValidateNodes(true) unexpectedly returned with work left to do."); // propagate some info to SEQTraversalFlowControlNode // TODO: In the future we should validate not on the flat list but the PARTraversalFlowControlNode structure. Then this will be unnecessary. for (auto& recInfo : m_allSEQNodes) { auto& node = recInfo->m_sourceNode; recInfo->m_needsGradient = node->m_needsGradient; recInfo->LinkToMBLayout(node->GetMBLayout()); } for (auto& node : nodes) { // nodes must output non-zero dimensional data, otherwise assume user error if (!node->m_needsDynamicValidation && node->GetSampleLayout().GetNumElements() == 0) RuntimeError("%ls operation has 0 elements", node->NodeName().c_str()); } if (TraceLevel() > 0) fprintf(stderr, "\n\n"); // logging the non-default-layout nodes vector nonDefaultNodes; for (auto node : nodes) { if (!(node->GetMBLayout() == m_pMBLayoutOfNetwork)) nonDefaultNodes.push_back(node); } #if 0 // this message is no longer necessary if (TraceLevel() > 0 && !nonDefaultNodes.empty()) { fprintf(stderr, "%d out of %d nodes do not share the minibatch layout with the input data.\n", (int)nonDefaultNodes.size(), (int)nodes.size()); // for (auto node : nonDefaultNodes) // fprintf(stderr, " %ls\n", node->NodeName().c_str()); // fprintf(stderr, "\n\n"); } #endif } // helper to discover dimension changes static pair GetDims(const ComputationNodeBasePtr& node) { return make_pair(node->GetSampleLayout(), node->HasMBLayout()); } bool ComputationNetwork::ValidateNode(ComputationNodeBasePtr node, bool isFinalValidationPass) const { const auto& children = node->GetInputs(); // keep state MBLayoutPtr oldMBLayoutPtr = node->GetMBLayout(); auto dim = GetDims(node); vector> childDims; for (auto& child : children) childDims.push_back(GetDims(child)); auto sampleLayout = node->GetSampleLayout(); // also take the opportunity to propagate m_needsGradient and m_nodeNeedsDynamicValidation auto nodeNeedsDynamicValidation = node->NeedsDynamicValidation(); node->m_needsDynamicValidation |= node->ForceDynamicValidation(); auto needsGradient = node->m_needsGradient; for (auto& child : children) // TODO: do we need a check that this is stable if isFinalValidationPass? { // check if this is StopGradientNode. For this node it is ok to not backprop gradient. if (node->OperationName() != OperationNameOf(StopGradientNode)) node->m_needsGradient |= child->m_needsGradient; node->m_needsDynamicValidation |= child->m_needsDynamicValidation; } // We do call validate(final) as many times as needed, since stuff may have changed underneath. node->Validate(isFinalValidationPass && !node->m_needsDynamicValidation /*final*/); // all nodes have been visited: do verification instead of just inference // check state --node will be valid if all nodes have been visited and node has not been updated bool unchanged = true; unchanged &= (oldMBLayoutPtr == node->GetMBLayout()); unchanged &= (dim == GetDims(node)); vector> newChildDims; for (auto& child : children) newChildDims.push_back(GetDims(child)); unchanged &= (childDims == newChildDims); unchanged &= (sampleLayout == node->GetSampleLayout()); unchanged &= (needsGradient == node->m_needsGradient); unchanged &= (nodeNeedsDynamicValidation == node->m_needsDynamicValidation); return !unchanged; } // perform one pass of validation over the topologically-sorted node set // returns how many nodes either could not yet be validated yet or have changed and thus must be redone size_t ComputationNetwork::ValidateNodes(list nodes, bool isFirstPass, bool isFinalValidationPass) { size_t todo = 0; for (auto& node : nodes) { const auto& children = node->GetInputs(); const bool isLeaf = node->IsLeaf(); // only validate a node if it has at least one child bool hasVisitedChild = false; bool allChildrenVisited = true; for (auto& child : children) { hasVisitedChild |= child->m_visited; // if not a single visited child then no point in validating allChildrenVisited &= child->m_visited; // Make sure we don't use DynamicAxis in places where it was not designed for. // This is a stop-gap. We need a more coherent concept for passing of shapes. if (child->OperationName() == L"DynamicAxis") RuntimeError("%ls: Cannot be used as input to another node. It can only be used on the 'dynamicAxis' property of an Input node.", child->NodeDescription().c_str()); } // if there is not at least one visited child bool valid = false; if (hasVisitedChild || isLeaf) // got at least one child: it makes sense to call Validate() { string prevPrototype = node->FormatOperationPrototype(""); bool unchanged; try { unchanged = !ValidateNode(node, isFinalValidationPass); string updatedPrototype = node->FormatOperationPrototype(""); #if 0 // print prototype in final validation pass. Problematic for tracking down validation errors in loops. unchanged; if (isFinalValidationPass) #else // print prototype upon every change (useful for debugging) if (isFirstPass || !unchanged || prevPrototype != updatedPrototype) #endif if (TraceLevel() > 0) fprintf(stderr, "Validating --> %s\n", updatedPrototype.c_str()); } catch (...) // if validation failed then print the prototype anyway so one can see the input args { fprintf(stderr, "Validating --> %s FAILED\n", prevPrototype.c_str()); throw; } node->m_visited = true; // print the new type // sanity checks if (isFinalValidationPass && !unchanged) LogicError("ValidateSubNetwork: %ls %ls operation changed during final validation.", node->NodeName().c_str(), node->OperationName().c_str()); if (isFinalValidationPass && !allChildrenVisited) LogicError("ValidateSubNetwork: %ls %ls operation in final validation although not all children were visited?", node->NodeName().c_str(), node->OperationName().c_str()); // if all children valid then valid = (allChildrenVisited && unchanged) || isLeaf; } // count those that we need to redo if (!valid) todo++; } return todo; } // ----------------------------------------------------------------------- // memory allocation // ----------------------------------------------------------------------- // mark nodes that are purely induced by parameters as non-sharable and create space for value if null void ComputationNetwork::MarkValueNonSharableNodes() { const auto& nodes = GetEvalOrder(nullptr); std::map allLeafDescendentsAreParametersOrPreComputeNodes; std::list allLearnableParameters = GetNodesWithType(OperationNameOf(LearnableParameter)); // note that: we cannot use m_learnableParameters because we need all parameters node, regardless whether it requires update or not std::list allPreComputeNodes; for (const auto& node : nodes) { if (node->Is()) allPreComputeNodes.push_back(node); } for (auto& node : nodes) { auto inputs = node->GetInputs(); // Mark the UserDefinedV2FunctionNode and all its inputs as ValueNonShareable, since // the inputs and outputs of a UDF may be externally preserved by the UDF implementation // for bakcpropagation and thus reusing them within the network is not possible as // we do not control when the user actually releases the input/output Matrices that // they may have help in the backprop state returned from the UDF's forward pass. bool isUserDefinedV2FunctionNode = (node->OperationName() == L"UserDefinedV2Function"); if (isUserDefinedV2FunctionNode) { node->MarkValueNonSharable(); for (auto input : inputs) input->MarkValueNonSharable(); } wstring myname = node->NodeName(); bool allParametersOrPreComputeNodes = true; if (inputs.size()) // we don't do the check for leaf node, cause all the possible leaf nodes (input/parameters/precompute node) are marked as non-sharable already { if (std::find(allPreComputeNodes.begin(), allPreComputeNodes.end(), node) == allPreComputeNodes.end()) { for (auto input : inputs) { const auto& inputName = input->NodeName(); if (allLeafDescendentsAreParametersOrPreComputeNodes.find(inputName) == allLeafDescendentsAreParametersOrPreComputeNodes.end()) { // not found, means it is a leaf node (we are at eval order ) assert(input->IsLeaf() || input->IsPartOfLoop()); if (std::find(allLearnableParameters.begin(), allLearnableParameters.end(), input) != allLearnableParameters.end()) { allLeafDescendentsAreParametersOrPreComputeNodes[inputName] = true; } else { allParametersOrPreComputeNodes = false; allLeafDescendentsAreParametersOrPreComputeNodes[inputName] = false; break; } } else { if (allLeafDescendentsAreParametersOrPreComputeNodes[inputName] == false) { allParametersOrPreComputeNodes = false; break; } } } } allLeafDescendentsAreParametersOrPreComputeNodes[myname] = allParametersOrPreComputeNodes; if (allParametersOrPreComputeNodes) node->MarkValueNonSharable(); } } } // From the set of nodes extract all nodes which are used as accumulator nodes. set ComputationNetwork::ExtractNodesWhichAccumulateResult(set candidates) { const auto& nodes = GetEvalOrder(nullptr); // Set of nodes which leaf descendants are learnable parameters only. Initially, we add learnable parameter nodes to // the set. Later, we add all nodes which have all children nodes from this list. auto allLearnableParameters = GetNodesWithType(OperationNameOf(LearnableParameter)); set allLeafDescendantsAreLearnableParameters(allLearnableParameters.begin(), allLearnableParameters.end()); // Set of nodes that accumulate samples from the input. // We initially add all epoch accumulator nodes to this list. Later, we add all nodes that have at least one child // node from this list and all other child nodes from the list above (whose leaf descendants are learnable // parameters only). Combination of accumulator node with another accumulator node, or with nodes whose leaf // descendants are learnable parameter is also accumulator node. auto allEpochAccumulatorNodes = GetNodesWithType(OperationNameOf(EpochAccumulatorNode)); set accumulatorNodes(allEpochAccumulatorNodes.begin(), allEpochAccumulatorNodes.end()); if (!accumulatorNodes.empty()) { for (auto& node : nodes) { auto inputs = node->GetInputs(); bool hasAccumulatorInput = false; bool areAllLeafDescendantsLearnableNodes = true; // Indicates that node shouldn't be added to set of nodes with learnable parameter descendants, nor to the // set of accumulator nodes. bool skipNode = false; for (auto input : inputs) { bool hasAllLearnableParameterDescendents = allLeafDescendantsAreLearnableParameters.find(input) != allLeafDescendantsAreLearnableParameters.end(); bool isAccumulatorNode = accumulatorNodes.find(input) != accumulatorNodes.end(); if (!hasAllLearnableParameterDescendents) areAllLeafDescendantsLearnableNodes = false; if (isAccumulatorNode) hasAccumulatorInput = true; if (!isAccumulatorNode && !hasAllLearnableParameterDescendents) { skipNode = true; break; } } if (skipNode) continue; if (areAllLeafDescendantsLearnableNodes) allLeafDescendantsAreLearnableParameters.insert(node); if (hasAccumulatorInput) accumulatorNodes.insert(node); } } // Extract all candidate nodes that appear in set of accumulator nodes. set intersection; set_intersection(accumulatorNodes.begin(), accumulatorNodes.end(), candidates.begin(), candidates.end(), inserter(intersection, intersection.begin())); return intersection; } // print memory-sharing information to log void ComputationNetwork::PrintMemorySharingStructure(const vector& nodes) { map > memSharingStructure; size_t numMatrices = 0; for (const auto& node : nodes) { set> matrixInfo = node->GetMatrixInfo(); for (const auto& item : matrixInfo) // {value} or {value, gradient} { memSharingStructure[item.first].insert(item.second); numMatrices++; } } // count shared/unshared size_t numShared = 0; size_t numUnshared = 0; for (const auto& item : memSharingStructure) { if (item.second.size() < 2) // unshared matrices numUnshared++; else // shared matrices numShared++; } fprintf(stderr, "\nMemory Sharing: Out of %d matrices, %d are shared as %d, and %d are not shared.\n", (int)numMatrices, (int)(numMatrices - numUnshared), (int)numShared, (int)numUnshared); fprintf(stderr, "\nHere are the ones that share memory:\n"); for (const auto& item : memSharingStructure) { if (item.second.size() >= 2) { // Format: // { node1 // node2 } // { node3 // node4 // node5 } const char* delim = "\t{ "; for (const auto& memShareInfo : item.second) { fprintf(stderr, "%s%ls", delim, memShareInfo.c_str()); delim = "\n\t "; } fprintf(stderr, " }\n"); } } fprintf(stderr, "\nHere are the ones that don't share memory:\n"); for (const auto& item : memSharingStructure) { if (item.second.size() < 2) { fprintf(stderr, "\t{%ls}\n", item.second.begin()->c_str()); } } fprintf(stderr, "\n"); } // this function will need to be called before actual validation and execution to // predetermine how to share matrices to reduce memory usage. // TODO: find a simple topological order and allocateEvalMatrices on that order directly // without passing in eval, out, and train nodes. void ComputationNetwork::AllocateAllMatrices(const std::vector& evalRootNodes, const std::vector& outValueRootNodes, ComputationNodeBasePtr trainRootNode) { if (AreMatricesAllocated()) return; // Allocate memory for forward/backward computation if (TraceLevel() > 0) fprintf(stderr, "\n\nAllocating matrices for forward and/or backward propagation.\n"); VerifyIsCompiled("AllocateAllMatrices"); std::vector forwardPropRoots; forwardPropRoots.insert(forwardPropRoots.end(), evalRootNodes.begin(), evalRootNodes.end()); forwardPropRoots.insert(forwardPropRoots.end(), outValueRootNodes.begin(), outValueRootNodes.end()); if (trainRootNode != nullptr) forwardPropRoots.push_back(trainRootNode); // Mark all the eval, output and criterion roots as non-shareable for (auto& rootNode : forwardPropRoots) rootNode->MarkValueNonSharable(); // Due to special topology, if a node is solely induced by parameters, its function value should not be shared MarkValueNonSharableNodes(); bool performingBackPropagation = (trainRootNode != nullptr); // Create a composite Eval order with the specified nodes as roots // For each node determine parents and whether the output of the // node is needed during back propagation std::unordered_map outputValueNeededDuringBackProp; std::unordered_map> parentsMap; std::unordered_set uniqueForwardPropEvalNodes; for (auto& rootNode : forwardPropRoots) { for (const auto& node : GetEvalOrder(rootNode)) { if (uniqueForwardPropEvalNodes.find(node) == uniqueForwardPropEvalNodes.end()) uniqueForwardPropEvalNodes.insert(node); for (int i = 0; i < node->GetNumInputs(); i++) { ComputationNodeBasePtr input = node->GetInputs()[i]; parentsMap[input].insert(node); if (performingBackPropagation) { if (outputValueNeededDuringBackProp.find(input) == outputValueNeededDuringBackProp.end()) outputValueNeededDuringBackProp[input] = input->NeedsGradient() && input->OutputUsedInComputingInputNodesGradients(); outputValueNeededDuringBackProp[input] |= (node->NeedsGradient() && node->InputUsedInComputingInputNodesGradients(i)); } else outputValueNeededDuringBackProp[input] = false; } } } // gradient reuse maps std::unordered_map> gradientReuseChildrenMap; std::unordered_map gradientReuseParentMap; for (auto& keyValue : parentsMap) { // Indicate on the node that it's parent overwrites its gradient if the node is not part of a loop // and has exactly one parent who implements the gradient overwrite optimization if (Globals::ShouldOptimizeGradientAccumulation() && !keyValue.first->IsPartOfLoop() && (keyValue.second.size() == 1)) { auto parent = *keyValue.second.begin(); auto opt = parent->ImplementsGradientOptimization(keyValue.first.get()); if (opt != ParentGradientOptimization::None && trainRootNode != parent) { // We cannot enable the gradient overwrite/reuse optimization if this node's (lone) parent // has this same node as multiple of its inputs since, in that case the // gradients will flow back from multiple paths of the same parent into the input auto& allInputsOfParent = parent->GetInputs(); if (std::count(allInputsOfParent.begin(), allInputsOfParent.end(), keyValue.first) <= 1) { auto child = keyValue.first; child->SetParentGradientOptimization(opt); if (opt == ParentGradientOptimization::Reuse) { gradientReuseChildrenMap[&*parent].insert(&*child); if (gradientReuseParentMap.find(&*child) != gradientReuseParentMap.end()) LogicError("Already has a gradient reuse parent."); gradientReuseParentMap[&*child] = &*parent; } } } } } m_matrixPool.Reset(); TravserseInSortedGlobalEvalOrder(forwardPropRoots, [&outputValueNeededDuringBackProp, &parentsMap, this](const ComputationNodeBasePtr& node) { if (node->Is()) { auto seqTraversalFlowControlNode = node->As(); for (auto& loopNode : seqTraversalFlowControlNode->m_nestedNodes) loopNode->SetOutputNeededDuringBackprop(outputValueNeededDuringBackProp[loopNode]); seqTraversalFlowControlNode->RequestMatricesBeforeForwardProp(m_matrixPool); for (auto& loopNode : seqTraversalFlowControlNode->m_nestedNodes) ReleaseMatricesAfterEvalForChildren(loopNode, parentsMap); } else { node->SetOutputNeededDuringBackprop(outputValueNeededDuringBackProp[node]); node->RequestMatricesBeforeForwardProp(m_matrixPool); // we only release matrices for the children since the root node's information will be used // and should not be shared with others ReleaseMatricesAfterEvalForChildren(node, parentsMap); } }); if (trainRootNode != nullptr) { const std::list& backPropNodes = GetEvalOrder(trainRootNode); // compact the alias map for cases like s = a + b + c + d std::unordered_map> compactGradientAliasMap; std::unordered_map compactGradientAliasRootMap; for (const auto& gradientReuseKeyValue : gradientReuseChildrenMap) { // keep searching parent until reaching root auto parent = gradientReuseKeyValue.first; auto parentIter = gradientReuseParentMap.find(parent); while (parentIter != gradientReuseParentMap.end()) { parent = parentIter->second; parentIter = gradientReuseParentMap.find(parent); } // add children to the alias group under the root auto children = gradientReuseKeyValue.second; compactGradientAliasMap[parent].insert(children.begin(), children.end()); for (const auto& child : children) { if (compactGradientAliasRootMap.find(child) != compactGradientAliasRootMap.end()) LogicError("one node cannot be in two alias group"); compactGradientAliasRootMap[child] = parent; } // and add root itself to the alias group compactGradientAliasMap[parent].insert(parent); compactGradientAliasRootMap[parent] = parent; } // print the memory aliasing info if (TraceLevel() > 0 && compactGradientAliasRootMap.size() > 0) { fprintf(stderr, "\nGradient Memory Aliasing: %d are aliased.\n", (int)compactGradientAliasRootMap.size()); for (const auto pair : compactGradientAliasRootMap) { auto child = (const ComputationNodeBase*)pair.first; auto parent = (const ComputationNodeBase*)pair.second; if (child != parent) fprintf(stderr, "\t%S (gradient) reuses %S (gradient)\n", child->GetName().c_str(), parent->GetName().c_str()); } } m_matrixPool.SetAliasInfo(compactGradientAliasMap, compactGradientAliasRootMap); // now, simulate the gradient computation order to determine how to allocate matrices set completedGradient; // we need to call it here since we always compute gradients for children and root node is not children of other node trainRootNode->RequestMatricesBeforeBackprop(m_matrixPool); for (auto iter = backPropNodes.rbegin(); iter != backPropNodes.rend(); iter++) // for gradient computation, traverse in reverse order { auto n = *iter; if (n->IsPartOfLoop()) { std::vector recurrentNodes; shared_ptr recInfo = FindInRecurrentLoops(m_allSEQNodes, n); if (completedGradient.insert(recInfo).second) { // SEQ mode: allocate all in loop first, then deallocate again // TODO: next step: use PARTraversalFlowControlNode::AllocateGradientMatricesForInputs() and ReleaseMatricesAfterBackprop()... // BUGBUG: naw, ^^ would not work! Wrong order! Need to rethink this. Need to make AllocateEvalMatrices() and AllocateGradientMatrices() the virtual functions. recInfo->AllocateGradientMatricesForInputs(m_matrixPool); // Loops are computed sample by sample so we have to allocate them all recInfo->ReleaseMatricesAfterBackprop(m_matrixPool); } } else { // PAR mode: we can allocate and immediately deallocate one by one n->AllocateGradientMatricesForInputs(m_matrixPool); // Root node's information will be used and should not be shared with others, also it's small (1x1) if ((n != trainRootNode) && n->NeedsGradient()) n->ReleaseMatricesAfterBackprop(m_matrixPool); } } } m_matrixPool.OptimizedMemoryAllocation(); m_areMatricesAllocated = true; // TO DO: At the time of AllocateAllMatrices we don't know the minibatch size. In theory one may allocate memory again once we start to receive // data from the reader (and the minibatch size is known). For some problems, minibatch size can change constantly, and there needs to be a // tradeoff in deciding how frequent to run optimized memory allocation. For now, we do it only once at the very beginning for speed concerns. // TO DO: when some matrices are sparse, the memory size request may be wrong. One may need to call OptimizedMemoryAllocation later again // if the requests of sparse allocation and release are re-processed correctly. Future work. // print the memory sharing structure if (TraceLevel() > 0) PrintMemorySharingStructure(GetAllNodes()); } void ComputationNetwork::ReleaseMatricesAfterEvalForChildren(ComputationNodeBasePtr n, std::unordered_map>& parentsMap) { for (int i = 0; i < n->GetNumInputs(); i++) { ComputationNodeBasePtr pNode = n->GetInputs()[i]; if (!parentsMap[pNode].empty()) { parentsMap[pNode].erase(n); if (parentsMap[pNode].empty()) pNode->ReleaseMatricesAfterForwardProp(m_matrixPool); } } } }}}