https://github.com/Microsoft/CNTK
Revision 3af9845d0b2f43209fcdfc088f1685866c436aaf authored by Igor Macedo Quintanilha on 20 April 2018, 16:00:24 UTC, committed by Igor Macedo Quintanilha on 20 April 2018, 16:00:24 UTC
1 parent 8c0089e
Raw File
Tip revision: 3af9845d0b2f43209fcdfc088f1685866c436aaf authored by Igor Macedo Quintanilha on 20 April 2018, 16:00:24 UTC
Removing x86 builds
Tip revision: 3af9845
ComputationNetworkEvaluation.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms  --add this at the top of all CPP files that give "function or variable may be unsafe" warnings

#include "Basics.h"
#include "ComputationNode.h"
#include "ComputationNetwork.h"
#include "RecurrentNodes.h"
#include "InputAndParamNodes.h"
#include "LinearAlgebraNodes.h"
#include <string>
#include <vector>
#include <list>
#include <set>
#include <algorithm>
#include <map>

using namespace std;

namespace Microsoft { namespace MSR { namespace CNTK {

// This source file contains methods related to evaluation (forward prop, backprop), network validation, and matrix memory allocation (memory sharing).

// -----------------------------------------------------------------------
// forward and backward propagation
// -----------------------------------------------------------------------

// MAIN ENTRY POINT for evaluating one minibatch (forward prop)
// This calls ForwardProp() on all nodes in order of data flow through the network.
// By default, the network is applied concurrently on all frames in a minibatch in parallel (PAR mode, a "map" operation)
// Recurrent loops must be treated differently:
//  - a recurrent loop is the loop of nodes that make up computation for one time step (e.g. Times -> Plus -> Sigmoid -> Delay)
//  - these must be executed frame by frame (SEQuential) rather than as a map
//  - such a loop is treated as if they were a little nested network; this is done inside SEQTraversalFlowControlNodes
//  - these little nested networks are defined in the execution network in the form of nested sentinel nodes of type SEQTraversalFlowControlNode
void ComputationNetwork::ForwardProp(const ComputationNodeBasePtr rootNode)
{
    VerifyIsCompiled("ForwardProp");

    // traverse all nodes in the pre-determined evaluation order
    GetNestedNetwork(rootNode)->ForwardProp(FrameRange(nullptr));
}

void ComputationNetwork::PostForwardAndBackProp(const ComputationNodeBasePtr rootNode)
{
    VerifyIsCompiled("PostForwardAndBackProp");

    // traverse all nodes in the pre-determined evaluation order
    GetNestedNetwork(rootNode)->PostForwardAndBackProp();
}

// set the gradient matrix of a (root) node 1.0
// Returns false if the node is not a ComputationNode<ElemType>; see Backprop() below for intended use.
template <class ElemType>
static bool SetRootGradientToScalarOne(ComputationNodeBasePtr nodep)
{
    auto node = dynamic_pointer_cast<ComputationNode<ElemType>>(nodep);
    bool hasMatchingType = (node != nullptr);
    if (hasMatchingType)
    {
        // reset the root gradient to 1
        node->ResetGradient(1);
    }
    return hasMatchingType;
}

// MAIN ENTRY POINT for evaluation followed by gradient computation (forward prop then back prop)
// The typical calling pattern is:
//  - ForwardProp() for eval nodes
//  - ForwardProp() for the training criterion (which will reuse computation results from the previous step)
//  - Backprop() for the training criterion
void ComputationNetwork::Backprop(const ComputationNodeBasePtr rootNode) // training criterion to compute the gradients for
{
    if (!Environment().IsTraining())
        LogicError("Backprop: Requires network is to be in training mode.");

    // initialize root gradient with a scalar value of 1.0
    if (!SetRootGradientToScalarOne<float>(rootNode) && !SetRootGradientToScalarOne<double>(rootNode))
        LogicError("Backprop: Training criterion is neither ComputationNode<float> nor ComputationNode<double>.");

    // reset all gradients below rootNode to zero (actually, internally, this is lazy, but we don't care here)
    ZeroInputGradients(rootNode);

    // backpropagate through the network
    GetNestedNetwork(rootNode)->Backprop(FrameRange(nullptr), true, true);
}

void ComputationNetwork::FormNestedNetwork(const ComputationNodeBasePtr& rootNode)
{
    if (m_nestedNetworks.find(rootNode) != m_nestedNetworks.end())
        fprintf(stderr, "FormNestedNetwork: WARNING: Was called twice for %ls %ls operation\n", rootNode->NodeName().c_str(), rootNode->OperationName().c_str());

    m_nestedNetworks[rootNode] = make_shared<PARTraversalFlowControlNode>(m_allSEQNodes, GetEvalOrder(rootNode));
}

ComputationNodeBasePtr ComputationNetwork::GetNestedNetwork(const ComputationNodeBasePtr& rootNode)
{
    if (m_nestedNetworks.find(rootNode) == m_nestedNetworks.end())
        LogicError("GetNestedNetwork: Called without prior call to FormNestedNetwork() for %ls %ls operation", rootNode->NodeName().c_str(), rootNode->OperationName().c_str());
    return m_nestedNetworks[rootNode];
}

// -----------------------------------------------------------------------
// PARTraversalFlowControlNode methods -- implements PAR traversal
//
// This implements an outer loop over non-recurrent nodes, where each node can be
// executed in PAR mode; that is, all samples are independent and allow for
// concurrent computation in bulk CUDA launches.
// -----------------------------------------------------------------------

static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient);

ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(const std::vector<shared_ptr<SEQTraversalFlowControlNode>>& recurrentInfo, const std::list<ComputationNodeBasePtr>& allNodes /*must be in eval order*/)
{
    // traverse the network in evaluation order and create a new list that replaces all recurrence by a SEQTraversalFlowControlNode
    set<shared_ptr<IComputationNode>> loopsSeen; // for consistency check only
    for (auto nodeIter = allNodes.begin(); nodeIter != allNodes.end();)
    {
        shared_ptr<SEQTraversalFlowControlNode> recInfo = FindInRecurrentLoops(recurrentInfo, *nodeIter); // check if this node participates in a recurrent loop
        if (recInfo)                                                                                      // node is part of a SEQ loop: gather all of them. The nodes must be consecutive in 'allNodes'
        {
            // instead of the node itself, include the sentinel SEQTraversalFlowControlNode in our list
            m_nestedNodes.push_back(recInfo);

            // and verify that we only encountered the loop once (all nodes should have been consecutive)
            if (!loopsSeen.insert(recInfo).second)
                LogicError("PARTraversalFlowControlNode: members of loop %ls are not consecutive in node list.", recInfo->NodeName().c_str());

            // consume all nodes that are part of the same loop (they are all consecutive)
            while (nodeIter != allNodes.end() && (*nodeIter)->IsPartOfLoop() && FindInRecurrentLoops(recurrentInfo, *nodeIter) == recInfo)
                nodeIter++;
        }
        else // regular top-level node (non-looping, PAR)
        {
            m_nestedNodes.push_back(*nodeIter);
            nodeIter++; // and consume this node
        }
    }
}
/*static*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const ComputationNodeBasePtr& node, const FrameRange& fr)
{
    if (node->IsOutOfDateWrtInputs())
    {
        node->BeginForwardProp();
        node->BeginTiming(false /*backward*/);
        node->ForwardProp(fr.WithLayout(node->GetMBLayout()));
        node->EndTiming(false /*backward*/);
        node->EndForwardProp();

