https://github.com/Microsoft/CNTK
Raw File
Tip revision: 5dffd712e05feec7198d7ced6777eeafc602a4a9 authored by REDMOND\sayanpa on 16 October 2017, 20:25:19 UTC
Added DSSM tutorial
Tip revision: 5dffd71
ComputationGraphAlgorithms.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include <vector>
#include <list>
#include <set>
#include <map>
#include <functional>
#include <memory>

//
// Header only algorithms for working with execution graphs.
// The functionality is used by the computational network and the code gen evaluation engine.
// Currently this is refactoring of existing legacy code.
// In the future we should consider using Boost::Graph instead, but this will require more testing
// in order not to break current behavior/baselines.
//
namespace CNTK
{
    //
    // Interface for a directed graph.
    // The graph can be traversed starting from the graph roots (usually nodes with no successors)
    // and using the predecessor information.
    //
    // TNode is a type that represents a graph vertex. It should be:
    //     - container friendly (set,vector, map):
    //         copyable
    //         define less operator
    //     - provide ToString function (used in erroneous situations for exception messages).
    //
    template<class TNode>
    class DirectedGraph
    {
    public:
        //
        // A list of predecessors for a given node.
        //
        virtual const std::vector<TNode>& Predecessors(const TNode& node) const = 0;

        //
        // A list of root nodes used as starting points for graph traversal.
        // Usually these are leafs, but can also be some inner nodes.
        //
        virtual const std::vector<TNode>& Roots() const = 0;

        virtual ~DirectedGraph() {}
    };

    //
    // Forward declaration of the main algorithms that are used for defining 
    // execution order of a computational network.
    // For the actual implementation please see the end of this file.
    //

    //
    // Returns a list of nodes reachable from 'startNodes' in the post-order.
    // Firstly it visits all predecessors of a starting node, then the node itself.
    // Starting nodes are evaluated in order and all nodes are visited exactly once.
    //
    template<class TNode>
    inline std::list<TNode> PostOrderTraversal(const DirectedGraph<TNode>& graph, const std::vector<TNode>& startNodes);

    //
    // Class representing a strongly connected component.
    //
    template<class TNode>
    struct StrongComponent final
    {
        StrongComponent(const std::vector<TNode>&& nodes) :
            m_nodes(std::move(nodes))
        {}

        //
        // Returns a list of nested nodes.
        //
        const std::vector<TNode>& Nodes() const
        {
            return m_nodes;
        }

        //
        // Updates the order of nested nodes.
        //
        void UpdateNodeOrder(std::vector<TNode>&& nodes)
        {
            assert(std::set<TNode>(m_nodes.begin(), m_nodes.end()) == std::set<TNode>(nodes.begin(), nodes.end()));
            m_nodes = std::move(nodes);
        }

        //
        // Checks if the node belongs to the component.
        //
        bool Contains(const TNode& node) const
        {
            return std::find(m_nodes.begin(), m_nodes.end(), node) != m_nodes.end();
        }

    private:
        std::vector<TNode> m_nodes;
    };

    //
    // Returns a list of strongly connected components in the graph.
    //
    template<class TNode>
    std::vector<StrongComponent<TNode>> StrongComponents(const DirectedGraph<TNode>& graph);

    //
    // Sorts nodes inside strong components for evaluation.
    // The order is defined as follows:
    //  - take a connected component
    //  - find all its nodes that feed only into delay nodes, these nodes become new roots
    //  - perform the topological sort starting at these roots and breaking at delay nodes
    //  - update the component with the reordered list of sorted nodes
    //
    template<class TNode>
    void EvaluationSort(const DirectedGraph<TNode>& graph, std::function<bool(const TNode&)> delay, std::vector<StrongComponent<TNode>>& strongComponents);

    //
    // Sorts all nodes of the graph in the evaluation order given by the root nodes.
    // Strongly connected componentes should be already sorted using EvaluationSort function.
    //
    template<class TNode>
    std::vector<TNode> GlobalEvaluationSort(const DirectedGraph<TNode>& graph, const std::vector<StrongComponent<TNode>>& strongComponents);

    //
    // Actual implementation of the above functions.
    //
    namespace Internal
    {
        // Functions from this namespace should not be used directly.

        //
        // Function performs post-order traversal of the graph and returns
        // collected nodes.
        //
        template<class TNode>
        static void PostOrderTraversalImpl(const DirectedGraph<TNode>& graph, const TNode& node, std::set<TNode>& visited, std::list<TNode>& result)
        {
            if (visited.find(node) != visited.end())
                return;

            visited.insert(node);
            for (const auto& p : graph.Predecessors(node))
                PostOrderTraversalImpl(graph, p, visited, result);
            result.push_back(node);
        }

