https://github.com/Microsoft/CNTK
Revision 0fdf321b6f8f6c1c00d8c6458cd3ae55b73540e7 authored by Friedel van Megen on 29 November 2016, 10:17:05 UTC, committed by Friedel van Megen on 29 November 2016, 11:06:26 UTC
1 parent 20dada9
Raw File
Tip revision: 0fdf321b6f8f6c1c00d8c6458cd3ae55b73540e7 authored by Friedel van Megen on 29 November 2016, 10:17:05 UTC
eval changes
Tip revision: 0fdf321
CompositeFunction.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#include "stdafx.h"
#include "CNTKLibrary.h"
#include "CompositeFunction.h"
#include "ComputationNetworkBuilder.h"
#include "Utils.h"
#include "ComputationNode.h"
#include "ReshapingNodes.h"
#include "EvaluationNodes.h"
#include "TrainingNodes.h"
#include "LinearAlgebraNodes.h"
#include "InputAndParamNodes.h"
#include "NonlinearityNodes.h"
#include "RecurrentNodes.h"
#include "Serialization.h"
#include "Value.h"
#include "RNNNodes.h"

using namespace Microsoft::MSR::CNTK;

namespace CNTK
{
    /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName";
    /*static*/ std::atomic<unsigned int> CompositeFunction::s_nextAutoGeneratedDynamicAxis(0);

    static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction";

    /*virtual*/ Dictionary CompositeFunction::Serialize() const
    {
        Dictionary dict;

        dict[versionKey] = CurrentVersion();
        dict[typeKey] = s_compositeFunctionTypeValue;
        dict[rootKey] = RootFunction()->Uid();
        dict[nameKey] = Name();
        dict[uidKey] = Uid();

       
        // Find cycles in the graph and "break" them by inserting placeholders.
        // This needs to be done on Save, since here we have easy access to the shape and 
        // dynamic axis info.
        std::unordered_set<FunctionPtr> visitedFunctions;
        std::vector<FunctionPtr> topoSortedPrimitiveFunctions;
        std::vector<Variable> inputs;
        std::unordered_set<std::wstring> inputUids;
        Traverse(RootFunction(), [&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) {
                    std::vector<Variable> functionInputs = function->Inputs();
                    for (const auto& input : functionInputs)
                    {
                        auto& uid = input.Uid();
                        if (inputUids.find(uid) != inputUids.end())
                        {
                            continue;
                        }

                        // check if this input corresponds to a cyclic edge in the graph.
                        bool mustBeReplaced = input.IsOutput() && visitedFunctions.find(input.Owner()) != visitedFunctions.end();

                        if (mustBeReplaced)
                        {
                            auto varKind = VariableKind::Placeholder;
                                Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, 
                                             input.IsSparse(), input.DynamicAxes(), input.Name(), uid);
                                inputs.push_back(var);
                                inputUids.insert(uid);
                        }
                        else if (!input.IsOutput())
                        {
                            // leave the input as is.
                            inputs.push_back(input);
                            inputUids.insert(uid);
                        }
                    }
                    visitedFunctions.insert(function);
                    topoSortedPrimitiveFunctions.push_back(function);
                });