        node->BumpEvalTimeStamp();

        // Extreme Tracing, part 1/4
        if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode())
            DumpNode(node, /*dumpGradient=*/false);
    }
}


/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(const FrameRange& fr) /*override*/
{
    for (auto& node : m_nestedNodes)
        ForwardProp(node, fr);
}

/*static*/ void ComputationNetwork::PARTraversalFlowControlNode::PostForwardAndBackProp(const ComputationNodeBasePtr& node)
{
    node->PostForwardAndBackProp();
}

/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::PostForwardAndBackProp() /*override*/
{
    for (auto& node : m_nestedNodes)
        PostForwardAndBackProp(node);
}

/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::Backprop(const FrameRange& fr, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
{
    childrenInThisLoop, childrenInOuterLoop; // TODO: think through what these mean when coming from PAR mode
    // process nodes in pre-determined order
    for (auto pnode = m_nestedNodes.rbegin(); pnode != m_nestedNodes.rend(); pnode++) // iterate backwards over evaluation order
    {
        auto& node = *pnode;

        node->BeginBackprop();
        node->BeginTiming(true /*backward*/);
        node->Backprop(fr.WithLayout(node->GetMBLayout()), true /*childrenInThisLoop*/, true /*childrenInOuterLoop*/);
        node->EndTiming(true /*backward*/);
        node->EndBackprop();

        // Extreme Tracing, part 2/4
        if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode() && node->NeedsGradient())
            DumpNode(node, /*dumpGradient=*/true);
    }
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::AllocateGradientMatricesForInputs(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::RequestMatricesBeforeBackprop(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::PARTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
{
}

template<typename ElemType>
bool TypedDumpNode(shared_ptr<ComputationNode<ElemType>> node, bool dumpGradient)
{
    if (!node)
        return false;
    let dataPtr = dumpGradient ? node->GradientPtr() : node->ValuePtr();
    if (!dataPtr)
        return true; // e.g. SEQ sentinel node

    bool concise = !(node->Environment().IsLogLevelNodeTrace());

    fprintf(stderr, "Dump --> %s%s\n", node->FormatOperationPrototype("").c_str(), dumpGradient ? " Grad" : "");
    node->WriteMinibatchWithFormatting(stderr, FrameRange(), SIZE_MAX, SIZE_MAX, false/*transpose*/, /*isCategoryLabel=*/false, /*isSparse=*/false, std::vector<std::string>(),
        ""/*sequenceSeparator*/, "  "/*sequencePrologue*/, "\n"/*sequenceEpilogue*/, " "/*elementSeparator*/, "\n  "/*sampleSeparator*/,
        "%13.10f"/*valueFormatString*/, dumpGradient, concise);
    return true;
}

// helper for logging. Returns false if it was not able to dump
static bool DumpNode(ComputationNodeBasePtr nodep, bool dumpGradient)
{
    let nodef = dynamic_pointer_cast<ComputationNode<float>>(nodep);
    if (nodef) return TypedDumpNode<float>(nodef, dumpGradient);
    let noded = dynamic_pointer_cast<ComputationNode<double>>(nodep);
    if (noded) return TypedDumpNode<double>(noded, dumpGradient);
    let nodeh = dynamic_pointer_cast<ComputationNode<half>>(nodep);
    if (nodeh) return TypedDumpNode<half>(nodeh, dumpGradient);
    return false;
}

// -----------------------------------------------------------------------
// SEQTraversalFlowControlNode methods -- implements SEQ traversal (loop unrolling)
//
// While PAR mode processes all samples in the MB independently, and thus in
// PARallel, SEQ mode is to honor sequential dependencies. As such, it
// unrolls the loop over time steps and runs the network once per time step.
// -----------------------------------------------------------------------

/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::BeginForwardProp() /*override*/
{
    // take the opportunity to check that layout is shared by all nodes in the loop
    // TODO: we should do this in a constructor.
    for (auto& node : m_nestedNodes)
    {
        if (node->GetMBLayout() != GetMBLayout())
            LogicError("Evaluate: All nodes inside a recurrent loop must have a layout that is identical; mismatch found for nodes '%ls' (%ls) vs. '%ls' (%ls)",
                       node            ->NodeName().c_str(), node            ->GetMBLayoutAxisString().c_str(),
                       m_nestedNodes[0]->NodeName().c_str(), m_nestedNodes[0]->GetMBLayoutAxisString().c_str());
    }

    // tell all that loop is about to commence
    for (auto& node : m_nestedNodes)
        node->BeginForwardProp();
}

// evaluation of a SEQTraversalFlowControlNode FlowControlNode
// This evaluates all nodes in this FlowControlNode in SEQ mode: process the loop frame by frame in a nested loop.
// This is where the time axis changes.
// TODO: Once we do nested loops, then the FrameRange argument to this will refer to the outer loop.
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ForwardProp(const FrameRange&) /*override*/
{
    // get layout associated with this loop
    // All nodes share the same layout.
    assert(GetMBLayout() == m_nestedNodes[0]->GetMBLayout());

    // for every time step run through all nodes in this particular loop (treat the loop like a little ComputationNetwork)
    // Note: Currently, this is limited to linear-time loops. But nothing stops the iteration below to, e.g., be a 2D iteration over an image
    // if we implement an according FrameRangeIteration.
    FrameRangeIteration range(GetMBLayout(), m_steppingDirection);
    for (auto t = range.begin(); t != range.end(); t++)
    {
        for (auto& node : m_nestedNodes)
        {
            node->BeginTiming(false /*backward*/);
            node->ForwardProp(t);
            node->EndTiming(false /*backward*/);
            node->BumpEvalTimeStamp();
        }
    }

    // Extreme Tracing, part 3/4
    for (auto& node : m_nestedNodes)
    {
        if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode())
        {
            DumpNode(node, /*dumpGradient=*/false);
        }
    }
}

/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::EndForwardProp() /*override*/
{
    // tell all that loop is done  --e.g. PastValueNode will capture its state for BPTT processing
    for (auto& node : m_nestedNodes)
        node->EndForwardProp();
}

// called before first iteration step of ComputeGradient()
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::BeginBackprop() /*override*/
{
    for (auto& node2 : m_nestedNodes)
        node2->BeginBackprop();
}

/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::Backprop(const FrameRange&, bool childrenInThisLoop, bool childrenInOuterLoop) /*override*/
{
    childrenInThisLoop, childrenInOuterLoop;    // TODO: think through what these mean when coming from PAR mode
    const auto& recurrentNodes = m_nestedNodes; // BUGBUG: -ForForward?? Does this mean we can remove non-ForForward?
    auto pMBLayout = recurrentNodes[0]->GetMBLayout();
    FrameRangeIteration range(pMBLayout, m_steppingDirection);
    for (auto t = range.rbegin(); t != range.rend(); t++) // note: reverse iteration
    {
        for (auto nodeIter2 = recurrentNodes.rbegin(); nodeIter2 != recurrentNodes.rend(); ++nodeIter2)
        {
            auto& node2 = *nodeIter2;
            node2->BeginTiming(true /*backward*/);
            node2->Backprop(t, true /*childrenInThisLoop*/, false /*childrenInOuterLoop*/);
            node2->EndTiming(true /*backward*/);
            // The above flags tell Backprop() to skip back-propagation from inside a node into
            // a node that is outside the loop, which is done later in EndBackprop() in PAR mode.
        }
    }

    // Extreme Tracing, part 4
    for (auto& node : m_nestedNodes)
    {
        if (node->HasEnvironmentPtr() && node->Environment().ShouldDumpNode() && node->NeedsGradient())
        {
            DumpNode(node, /*dumpGradient=*/true);
        }
    }
}

// called after last iteration step of ComputeGradient()
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::EndBackprop() /*override*/
{
    // The following loop handles the case that a node inside the loop back-propagates a gradient into a node outside of the loop.
    // For efficiency, we perform this outside the loop in PAR mode. E.g., in one LSTM speech setup, we measured 12..14% overall speed-up.
    for (auto nodeIter2 = m_nestedNodes.rbegin(); nodeIter2 != m_nestedNodes.rend(); ++nodeIter2)
    {
        auto& node2 = *nodeIter2;
        node2->Backprop(FrameRange(m_nestedNodes[0]->GetMBLayout()), false /*childrenInThisLoop*/, true /*childrenInOuterLoop*/);
    }

    // tell all nodes we are done for this iteraTion
    for (auto& node2 : m_nestedNodes)
        node2->EndBackprop();
}

/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::RequestMatricesBeforeForwardProp(MatrixPool& matrixPool) /*override*/
{
    for (auto& nodeLoopIter : m_nestedNodes)
        nodeLoopIter->RequestMatricesBeforeForwardProp(matrixPool);
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ReleaseMatricesAfterForwardProp(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::AllocateGradientMatricesForInputs(MatrixPool& matrixPool) /*override*/
{
    // TODO: should we deallocate in opposite order?
    for (auto nodeIter = m_nestedNodes.rbegin(); nodeIter != m_nestedNodes.rend(); ++nodeIter)
    {
        (*nodeIter)->AllocateGradientMatricesForInputs(matrixPool);
    }
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::RequestMatricesBeforeBackprop(MatrixPool& matrixPool) /*override*/
{
}
/*virtual*/ void ComputationNetwork::SEQTraversalFlowControlNode::ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) /*override*/
{
    for (auto nodeIter = m_nestedNodes.rbegin(); nodeIter != m_nestedNodes.rend(); ++nodeIter)
    {
        if ((*nodeIter)->NeedsGradient())
            (*nodeIter)->ReleaseMatricesAfterBackprop(matrixPool);
    }
}

// find if node is part of a recurrent loop; and return the loop id
// If found then return a pointer to the list of nodes of this loop.
/*static*/ shared_ptr<ComputationNetwork::SEQTraversalFlowControlNode> ComputationNetwork::FindInRecurrentLoops(const std::vector<std::shared_ptr<SEQTraversalFlowControlNode>>& recurrentInfo, const ComputationNodeBasePtr& node)
{
    // look in all recurrent loops of the network
    // TODO: Check for IsPartOfLoop(). Also why not store the loop id in the node for direct lookup?
    for (auto& iter : recurrentInfo)
    {
        if (std::find(iter->m_nestedNodes.begin(), iter->m_nestedNodes.end(), node) != iter->m_nestedNodes.end()) // TODO: should this loop need to be a method of SEQTraversalFlowControlNode?
            return iter;
    }
    return nullptr; // not part of a recurrent loop
}

// check if any of the nodes in the recurrence IsOutOfDateWrtInputs(), with exception of delay nodes for which this check would fail and must be skipped
// TODO: Would it be sufficient to check against our own time stamp, so that we can use a unified time-stamping mechanism? Then we'd not need this special check for delayed nodes; just check all inputs against our own time stamp.
bool ComputationNetwork::SEQTraversalFlowControlNode::IsOutOfDateWrtInputs() const
{
    for (auto& ptr : m_nestedNodes)
    {
        if (ptr->IsOutOfDateWrtInputs() &&
            ptr->OperationName() != OperationNameOf(PastValueNode) &&
            ptr->OperationName() != OperationNameOf(FutureValueNode))
            // TODO: when ShiftNode lands, check this as well. Ideally just test whether ptr is a IRecurrentNode
        {
            return true;
        }
    }
    return false;
}

// TODO: do this on PARTraversalFlowControlNode
void ComputationNetwork::ResetEvalTimeStamps()
{
    for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
        nodeIter->second->ResetEvalTimeStamp();
}

// Set EvalTimeStamp of all nodes as outdated, so that each node will be evaluated at least once.
// The ResetEvalTimeStamps() above cannot do the work, since it only (re)sets the node to the current
// global timestamp, which could be updated by other threads, so that the nodes of the network might
// have different timestamps and the nodes with a higher timestamps are not treated as "outdated".
void ComputationNetwork::SetEvalTimeStampsOutdatedWithRegardToAll()
{
    for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
        nodeIter->second->SetEvalTimeStampOutdatedWrtAll();
}

/*static*/ void ComputationNetwork::BumpEvalTimeStamp(const vector<ComputationNodeBasePtr>& nodes)
{
    for (size_t i = 0; i < nodes.size(); i++)
        nodes[i]->BumpEvalTimeStamp();
}

// for debugging
void ComputationNetwork::PrintComputationTree(const ComputationNodeBasePtr& rootNode,
                                              const bool forwardCompute,
                                              const bool printMatrices)
{
    auto nodes = GetEvalOrder(rootNode); // note: don't take a reference, since we reverse() below
    if (forwardCompute)
    {
        fprintf(stderr, "\n\nPrinting forward-computation node order ... \n");
    }
    else
    {
        fprintf(stderr, "\n\nPrinting gradient-computation node order ... \n");
        nodes.reverse();
    }

    if (nodes.size() == 0)
        fprintf(stderr, "\n(empty)\n");
    else
    {
        for (const auto& node : nodes)
            node->PrintSelf(printMatrices);
    }
}

// -----------------------------------------------------------------------
// preparation of network
// -----------------------------------------------------------------------

// called by model editing operations, such as DeleteNode(); and by RebuildNetwork()
// These invalidates any post-processed structures. If they are accessed, we will fail.
void ComputationNetwork::InvalidateCompiledNetwork()
{
    m_isCompiled = false;
    m_allSEQNodes.clear();
    m_evalOrders.clear();
    m_nestedNetworks.clear();
    m_inputValues.clear();
    m_learnableParameters.clear();
}

// verify that network has undergone CompileNetwork()
void ComputationNetwork::VerifyIsCompiled(const char* where) const
{
    if (!IsCompiled())
        LogicError("%s: A compiled network was expected.", where);
}

// CompileNetwork() -- bring network into executable state
// Call this after creation, load, and any modification.
// This method sets up all members that are cleared in InvalidateCompiledNetwork();
// TODO: This is in a somewhat partial state in that we now have a global eval order (keyed by a nullptr), but don't use it yet.
void ComputationNetwork::CompileNetwork()
{
    if (TraceLevel() > 0)
        fprintf(stderr, "\nPost-processing network...\n");

    // We may only get here if not !IsCompiled(). We could now verify each member to be virgin.
    // Or just invalidate it again, which is easier and safer.
    InvalidateCompiledNetwork();

    // all steps below have to be repeated for all root nodes (=nodes without parents and PreComputeNodes)
    DetermineSetOfAllRoots();

    if (TraceLevel() > 0)
    {
        fprintf(stderr, "\n%d roots:\n", (int)m_allRoots.size());
        for (const auto& root : m_allRoots)
            fprintf(stderr, "\t%ls = %ls()\n", root->NodeName().c_str(), root->OperationName().c_str());
    }

    // Note: Steps below are loops over root nodes. We will gradually push those loops through to the functions,
    //       to reduce redundant operation on shared portions of the network.

    // STEP: Create a depth-first tree-traversal order through complete graph.
    // TODO: Do not cache this before reordering; get list & pass to FormRecurrentLoops() which reorders it, then store it (such that GetEvalOrder(nullptr) is always valid w.r.t. loops).
    FormEvalOrder(nullptr);

    // STEP: Form the m_inputValues and m_learnableParameters sets for the entire network.
    // Needed for ResetMBLayouts() below.
    // TODO: Move this further down; or decide whether the 'nullptr' version is needed, other than ResetMBLayouts() which could use the global order and filter by itself.
    CollectInputAndLearnableParameters(nullptr);

    // STEP: Establish time-axis relationships.
    // This sets all MBLayout pointers of Input nodes according to user spec of time axes.
    // TODO: Don't use m_inputValues, traverse ourselves, to remove dependency on FormEvalOrder().
    ResetMBLayouts();

    // STEP: Discover nested loops.
    FormRecurrentLoops();

    // STEP: Create loop-corrected depth-first traversals and cached input/parameter sets for every actual root node.
    for (auto& root : m_allRoots)
    {
        FormEvalOrder(root);
        CollectInputAndLearnableParameters(root);
    }

    // STEP: Form nested structure of PAR and SEQ traversal nodes.
    for (auto& node : m_allRoots)
        FormNestedNetwork(node);

    // STEP: Infer node dimensions.
    ValidateNetwork();

    // STEP: Optimize the network.
    // :)

    // STEP: Some final details.
    ResetEvalTimeStamps(); // invalidate all m_value fields. Really belongs into StartEvaluateMinibatchLoop()

    if (TraceLevel() > 0)
        fprintf(stderr, "\nPost-processing network complete.\n\n");

    m_isCompiled = true;
}

// determine the set of all root nodes
// Roots are nodes that ForwardProp() may be called for.
//  - training criterion, eval criteria
//  - outputs
//  - PreComputeNodes
// Result is stored in m_allRoots.
// BUGBUG: In the current implementation, outputs that are also inputs to others must be specified explicitly e.g. by a tag.
void ComputationNetwork::DetermineSetOfAllRoots()
{
    // start with all non-referenced nodes
    set<ComputationNodeBasePtr> allNodes, referencedNodes;
    for (const auto& iter : m_nameToNodeMap)
    {
        auto node = iter.second;
        allNodes.insert(node);
        for (size_t i = 0; i < node->GetNumInputs(); i++)
        {
            auto input = node->Input(i);
            if (!input) // this may be the result of an incorrect MEL operation
            {
                InvalidArgument("DetermineSetOfAllRoots: Input %d of %ls %ls operation is not connected, network is malformed.",
                                (int) i, node->NodeName().c_str(), node->OperationName().c_str());
            }
            referencedNodes.insert(input);
        }
    }
    set<ComputationNodeBasePtr> unreferencedNodes;
    set_difference(allNodes.begin(), allNodes.end(), referencedNodes.begin(), referencedNodes.end(), inserter(unreferencedNodes, unreferencedNodes.end()));

    // add in all explicitly specified nodes.
    // TODO: This is not ideal. We will also need on-demand compilation, to allow any node to be used as an output after the fact.
    set<ComputationNodeBasePtr> allKnownRoots;
    for (const auto& node : FinalCriterionNodes())
        allKnownRoots.insert(node);
    for (const auto& node : EvaluationNodes())
        allKnownRoots.insert(node);
    for (const auto& node : OutputNodes())
        allKnownRoots.insert(node);
    for (const auto& iter : m_nameToNodeMap) // PreComputeNodes
    {
        auto node = iter.second;
        if (node->RequiresPreCompute())
            allKnownRoots.insert(node);
    }

    // set m_allRoots to include both non-referenced nodes and also all explicitly specified roots
    m_allRoots.clear();
    set_union(unreferencedNodes.begin(), unreferencedNodes.end(), allKnownRoots.begin(), allKnownRoots.end(), inserter(m_allRoots, m_allRoots.end()));

    // and bring the roots into a well-defined order
    // I did observe different order depending on complexity of non-Node BrainScript expressions.
    sort(m_allRoots.begin(), m_allRoots.end(),[](const ComputationNodeBasePtr& a, const ComputationNodeBasePtr& b)
    {
        return a->NodeName() < b->NodeName();
    });
}

// initial setup of MBLayout pointers
//  - link all input nodes to one or more MBLayouts
//  - reset all others to nullptr, in expectation of a ValidateNetwork() pass
void ComputationNetwork::ResetMBLayouts()
{
    // reset to a well-defined MBLayout (any meaningful layout should do here)
    // Note that Validate is never called during operation. Any actual computation will lead to MBLayout to be set.
    m_pMBLayoutOfNetwork->Init(1, 0);

    // first reset all
    for (const auto& node : GetAllNodesForRoot(nullptr))
        node->LinkToMBLayout(nullptr);

    // DynamicAxis nodes are (apart from the soon-to-be-deprecated network-wide MBLayout) the main holders of MBLayouts. Initialize them.
    // The only other instances are nodes that change the MBLayout, like WhereNode. 
    for (auto node : GetNodesWithType(L"DynamicAxis"))
        node->LinkToMBLayout(make_shared<MBLayout>(1, 0, node->GetName()));

    // This is now initialized inside of the Input nodes, with the proper connections.
    for (auto node : InputNodes(nullptr))
    {
        // TODO: use if (!Is<ITakesDynamicAxis>(node))...
        auto n = dynamic_pointer_cast<ITakesDynamicAxis>(node);
        if (!n)
            LogicError("Expected %ls to implement ITakesDynamicAxis, but it doesn't.", node->NodeDescription().c_str());
        std::wstring axisName = n->GetRequestedDynamicAxis();

        if (axisName == L"")
        {
            // Legacy behavior: One shared MBLayout
            // TODO Remove m_pMBLayoutOfNetwork altogether. See issue 358.
            node->LinkToMBLayout(m_pMBLayoutOfNetwork);
        }
        else
        {
            auto axisNode = GetNodeFromName(axisName);

            if (!axisNode)
                RuntimeError("%ls: Can't find node '%ls' for retrieving dynamic axis.", axisNode->NodeDescription().c_str(), axisName.c_str());

            // For now we require the node to be a DynamicAxisNode, though we could derive the same from other nodes. This would involve
            // more dependencies on the order in which things are evaluated, though.
            if (axisNode->OperationName() != L"DynamicAxis")
                RuntimeError("%ls: dynamicAxis argument must be of type DynamicAxis(), but got %ls.", node->NodeDescription().c_str(), axisNode->NodeDescription().c_str());
            if (!axisNode->HasMBLayout())
                LogicError("%ls: Expected %ls to have MBLayout, but it doesn't.", node->NodeDescription().c_str(), axisNode->NodeDescription().c_str());
            node->LinkToMBLayout(axisNode->GetMBLayout());
        }
    }
}

// -----------------------------------------------------------------------
// validation
// -----------------------------------------------------------------------

// validate sub-network needed to evalute a specific output node
// This calls Validate() on every node in evaluation order (allowing to propagate things forwards through the net).
// This is called lazily but once only per node until next ClearCache().
// MBLayout links are expected to have been set up already for inputs, and reset to nullptr for all other nodes.
void ComputationNetwork::ValidateNetwork()
{
    // we call all nodes' Validate() in order to validate, that is, set up MBLayout and FunctionValues dimension
    // A problem is that recurrent loops may require partial validation.
    // Nodes validated on partial input (i.e. some children not yet validated) will be revisited.
    const auto& nodes = GetEvalOrder(nullptr);

    for (auto& node : nodes)
    {
        node->m_visited = false;
        node->m_needsGradient = node->IsParameterUpdateRequired(); // these get propagated upwards in the following
    }

    // loop and validate until we are done
    // steps:
    //  - validate (not final)          // not final means no dimension checks
    //    Keep going through the list until all nodes have been validated and all inputs have been validated as well.
    //  - validate (final)              // final means consistency checks
    //    Fail if any change during this stage.
    size_t pass = 1;
    size_t toValidate = nodes.size();
    while (toValidate > 0)
    {
        if (TraceLevel() > 0)
        fprintf(stderr, "\nValidating network. %d nodes to process in pass %d.\n\n", (int) toValidate, (int) pass);
        toValidate = ValidateNodes(nodes, /*isFirstPass=*/pass == 1, false /*isFinalValidationPass*/);
        pass++;
    }
    if (TraceLevel() > 0)
    fprintf(stderr, "\nValidating network, final pass.\n\n");
    toValidate = ValidateNodes(nodes, /*isFirstPass=*/pass == 1, true /*isFinalValidationPass*/);
    if (toValidate != 0)
        LogicError("ValidateSubNetwork: ValidateNodes(true) unexpectedly returned with work left to do.");

    // propagate some info to SEQTraversalFlowControlNode
    // TODO: In the future we should validate not on the flat list but the PARTraversalFlowControlNode structure. Then this will be unnecessary.
    for (auto& recInfo : m_allSEQNodes)
    {
        auto& node = recInfo->m_sourceNode;
        recInfo->m_needsGradient = node->m_needsGradient;
        recInfo->LinkToMBLayout(node->GetMBLayout());
    }

    for (auto& node : nodes)
    {
        // nodes must output non-zero dimensional data, otherwise assume user error
        if (!node->m_needsDynamicValidation && node->GetSampleLayout().GetNumElements() == 0)
            RuntimeError("%ls operation has 0 elements", node->NodeName().c_str());
    }
    if (TraceLevel() > 0)
        fprintf(stderr, "\n\n");

    // logging the non-default-layout nodes
    vector<ComputationNodeBasePtr> nonDefaultNodes;
    for (auto node : nodes)
    {
        if (!(node->GetMBLayout() == m_pMBLayoutOfNetwork))
            nonDefaultNodes.push_back(node);
    }
#if 0 // this message is no longer necessary
    if (TraceLevel() > 0 && !nonDefaultNodes.empty())
    {
        fprintf(stderr, "%d out of %d nodes do not share the minibatch layout with the input data.\n", (int)nonDefaultNodes.size(), (int)nodes.size());
        // for (auto node : nonDefaultNodes)
        //    fprintf(stderr, "    %ls\n", node->NodeName().c_str());
        // fprintf(stderr, "\n\n");
    }
#endif
}

// helper to discover dimension changes
static pair<TensorShape, bool> GetDims(const ComputationNodeBasePtr& node)
{
    return make_pair(node->GetSampleLayout(), node->HasMBLayout());
}

bool ComputationNetwork::ValidateNode(ComputationNodeBasePtr node, bool isFinalValidationPass) const
{
    const auto& children = node->GetInputs();

    // keep state
    MBLayoutPtr oldMBLayoutPtr = node->GetMBLayout();
    auto dim = GetDims(node);
    vector<pair<TensorShape, bool>> childDims;
    for (auto& child : children)
        childDims.push_back(GetDims(child));
    auto sampleLayout = node->GetSampleLayout();

    // also take the opportunity to propagate m_needsGradient and m_nodeNeedsDynamicValidation
    auto nodeNeedsDynamicValidation = node->NeedsDynamicValidation();
    node->m_needsDynamicValidation |= node->ForceDynamicValidation();
    auto needsGradient = node->m_needsGradient;
    for (auto& child : children) // TODO: do we need a check that this is stable if isFinalValidationPass?
    {
        node->m_needsGradient |= child->m_needsGradient;
        node->m_needsDynamicValidation |= child->m_needsDynamicValidation;
    }

    // We do call validate(final) as many times as needed, since stuff may have changed underneath.
    node->Validate(isFinalValidationPass && !node->m_needsDynamicValidation /*final*/); // all nodes have been visited: do verification instead of just inference

    // check state --node will be valid if all nodes have been visited and node has not been updated
    bool unchanged = true;
    unchanged &= (oldMBLayoutPtr == node->GetMBLayout());
    unchanged &= (dim == GetDims(node));
    vector<pair<TensorShape, bool>> newChildDims;
    for (auto& child : children)
        newChildDims.push_back(GetDims(child));
    unchanged &= (childDims == newChildDims);
    unchanged &= (sampleLayout == node->GetSampleLayout());
    unchanged &= (needsGradient == node->m_needsGradient);
    unchanged &= (nodeNeedsDynamicValidation == node->m_needsDynamicValidation);
    return !unchanged;
}

// perform one pass of validation over the topologically-sorted node set
// returns how many nodes either could not yet be validated yet or have changed and thus must be redone
size_t ComputationNetwork::ValidateNodes(list<ComputationNodeBasePtr> nodes, bool isFirstPass, bool isFinalValidationPass)
{
    size_t todo = 0;
    for (auto& node : nodes)
    {
        const auto& children = node->GetInputs();
        const bool isLeaf = node->IsLeaf();
        // only validate a node if it has at least one child
        bool hasVisitedChild = false;
        bool allChildrenVisited = true;
        for (auto& child : children)
        {
            hasVisitedChild |= child->m_visited; // if not a single visited child then no point in validating
            allChildrenVisited &= child->m_visited;

            // Make sure we don't use DynamicAxis in places where it was not designed for.
            // This is a stop-gap. We need a more coherent concept for passing of shapes.
            if (child->OperationName() == L"DynamicAxis")
                RuntimeError("%ls: Cannot be used as input to another node. It can only be used on the 'dynamicAxis' property of an Input node.", child->NodeDescription().c_str());
        }

        // if there is not at least one visited child
        bool valid = false;
        if (hasVisitedChild || isLeaf) // got at least one child: it makes sense to call Validate()
        {
            string prevPrototype = node->FormatOperationPrototype("");
            bool unchanged;
            try
            {
                unchanged = !ValidateNode(node, isFinalValidationPass);
                string updatedPrototype = node->FormatOperationPrototype("");
#if 0           // print prototype in final validation pass. Problematic for tracking down validation errors in loops.
                unchanged;
                if (isFinalValidationPass)
#else           // print prototype upon every change (useful for debugging)
                if (isFirstPass || !unchanged || prevPrototype != updatedPrototype)
#endif
                    if (TraceLevel() > 0)
                    fprintf(stderr, "Validating --> %s\n", updatedPrototype.c_str());
            }
            catch (...) // if validation failed then print the prototype anyway so one can see the input args
            {
                fprintf(stderr, "Validating --> %s FAILED\n", prevPrototype.c_str());
                throw;
            }
            node->m_visited = true;
            // print the new type
            // sanity checks
            if (isFinalValidationPass && !unchanged)
                LogicError("ValidateSubNetwork: %ls %ls operation changed during final validation.", node->NodeName().c_str(), node->OperationName().c_str());
            if (isFinalValidationPass && !allChildrenVisited)
                LogicError("ValidateSubNetwork: %ls %ls operation in final validation although not all children were visited?", node->NodeName().c_str(), node->OperationName().c_str());
            // if all children valid then
            valid = (allChildrenVisited && unchanged) || isLeaf;
        }
        // count those that we need to redo
        if (!valid)
            todo++;
    }
    return todo;
}

// -----------------------------------------------------------------------
// memory allocation
// -----------------------------------------------------------------------
// mark nodes that are purely induced by parameters as non-sharable and create space for value if null
void ComputationNetwork::MarkValueNonSharableNodes()
{
    const auto& nodes = GetEvalOrder(nullptr);
    std::map<wstring, bool> allLeafDescendentsAreParametersOrPreComputeNodes;
    std::list<ComputationNodeBasePtr> allLearnableParameters = GetNodesWithType(OperationNameOf(LearnableParameter));
    // note that: we cannot use m_learnableParameters because we need all parameters node, regardless whether it requires update or not

    std::list<ComputationNodeBasePtr> allPreComputeNodes;
    for (const auto& node : nodes)
    {
        if (node->Is<IPreComputeNode>())
            allPreComputeNodes.push_back(node);
    }

    for (auto& node : nodes)
    {
        auto inputs = node->GetInputs();

        // Mark the UserDefinedV2FunctionNode and all its inputs as ValueNonShareable, since
        // the inputs and outputs of a UDF may be externally preserved by the UDF implementation
        // for bakcpropagation and thus reusing them within the network is not possible as 
        // we do not control when the user actually releases the input/output Matrices that
        // they may have help in the backprop state returned from the UDF's forward pass.
        bool isUserDefinedV2FunctionNode = (node->OperationName() == L"UserDefinedV2Function");
        if (isUserDefinedV2FunctionNode)
        {
            node->MarkValueNonSharable();
            for (auto input : inputs)
                input->MarkValueNonSharable();
        }

        wstring myname = node->NodeName();
        bool allParametersOrPreComputeNodes = true;

        if (inputs.size()) // we don't do the check for leaf node, cause all the possible leaf nodes (input/parameters/precompute node) are marked as non-sharable already
        {
            if (std::find(allPreComputeNodes.begin(), allPreComputeNodes.end(), node) == allPreComputeNodes.end())
            {
                for (auto input : inputs)
                {
                    const auto& inputName = input->NodeName();
                    if (allLeafDescendentsAreParametersOrPreComputeNodes.find(inputName) == allLeafDescendentsAreParametersOrPreComputeNodes.end())
                    {
                        // not found, means it is a leaf node (we are at eval order )
                        assert(input->IsLeaf() || input->IsPartOfLoop());
                        if (std::find(allLearnableParameters.begin(), allLearnableParameters.end(), input) != allLearnableParameters.end())
                        {
                            allLeafDescendentsAreParametersOrPreComputeNodes[inputName] = true;
                        }
                        else
                        {
                            allParametersOrPreComputeNodes = false;
                            allLeafDescendentsAreParametersOrPreComputeNodes[inputName] = false;
                            break;
                        }
                    }
                    else
                    {
                        if (allLeafDescendentsAreParametersOrPreComputeNodes[inputName] == false)
                        {
                            allParametersOrPreComputeNodes = false;
                            break;
                        }
                    }
                }
            }

            allLeafDescendentsAreParametersOrPreComputeNodes[myname] = allParametersOrPreComputeNodes;
            if (allParametersOrPreComputeNodes)
                node->MarkValueNonSharable();
        }
    }
}

// From the set of nodes extract all nodes which are used as accumulator nodes.
set<ComputationNodeBasePtr> ComputationNetwork::ExtractNodesWhichAccumulateResult(set<ComputationNodeBasePtr> candidates)
{
    const auto& nodes = GetEvalOrder(nullptr);

    // Set of nodes which leaf descendants are learnable parameters only. Initially, we add learnable parameter nodes to
    // the set. Later, we add all nodes which have all children nodes from this list.
    auto allLearnableParameters = GetNodesWithType(OperationNameOf(LearnableParameter));
    set<ComputationNodeBasePtr> allLeafDescendantsAreLearnableParameters(allLearnableParameters.begin(),
                                                                         allLearnableParameters.end());

    // Set of nodes that accumulate samples from the input.
    // We initially add all epoch accumulator nodes to this list. Later, we add all nodes that have at least one child
    // node from this list and all other child nodes from the list above (whose leaf descendants are learnable
    // parameters only). Combination of accumulator node with another accumulator node, or with nodes whose leaf
    // descendants are learnable parameter is also accumulator node.
    auto allEpochAccumulatorNodes = GetNodesWithType(OperationNameOf(EpochAccumulatorNode));
    set<ComputationNodeBasePtr> accumulatorNodes(allEpochAccumulatorNodes.begin(), allEpochAccumulatorNodes.end());

    if (!accumulatorNodes.empty())
    {
        for (auto& node : nodes)
        {
            auto inputs = node->GetInputs();
            bool hasAccumulatorInput = false;
            bool areAllLeafDescendantsLearnableNodes = true;
            // Indicates that node shouldn't be added to set of nodes with learnable parameter descendants, nor to the
            // set of accumulator nodes.
            bool skipNode = false;
            for (auto input : inputs)
            {
                bool hasAllLearnableParameterDescendents = allLeafDescendantsAreLearnableParameters.find(input) !=
                                                           allLeafDescendantsAreLearnableParameters.end();
                bool isAccumulatorNode = accumulatorNodes.find(input) != accumulatorNodes.end();
                if (!hasAllLearnableParameterDescendents)
                    areAllLeafDescendantsLearnableNodes = false;
                if (isAccumulatorNode)
                    hasAccumulatorInput = true;
                if (!isAccumulatorNode && !hasAllLearnableParameterDescendents)
                {
                    skipNode = true;
                    break;
                }
            }
            if (skipNode) continue;
            if (areAllLeafDescendantsLearnableNodes)
                allLeafDescendantsAreLearnableParameters.insert(node);
            if (hasAccumulatorInput)
                accumulatorNodes.insert(node);
        }
    }

    // Extract all candidate nodes that appear in set of accumulator nodes.
    set<ComputationNodeBasePtr> intersection;
    set_intersection(accumulatorNodes.begin(), accumulatorNodes.end(),
                     candidates.begin(), candidates.end(),
                     inserter(intersection, intersection.begin()));

    return intersection;
}

// print memory-sharing information to log
void ComputationNetwork::PrintMemorySharingStructure(const vector<ComputationNodeBasePtr>& nodes)
{
    map <const MatrixBase*, set<wstring>> memSharingStructure;
    size_t numMatrices = 0;
    for (const auto& node : nodes)
    {
        set<pair<const MatrixBase*, wstring>> matrixInfo = node->GetMatrixInfo();
        for (const auto& item : matrixInfo) // {value} or {value, gradient}
        {
            memSharingStructure[item.first].insert(item.second);
            numMatrices++;
        }
    }

    // count shared/unshared
    size_t numShared = 0;
    size_t numUnshared = 0;
    for (const auto& item : memSharingStructure)
    {
        if (item.second.size() < 2) // unshared matrices
            numUnshared++;
        else                        // shared matrices
            numShared++;
    }

    fprintf(stderr, "\nMemory Sharing: Out of %d matrices, %d are shared as %d, and %d are not shared.\n", (int)numMatrices, (int)(numMatrices - numUnshared), (int)numShared, (int)numUnshared);

    fprintf(stderr, "\nHere are the ones that share memory:\n"); 
    for (const auto& item : memSharingStructure)
    {
        if (item.second.size() >= 2)
        {
            // Format:
            // { node1
            //   node2 }
            // { node3
            //   node4
            //   node5 }
            const char* delim = "\t{ ";
            for (const auto& memShareInfo : item.second)
            {
                fprintf(stderr, "%s%ls", delim, memShareInfo.c_str());
                delim = "\n\t  ";
            }
            fprintf(stderr, " }\n");
        }
    }
    fprintf(stderr, "\nHere are the ones that don't share memory:\n");
    for (const auto& item : memSharingStructure)
    {
        if (item.second.size() < 2)
        {
            fprintf(stderr, "\t{%ls}\n", item.second.begin()->c_str()); 
        }
    }
    fprintf(stderr, "\n");
}


// this function will need to be called before actual validation and execution to
// predetermine how to share matrices to reduce memory usage.
// TODO: find a simple topological order and allocateEvalMatrices on that order directly
// without passing in eval, out, and train nodes.
void ComputationNetwork::AllocateAllMatrices(const std::vector<ComputationNodeBasePtr>& evalRootNodes,
                                             const std::vector<ComputationNodeBasePtr>& outValueRootNodes,
                                             ComputationNodeBasePtr trainRootNode)
{
    if (AreMatricesAllocated())
        return;

    // Allocate memory for forward/backward computation
    if (TraceLevel() > 0)
        fprintf(stderr, "\n\nAllocating matrices for forward and/or backward propagation.\n");

    VerifyIsCompiled("AllocateAllMatrices");

    std::vector<ComputationNodeBasePtr> forwardPropRoots;
    forwardPropRoots.insert(forwardPropRoots.end(), evalRootNodes.begin(), evalRootNodes.end());
    forwardPropRoots.insert(forwardPropRoots.end(), outValueRootNodes.begin(), outValueRootNodes.end());
    if (trainRootNode != nullptr)
        forwardPropRoots.push_back(trainRootNode);

    // Mark all the eval, output and criterion roots as non-shareable
    for (auto& rootNode : forwardPropRoots)
        rootNode->MarkValueNonSharable();

    // Due to special topology, if a node is solely induced by parameters, its function value should not be shared
    MarkValueNonSharableNodes();

    bool performingBackPropagation = (trainRootNode != nullptr);

    // Create a composite Eval order with the specified nodes as roots
    // For each node determine parents and whether the output of the
    // node is needed during back propagation
    std::unordered_map<ComputationNodeBasePtr, bool> outputValueNeededDuringBackProp;
    std::unordered_map<ComputationNodeBasePtr, std::unordered_set<ComputationNodeBasePtr>> parentsMap;
    std::unordered_set<ComputationNodeBasePtr> uniqueForwardPropEvalNodes;
    for (auto& rootNode : forwardPropRoots)
    {
        for (const auto& node : GetEvalOrder(rootNode))
        {
            if (uniqueForwardPropEvalNodes.find(node) == uniqueForwardPropEvalNodes.end())
                uniqueForwardPropEvalNodes.insert(node);

            for (int i = 0; i < node->GetNumInputs(); i++)
            {
                ComputationNodeBasePtr input = node->GetInputs()[i];
                parentsMap[input].insert(node);

                if (performingBackPropagation)
                {
                    if (outputValueNeededDuringBackProp.find(input) == outputValueNeededDuringBackProp.end())
                        outputValueNeededDuringBackProp[input] = input->NeedsGradient() && input->OutputUsedInComputingInputNodesGradients();

                    outputValueNeededDuringBackProp[input] |= (node->NeedsGradient() && node->InputUsedInComputingInputNodesGradients(i));
                }
                else
                    outputValueNeededDuringBackProp[input] = false;
            }
        }
    }

    // gradient reuse maps
    std::unordered_map<MatrixPool::AliasNodePtr, std::unordered_set<MatrixPool::AliasNodePtr>> gradientReuseChildrenMap;
    std::unordered_map<MatrixPool::AliasNodePtr, MatrixPool::AliasNodePtr> gradientReuseParentMap;
    for (auto& keyValue : parentsMap)
    {
        // Indicate on the node that it's parent overwrites its gradient if the node is not part of a loop
        // and has exactly one parent who implements the gradient overwrite optimization
        if (Globals::ShouldOptimizeGradientAccumulation() &&
            !keyValue.first->IsPartOfLoop() &&
            (keyValue.second.size() == 1))
        {
            auto parent = *keyValue.second.begin();
            auto opt = parent->ImplementsGradientOptimization(keyValue.first.get());
            if (opt != ParentGradientOptimization::None && trainRootNode != parent)
            {
                // We cannot enable the gradient overwrite/reuse optimization if this node's (lone) parent
                // has this same node as multiple of its inputs since, in that case the
                // gradients will flow back from multiple paths of the same parent into the input

                auto& allInputsOfParent = parent->GetInputs();
                if (std::count(allInputsOfParent.begin(), allInputsOfParent.end(), keyValue.first) <= 1)
                {
                    auto child = keyValue.first;
                    child->SetParentGradientOptimization(opt);
                    if (opt == ParentGradientOptimization::Reuse)
                    {
                        gradientReuseChildrenMap[&*parent].insert(&*child);
                        if (gradientReuseParentMap.find(&*child) != gradientReuseParentMap.end())
                            LogicError("Already has a gradient reuse parent.");
                        gradientReuseParentMap[&*child] = &*parent;
                    }
                }
            }
        }
    }

    m_matrixPool.Reset();

    TravserseInSortedGlobalEvalOrder(forwardPropRoots, [&outputValueNeededDuringBackProp, &parentsMap, this](const ComputationNodeBasePtr& node) {
        if (node->Is<SEQTraversalFlowControlNode>())
        {
            auto seqTraversalFlowControlNode = node->As<SEQTraversalFlowControlNode>();
            for (auto& loopNode : seqTraversalFlowControlNode->m_nestedNodes)
                loopNode->SetOutputNeededDuringBackprop(outputValueNeededDuringBackProp[loopNode]);

            seqTraversalFlowControlNode->RequestMatricesBeforeForwardProp(m_matrixPool);

            for (auto& loopNode : seqTraversalFlowControlNode->m_nestedNodes)
                ReleaseMatricesAfterEvalForChildren(loopNode, parentsMap);
        }
        else
        {
            node->SetOutputNeededDuringBackprop(outputValueNeededDuringBackProp[node]);
            node->RequestMatricesBeforeForwardProp(m_matrixPool);
            // we only release matrices for the children since the root node's information will be used
            // and should not be shared with others
            ReleaseMatricesAfterEvalForChildren(node, parentsMap);
        }
    });

    if (trainRootNode != nullptr)
    {
        const std::list<ComputationNodeBasePtr>& backPropNodes = GetEvalOrder(trainRootNode);

        // compact the alias map for cases like s = a + b + c + d

        std::unordered_map<MatrixPool::AliasNodePtr, std::unordered_set<MatrixPool::AliasNodePtr>> compactGradientAliasMap;
        std::unordered_map<MatrixPool::AliasNodePtr, MatrixPool::AliasNodePtr> compactGradientAliasRootMap;
        for (const auto& gradientReuseKeyValue : gradientReuseChildrenMap)
        {
            // keep searching parent until reaching root

            auto parent = gradientReuseKeyValue.first;
            auto parentIter = gradientReuseParentMap.find(parent);
            while (parentIter != gradientReuseParentMap.end())
            {
                parent = parentIter->second;
                parentIter = gradientReuseParentMap.find(parent);
            }

            // add children to the alias group under the root

            auto children = gradientReuseKeyValue.second;
            compactGradientAliasMap[parent].insert(children.begin(), children.end());

            for (const auto& child : children)
            {
                if (compactGradientAliasRootMap.find(child) != compactGradientAliasRootMap.end())
                    LogicError("one node cannot be in two alias group");

                compactGradientAliasRootMap[child] = parent;
            }

            // and add root itself to the alias group

            compactGradientAliasMap[parent].insert(parent);
            compactGradientAliasRootMap[parent] = parent;
        }

        // print the memory aliasing info
        if (TraceLevel() > 0 && compactGradientAliasRootMap.size() > 0)
        {
            fprintf(stderr, "\nGradient Memory Aliasing: %d are aliased.\n", (int)compactGradientAliasRootMap.size());
            for (const auto pair : compactGradientAliasRootMap)
            {
                auto child = (const ComputationNodeBase*)pair.first;
                auto parent = (const ComputationNodeBase*)pair.second;
                if (child != parent)
                    fprintf(stderr, "\t%S (gradient) reuses %S (gradient)\n", child->GetName().c_str(), parent->GetName().c_str());
            }
        }

        m_matrixPool.SetAliasInfo(compactGradientAliasMap, compactGradientAliasRootMap);

        // now, simulate the gradient computation order to determine how to allocate matrices
        set<ComputationNodeBasePtr> completedGradient;

        // we need to call it here since we always compute gradients for children and root node is not children of other node
        trainRootNode->RequestMatricesBeforeBackprop(m_matrixPool);

        for (auto iter = backPropNodes.rbegin(); iter != backPropNodes.rend(); iter++) // for gradient computation, traverse in reverse order
        {
            auto n = *iter;
            if (n->IsPartOfLoop())
            {
                std::vector<ComputationNodeBasePtr> recurrentNodes;
                shared_ptr<SEQTraversalFlowControlNode> recInfo = FindInRecurrentLoops(m_allSEQNodes, n);
                if (completedGradient.insert(recInfo).second)
                {
                    // SEQ mode: allocate all in loop first, then deallocate again
                    // TODO: next step: use PARTraversalFlowControlNode::AllocateGradientMatricesForInputs() and ReleaseMatricesAfterBackprop()...
                    // BUGBUG: naw, ^^ would not work! Wrong order! Need to rethink this. Need to make AllocateEvalMatrices() and AllocateGradientMatrices() the virtual functions.
                    recInfo->AllocateGradientMatricesForInputs(m_matrixPool);
                    // Loops are computed sample by sample so we have to allocate them all
                    recInfo->ReleaseMatricesAfterBackprop(m_matrixPool);
                }
            }
            else
            {
                // PAR mode: we can allocate and immediately deallocate one by one
                n->AllocateGradientMatricesForInputs(m_matrixPool);
                // Root node's information will be used and should not be shared with others, also it's small (1x1)
                if ((n != trainRootNode) && n->NeedsGradient())
                    n->ReleaseMatricesAfterBackprop(m_matrixPool);
            }
        }
    }

    m_matrixPool.OptimizedMemoryAllocation(); 
    m_areMatricesAllocated = true;

    // TO DO: At the time of AllocateAllMatrices we don't know the minibatch size. In theory one may allocate memory again once we start to receive
    // data from the reader (and the minibatch size is known). For some problems, minibatch size can change constantly, and there needs to be a 
    // tradeoff in deciding how frequent to run optimized memory allocation. For now, we do it only once at the very beginning for speed concerns. 

    // TO DO: when some matrices are sparse, the memory size request may be wrong. One may need to call OptimizedMemoryAllocation later again 
    // if the requests of sparse allocation and release are re-processed correctly. Future work. 

    // print the memory sharing structure
    if (TraceLevel() > 0)
        PrintMemorySharingStructure(GetAllNodes());
}

void ComputationNetwork::ReleaseMatricesAfterEvalForChildren(ComputationNodeBasePtr n, std::unordered_map<ComputationNodeBasePtr, std::unordered_set<ComputationNodeBasePtr>>& parentsMap)
{
    for (int i = 0; i < n->GetNumInputs(); i++)
    {
        ComputationNodeBasePtr pNode = n->GetInputs()[i];
        if (!parentsMap[pNode].empty())
        {
            parentsMap[pNode].erase(n);
            if (parentsMap[pNode].empty())
                pNode->ReleaseMatricesAfterForwardProp(m_matrixPool);
        }
    }
}

}}}
back to top