        //
        // Helper struct used in StrongComponents function.
        // Contains additional information needed for Tarjan algorithm for
        // performing strong component search.
        // Same as in wikipedia,
        // please see https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
        //
        struct StrongComponentNodeState final
        {
            bool m_visited{ false };   // flag indicating whether the node was visited
            int m_index{ -1 };         // index denoting order in which nodes were visited
            int m_minIndex{ -1 };      // min of m_index over all nodes within a single component
            bool m_inStack{ false };   // flag indicating whether the node is still on the stack
        };

        //
        // Recursive implementation of the Tarjan algorithm for finding all stronly connected
        // components.
        //
        template<class TNode>
        void StrongComponentsImpl(
            const DirectedGraph<TNode>& graph,
            const TNode& node,
            std::stack<TNode>& nodeStack,
            size_t& index,
            std::map<TNode, Internal::StrongComponentNodeState>& state,
            std::vector<StrongComponent<TNode>>& strongComponents)
        {
            assert(!state[node].m_visited);

            // set the index (in order of visitation)
            // Each node is assigned a unique integer m_index, which numbers the nodes consecutively in the order in which they are discovered.
            state[node].m_index = index;
            state[node].m_minIndex = index;
            index++;

            state[node].m_visited = true;

            // The nodes are placed on the stack in the order in which they are visited.
            // When the depth-first search recursively explores a node 'node' and its descendants,
            // those nodes are not all necessarily popped from the stack when this recursive call returns.
            // The crucial invariant property is that a node remains on the stack after exploration if and only if it has a path to some node earlier on the stack.
            // At the end of the call that explores 'node' and its descendants, we know whether 'node' itself has a path to any node earlier on the stack.
            // If so, the call returns, leaving 'node' on the stack to preserve the stack invariant.
            // If not, then 'node' must be the root of its strongly connected component, which consists of 'node' together with any later nodes on the stack
            // (such nodes all have paths back to 'node' but not to any earlier node,
            // because if they had paths to earlier nodes then 'node' would also have paths to earlier nodes which is false).
            // This entire component is then popped from the stack and returned, again preserving the invariant. [Wikipedia]
            nodeStack.push(node);
            state[node].m_inStack = true;

            // set m_minIndex to min over m_minIndex of children
            // m_minIndex (lowlink in Tarjan's notation) represents (roughly speaking) the smallest index of any node known to be reachable from 'node', including 'node' itself. [Wikipedia]
            for (const auto& predecessor : graph.Predecessors(node))
            {
                if (!state[predecessor].m_visited)
                {
                    // predecessor w has not yet been visited; recurse on it
                    StrongComponentsImpl(graph, predecessor, nodeStack, index, state, strongComponents);
                    state[node].m_minIndex = std::min(state[node].m_minIndex, state[predecessor].m_minIndex);
                }
                else if (state[predecessor].m_inStack)
                {
                    // successor w is in stack S and hence in the current SCC
                    // NOTE! This line is actually different from the BS algorithm
                    state[node].m_minIndex = std::min(state[node].m_minIndex, state[predecessor].m_index);
                }
            }

            // if 'node' is a root node, then we closed a loop.
            // 'node' must be left on the stack if m_minIndex < m_index,
            // whereas it must be removed as the root of a strongly connected component if m_minIndex == m_index.
            // m_minIndex is computed during the depth-first search from 'node' (above), as this finds the nodes that are reachable from 'node'. [Wikipedia]
            assert(state[node].m_minIndex <= state[node].m_index);
            if (state[node].m_minIndex == state[node].m_index) // m_minIndex is still equal to m_index, as we set it at the start of this function: we closed a loop
            {
                // gather the list of all nodes in this loop
                std::vector<TNode> nestedNodes;

                for (;;)
                {
                    TNode current = nodeStack.top();
                    nodeStack.pop();

                    state[current].m_inStack = false;
                    nestedNodes.push_back(current);

                    if (current == node) // hit our starting point: done
                        break;
                }

                // not a real loop. In degenerate situation it could be that the delay
                // feeds directly into itself though, but then its still just returns the same value
                // so can be evaluated in a topological sort order.
                if (nestedNodes.size() <= 1)
                    return;

                strongComponents.emplace_back(std::move(nestedNodes));
            }
        }

        //
        // Helper function for EvaluationSort of nodes inside connected components.
        // Creates the processing order within a recurrent loop.
        // Re-traverses the set of nodes between 'node' and the first delay node on each sub-graph.
        //
        template<class TNode>
        void LoopEvaluationSort(std::set<TNode>& visited,
            std::set<TNode>& nodesOnThePathFromRoot,
            std::vector<TNode>& result,
            TNode node,
            const DirectedGraph<TNode>& graph,
            const StrongComponent<TNode>& component,
            std::function<bool(const TNode&)> delay)
        {
            if (visited.find(node) != visited.end())
            {
                // Check if we have a loop without a delay node.
                if (nodesOnThePathFromRoot.find(node) != nodesOnThePathFromRoot.end())
                    LogicError("Node %ls is part of an infinite loop that cannot be unrolled.", ToString(node).c_str());
                return;
            }

            visited.insert(node);
            nodesOnThePathFromRoot.insert(node);

            // Recurse if not a delay, stop when see a recurrence.
            if (!delay(node))
            {
                for (const auto& p : graph.Predecessors(node))
                {
                    if (component.Contains(p))
                        LoopEvaluationSort(visited, nodesOnThePathFromRoot, result, p, graph, component, delay);
                }
            }

            nodesOnThePathFromRoot.erase(node);
            result.push_back(node);
        }
    }

