https://github.com/Microsoft/CNTK
Raw File
Tip revision: 5d14950684bda6f8ea64984b52dd9fbd6387e77f authored by Eldar Akchurin on 01 March 2017, 14:17:46 UTC
Implementation of sparse labels
Tip revision: 5d14950
CompositeFunction.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#include "stdafx.h"
#include "CNTKLibrary.h"
#include "CompositeFunction.h"
#include "ComputationNetworkBuilder.h"
#include "Utils.h"
#include "ComputationNode.h"
#include "ReshapingNodes.h"
#include "EvaluationNodes.h"
#include "TrainingNodes.h"
#include "LinearAlgebraNodes.h"
#include "InputAndParamNodes.h"
#include "NonlinearityNodes.h"
#include "RecurrentNodes.h"
#include "Serialization.h"
#include "Value.h"
#include "RNNNodes.h"
#include "UserDefinedV2FunctionNode.h"
#include "BlockFunction.h"
#include "SpecialPurposeNodes.h"

using namespace Microsoft::MSR::CNTK;

namespace CNTK
{
    /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName";
    /*static*/ std::atomic<unsigned int> CompositeFunction::s_nextAutoGeneratedDynamicAxis(0);

    static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction";

    Dictionary CompositeFunction::SerializeBlockComposite() const
    {
        Dictionary dict;

        dict[versionKey] = CurrentVersion();
        dict[typeKey] = s_compositeFunctionTypeValue;
        dict[rootKey] = RootFunction()->Uid();
        if (!Name().empty())
            dict[nameKey] = Name();
        dict[uidKey] = Uid();

        return dict;
    }

    /*virtual*/ Dictionary CompositeFunction::Serialize() const
    {
        Dictionary dict = SerializeBlockComposite();
       
        // Find cycles in the graph and "break" them by inserting placeholders.
        // This needs to be done on Save, since here we have easy access to the shape and 
        // dynamic axis info.
        std::unordered_set<FunctionPtr> visitedFunctions;
        std::vector<FunctionPtr> topoSortedPrimitiveFunctions;
        std::vector<Variable> uniqueInputs;
        std::unordered_set<std::wstring> inputUids;
        std::function<void(const FunctionPtr& function)> SerializationTraversalFunc;
        SerializationTraversalFunc = [&visitedFunctions, &uniqueInputs, &topoSortedPrimitiveFunctions, &inputUids, &SerializationTraversalFunc](const FunctionPtr& function) {
            std::vector<Variable> functionInputs = function->Inputs();
            for (const auto& input : functionInputs)
            {
                auto& uid = input.Uid();
                if (inputUids.find(uid) != inputUids.end())
                    continue;

                // check if this input corresponds to a cyclic edge in the graph.
                // BUG: A function being visited twice does not indicate it being a cyclic edge in the graph.
                // It just means there are at least 2 successors in the graph that have the function as input
                bool mustBeReplaced = input.IsOutput() && (visitedFunctions.find(input.Owner()) != visitedFunctions.end());

                if (mustBeReplaced)
                {
                    auto varKind = VariableKind::Placeholder;
                    Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid);
                    uniqueInputs.push_back(var);
                    inputUids.insert(uid);
                }
                else if (!input.IsOutput())
                {
                    // leave the input as is.
                    uniqueInputs.push_back(input);
                    inputUids.insert(uid);
                }
            }
            visitedFunctions.insert(function);
            topoSortedPrimitiveFunctions.push_back(function);

            // For block functions we need to recursively traverse the underlying composite
            if (function->IsBlock())
                PreorderTraverseFunctions(function->BlockRoot(), SerializationTraversalFunc);
        };

        PreorderTraverseFunctions(RootFunction(), SerializationTraversalFunc);

