https://github.com/Microsoft/CNTK
Raw File
Tip revision: b8544c177606ce9e10fc20a9a1d0d08ff2d6d7ac authored by Mark Hillebrand on 20 January 2017, 11:24:32 UTC
Bump NuGet package versions
Tip revision: b8544c1
CompositeFunction.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#include "stdafx.h"
#include "CNTKLibrary.h"
#include "CompositeFunction.h"
#include "ComputationNetworkBuilder.h"
#include "Utils.h"
#include "ComputationNode.h"
#include "ReshapingNodes.h"
#include "EvaluationNodes.h"
#include "TrainingNodes.h"
#include "LinearAlgebraNodes.h"
#include "InputAndParamNodes.h"
#include "NonlinearityNodes.h"
#include "RecurrentNodes.h"
#include "Serialization.h"
#include "Value.h"
#include "RNNNodes.h"
#include "UserDefinedV2FunctionNode.h"
#include "BlockFunction.h"

using namespace Microsoft::MSR::CNTK;

namespace CNTK
{
    /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName";
    /*static*/ std::atomic<unsigned int> CompositeFunction::s_nextAutoGeneratedDynamicAxis(0);

    static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction";

    Dictionary CompositeFunction::SerializeBlockComposite() const
    {
        Dictionary dict;

        dict[versionKey] = CurrentVersion();
        dict[typeKey] = s_compositeFunctionTypeValue;
        dict[rootKey] = RootFunction()->Uid();
        dict[nameKey] = Name();
        dict[uidKey] = Uid();

        return dict;
    }

    /*virtual*/ Dictionary CompositeFunction::Serialize() const
    {
        Dictionary dict = SerializeBlockComposite();
       
        // Find cycles in the graph and "break" them by inserting placeholders.
        // This needs to be done on Save, since here we have easy access to the shape and 
        // dynamic axis info.
        std::unordered_set<FunctionPtr> visitedFunctions;
        std::vector<FunctionPtr> topoSortedPrimitiveFunctions;
        std::vector<Variable> uniqueInputs;
        std::unordered_set<std::wstring> inputUids;
        std::function<void(const FunctionPtr& function)> SerializationTraversalFunc;
        SerializationTraversalFunc = [&visitedFunctions, &uniqueInputs, &topoSortedPrimitiveFunctions, &inputUids, &SerializationTraversalFunc](const FunctionPtr& function) {
            std::vector<Variable> functionInputs = function->Inputs();
            for (const auto& input : functionInputs)
            {
                auto& uid = input.Uid();
                if (inputUids.find(uid) != inputUids.end())
                    continue;

                // check if this input corresponds to a cyclic edge in the graph.
                // BUG: A function being visited twice does not indicate it being a cyclic edge in the graph.
                // It just means there are at least 2 successors in the graph that have the function as input
                bool mustBeReplaced = input.IsOutput() && (visitedFunctions.find(input.Owner()) != visitedFunctions.end());

                if (mustBeReplaced)
                {
                    auto varKind = VariableKind::Placeholder;
                    Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid);
                    uniqueInputs.push_back(var);
                    inputUids.insert(uid);
                }
                else if (!input.IsOutput())
                {
                    // leave the input as is.
                    uniqueInputs.push_back(input);
                    inputUids.insert(uid);
                }
            }
            visitedFunctions.insert(function);
            topoSortedPrimitiveFunctions.push_back(function);

            // For block functions we need to recursively traverse the underlying composite
            if (function->IsBlock())
                PreorderTraverseFunctions(function->BlockRoot(), SerializationTraversalFunc);
        };

        PreorderTraverseFunctions(RootFunction(), SerializationTraversalFunc);