    //
    // Returns a list of nodes reachable from 'startNodes' in the post-order traversal.
    // For more information please see the forward declaration at the beginning of the file.
    //
    template<class TNode>
    inline std::list<TNode> PostOrderTraversal(const DirectedGraph<TNode>& graph, const std::vector<TNode>& startNodes)
    {
        std::list<TNode> result;
        std::set<TNode> visited;
        for (const auto& node : startNodes)
            Internal::PostOrderTraversalImpl(graph, node, visited, result);
        return result;
    }

    //
    // Returns a list of strongly connected components using Tarjan algorithm.
    //
    template<class TNode>
    std::vector<StrongComponent<TNode>> StrongComponents(const DirectedGraph<TNode>& graph)
    {
        std::map<TNode, Internal::StrongComponentNodeState> state;
        std::vector<StrongComponent<TNode>> result;
        std::stack<TNode> nodeStack;
        size_t index = 0;
        for (auto& root : graph.Roots())
        {
            if (state[root].m_visited)
                continue;
            StrongComponentsImpl(graph, root, nodeStack, index, state, result);
        }
        return result;
    }

    //
    // Sorts nodes inside strongly connected components according to their evaluation order,
    // breaking loops at the delay nodes.
    //
    // Used algorithm goes as follows:
    //  - take a connected component
    //  - find all its nodes that feed only into delay nodes, these nodes become new roots
    //  - perform the topological sort starting at these roots and breaking at delay nodes
    //  - update the component with the reordered list of sorted nodes
    //
    template<class TNode>
    inline void EvaluationSort(const DirectedGraph<TNode>& graph, std::function<bool(const TNode&)> delay, std::vector<StrongComponent<TNode>>& strongComponents)
    {
        for (auto& component : strongComponents)
        {
            // Get all nodes that only have a delay child, these
            // will become new roots for evaluation.
            const auto& nestedNodes = component.Nodes();
            std::set<TNode> newRoots(nestedNodes.begin(), nestedNodes.end());
            for (const auto& node : nestedNodes)
            {
                if (delay(node))
                    continue;

                for (const auto& predecessor : graph.Predecessors(node))
                {
                    if (component.Contains(predecessor))
                        newRoots.erase(predecessor);
                }
            }

            // Perform the topological sort stopping at delay nodes
            // to break the loops.
            std::vector<TNode> reordered;
            reordered.reserve(component.Nodes().size());

            std::set<TNode> visited;
            for (const auto& root : newRoots)
            {
                if (visited.find(root) != visited.end())
                    continue;

                std::set<TNode> checkInfinity;
                Internal::LoopEvaluationSort(visited, checkInfinity, reordered, root, graph, component, delay);
            }

            // Update the component.
            component.UpdateNodeOrder(std::move(reordered));
        }
    }

    //
    // Sorts all nodes of the graph in the evaluation order given by the root nodes.
    // Strongly connected components should be already sorted using EvaluationSort function.
    //
    template<class TNode>
    inline std::vector<TNode> GlobalEvaluationSort(const DirectedGraph<TNode>& graph, const std::vector<StrongComponent<TNode>>& strongComponents)
    {
        auto nodes = PostOrderTraversal(graph, graph.Roots());
        if (strongComponents.empty())
            return std::vector<TNode>(nodes.begin(), nodes.end());

        // Now we need to collect all strong components and the rest of the nodes
        // in the global evaluation order.

        // Prepare additional structure that contains the number of nodes per
        // component.
        std::map<decltype(strongComponents.begin()), size_t> componentToNodeCount;
        for (auto i = strongComponents.begin(); i != strongComponents.end(); ++i)
            componentToNodeCount.insert(std::make_pair(i, i->Nodes().size()));

        // Strong components should already be sorted in a proper evaluation order.
        // The whole strong component gets evaluated on its last node position in the global
        // topological order list('nodes').
        std::vector<TNode> result;
        result.reserve(nodes.size());
        for (const auto& node : nodes)
        {
            auto component = std::find_if(strongComponents.begin(), strongComponents.end(),
                [&node](const StrongComponent<TNode>& c) { return c.Contains(node); });
            if (component == strongComponents.end())
            {
                result.push_back(node);
            }
            else
            {
                // Check if the last node of the component in the global topological
                // sort order. If that is the case, insert all nodes of the component.
                assert(componentToNodeCount[component] > 0);
                if (--componentToNodeCount[component] == 0)
                    result.insert(result.end(), component->Nodes().begin(), component->Nodes().end());
            }
        }
        return result;
    }
}
back to top