Revision f42f85e4b66e9473b74469402cdafb3e7bfeeff6 authored by Cheng Tang on 14 August 2017, 21:04:54 UTC, committed by Cheng Tang on 14 August 2017, 21:04:54 UTC
1 parent 22604c8
Raw File
CompositeFunction.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#include "stdafx.h"
#include "CNTKLibrary.h"
#include "CompositeFunction.h"
#include "ComputationNetworkBuilder.h"
#include "Utils.h"
#include "ComputationNode.h"
#include "ReshapingNodes.h"
#include "EvaluationNodes.h"
#include "TrainingNodes.h"
#include "LinearAlgebraNodes.h"
#include "InputAndParamNodes.h"
#include "NonlinearityNodes.h"
#include "RecurrentNodes.h"
#include "Serialization.h"
#include "Value.h"
#include "RNNNodes.h"
#include "UserDefinedV2FunctionNode.h"
#include "BlockFunction.h"
#include "SpecialPurposeNodes.h"
#include "SequenceReshapeNodes.h"
#include "UserDefinedFunction.h"

using namespace Microsoft::MSR::CNTK;

namespace CNTK
{
    /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName";
    /*static*/ std::atomic<unsigned int> CompositeFunction::s_nextAutoGeneratedDynamicAxis(0);

    static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction";

    Dictionary CompositeFunction::SerializeBlockComposite() const
    {
        Dictionary dict;

        dict[versionKey] = CurrentVersion();
        dict[typeKey] = s_compositeFunctionTypeValue;
        dict[rootKey] = RootFunction()->Uid();
        if (!Name().empty())
            dict[nameKey] = Name();
        dict[uidKey] = Uid();

        return dict;
    }

    // Copy the internal state from the network into the function graph,
    // specifically from RngUser nodes into the attributes dictionaries of 
    // the corresponding stateful primitive functions.
    void CompositeFunction::UpdateInternalState() const
    {
        if (!m_computationNetwork)
            return;

        for (auto& function : m_allPrimitiveFunctions)
        {
            auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
            if (!primitiveFunction || !primitiveFunction->IsStateful())
                continue;

            // TODO: same for BatchNorm

            auto& outputs = primitiveFunction->RawOutputs();
            if (outputs.size() != 1)
                LogicError("Function '%S' UpdateInternalState: a stateful primitive function must have a single output.", AsString().c_str());

            const auto& rng = m_variableToNodeMap.at(outputs[0])->As<RngUser>();

            Dictionary state;
            state[PrimitiveFunction::AttributeNameRngSeed] = static_cast<size_t>(rng->GetRngSeed());
            state[PrimitiveFunction::AttributeNameRngOffset] = static_cast<size_t>(rng->GetRngOffset());
            primitiveFunction->SetState(state);
        }
    }

    // Generate a dictionary representing the internal (local) state of the function graph.
    Dictionary CompositeFunction::GetInternalState() const
    {
        UpdateInternalState();

        Dictionary stateDictionary;
        for (auto& function : m_allPrimitiveFunctions)
        {
            auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
            if (!primitiveFunction || !primitiveFunction->IsStateful())
                continue;

            // TODO: same for BatchNorm

            stateDictionary[primitiveFunction->Uid()] = primitiveFunction->GetState();
        }
        return stateDictionary;
    }

    /*virtual*/ Dictionary CompositeFunction::Serialize() const
    {
        UpdateInternalState();

        Dictionary dict = SerializeBlockComposite();
       
        // Find cycles in the graph and "break" them by inserting placeholders.
        // This needs to be done on Save, since here we have easy access to the shape and 
        // dynamic axis info.
        std::unordered_set<FunctionPtr> visitedFunctions;
        std::vector<FunctionPtr> topoSortedPrimitiveFunctions;
        std::vector<Variable> uniqueInputs;
        std::unordered_set<std::wstring> inputUids;
        std::function<void(const FunctionPtr& function)> SerializationTraversalFunc;
        SerializationTraversalFunc = [&visitedFunctions, &uniqueInputs, &topoSortedPrimitiveFunctions, &inputUids, &SerializationTraversalFunc](const FunctionPtr& function) {
            std::vector<Variable> functionInputs = function->Inputs();
            for (const auto& input : functionInputs)
            {
                auto& uid = input.Uid();
                if (inputUids.find(uid) != inputUids.end())
                    continue;

                // check if this input corresponds to a cyclic edge in the graph.
                // BUG: A function being visited twice does not indicate it being a cyclic edge in the graph.
                // It just means there are at least 2 successors in the graph that have the function as input
                bool mustBeReplaced = input.IsOutput() && (visitedFunctions.find(input.Owner()) != visitedFunctions.end());

                if (mustBeReplaced)
                {
                    auto varKind = VariableKind::Placeholder;
                    Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid);
                    uniqueInputs.push_back(var);
                    inputUids.insert(uid);
                }
                else if (!input.IsOutput())
                {
                    // leave the input as is.
                    uniqueInputs.push_back(input);
                    inputUids.insert(uid);
                }
            }
            visitedFunctions.insert(function);
            topoSortedPrimitiveFunctions.push_back(function);

            // For block functions we need to recursively traverse the underlying composite
            if (function->IsBlock())
                PreorderTraverseFunctions(function->BlockRoot(), SerializationTraversalFunc);
        };

        PreorderTraverseFunctions(RootFunction(), SerializationTraversalFunc);