        std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions));

        assert(topoSortedPrimitiveFunctions.size() == m_allPrimitiveFunctions.size());
        assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid());
        
        std::vector<DictionaryValue> inputDictionaries;
        inputDictionaries.reserve(inputs.size());
        inputUids.clear();
        for (const auto& input : inputs)
        {
            if (inputUids.find(input.Uid()) != inputUids.end())
            {
                LogicError("Input uids must be unique");
            }
            inputUids.insert(input.Uid());
            inputDictionaries.push_back(input.Serialize());
        }

        dict[inputsKey] = std::move(inputDictionaries);
       
        std::vector<DictionaryValue>  functionDictionaries;
        std::unordered_set<std::wstring> outputUids;
        for (const auto& primitiveFunciton : topoSortedPrimitiveFunctions)
        {
            for (const auto& output : primitiveFunciton->Outputs())
            {
                if (outputUids.find(output.Uid()) != outputUids.end())
                {
                    LogicError("Output uids of all primitive functions in a function graph must be unique");
                }
                outputUids.insert(primitiveFunciton->Uid());
            }
            functionDictionaries.push_back(primitiveFunciton->Serialize());
        }

        dict[functionsKey] = std::move(functionDictionaries);
        
        // Now, collect and store the internal state for all non-pure (stateful) functions in the graph 
        // (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc).
        Dictionary stateDictionary; 
        for (const auto& kv : m_variableToNodeMap)
        {
            if (kv.second->Is<RngUser>()) 
            {
                auto rng = kv.second->As<RngUser>();
                Dictionary state;
                state[rngSeedKey] = static_cast<size_t>(rng->GetRngSeed());
                state[rngOffsetKey] = static_cast<size_t>(rng->GetRngOffset());
                stateDictionary[kv.first.Owner()->Uid()] = state;
            }
        }

        dict[stateKey] = std::move(stateDictionary);

        return dict;
    }

    /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
    {
        static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, rootKey, nameKey, uidKey, inputsKey, functionsKey };
       
        size_t version = ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);

        const auto& rootUid = dict[rootKey].Value<std::wstring>();
        const auto& name = dict[nameKey].Value<std::wstring>();
        const auto& uid = dict[uidKey].Value<std::wstring>();
        const auto& inputs = dict[inputsKey].Value<vector<DictionaryValue>>();

        std::unordered_map<std::wstring, Variable> uidToInputMap(inputs.size());

        for (const auto& dictionaryValue : inputs)
        {
            const auto& dictionary = dictionaryValue.Value<Dictionary>();
            const auto& inputVar = Variable::Deserialize(dictionary, device);

            if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end())
            {
                LogicError("Input uids are not unique (several inputs share '%ls' uid) "
                        "(%s).", inputVar.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
            }
            uidToInputMap[inputVar.Uid()] = inputVar;
        }

        Dictionary stateDictionary;
        if (dict.Contains(stateKey))
        {
            stateDictionary = dict[stateKey].Value<Dictionary>();
        }

        const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();

        FunctionPtr root;
        std::unordered_map<Variable, Variable> placeholderReplacements;
        std::unordered_set<FunctionPtr> allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created.
        for (const auto& dictionaryValue : functions)
        {
            root = PrimitiveFunction::Deserialize(dictionaryValue.Value<Dictionary>(), uidToInputMap, device);
            allPrimitiveFunctions.insert(root);

            auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(root.get());
            // Since Combine simply forwards other functions' outputs, all of its outputs
            // should already be in the uidToInputMap.
            auto opType = primitiveFunction->OpType();
            if (opType == PrimitiveOpType::Combine)
            {
                continue;
            }

            if (primitiveFunction->IsStateful())
            {
                if (stateDictionary.Contains(primitiveFunction->Uid()))
                {
                    auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
                    auto seed = state[rngSeedKey].Value<size_t>();
                    auto offset = state[rngOffsetKey].Value<size_t>();
                    primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed;
                    primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset;
                }  
                else if (Internal::GetComputationNetworkTraceLevel() > 0)
                {
                    // TODO: all logging functionality should be refactored to live in a logging utility class. 
                    fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
                            "when deserializing from a dictionary (version=%zu). "
                            "Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
                }
            }

            for (const auto& output : root->Outputs())
            {
                const auto& it = uidToInputMap.find(output.Uid());
                if (it != uidToInputMap.end())
                {
                    if (!it->second.IsPlaceholder())
                    {
                        LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)"
                        "(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(),
                        GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
                    }
                    placeholderReplacements[it->second] = output;
                }
                else
                {
                    uidToInputMap[output.Uid()] = output;
                }
            }
        }

        if (root->Uid() != rootUid)
        {
            LogicError("Root UID '%ls' is different from the expected value '%ls'.", root->Uid().c_str(), rootUid.c_str());
        }

        if (placeholderReplacements.size() > 0)
        {
           return CompositeFunction::Create(root->ReplacePlaceholders(placeholderReplacements), name, uid);
        }

        return CompositeFunction::Create(root, name, uid);
    }

    void CompositeFunction::CopyState(const CompositeFunction& source)
    {
        // Create a map with all non-pure (stateful) functions in the function graph.
        auto collectStatefulFunctions = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions) -> std::map<std::wstring, FunctionPtr> {
            std::map<std::wstring, FunctionPtr> functionMap;
            for (auto funcPtr : allPrimitiveFunctions)
            {
                auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
                if (primitiveFunction->IsStateful())
                {
                    functionMap[primitiveFunction->Uid()] = funcPtr;
                }
            }
            return functionMap;
        };

        std::map<std::wstring, FunctionPtr> statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions);
        std::map<std::wstring, FunctionPtr> statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions);

        assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size());
        if (statefulFunctionsFrom.size() == 0)
        {
            return;
        }

        // Copy state captured in the attributes dictionaries.
        for (const auto& kv : statefulFunctionsFrom)
        {
            statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes();
        }

        UpdateInternalNetworkState();
    }

    void CompositeFunction::UpdateInternalNetworkState()
    {
        if (!m_computationNetwork)
        {
            return;
        }

        for (const auto& function : m_allPrimitiveFunctions)
        {
            auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
            if (primitiveFunction->IsStateful())
            {
                for (const auto& output : function->Outputs())
                {
                    auto node = m_variableToNodeMap[output];
                    auto attributes = function->Attributes();
                    auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                    auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
                    node->As<RngUser>()->SetRngState(seed, offset);
                }
            }
        }
    }

    // Names of the dynamic axes in the CNTK engine for some special sets of dynamic axes values
    // Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over)
    // and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed
    // when the new CNTK v2 model serialization format is ready.
    /*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*";
    /*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis";

    // Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables
    // should have been replaced before performing any Forward compute of 'this' Function.
    /*virtual*/ void CompositeFunction::ReplacePlaceholdersInPlace(const std::unordered_map<Variable, Variable>& placeholderReplacements,
                                                                   std::unordered_set<const Function*>& visitedFunctions,
                                                                   std::unordered_set<Variable>& replacedPlaceholders)
    {
        RootFunction()->ReplacePlaceholdersInPlace(placeholderReplacements, visitedFunctions, replacedPlaceholders);

        // If any of the placeholders were replaced with Output variables, let's add the graph of function underneath each of those to 'm_allPrimitiveFunctions' set
        for (auto replacedPlaceholder : replacedPlaceholders)
        {
            auto replacingVariable = placeholderReplacements.at(replacedPlaceholder);
            if (replacingVariable.IsOutput())
            {
                auto ownerFunc = replacingVariable.Owner();
                std::unordered_set<FunctionPtr> visitedFunctions2;
                Collect(ownerFunc, visitedFunctions2);

                // Add the newly visited functions to 'm_allPrimitiveFunctions' set
                m_allPrimitiveFunctions.insert(visitedFunctions2.begin(), visitedFunctions2.end());
            }
        }
        std::unordered_map<const Function*, size_t> functionVisitCounts;

        // An arbitrary cap on changing output shape of recurrent nodes, to detect infinite inference loops
        const size_t maxNumValidationPassesAllowed = 25;
        bool recurrentNodeOutputModified = false;
        size_t numValidationPasses = 0;
        do
        {
            recurrentNodeOutputModified = false;
            functionVisitCounts.clear();
            RootFunction()->ValidateOrUpdateOutputs(functionVisitCounts, recurrentNodeOutputModified);
            numValidationPasses++;
        } while (recurrentNodeOutputModified && (numValidationPasses < maxNumValidationPassesAllowed));

        if (numValidationPasses >= maxNumValidationPassesAllowed)
            LogicError("A recurrent node output shape change happened in successive %d validation passes indicating a potential infinite inference loop!", (int)numValidationPasses);
    }

    // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions 
    // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the 
    // top level 'variable'
    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
                                                                 Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                 ComputationNetworkBuilder<ElementType>& builder,
                                                                 std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                 std::unordered_map<Variable, bool>& isVariableRootMap)
    {
        auto iter = variableToNodeMap.find(variable);
        if (iter != variableToNodeMap.end())
        {
            isVariableRootMap[variable] = false;
            return iter->second;
        }

        // The DataType, Shape and DynamicAxes of the variable must be known by now
        if (variable.GetDataType() == DataType::Unknown)
            InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.Shape().IsUnknown())
            InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.Shape().HasInferredDimension())
            InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        if (variable.DynamicAxes() == Axis::UnknownDynamicAxes())
            InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());

        // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs
        variableToNodeMap[variable] = nullptr;

        std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
        if (variable.IsParameter() || variable.IsConstant())
        {
            auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name());
            computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape()));
            network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
            if (!variable.NeedsGradient())
                computationNodePtr->SetLearningRateMultiplier(0.0);

            NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
            std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();

            if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
                computationNodePtr->Value() = valueMatrix->AsReference();
            else
            {
                Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
                clonedMatrix.AssignValuesOf(*valueMatrix);
                computationNodePtr->Value() = std::move(clonedMatrix);
            }
        }
        else if (variable.IsInput())
        {
            auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name());

            // TODO: Input variables currently are required to have the default batch axis
            auto dynamicAxes = variable.DynamicAxes();
            auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis());
            if (foundDefaultBatchAxis == dynamicAxes.end())
                LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes");

            if (dynamicAxes.back() != Axis::DefaultBatchAxis())
                LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes");

            // TODO: Support inputs with > 1 dynamic axes
            if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2))
                LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported");

            // Construct the dynamic axis name to be used internally for the CNTK InputNodes
            std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);

            if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
                network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});

            if (IsSparseInput(variable))
                computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName);
            else
                computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName);

            if (variable.NeedsGradient())
            {
                // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default
                // gradients are not computed for Input nodes
                computationNodePtr->SetLearningRateMultiplier(0.00001f);
            }
        }
        else
        {
            assert(variable.IsOutput());
            computationNodePtr = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap)->template As<ComputationNode<ElementType>>()->shared_from_this();
        }

        variableToNodeMap[variable] = computationNodePtr;
        if (isVariableRootMap.find(variable) == isVariableRootMap.end())
        isVariableRootMap[variable] = variable.IsOutput();

        return computationNodePtr;
    }

    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
                                                                               PrimitiveFunction* primitiveFunction,
                                                                               const std::vector<std::shared_ptr<ComputationNode<ElementType>>>& inputNodes,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap)
    {
        ComputationNodeBasePtr computationNodePtr;

        auto internalNodeName = CNTKInternalNodeNameFromUidAndName(primitiveFunction->Uid(), primitiveFunction->Name());

        auto& functionConfig = primitiveFunction->Attributes();
        auto functionInputs = primitiveFunction->Inputs();
        PrimitiveOpType op = primitiveFunction->OpType();

        switch (op)
        {
        case PrimitiveOpType::Negate:
            computationNodePtr = New<NegateNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Sigmoid:
            computationNodePtr = New<SigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Tanh:
            computationNodePtr = New<TanhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Cos:
            computationNodePtr = New<CosineNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Sin:
            computationNodePtr = New<SinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::ReLU:
            computationNodePtr = New<RectifiedLinearNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Exp:
            computationNodePtr = New<ExpNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Log:
            computationNodePtr = New<LogNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Sqrt:
            computationNodePtr = New<SqrtNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Floor:
            computationNodePtr = New<FloorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Abs:
            computationNodePtr = New<AbsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Reciprocal:
            computationNodePtr = New<ReciprocalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Softmax:
            computationNodePtr = New<SoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Hardmax:
            computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::TransposeAxes:
        {
            auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
            auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();

            // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
            computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
            break;
        }
        case PrimitiveOpType::Where:
        {
            auto dynamicAxes = variable.DynamicAxes();
            auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
            computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
            break;
        }
        case PrimitiveOpType::Slice:
        {
            auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
            auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
            auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value<int>();

            // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
            computationNodePtr = New<SliceNode<ElementType>>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis));
            break;
        }
        case PrimitiveOpType::RandomSample:
        {
            auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
            auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
            computationNodePtr = New<RandomSampleNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
            break;
        }
        case PrimitiveOpType::RandomSampleInclusionFrequency:
        {
            auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
            auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
            computationNodePtr = New<RandomSampleInclusionFrequencyNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
            break;
        }
        case PrimitiveOpType::Dropout:
        {
            auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value<double>();
            computationNodePtr = New<DropoutNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            computationNodePtr->As<DropoutNode<ElementType>>()->SetDropoutRate(dropoutRate);
            break;
        }
        case PrimitiveOpType::Reshape:
        {
            auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
            computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(newShape));
            break;
        }
        case PrimitiveOpType::ROIPooling:
        {
            auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value<NDShape>();
            computationNodePtr = New<ROIPoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape));
            break;
        }
        case PrimitiveOpType::Pooling:
        {
            PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
            auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
            auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
            auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
            auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
            auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
            computationNodePtr = New<PoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
            break;
        }
        case PrimitiveOpType::SumAll:
            computationNodePtr = New<SumElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Plus:
            computationNodePtr = New<PlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::LogPlus:
            computationNodePtr = New<LogPlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Minus:
            computationNodePtr = New<MinusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::ElementTimes:
            computationNodePtr = New<ElementTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Equal:
            computationNodePtr = New<EqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::NotEqual:
            computationNodePtr = New<NotEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Less:
            computationNodePtr = New<LessNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::LessEqual:
            computationNodePtr = New<LessEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Greater:
            computationNodePtr = New<GreaterNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::GreaterEqual:
            computationNodePtr = New<GreaterEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Times:
        {
            size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
            auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value<int>();
            computationNodePtr = New<TimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
            break;
        }
        case PrimitiveOpType::TransposeTimes:
        {
            size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
            computationNodePtr = New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank);
            break;
        }
        case PrimitiveOpType::Convolution:
        {
            NDShape outputMapCount, kernelShape;
            std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape());
            auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
            auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
            auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
            auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
            auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
            auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
            auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
            computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples);
            break;
        }
        case PrimitiveOpType::CosDistance:
            computationNodePtr = New<CosDistanceNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Logistic:
            computationNodePtr = New<LogisticNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::SquaredError:
            computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::CrossEntropyWithSoftmax:
            computationNodePtr = New<CrossEntropyWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::ClassificationError:
            computationNodePtr = New<ClassificationErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::PastValue:
        case PrimitiveOpType::FutureValue:
        {
            Variable inputOperandVar = functionInputs[0];
            Variable initialStateVar = functionInputs[1];

            size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
            if (op == PrimitiveOpType::PastValue)
                computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
            else
                computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);

            break;
        }
        case PrimitiveOpType::ReduceElements:
        {
            auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
            auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
            computationNodePtr = New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis));
            break;
        }
        case PrimitiveOpType::BatchNormalization:
        {
            auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
            auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
            auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
            auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
            auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();

            computationNodePtr = New<BatchNormalizationNode<ElementType>>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW);
            break;
        }
        case PrimitiveOpType::Combine:
            // This operation is just a no-op and is a means to combine multiple functions to create a single Function
            // whose outputs are a union of the outputs of the Functions being combined.
            computationNodePtr = variableToNodeMap[variable];
            break;
        case PrimitiveOpType::PackedIndex:
            computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::GatherPacked:
            computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::ScatterPacked:
            computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Clip:
            computationNodePtr = New<ClipNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Select:
            computationNodePtr = New<IfNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        case PrimitiveOpType::Splice:
        {
            Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
            computationNodePtr = New<RowStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
            break;
        }
        case PrimitiveOpType::OptimizedRNNStack:
        {
            auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value<bool>();
            auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value<size_t>();
            auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value<size_t>();
            auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value<std::wstring>();

            computationNodePtr = New<OptimizedRNNStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
            break;
        }
        case PrimitiveOpType::ReconcileDynamicAxis:
        {
            computationNodePtr = New<ReconcileDynamicAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        }
        case PrimitiveOpType::LogSoftmax:
        {
            //This can be implemented as x => x - ReduceLogSum(x). How to do this here?
            computationNodePtr = New<LogSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        }
        case PrimitiveOpType::Pass:
            computationNodePtr = New<PassNode<ElementType>>(network->GetDeviceId(), internalNodeName);
            break;
        default:
            LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
            break;
        }

        std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
        for (auto inputNode : inputNodes)
            inputNodesBasePtrs.push_back(inputNode);

        // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
        ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);
        if (computationNodePtr->Is<INumInputs>())
        {
            auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
            if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
                LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs",
                           PrimitiveOpTypeName(op).c_str(),
                           (int)inputNodesBasePtrs.size(),
                           (int)computationNodeExpectedInputCount);
        }

        if (computationNodePtr->Is<RngUser>())
        {
            if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed))
            {
                auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
                unsigned long long offset = 0;
                if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
                {
                    offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
                }
                computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
            }
        }

        network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);

        return computationNodePtr;
    }

    template <typename ElementType>
    /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
                                                                               Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
                                                                               ComputationNetworkBuilder<ElementType>& builder,
                                                                               std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
                                                                               std::unordered_map<Variable, bool>& isVariableRootMap)
    {
        assert(variable.IsOutput());

        Function* function = variable.Owner().get();
        ComputationNodeBasePtr computationNodePtr;
        if (dynamic_cast<PrimitiveFunction*>(function))
        {
            PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
            PrimitiveOpType op = primitiveFunction->OpType();
            auto& functionInputs = primitiveFunction->m_inputs;

            DataType nonConstInputDataType = DataType::Unknown;
            for (auto& inputVar : functionInputs)
            {
                if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown))
                {
                    nonConstInputDataType = inputVar.GetDataType();
                    break;
                }
            }
            
            // Create the nodes corresponding to the inputs
            std::vector<std::shared_ptr<ComputationNode<ElementType>>> inputNodes;
            for (auto& inputVar : functionInputs)
            {
                // If the inputVar is a constant and not the right DataType let's coerce it to the right type
                if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType))
                {
                    auto originalConstantValue = Constant(inputVar).Value();
                    auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true);
                    NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true);
                    inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name());
                }

                auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap);
                inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As<ComputationNode<ElementType>>()->shared_from_this() : nullptr);
            }

            computationNodePtr = CreateComputationNode(variable, primitiveFunction, inputNodes, network, variableToNodeMap);
            if (op != PrimitiveOpType::Combine)
            {
                for (auto inputVar : functionInputs)
                    isVariableRootMap[inputVar] = false;
            }
        }
        else
            LogicError("User defined Functions are currently unsupported!");

        return computationNodePtr;
    }

    template <typename ElementType>
    ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set<Variable>& backpropRoots, const std::unordered_set<Variable>& outputs, bool allocateNetworkMatrices)
    {
        if (m_computationNetwork != nullptr)
        {
            // TODO: We should either invalidate and readapt the network if he backpropRoots change compared to what was specified when the network
            // was last constructed, to just recreate a new network.
            // For now just disallow changing the backpropRoots after the network is created
            if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots))
                LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported");

            // TODO: Support changing the device across different invocations of the forward method on a Function instance
            if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device)
                LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported");
        }
        else
        {
            m_computationNetwork = std::make_shared<ComputationNetwork>(AsCNTKImplDeviceId(device));

            ComputationNetworkBuilder<ElementType> builder(*m_computationNetwork);

            // TODO: We currently only support one backprop root
            if (backpropRoots.size() > 1)
                LogicError("More than one backprop roots is currently unsupported");

            auto placeholders = Placeholders();
            if (!placeholders.empty())
                InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!");

            // Now recursively create the network in a top-down fashion
            auto rootFunction = RootFunction();
            auto rootFunctionOutputs = rootFunction->Outputs();
            for (auto rootOutput : rootFunctionOutputs)
                GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap);

            // If any of the function outputs is not a root node, we need to explicitly add it to the 'output' group of the ComputationNetwork
            for (auto rootOutput : rootFunctionOutputs)
            {
                if (!m_isVariableRootMap[rootOutput])
                    m_computationNetwork->AddToNodeGroup(L"output", m_variableToNodeMap[rootOutput]);
            }

            // If any of the requested outputs is not a root node, we need to explicitly add it to the 'output' group of the ComputationNetwork
            for (auto output : outputs)
            {
                if (!m_isVariableRootMap[output])
                    m_computationNetwork->AddToNodeGroup(L"output", m_variableToNodeMap[output]);
            }

            m_currentBackpropRoots = backpropRoots;

            // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles.
            // Now attach those after we have created all ComputationNodes in the network
            for (auto varNodePair : m_variableToNodeMap)
            {
                auto& currentComputationNode = varNodePair.second;
                auto& currentComputationNodeInputs = currentComputationNode->GetInputs();
                auto& currentVar = varNodePair.first;

                if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end())
                {
                    // This ComputationNode has at least one null input which now needs to be properly attached

                    const PrimitiveFunction* primitiveFunc = dynamic_cast<const PrimitiveFunction*>(currentVar.Owner().get());

                    // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
                    auto inputVars = primitiveFunc->Inputs();
                    ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars);
                    inputVars.resize(currentComputationNode->GetNumInputs());

                    std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
                    for (auto inputVar : inputVars)
                        inputNodesBasePtrs.push_back(m_variableToNodeMap[inputVar]);

                    currentComputationNode->AttachInputs(inputNodesBasePtrs);
                }
            }

            m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel());
            m_computationNetwork->CompileNetwork();

            // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork
            for (auto varNodePair : m_variableToNodeMap)
            {
                if (varNodePair.first.IsOutput())
                {
                    auto outputVar = varNodePair.first;
                    auto computationNodePtr = m_variableToNodeMap[outputVar];
                    auto outputShape = outputVar.Shape();
                    auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
                    if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) ||
                        ((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape))))
                    {
                        LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str());
                    }
                }
            }

            // Record the timestamps of Parameter values
            assert(m_lastRecordedParameterValueTimeStamps.empty());
            auto functionParameters = Parameters();
            for (auto parameter : functionParameters)
                m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() });
        }


        if (!m_networkMatricesAllocated && allocateNetworkMatrices)
        {
            ComputationNodeBasePtr backpropRootNode;

            // Now recursively traverse the network in a top-down fashion
            auto rootFunction = RootFunction();
            auto rootFunctionOutputs = rootFunction->Outputs();
            std::vector<ComputationNodeBasePtr> forwardRootNodes;
            for (auto rootOutput : rootFunctionOutputs)
            {
                auto currentRootNode = m_variableToNodeMap[rootOutput];
                forwardRootNodes.push_back(currentRootNode);

                if (m_currentBackpropRoots.find(rootOutput) != m_currentBackpropRoots.end())
                    backpropRootNode = currentRootNode;
            }

            std::vector<ComputationNodeBasePtr> forwardOutputNodes;
            for (auto output : outputs)
            {
                auto currentOutputNode = m_variableToNodeMap[output];
                forwardOutputNodes.push_back(currentOutputNode);

                if (m_currentBackpropRoots.find(output) != m_currentBackpropRoots.end())
                    backpropRootNode = currentOutputNode;
            }

            m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode);
            m_networkMatricesAllocated = allocateNetworkMatrices;

            m_currentOutputs = outputs;
            m_currentOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
            m_currentOutputs.insert(m_currentBackpropRoots.begin(), m_currentBackpropRoots.end());
        }
        else
        {
            // Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure
            // in the cached computation network
            for (auto output : outputs)
            {
                if (m_currentOutputs.find(output) == m_currentOutputs.end())
                    LogicError("Changing requested outputs across different Forward calls on a CNTK composite Function is currently unsupported");
            }
        }

        return m_computationNetwork;
    }

    template <typename ElementType>
    /*static*/ std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CompositeFunction::GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value)
    {
        if (var.GetDataType() != value->GetDataType())
            LogicError("The Variable's DataType %s does not match the corresponding Value's DataType %s", DataTypeName(var.GetDataType()), DataTypeName(value->GetDataType()));

        if (AsDataType<ElementType>() != value->GetDataType())
            LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(value->GetDataType()));

        // TODO: Is supplying dense data for an Input variable tagged as sparse, a fatal error?
        if (IsSparseInput(var) && !value->IsSparse())
            InvalidArgument("Dense input data supplied for a sparse input Variable");

        if (IsSparseInput(var) && (value->GetStorageFormat() != StorageFormat::SparseCSC))
            InvalidArgument("Sparse Input data must be in SparseCSC format");

        auto varShape = var.Shape();
        auto valueShape = value->Shape();
        if (valueShape.Rank() < varShape.Rank())
            InvalidArgument("Value's rank should be >= the Variable's rank");

        size_t maxAddionalValueAxes = std::max<size_t>(2, var.DynamicAxes().size());
        if (valueShape.Rank() > (varShape.Rank() + maxAddionalValueAxes))
            InvalidArgument("Value rank should be larger than the Variable%S rank at most by number of dynamic axes", ParanthesizedName(var.Name()).c_str());

        if (valueShape.SubShape(0, varShape.Rank()) != varShape)
        {
            InvalidArgument("The %s dimensions of the Value shape %S do not match the shape of the variable %S that it corresponds to!",
                            Internal::IsReversingTensorShapesInErrorMessagesEnabled() ? "trailing" : "leading",
                            AsStringForErrorReporting(valueShape).c_str(),
                            AsStringForErrorReporting(varShape).c_str());
        }

        if (var.DynamicAxes().empty())
            return{ value->Data()->GetMatrix<ElementType>(), nullptr };

        if (var.DynamicAxes().size() > 2)
            LogicError("More than 2 dynamic axis for a variable is currently unsupported");

        auto mask = value->Mask();
        if ((mask != nullptr) && ((varShape.Rank() + mask->Shape().Rank()) != valueShape.Rank()))
            InvalidArgument("Invalid Value object; the sum of the rank of the mask and data does not equal the Variable's rank + number of dynamic axes");

        auto getNumTimeStepsAndSequencesFunc = [](const NDShape& maskShape) {
            size_t maxNumTimeSteps = 1;
            size_t numSequences = 1;
            if (maskShape.Rank() > 0)
                maxNumTimeSteps = maskShape[0];

            if (maskShape.Rank() > 1)
                numSequences = maskShape[1];

            return std::pair<size_t, size_t>(maxNumTimeSteps, numSequences);
        };

        size_t maxNumTimeSteps, numSequences;
        std::tie(maxNumTimeSteps, numSequences) = getNumTimeStepsAndSequencesFunc(valueShape.SubShape(varShape.Rank()));

        auto getSequenceStartsAndLengthsFunc = [&getNumTimeStepsAndSequencesFunc](const NDMaskPtr& mask, std::vector<ptrdiff_t>& sequenceBeginIndices, std::vector<size_t>& sequenceLengths) {
            auto cpuMask = mask;
            if (mask->Device() != DeviceDescriptor::CPUDevice())
                cpuMask = mask->DeepClone(DeviceDescriptor::CPUDevice());

            const MaskKind* maskBuffer = cpuMask->DataBuffer();
            size_t maxNumTimeSteps, numSequences;
            std::tie(maxNumTimeSteps, numSequences) = getNumTimeStepsAndSequencesFunc(mask->Shape());

            for (size_t i = 0; i < numSequences; ++i)
            {
                MaskKind firstMaskEntry = maskBuffer[i * maxNumTimeSteps];
                if (firstMaskEntry == MaskKind::SequenceBegin)
                    sequenceBeginIndices[i] = 0;
                else if (firstMaskEntry == MaskKind::Valid)
                    sequenceBeginIndices[i] = Microsoft::MSR::CNTK::SentinelValueIndicatingUnspecifedSequenceBeginIdx;
                else
                    LogicError("The first entry of a mask should be Valid or SequenceBegin");

                size_t currentSequenceLength = 1;
                bool currentSequenceEndAlreadyFound = false;
                for (size_t j = 1; j < maxNumTimeSteps; ++j)
                {
                    if (maskBuffer[(i * maxNumTimeSteps) + j] == MaskKind::Invalid)
                        currentSequenceEndAlreadyFound = true;
                    else
                    {
                        if (currentSequenceEndAlreadyFound)
                            InvalidArgument("Invalid Value object; only trailing steps of a sequence can be masked");

                        currentSequenceLength++;
                    }
                }

                sequenceLengths[i] = currentSequenceLength;
            }
        };

        if ((numSequences == 1) || (maxNumTimeSteps == 1))
        {
            // The data need not be shuffled
            std::shared_ptr<const Matrix<ElementType>> matrixData = value->Data()->GetMatrix<ElementType>(varShape.Rank());
            auto layout = std::make_shared<MBLayout>();
            if (!mask)
            {
                if (maxNumTimeSteps == 1)
                    layout->InitAsFrameMode(numSequences);
                else
                {
                    layout->Init(numSequences, maxNumTimeSteps);
                    layout->AddSequence(0, 0, 0, maxNumTimeSteps);
                }
            }
            else
            {
                layout->Init(numSequences, maxNumTimeSteps);

                std::vector<ptrdiff_t> sequenceBeginIndices(numSequences, 0);
                std::vector<size_t> sequenceLengths(numSequences, maxNumTimeSteps);
                getSequenceStartsAndLengthsFunc(mask, sequenceBeginIndices, sequenceLengths);

                for (size_t i = 0; i < numSequences; ++i)
                    layout->AddSequence(i, i, sequenceBeginIndices[i], sequenceLengths[i]);
            }

            return{ matrixData , layout};
        }
        else
        {
            std::vector<ptrdiff_t> sequenceBeginIndices(numSequences, 0);
            std::vector<size_t> sequenceLengths(numSequences, maxNumTimeSteps);
            if (mask != nullptr)
                getSequenceStartsAndLengthsFunc(mask, sequenceBeginIndices, sequenceLengths);

            bool hasTruncatedSequences = std::find_if(sequenceBeginIndices.begin(), sequenceBeginIndices.end(), [](const int& val) { return (val < 0); }) != sequenceBeginIndices.end();

            auto layout = std::make_shared<MBLayout>();
            std::vector<std::pair<size_t, size_t>> placement;
            if (!hasTruncatedSequences)
            {
                std::vector<MBLayout::SequenceInfo> sequences;
                for (size_t i = 0; i < numSequences; ++i)
                    sequences.push_back({ i, SIZE_MAX, sequenceBeginIndices[i], sequenceLengths[i] });

                std::vector<size_t> rowAllocations;
                layout->InitAsPackedSequences(sequences, placement, rowAllocations);
            }
            else
            {
                layout->Init(numSequences, maxNumTimeSteps);

                // We cannot pack as some of the sequences are truncated and thus all sequences have to be
                // kept in their original parallel streams
                placement.resize(numSequences);
                for (size_t i = 0; i < numSequences; ++i)
                {
                    layout->AddSequence(i, i, sequenceBeginIndices[i], sequenceLengths[i]);

                    // Add the gap if there is one
                    if (sequenceLengths[i] < maxNumTimeSteps)
                        layout->AddSequence(GAP_SEQUENCE_ID, i, sequenceLengths[i], maxNumTimeSteps);

                    placement[i] = std::make_pair(i, 0);
                }
            }

            if (maxNumTimeSteps != layout->GetNumTimeSteps())
                LogicError("The number of time steps in the packed MBLayout does not match the longest sequence's length in the Value object");

            if (numSequences != layout->GetNumSequences())
                LogicError("The number of sequences in the packed MBLayout does not match the sequence count in the Value object");

            // The data needs to be rearranged since CNTK requires sequences to be interleaved across timesteps
            // Now generate the gather indices
            auto matrixData = std::make_shared<Matrix<ElementType>>(varShape.TotalSize(),
                                                                    layout->GetNumCols(),
                                                                    AsCNTKImplDeviceId(value->Device()),
                                                                    value->IsSparse() ? MatrixType::SPARSE : MatrixType::DENSE,
                                                                    AsCNTKImplMatrixFormat(value->GetStorageFormat()));

            std::vector<size_t> sequencesShorterThanLongestSequence;
            for (size_t i = 0; i < numSequences; ++i)
                if (sequenceLengths[i] != maxNumTimeSteps)
                    sequencesShorterThanLongestSequence.push_back(i);

            // Set the source location for all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch
            size_t sourceColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1);
            std::vector<ElementType> gatherIndicesVector(layout->GetNumCols(), (ElementType)sourceColIdxForInvalidColumns);
            for (size_t i = 0; i < numSequences; ++i)
            {
                size_t targetParallelStreamIdx = placement[i].first;
                size_t targetStartIdxInParallelStream = placement[i].second;
                for (size_t j = 0; j < sequenceLengths[i]; ++j)
                    gatherIndicesVector[((targetStartIdxInParallelStream + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j);
            }

            auto gatherIdxMatrix = std::make_shared<Matrix<ElementType>>(1, layout->GetNumCols(), gatherIndicesVector.data(), AsCNTKImplDeviceId(value->Device()));
            matrixData->DoGatherColumnsOf(0, *gatherIdxMatrix, *(value->Data()->GetMatrix<ElementType>(varShape.Rank())), 1);
            return{ matrixData, layout };
        }
    }

    template <typename ElementType>
    /*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const Matrix<ElementType>& matrix, const MBLayoutPtr& layout, bool readOnly /*= true*/)
    {
        NDShape valueDataShape = sampleShape;

        size_t maxNumTimeSteps = 1;
        size_t numSequences = 1;
        if (layout != nullptr)
        {
            maxNumTimeSteps = layout->GetNumTimeSteps();
            numSequences = layout->GetNumSequences();
            valueDataShape = valueDataShape.AppendShape({ maxNumTimeSteps, numSequences });
        }

        auto createMaskFunc = [](const MBLayoutPtr& layout, const DeviceDescriptor& device, std::vector<size_t>& sequencesShorterThanLongestSequence) {
            std::vector<bool> sequenceBeginFlags;
            std::vector<size_t> sequenceLengths;
            sequencesShorterThanLongestSequence.clear();

            size_t maxNumTimeSteps = layout->GetNumTimeSteps();
            size_t numSequences = layout->GetNumSequences();
            auto& layoutSequences = layout->GetAllSequences();

            size_t sequenceIdx = 0;
            bool allSequencesStartInThisMB = true;
            bool allSequencesSameLength = true;
            for (auto sequenceInfo : layoutSequences)
            {
                if (sequenceInfo.seqId != GAP_SEQUENCE_ID)
                {
                    auto currentSequenceBeginIdx = std::max<ptrdiff_t>(0, sequenceInfo.tBegin);
                    auto currentSequenceEndIdx = std::min(maxNumTimeSteps, sequenceInfo.tEnd);
                    auto currentSequenceLength = (currentSequenceEndIdx - currentSequenceBeginIdx);
                    auto isCurrentSequenceBeginningInsideThisMB = sequenceInfo.tBegin >= 0;

                    allSequencesStartInThisMB = allSequencesStartInThisMB && isCurrentSequenceBeginningInsideThisMB;
                    allSequencesSameLength = allSequencesSameLength && (currentSequenceLength == maxNumTimeSteps);

                    sequenceBeginFlags.push_back(isCurrentSequenceBeginningInsideThisMB);
                    sequenceLengths.push_back(currentSequenceLength);

                    if (currentSequenceLength != maxNumTimeSteps)
                        sequencesShorterThanLongestSequence.push_back(sequenceIdx);

                    sequenceIdx++;
                }
            }

            if (!allSequencesStartInThisMB && (numSequences != layout->GetNumParallelSequences()))
                LogicError("Cannot create an unpacked Value object from packed data where one or more sequences are truncated");

            bool maskNeeded = !allSequencesSameLength || !allSequencesStartInThisMB;

            NDMaskPtr mask;
            if (maskNeeded)
            {
                mask = MakeSharedObject<NDMask>(NDShape({ maxNumTimeSteps, numSequences }), DeviceDescriptor::CPUDevice());
                for (size_t i = 0; i < numSequences; ++i)
                    if (sequenceBeginFlags[i])
                        mask->MarkSequenceBegin({0, i});

                for (auto shortSequenceIdx : sequencesShorterThanLongestSequence)
                    mask->InvalidateSection({ sequenceLengths[shortSequenceIdx], shortSequenceIdx }, { NDShape::InferredDimension, 1 });
            }

            return mask;
        };

        // No data shuffling needed if no layout or the layout has just one time-step or just one sequence
        std::vector<size_t> sequencesShorterThanLongestSequence;
        if ((maxNumTimeSteps == 1) || (numSequences == 1))
        {
            // Just create a view over the existing matrix itself
            auto tensorView = new TensorView<ElementType>(std::make_shared<Matrix<ElementType>>(matrix.AsReference()), AsTensorViewShape(valueDataShape));
            auto data = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), valueDataShape, readOnly, tensorView);
            if (layout == nullptr)
                return MakeSharedObject<Value>(data);
            else
            {
                auto mask = createMaskFunc(layout, AsDeviceDescriptor(matrix.GetDeviceId()), sequencesShorterThanLongestSequence);
                return MakeSharedObject<Value>(data, mask);
            }
        }

        if (layout->GetNumCols() != matrix.GetNumCols())
            LogicError("Bad MBLayout: The number of columns in the MBLayout does not match the number of columns in the data matrix!");

        // Reshuffle to data to unpack and uninterleave the CNTK form packed data
        // Now generate the scatter indices
        auto shuffledMatrixData = std::make_shared<Matrix<ElementType>>(matrix.GetNumRows(), maxNumTimeSteps * numSequences, matrix.GetDeviceId(), matrix.GetMatrixType(), matrix.GetFormat());
        auto mask = createMaskFunc(layout, AsDeviceDescriptor(matrix.GetDeviceId()), sequencesShorterThanLongestSequence);

        // Set the target location of all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch
        size_t targetColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1);
        std::vector<ElementType> scatterIndicesVector(layout->GetNumCols(), (ElementType)targetColIdxForInvalidColumns);

        size_t i = 0;
        auto& layoutSequences = layout->GetAllSequences();
        for (auto sequenceInfo : layoutSequences)
        {
            if (sequenceInfo.seqId != GAP_SEQUENCE_ID)
            {
                size_t targetParallelStreamIdx = sequenceInfo.s;
                auto currentSequenceBeginIdx = std::max<ptrdiff_t>(0, sequenceInfo.tBegin);
                auto currentSequenceEndIdx = std::min(maxNumTimeSteps, sequenceInfo.tEnd);
                size_t currentSequenceLength = (currentSequenceEndIdx - currentSequenceBeginIdx);

                for (size_t j = 0; j < currentSequenceLength; ++j)
                    scatterIndicesVector[((currentSequenceBeginIdx + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j);

                i++;
            }
        }

        auto scatterIdxMatrix = std::make_shared<Matrix<ElementType>>(1, layout->GetNumCols(), scatterIndicesVector.data(), matrix.GetDeviceId());
        shuffledMatrixData->DoScatterColumnsOf(0, *scatterIdxMatrix, matrix, 1);

        auto tensorView = new TensorView<ElementType>(shuffledMatrixData, AsTensorViewShape(valueDataShape));
        auto data = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(shuffledMatrixData->GetFormat()), valueDataShape, readOnly, tensorView);
        return MakeSharedObject<Value>(data, mask);
    }

    template <typename ElementType>
    /*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Matrix<ElementType>& matrix, const MBLayoutPtr& layout, bool readOnly /*= true*/)
    {
        if (var.DynamicAxes().size() > 2)
            LogicError("More than 2 dynamic axis for a variable is currently unsupported");

        if (AsDataType<ElementType>() != var.GetDataType())
            LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(var.GetDataType()));

        if ((layout != nullptr) && (matrix.GetNumRows() != var.Shape().TotalSize()))
            LogicError("Unexpected matrix layout: The number of rows in the matrix does not match the sample size of the Variable");

        return GetValueObjectFromCNTKImplMatrixAndMBLayout(var.Shape(), matrix, layout, readOnly);
    }

    template <typename ElementType>
    /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, ComputationNodeBasePtr& computationNode)
    {
        std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout;
        auto packedValue = dynamic_cast<PackedValue*>(variableValue.second.get());
        if (packedValue)
            CNTKMatrixAndMBLayout = packedValue->PackedData<ElementType>();
        else
            CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableValue.first, variableValue.second);

        MBLayoutPtr layout = CNTKMatrixAndMBLayout.second;

        auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();

        // Switch the node matrix to the right matrix type
        nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);
        computationNode->GetMBLayout()->CopyFrom(layout);
    }

    void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
    {
        std::vector<ComputationNodeBasePtr> inputNodes;
        for (auto argumentValuePair : arguments)
        {
            auto argument = argumentValuePair.first;
            auto argumentComputationNode = m_variableToNodeMap[argument];
            assert(argumentComputationNode);
            inputNodes.push_back(argumentComputationNode);

            ValuePtr argumentValue = arguments.at(argument);

            MBLayoutPtr layout;
            switch (argumentValue->GetDataType())
            {
            case DataType::Float:
                PopulateComputationNodeValue<float>({ argument, argumentValue }, argumentComputationNode);
                break;
            case DataType::Double:
                PopulateComputationNodeValue<double>({ argument, argumentValue }, argumentComputationNode);
                break;
            default:
                LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType()));
                break;
            }
        }

        m_computationNetwork->BumpEvalTimeStamp(inputNodes);
    }

    template <typename ElementType>
    /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode)
    {
        std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout;
        auto packedValue = dynamic_cast<PackedValue*>(variableGradient.second.get());
        if (packedValue)
            CNTKMatrixAndMBLayout = packedValue->PackedData<ElementType>();
        else
            CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableGradient.first, variableGradient.second);

        MBLayoutPtr layout = CNTKMatrixAndMBLayout.second;
        auto nodeLayout = computationNode->GetMBLayout();
        if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
            InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call");
        computationNode->As<ComputationNode<ElementType>>()->AssignGradient(*CNTKMatrixAndMBLayout.first);
    }

    // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph
    void CompositeFunction::PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients)
    {
        auto functionOutputs = this->Outputs();
        for (auto gradientVarValuePair : gradients)
        {
            auto outputComputationNode = m_variableToNodeMap[gradientVarValuePair.first];
            ValuePtr gradientValue = gradientVarValuePair.second;

            switch (gradientValue->GetDataType())
            {
            case DataType::Float:
                PopulateComputationNodeGradient<float>(gradientVarValuePair, outputComputationNode);
                break;
            case DataType::Double:
                PopulateComputationNodeGradient<double>(gradientVarValuePair, outputComputationNode);
                break;
            default:
                LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType()));
                break;
            }
        }
    }

    static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr)
    {
        size_t outputValueNumAxes = var.Shape().Rank();

        // Add the batch and dynamic axes if needed
        if (computationNodePtr->GetMBLayout() != nullptr)
            outputValueNumAxes += 2;

        std::vector<size_t> outputShapeDims(outputValueNumAxes);
        for (size_t i = 0; i < var.Shape().Rank(); ++i)
            outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i);

        if (computationNodePtr->GetMBLayout() != nullptr)
        {
            outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps();
            outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences();
        }

        return NDShape(outputShapeDims);
    }

    /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient)
    {
        auto valueShape = GetValueShape(var, computationNode);
        if (varValue != nullptr)
        {
            // TODO: The shape of the specified output Value object must match the actual output shape
            if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape)))
                InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str());
        }

        ValuePtr nodeValue;
        auto layout = computationNode->GetMBLayout();
        switch (var.GetDataType())
        {
        case DataType::Float:
        {
            auto& matrix = getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value();
            if (varValue == nullptr)
                nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<float>>(matrix.AsReference()), layout, /*readOnly =*/ false);
            else
                nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var, matrix, layout);
            break;
        }
        case DataType::Double:
        {
            auto& matrix = getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value();
            if (varValue == nullptr)
                nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<double>>(matrix.AsReference()), layout, /*readOnly =*/ false);
            else
                nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var, matrix, layout);
            break;
        }
        default:
            LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType()));
            break;
        }

        if (varValue == nullptr)
            varValue = nodeValue;
        else
            varValue->CopyFrom(*nodeValue);
    }

    void CompositeFunction::GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs)
    {
        // Now copy the Forward values of output nodes from the network to outputs' Value objects
        for (auto outputVarValuePair : outputs)
            GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap[outputVarValuePair.first], false /*getGradient*/);
    }

    void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
    {
        auto networkInputs = this->Inputs();
        // Now copy the gradient values of input nodes of the network to gradients' Value objects
        for (auto gradientVarValuePair : gradients)
        {
            // Only gradients corresponding to inputs of the network can be obtained
            if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end())
                InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function");

            // Gradients can only be obtained for parameter variables or input variables that NeedsGradient
            if (!gradientVarValuePair.first.NeedsGradient())
                InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, or an Input Variable with NeedsGradient setting of false");

            auto computationNodePtr = m_variableToNodeMap[gradientVarValuePair.first];

            if (!computationNodePtr->NeedsGradient())
                LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false");

            GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/);
        }
    }

    const std::vector<Variable>& CompositeFunction::GetArgumentDependencies(const Variable& output)
    {
        assert(output.IsOutput());

        auto iter = m_perOutputVarArgumentDependencies.find(output);
        if (iter != m_perOutputVarArgumentDependencies.end())
            return iter->second;

        auto wrappedComposite = CompositeFunction::Create(output.Owner());
        m_perOutputVarArgumentDependencies[output] = wrappedComposite->Arguments();

        return m_perOutputVarArgumentDependencies[output];
    }

    /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
                                                            std::unordered_map<Variable, ValuePtr>& outputs,
                                                            const DeviceDescriptor& computeDevice,
                                                            const std::unordered_set<Variable>& outputsToRetainBackwardStateFor)
    {
        // Validate arguments and outputs
        if (outputs.empty())
            InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!");

        // Make sure that the DataType of the variables and corresponding values match
        // TODO: We need a better way to determine the ElementType for the network
        auto dataType = DataType::Unknown;
        for (auto variableValuePair : arguments)
        {
            if (dataType == DataType::Unknown)
                dataType = variableValuePair.first.GetDataType();
            else if (dataType != variableValuePair.first.GetDataType())
                LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same");
        }

        if (dataType == DataType::Unknown)
        {
            for (auto variableValuePair : outputs)
            {
                if (dataType == DataType::Unknown)
                    dataType = variableValuePair.first.GetDataType();
            }
        }

        std::unordered_set<Variable> requestedOutputVariables;
        for (auto output : outputs)
            requestedOutputVariables.insert(output.first);

        if (dataType == DataType::Float)
            GetComputationNetwork<float>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, true);
        else if (dataType == DataType::Double)
            GetComputationNetwork<double>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, true);
        else
            InvalidArgument("Unsupported DataType %s", DataTypeName(dataType));

        std::unordered_set<Variable> functionOutputs(this->Outputs().begin(), this->Outputs().end());
        std::vector<ComputationNodeBasePtr> outputsToEvaluate;
        std::unordered_set<Variable> requiredArguments;
        for (auto outputVarValuePair : outputs)
        {
            auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first);
            requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end());

            auto outputComputationNode = m_variableToNodeMap[outputVarValuePair.first];
            outputsToEvaluate.push_back(outputComputationNode);
        }

        // TODO: Avoid copying the data when possible

        // We should have argument values supplied for all required argument dependencies for the requested outputs
        for (auto requiredArgument : requiredArguments)
        {
            if (arguments.find(requiredArgument) == arguments.end())
                InvalidArgument("Function::Forward: Required argument's (%S) value that the requested output(s) depend on has not been provided", requiredArgument.Name().c_str());
        }

        // Feed data into the arguments of the network
        PopulateNetworkInputs(arguments);

        // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input
        // This mask is regerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated
        // w.r.t. inputs to force evaluation in each minibatch
        list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode));
        for (auto& nodeIter : dropoutNodes)
            nodeIter->SetEvalTimeStampOutdatedWrtAll();
        
        // Bump the timestamp of the parameter nodes whose values have changed
        for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps)
        {
            auto parameter = paramTimeStampRecord.first;
            auto prevTimeStamp = paramTimeStampRecord.second;
            auto newTimeStamp = parameter.CurrentValueTimeStamp();
            if (newTimeStamp > prevTimeStamp)
            {
                paramTimeStampRecord.second = newTimeStamp;
                m_variableToNodeMap[parameter]->BumpEvalTimeStamp();
            }
        }

        // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
        for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
        {
            if (outputs.find(rootVarForBackprop) == outputs.end())
                outputsToEvaluate.push_back(m_variableToNodeMap[rootVarForBackprop]);
        }

        // TODO: Verify that values were supplied for all inputs that requested outputs depend on

        ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);

        m_computationNetwork->ForwardProp(outputsToEvaluate);

        GetNetworkOutputs(outputs);

        // TODO: How to deal with the specified 'computeDevice'
        Variable evalTimeStampVariable;
        if (arguments.empty())
            evalTimeStampVariable = Inputs()[0];
        else
            evalTimeStampVariable = arguments.begin()->first;

        return (outputsToRetainBackwardStateFor.size() > 0) ? MakeSharedObject<CNTKBackPropState>(this->shared_from_this(), computeDevice, std::make_pair(evalTimeStampVariable, m_variableToNodeMap[evalTimeStampVariable]->GetEvalTimeStamp())) : nullptr;
    }

    /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
                                                 const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
                                                 std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
    {
        auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
        if (backpropState == nullptr)
            InvalidArgument("Invalid backprop state specified");

        // TODO: Support multiple concurrent backprop states
        if (backpropState->EvalTimeStamp().second != m_variableToNodeMap[backpropState->EvalTimeStamp().first]->GetEvalTimeStamp())
            LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function."
                       "This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");

        if (rootGradientValues.size() > 1)
            LogicError("Currently gradient backprop from only one of the Function Outputs is supported");

        // TODO: Avoid copying the data when possible

        // Zero all gradients of nodes below the root nodes
        for (auto rootGradientVarValuePair : rootGradientValues)
            m_computationNetwork->ZeroInputGradients(m_variableToNodeMap[rootGradientVarValuePair.first]);

        // Feed data into the arguments of the network
        PopulateNetworkGradients(rootGradientValues);

        // Backpropagate through the network
        ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training);

        auto rootComputationNodePtr = m_variableToNodeMap[rootGradientValues.begin()->first];
        m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true);

        GetNetworkGradients(backPropagatedGradientValuesForInputs);

        // TODO: How to deal with the specified 'computeDevice'
    }
}
back to top