        std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions));

        assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid());
        
        std::vector<DictionaryValue> inputDictionaries;
        inputDictionaries.reserve(uniqueInputs.size());
        inputUids.clear();
        for (const auto& input : uniqueInputs)
        {
            if (inputUids.find(input.Uid()) != inputUids.end())
                LogicError("Input uids must be unique");

            inputUids.insert(input.Uid());
            inputDictionaries.push_back(input.Serialize());
        }

        dict[inputsKey] = std::move(inputDictionaries);
       
        std::vector<DictionaryValue>  functionDictionaries;
        std::unordered_set<std::wstring> outputUids;
        for (const auto& primitiveFunction : topoSortedPrimitiveFunctions)
        {
            for (const auto& output : primitiveFunction->RawOutputs())
            {
                if (outputUids.find(output.Uid()) != outputUids.end())
                    LogicError("Output uids of all primitive functions in a function graph must be unique");

                outputUids.insert(primitiveFunction->Uid());
            }

            functionDictionaries.push_back(primitiveFunction->Serialize());
        }

        dict[functionsKey] = std::move(functionDictionaries);
        
        // Now, collect and store the internal state for all non-pure (stateful) functions in the graph 
        // (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc).
        Dictionary stateDictionary; 
        for (const auto& kv : m_variableToNodeMap)
        {
            if (kv.second->Is<RngUser>() && kv.first.IsOutput())
            {
                // The RNG state should be associated with the actual function that the computation node
                // corresponds to, and not the block primitives that wrap the actual function
                auto ownerFunction = kv.first.Owner().get();
                if (!ownerFunction->IsBlock())
                {
                    auto rng = kv.second->As<RngUser>();
                    Dictionary state;
                    state[rngSeedKey] = static_cast<size_t>(rng->GetRngSeed());
                    state[rngOffsetKey] = static_cast<size_t>(rng->GetRngOffset());
                    stateDictionary[ownerFunction->Uid()] = state;
                }
            }
        }

        dict[stateKey] = std::move(stateDictionary);

        return dict;
    }

    /*static*/ FunctionPtr CompositeFunction::DeserializeBlockComposite(const Dictionary& dict,
                                                                        const std::unordered_set<FunctionPtr>& allPrimitiveFunctions,
                                                                        const std::unordered_map<Variable, Variable>& allPlaceholderReplacements,
                                                                        const CNTK::DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, rootKey, uidKey };
        ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);

        const auto& rootUid = dict[rootKey].Value<std::wstring>();
        std::wstring name = L"";
        if (dict.Contains(nameKey))
            name = dict[nameKey].Value<std::wstring>();
        const auto& uid = dict[uidKey].Value<std::wstring>();

        FunctionPtr root = *std::find_if(allPrimitiveFunctions.begin(), allPrimitiveFunctions.end(), [&rootUid](const FunctionPtr& func) {
            return func->Uid() == rootUid;
        });

        // Find the subset of placeholder replacements that apply for this composite
        FunctionPtr composite = CompositeFunction::Create(root, name, uid);
        std::unordered_map<Variable, Variable> placeholderReplacements;
        do
        {
            placeholderReplacements.clear();
            auto compositePlaceholders = composite->Placeholders();
            for (auto placeholder : compositePlaceholders)
            {
                if (allPlaceholderReplacements.find(placeholder) != allPlaceholderReplacements.end())
                    placeholderReplacements.insert({ placeholder, allPlaceholderReplacements.at(placeholder) });
            }

            if (placeholderReplacements.size() > 0)
                composite = composite->ReplacePlaceholders(placeholderReplacements);

        } while (placeholderReplacements.size() > 0);

        return composite;
    }

    /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { inputsKey, functionsKey };
       
        size_t version = ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);

        const auto& inputs = dict[inputsKey].Value<vector<DictionaryValue>>();
        std::unordered_map<std::wstring, Variable> uidToInputMap(inputs.size());
        for (const auto& dictionaryValue : inputs)
        {
            const auto& dictionary = dictionaryValue.Value<Dictionary>();
            const auto& inputVar = Variable::Deserialize(dictionary, device);

            if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end())
            {
                LogicError("Input uids are not unique (several inputs share '%ls' uid) "
                           "(%s).", inputVar.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
            }
            uidToInputMap[inputVar.Uid()] = inputVar;
        }

        Dictionary stateDictionary;
        if (dict.Contains(stateKey))
            stateDictionary = dict[stateKey].Value<Dictionary>();

        const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();

        std::unordered_map<Variable, Variable> allPlaceholderReplacements;
        std::unordered_set<FunctionPtr> allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created.
        for (const auto& dictionaryValue : functions)
        {
            FunctionPtr root = PrimitiveFunction::Deserialize(dictionaryValue.Value<Dictionary>(), uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device);
            allPrimitiveFunctions.insert(root);

            auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(root.get());

            // Since Combine simply forwards other functions' outputs, all of its outputs
            // should already be in the uidToInputMap.
            auto opType = primitiveFunction->OpType();
            if (opType == PrimitiveOpType::Combine)
                continue;

            if (primitiveFunction->IsStateful())
            {
                if (stateDictionary.Contains(primitiveFunction->Uid()))
                {
                    auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
                    auto seed = state[rngSeedKey].Value<size_t>();
                    auto offset = state[rngOffsetKey].Value<size_t>();
                    primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed;
                    primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset;
                }  
                else if (Internal::GetComputationNetworkTraceLevel() > 0)
                {
                    // TODO: all logging functionality should be refactored to live in a logging utility class. 
                    fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
                            "when deserializing from a dictionary (version=%zu). "
                            "Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
                }
            }

            for (const auto& output : root->RawOutputs())
            {
                const auto& it = uidToInputMap.find(output.Uid());
                if (it != uidToInputMap.end())
                {
                    if (!it->second.IsPlaceholder())
                    {
                        LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)"
                        "(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(),
                        GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
                    }
                    allPlaceholderReplacements[it->second] = output;
                }
                else
                {
                    uidToInputMap[output.Uid()] = output;
                }
            }
        }

        return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
    }

    void CompositeFunction::CopyState(const CompositeFunction& source)
    {
        // Create a map with all non-pure (stateful) functions in the function graph.
        auto collectStatefulFunctions = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions) -> std::map<std::wstring, FunctionPtr> {
            std::map<std::wstring, FunctionPtr> functionMap;
            for (auto funcPtr : allPrimitiveFunctions)
            {
                auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
                if (primitiveFunction->IsStateful())
                {
                    functionMap[primitiveFunction->Uid()] = funcPtr;
                }
            }
            return functionMap;
        };

        std::map<std::wstring, FunctionPtr> statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions);
        std::map<std::wstring, FunctionPtr> statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions);

        assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size());
        if (statefulFunctionsFrom.size() == 0)
        {
            return;
        }

        // Copy state captured in the attributes dictionaries.
        for (const auto& kv : statefulFunctionsFrom)
        {
            statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes();
        }

        UpdateInternalNetworkState();
    }

    void CompositeFunction::UpdateInternalNetworkState()
    {
        if (!m_computationNetwork)
        {
            return;
        }

        for (const auto& function : m_allPrimitiveFunctions)
        {
            auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
            if (primitiveFunction->IsStateful())
            {
                for (const auto& output : function->RawOutputs())
                {
                    auto node = m_variableToNodeMap.at(output);
                    auto attributes = function->Attributes();
                    auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                    auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
                    node->As<RngUser>()->SetRngState(seed, offset);
                }
            }
        }
    }

    // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions 
    // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the 
    // top level 'variable'
    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
                                                                 Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                 ComputationNetworkBuilder<ElementType>& builder,
                                                                 std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                 std::unordered_map<Variable, bool>& isVariableRootMap,
                                                                 const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
    {
        auto iter = variableToNodeMap.find(variable);
        if (iter != variableToNodeMap.end())
        {
            isVariableRootMap[variable] = false;
            return iter->second;
        }

        // The DataType, Shape and DynamicAxes of the variable must be known by now
        if (variable.GetDataType() == DataType::Unknown)
            InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.Shape().IsUnknown())
            InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.Shape().HasInferredDimension())
            InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.DynamicAxes() == Axis::UnknownDynamicAxes())
            InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs
        variableToNodeMap[variable] = nullptr;

        std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
        if (variable.IsParameter() || variable.IsConstant())
        {
            auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name());
            computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape()));
            network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
            if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end()))
                computationNodePtr->SetLearningRateMultiplier(0.0);

            NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
            std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();

            if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
                computationNodePtr->Value() = valueMatrix->AsReference();
            else // Constant: if initialized data lives on wrong device, make a copy to the right one (copy is OK since it's constant)
            {
                Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
                clonedMatrix.AssignValuesOf(*valueMatrix);
                computationNodePtr->Value() = std::move(clonedMatrix);
            }
        }
        else if (variable.IsInput())
        {
            auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name());

            // TODO: Input variables currently are required to have the default batch axis
            auto dynamicAxes = variable.DynamicAxes();
            auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis());
            if (foundDefaultBatchAxis == dynamicAxes.end())
                LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes");

            if (dynamicAxes.back() != Axis::DefaultBatchAxis())
                LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes");

            // TODO: Support inputs with > 1 dynamic axes
            if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2))
                LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported");

            // Construct the dynamic axis name to be used internally for the CNTK InputNodes
            std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);

            if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
                network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});

            if (IsSparseInput(variable))
                computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName);
            else
                computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName);

            if (variable.NeedsGradient() && (inputsToExcludeGradientsFor.find(variable) == inputsToExcludeGradientsFor.end()))
            {
                // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default
                // gradients are not computed for Input nodes
                computationNodePtr->SetLearningRateMultiplier(0.00001f);
            }
        }
        else
        {
            assert(variable.IsOutput());
            auto outputVariableNode = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor);
            // Can be null in case of loops with f.output == f.input.
            // Such loops cannot be handled, so we leave nullptr as computational node.
            if (outputVariableNode)
                computationNodePtr = outputVariableNode->template As<ComputationNode<ElementType>>()->shared_from_this();
            else
                computationNodePtr = nullptr;
        }

        variableToNodeMap[variable] = computationNodePtr;
        if (isVariableRootMap.find(variable) == isVariableRootMap.end())
            isVariableRootMap[variable] = variable.IsOutput();

        return computationNodePtr;
    }

    /*static*/ Variable CompositeFunction::GetMappingForNoOpOutput(const Variable& variable, bool recursive)
    {
        Variable mappingVariable = variable;

        auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
        auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
        if (ownerPrimitiveFunc && (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp))
            mappingVariable = ownerPrimitiveFunc->Inputs()[0];

        if (recursive && (mappingVariable != variable))
            return GetMappingForNoOpOutput(mappingVariable);
        else
            return mappingVariable;
    }

    /*static*/ Variable CompositeFunction::GetMappingVariable(const Variable& variable, bool recursive)
    {
        Variable mappingVariable = variable;

        auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
        auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
        if (ownerPrimitiveFunc)
        {
            if (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp)
                mappingVariable = GetMappingForNoOpOutput(variable);
            else
            {
                auto ownerBlockFunc = dynamic_cast<BlockFunction*>(ownerFunc);
                if (ownerBlockFunc)
                    mappingVariable = ownerBlockFunc->CompositeOutputsMap().at(variable);
            }
        }

        if (recursive && (mappingVariable != variable))
            return GetMappingVariable(mappingVariable);
        else
            return mappingVariable;
    }

    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
                                                                               Function* function,
                                                                               const std::vector<std::shared_ptr<ComputationNode<ElementType>>>& inputNodes,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap)
    {
        PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
        if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp))
            return variableToNodeMap[GetMappingVariable(variable)];

        ComputationNodeBasePtr computationNodePtr;

        auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name());

        std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
        for (auto inputNode : inputNodes)
            inputNodesBasePtrs.push_back(inputNode);

        if (primitiveFunction)
        {
            auto functionInputs = function->Inputs();
            auto& functionConfig = function->Attributes();
            PrimitiveOpType op = primitiveFunction->OpType();

            switch (op)
            {
            case PrimitiveOpType::Negate:
                computationNodePtr = New<NegateNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Sigmoid:
                computationNodePtr = New<SigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Tanh:
                computationNodePtr = New<TanhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Cos:
                computationNodePtr = New<CosineNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Sin:
                computationNodePtr = New<SinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ReLU:
                computationNodePtr = New<RectifiedLinearNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Exp:
                computationNodePtr = New<ExpNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Log:
                computationNodePtr = New<LogNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Sqrt:
                computationNodePtr = New<SqrtNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ELU:
                computationNodePtr = New<ExponentialLinearUnitNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Floor:
                computationNodePtr = New<FloorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Abs:
                computationNodePtr = New<AbsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Reciprocal:
                computationNodePtr = New<ReciprocalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Softmax:
                computationNodePtr = New<SoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Hardmax:
                computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::TransposeAxes:
            {
                auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
                auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();

                // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
                computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
                break;
            }
            case PrimitiveOpType::Where:
            {
                auto dynamicAxes = variable.DynamicAxes();
                auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
                computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
                break;
            }
            case PrimitiveOpType::Slice:
            {
                auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
                auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
                auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value<int>();

                // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
                computationNodePtr = New<SliceNode<ElementType>>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis));
                break;
            }
            case PrimitiveOpType::RandomSample:
            {
                auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
                auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
                computationNodePtr = New<RandomSampleNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
                break;
            }
            case PrimitiveOpType::RandomSampleInclusionFrequency:
            {
                auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
                auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
                computationNodePtr = New<RandomSampleInclusionFrequencyNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
                break;
            }
            case PrimitiveOpType::Dropout:
            {
                auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value<double>();
                computationNodePtr = New<DropoutNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                computationNodePtr->As<DropoutNode<ElementType>>()->SetDropoutRate(dropoutRate);
                break;
            }
            case PrimitiveOpType::Reshape:
            {
                computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(primitiveFunction->RawOutputs()[0].Shape()));
                break;
            }
            case PrimitiveOpType::ROIPooling:
            {
                auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value<NDShape>();
                computationNodePtr = New<ROIPoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape));
                break;
            }
            case PrimitiveOpType::Pooling:
            {
                PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
                auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
                auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                computationNodePtr = New<PoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
                break;
            }
            case PrimitiveOpType::Unpooling:
            {
                auto unpoolingWindowShape = functionConfig[PrimitiveFunction::AttributeNameUnpoolingWindowShape].Value<NDShape>();
                auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                //We only get here after validation so it is safe to assume unpooling is max
                computationNodePtr = New<MaxUnpoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
                break;
            }
            case PrimitiveOpType::SumAll:
                computationNodePtr = New<SumElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Plus:
                computationNodePtr = New<PlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::LogPlus:
                computationNodePtr = New<LogPlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Minus:
                computationNodePtr = New<MinusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ElementTimes:
                computationNodePtr = New<ElementTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Equal:
                computationNodePtr = New<EqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::NotEqual:
                computationNodePtr = New<NotEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Less:
                computationNodePtr = New<LessNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::LessEqual:
                computationNodePtr = New<LessEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Greater:
                computationNodePtr = New<GreaterNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::GreaterEqual:
                computationNodePtr = New<GreaterEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Times:
            {
                size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
                auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value<int>();
                computationNodePtr = New<TimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
                break;
            }
            case PrimitiveOpType::TransposeTimes:
            {
                size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
                computationNodePtr = New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank);
                break;
            }
            case PrimitiveOpType::Convolution:
            {
                NDShape outputMapCount, kernelShape;
                std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape());
                auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
                auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
                auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
                computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples);
                break;
            }
            case PrimitiveOpType::CosDistance:
                computationNodePtr = New<CosDistanceNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Logistic:
                computationNodePtr = New<LogisticNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::SquaredError:
                computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::CrossEntropyWithSoftmax:
                computationNodePtr = New<CrossEntropyWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ClassificationError:
                computationNodePtr = New<ClassificationErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::EditDistanceError:
            {
                auto subPen = functionConfig[PrimitiveFunction::AttributeNameSubstitutionPenalty].Value<float>();
                auto delPen = functionConfig[PrimitiveFunction::AttributeNameDeletionPenalty].Value<float>();
                auto insPen = functionConfig[PrimitiveFunction::AttributeNameInsertionPenalty].Value<float>();
                auto squashInputs = functionConfig[PrimitiveFunction::AttributeNameSquashInputs].Value<bool>();
                auto tokensToIgnore = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNameTokensToIgnore].Value<std::vector<DictionaryValue>>());
                computationNodePtr = New<EditDistanceErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore);
                break;
            }
            case PrimitiveOpType::ForwardBackward:
            {
                auto delayContraint = functionConfig[PrimitiveFunction::AttributeNameDelayConstraint].Value<int>();
                auto blankTokenId = functionConfig[PrimitiveFunction::AttributeNameBlankTokenId].Value<size_t>();
                computationNodePtr = New<ForwardBackwardNode<ElementType>>(network->GetDeviceId(), internalNodeName, blankTokenId, delayContraint);
                break;
            }
            case PrimitiveOpType::LambdaRank:
                computationNodePtr = New<LambdaRankNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::NDCG:
                computationNodePtr = New<NDCG1EvalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::PastValue:
            case PrimitiveOpType::FutureValue:
            {
                Variable inputOperandVar = functionInputs[0];
                Variable initialStateVar = functionInputs[1];

                size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
                if (op == PrimitiveOpType::PastValue)
                    computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
                else
                    computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);

                break;
            }
            case PrimitiveOpType::ReduceElements:
            {
                auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
                auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
                computationNodePtr = New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis));
                break;
            }
            case PrimitiveOpType::BatchNormalization:
            {
                auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
                auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
                auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
                auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
                auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();

                computationNodePtr = New<BatchNormalizationNode<ElementType>>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW);
                break;
            }
            case PrimitiveOpType::Combine:
                // This operation is just a no-op and is a means to combine multiple functions to create a single Function
                // whose outputs are a union of the outputs of the Functions being combined.
                computationNodePtr = variableToNodeMap[variable];
                break;
            case PrimitiveOpType::PackedIndex:
                computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::GatherPacked:
                computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ScatterPacked:
                computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Clip:
                computationNodePtr = New<ClipNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Select:
                computationNodePtr = New<IfNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Splice:
            {
                Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
                computationNodePtr = New<RowStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
                break;
            }
            case PrimitiveOpType::OptimizedRNNStack:
            {
                auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value<bool>();
                auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value<size_t>();
                auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value<size_t>();
                auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value<std::wstring>();

                computationNodePtr = New<OptimizedRNNStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
                break;
            }
            case PrimitiveOpType::ReconcileDynamicAxis:
            {
                computationNodePtr = New<ReconcileDynamicAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            }
            case PrimitiveOpType::LogSoftmax:
            {
                //This can be implemented as x => x - ReduceLogSum(x). How to do this here?
                computationNodePtr = New<LogSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            }
            case PrimitiveOpType::Pass:
                computationNodePtr = New<PassNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::LabelsToGraph:
                computationNodePtr = New<LabelsToGraphNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::StopGradient:
                computationNodePtr = New<StopGradientNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            default:
                LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
                break;
            }

            // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
            ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);

            if (computationNodePtr->Is<INumInputs>())
            {
                auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
                if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
                    LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs",
                    PrimitiveOpTypeName(op).c_str(),
                    (int)inputNodesBasePtrs.size(),
                    (int)computationNodeExpectedInputCount);
            }

            if (computationNodePtr->Is<RngUser>())
            {
                if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed))
                {
                    auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                    uint64_t offset = 0;
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
                    {
                        offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
                    }
                    computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
                }
            }
        }
        else
        {
            auto outputs = function->RawOutputs();
            if (variable == outputs[0])
            {
                computationNodePtr = New<UserDefinedV2FunctionNode<ElementType>>(network->GetDeviceId(), internalNodeName, function->shared_from_this());

                // For user defined functions, we only attach unique inputs in the internal computation network since, the UDF 
                // backward implementations directly compute aggregate gradient values for unique inputs
                std::vector<ComputationNodeBasePtr> uniqueInputNodesBasePtrs;
                for (auto inputNodeBasePtr : inputNodesBasePtrs)
                {
                    if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end())
                        uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr);
                }

                inputNodesBasePtrs = uniqueInputNodesBasePtrs;
            }
            else
            {
                size_t i = 1;
                while (outputs[i] != variable) i++;
                assert(i < outputs.size());

                computationNodePtr = New<SelectUserDefinedV2FunctionOutputNode<ElementType>>(network->GetDeviceId(), CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()), i);
                inputNodesBasePtrs = { variableToNodeMap[outputs[0]] };
            }
        }

        network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
        return computationNodePtr;
    }

    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               ComputationNetworkBuilder<ElementType>& builder,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                               std::unordered_map<Variable, bool>& isVariableRootMap,
                                                                               const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
    {
        assert(variable.IsOutput());

        Function* function = variable.Owner().get();
        ComputationNodeBasePtr computationNodePtr;
        auto& functionInputs = function->m_inputs;

        DataType nonConstInputDataType = DataType::Unknown;
        for (auto& inputVar : functionInputs)
        {
            if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown))
            {
                nonConstInputDataType = inputVar.GetDataType();
                break;
            }
        }
            
        // Create the nodes corresponding to the inputs
        std::vector<std::shared_ptr<ComputationNode<ElementType>>> inputNodes;
        for (auto& inputVar : functionInputs)
        {
            // If the inputVar is a constant and not the right DataType let's coerce it to the right type
            if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType))
            {
                auto originalConstantValue = Constant(inputVar).Value();
                auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true);
                NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true);
                inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name());
            }

            auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor);
            inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As<ComputationNode<ElementType>>()->shared_from_this() : nullptr);
        }

        BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function);
        if (blockFunction)
        {
            // For block function, map each argument placeholder of the underlying composite to
            // the computation node corresponding to the block input that the argument placeholder
            // of the composite is mapped to.
            auto compositeArguments = blockFunction->Composite()->Arguments();
            for (auto compositeArgument : compositeArguments)
                variableToNodeMap[compositeArgument] = variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());

            return GetNode(variable.BlockFunctionVariableMapping(), network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor);
        }
        else
            computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap);

        PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
        if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine))
        {
            for (auto inputVar : functionInputs)
                isVariableRootMap[inputVar] = false;
        }

        return computationNodePtr;
    }

    std::unordered_set<Variable> CompositeFunction::NonOwnerPreservingCopy(const std::unordered_set<Variable>& outputs)
    {
        std::unordered_set<Variable> result;
        for (auto& o : outputs)
        {
            Variable sanitized = o.NonCompositePreservingCopy();
            result.insert(sanitized);
        }

        return result;
    }

    template <typename ElementType>
    ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device,
                                                                   const std::unordered_set<Variable>& backpropRoots,
                                                                   const std::unordered_set<Variable>& outputs,
                                                                   const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
                                                                   bool allocateNetworkMatrices)
    {
        if (m_computationNetwork != nullptr)
        {
            // TODO: We should either invalidate and readapt the network if the backpropRoots change compared to what was specified when the network
            // was last constructed, to just recreate a new network.
            // For now just disallow changing the backpropRoots after the network is created
            if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots))
                LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported");

            // TODO: Support changing the device across different invocations of the forward method on a Function instance
            if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device)
                LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported");
            
            if (!backpropRoots.empty() && (inputsToExcludeGradientsFor != m_inputsExcludedFromGradientComputation))
                LogicError("Changing the set of inputs to exclude from gradient computation, across different Forward calls on a CNTK composite Function, is currently unsupported");
        }
        else
        {
            m_computationNetwork = std::make_shared<ComputationNetwork>(AsCNTKImplDeviceId(device));

            auto networkInputs = this->Inputs();
            for (auto inputExcluded : inputsToExcludeGradientsFor)
            {
                // Only inputs of the network can be excluded from gradient computation
                if (std::find(networkInputs.begin(), networkInputs.end(), inputExcluded) == networkInputs.end())
                    InvalidArgument("Function::Forward: Only inputs of a Function can be excluded from gradient computation");
            }

            m_inputsExcludedFromGradientComputation = NonOwnerPreservingCopy(inputsToExcludeGradientsFor);

            ComputationNetworkBuilder<ElementType> builder(*m_computationNetwork);

            // TODO: We currently only support one backprop root
            if (backpropRoots.size() > 1)
                LogicError("More than one backprop roots is currently unsupported");

            auto placeholders = Placeholders();
            if (!placeholders.empty())
                InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!");

            // Now recursively create the network in a top-down fashion
            auto rootFunction = RootFunction();
            auto rootFunctionOutputs = rootFunction->RawOutputs();
            for (auto rootOutput : rootFunctionOutputs)
                GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap, m_inputsExcludedFromGradientComputation);

            // We need to patch the Computation node mappings for the arguments of block functions 
            // since for recurrent inputs, the mappings are not fully established the first time
            std::function<void(const FunctionPtr&)> PatchBlockArgumentsMapping;
            PatchBlockArgumentsMapping = [this, &PatchBlockArgumentsMapping](const FunctionPtr& function) {
                BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function.get());
                if (blockFunction)
                {
                    auto compositeArguments = blockFunction->Composite()->Arguments();
                    for (auto compositeArgument : compositeArguments)
                        m_variableToNodeMap[compositeArgument] = m_variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());

                    PreorderTraverseFunctions(function->BlockRoot(), PatchBlockArgumentsMapping);
                }
            };
            PreorderTraverseFunctions(rootFunction, PatchBlockArgumentsMapping);

            std::function<bool(const Variable&)> IsVariableRoot = [this, &IsVariableRoot](const Variable& outputVar) {
                auto mappingVariable = GetMappingVariable(outputVar);
                return (m_isVariableRootMap[outputVar] && !IsFirstOutputOfMultiOutputUDF(mappingVariable) && ((mappingVariable == outputVar) || IsVariableRoot(mappingVariable)));
            };

            // If any of the function or requested outputs is not a root node, we need to explicitly
            // add it to the 'output' group of the ComputationNetwork
            std::unordered_set<Variable> networkOutputs(outputs);
            networkOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
            for (auto output : networkOutputs)
            {
                if (!IsVariableRoot(output))
                {
                    auto computationNode = m_variableToNodeMap[output];

                    if (!computationNode)
                        InvalidArgument("One of the requested outputs for the Function forward computation is not part of the graph underlying the Function");

                    m_computationNetwork->AddToNodeGroup(L"output", computationNode);
                }
            }

            m_currentBackpropRoots = NonOwnerPreservingCopy(backpropRoots);

            // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles.
            // Now attach those after we have created all ComputationNodes in the network
            for (auto varNodePair : m_variableToNodeMap)
            {
                auto& currentComputationNode = varNodePair.second;
                if (!currentComputationNode)
                    LogicError("No computation node mapping exists for Variable %S", varNodePair.first.Name().c_str());

                auto& currentComputationNodeInputs = currentComputationNode->GetInputs();
                auto& currentVar = varNodePair.first;
                if (!currentVar.IsOutput())
                    continue;

                if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end())
                {
                    // This ComputationNode has at least one null input which now needs to be properly attached

                    const PrimitiveFunction* primitiveFunc = dynamic_cast<const PrimitiveFunction*>(currentVar.Owner().get());

                    // Skip block primitives since they do not directly map to a computation node
                    if (primitiveFunc->OpType() == PrimitiveOpType::Block)
                        continue;

                    // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
                    auto inputVars = primitiveFunc->Inputs();
                    ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars);
                    inputVars.resize(currentComputationNode->GetNumInputs());

                    std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
                    for (auto inputVar : inputVars)
                        inputNodesBasePtrs.push_back(m_variableToNodeMap.at(inputVar));

                    currentComputationNode->AttachInputs(inputNodesBasePtrs);
                }
            }

            m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel());
            m_computationNetwork->CompileNetwork();

            // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork
            for (auto varNodePair : m_variableToNodeMap)
            {
                if (varNodePair.first.IsOutput())
                {
                    auto outputVar = varNodePair.first;
                    auto computationNodePtr = m_variableToNodeMap.at(outputVar);
                    auto outputShape = outputVar.Shape();
                    auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
                    if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) ||
                        ((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape))))
                    {
                        LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str());
                    }
                }
            }

            // Record the timestamps of Parameter values
            assert(m_lastRecordedParameterValueTimeStamps.empty());
            auto functionParameters = Parameters();
            for (auto parameter : functionParameters)
                m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() });
        }

        if (!m_networkMatricesAllocated && allocateNetworkMatrices)
        {
            ComputationNodeBasePtr backpropRootNode;
            if (!m_currentBackpropRoots.empty())
                backpropRootNode = m_variableToNodeMap.at(*m_currentBackpropRoots.begin());

            // Now recursively traverse the network in a top-down fashion
            auto rootFunction = RootFunction();
            auto rootFunctionOutputs = rootFunction->RawOutputs();
            std::vector<ComputationNodeBasePtr> forwardRootNodes;
            for (auto rootOutput : rootFunctionOutputs)
                forwardRootNodes.push_back(m_variableToNodeMap.at(rootOutput));

            std::vector<ComputationNodeBasePtr> forwardOutputNodes;
            for (auto output : outputs)
                forwardOutputNodes.push_back(m_variableToNodeMap.at(output));

            m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode);
            m_networkMatricesAllocated = allocateNetworkMatrices;

            std::unordered_set<ComputationNodeBasePtr> allNetworkRoots = { backpropRootNode };
            allNetworkRoots.insert(forwardRootNodes.begin(), forwardRootNodes.end());
            allNetworkRoots.insert(forwardOutputNodes.begin(), forwardOutputNodes.end());
            m_allNetworkRootsInGlobalEvalOrder = m_computationNetwork->SortByGlobalEvalOrder(allNetworkRoots);
        }
        else
        {
            // Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure
            // in the cached computation network
            for (auto output : outputs)
            {
                auto computationNode = m_variableToNodeMap.at(output);
                if (std::find(m_allNetworkRootsInGlobalEvalOrder.begin(), m_allNetworkRootsInGlobalEvalOrder.end(), computationNode) == m_allNetworkRootsInGlobalEvalOrder.end())
                    LogicError("Changing requested outputs across different Forward calls on a CNTK composite Function is currently unsupported");
            }
        }

        return m_computationNetwork;
    }

    template <typename ElementType>
    /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, ComputationNodeBasePtr& computationNode, std::unordered_map<MBLayoutPtr, Variable>& layoutsPopulated)
    {
        if (!computationNode->Is<InputValueBase<ElementType>>())
            LogicError("CompositeFunction::Forward: Illegal to populate value of computation node type other than InputValueBase!");

        std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableValue.first, variableValue.second);

        // Switch the node matrix to the right matrix type
        auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();
        nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);

        auto layout = CNTKMatrixAndMBLayout.second;
        auto& nodeLayout = computationNode->GetMBLayout();
        if (layoutsPopulated.find(nodeLayout) == layoutsPopulated.end())
        {
            nodeLayout->CopyFrom(layout);
            layoutsPopulated.insert({ nodeLayout, variableValue.first });
        }
        else
        {
            if (*nodeLayout != *layout)
                InvalidArgument("Function::Forward: Different minibatch layouts detected (difference in sequence lengths or count or start flags) in data specified for 2 of the Function's argument ('%S', '%S') having same dynamic axes", variableValue.first.Name().c_str(), layoutsPopulated.at(nodeLayout).Name().c_str());
        }
    }

    void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
    {
        std::unordered_map<MBLayoutPtr, Variable> layoutsPopulated;
        std::vector<ComputationNodeBasePtr> inputNodes;
        for (auto argumentValuePair : arguments)
        {
            auto argument = argumentValuePair.first;
            auto argumentComputationNode = m_variableToNodeMap.at(argument);
            assert(argumentComputationNode);
            inputNodes.push_back(argumentComputationNode);

            ValuePtr argumentValue = arguments.at(argument);
            switch (argumentValue->GetDataType())
            {
            case DataType::Float:
                PopulateComputationNodeValue<float>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
                break;
            case DataType::Double:
                PopulateComputationNodeValue<double>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
                break;
            default:
                LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType()));
                break;
            }
        }

        m_computationNetwork->BumpEvalTimeStamp(inputNodes);
    }

    template <typename ElementType>
    /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode)
    {
        std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableGradient.first, variableGradient.second);

        MBLayoutPtr layout = CNTKMatrixAndMBLayout.second;
        auto nodeLayout = computationNode->GetMBLayout();
        if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
            InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call");
        computationNode->As<ComputationNode<ElementType>>()->AssignGradient(*CNTKMatrixAndMBLayout.first);
    }

    // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph
    void CompositeFunction::PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients)
    {
        auto functionOutputs = RawOutputs();
        for (auto gradientVarValuePair : gradients)
        {
            auto outputComputationNode = m_variableToNodeMap.at(gradientVarValuePair.first);
            ValuePtr gradientValue = gradientVarValuePair.second;

            switch (gradientValue->GetDataType())
            {
            case DataType::Float:
                PopulateComputationNodeGradient<float>(gradientVarValuePair, outputComputationNode);
                break;
            case DataType::Double:
                PopulateComputationNodeGradient<double>(gradientVarValuePair, outputComputationNode);
                break;
            default:
                LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType()));
                break;
            }
        }
    }

    static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr)
    {
        size_t outputValueNumAxes = var.Shape().Rank();

        // Add the batch and dynamic axes if needed
        if (computationNodePtr->GetMBLayout() != nullptr)
            outputValueNumAxes += 2;

        std::vector<size_t> outputShapeDims(outputValueNumAxes);
        for (size_t i = 0; i < var.Shape().Rank(); ++i)
            outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i);

        if (computationNodePtr->GetMBLayout() != nullptr)
        {
            outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps();
            outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences();
        }

        return NDShape(outputShapeDims);
    }

    /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient)
    {
        auto valueShape = GetValueShape(var, computationNode);
        if (varValue != nullptr)
        {
            // TODO: The shape of the specified output Value object must match the actual output shape
            if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape)))
                InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str());
        }

        ValuePtr nodeValue;
        auto layout = computationNode->GetMBLayout();
        switch (var.GetDataType())
        {
        case DataType::Float:
        {
            auto& matrix = getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value();
            if (varValue == nullptr)
                nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<float>>(matrix.AsReference()), layout, /*readOnly =*/ false);
            else
                nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var, matrix, layout);
            break;
        }
        case DataType::Double:
        {
            auto& matrix = getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value();
            if (varValue == nullptr)
                nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<double>>(matrix.AsReference()), layout, /*readOnly =*/ false);
            else
                nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var, matrix, layout);
            break;
        }
        default:
            LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType()));
            break;
        }

        if (varValue == nullptr)
            varValue = nodeValue;
        else
            varValue->CopyFrom(*nodeValue);
    }

    void CompositeFunction::GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs)
    {
        // Now copy the Forward values of output nodes from the network to outputs' Value objects
        for (auto outputVarValuePair : outputs)
            GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap.at(outputVarValuePair.first), false /*getGradient*/);
    }

    void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
    {
        auto networkInputs = this->Inputs();
        // Now copy the gradient values of input nodes of the network to gradients' Value objects
        for (auto gradientVarValuePair : gradients)
        {
            // Only gradients corresponding to inputs of the network can be obtained
            if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end())
                InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function");

            // Gradients can only be obtained for parameter variables or input variables that NeedsGradient
            if (!gradientVarValuePair.first.NeedsGradient() || (m_inputsExcludedFromGradientComputation.find(gradientVarValuePair.first) != m_inputsExcludedFromGradientComputation.end()))
                InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, an Input Variable with NeedsGradient setting of false, or an input for which gradient computation was explicitly excluded");

            auto computationNodePtr = m_variableToNodeMap.at(gradientVarValuePair.first);

            if (!computationNodePtr->NeedsGradient())
                LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false");

            GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/);
        }
    }

    const std::vector<Variable>& CompositeFunction::GetArgumentDependencies(const Variable& output)
    {
        if (m_perOutputVarArgumentDependencies.find(output) == m_perOutputVarArgumentDependencies.end())
        {
            auto sanitizedOutput = output.NonCompositePreservingCopy();

            if (sanitizedOutput.IsOutput())
                m_perOutputVarArgumentDependencies[sanitizedOutput] = AsComposite(sanitizedOutput.Owner())->Arguments();
            else
                m_perOutputVarArgumentDependencies[sanitizedOutput] = { sanitizedOutput };
        }

        return m_perOutputVarArgumentDependencies[output];
    }

    std::unordered_map<Variable, uint64_t> CompositeFunction::GetCurrentBackpropRootsTimeStamps() const
    {
        std::unordered_map<Variable, uint64_t> currentBackpropRootsTimeStamps;
        assert(m_computationNetwork != nullptr);

        for (auto& backpropRoot : m_currentBackpropRoots)
            currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp();

        return currentBackpropRootsTimeStamps;
    }

    /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
                                                            std::unordered_map<Variable, ValuePtr>& outputs,
                                                            const DeviceDescriptor& computeDevice,
                                                            const std::unordered_set<Variable>& outputsToRetainBackwardStateFor,
                                                            const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
    {
        // Validate arguments and outputs
        if (outputs.empty())
            InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!");

        // Make sure that the DataType of the variables and corresponding values match
        // TODO: We need a better way to determine the ElementType for the network
        auto dataType = DataType::Unknown;
        for (auto variableValuePair : arguments)
        {
            if (dataType == DataType::Unknown)
                dataType = variableValuePair.first.GetDataType();
            else if (dataType != variableValuePair.first.GetDataType())
                LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same");
        }

        if (dataType == DataType::Unknown)
        {
            for (auto variableValuePair : outputs)
            {
                if (dataType == DataType::Unknown)
                    dataType = variableValuePair.first.GetDataType();
            }
        }

        std::unordered_set<Variable> requestedOutputVariables;
        for (auto output : outputs)
            requestedOutputVariables.insert(output.first);

        if (dataType == DataType::Float)
            GetComputationNetwork<float>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
        else if (dataType == DataType::Double)
            GetComputationNetwork<double>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
        else
            InvalidArgument("Unsupported DataType %s", DataTypeName(dataType));

        std::unordered_set<Variable> functionOutputs(m_outputs.begin(), m_outputs.end());
        std::vector<ComputationNodeBasePtr> outputsToEvaluate;
        std::unordered_set<Variable> requiredArguments;

        for (auto outputVariable : requestedOutputVariables)
        {
            auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVariable);
            requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end());

            auto outputComputationNode = m_variableToNodeMap.at(outputVariable);
            outputsToEvaluate.push_back(outputComputationNode);
        }

        // We should have argument values supplied for all required argument dependencies for the requested outputs
        std::vector<Variable> missingRequiredArguments;
        std::unordered_map<Variable, ValuePtr> requiredArgumentValues;
        for (auto requiredArgument : requiredArguments)
        {
            auto iter = arguments.find(requiredArgument);
            if (iter == arguments.end())
                missingRequiredArguments.push_back(requiredArgument);
            else
                requiredArgumentValues.insert(*iter);
        }

        if (!missingRequiredArguments.empty())
        {
            std::wstring missingRequiredArgumentNames = NamedListString(missingRequiredArguments);
            InvalidArgument("Function::Forward: Values for %d required arguments (%S), that the requested output(s) depend on, have not been provided", (int)missingRequiredArguments.size(), missingRequiredArgumentNames.c_str());
        }

        if (requiredArgumentValues.size() < arguments.size())
            fprintf(stderr, "WARNING: Function::Forward provided values for (%d) extra arguments which are not required for evaluating the specified Function outputs!\n", (int)(arguments.size() - requiredArgumentValues.size()));

        // Feed data into the arguments of the network
        // TODO: Avoid copying the data when possible
        PopulateNetworkInputs(requiredArgumentValues);

        // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input
        // This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated
        // w.r.t. inputs to force evaluation in each minibatch
        list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode));
        for (auto& nodeIter : dropoutNodes)
            nodeIter->SetEvalTimeStampOutdatedWrtAll();
        
        // Bump the timestamp of the parameter nodes whose values have changed
        for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps)
        {
            auto parameter = paramTimeStampRecord.first;
            auto prevTimeStamp = paramTimeStampRecord.second;
            auto newTimeStamp = parameter.CurrentValueTimeStamp();
            if (newTimeStamp > prevTimeStamp)
            {
                paramTimeStampRecord.second = newTimeStamp;
                m_variableToNodeMap.at(parameter)->BumpEvalTimeStamp();
            }
        }

        // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
        for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
        {
            if (outputs.find(rootVarForBackprop) == outputs.end())
                outputsToEvaluate.push_back(m_variableToNodeMap.at(rootVarForBackprop));
        }

        // Reset the timestamps of all backward roots to record an update in one or more inputs
        for (auto& backpropRoot : m_currentBackpropRoots)
            m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();

        // TODO: Verify that values were supplied for all inputs that requested outputs depend on

        ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);

        // We may have to include additional nodes in the ForwardProp to align with how the memory sharing structure is setup
        // We need to include all roots that lie earlier in the global eval order than the actual outputs we are interested
        // in evaluation.
        // TODO: This may incur additonal compute costs in some rare scenarios. We need to come up with a better way to handle this.
        outputsToEvaluate = m_computationNetwork->SortByGlobalEvalOrder(outputsToEvaluate);
        auto lastOutputInEvalOrder = outputsToEvaluate.back();
        auto iterEndRootInEvalOrder = std::find(m_allNetworkRootsInGlobalEvalOrder.begin(), m_allNetworkRootsInGlobalEvalOrder.end(), lastOutputInEvalOrder) + 1;

        auto augmentedOutputsToEvaluate = std::vector<ComputationNodeBasePtr>(m_allNetworkRootsInGlobalEvalOrder.begin(), iterEndRootInEvalOrder);
        m_computationNetwork->ForwardProp(augmentedOutputsToEvaluate);

        GetNetworkOutputs(outputs);

        // TODO: How to deal with the specified 'computeDevice'
        Variable evalTimeStampVariable;
        if (requiredArgumentValues.empty())
            evalTimeStampVariable = Inputs()[0];
        else
            evalTimeStampVariable = requiredArgumentValues.begin()->first;

        BackPropStatePtr backpropStatePtr;
        if (outputsToRetainBackwardStateFor.size() > 0)
            backpropStatePtr = MakeSharedObject<CNTKBackPropState>(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps());

        return backpropStatePtr;
    }

    /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
                                                 const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
                                                 std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
    {
        auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
        if (backpropState == nullptr)
            InvalidArgument("Invalid backprop state specified");

        // TODO: Support multiple concurrent backprop states
        std::unordered_map<Variable, uint64_t> currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps();
        if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps)
            LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function."
                       "This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");

        if (rootGradientValues.size() > 1)
            LogicError("Currently gradient backprop from only one of the Function Outputs is supported");

        // TODO: Avoid copying the data when possible

        // Zero all gradients of nodes below the root nodes
        for (auto rootGradientVarValuePair : rootGradientValues)
            m_computationNetwork->ZeroInputGradients(m_variableToNodeMap.at(rootGradientVarValuePair.first));

        // Feed data into the arguments of the network
        PopulateNetworkGradients(rootGradientValues);

        // Backpropagate through the network
        ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training);

        auto rootComputationNodePtr = m_variableToNodeMap.at(rootGradientValues.begin()->first);
        m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true);

        GetNetworkGradients(backPropagatedGradientValuesForInputs);

        // TODO: How to deal with the specified 'computeDevice'
    }
}
back to top