        std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions));

        assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid());
        
        std::vector<DictionaryValue> inputDictionaries;
        inputDictionaries.reserve(uniqueInputs.size());
        inputUids.clear();
        for (const auto& input : uniqueInputs)
        {
            if (inputUids.find(input.Uid()) != inputUids.end())
                LogicError("Function '%S' Serialize: Input uids must be unique.", AsString().c_str());

            inputUids.insert(input.Uid());
            inputDictionaries.push_back(input.Serialize());
        }

        dict[inputsKey] = std::move(inputDictionaries);
       
        std::vector<DictionaryValue>  functionDictionaries;
        std::unordered_set<std::wstring> outputUids;
        for (const auto& primitiveFunction : topoSortedPrimitiveFunctions)
        {
            for (const auto& output : primitiveFunction->RawOutputs())
            {
                if (outputUids.find(output.Uid()) != outputUids.end())
                    LogicError("Function '%S' Serialize: Output uids of all primitive functions in a function graph must be unique", AsString().c_str());

                outputUids.insert(primitiveFunction->Uid());
            }

            auto functionDict = UDFUtils::IsUDF(primitiveFunction)  ? UDFUtils::Serialize(primitiveFunction) : primitiveFunction->Serialize();
            functionDictionaries.push_back(functionDict);
        }

        dict[functionsKey] = std::move(functionDictionaries);

        return dict;
    }

    /*static*/ FunctionPtr CompositeFunction::DeserializeBlockComposite(const Dictionary& dict,
                                                                        const std::unordered_set<FunctionPtr>& allPrimitiveFunctions,
                                                                        const std::unordered_map<Variable, Variable>& allPlaceholderReplacements,
                                                                        const CNTK::DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, rootKey, uidKey };
        ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);

        const auto& rootUid = dict[rootKey].Value<std::wstring>();
        std::wstring name = L"";
        if (dict.Contains(nameKey))
            name = dict[nameKey].Value<std::wstring>();
        const auto& uid = dict[uidKey].Value<std::wstring>();

        FunctionPtr root = *std::find_if(allPrimitiveFunctions.begin(), allPrimitiveFunctions.end(), [&rootUid](const FunctionPtr& func) {
            return func->Uid() == rootUid;
        });

        // Find the subset of placeholder replacements that apply for this composite
        FunctionPtr composite = CompositeFunction::Create(root, name, uid);
        std::unordered_map<Variable, Variable> placeholderReplacements;
        do
        {
            placeholderReplacements.clear();
            auto compositePlaceholders = composite->Placeholders();
            for (auto placeholder : compositePlaceholders)
            {
                if (allPlaceholderReplacements.find(placeholder) != allPlaceholderReplacements.end())
                    placeholderReplacements.insert({ placeholder, allPlaceholderReplacements.at(placeholder) });
            }

            if (placeholderReplacements.size() > 0)
                composite = composite->ReplacePlaceholders(placeholderReplacements);

        } while (placeholderReplacements.size() > 0);

        return composite;
    }

    /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { inputsKey, functionsKey };
       
        size_t version = ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);

        const auto& inputs = dict[inputsKey].Value<vector<DictionaryValue>>();
        std::unordered_map<std::wstring, Variable> uidToInputMap(inputs.size());
        for (const auto& dictionaryValue : inputs)
        {
            const auto& dictionary = dictionaryValue.Value<Dictionary>();
            const auto& inputVar = Variable::Deserialize(dictionary, device);

            if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end())
            {
                CNTK::LogicError("CompositeFunction::Deserialize: Input uids are not unique (several inputs share '%ls' uid) (%s).",
                                 inputVar.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
            }

            uidToInputMap[inputVar.Uid()] = inputVar;
        }

        const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();

        std::unordered_map<Variable, Variable> allPlaceholderReplacements;
        std::unordered_set<FunctionPtr> allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created.
        for (const auto& dictionaryValue : functions)
        {
            auto functionDict = dictionaryValue.Value<Dictionary>();
            FunctionPtr root = UDFUtils::IsUDF(functionDict) ?
                UDFUtils::Deserialize(functionDict, uidToInputMap, device) :
                PrimitiveFunction::Deserialize(functionDict, uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device);
            allPrimitiveFunctions.insert(root);

            auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(root.get());

            if (primitiveFunction != nullptr && primitiveFunction->OpType() == PrimitiveOpType::Combine) 
            {
                // Since Combine simply forwards other functions' outputs, all of its outputs
                // should already be in the uidToInputMap.
                continue;
            }

            for (const auto& output : root->RawOutputs())
            {
                const auto& it = uidToInputMap.find(output.Uid());
                if (it != uidToInputMap.end())
                {
                    if (!it->second.IsPlaceholder())
                    {
                        CNTK::LogicError("CompositeFunction::Deserialize: Unexpected variable '%S' instead of a Placeholder (uid = %ls) (%s).",
                                         it->second.AsString().c_str(), it->second.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
                    }
                    allPlaceholderReplacements[it->second] = output;
                }
                else
                {
                    uidToInputMap[output.Uid()] = output;
                }
            }
        }

        // starting with the serialization version = 3, the state is preserved inside the attribute dictionaries of the
        // corresponding primitive functions. Earlier versions have a dedicated key-value pair in the composite function dict.
        if (version < 3)
            RestoreStatefulFunctions(version, dict, allPrimitiveFunctions);

        return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
    }

    void CompositeFunction::RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> functions) 
    {
        Dictionary stateDictionary;
        if (dict.Contains(stateKey))
            stateDictionary = dict[stateKey].Value<Dictionary>();

        for (auto& function : functions)
        {
            auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
            if (!primitiveFunction || !primitiveFunction->IsStateful())
                continue;

            if (stateDictionary.Contains(primitiveFunction->Uid()))
            {
                auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
                // Add key-value pairs expected by the SetState method to the state dictionary.
                state[PrimitiveFunction::AttributeNameRngSeed] = state[rngSeedKey].Value<size_t>();
                state[PrimitiveFunction::AttributeNameRngOffset] = state[rngOffsetKey].Value<size_t>();
                primitiveFunction->SetState(state);
            }
            else 
            {
                if (GetTraceLevel() >= TraceLevel::Warning) 
                {
                    // TODO: all logging functionality should be refactored to live in a logging utility class. 
                    fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
                        "when deserializing from a dictionary (version=%zu). "
                        "Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
                }

                // Create state from scratch, so that function attributes contain all the required key-value pairs.
                Dictionary state;
                state[PrimitiveFunction::AttributeNameRngSeed] = Internal::GenerateRandomSeed(true);
                state[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
                primitiveFunction->SetState(state);
            }
        }
    }

    void CompositeFunction::CopyState(const CompositeFunction& source)
    {
        // Collect a vector of stateful function uids using a pre-order traversal of a function graphs.
        auto collectStatefulFunctionUIDs = [](const Function& function) -> vector<wstring> {
            vector<wstring> uids;
            PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
                auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
                if (primitiveFunction && primitiveFunction->IsStateful())
                {
                    uids.push_back(funcPtr->Uid());
                }
            }, true);

            return uids;
        };

        auto theirUIDs = collectStatefulFunctionUIDs(source);
        auto ourUIDs = collectStatefulFunctionUIDs(*this);

        if (theirUIDs.size() != ourUIDs.size())
            CNTK::LogicError("Cannot copy internal state, the source and the destination contain different number of stateful functions.");

        auto state = source.GetInternalState();

        if (theirUIDs == ourUIDs)
        {
            // uids are identialy, no need to remap.
            SetInternalState(state);
            return;
        }
        
        // build a map of souce funtion to the destination (this) function UIDs.
        map<wstring, wstring> uidMap;
        for (auto i = 0; i < theirUIDs.size(); i++)
            uidMap[theirUIDs[i]] = ourUIDs[i];
        
        Dictionary remappedState;
        for (auto& kv : state)
            remappedState[uidMap[kv.first]] = kv.second;

        SetInternalState(remappedState);
    }

    void CompositeFunction::SetInternalState(const Dictionary& state)
    {
        if (state.Size() == 0)
            return;

        for (const auto& function : m_allPrimitiveFunctions)
        {
            auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
            if (!primitiveFunction || !primitiveFunction->IsStateful())
                continue;

            auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();
            
            primitiveFunction->SetState(functionState);

            if (!m_computationNetwork)
                continue;

            auto seed = functionState[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
            auto offset = functionState[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();

            // copy the state directly into the network
            for (const auto& output : function->RawOutputs())
            {
                auto node = m_variableToNodeMap.at(output);
                node->As<RngUser>()->SetRngState(seed, offset);
            }
        }
    }


    // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions 
    // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the 
    // top level 'variable'
    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
                                                                 Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                 ComputationNetworkBuilder<ElementType>& builder,
                                                                 const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
                                                                 std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                 std::unordered_map<Variable, bool>& isVariableRootMap,
                                                                 const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
                                                                 bool useMangledNamesForComputationNodes)
    {
        auto iter = variableToNodeMap.find(variable);
        if (iter != variableToNodeMap.end())
        {
            isVariableRootMap[variable] = false;
            return iter->second;
        }

        // The DataType, Shape and DynamicAxes of the variable must be known by now
        if (variable.GetDataType() == DataType::Unknown)
            InvalidArgument("Variable '%S' with unknown DataType found when compiling the Function graph.", variable.AsString().c_str());

        if (variable.Shape().IsUnknown())
            InvalidArgument("Variable '%S' with unknown shape found when compiling the Function graph.", variable.AsString().c_str());

        if (variable.DynamicAxes() == Axis::UnknownDynamicAxes())
            InvalidArgument("Variable '%S' with unknown dynamic axes found when compiling the Function graph.", variable.AsString().c_str());

        // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs
        variableToNodeMap[variable] = nullptr;

        std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
        auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes);
        if (variable.IsParameter() || variable.IsConstant())
        {
            if (variable.Shape().HasInferredDimension())
                InvalidArgument("Parameter or Constant '%S' with unresolved shape %S found when compiling the Function graph.", variable.AsString().c_str(), variable.Shape().AsString().c_str());

            computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape()));
            network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
            if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end()))
                computationNodePtr->SetLearningRateMultiplier(0.0);

            NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
            std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();

            if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
                computationNodePtr->Value() = valueMatrix->AsReference();
            else // Constant: if initialized data lives on wrong device, make a copy to the right one (copy is OK since it's constant)
            {
                // TODO: the following two lines are a workaround for a bug in the Math library
                // (AssignValuesOf throws when source and destination matrices reside on different GPU devices).
                // Once this bug is fixed, change to 
                // Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
                Matrix<ElementType> clonedMatrix(network->GetDeviceId());
                clonedMatrix.SwitchToMatrixType(valueMatrix->GetMatrixType(), valueMatrix->GetFormat(), false);
                clonedMatrix.AssignValuesOf(*valueMatrix);
                computationNodePtr->Value() = std::move(clonedMatrix);
            }
        }
        else if (variable.IsInput())
        {
            auto fullyDefinedArgumentVar = variable;
            if (fullyDefinedArgumentVar.Shape().HasFreeDimension() && (fullyDefinedArgumentsMap.find(fullyDefinedArgumentVar) != fullyDefinedArgumentsMap.end()))
                fullyDefinedArgumentVar = fullyDefinedArgumentsMap.at(fullyDefinedArgumentVar);

            if (fullyDefinedArgumentVar.Shape().HasUnboundDimension())
                InvalidArgument("Input Variable '%S' with unresolved shape %S found when compiling the Function graph.", fullyDefinedArgumentVar.AsString().c_str(), fullyDefinedArgumentVar.Shape().AsString().c_str());

            // TODO: Input variables currently are required to have the default batch axis
            auto dynamicAxes = variable.DynamicAxes();
            auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis());
            if (IsSparseInput(variable) && (foundDefaultBatchAxis == dynamicAxes.end()))
                CNTK::LogicError("Sparse Input Variable '%S' found without a DefaultBatchAxis dynamic axis; this is currently unsupported.", variable.AsString().c_str());

            if (!dynamicAxes.empty() && (dynamicAxes.back() != Axis::DefaultBatchAxis()))
                CNTK::LogicError("Input Variable '%S' does not have the DefaultBatchAxis as its last dynamic axis.", variable.AsString().c_str());

            // TODO: Support inputs with > 1 dynamic axes
            if (dynamicAxes.size() > 2)
                CNTK::LogicError("Input Variable '%S' has %d dynamic axes; currently only inputs with <= 2 dynamic axes are supported.",
                                 variable.AsString().c_str(), (int)dynamicAxes.size());

            if (!dynamicAxes.empty())
            {
                // Construct the dynamic axis name to be used internally for the CNTK InputNodes
                std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);

                if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
                    network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});

                if (IsSparseInput(variable))
                    computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()), internalDynamicAxisName);
                else
                    computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()), internalDynamicAxisName);

                if (variable.NeedsGradient() && (inputsToExcludeGradientsFor.find(variable) == inputsToExcludeGradientsFor.end()))
                {
                    // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default
                    // gradients are not computed for Input nodes
                    computationNodePtr->SetLearningRateMultiplier(0.00001f);
                }
            }
            else
            {
                computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()));
                network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
                if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end()))
                    computationNodePtr->SetLearningRateMultiplier(0.0);
            }

            if (variable.Shape().HasFreeDimension())
                computationNodePtr->MarkNeedsDynamicValidation();
        }
        else
        {
            assert(variable.IsOutput());
            auto outputVariableNode = GetOutputVariableNode(variable, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
            // Can be null in case of loops with f.output == f.input.
            // Such loops cannot be handled, so we leave nullptr as computational node.
            if (outputVariableNode)
                computationNodePtr = outputVariableNode->template As<ComputationNode<ElementType>>()->shared_from_this();
            else
                computationNodePtr = nullptr;
        }

        variableToNodeMap[variable] = computationNodePtr;
        if (isVariableRootMap.find(variable) == isVariableRootMap.end())
            isVariableRootMap[variable] = variable.IsOutput();

        return computationNodePtr;
    }

    /*static*/ Variable CompositeFunction::GetMappingForNoOpOutput(const Variable& variable, bool recursive)
    {
        Variable mappingVariable = variable;

        auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
        auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
        if (ownerPrimitiveFunc && (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp))
            mappingVariable = ownerPrimitiveFunc->Inputs()[0];

        if (recursive && (mappingVariable != variable))
            return GetMappingForNoOpOutput(mappingVariable);
        else
            return mappingVariable;
    }

    /*static*/ Variable CompositeFunction::GetMappingVariable(const Variable& variable, bool recursive)
    {
        Variable mappingVariable = variable;

        auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
        auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
        if (ownerPrimitiveFunc)
        {
            if (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp)
                mappingVariable = GetMappingForNoOpOutput(variable);
            else
            {
                auto ownerBlockFunc = dynamic_cast<BlockFunction*>(ownerFunc);
                if (ownerBlockFunc)
                    mappingVariable = ownerBlockFunc->CompositeOutputsMap().at(variable);
            }
        }

        if (recursive && (mappingVariable != variable))
            return GetMappingVariable(mappingVariable);
        else
            return mappingVariable;
    }

    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
                                                                               Function* function,
                                                                               const std::vector<std::shared_ptr<ComputationNode<ElementType>>>& inputNodes,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                               bool useMangledNamesForComputationNodes)
    {
        PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
        if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp))
            return variableToNodeMap[GetMappingVariable(variable)];

        ComputationNodeBasePtr computationNodePtr;

        auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name(), useMangledNamesForComputationNodes);

        std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
        for (auto inputNode : inputNodes)
            inputNodesBasePtrs.push_back(inputNode);

        auto outputs = function->RawOutputs();
        if (variable == outputs[0])
        {
            if (primitiveFunction)
            {
                auto functionInputs = function->Inputs();
                auto& functionConfig = function->Attributes();
                PrimitiveOpType op = primitiveFunction->OpType();

                switch (op)
                {
                case PrimitiveOpType::Negate:
                    computationNodePtr = New<NegateNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Sigmoid:
                    computationNodePtr = New<SigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Tanh:
                    computationNodePtr = New<TanhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Acos:
                    computationNodePtr = New<AcosNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Cos:
                    computationNodePtr = New<CosineNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Asin:
                    computationNodePtr = New<AsinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Sin:
                    computationNodePtr = New<SinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Cosh:
                    computationNodePtr = New<CoshNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Sinh:
                    computationNodePtr = New<SinhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ReLU:
                    computationNodePtr = New<RectifiedLinearNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Exp:
                    computationNodePtr = New<ExpNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Log:
                    computationNodePtr = New<LogNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Sqrt:
                    computationNodePtr = New<SqrtNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ELU:
                    computationNodePtr = New<ExponentialLinearUnitNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Floor:
                    computationNodePtr = New<FloorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Abs:
                    computationNodePtr = New<AbsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Reciprocal:
                    computationNodePtr = New<ReciprocalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Softmax:
                    computationNodePtr = New<SoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Hardmax:
                    computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::StableSigmoid:
                    computationNodePtr = New<StableSigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::TransposeAxes:
                {
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec))
                    {
                        auto perm = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
                        for (auto& p : perm)
                            p = NormalizeStaticAxis(p, perm.size());
                        computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(perm));
                    }
                    else
                    {
                        auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
                        auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();

                        // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
                        computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
                    }
                    break;
                }
                case PrimitiveOpType::Where:
                {
                    auto dynamicAxes = variable.DynamicAxes();
                    auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
                    computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
                    break;
                }
                case PrimitiveOpType::ToSequence:
                {
                    auto dynamicAxes = variable.DynamicAxes();
                    auto internalCNTKDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
                    computationNodePtr = New<ToSequenceNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKDynamicAxisName);
                    break;
                }
                case PrimitiveOpType::ToSequenceLike:
                    computationNodePtr = New<ToSequenceLikeNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::UnpackSequence:
                {
                    auto paddingValue = functionConfig[PrimitiveFunction::AttributeNameSequenceUnpackPaddingValue].Value<double>();
                    auto suppressMaskOutput = functionConfig[PrimitiveFunction::AttributeNameSequenceUnpackSuppressMaskOutput].Value<bool>();
                    computationNodePtr = New<UnpackSequenceNode<ElementType>>(network->GetDeviceId(), internalNodeName, (ElementType)paddingValue, suppressMaskOutput);
                    break;
                }
                case PrimitiveOpType::Slice:
                {
                    std::vector<Axis> axis;
                    std::vector<int> beginIndex, endIndex, strides;
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec) &&
                        functionConfig.Contains(PrimitiveFunction::AttributeNameBeginIndexVec) &&
                        functionConfig.Contains(PrimitiveFunction::AttributeNameEndIndexVec))
                    {
                        axis = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
                        beginIndex = AsVector<int>(functionConfig[PrimitiveFunction::AttributeNameBeginIndexVec].Value<std::vector<DictionaryValue>>());
                        endIndex = AsVector<int>(functionConfig[PrimitiveFunction::AttributeNameEndIndexVec].Value<std::vector<DictionaryValue>>());
                        if (functionConfig.Contains(PrimitiveFunction::AttributeNameSliceStridesVec))
                            strides = AsVector<int>(functionConfig[PrimitiveFunction::AttributeNameSliceStridesVec].Value<std::vector<DictionaryValue>>());
                        else
                            strides.resize(axis.size(), 1);
                    }
                    else if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxis) &&
                        functionConfig.Contains(PrimitiveFunction::AttributeNameBeginIndex) &&
                        functionConfig.Contains(PrimitiveFunction::AttributeNameEndIndex))
                    {
                        axis.push_back(functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>());
                        beginIndex.push_back(functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value<int>());
                        endIndex.push_back(functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value<int>());
                        if (functionConfig.Contains(PrimitiveFunction::AttributeNameSliceStrides))
                            strides.push_back(functionConfig[PrimitiveFunction::AttributeNameSliceStrides].Value<int>());
                        else
                            strides.push_back(1);
                    }
                    else
                    {
                        RuntimeError("Failed to create computation node: Slice operation with inconsistent attributes");
                    }
                    // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
                    computationNodePtr = New<SliceNode<ElementType>>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis), strides);
                    break;
                }
                case PrimitiveOpType::RandomSample:
                {
                    auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
                    auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
                    computationNodePtr = New<RandomSampleNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
                    break;
                }
                case PrimitiveOpType::RandomSampleInclusionFrequency:
                {
                    auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
                    auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
                    computationNodePtr = New<RandomSampleInclusionFrequencyNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
                    break;
                }
                case PrimitiveOpType::Dropout:
                {
                    auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value<double>();
                    computationNodePtr = New<DropoutNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    computationNodePtr->As<DropoutNode<ElementType>>()->SetDropoutRate(dropoutRate);
                    break;
                }
                case PrimitiveOpType::RandomDistribution:
                {
                    auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                    auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
                    auto rvtype = functionConfig[PrimitiveFunction::AttributeNameRandomDistributionType].Value<std::wstring>();

                    std::vector<double> randomDistributionArgs;
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameRandomDistributionArgs))
                        randomDistributionArgs = AsVector<double>(functionConfig[PrimitiveFunction::AttributeNameRandomDistributionArgs].Value<std::vector<DictionaryValue>>());

                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewShape))
                    {
                        auto shape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
                        computationNodePtr = New<RandomDistributionNode<ElementType>>(network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs, AsTensorShape(shape));
                    }
                    else
                        computationNodePtr = New<RandomDistributionNode<ElementType>>(network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs);
                    computationNodePtr->As<RandomDistributionNode<ElementType>>()->SetRngState(seed, offset);
                    break;
                }
                case PrimitiveOpType::Reshape:
                {
                    auto beginAxis = Axis(0);
                    auto endAxis = Axis((int)functionInputs[0].Shape().Rank());
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameBeginAxis))
                        beginAxis = functionConfig[PrimitiveFunction::AttributeNameBeginAxis].Value<Axis>();

                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameEndAxis))
                        endAxis = functionConfig[PrimitiveFunction::AttributeNameEndAxis].Value<Axis>();

                    auto replacementShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
                    for (size_t i = 0; i < replacementShape.Rank(); ++i)
                    {
                        if (replacementShape[i] == NDShape::InferredDimension)
                            replacementShape[i] = 0;
                    }

                    computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(replacementShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis));
                    break;
                }
                case PrimitiveOpType::ROIPooling:
                {
                    PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
                    auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value<NDShape>();
                    auto spatialScale = functionConfig[PrimitiveFunction::AttributeNameSpatialScale].Value<double>();
                    computationNodePtr = New<ROIPoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(roiOutputShape), spatialScale);
                    break;
                }
                case PrimitiveOpType::Pooling:
                {
                    PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
                    auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
                    auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                    auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                    auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                    auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                    auto ceilOutDim = false;
                    auto includePad = false;
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameCeilOutDim))
                    {
                        ceilOutDim = functionConfig[PrimitiveFunction::AttributeNameCeilOutDim].Value<bool>();
                    }
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameIncludePad))
                    {
                        includePad = functionConfig[PrimitiveFunction::AttributeNameIncludePad].Value<bool>();
                    }
                    computationNodePtr = New<PoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutDim, includePad, ImageLayoutKind::CHW);
                    break;
                }
                case PrimitiveOpType::Unpooling:
                {
                    auto unpoolingWindowShape = functionConfig[PrimitiveFunction::AttributeNameUnpoolingWindowShape].Value<NDShape>();
                    auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                    auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                    auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                    auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                    //We only get here after validation so it is safe to assume unpooling is max
                    computationNodePtr = New<MaxUnpoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
                    break;
                }
                case PrimitiveOpType::SumAll:
                    computationNodePtr = New<SumElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::OneHot:
                {
                    auto numClass = functionConfig[PrimitiveFunction::AttributeNameNumClass].Value<size_t>();
                    auto is_sparse = functionConfig[PrimitiveFunction::AttributeNameOneHotOutputSparse].Value<bool>();
                    auto axis = functionConfig[PrimitiveFunction::AttributeNameOneHotAxis].Value<Axis>();
                    computationNodePtr = New<OneHotNode<ElementType>>(network->GetDeviceId(), numClass, is_sparse, axis.StaticAxisIndex(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::Gather:
                    computationNodePtr = New<GatherNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ToBatch:
                {
                    computationNodePtr = New<ToBatchAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::UnpackBatch:
                {
                    computationNodePtr = New<UnpackBatchAixsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::Plus:
                    computationNodePtr = New<PlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::LogPlus:
                    computationNodePtr = New<LogPlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Pow:
                    computationNodePtr = New<PowNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Minus:
                    computationNodePtr = New<MinusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ElementTimes:
                    computationNodePtr = New<ElementTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Equal:
                    computationNodePtr = New<EqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::NotEqual:
                    computationNodePtr = New<NotEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Less:
                    computationNodePtr = New<LessNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::LessEqual:
                    computationNodePtr = New<LessEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Greater:
                    computationNodePtr = New<GreaterNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::GreaterEqual:
                    computationNodePtr = New<GreaterEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Times:
                {
                    size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
                    auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value<int>();
                    computationNodePtr = New<TimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
                    break;
                }
                case PrimitiveOpType::TransposeTimes:
                {
                    size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
                    computationNodePtr = New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank);
                    break;
                }
                case PrimitiveOpType::Convolution:
                {
                    auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                    auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                    auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                    auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
                    auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                    auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
                    NDShape outputMapCount, kernelShape;
                    std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose);
                    NDShape outputShape = NDShape::Unknown();
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameOutputShape))
                        outputShape = functionConfig[PrimitiveFunction::AttributeNameOutputShape].Value<NDShape>();
                    auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
                    computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName,
                        AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides),
                        sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose,
                        outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape),
                        ImageLayoutKind::CHW, maxTempMemSizeInSamples);
                    break;
                }
                case PrimitiveOpType::CosDistance:
                    computationNodePtr = New<CosDistanceNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::CosDistanceWithNegativeSamples:
                    computationNodePtr = New<CosDistanceWithNegativeSamplesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Logistic:
                    computationNodePtr = New<LogisticNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::SquaredError:
                    computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::CrossEntropyWithSoftmax:
                    computationNodePtr = New<CrossEntropyWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ClassificationError:
                    computationNodePtr = New<ClassificationErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::EditDistanceError:
                {
                    auto subPen = functionConfig[PrimitiveFunction::AttributeNameSubstitutionPenalty].Value<float>();
                    auto delPen = functionConfig[PrimitiveFunction::AttributeNameDeletionPenalty].Value<float>();
                    auto insPen = functionConfig[PrimitiveFunction::AttributeNameInsertionPenalty].Value<float>();
                    auto squashInputs = functionConfig[PrimitiveFunction::AttributeNameSquashInputs].Value<bool>();
                    auto tokensToIgnore = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNameTokensToIgnore].Value<std::vector<DictionaryValue>>());
                    computationNodePtr = New<EditDistanceErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore);
                    break;
                }
                case PrimitiveOpType::ForwardBackward:
                {
                    auto delayContraint = functionConfig[PrimitiveFunction::AttributeNameDelayConstraint].Value<int>();
                    auto blankTokenId = functionConfig[PrimitiveFunction::AttributeNameBlankTokenId].Value<size_t>();
                    computationNodePtr = New<ForwardBackwardNode<ElementType>>(network->GetDeviceId(), internalNodeName, blankTokenId, delayContraint);
                    break;
                }
                case PrimitiveOpType::LambdaRank:
                    computationNodePtr = New<LambdaRankNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::NDCG:
                    computationNodePtr = New<NDCG1EvalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::PastValue:
                case PrimitiveOpType::FutureValue:
                {
                    Variable inputOperandVar = functionInputs[0];
                    Variable initialStateVar = functionInputs[1];

                    size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
                    if (op == PrimitiveOpType::PastValue)
                        computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
                    else
                        computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);

                    break;
                }
                case PrimitiveOpType::ReduceElements:
                {
                    bool keepDimensions = true;
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameReductionKeepDimensions))
                        keepDimensions = functionConfig[PrimitiveFunction::AttributeNameReductionKeepDimensions].Value<bool>();
                    auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
                    std::vector<Axis> reductionAxis;
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec))
                    {
                        reductionAxis = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
                     }
                    else if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxis))
                    {
                        reductionAxis.push_back(functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>());
                    }
                    else
                    {
                        RuntimeError("Failed to create computation node': Reduce operation %ls with no '%ls' or  '%ls' attributes",
                            PrimitiveOpTypeName(op).c_str(),
                            PrimitiveFunction::AttributeNameAxis.c_str(),
                            PrimitiveFunction::AttributeNameAxisVec.c_str()
                        );

                    } 
                    computationNodePtr = New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis), keepDimensions);
                    break;
                }
                case PrimitiveOpType::BatchNormalization:
                {
                    auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
                    auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
                    auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
                    auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
                    auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();

                    computationNodePtr = New<BatchNormalizationNode<ElementType>>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW);
                    break;
                }
                case PrimitiveOpType::Combine:
                    // This operation is just a no-op and is a means to combine multiple functions to create a single Function
                    // whose outputs are a union of the outputs of the Functions being combined.
                    computationNodePtr = variableToNodeMap[variable];
                    break;
                case PrimitiveOpType::PackedIndex:
                    computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::GatherPacked:
                    computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::ScatterPacked:
                    computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Clip:
                    computationNodePtr = New<ClipNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Select:
                    computationNodePtr = New<IfNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Splice:
                {
                    Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
                    computationNodePtr = New<RowStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
                    break;
                }
                case PrimitiveOpType::Pad:
                {
                    auto head = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNamePaddingHead].Value<std::vector<DictionaryValue>>());
                    auto foot = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNamePaddingFoot].Value<std::vector<DictionaryValue>>());
                    auto mode = functionConfig[PrimitiveFunction::AttributeNamePaddingMode].Value<size_t>();
                    auto constantValue = functionConfig[PrimitiveFunction::AttributeNamePaddingConstantValue].Value<double>();
                    computationNodePtr = New<PaddingNode<ElementType>>(network->GetDeviceId(), internalNodeName, head, foot, (PaddingType)mode, (ElementType)constantValue);
                    break;
                }
                case PrimitiveOpType::OptimizedRNNStack:
                {
                    auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value<bool>();
                    auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value<size_t>();
                    auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value<size_t>();
                    auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value<std::wstring>();

                    computationNodePtr = New<OptimizedRNNStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
                    break;
                }
                case PrimitiveOpType::ReconcileDynamicAxis:
                {
                    computationNodePtr = New<ReconcileDynamicAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::LogSoftmax:
                {
                    //This can be implemented as x => x - ReduceLogSum(x). How to do this here?
                    computationNodePtr = New<LogSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                }
                case PrimitiveOpType::Pass:
                    computationNodePtr = New<PassNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::LabelsToGraph:
                    computationNodePtr = New<LabelsToGraphNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::StopGradient:
                    computationNodePtr = New<StopGradientNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                case PrimitiveOpType::Assign:
                    computationNodePtr = New<AssignNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                    break;
                default:
                    CNTK::LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
                    break;
                }

                // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
                ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);

                if (computationNodePtr->Is<INumInputs>())
                {
                    auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
                    if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
                        CNTK::LogicError("The Primitive Function '%S' has %d inputs while the corresponding ComputationNode expects %d inputs.",
                            function->AsString().c_str(),
                            (int)inputNodesBasePtrs.size(),
                            (int)computationNodeExpectedInputCount);
                }

                if (computationNodePtr->Is<RngUser>())
                {
                    auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                    auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
                    computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
                }
            }
            else
            {
                computationNodePtr = New<UserDefinedV2FunctionNode<ElementType>>(network->GetDeviceId(), internalNodeName, function->shared_from_this());

                // For user defined functions, we only attach unique inputs in the internal computation network since, the UDF 
                // backward implementations directly compute aggregate gradient values for unique inputs
                std::vector<ComputationNodeBasePtr> uniqueInputNodesBasePtrs;
                for (auto inputNodeBasePtr : inputNodesBasePtrs)
                {
                    if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end())
                        uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr);
                }

                inputNodesBasePtrs = uniqueInputNodesBasePtrs;
            }
        }
        else
        {
            size_t i = 1;
            while (outputs[i] != variable) i++;
            assert(i < outputs.size());

            computationNodePtr = New<OutputMultiplexerNode<ElementType>>(network->GetDeviceId(), CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes), i);
            inputNodesBasePtrs = { variableToNodeMap[outputs[0]] };
        }

        network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
        return computationNodePtr;
    }

    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               ComputationNetworkBuilder<ElementType>& builder,
                                                                               const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                               std::unordered_map<Variable, bool>& isVariableRootMap,
                                                                               const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
                                                                               bool useMangledNamesForComputationNodes)
    {
        assert(variable.IsOutput());

        Function* function = variable.Owner().get();
        ComputationNodeBasePtr computationNodePtr;
        auto& functionInputs = function->m_inputs;

        DataType nonConstInputDataType = DataType::Unknown;
        for (auto& inputVar : functionInputs)
        {
            if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown))
            {
                nonConstInputDataType = inputVar.GetDataType();
                break;
            }
        }
            
        // Create the nodes corresponding to the inputs
        std::vector<std::shared_ptr<ComputationNode<ElementType>>> inputNodes;
        for (auto& inputVar : functionInputs)
        {
            // If the inputVar is a constant and not the right DataType let's coerce it to the right type
            if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType))
                inputVar = Constant(inputVar).CloneAs(nonConstInputDataType);

            auto baseNodePtr = GetNode(inputVar, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
            inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As<ComputationNode<ElementType>>()->shared_from_this() : nullptr);
        }

        BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function);
        if (blockFunction)
        {
            // For block function, map each argument placeholder of the underlying composite to
            // the computation node corresponding to the block input that the argument placeholder
            // of the composite is mapped to.
            auto compositeArguments = blockFunction->Composite()->Arguments();
            for (auto compositeArgument : compositeArguments)
                variableToNodeMap[compositeArgument] = variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());

            return GetNode(variable.BlockFunctionVariableMapping(), network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
        }
        else
            computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap, useMangledNamesForComputationNodes);

        PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
        if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine))
        {
            for (auto inputVar : functionInputs)
                isVariableRootMap[inputVar] = false;
        }

        return computationNodePtr;
    }

    std::unordered_set<Variable> CompositeFunction::NonOwnerPreservingCopy(const std::unordered_set<Variable>& outputs)
    {
        std::unordered_set<Variable> result;
        for (auto& o : outputs)
        {
            Variable sanitized = o.NonCompositePreservingCopy();
            result.insert(sanitized);
        }

        return result;
    }

    static bool VariableShapeMatchesNodeShape(const NDShape& varShape, const TensorShape& nodeShape) 
    {
        if (varShape.Rank() == 0)
            return (nodeShape.GetNumElements() == 1);

        // Sometimes the nodeShape may have an additional trailing axis with dim==1 due to the lack of support for 0-d tensors in V1 engine.
        auto adjustedNodeShape = nodeShape;
        while ((adjustedNodeShape.GetRank() > varShape.Rank()) && (adjustedNodeShape.GetDim(adjustedNodeShape.GetRank() - 1) == 1))
            adjustedNodeShape.TrimRankInPlace(adjustedNodeShape.GetRank() - 1);

        if (!varShape.HasUnboundDimension())
            return (AsNDShape(adjustedNodeShape) == varShape);

        if (varShape.Rank() != adjustedNodeShape.GetRank())
            return false;

        for (size_t i = 0; i < varShape.Rank(); ++i)
        {
            if ((varShape[i] != NDShape::FreeDimension) && (varShape[i] != NDShape::InferredDimension) && (varShape[i] != adjustedNodeShape.GetDim(i)))
                return false;
        }

        return true;
    }

    template <typename ElementType>
    std::pair<ComputationNetworkPtr, std::unordered_map<Variable, ComputationNodeBasePtr>>
        CompositeFunction::CreateComputationNetwork(const FunctionPtr& compositeFunction,
                                                    const DeviceDescriptor& device,
                                                    const std::unordered_set<Variable>& outputs,
                                                    const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
                                                    const std::unordered_set<Variable>& inputsExcludedFromGradientComputation,
                                                    bool useMangledNamesForComputationNodes)
    {
        auto computationNetwork = std::make_shared<ComputationNetwork>(AsCNTKImplDeviceId(device));
        ComputationNetworkBuilder<ElementType> builder(*computationNetwork);

        std::unordered_map<Variable, bool> isVariableRootMap;
        std::unordered_map<Variable, ComputationNodeBasePtr> variableToNodeMap;

        // Now recursively create the network in a top-down fashion
        auto rootFunction = compositeFunction->RootFunction();
        auto rootFunctionOutputs = rootFunction->RawOutputs();
        for (auto rootOutput : rootFunctionOutputs)
            GetNode(rootOutput, computationNetwork, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsExcludedFromGradientComputation, useMangledNamesForComputationNodes);

        // We need to patch the Computation node mappings for the arguments of block functions 
        // since for recurrent inputs, the mappings are not fully established the first time
        std::function<void(const Variable&)> PatchBlockArgumentsAndOutputsMapping;
        PatchBlockArgumentsAndOutputsMapping = [&variableToNodeMap, &PatchBlockArgumentsAndOutputsMapping](const Variable& var) {
            if (var.IsOutput())
            {
                BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(var.Owner().get());
                if (blockFunction)
                {
                    PostorderTraverseVariables(blockFunction->BlockRoot(), PatchBlockArgumentsAndOutputsMapping);

                    auto compositeArguments = blockFunction->Composite()->Arguments();
                    for (auto compositeArgument : compositeArguments)
                    {
                        auto mappingVarNodeIter = variableToNodeMap.find(compositeArgument.BlockFunctionVariableMapping());
                        if (mappingVarNodeIter != variableToNodeMap.end())
                            variableToNodeMap[compositeArgument] = mappingVarNodeIter->second;
                    }

                    auto mappingVarNodeIter = variableToNodeMap.find(var.BlockFunctionVariableMapping());
                    if (mappingVarNodeIter != variableToNodeMap.end())
                        variableToNodeMap[var] = mappingVarNodeIter->second;
                }
            }
        };
        PostorderTraverseVariables(rootFunction, PatchBlockArgumentsAndOutputsMapping);

        std::function<bool(const Variable&)> IsVariableRoot = [&isVariableRootMap, &IsVariableRoot](const Variable& outputVar) {
            auto mappingVariable = GetMappingVariable(outputVar);
            return (isVariableRootMap[outputVar] && !IsFirstOutputOfMultiOutputFunction(mappingVariable) && ((mappingVariable == outputVar) || IsVariableRoot(mappingVariable)));
        };

        // If any of the function or requested outputs is not a root node, we need to explicitly
        // add it to the 'output' group of the ComputationNetwork
        std::unordered_set<Variable> networkOutputs(outputs);
        networkOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
        for (auto output : networkOutputs)
        {
            if (!IsVariableRoot(output))
            {
                auto computationNode = variableToNodeMap[output];

                if (!computationNode)
                    InvalidArgument("One of the requested outputs '%S' for the Function '%S' forward computation is not part of the graph underlying the Function.",
                                    output.AsString().c_str(), compositeFunction->AsString().c_str());

                computationNetwork->AddToNodeGroup(L"output", computationNode);
            }
        }

        // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles.
        // Now attach those after we have created all ComputationNodes in the network
        for (auto varNodePair : variableToNodeMap)
        {
            auto& currentComputationNode = varNodePair.second;
            if (!currentComputationNode)
                LogicError("Function '%S': No computation node mapping exists for Variable %S.", compositeFunction->AsString().c_str(), varNodePair.first.AsString().c_str());

            auto& currentComputationNodeInputs = currentComputationNode->GetInputs();
            auto& currentVar = varNodePair.first;
            if (!currentVar.IsOutput())
                continue;

            if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end())
            {
                // This ComputationNode has at least one null input which now needs to be properly attached
                const PrimitiveFunction* primitiveFunc = dynamic_cast<const PrimitiveFunction*>(currentVar.Owner().get());

                // Skip block primitives since they do not directly map to a computation node
                if (primitiveFunc->OpType() == PrimitiveOpType::Block)
                    continue;

                // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
                auto inputVars = primitiveFunc->Inputs();
                ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars);
                inputVars.resize(currentComputationNode->GetNumInputs());

                std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
                for (auto inputVar : inputVars)
                    inputNodesBasePtrs.push_back(variableToNodeMap.at(inputVar));

                currentComputationNode->AttachInputs(inputNodesBasePtrs);
            }
        }

        computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel());
        computationNetwork->SetTrackGapNans(GetCheckedMode());
        computationNetwork->SetIsV2Library(true);
        computationNetwork->CompileNetwork();
        // Set EvalTimeStamp of all nodes in the network as "outdated" to make sure that all nodes will be evaluated at least once.
        // During CompileNetwork(), nodes in the network might get different timestamp values because other threads could update the global timestamp value.
        // (The global timestamp value is currently shared process-wide, i.e. among all nodes of all networks.) The nodes with a higher timestamp value are
        // thus incorrectly treated as "updated", and their inputs are not further evaluated by ComputationNetwork::PARTraversalFlowControlNode::ForwardProp().
        // This could lead to incorrect results or crash, because the matrix of the input nodes might never be initialized for ForwardProp().
        computationNetwork->SetEvalTimeStampsOutdatedWithRegardToAll();

        // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork
        for (auto varNodePair : variableToNodeMap)
        {
            if (varNodePair.first.IsOutput())
            {
                auto outputVar = varNodePair.first;
                auto computationNodePtr = variableToNodeMap.at(outputVar);
                auto outputShape = outputVar.Shape();
                auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
                if (!VariableShapeMatchesNodeShape(outputShape, computationNodeSampleLayout))
                {
                    LogicError("Function '%S': The output Variable '%S' shape '%S' does not match the SampleLayout shape '[%s]' of the corresponding ComputationNode in the network.",
                               compositeFunction->AsString().c_str(), outputVar.AsString().c_str(), outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str());
                }
            }
        }

        return { computationNetwork, variableToNodeMap };
    }

    template <typename ElementType>
    ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device,
                                                                   const std::unordered_set<Variable>& backpropRoots,
                                                                   const std::unordered_set<Variable>& outputs,
                                                                   const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
                                                                   bool allocateNetworkMatrices)
    {
        // Lets purge the current computation network and regenerate the network if the CompositeFunction
        // was previously compiled just for evaluation and not for gradient backpropagation.
        if ((m_computationNetwork != nullptr) && (m_currentBackpropRoots.empty() && !backpropRoots.empty()))
            PurgeComputationNetwork();

        if (m_computationNetwork != nullptr)
        {
            // TODO: We should either invalidate and readapt the network if the backpropRoots change compared to what was specified when the network
            // was last constructed, to just recreate a new network.
            // For now just disallow changing the backpropRoots after the network is created
            if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots))
                LogicError("Function '%S': Changing backprop roots (Current = '%S', New = '%S') across different Forward calls on a CNTK composite Function is currently unsupported.",
                            AsString().c_str(), NamedListString(m_currentBackpropRoots).c_str(), NamedListString(backpropRoots).c_str());

            // TODO: Support changing the device across different invocations of the forward method on a Function instance
            if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device)
                LogicError("Function '%S': Changing device (Current = '%S', New = %S') across different Forward calls on a CNTK composite Function is currently unsupported.",
                            AsString().c_str(), AsDeviceDescriptor(m_computationNetwork->GetDeviceId()).AsString().c_str(), device.AsString().c_str());
            
            if (!backpropRoots.empty() && (inputsToExcludeGradientsFor != m_inputsExcludedFromGradientComputation))
                LogicError("Function '%S': Changing the set of inputs to exclude from gradient computation, across different Forward calls on a CNTK composite Function, is currently unsupported.", AsString().c_str());

            // Verify if the free dimensions of any of the arguments have changed, and if so, update the corresponding
            // input ComputationNodes and rerun validation on the computation network
            for (auto freeDimensionArgumentMapping : m_fullyDefinedArgumentsMap)
            {
                auto newShape = freeDimensionArgumentMapping.second.Shape();
                auto argumentComputationNode = m_variableToNodeMap[freeDimensionArgumentMapping.first];
                if (AsTensorShape(newShape) != argumentComputationNode->GetSampleLayout())
                    argumentComputationNode->SetDims(AsTensorShape(newShape), argumentComputationNode->HasMBLayout());
            }
        }
        else
        {
            auto networkInputs = this->Inputs();
            for (auto inputExcluded : inputsToExcludeGradientsFor)
            {
                // Only inputs of the network can be excluded from gradient computation
                if (std::find(networkInputs.begin(), networkInputs.end(), inputExcluded) == networkInputs.end())
                    InvalidArgument("Variable '%S' specified for exclusion from gradient computation is not an input of the Function '%S'. "
                                    "Only an input of the Function can be explicitly excluded from gradient computation.",
                                    inputExcluded.AsString().c_str(), this->AsString().c_str());
            }

            m_inputsExcludedFromGradientComputation = NonOwnerPreservingCopy(inputsToExcludeGradientsFor);
            m_currentBackpropRoots = NonOwnerPreservingCopy(backpropRoots);

            // TODO: We currently only support one backprop root
            if (backpropRoots.size() > 1)
                LogicError("Function '%S': %d backprop roots specified; currently at most one backprop root is supported.", AsString().c_str(), (int)backpropRoots.size());

            auto placeholders = Placeholders();
            if (!placeholders.empty())
                InvalidArgument("%d unbound Placeholder(s) '%S' found in the Function. "
                    "All Placeholders of a Function must be bound (to a variable) before performing a Forward computation.",
                    (int)placeholders.size(), NamedListString(placeholders).c_str());

            // Lets update the composite Function graph's inputs with any inferred dimensions that 
            // were determined from the shapes of the supplied data
            auto networkArguments = Arguments();
            for (auto argument : networkArguments)
            {
                if (argument.Shape().HasInferredDimension())
                {
                    auto fullyDefinedArgument = m_fullyDefinedArgumentsMap.at(argument);
                    for (size_t i = 0; i < argument.Shape().Rank(); ++i)
                        if (argument.Shape()[i] == NDShape::InferredDimension)
                            argument.m_dataFields->m_shape[i] = fullyDefinedArgument.Shape()[i];
                }
            }

            // Run the final validation on the entire network once before constructing/compiling the
            // internal computation network
            ValidateOrUpdateOutputs();

            std::tie(m_computationNetwork, m_variableToNodeMap) = CreateComputationNetwork<ElementType>(this->shared_from_this(), device, outputs, m_fullyDefinedArgumentsMap, m_inputsExcludedFromGradientComputation, /*useMangledNamesForComputationNodes =*/ false);

            // Record the timestamps of Parameters and Constants
            assert(m_lastRecordedTimeStamps.empty());
            auto functionParameters = Parameters();
            for (auto parameter : functionParameters)
                m_lastRecordedTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() });
            auto functionConstants = Constants();
            for (auto constant : functionConstants)
                m_lastRecordedTimeStamps.insert({ constant, constant.CurrentValueTimeStamp() });

            // Collect parameters and constants being assigned to
            PreorderTraverseFunctions(RootFunction(), [this](const FunctionPtr& function) {
                auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
                if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::Assign))
                    m_refVariables.insert(primitiveFunction->Inputs()[0]);
            }, /*nestedSearchInsideBlockFunction =*/ true);
        }

        if (!m_networkMatricesAllocated && allocateNetworkMatrices)
        {
            m_allNetworkRoots = m_currentBackpropRoots;
            ComputationNodeBasePtr backpropRootNode;
            if (!m_currentBackpropRoots.empty())
                backpropRootNode = m_variableToNodeMap.at(*m_currentBackpropRoots.begin());

            // Now recursively traverse the network in a top-down fashion
            auto rootFunction = RootFunction();
            auto rootFunctionOutputs = rootFunction->RawOutputs();
            m_allNetworkRoots.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
            std::vector<ComputationNodeBasePtr> forwardRootNodes;
            for (auto rootOutput : rootFunctionOutputs)
                forwardRootNodes.push_back(m_variableToNodeMap.at(rootOutput));

            std::vector<ComputationNodeBasePtr> forwardOutputNodes;
            m_allNetworkRoots.insert(outputs.begin(), outputs.end());
            for (auto output : outputs)
                forwardOutputNodes.push_back(m_variableToNodeMap.at(output));

            m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode);
            m_networkMatricesAllocated = allocateNetworkMatrices;
        }
        else
        {
            // Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure
            // in the cached computation network
            for (auto output : outputs)
            {
                if (m_allNetworkRoots.find(output) == m_allNetworkRoots.end())
                    LogicError("Function '%S': Requested output '%S' is not part of the list of outputs '%S' that the Function was initially compiled for. "
                                "Changing requested outputs across different Forward calls is currently unsupported.",
                                AsString().c_str(), output.AsString().c_str(), NamedListString(m_allNetworkRoots).c_str());
            }
        }

        return m_computationNetwork;
    }

    template <typename ElementType>
    /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, ComputationNodeBasePtr& computationNode, std::unordered_map<MBLayoutPtr, Variable>& layoutsPopulated)
    {
        NDShape inferredVariableShape;
        std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableValue.first, variableValue.second, &inferredVariableShape);
        if (!VariableShapeMatchesNodeShape(inferredVariableShape, computationNode->GetSampleLayout()))
            CNTK::LogicError("CompositeFunction::Forward: Inferred shape '%S' of Variable '%S' does not match the corresponding computation node shape '%s'.", 
                             inferredVariableShape.AsString().c_str(), variableValue.first.AsString().c_str(), ((std::string)computationNode->GetSampleLayout()).c_str());

        // Switch the node matrix to the right matrix type
        auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();
        nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);

        auto layout = CNTKMatrixAndMBLayout.second;
        auto& nodeLayout = computationNode->GetMBLayout();
        if ((layout == nullptr) != (nodeLayout == nullptr))
            InvalidArgument("The layout of the specified Value for Variable '%S' is incompatible with the layout of the corresponding ComputationNode.", variableValue.first.AsString().c_str());
        else if (layout)
        {
            if (layoutsPopulated.find(nodeLayout) == layoutsPopulated.end())
            {
                nodeLayout->CopyFrom(layout);
                layoutsPopulated.insert({ nodeLayout, variableValue.first });
            }
            else
            {
                if (*nodeLayout != *layout)
                    InvalidArgument("Different minibatch layouts detected (difference in sequence lengths or count or start flags) in data specified "
                                    "for the Function's arguments '%S' vs. '%S', though these arguments have the same dynamic axes '%S'",
                                     variableValue.first.AsString().c_str(), layoutsPopulated.at(nodeLayout).AsString().c_str(), DynamicAxesAsString(variableValue.first.DynamicAxes(), Internal::IsReversingTensorShapesInErrorMessagesEnabled()).c_str());
            }
        }
    }

    std::unordered_map<Variable, NDShape> CompositeFunction::InferFreeDimensionsOfArguments(const std::unordered_map<Variable, ValuePtr>& arguments)
    {
        std::unordered_map<Variable, NDShape> inferredArgumentDimensions;
        for (auto argumentValuePair : arguments)
        {
            NDShape inferredVarShape;
            Utils::VerifyVariableValueCompatibility(argumentValuePair.first, argumentValuePair.second, &inferredVarShape);
            if (inferredVarShape != argumentValuePair.first.Shape())
                inferredArgumentDimensions.insert({ argumentValuePair.first , inferredVarShape });
        }

        if (!inferredArgumentDimensions.empty())
        {
            if (m_fullyDefinedArgumentsMap.empty())
            {
                for (auto inferredArgumentShapePair : inferredArgumentDimensions)
                {
                    auto fullyDefinedArgument = inferredArgumentShapePair.first.Clone();
                    fullyDefinedArgument.m_dataFields->m_shape = inferredArgumentShapePair.second;
                    m_fullyDefinedArgumentsMap.insert({ inferredArgumentShapePair.first, fullyDefinedArgument });
                }

                if (GetCheckedMode())
                    m_latestFullyDefinedCompositeForCheckedModeValidation = this->Clone(ParameterCloningMethod::Share, m_fullyDefinedArgumentsMap);
            }
            else
            {
                bool argumentShapeChangedSinceLastTime = false;
                for (auto inferredArgumentShapePair : inferredArgumentDimensions)
                {
                    if (inferredArgumentShapePair.second != m_fullyDefinedArgumentsMap[inferredArgumentShapePair.first].Shape())
                    {
                        argumentShapeChangedSinceLastTime = true;
                        m_fullyDefinedArgumentsMap[inferredArgumentShapePair.first].m_dataFields->m_shape = inferredArgumentShapePair.second;
                    }
                }

                if (argumentShapeChangedSinceLastTime && m_latestFullyDefinedCompositeForCheckedModeValidation)
                    m_latestFullyDefinedCompositeForCheckedModeValidation->ValidateOrUpdateOutputs();
            }
        }

        return inferredArgumentDimensions;
    }

    void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
    {
        std::unordered_map<MBLayoutPtr, Variable> layoutsPopulated;
        std::vector<ComputationNodeBasePtr> inputNodes;
        for (auto argumentValuePair : arguments)
        {
            auto argument = argumentValuePair.first;
            auto argumentComputationNode = m_variableToNodeMap.at(argument);
            assert(argumentComputationNode);
            inputNodes.push_back(argumentComputationNode);

            ValuePtr argumentValue = arguments.at(argument);
            switch (argumentValue->GetDataType())
            {
            case DataType::Float:
                PopulateComputationNodeValue<float>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
                break;
            case DataType::Double:
                PopulateComputationNodeValue<double>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
                break;
            default:
                LogicError("Function '%S' Forward: Unsupported DataType %s.", AsString().c_str(), DataTypeName(argumentValue->GetDataType()));
                break;
            }
        }

        m_computationNetwork->BumpEvalTimeStamp(inputNodes);
    }

    template <typename ElementType>
    /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode)
    {
        NDShape inferredVariableShape;
        std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableGradient.first, variableGradient.second, &inferredVariableShape);
        if (!VariableShapeMatchesNodeShape(inferredVariableShape, computationNode->GetSampleLayout()))
            CNTK::LogicError("CompositeFunction::Backward: Inferred shape '%S' of Variable '%S' does not match the corresponding computation node shape '%s'.",
                             inferredVariableShape.AsString().c_str(), variableGradient.first.AsString().c_str(), ((std::string)computationNode->GetSampleLayout()).c_str());

        MBLayoutPtr layout = CNTKMatrixAndMBLayout.second;
        auto nodeLayout = computationNode->GetMBLayout();
        if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
            InvalidArgument("The layout of the specified gradient Value for Variable '%S' is incompatible with the layout computed during Forward call.",
                            variableGradient.first.AsString().c_str());
        computationNode->As<ComputationNode<ElementType>>()->AssignGradient(*CNTKMatrixAndMBLayout.first);
    }

    // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph
    void CompositeFunction::PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients)
    {
        auto functionOutputs = RawOutputs();
        for (auto gradientVarValuePair : gradients)
        {
            auto outputComputationNode = m_variableToNodeMap.at(gradientVarValuePair.first);
            ValuePtr gradientValue = gradientVarValuePair.second;

            switch (gradientValue->GetDataType())
            {
            case DataType::Float:
                PopulateComputationNodeGradient<float>(gradientVarValuePair, outputComputationNode);
                break;
            case DataType::Double:
                PopulateComputationNodeGradient<double>(gradientVarValuePair, outputComputationNode);
                break;
            default:
                LogicError("Function '%S' Backward: Unsupported DataType %s.", AsString().c_str(), DataTypeName(gradientValue->GetDataType()));
                break;
            }
        }
    }

    /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient)
    {
        auto varShape = GetVariableShape(var.Shape(), computationNode->GetSampleLayout());
        auto valueShape = PackedValue::GetUnpackedShape(varShape, var.DynamicAxes(), computationNode->GetMBLayout());
        if (varValue != nullptr)
        {
            // TODO: The shape of the specified output Value object must match the actual output shape
            if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape)))
                ::InvalidArgument("The shape %S of the specified Value object for Variable '%S' %s does not match the actual shape %S",
                                  varValue->Shape().AsString().c_str(), var.AsString().c_str(), getGradient ? "gradient" : "output", valueShape.AsString().c_str());
        }

        ValuePtr nodeValue;
        auto layout = computationNode->GetMBLayout();
        switch (var.GetDataType())
        {
        case DataType::Float:
        {
            auto& matrix = getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value();
            if (varValue == nullptr)
                nodeValue = MakeSharedObject<PackedValue>(varShape, var.DynamicAxes(), std::make_shared<Matrix<float>>(matrix.AsReference()), layout, /*readOnly =*/ false);
            else
                nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var, computationNode, matrix, layout);
            break;
        }
        case DataType::Double:
        {
            auto& matrix = getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value();
            if (varValue == nullptr)
                nodeValue = MakeSharedObject<PackedValue>(varShape, var.DynamicAxes(), std::make_shared<Matrix<double>>(matrix.AsReference()), layout, /*readOnly =*/ false);
            else
                nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var, computationNode, matrix, layout);
            break;
        }
        default:
            CNTK::LogicError("CompositeFunction::Forward/Backward: Unsupported DataType %s", DataTypeName(var.GetDataType()));
            break;
        }

        if (varValue == nullptr)
            varValue = nodeValue;
        else
            varValue->CopyFrom(*nodeValue);
    }

    void CompositeFunction::GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs)
    {
        // Now copy the Forward values of output nodes from the network to outputs' Value objects
        for (auto& outputVarValuePair : outputs)
        {
            auto& valuePtr = outputVarValuePair.second;
            auto node = m_variableToNodeMap.at(outputVarValuePair.first);
            bool noValueStrorageProvided = (valuePtr == nullptr);
            GetNodeOutputOrGradient(outputVarValuePair.first, valuePtr, node, false /*getGradient*/);

            auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
            if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
                m_existingNetworkStorageReferences.push_back(packedVarValue);
        }
    }

    void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
    {
        auto networkInputs = this->Inputs();
        // Now copy the gradient values of input nodes of the network to gradients' Value objects
        for (auto& gradientVarValuePair : gradients)
        {
            // Only gradients corresponding to inputs of the network can be obtained
            if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end())
                InvalidArgument("Gradient requested for Variable '%S' which is not a leaf input (Input, Parameter or Constant) of the Function '%S'; this is currently unsupported.",
                                gradientVarValuePair.first.AsString().c_str(), this->AsString().c_str());

            // Gradients can only be obtained for parameter variables or input variables that NeedsGradient
            if (!gradientVarValuePair.first.NeedsGradient() || (m_inputsExcludedFromGradientComputation.find(gradientVarValuePair.first) != m_inputsExcludedFromGradientComputation.end()))
                InvalidArgument("Gradient value incorrectly requested for Variable '%S', "
                                "an Output or Constant or Input Variable with NeedsGradient setting of false, or an input for which gradient computation was explicitly excluded.",
                                gradientVarValuePair.first.AsString().c_str());

            auto computationNodePtr = m_variableToNodeMap.at(gradientVarValuePair.first);

            if (!computationNodePtr->NeedsGradient())
                LogicError("Function '%S': Backpropagated gradient value cannot be read from a Variable '%S' whose ComputationNode has NeedsGradient set to false.",
                            AsString().c_str(), gradientVarValuePair.first.AsString().c_str());

            auto& valuePtr = gradientVarValuePair.second;
            bool noValueStrorageProvided = (valuePtr == nullptr);
            GetNodeOutputOrGradient(gradientVarValuePair.first, valuePtr, computationNodePtr, true /*getGradient*/);

            auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
            if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
                m_existingNetworkStorageReferences.push_back(packedVarValue);
        }
    }

    const std::vector<Variable>& CompositeFunction::GetArgumentDependencies(const Variable& output)
    {
        if (m_perOutputVarArgumentDependencies.find(output) == m_perOutputVarArgumentDependencies.end())
        {
            auto sanitizedOutput = output.NonCompositePreservingCopy();

            if (sanitizedOutput.IsOutput())
                m_perOutputVarArgumentDependencies[sanitizedOutput] = AsComposite(sanitizedOutput.Owner())->Arguments();
            else if (sanitizedOutput.IsParameter() || sanitizedOutput.IsConstant())
                m_perOutputVarArgumentDependencies[sanitizedOutput] = {};
            else
                m_perOutputVarArgumentDependencies[sanitizedOutput] = { sanitizedOutput };
        }

        return m_perOutputVarArgumentDependencies[output];
    }

    std::unordered_map<Variable, uint64_t> CompositeFunction::GetCurrentBackpropRootsTimeStamps() const
    {
        std::unordered_map<Variable, uint64_t> currentBackpropRootsTimeStamps;
        assert(m_computationNetwork != nullptr);

        for (auto& backpropRoot : m_currentBackpropRoots)
            currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp();

        return currentBackpropRootsTimeStamps;
    }

    /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
                                                            std::unordered_map<Variable, ValuePtr>& outputs,
                                                            const DeviceDescriptor& computeDevice,
                                                            const std::unordered_set<Variable>& outputsToRetainBackwardStateFor,
                                                            const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
    {
        // Validate arguments and outputs
        if (outputs.empty())
            InvalidArgument("At least one output has to be specified when calling Forward method of the Function '%S'.", this->AsString().c_str());

        // Make sure that the DataType of the variables and corresponding values match
        // TODO: We need a better way to determine the ElementType for the network
        auto dataType = DataType::Unknown;
        for (auto variableValuePair : arguments)
        {
            if (dataType == DataType::Unknown)
                dataType = variableValuePair.first.GetDataType();
            else if (dataType != variableValuePair.first.GetDataType())
                LogicError("Function '%S' Forward: The DataType of all arguments must be same.", this->AsString().c_str());
        }

        if (dataType == DataType::Unknown)
        {
            for (auto variableValuePair : outputs)
            {
                if (dataType == DataType::Unknown)
                    dataType = variableValuePair.first.GetDataType();
            }
        }

        std::unordered_set<Variable> requestedOutputVariables;
        for (auto output : outputs)
            requestedOutputVariables.insert(output.first);

        std::unordered_set<Variable> functionOutputs(m_outputs.begin(), m_outputs.end());
        std::unordered_set<Variable> requiredArguments;
        for (auto outputVariable : requestedOutputVariables)
        {
            auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVariable);
            requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end());
        }

        // We should have argument values supplied for all required argument dependencies for the requested outputs
        std::vector<Variable> missingRequiredArguments;
        std::unordered_map<Variable, ValuePtr> requiredArgumentValues;
        for (auto requiredArgument : requiredArguments)
        {
            auto iter = arguments.find(requiredArgument);
            if (iter == arguments.end())
                missingRequiredArguments.push_back(requiredArgument);
            else
                requiredArgumentValues.insert(*iter);
        }

        if (!missingRequiredArguments.empty())
        {
            InvalidArgument("Values for %d required arguments '%S', that the requested output(s) '%S' depend on, have not been provided.",
                            (int)missingRequiredArguments.size(), NamedListString(missingRequiredArguments).c_str(), NamedListString(requestedOutputVariables).c_str());
        }

        if (requiredArgumentValues.size() < arguments.size())
            fprintf(stderr, "WARNING: Function::Forward provided values for (%d) extra arguments which are not required for evaluating the specified Function outputs!\n", (int)(arguments.size() - requiredArgumentValues.size()));

        auto inferredArgumentShapes = InferFreeDimensionsOfArguments(requiredArgumentValues);

        if (dataType == DataType::Float)
            GetComputationNetwork<float>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
        else if (dataType == DataType::Double)
            GetComputationNetwork<double>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
        else
            InvalidArgument("Unsupported DataType %s", DataTypeName(dataType));

        // Feed data into the arguments of the network
        // TODO: Avoid copying the data when possible
        PopulateNetworkInputs(requiredArgumentValues);

        // Copy all new values for 'dirty' attributes from functions into corresponding network nodes.
        ApplyAttributeUpdates();

        // Bump the timestamp of the parameter nodes whose values have changed
        for (auto& timeStampRecord : m_lastRecordedTimeStamps)
        {
            auto variable = timeStampRecord.first;
            auto prevTimeStamp = timeStampRecord.second;
            auto newTimeStamp = variable.CurrentValueTimeStamp();
            if (newTimeStamp > prevTimeStamp)
            {
                timeStampRecord.second = newTimeStamp;
                m_variableToNodeMap.at(variable)->BumpEvalTimeStamp();
            }
        }

        std::vector<ComputationNodeBasePtr> outputsToEvaluate;
        for (auto outputVariable : requestedOutputVariables)
            outputsToEvaluate.push_back(m_variableToNodeMap.at(outputVariable));

        // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
        for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
        {
            if (outputs.find(rootVarForBackprop) == outputs.end())
                outputsToEvaluate.push_back(m_variableToNodeMap.at(rootVarForBackprop));
        }

        // Reset the timestamps of all backward roots to record an update in one or more inputs
        for (auto& backpropRoot : m_currentBackpropRoots)
            m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();

        // Reset the timestamps of all dropout node to force recomputation of the (random) dropout mask.
        list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType<DropoutNodeBase>();
        for (auto& dropout : dropoutNodes)
            dropout->SetEvalTimeStampOutdatedWrtAll();

        // Free any previous references to the matrix storage associated with the outputsToEvaluate
        ClearExistingOutputOrGradientStorageReferences();

        ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);

        m_computationNetwork->ForwardProp(outputsToEvaluate);

        // Call PostForwardAndBackProp after ForwardProp only in evaluation mode.
        if (outputsToRetainBackwardStateFor.empty())
        {
            m_computationNetwork->PostForwardAndBackProp(outputsToEvaluate);
            RecordRefVariableUpdates();
        }
        else
        {
            m_currentOutputsToEvaluate.clear();
            for (auto outputToEvaluate : outputsToEvaluate)
                m_currentOutputsToEvaluate.push_back(outputToEvaluate);
        }

        GetNetworkOutputs(outputs);
        // TODO: How to deal with the specified 'computeDevice'
        BackPropStatePtr backpropStatePtr;
        if (outputsToRetainBackwardStateFor.size() > 0)
            backpropStatePtr = MakeSharedObject<CNTKBackPropState>(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps());

        return backpropStatePtr;
    }

    /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
                                                 const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
                                                 std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
    {
        auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
        if (backpropState == nullptr)
            InvalidArgument("Function '%S' Backward: Invalid backprop state passed.", AsString().c_str());

        if (backPropagatedGradientValuesForInputs.empty())
            InvalidArgument("Function '%S' Backward: List of inputs to compute gradients for, must not be empty.", AsString().c_str());

        // TODO: Support multiple concurrent backprop states
        std::unordered_map<Variable, uint64_t> currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps();
        if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps)
            LogicError("Function '%S' Backward: The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified "
                        "by subsequent Forward calls to the function. This is not a user error but a shortcoming of the current implementation where multiple independent "
                        "backprop states are not simultaneously supported.", AsString().c_str());

        if (rootGradientValues.size() > 1)
            LogicError("Function '%S' Backward: %d root gradient values specified; currently gradient backprop from only one of the Function Outputs is supported.",
                        AsString().c_str(), (int)rootGradientValues.size());

        // TODO: Avoid copying the data when possible

        // Zero all gradients of nodes below the root nodes
        for (auto rootGradientVarValuePair : rootGradientValues)
            m_computationNetwork->ZeroInputGradients(m_variableToNodeMap.at(rootGradientVarValuePair.first));

        // Feed data into the arguments of the network
        PopulateNetworkGradients(rootGradientValues);

        // Backpropagate through the network
        ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training);

        auto rootComputationNodePtr = m_variableToNodeMap.at(rootGradientValues.begin()->first);
        m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true);

        GetNetworkGradients(backPropagatedGradientValuesForInputs);

        if (m_currentOutputsToEvaluate.size() > 0)
        {
            m_computationNetwork->PostForwardAndBackProp(m_currentOutputsToEvaluate);
            RecordRefVariableUpdates();
            m_currentOutputsToEvaluate.clear();
        }

        // TODO: How to deal with the specified 'computeDevice'
    }

    void CompositeFunction::ApplyAttributeUpdates()
    {
        // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input
        // This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated
        // w.r.t. inputs to force evaluation in each minibatch
        for (auto varNodePair : m_variableToNodeMap)
        {
            auto var = varNodePair.first;
            if (!var.IsOutput())
                continue;

            auto function = var.Owner();

            if (function->m_dirtyAttributes.empty())
                continue;

            auto node = varNodePair.second;

            for (const wstring& attribute : function->m_dirtyAttributes)
            {
                if (attribute == PrimitiveFunction::AttributeNameDropoutRate)
                {
                    auto dropoutRate = function->m_attributes[attribute].Value<double>();
                    auto dropoutPtr = dynamic_cast<DropoutNodeBase*>(node.get());
                    assert(dropoutPtr != nullptr);
                    dropoutPtr->SetDropoutRate(dropoutRate);
                }
                else if (attribute == PrimitiveFunction::AttributeNameRngSeed) 
                {
                    auto seed = function->m_attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                    auto rngUserPtr = dynamic_cast<RngUser*>(node.get());
                    assert(rngUserPtr != nullptr);
                    rngUserPtr->SetRngState(seed);
                }
                else 
                {
                    // Should never happen.
                    LogicError("ApplyAttributeUpdates: function '%S' specified an unsupported attribute '%S'.",
                        function->AsString().c_str(), attribute.c_str());
                }
            }

            function->m_dirtyAttributes.clear();
            node->SetEvalTimeStampOutdatedWrtAll();
        }
    }
}
back to top