        std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions));

        assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid());
        
        std::vector<DictionaryValue> inputDictionaries;
        inputDictionaries.reserve(uniqueInputs.size());
        inputUids.clear();
        for (const auto& input : uniqueInputs)
        {
            if (inputUids.find(input.Uid()) != inputUids.end())
                LogicError("Input uids must be unique");

            inputUids.insert(input.Uid());
            inputDictionaries.push_back(input.Serialize());
        }

        dict[inputsKey] = std::move(inputDictionaries);
       
        std::vector<DictionaryValue>  functionDictionaries;
        std::unordered_set<std::wstring> outputUids;
        for (const auto& primitiveFunciton : topoSortedPrimitiveFunctions)
        {
            for (const auto& output : primitiveFunciton->Outputs())
            {
                if (outputUids.find(output.Uid()) != outputUids.end())
                    LogicError("Output uids of all primitive functions in a function graph must be unique");

                outputUids.insert(primitiveFunciton->Uid());
            }

            functionDictionaries.push_back(primitiveFunciton->Serialize());
        }

        dict[functionsKey] = std::move(functionDictionaries);
        
        // Now, collect and store the internal state for all non-pure (stateful) functions in the graph 
        // (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc).
        Dictionary stateDictionary; 
        for (const auto& kv : m_variableToNodeMap)
        {
            if (kv.second->Is<RngUser>() && kv.first.IsOutput())
            {
                // The RNG state should be associated with the actual function that the computation node
                // corresponds to, and not the block primitives that wrap the actual function
                auto ownerFunction = kv.first.Owner().get();
                if (!ownerFunction->IsBlock())
                {
                    auto rng = kv.second->As<RngUser>();
                    Dictionary state;
                    state[rngSeedKey] = static_cast<size_t>(rng->GetRngSeed());
                    state[rngOffsetKey] = static_cast<size_t>(rng->GetRngOffset());
                    stateDictionary[ownerFunction->Uid()] = state;
                }
            }
        }

        dict[stateKey] = std::move(stateDictionary);

        return dict;
    }

    /*static*/ FunctionPtr CompositeFunction::DeserializeBlockComposite(const Dictionary& dict,
                                                                        const std::unordered_set<FunctionPtr>& allPrimitiveFunctions,
                                                                        const std::unordered_map<Variable, Variable>& allPlaceholderReplacements,
                                                                        const CNTK::DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, rootKey, nameKey, uidKey };
        ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);

        const auto& rootUid = dict[rootKey].Value<std::wstring>();
        const auto& name = dict[nameKey].Value<std::wstring>();
        const auto& uid = dict[uidKey].Value<std::wstring>();

        FunctionPtr root = *std::find_if(allPrimitiveFunctions.begin(), allPrimitiveFunctions.end(), [&rootUid](const FunctionPtr& func) {
            return func->Uid() == rootUid;
        });

        // Find the subset of placeholder replacements that apply for this composite
        FunctionPtr composite = CompositeFunction::Create(root, name, uid);
        std::unordered_map<Variable, Variable> placeholderReplacements;
        do
        {
            placeholderReplacements.clear();
            auto compositePlaceholders = composite->Placeholders();
            for (auto placeholder : compositePlaceholders)
            {
                if (allPlaceholderReplacements.find(placeholder) != allPlaceholderReplacements.end())
                    placeholderReplacements.insert({ placeholder, allPlaceholderReplacements.at(placeholder) });
            }

            if (placeholderReplacements.size() > 0)
                composite = composite->ReplacePlaceholders(placeholderReplacements);

        } while (placeholderReplacements.size() > 0);

        return composite;
    }

    /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { inputsKey, functionsKey };
       
        size_t version = ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);

        const auto& inputs = dict[inputsKey].Value<vector<DictionaryValue>>();
        std::unordered_map<std::wstring, Variable> uidToInputMap(inputs.size());
        for (const auto& dictionaryValue : inputs)
        {
            const auto& dictionary = dictionaryValue.Value<Dictionary>();
            const auto& inputVar = Variable::Deserialize(dictionary, device);

            if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end())
            {
                LogicError("Input uids are not unique (several inputs share '%ls' uid) "
                           "(%s).", inputVar.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
            }
            uidToInputMap[inputVar.Uid()] = inputVar;
        }

        Dictionary stateDictionary;
        if (dict.Contains(stateKey))
            stateDictionary = dict[stateKey].Value<Dictionary>();

        const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();

        std::unordered_map<Variable, Variable> allPlaceholderReplacements;
        std::unordered_set<FunctionPtr> allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created.
        for (const auto& dictionaryValue : functions)
        {
            FunctionPtr root = PrimitiveFunction::Deserialize(dictionaryValue.Value<Dictionary>(), uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device);
            allPrimitiveFunctions.insert(root);

            auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(root.get());

            // Since Combine simply forwards other functions' outputs, all of its outputs
            // should already be in the uidToInputMap.
            auto opType = primitiveFunction->OpType();
            if (opType == PrimitiveOpType::Combine)
                continue;

            if (primitiveFunction->IsStateful())
            {
                if (stateDictionary.Contains(primitiveFunction->Uid()))
                {
                    auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
                    auto seed = state[rngSeedKey].Value<size_t>();
                    auto offset = state[rngOffsetKey].Value<size_t>();
                    primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed;
                    primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset;
                }  
                else if (Internal::GetComputationNetworkTraceLevel() > 0)
                {
                    // TODO: all logging functionality should be refactored to live in a logging utility class. 
                    fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
                            "when deserializing from a dictionary (version=%zu). "
                            "Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
                }
            }

            for (const auto& output : root->Outputs())
            {
                const auto& it = uidToInputMap.find(output.Uid());
                if (it != uidToInputMap.end())
                {
                    if (!it->second.IsPlaceholder())
                    {
                        LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)"
                        "(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(),
                        GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
                    }
                    allPlaceholderReplacements[it->second] = output;
                }
                else
                {
                    uidToInputMap[output.Uid()] = output;
                }
            }
        }

        return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
    }

    void CompositeFunction::CopyState(const CompositeFunction& source)
    {
        // Create a map with all non-pure (stateful) functions in the function graph.
        auto collectStatefulFunctions = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions) -> std::map<std::wstring, FunctionPtr> {
            std::map<std::wstring, FunctionPtr> functionMap;
            for (auto funcPtr : allPrimitiveFunctions)
            {
                auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
                if (primitiveFunction->IsStateful())
                {
                    functionMap[primitiveFunction->Uid()] = funcPtr;
                }
            }
            return functionMap;
        };

        std::map<std::wstring, FunctionPtr> statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions);
        std::map<std::wstring, FunctionPtr> statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions);

        assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size());
        if (statefulFunctionsFrom.size() == 0)
        {
            return;
        }

        // Copy state captured in the attributes dictionaries.
        for (const auto& kv : statefulFunctionsFrom)
        {
            statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes();
        }

        UpdateInternalNetworkState();
    }

    void CompositeFunction::UpdateInternalNetworkState()
    {
        if (!m_computationNetwork)
        {
            return;
        }

        for (const auto& function : m_allPrimitiveFunctions)
        {
            auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
            if (primitiveFunction->IsStateful())
            {
                for (const auto& output : function->Outputs())
                {
                    auto node = m_variableToNodeMap.at(output);
                    auto attributes = function->Attributes();
                    auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                    auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
                    node->As<RngUser>()->SetRngState(seed, offset);
                }
            }
        }
    }

    // Names of the dynamic axes in the CNTK engine for some special sets of dynamic axes values
    // Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over)
    // and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed
    // when the new CNTK v2 model serialization format is ready.
    /*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*";
    /*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis";

    // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions 
    // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the 
    // top level 'variable'
    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
                                                                 Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                 ComputationNetworkBuilder<ElementType>& builder,
                                                                 std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                 std::unordered_map<Variable, bool>& isVariableRootMap)
    {
        auto iter = variableToNodeMap.find(variable);
        if (iter != variableToNodeMap.end())
        {
            isVariableRootMap[variable] = false;
            return iter->second;
        }

        // The DataType, Shape and DynamicAxes of the variable must be known by now
        if (variable.GetDataType() == DataType::Unknown)
            InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.Shape().IsUnknown())
            InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.Shape().HasInferredDimension())
            InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.DynamicAxes() == Axis::UnknownDynamicAxes())
            InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs
        variableToNodeMap[variable] = nullptr;

        std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
        if (variable.IsParameter() || variable.IsConstant())
        {
            auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name());
            computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape()));
            network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
            if (!variable.NeedsGradient())
                computationNodePtr->SetLearningRateMultiplier(0.0);

            NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
            std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();

            if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
                computationNodePtr->Value() = valueMatrix->AsReference();
            else
            {
                Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
                clonedMatrix.AssignValuesOf(*valueMatrix);
                computationNodePtr->Value() = std::move(clonedMatrix);
            }
        }
        else if (variable.IsInput())
        {
            auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name());

            // TODO: Input variables currently are required to have the default batch axis
            auto dynamicAxes = variable.DynamicAxes();
            auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis());
            if (foundDefaultBatchAxis == dynamicAxes.end())
                LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes");

            if (dynamicAxes.back() != Axis::DefaultBatchAxis())
                LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes");

            // TODO: Support inputs with > 1 dynamic axes
            if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2))
                LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported");

            // Construct the dynamic axis name to be used internally for the CNTK InputNodes
            std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);

            if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
                network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});

            if (IsSparseInput(variable))
                computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName);
            else
                computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName);

            if (variable.NeedsGradient())
            {
                // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default
                // gradients are not computed for Input nodes
                computationNodePtr->SetLearningRateMultiplier(0.00001f);
            }
        }
        else
        {
            assert(variable.IsOutput());
            computationNodePtr = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap)->template As<ComputationNode<ElementType>>()->shared_from_this();
        }

        variableToNodeMap[variable] = computationNodePtr;
        if (isVariableRootMap.find(variable) == isVariableRootMap.end())
        isVariableRootMap[variable] = variable.IsOutput();

        return computationNodePtr;
    }

    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
                                                                               Function* function,
                                                                               const std::vector<std::shared_ptr<ComputationNode<ElementType>>>& inputNodes,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap)
    {
        ComputationNodeBasePtr computationNodePtr;

        auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name());

        std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
        for (auto inputNode : inputNodes)
            inputNodesBasePtrs.push_back(inputNode);

        PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
        if (primitiveFunction)
        {
            auto functionInputs = function->Inputs();
            auto& functionConfig = function->Attributes();
            PrimitiveOpType op = primitiveFunction->OpType();

            switch (op)
            {
            case PrimitiveOpType::Negate:
                computationNodePtr = New<NegateNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Sigmoid:
                computationNodePtr = New<SigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Tanh:
                computationNodePtr = New<TanhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Cos:
                computationNodePtr = New<CosineNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Sin:
                computationNodePtr = New<SinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ReLU:
                computationNodePtr = New<RectifiedLinearNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Exp:
                computationNodePtr = New<ExpNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Log:
                computationNodePtr = New<LogNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Sqrt:
                computationNodePtr = New<SqrtNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Floor:
                computationNodePtr = New<FloorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Abs:
                computationNodePtr = New<AbsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Reciprocal:
                computationNodePtr = New<ReciprocalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Softmax:
                computationNodePtr = New<SoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Hardmax:
                computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::TransposeAxes:
            {
                auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
                auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();

                // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
                computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
                break;
            }
            case PrimitiveOpType::Where:
            {
                auto dynamicAxes = variable.DynamicAxes();
                auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
                computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
                break;
            }
            case PrimitiveOpType::Slice:
            {
                auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
                auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
                auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value<int>();

                // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
                computationNodePtr = New<SliceNode<ElementType>>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis));
                break;
            }
            case PrimitiveOpType::RandomSample:
            {
                auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
                auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
                computationNodePtr = New<RandomSampleNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
                break;
            }
            case PrimitiveOpType::RandomSampleInclusionFrequency:
            {
                auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
                auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
                computationNodePtr = New<RandomSampleInclusionFrequencyNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
                break;
            }
            case PrimitiveOpType::Dropout:
            {
                auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value<double>();
                computationNodePtr = New<DropoutNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                computationNodePtr->As<DropoutNode<ElementType>>()->SetDropoutRate(dropoutRate);
                break;
            }
            case PrimitiveOpType::Reshape:
            {
                computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(primitiveFunction->Output().Shape()));
                break;
            }
            case PrimitiveOpType::ROIPooling:
            {
                auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value<NDShape>();
                computationNodePtr = New<ROIPoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape));
                break;
            }
            case PrimitiveOpType::Pooling:
            {
                PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
                auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
                auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                computationNodePtr = New<PoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
                break;
            }
            case PrimitiveOpType::Unpooling:
            {
                auto unpoolingWindowShape = functionConfig[PrimitiveFunction::AttributeNameUnpoolingWindowShape].Value<NDShape>();
                auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                //We only get here after validation so it is safe to assume unpooling is max
                computationNodePtr = New<MaxUnpoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
                break;
            }
            case PrimitiveOpType::SumAll:
                computationNodePtr = New<SumElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Plus:
                computationNodePtr = New<PlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::LogPlus:
                computationNodePtr = New<LogPlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Minus:
                computationNodePtr = New<MinusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ElementTimes:
                computationNodePtr = New<ElementTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Equal:
                computationNodePtr = New<EqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::NotEqual:
                computationNodePtr = New<NotEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Less:
                computationNodePtr = New<LessNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::LessEqual:
                computationNodePtr = New<LessEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Greater:
                computationNodePtr = New<GreaterNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::GreaterEqual:
                computationNodePtr = New<GreaterEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Times:
            {
                size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
                auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value<int>();
                computationNodePtr = New<TimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
                break;
            }
            case PrimitiveOpType::TransposeTimes:
            {
                size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
                computationNodePtr = New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank);
                break;
            }
            case PrimitiveOpType::Convolution:
            {
                NDShape outputMapCount, kernelShape;
                std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape());
                auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
                auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
                auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
                auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
                auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
                auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
                auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
                computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples);
                break;
            }
            case PrimitiveOpType::CosDistance:
                computationNodePtr = New<CosDistanceNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Logistic:
                computationNodePtr = New<LogisticNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::SquaredError:
                computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::CrossEntropyWithSoftmax:
                computationNodePtr = New<CrossEntropyWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ClassificationError:
                computationNodePtr = New<ClassificationErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::LambdaRank:
                computationNodePtr = New<LambdaRankNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::NDCG:
                computationNodePtr = New<NDCG1EvalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::PastValue:
            case PrimitiveOpType::FutureValue:
            {
                Variable inputOperandVar = functionInputs[0];
                Variable initialStateVar = functionInputs[1];

                size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
                if (op == PrimitiveOpType::PastValue)
                    computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
                else
                    computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);

                break;
            }
            case PrimitiveOpType::ReduceElements:
            {
                auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
                auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
                computationNodePtr = New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis));
                break;
            }
            case PrimitiveOpType::BatchNormalization:
            {
                auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
                auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
                auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
                auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
                auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();

                computationNodePtr = New<BatchNormalizationNode<ElementType>>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW);
                break;
            }
            case PrimitiveOpType::Combine:
                // This operation is just a no-op and is a means to combine multiple functions to create a single Function
                // whose outputs are a union of the outputs of the Functions being combined.
                computationNodePtr = variableToNodeMap[variable];
                break;
            case PrimitiveOpType::PackedIndex:
                computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::GatherPacked:
                computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::ScatterPacked:
                computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Clip:
                computationNodePtr = New<ClipNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Select:
                computationNodePtr = New<IfNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            case PrimitiveOpType::Splice:
            {
                Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
                computationNodePtr = New<RowStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
                break;
            }
            case PrimitiveOpType::OptimizedRNNStack:
            {
                auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value<bool>();
                auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value<size_t>();
                auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value<size_t>();
                auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value<std::wstring>();

                computationNodePtr = New<OptimizedRNNStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
                break;
            }
            case PrimitiveOpType::ReconcileDynamicAxis:
            {
                computationNodePtr = New<ReconcileDynamicAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            }
            case PrimitiveOpType::LogSoftmax:
            {
                //This can be implemented as x => x - ReduceLogSum(x). How to do this here?
                computationNodePtr = New<LogSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            }
            case PrimitiveOpType::Pass:
                computationNodePtr = New<PassNode<ElementType>>(network->GetDeviceId(), internalNodeName);
                break;
            default:
                LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
                break;
            }

            // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
            ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);

            if (computationNodePtr->Is<INumInputs>())
            {
                auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
                if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
                    LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs",
                    PrimitiveOpTypeName(op).c_str(),
                    (int)inputNodesBasePtrs.size(),
                    (int)computationNodeExpectedInputCount);
            }

            if (computationNodePtr->Is<RngUser>())
            {
                if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed))
                {
                    auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                    uint64_t offset = 0;
                    if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
                    {
                        offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
                    }
                    computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
                }
            }
        }
        else
        {
            computationNodePtr = New<UserDefinedV2FunctionNode<ElementType>>(network->GetDeviceId(), internalNodeName, function->shared_from_this());

            // For user defined functions, we only attach unique inputs in the internal computation network since, the UDF 
            // backward implementations directly compute aggregate gradient values for unique inputs
            std::vector<ComputationNodeBasePtr> uniqueInputNodesBasePtrs;
            for (auto inputNodeBasePtr : inputNodesBasePtrs)
            {
                if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end())
                    uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr);
            }

            inputNodesBasePtrs = uniqueInputNodesBasePtrs;
        }

        network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);

        return computationNodePtr;
    }

    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               ComputationNetworkBuilder<ElementType>& builder,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                               std::unordered_map<Variable, bool>& isVariableRootMap)
    {
        assert(variable.IsOutput());

        Function* function = variable.Owner().get();
        ComputationNodeBasePtr computationNodePtr;
        auto& functionInputs = function->m_inputs;

        DataType nonConstInputDataType = DataType::Unknown;
        for (auto& inputVar : functionInputs)
        {
            if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown))
            {
                nonConstInputDataType = inputVar.GetDataType();
                break;
            }
        }
            
        // Create the nodes corresponding to the inputs
        std::vector<std::shared_ptr<ComputationNode<ElementType>>> inputNodes;
        for (auto& inputVar : functionInputs)
        {
            // If the inputVar is a constant and not the right DataType let's coerce it to the right type
            if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType))
            {
                auto originalConstantValue = Constant(inputVar).Value();
                auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true);
                NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true);
                inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name());
            }

            auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap);
            inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As<ComputationNode<ElementType>>()->shared_from_this() : nullptr);
        }

        BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function);
        if (blockFunction)
        {
            // For block function, map each argument placeholder of the underlying composite to
            // the computation node corresponding to the block input that the argument placeholder
            // of the composite is mapped to.
            auto compositeArguments = blockFunction->Composite()->Arguments();
            for (auto compositeArgument : compositeArguments)
                variableToNodeMap[compositeArgument] = variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());

            return GetNode(variable.BlockFunctionVariableMapping(), network, builder, variableToNodeMap, isVariableRootMap);
        }
        else
            computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap);

        PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
        if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine))
        {
            for (auto inputVar : functionInputs)
                isVariableRootMap[inputVar] = false;
        }

        return computationNodePtr;
    }

    template <typename ElementType>
    ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set<Variable>& backpropRoots, const std::unordered_set<Variable>& outputs, bool allocateNetworkMatrices)
    {
        if (m_computationNetwork != nullptr)
        {
            // TODO: We should either invalidate and readapt the network if he backpropRoots change compared to what was specified when the network
            // was last constructed, to just recreate a new network.
            // For now just disallow changing the backpropRoots after the network is created
            if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots))
                LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported");

            // TODO: Support changing the device across different invocations of the forward method on a Function instance
            if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device)
                LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported");
        }
        else
        {
            m_computationNetwork = std::make_shared<ComputationNetwork>(AsCNTKImplDeviceId(device));

            ComputationNetworkBuilder<ElementType> builder(*m_computationNetwork);

            // TODO: We currently only support one backprop root
            if (backpropRoots.size() > 1)
                LogicError("More than one backprop roots is currently unsupported");

            auto placeholders = Placeholders();
            if (!placeholders.empty())
                InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!");

            // Now recursively create the network in a top-down fashion
            auto rootFunction = RootFunction();
            auto rootFunctionOutputs = rootFunction->Outputs();
            for (auto rootOutput : rootFunctionOutputs)
                GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap);

            // We need to patch the Computation node mappings for the arguments of block functions 
            // since for recurrent inputs, the mappings are not fully established the first time
            std::function<void(const FunctionPtr&)> PatchBlockArgumentsMapping;
            PatchBlockArgumentsMapping = [this, &PatchBlockArgumentsMapping](const FunctionPtr& function) {
                BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function.get());
                if (blockFunction)
                {
                    auto compositeArguments = blockFunction->Composite()->Arguments();
                    for (auto compositeArgument : compositeArguments)
                        m_variableToNodeMap[compositeArgument] = m_variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());

                    PreorderTraverseFunctions(function->BlockRoot(), PatchBlockArgumentsMapping);
                }
            };
            PreorderTraverseFunctions(rootFunction, PatchBlockArgumentsMapping);

            std::function<bool(const Variable&)> IsVariableRoot;
            IsVariableRoot = [this, &IsVariableRoot](const Variable& outputVar) {
                auto ownerFunc = outputVar.IsOutput() ? outputVar.Owner().get() : nullptr;
                auto ownerBlockFunc = dynamic_cast<BlockFunction*>(ownerFunc);
                return (m_isVariableRootMap[outputVar] && (!ownerBlockFunc || IsVariableRoot(ownerBlockFunc->CompositeOutputsMap().at(outputVar))));
            };

            // If any of the function or requested outputs is not a root node, we need to explicitly
            // add it to the 'output' group of the ComputationNetwork
            std::unordered_set<Variable> networkOutputs(outputs);
            networkOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
            for (auto output : networkOutputs)
            {
                if (!IsVariableRoot(output))
                {
                    auto computationNode = m_variableToNodeMap[output];

                    if (!computationNode)
                        InvalidArgument("One of the requested outputs for the Function forward computation is not part of the graph underlying the Function");

                    m_computationNetwork->AddToNodeGroup(L"output", computationNode);
                }
            }

            m_currentBackpropRoots = backpropRoots;

            // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles.
            // Now attach those after we have created all ComputationNodes in the network
            for (auto varNodePair : m_variableToNodeMap)
            {
                auto& currentComputationNode = varNodePair.second;
                if (!currentComputationNode)
                    LogicError("No computation node mapping exists for Variable %S", varNodePair.first.Name().c_str());

                auto& currentComputationNodeInputs = currentComputationNode->GetInputs();
                auto& currentVar = varNodePair.first;
                if (!currentVar.IsOutput())
                    continue;

                if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end())
                {
                    // This ComputationNode has at least one null input which now needs to be properly attached

                    const PrimitiveFunction* primitiveFunc = dynamic_cast<const PrimitiveFunction*>(currentVar.Owner().get());

                    // Skip block primitives since they do not directly map to a computation node
                    if (primitiveFunc->OpType() == PrimitiveOpType::Block)
                        continue;

                    // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
                    auto inputVars = primitiveFunc->Inputs();
                    ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars);
                    inputVars.resize(currentComputationNode->GetNumInputs());

                    std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
                    for (auto inputVar : inputVars)
                        inputNodesBasePtrs.push_back(m_variableToNodeMap.at(inputVar));

                    currentComputationNode->AttachInputs(inputNodesBasePtrs);
                }
            }

            m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel());
            m_computationNetwork->CompileNetwork();

            // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork
            for (auto varNodePair : m_variableToNodeMap)
            {
                if (varNodePair.first.IsOutput())
                {
                    auto outputVar = varNodePair.first;
                    auto computationNodePtr = m_variableToNodeMap.at(outputVar);
                    auto outputShape = outputVar.Shape();
                    auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
                    if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) ||
                        ((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape))))
                    {
                        LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str());
                    }
                }
            }

            // Record the timestamps of Parameter values
            assert(m_lastRecordedParameterValueTimeStamps.empty());
            auto functionParameters = Parameters();
            for (auto parameter : functionParameters)
                m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() });
        }


        if (!m_networkMatricesAllocated && allocateNetworkMatrices)
        {
            ComputationNodeBasePtr backpropRootNode;
            if (!m_currentBackpropRoots.empty())
                backpropRootNode = m_variableToNodeMap.at(*m_currentBackpropRoots.begin());

            // Now recursively traverse the network in a top-down fashion
            auto rootFunction = RootFunction();
            auto rootFunctionOutputs = rootFunction->Outputs();
            std::vector<ComputationNodeBasePtr> forwardRootNodes;
            for (auto rootOutput : rootFunctionOutputs)
                forwardRootNodes.push_back(m_variableToNodeMap.at(rootOutput));

            std::vector<ComputationNodeBasePtr> forwardOutputNodes;
            for (auto output : outputs)
                forwardOutputNodes.push_back(m_variableToNodeMap.at(output));

            m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode);
            m_networkMatricesAllocated = allocateNetworkMatrices;

            std::unordered_set<ComputationNodeBasePtr> allNetworkRoots = { backpropRootNode };
            allNetworkRoots.insert(forwardRootNodes.begin(), forwardRootNodes.end());
            allNetworkRoots.insert(forwardOutputNodes.begin(), forwardOutputNodes.end());
            m_allNetworkRootsInGlobalEvalOrder = m_computationNetwork->SortByGlobalEvalOrder(allNetworkRoots);

            m_currentOutputs = outputs;
            m_currentOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
            m_currentOutputs.insert(m_currentBackpropRoots.begin(), m_currentBackpropRoots.end());
        }
        else
        {
            // Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure
            // in the cached computation network
            for (auto output : outputs)
            {
                if (m_currentOutputs.find(output) == m_currentOutputs.end())
                    LogicError("Changing requested outputs across different Forward calls on a CNTK composite Function is currently unsupported");
            }
        }

        return m_computationNetwork;
    }

    template <typename ElementType>
    /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, ComputationNodeBasePtr& computationNode, std::unordered_map<MBLayoutPtr, Variable>& layoutsPopulated)
    {
        if (!computationNode->Is<InputValueBase<ElementType>>())
            LogicError("CompositeFunction::Forward: Illegal to populate value of computation node type other than InputValueBase!");

        std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout;
        auto packedValue = dynamic_cast<PackedValue*>(variableValue.second.get());
        if (packedValue)
            CNTKMatrixAndMBLayout = packedValue->PackedData<ElementType>();
        else
            CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableValue.first, variableValue.second);

        // Switch the node matrix to the right matrix type
        auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();
        nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);

        auto layout = CNTKMatrixAndMBLayout.second;
        auto& nodeLayout = computationNode->GetMBLayout();
        if (layoutsPopulated.find(nodeLayout) == layoutsPopulated.end())
        {
            nodeLayout->CopyFrom(layout);
            layoutsPopulated.insert({ nodeLayout, variableValue.first });
        }
        else
        {
            if (*nodeLayout != *layout)
                InvalidArgument("Function::Forward: Different minibatch layouts detected (difference in sequence lengths or count or start flags) in data specified for 2 of the Function's argument ('%S', '%S') having same dynamic axes", variableValue.first.Name().c_str(), layoutsPopulated.at(nodeLayout).Name().c_str());
        }
    }

    void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
    {
        std::unordered_map<MBLayoutPtr, Variable> layoutsPopulated;
        std::vector<ComputationNodeBasePtr> inputNodes;
        for (auto argumentValuePair : arguments)
        {
            auto argument = argumentValuePair.first;
            auto argumentComputationNode = m_variableToNodeMap.at(argument);
            assert(argumentComputationNode);
            inputNodes.push_back(argumentComputationNode);

            ValuePtr argumentValue = arguments.at(argument);
            switch (argumentValue->GetDataType())
            {
            case DataType::Float:
                PopulateComputationNodeValue<float>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
                break;
            case DataType::Double:
                PopulateComputationNodeValue<double>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
                break;
            default:
                LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType()));
                break;
            }
        }

        m_computationNetwork->BumpEvalTimeStamp(inputNodes);
    }

    template <typename ElementType>
    /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode)
    {
        std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout;
        auto packedValue = dynamic_cast<PackedValue*>(variableGradient.second.get());
        if (packedValue)
            CNTKMatrixAndMBLayout = packedValue->PackedData<ElementType>();
        else
            CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableGradient.first, variableGradient.second);

        MBLayoutPtr layout = CNTKMatrixAndMBLayout.second;
        auto nodeLayout = computationNode->GetMBLayout();
        if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
            InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call");
        computationNode->As<ComputationNode<ElementType>>()->AssignGradient(*CNTKMatrixAndMBLayout.first);
    }

    // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph
    void CompositeFunction::PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients)
    {
        auto functionOutputs = this->Outputs();
        for (auto gradientVarValuePair : gradients)
        {
            auto outputComputationNode = m_variableToNodeMap.at(gradientVarValuePair.first);
            ValuePtr gradientValue = gradientVarValuePair.second;

            switch (gradientValue->GetDataType())
            {
            case DataType::Float:
                PopulateComputationNodeGradient<float>(gradientVarValuePair, outputComputationNode);
                break;
            case DataType::Double:
                PopulateComputationNodeGradient<double>(gradientVarValuePair, outputComputationNode);
                break;
            default:
                LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType()));
                break;
            }
        }
    }

    static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr)
    {
        size_t outputValueNumAxes = var.Shape().Rank();

        // Add the batch and dynamic axes if needed
        if (computationNodePtr->GetMBLayout() != nullptr)
            outputValueNumAxes += 2;

        std::vector<size_t> outputShapeDims(outputValueNumAxes);
        for (size_t i = 0; i < var.Shape().Rank(); ++i)
            outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i);

        if (computationNodePtr->GetMBLayout() != nullptr)
        {
            outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps();
            outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences();
        }

        return NDShape(outputShapeDims);
    }

    /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient)
    {
        auto valueShape = GetValueShape(var, computationNode);
        if (varValue != nullptr)
        {
            // TODO: The shape of the specified output Value object must match the actual output shape
            if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape)))
                InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str());
        }

        ValuePtr nodeValue;
        auto layout = computationNode->GetMBLayout();
        switch (var.GetDataType())
        {
        case DataType::Float:
        {
            auto& matrix = getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value();
            if (varValue == nullptr)
                nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<float>>(matrix.AsReference()), layout, /*readOnly =*/ false);
            else
                nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var, matrix, layout);
            break;
        }
        case DataType::Double:
        {
            auto& matrix = getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value();
            if (varValue == nullptr)
                nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<double>>(matrix.AsReference()), layout, /*readOnly =*/ false);
            else
                nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var, matrix, layout);
            break;
        }
        default:
            LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType()));
            break;
        }

        if (varValue == nullptr)
            varValue = nodeValue;
        else
            varValue->CopyFrom(*nodeValue);
    }

    void CompositeFunction::GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs)
    {
        // Now copy the Forward values of output nodes from the network to outputs' Value objects
        for (auto outputVarValuePair : outputs)
            GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap.at(outputVarValuePair.first), false /*getGradient*/);
    }

    void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
    {
        auto networkInputs = this->Inputs();
        // Now copy the gradient values of input nodes of the network to gradients' Value objects
        for (auto gradientVarValuePair : gradients)
        {
            // Only gradients corresponding to inputs of the network can be obtained
            if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end())
                InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function");

            // Gradients can only be obtained for parameter variables or input variables that NeedsGradient
            if (!gradientVarValuePair.first.NeedsGradient())
                InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, or an Input Variable with NeedsGradient setting of false");

            auto computationNodePtr = m_variableToNodeMap.at(gradientVarValuePair.first);

            if (!computationNodePtr->NeedsGradient())
                LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false");

            GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/);
        }
    }

    const std::vector<Variable>& CompositeFunction::GetArgumentDependencies(const Variable& output)
    {
        assert(output.IsOutput());

        auto iter = m_perOutputVarArgumentDependencies.find(output);
        if (iter != m_perOutputVarArgumentDependencies.end())
            return iter->second;

        auto wrappedComposite = CompositeFunction::Create(output.Owner());
        m_perOutputVarArgumentDependencies[output] = wrappedComposite->Arguments();

        return m_perOutputVarArgumentDependencies[output];
    }

    std::unordered_map<Variable, uint64_t> CompositeFunction::GetCurrentBackpropRootsTimeStamps() const
    {
        std::unordered_map<Variable, uint64_t> currentBackpropRootsTimeStamps;
        assert(m_computationNetwork != nullptr);

        for (auto& backpropRoot : m_currentBackpropRoots)
            currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp();

        return currentBackpropRootsTimeStamps;
    }

    /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
                                                            std::unordered_map<Variable, ValuePtr>& outputs,
                                                            const DeviceDescriptor& computeDevice,
                                                            const std::unordered_set<Variable>& outputsToRetainBackwardStateFor)
    {
        // Validate arguments and outputs
        if (outputs.empty())
            InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!");

        // Make sure that the DataType of the variables and corresponding values match
        // TODO: We need a better way to determine the ElementType for the network
        auto dataType = DataType::Unknown;
        for (auto variableValuePair : arguments)
        {
            if (dataType == DataType::Unknown)
                dataType = variableValuePair.first.GetDataType();
            else if (dataType != variableValuePair.first.GetDataType())
                LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same");
        }

        if (dataType == DataType::Unknown)
        {
            for (auto variableValuePair : outputs)
            {
                if (dataType == DataType::Unknown)
                    dataType = variableValuePair.first.GetDataType();
            }
        }

        std::unordered_set<Variable> requestedOutputVariables;
        for (auto output : outputs)
            requestedOutputVariables.insert(output.first);

        if (dataType == DataType::Float)
            GetComputationNetwork<float>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, true);
        else if (dataType == DataType::Double)
            GetComputationNetwork<double>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, true);
        else
            InvalidArgument("Unsupported DataType %s", DataTypeName(dataType));

        std::unordered_set<Variable> functionOutputs(this->Outputs().begin(), this->Outputs().end());
        std::vector<ComputationNodeBasePtr> outputsToEvaluate;
        std::unordered_set<Variable> requiredArguments;
        for (auto outputVarValuePair : outputs)
        {
            auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first);
            requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end());

            auto outputComputationNode = m_variableToNodeMap.at(outputVarValuePair.first);
            outputsToEvaluate.push_back(outputComputationNode);
        }

        // We should have argument values supplied for all required argument dependencies for the requested outputs
        std::vector<Variable> missingRequiredArguments;
        for (auto requiredArgument : requiredArguments)
        {
            if (arguments.find(requiredArgument) == arguments.end())
                missingRequiredArguments.push_back(requiredArgument);
        }

        if (!missingRequiredArguments.empty())
        {
            std::wstring missingRequiredArgumentNames = NamedListString(missingRequiredArguments);
            InvalidArgument("Function::Forward: Values for %d required arguments (%S), that the requested output(s) depend on, have not been provided", (int)missingRequiredArgumentNames.size(), missingRequiredArgumentNames.c_str());
        }

        // Feed data into the arguments of the network
        // TODO: Avoid copying the data when possible
        PopulateNetworkInputs(arguments);

        // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input
        // This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated
        // w.r.t. inputs to force evaluation in each minibatch
        list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode));
        for (auto& nodeIter : dropoutNodes)
            nodeIter->SetEvalTimeStampOutdatedWrtAll();
        
        // Bump the timestamp of the parameter nodes whose values have changed
        for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps)
        {
            auto parameter = paramTimeStampRecord.first;
            auto prevTimeStamp = paramTimeStampRecord.second;
            auto newTimeStamp = parameter.CurrentValueTimeStamp();
            if (newTimeStamp > prevTimeStamp)
            {
                paramTimeStampRecord.second = newTimeStamp;
                m_variableToNodeMap.at(parameter)->BumpEvalTimeStamp();
            }
        }

        // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
        for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
        {
            if (outputs.find(rootVarForBackprop) == outputs.end())
                outputsToEvaluate.push_back(m_variableToNodeMap.at(rootVarForBackprop));
        }

        // Reset the timestamps of all backward roots to record an update in one or more inputs
        for (auto& backpropRoot : m_currentBackpropRoots)
            m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();

        // TODO: Verify that values were supplied for all inputs that requested outputs depend on

        ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);

        // We may have to include additional nodes in the ForwardProp to align with how the memory sharing structure is setup
        // We need to include all roots that lie earlier in the global eval order than the actual outputs we are interested
        // in evaluation.
        // TODO: This may incur additonal compute costs in some rare scenarios. We need to come up with a better way to handle this.
        outputsToEvaluate = m_computationNetwork->SortByGlobalEvalOrder(outputsToEvaluate);
        auto lastOutputInEvalOrder = outputsToEvaluate.back();
        auto iterEndRootInEvalOrder = std::find(m_allNetworkRootsInGlobalEvalOrder.begin(), m_allNetworkRootsInGlobalEvalOrder.end(), lastOutputInEvalOrder) + 1;

        auto augmentedOutputsToEvaluate = std::vector<ComputationNodeBasePtr>(m_allNetworkRootsInGlobalEvalOrder.begin(), iterEndRootInEvalOrder);
        m_computationNetwork->ForwardProp(augmentedOutputsToEvaluate);

        GetNetworkOutputs(outputs);

        // TODO: How to deal with the specified 'computeDevice'
        Variable evalTimeStampVariable;
        if (arguments.empty())
            evalTimeStampVariable = Inputs()[0];
        else
            evalTimeStampVariable = arguments.begin()->first;

        BackPropStatePtr backpropStatePtr;
        if (outputsToRetainBackwardStateFor.size() > 0)
            backpropStatePtr = MakeSharedObject<CNTKBackPropState>(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps());

        return backpropStatePtr;
    }

    /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
                                                 const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
                                                 std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
    {
        auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
        if (backpropState == nullptr)
            InvalidArgument("Invalid backprop state specified");

        // TODO: Support multiple concurrent backprop states
        std::unordered_map<Variable, uint64_t> currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps();
        if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps)
            LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function."
                       "This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");

        if (rootGradientValues.size() > 1)
            LogicError("Currently gradient backprop from only one of the Function Outputs is supported");

        // TODO: Avoid copying the data when possible

        // Zero all gradients of nodes below the root nodes
        for (auto rootGradientVarValuePair : rootGradientValues)
            m_computationNetwork->ZeroInputGradients(m_variableToNodeMap.at(rootGradientVarValuePair.first));

        // Feed data into the arguments of the network
        PopulateNetworkGradients(rootGradientValues);

        // Backpropagate through the network
        ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training);

        auto rootComputationNodePtr = m_variableToNodeMap.at(rootGradientValues.begin()->first);
        m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true);

        GetNetworkGradients(backPropagatedGradientValuesForInputs);

        // TODO: How to deal with the specified 'computeDevice'
    }
}
back to top