// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include "CNTKLibrary.h" #include "CompositeFunction.h" #include "ComputationNetworkBuilder.h" #include "Utils.h" #include "ComputationNode.h" #include "ReshapingNodes.h" #include "EvaluationNodes.h" #include "TrainingNodes.h" #include "LinearAlgebraNodes.h" #include "InputAndParamNodes.h" #include "NonlinearityNodes.h" #include "RecurrentNodes.h" #include "Serialization.h" #include "Value.h" #include "RNNNodes.h" using namespace Microsoft::MSR::CNTK; namespace CNTK { /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName"; /*static*/ std::atomic CompositeFunction::s_nextAutoGeneratedDynamicAxis(0); static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction"; /*virtual*/ Dictionary CompositeFunction::Serialize() const { Dictionary dict; dict[versionKey] = CurrentVersion(); dict[typeKey] = s_compositeFunctionTypeValue; dict[rootKey] = RootFunction()->Uid(); dict[nameKey] = Name(); dict[uidKey] = Uid(); // Find cycles in the graph and "break" them by inserting placeholders. // This needs to be done on Save, since here we have easy access to the shape and // dynamic axis info. std::unordered_set visitedFunctions; std::vector topoSortedPrimitiveFunctions; std::vector inputs; std::unordered_set inputUids; Traverse(RootFunction(), [&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) { std::vector functionInputs = function->Inputs(); for (const auto& input : functionInputs) { auto& uid = input.Uid(); if (inputUids.find(uid) != inputUids.end()) { continue; } // check if this input corresponds to a cyclic edge in the graph. bool mustBeReplaced = input.IsOutput() && visitedFunctions.find(input.Owner()) != visitedFunctions.end(); if (mustBeReplaced) { auto varKind = VariableKind::Placeholder; Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid); inputs.push_back(var); inputUids.insert(uid); } else if (!input.IsOutput()) { // leave the input as is. inputs.push_back(input); inputUids.insert(uid); } } visitedFunctions.insert(function); topoSortedPrimitiveFunctions.push_back(function); }); std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions)); assert(topoSortedPrimitiveFunctions.size() == m_allPrimitiveFunctions.size()); assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid()); std::vector inputDictionaries; inputDictionaries.reserve(inputs.size()); inputUids.clear(); for (const auto& input : inputs) { if (inputUids.find(input.Uid()) != inputUids.end()) { LogicError("Input uids must be unique"); } inputUids.insert(input.Uid()); inputDictionaries.push_back(input.Serialize()); } dict[inputsKey] = std::move(inputDictionaries); std::vector functionDictionaries; std::unordered_set outputUids; for (const auto& primitiveFunciton : topoSortedPrimitiveFunctions) { for (const auto& output : primitiveFunciton->Outputs()) { if (outputUids.find(output.Uid()) != outputUids.end()) { LogicError("Output uids of all primitive functions in a function graph must be unique"); } outputUids.insert(primitiveFunciton->Uid()); } functionDictionaries.push_back(primitiveFunciton->Serialize()); } dict[functionsKey] = std::move(functionDictionaries); // Now, collect and store the internal state for all non-pure (stateful) functions in the graph // (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc). Dictionary stateDictionary; for (const auto& kv : m_variableToNodeMap) { if (kv.second->Is()) { auto rng = kv.second->As(); Dictionary state; state[rngSeedKey] = static_cast(rng->GetRngSeed()); state[rngOffsetKey] = static_cast(rng->GetRngOffset()); stateDictionary[kv.first.Owner()->Uid()] = state; } } dict[stateKey] = std::move(stateDictionary); return dict; } /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device) { static const vector s_requiredDictionaryKeys = { typeKey, rootKey, nameKey, uidKey, inputsKey, functionsKey }; size_t version = ValidateDictionary(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion); const auto& rootUid = dict[rootKey].Value(); const auto& name = dict[nameKey].Value(); const auto& uid = dict[uidKey].Value(); const auto& inputs = dict[inputsKey].Value>(); std::unordered_map uidToInputMap(inputs.size()); for (const auto& dictionaryValue : inputs) { const auto& dictionary = dictionaryValue.Value(); const auto& inputVar = Variable::Deserialize(dictionary, device); if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end()) { LogicError("Input uids are not unique (several inputs share '%ls' uid) " "(%s).", inputVar.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); } uidToInputMap[inputVar.Uid()] = inputVar; } Dictionary stateDictionary; if (dict.Contains(stateKey)) { stateDictionary = dict[stateKey].Value(); } const auto& functions = dict[functionsKey].Value>(); FunctionPtr root; std::unordered_map placeholderReplacements; std::unordered_set allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created. for (const auto& dictionaryValue : functions) { root = PrimitiveFunction::Deserialize(dictionaryValue.Value(), uidToInputMap, device); allPrimitiveFunctions.insert(root); auto primitiveFunction = dynamic_cast(root.get()); // Since Combine simply forwards other functions' outputs, all of its outputs // should already be in the uidToInputMap. auto opType = primitiveFunction->OpType(); if (opType == PrimitiveOpType::Combine) { continue; } if (primitiveFunction->IsStateful()) { if (stateDictionary.Contains(primitiveFunction->Uid())) { auto state = stateDictionary[primitiveFunction->Uid()].Value(); auto seed = state[rngSeedKey].Value(); auto offset = state[rngOffsetKey].Value(); primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed; primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset; } else if (Internal::GetComputationNetworkTraceLevel() > 0) { // TODO: all logging functionality should be refactored to live in a logging utility class. fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) " "when deserializing from a dictionary (version=%zu). " "Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version); } } for (const auto& output : root->Outputs()) { const auto& it = uidToInputMap.find(output.Uid()); if (it != uidToInputMap.end()) { if (!it->second.IsPlaceholder()) { LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)" "(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); } placeholderReplacements[it->second] = output; } else { uidToInputMap[output.Uid()] = output; } } } if (root->Uid() != rootUid) { LogicError("Root UID '%ls' is different from the expected value '%ls'.", root->Uid().c_str(), rootUid.c_str()); } if (placeholderReplacements.size() > 0) { return CompositeFunction::Create(root->ReplacePlaceholders(placeholderReplacements), name, uid); } return CompositeFunction::Create(root, name, uid); } void CompositeFunction::CopyState(const CompositeFunction& source) { // Create a map with all non-pure (stateful) functions in the function graph. auto collectStatefulFunctions = [](const std::unordered_set& allPrimitiveFunctions) -> std::map { std::map functionMap; for (auto funcPtr : allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(funcPtr.get()); if (primitiveFunction->IsStateful()) { functionMap[primitiveFunction->Uid()] = funcPtr; } } return functionMap; }; std::map statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions); std::map statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions); assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size()); if (statefulFunctionsFrom.size() == 0) { return; } // Copy state captured in the attributes dictionaries. for (const auto& kv : statefulFunctionsFrom) { statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes(); } UpdateInternalNetworkState(); } void CompositeFunction::UpdateInternalNetworkState() { if (!m_computationNetwork) { return; } for (const auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); if (primitiveFunction->IsStateful()) { for (const auto& output : function->Outputs()) { auto node = m_variableToNodeMap[output]; auto attributes = function->Attributes(); auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value(); auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value(); node->As()->SetRngState(seed, offset); } } } } // Names of the dynamic axes in the CNTK engine for some special sets of dynamic axes values // Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over) // and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed // when the new CNTK v2 model serialization format is ready. /*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*"; /*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis"; // Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables // should have been replaced before performing any Forward compute of 'this' Function. /*virtual*/ void CompositeFunction::ReplacePlaceholdersInPlace(const std::unordered_map& placeholderReplacements, std::unordered_set& visitedFunctions, std::unordered_set& replacedPlaceholders) { RootFunction()->ReplacePlaceholdersInPlace(placeholderReplacements, visitedFunctions, replacedPlaceholders); // If any of the placeholders were replaced with Output variables, let's add the graph of function underneath each of those to 'm_allPrimitiveFunctions' set for (auto replacedPlaceholder : replacedPlaceholders) { auto replacingVariable = placeholderReplacements.at(replacedPlaceholder); if (replacingVariable.IsOutput()) { auto ownerFunc = replacingVariable.Owner(); std::unordered_set visitedFunctions2; Collect(ownerFunc, visitedFunctions2); // Add the newly visited functions to 'm_allPrimitiveFunctions' set m_allPrimitiveFunctions.insert(visitedFunctions2.begin(), visitedFunctions2.end()); } } std::unordered_map functionVisitCounts; // An arbitrary cap on changing output shape of recurrent nodes, to detect infinite inference loops const size_t maxNumValidationPassesAllowed = 25; bool recurrentNodeOutputModified = false; size_t numValidationPasses = 0; do { recurrentNodeOutputModified = false; functionVisitCounts.clear(); RootFunction()->ValidateOrUpdateOutputs(functionVisitCounts, recurrentNodeOutputModified); numValidationPasses++; } while (recurrentNodeOutputModified && (numValidationPasses < maxNumValidationPassesAllowed)); if (numValidationPasses >= maxNumValidationPassesAllowed) LogicError("A recurrent node output shape change happened in successive %d validation passes indicating a potential infinite inference loop!", (int)numValidationPasses); } // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the // top level 'variable' template /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder& builder, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap) { auto iter = variableToNodeMap.find(variable); if (iter != variableToNodeMap.end()) { isVariableRootMap[variable] = false; return iter->second; } // The DataType, Shape and DynamicAxes of the variable must be known by now if (variable.GetDataType() == DataType::Unknown) InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.Shape().IsUnknown()) InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.Shape().HasInferredDimension()) InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.DynamicAxes() == Axis::UnknownDynamicAxes()) InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs variableToNodeMap[variable] = nullptr; std::shared_ptr> computationNodePtr; if (variable.IsParameter() || variable.IsConstant()) { auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape())); network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later if (!variable.NeedsGradient()) computationNodePtr->SetLearningRateMultiplier(0.0); NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value(); std::shared_ptr> valueMatrix = variable.IsConstant() ? value->GetMatrix() : value->GetWritableMatrix(); if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId())) computationNodePtr->Value() = valueMatrix->AsReference(); else { Matrix clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat()); clonedMatrix.AssignValuesOf(*valueMatrix); computationNodePtr->Value() = std::move(clonedMatrix); } } else if (variable.IsInput()) { auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); // TODO: Input variables currently are required to have the default batch axis auto dynamicAxes = variable.DynamicAxes(); auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis()); if (foundDefaultBatchAxis == dynamicAxes.end()) LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes"); if (dynamicAxes.back() != Axis::DefaultBatchAxis()) LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes"); // TODO: Support inputs with > 1 dynamic axes if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2)) LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported"); // Construct the dynamic axis name to be used internally for the CNTK InputNodes std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName)) network->AddNodeToNetAndAttachInputs(New>(network->GetDeviceId(), internalDynamicAxisName), {}); if (IsSparseInput(variable)) computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); else computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); if (variable.NeedsGradient()) { // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default // gradients are not computed for Input nodes computationNodePtr->SetLearningRateMultiplier(0.00001f); } } else { assert(variable.IsOutput()); computationNodePtr = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap)->template As>()->shared_from_this(); } variableToNodeMap[variable] = computationNodePtr; if (isVariableRootMap.find(variable) == isVariableRootMap.end()) isVariableRootMap[variable] = variable.IsOutput(); return computationNodePtr; } template /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable, PrimitiveFunction* primitiveFunction, const std::vector>>& inputNodes, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, std::unordered_map& variableToNodeMap) { ComputationNodeBasePtr computationNodePtr; auto internalNodeName = CNTKInternalNodeNameFromUidAndName(primitiveFunction->Uid(), primitiveFunction->Name()); auto& functionConfig = primitiveFunction->Attributes(); auto functionInputs = primitiveFunction->Inputs(); PrimitiveOpType op = primitiveFunction->OpType(); switch (op) { case PrimitiveOpType::Negate: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sigmoid: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Tanh: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Cos: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sin: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ReLU: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Exp: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Log: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sqrt: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Floor: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Abs: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Reciprocal: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Softmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Hardmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::TransposeAxes: { auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value(); auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value(); // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2)); break; } case PrimitiveOpType::Where: { auto dynamicAxes = variable.DynamicAxes(); auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName); break; } case PrimitiveOpType::Slice: { auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value(); auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value(); // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based computationNodePtr = New>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis)); break; } case PrimitiveOpType::RandomSample: { auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); break; } case PrimitiveOpType::RandomSampleInclusionFrequency: { auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); break; } case PrimitiveOpType::Dropout: { auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName); computationNodePtr->As>()->SetDropoutRate(dropoutRate); break; } case PrimitiveOpType::Reshape: { auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(newShape)); break; } case PrimitiveOpType::ROIPooling: { auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape)); break; } case PrimitiveOpType::Pooling: { PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value()); auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW); break; } case PrimitiveOpType::SumAll: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Plus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LogPlus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Minus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ElementTimes: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Equal: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::NotEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Less: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LessEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Greater: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::GreaterEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Times: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap); break; } case PrimitiveOpType::TransposeTimes: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank); break; } case PrimitiveOpType::SampledTimes: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank); break; } case PrimitiveOpType::Convolution: { NDShape outputMapCount, kernelShape; std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape()); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples); break; } case PrimitiveOpType::CosDistance: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Logistic: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::SquaredError: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::CrossEntropyWithSoftmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::NCECriterion: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ClassificationError: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::PastValue: case PrimitiveOpType::FutureValue: { Variable inputOperandVar = functionInputs[0]; Variable initialStateVar = functionInputs[1]; size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value(); if (op == PrimitiveOpType::PastValue) computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); else computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); break; } case PrimitiveOpType::ReduceElements: { auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis)); break; } case PrimitiveOpType::BatchNormalization: { auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value(); auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value(); auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value(); auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value(); auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW); break; } case PrimitiveOpType::Combine: // This operation is just a no-op and is a means to combine multiple functions to create a single Function // whose outputs are a union of the outputs of the Functions being combined. computationNodePtr = variableToNodeMap[variable]; break; case PrimitiveOpType::PackedIndex: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::GatherPacked: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ScatterPacked: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Clip: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Select: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Splice: { Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis)); break; } case PrimitiveOpType::OptimizedRNNStack: { auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value(); auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value(); auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value(); auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp); break; } case PrimitiveOpType::ReconcileDynamicAxis: { computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::LogSoftmax: { //This can be implemented as x => x - ReduceLogSum(x). How to do this here? computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::Pass: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; default: LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str()); break; } std::vector inputNodesBasePtrs; for (auto inputNode : inputNodes) inputNodesBasePtrs.push_back(inputNode); // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs); if (computationNodePtr->Is()) { auto computationNodeExpectedInputCount = computationNodePtr->As()->GetExpectedNumInputs(); if (computationNodeExpectedInputCount != inputNodesBasePtrs.size()) LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs", PrimitiveOpTypeName(op).c_str(), (int)inputNodesBasePtrs.size(), (int)computationNodeExpectedInputCount); } if (computationNodePtr->Is()) { if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed)) { auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value(); unsigned long long offset = 0; if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset)) { offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value(); } computationNodePtr->As()->SetRngState(seed, offset); } } network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs); return computationNodePtr; } template /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder& builder, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap) { assert(variable.IsOutput()); Function* function = variable.Owner().get(); ComputationNodeBasePtr computationNodePtr; if (dynamic_cast(function)) { PrimitiveFunction* primitiveFunction = dynamic_cast(function); PrimitiveOpType op = primitiveFunction->OpType(); auto& functionInputs = primitiveFunction->m_inputs; DataType nonConstInputDataType = DataType::Unknown; for (auto& inputVar : functionInputs) { if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown)) { nonConstInputDataType = inputVar.GetDataType(); break; } } // Create the nodes corresponding to the inputs std::vector>> inputNodes; for (auto& inputVar : functionInputs) { // If the inputVar is a constant and not the right DataType let's coerce it to the right type if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType)) { auto originalConstantValue = Constant(inputVar).Value(); auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true); NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true); inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name()); } auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap); inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As>()->shared_from_this() : nullptr); } computationNodePtr = CreateComputationNode(variable, primitiveFunction, inputNodes, network, variableToNodeMap); if (op != PrimitiveOpType::Combine) { for (auto inputVar : functionInputs) isVariableRootMap[inputVar] = false; } } else LogicError("User defined Functions are currently unsupported!"); return computationNodePtr; } template ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, const std::unordered_set& outputs, bool allocateNetworkMatrices) { if (m_computationNetwork != nullptr) { // TODO: We should either invalidate and readapt the network if he backpropRoots change compared to what was specified when the network // was last constructed, to just recreate a new network. // For now just disallow changing the backpropRoots after the network is created if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots)) LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported"); // TODO: Support changing the device across different invocations of the forward method on a Function instance if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device) LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported"); } else { m_computationNetwork = std::make_shared(AsCNTKImplDeviceId(device)); ComputationNetworkBuilder builder(*m_computationNetwork); // TODO: We currently only support one backprop root if (backpropRoots.size() > 1) LogicError("More than one backprop roots is currently unsupported"); auto placeholders = Placeholders(); if (!placeholders.empty()) InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!"); // Now recursively create the network in a top-down fashion auto rootFunction = RootFunction(); auto rootFunctionOutputs = rootFunction->Outputs(); for (auto rootOutput : rootFunctionOutputs) GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap); // If any of the function outputs is not a root node, we need to explicitly add it to the 'output' group of the ComputationNetwork for (auto rootOutput : rootFunctionOutputs) { if (!m_isVariableRootMap[rootOutput]) m_computationNetwork->AddToNodeGroup(L"output", m_variableToNodeMap[rootOutput]); } // If any of the requested outputs is not a root node, we need to explicitly add it to the 'output' group of the ComputationNetwork for (auto output : outputs) { if (!m_isVariableRootMap[output]) m_computationNetwork->AddToNodeGroup(L"output", m_variableToNodeMap[output]); } m_currentBackpropRoots = backpropRoots; // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles. // Now attach those after we have created all ComputationNodes in the network for (auto varNodePair : m_variableToNodeMap) { auto& currentComputationNode = varNodePair.second; auto& currentComputationNodeInputs = currentComputationNode->GetInputs(); auto& currentVar = varNodePair.first; if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end()) { // This ComputationNode has at least one null input which now needs to be properly attached const PrimitiveFunction* primitiveFunc = dynamic_cast(currentVar.Owner().get()); // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering auto inputVars = primitiveFunc->Inputs(); ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars); inputVars.resize(currentComputationNode->GetNumInputs()); std::vector inputNodesBasePtrs; for (auto inputVar : inputVars) inputNodesBasePtrs.push_back(m_variableToNodeMap[inputVar]); currentComputationNode->AttachInputs(inputNodesBasePtrs); } } m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel()); m_computationNetwork->CompileNetwork(); // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork for (auto varNodePair : m_variableToNodeMap) { if (varNodePair.first.IsOutput()) { auto outputVar = varNodePair.first; auto computationNodePtr = m_variableToNodeMap[outputVar]; auto outputShape = outputVar.Shape(); auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout(); if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) || ((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape)))) { LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str()); } } } // Record the timestamps of Parameter values assert(m_lastRecordedParameterValueTimeStamps.empty()); auto functionParameters = Parameters(); for (auto parameter : functionParameters) m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() }); } if (!m_networkMatricesAllocated && allocateNetworkMatrices) { ComputationNodeBasePtr backpropRootNode; // Now recursively traverse the network in a top-down fashion auto rootFunction = RootFunction(); auto rootFunctionOutputs = rootFunction->Outputs(); std::vector forwardRootNodes; for (auto rootOutput : rootFunctionOutputs) { auto currentRootNode = m_variableToNodeMap[rootOutput]; forwardRootNodes.push_back(currentRootNode); if (m_currentBackpropRoots.find(rootOutput) != m_currentBackpropRoots.end()) backpropRootNode = currentRootNode; } std::vector forwardOutputNodes; for (auto output : outputs) { auto currentOutputNode = m_variableToNodeMap[output]; forwardOutputNodes.push_back(currentOutputNode); if (m_currentBackpropRoots.find(output) != m_currentBackpropRoots.end()) backpropRootNode = currentOutputNode; } m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode); m_networkMatricesAllocated = allocateNetworkMatrices; m_currentOutputs = outputs; m_currentOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end()); m_currentOutputs.insert(m_currentBackpropRoots.begin(), m_currentBackpropRoots.end()); } else { // Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure // in the cached computation network for (auto output : outputs) { if (m_currentOutputs.find(output) == m_currentOutputs.end()) LogicError("Changing requested outputs across different Forward calls on a CNTK composite Function is currently unsupported"); } } return m_computationNetwork; } template /*static*/ std::pair>, MBLayoutPtr> CompositeFunction::GetCNTKImplMatrixAndMBLayoutFromValueObject(Variable var, const ValuePtr& value) { if (var.GetDataType() != value->GetDataType()) LogicError("The Variable's DataType %s does not match the corresponding Value's DataType %s", DataTypeName(var.GetDataType()), DataTypeName(value->GetDataType())); if (AsDataType() != value->GetDataType()) LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(value->GetDataType())); // TODO: Is supplying dense data for an Input variable tagged as sparse, a fatal error? if (IsSparseInput(var) && !value->IsSparse()) InvalidArgument("Dense input data supplied for a sparse input Variable"); if (IsSparseInput(var) && (value->GetStorageFormat() != StorageFormat::SparseCSC)) InvalidArgument("Sparse Input data must be in SparseCSC format"); auto varShape = var.Shape(); auto valueShape = value->Shape(); if (valueShape.Rank() < varShape.Rank()) InvalidArgument("Value's rank should be >= the Variable's rank"); size_t maxAddionalValueAxes = std::max(2, var.DynamicAxes().size()); if (valueShape.Rank() > (varShape.Rank() + maxAddionalValueAxes)) InvalidArgument("Value rank should be larger than the Variable%S rank at most by number of dynamic axes", ParanthesizedName(var.Name()).c_str()); if (valueShape.SubShape(0, varShape.Rank()) != varShape) { InvalidArgument("The %s dimensions of the Value shape %S do not match the shape of the variable %S that it corresponds to!", Internal::IsReversingTensorShapesInErrorMessagesEnabled() ? "trailing" : "leading", AsStringForErrorReporting(valueShape).c_str(), AsStringForErrorReporting(varShape).c_str()); } if (var.DynamicAxes().empty()) return{ value->Data()->GetMatrix(), nullptr }; if (var.DynamicAxes().size() > 2) LogicError("More than 2 dynamic axis for a variable is currently unsupported"); auto mask = value->Mask(); if ((mask != nullptr) && ((varShape.Rank() + mask->Shape().Rank()) != valueShape.Rank())) InvalidArgument("Invalid Value object; the sum of the rank of the mask and data does not equal the Variable's rank + number of dynamic axes"); auto getNumTimeStepsAndSequencesFunc = [](const NDShape& maskShape) { size_t maxNumTimeSteps = 1; size_t numSequences = 1; if (maskShape.Rank() > 0) maxNumTimeSteps = maskShape[0]; if (maskShape.Rank() > 1) numSequences = maskShape[1]; return std::pair(maxNumTimeSteps, numSequences); }; size_t maxNumTimeSteps, numSequences; std::tie(maxNumTimeSteps, numSequences) = getNumTimeStepsAndSequencesFunc(valueShape.SubShape(varShape.Rank())); auto getSequenceStartsAndLengthsFunc = [&getNumTimeStepsAndSequencesFunc](const NDMaskPtr& mask, std::vector& sequenceBeginIndices, std::vector& sequenceLengths) { auto cpuMask = mask; if (mask->Device() != DeviceDescriptor::CPUDevice()) cpuMask = mask->DeepClone(DeviceDescriptor::CPUDevice()); const MaskKind* maskBuffer = cpuMask->DataBuffer(); size_t maxNumTimeSteps, numSequences; std::tie(maxNumTimeSteps, numSequences) = getNumTimeStepsAndSequencesFunc(mask->Shape()); for (size_t i = 0; i < numSequences; ++i) { MaskKind firstMaskEntry = maskBuffer[i * maxNumTimeSteps]; if (firstMaskEntry == MaskKind::SequenceBegin) sequenceBeginIndices[i] = 0; else if (firstMaskEntry == MaskKind::Valid) sequenceBeginIndices[i] = Microsoft::MSR::CNTK::SentinelValueIndicatingUnspecifedSequenceBeginIdx; else LogicError("The first entry of a mask should be Valid or SequenceBegin"); size_t currentSequenceLength = 1; bool currentSequenceEndAlreadyFound = false; for (size_t j = 1; j < maxNumTimeSteps; ++j) { if (maskBuffer[(i * maxNumTimeSteps) + j] == MaskKind::Invalid) currentSequenceEndAlreadyFound = true; else { if (currentSequenceEndAlreadyFound) InvalidArgument("Invalid Value object; only trailing steps of a sequence can be masked"); currentSequenceLength++; } } sequenceLengths[i] = currentSequenceLength; } }; if ((numSequences == 1) || (maxNumTimeSteps == 1)) { // The data need not be shuffled std::shared_ptr> matrixData = value->Data()->GetMatrix(varShape.Rank()); auto layout = std::make_shared(); if (!mask) { if (maxNumTimeSteps == 1) layout->InitAsFrameMode(numSequences); else { layout->Init(numSequences, maxNumTimeSteps); layout->AddSequence(0, 0, 0, maxNumTimeSteps); } } else { layout->Init(numSequences, maxNumTimeSteps); std::vector sequenceBeginIndices(numSequences, 0); std::vector sequenceLengths(numSequences, maxNumTimeSteps); getSequenceStartsAndLengthsFunc(mask, sequenceBeginIndices, sequenceLengths); for (size_t i = 0; i < numSequences; ++i) layout->AddSequence(i, i, sequenceBeginIndices[i], sequenceLengths[i]); } return{ matrixData , layout}; } else { std::vector sequenceBeginIndices(numSequences, 0); std::vector sequenceLengths(numSequences, maxNumTimeSteps); if (mask != nullptr) getSequenceStartsAndLengthsFunc(mask, sequenceBeginIndices, sequenceLengths); bool hasTruncatedSequences = std::find_if(sequenceBeginIndices.begin(), sequenceBeginIndices.end(), [](const int& val) { return (val < 0); }) != sequenceBeginIndices.end(); auto layout = std::make_shared(); std::vector> placement; if (!hasTruncatedSequences) { std::vector sequences; for (size_t i = 0; i < numSequences; ++i) sequences.push_back({ i, SIZE_MAX, sequenceBeginIndices[i], sequenceLengths[i] }); std::vector rowAllocations; layout->InitAsPackedSequences(sequences, placement, rowAllocations); } else { layout->Init(numSequences, maxNumTimeSteps); // We cannot pack as some of the sequences are truncated and thus all sequences have to be // kept in their original parallel streams placement.resize(numSequences); for (size_t i = 0; i < numSequences; ++i) { layout->AddSequence(i, i, sequenceBeginIndices[i], sequenceLengths[i]); // Add the gap if there is one if (sequenceLengths[i] < maxNumTimeSteps) layout->AddSequence(GAP_SEQUENCE_ID, i, sequenceLengths[i], maxNumTimeSteps); placement[i] = std::make_pair(i, 0); } } if (maxNumTimeSteps != layout->GetNumTimeSteps()) LogicError("The number of time steps in the packed MBLayout does not match the longest sequence's length in the Value object"); if (numSequences != layout->GetNumSequences()) LogicError("The number of sequences in the packed MBLayout does not match the sequence count in the Value object"); // The data needs to be rearranged since CNTK requires sequences to be interleaved across timesteps // Now generate the gather indices auto matrixData = std::make_shared>(varShape.TotalSize(), layout->GetNumCols(), AsCNTKImplDeviceId(value->Device()), value->IsSparse() ? MatrixType::SPARSE : MatrixType::DENSE, AsCNTKImplMatrixFormat(value->GetStorageFormat())); std::vector sequencesShorterThanLongestSequence; for (size_t i = 0; i < numSequences; ++i) if (sequenceLengths[i] != maxNumTimeSteps) sequencesShorterThanLongestSequence.push_back(i); // Set the source location for all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch size_t sourceColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1); std::vector gatherIndicesVector(layout->GetNumCols(), (ElementType)sourceColIdxForInvalidColumns); for (size_t i = 0; i < numSequences; ++i) { size_t targetParallelStreamIdx = placement[i].first; size_t targetStartIdxInParallelStream = placement[i].second; for (size_t j = 0; j < sequenceLengths[i]; ++j) gatherIndicesVector[((targetStartIdxInParallelStream + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j); } auto gatherIdxMatrix = std::make_shared>(1, layout->GetNumCols(), gatherIndicesVector.data(), AsCNTKImplDeviceId(value->Device())); matrixData->DoGatherColumnsOf(0, *gatherIdxMatrix, *(value->Data()->GetMatrix(varShape.Rank())), 1); return{ matrixData, layout }; } } template /*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const Matrix& matrix, const MBLayoutPtr& layout, bool readOnly /*= true*/) { NDShape valueDataShape = sampleShape; size_t maxNumTimeSteps = 1; size_t numSequences = 1; if (layout != nullptr) { maxNumTimeSteps = layout->GetNumTimeSteps(); numSequences = layout->GetNumSequences(); valueDataShape = valueDataShape.AppendShape({ maxNumTimeSteps, numSequences }); } auto createMaskFunc = [](const MBLayoutPtr& layout, const DeviceDescriptor& device, std::vector& sequencesShorterThanLongestSequence) { std::vector sequenceBeginFlags; std::vector sequenceLengths; sequencesShorterThanLongestSequence.clear(); size_t maxNumTimeSteps = layout->GetNumTimeSteps(); size_t numSequences = layout->GetNumSequences(); auto& layoutSequences = layout->GetAllSequences(); size_t sequenceIdx = 0; bool allSequencesStartInThisMB = true; bool allSequencesSameLength = true; for (auto sequenceInfo : layoutSequences) { if (sequenceInfo.seqId != GAP_SEQUENCE_ID) { auto currentSequenceBeginIdx = std::max(0, sequenceInfo.tBegin); auto currentSequenceEndIdx = std::min(maxNumTimeSteps, sequenceInfo.tEnd); auto currentSequenceLength = (currentSequenceEndIdx - currentSequenceBeginIdx); auto isCurrentSequenceBeginningInsideThisMB = sequenceInfo.tBegin >= 0; allSequencesStartInThisMB = allSequencesStartInThisMB && isCurrentSequenceBeginningInsideThisMB; allSequencesSameLength = allSequencesSameLength && (currentSequenceLength == maxNumTimeSteps); sequenceBeginFlags.push_back(isCurrentSequenceBeginningInsideThisMB); sequenceLengths.push_back(currentSequenceLength); if (currentSequenceLength != maxNumTimeSteps) sequencesShorterThanLongestSequence.push_back(sequenceIdx); sequenceIdx++; } } if (!allSequencesStartInThisMB && (numSequences != layout->GetNumParallelSequences())) LogicError("Cannot create an unpacked Value object from packed data where one or more sequences are truncated"); bool maskNeeded = !allSequencesSameLength || !allSequencesStartInThisMB; NDMaskPtr mask; if (maskNeeded) { mask = MakeSharedObject(NDShape({ maxNumTimeSteps, numSequences }), DeviceDescriptor::CPUDevice()); for (size_t i = 0; i < numSequences; ++i) if (sequenceBeginFlags[i]) mask->MarkSequenceBegin({0, i}); for (auto shortSequenceIdx : sequencesShorterThanLongestSequence) mask->InvalidateSection({ sequenceLengths[shortSequenceIdx], shortSequenceIdx }, { NDShape::InferredDimension, 1 }); } return mask; }; // No data shuffling needed if no layout or the layout has just one time-step or just one sequence std::vector sequencesShorterThanLongestSequence; if ((maxNumTimeSteps == 1) || (numSequences == 1)) { // Just create a view over the existing matrix itself auto tensorView = new TensorView(std::make_shared>(matrix.AsReference()), AsTensorViewShape(valueDataShape)); auto data = MakeSharedObject(AsDataType(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(matrix.GetFormat()), valueDataShape, readOnly, tensorView); if (layout == nullptr) return MakeSharedObject(data); else { auto mask = createMaskFunc(layout, AsDeviceDescriptor(matrix.GetDeviceId()), sequencesShorterThanLongestSequence); return MakeSharedObject(data, mask); } } if (layout->GetNumCols() != matrix.GetNumCols()) LogicError("Bad MBLayout: The number of columns in the MBLayout does not match the number of columns in the data matrix!"); // Reshuffle to data to unpack and uninterleave the CNTK form packed data // Now generate the scatter indices auto shuffledMatrixData = std::make_shared>(matrix.GetNumRows(), maxNumTimeSteps * numSequences, matrix.GetDeviceId(), matrix.GetMatrixType(), matrix.GetFormat()); auto mask = createMaskFunc(layout, AsDeviceDescriptor(matrix.GetDeviceId()), sequencesShorterThanLongestSequence); // Set the target location of all gaps to be the last step of the first sequence that is shorter than the longest sequence in the batch size_t targetColIdxForInvalidColumns = sequencesShorterThanLongestSequence.empty() ? 0 : (((sequencesShorterThanLongestSequence[0] + 1) * maxNumTimeSteps) - 1); std::vector scatterIndicesVector(layout->GetNumCols(), (ElementType)targetColIdxForInvalidColumns); size_t i = 0; auto& layoutSequences = layout->GetAllSequences(); for (auto sequenceInfo : layoutSequences) { if (sequenceInfo.seqId != GAP_SEQUENCE_ID) { size_t targetParallelStreamIdx = sequenceInfo.s; auto currentSequenceBeginIdx = std::max(0, sequenceInfo.tBegin); auto currentSequenceEndIdx = std::min(maxNumTimeSteps, sequenceInfo.tEnd); size_t currentSequenceLength = (currentSequenceEndIdx - currentSequenceBeginIdx); for (size_t j = 0; j < currentSequenceLength; ++j) scatterIndicesVector[((currentSequenceBeginIdx + j) * layout->GetNumParallelSequences()) + targetParallelStreamIdx] = (ElementType)((i * maxNumTimeSteps) + j); i++; } } auto scatterIdxMatrix = std::make_shared>(1, layout->GetNumCols(), scatterIndicesVector.data(), matrix.GetDeviceId()); shuffledMatrixData->DoScatterColumnsOf(0, *scatterIdxMatrix, matrix, 1); auto tensorView = new TensorView(shuffledMatrixData, AsTensorViewShape(valueDataShape)); auto data = MakeSharedObject(AsDataType(), AsDeviceDescriptor(matrix.GetDeviceId()), AsStorageFormat(shuffledMatrixData->GetFormat()), valueDataShape, readOnly, tensorView); return MakeSharedObject(data, mask); } template /*static*/ ValuePtr CompositeFunction::GetValueObjectFromCNTKImplMatrixAndMBLayout(Variable var, const Matrix& matrix, const MBLayoutPtr& layout, bool readOnly /*= true*/) { if (var.DynamicAxes().size() > 2) LogicError("More than 2 dynamic axis for a variable is currently unsupported"); if (AsDataType() != var.GetDataType()) LogicError("The specified ElementType %s does not match the DataType %s", typeid(ElementType).name(), DataTypeName(var.GetDataType())); if ((layout != nullptr) && (matrix.GetNumRows() != var.Shape().TotalSize())) LogicError("Unexpected matrix layout: The number of rows in the matrix does not match the sample size of the Variable"); return GetValueObjectFromCNTKImplMatrixAndMBLayout(var.Shape(), matrix, layout, readOnly); } template /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair& variableValue, ComputationNodeBasePtr& computationNode) { std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout; auto packedValue = dynamic_cast(variableValue.second.get()); if (packedValue) CNTKMatrixAndMBLayout = packedValue->PackedData(); else CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject(variableValue.first, variableValue.second); MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; auto& nodeData = computationNode->As>()->Value(); // Switch the node matrix to the right matrix type nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first); computationNode->GetMBLayout()->CopyFrom(layout); } void CompositeFunction::PopulateNetworkInputs(const std::unordered_map& arguments) { std::vector inputNodes; for (auto argumentValuePair : arguments) { auto argument = argumentValuePair.first; auto argumentComputationNode = m_variableToNodeMap[argument]; assert(argumentComputationNode); inputNodes.push_back(argumentComputationNode); ValuePtr argumentValue = arguments.at(argument); MBLayoutPtr layout; switch (argumentValue->GetDataType()) { case DataType::Float: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode); break; case DataType::Double: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode); break; default: LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType())); break; } } m_computationNetwork->BumpEvalTimeStamp(inputNodes); } template /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode) { std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout; auto packedValue = dynamic_cast(variableGradient.second.get()); if (packedValue) CNTKMatrixAndMBLayout = packedValue->PackedData(); else CNTKMatrixAndMBLayout = GetCNTKImplMatrixAndMBLayoutFromValueObject(variableGradient.first, variableGradient.second); MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; auto nodeLayout = computationNode->GetMBLayout(); if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout))) InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call"); computationNode->As>()->AssignGradient(*CNTKMatrixAndMBLayout.first); } // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph void CompositeFunction::PopulateNetworkGradients(const std::unordered_map& gradients) { auto functionOutputs = this->Outputs(); for (auto gradientVarValuePair : gradients) { auto outputComputationNode = m_variableToNodeMap[gradientVarValuePair.first]; ValuePtr gradientValue = gradientVarValuePair.second; switch (gradientValue->GetDataType()) { case DataType::Float: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; case DataType::Double: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; default: LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType())); break; } } } static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr) { size_t outputValueNumAxes = var.Shape().Rank(); // Add the batch and dynamic axes if needed if (computationNodePtr->GetMBLayout() != nullptr) outputValueNumAxes += 2; std::vector outputShapeDims(outputValueNumAxes); for (size_t i = 0; i < var.Shape().Rank(); ++i) outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i); if (computationNodePtr->GetMBLayout() != nullptr) { outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps(); outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences(); } return NDShape(outputShapeDims); } /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient) { auto valueShape = GetValueShape(var, computationNode); if (varValue != nullptr) { // TODO: The shape of the specified output Value object must match the actual output shape if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape))) InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str()); } ValuePtr nodeValue; auto layout = computationNode->GetMBLayout(); switch (var.GetDataType()) { case DataType::Float: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); break; } case DataType::Double: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); break; } default: LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType())); break; } if (varValue == nullptr) varValue = nodeValue; else varValue->CopyFrom(*nodeValue); } void CompositeFunction::GetNetworkOutputs(std::unordered_map& outputs) { // Now copy the Forward values of output nodes from the network to outputs' Value objects for (auto outputVarValuePair : outputs) GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap[outputVarValuePair.first], false /*getGradient*/); } void CompositeFunction::GetNetworkGradients(std::unordered_map& gradients) { auto networkInputs = this->Inputs(); // Now copy the gradient values of input nodes of the network to gradients' Value objects for (auto gradientVarValuePair : gradients) { // Only gradients corresponding to inputs of the network can be obtained if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end()) InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function"); // Gradients can only be obtained for parameter variables or input variables that NeedsGradient if (!gradientVarValuePair.first.NeedsGradient()) InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, or an Input Variable with NeedsGradient setting of false"); auto computationNodePtr = m_variableToNodeMap[gradientVarValuePair.first]; if (!computationNodePtr->NeedsGradient()) LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false"); GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/); } } const std::vector& CompositeFunction::GetArgumentDependencies(const Variable& output) { assert(output.IsOutput()); auto iter = m_perOutputVarArgumentDependencies.find(output); if (iter != m_perOutputVarArgumentDependencies.end()) return iter->second; auto wrappedComposite = CompositeFunction::Create(output.Owner()); m_perOutputVarArgumentDependencies[output] = wrappedComposite->Arguments(); return m_perOutputVarArgumentDependencies[output]; } /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map& arguments, std::unordered_map& outputs, const DeviceDescriptor& computeDevice, const std::unordered_set& outputsToRetainBackwardStateFor) { // Validate arguments and outputs if (outputs.empty()) InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!"); // Make sure that the DataType of the variables and corresponding values match // TODO: We need a better way to determine the ElementType for the network auto dataType = DataType::Unknown; for (auto variableValuePair : arguments) { if (dataType == DataType::Unknown) dataType = variableValuePair.first.GetDataType(); else if (dataType != variableValuePair.first.GetDataType()) LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same"); } if (dataType == DataType::Unknown) { for (auto variableValuePair : outputs) { if (dataType == DataType::Unknown) dataType = variableValuePair.first.GetDataType(); } } std::unordered_set requestedOutputVariables; for (auto output : outputs) requestedOutputVariables.insert(output.first); if (dataType == DataType::Float) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, true); else if (dataType == DataType::Double) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, true); else InvalidArgument("Unsupported DataType %s", DataTypeName(dataType)); std::unordered_set functionOutputs(this->Outputs().begin(), this->Outputs().end()); std::vector outputsToEvaluate; std::unordered_set requiredArguments; for (auto outputVarValuePair : outputs) { auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first); requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end()); auto outputComputationNode = m_variableToNodeMap[outputVarValuePair.first]; outputsToEvaluate.push_back(outputComputationNode); } // TODO: Avoid copying the data when possible // We should have argument values supplied for all required argument dependencies for the requested outputs for (auto requiredArgument : requiredArguments) { if (arguments.find(requiredArgument) == arguments.end()) InvalidArgument("Function::Forward: Required argument's (%S) value that the requested output(s) depend on has not been provided", requiredArgument.Name().c_str()); } // Feed data into the arguments of the network PopulateNetworkInputs(arguments); // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input // This mask is regerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated // w.r.t. inputs to force evaluation in each minibatch list dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode)); for (auto& nodeIter : dropoutNodes) nodeIter->SetEvalTimeStampOutdatedWrtAll(); // Bump the timestamp of the parameter nodes whose values have changed for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps) { auto parameter = paramTimeStampRecord.first; auto prevTimeStamp = paramTimeStampRecord.second; auto newTimeStamp = parameter.CurrentValueTimeStamp(); if (newTimeStamp > prevTimeStamp) { paramTimeStampRecord.second = newTimeStamp; m_variableToNodeMap[parameter]->BumpEvalTimeStamp(); } } // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs' for (auto rootVarForBackprop : outputsToRetainBackwardStateFor) { if (outputs.find(rootVarForBackprop) == outputs.end()) outputsToEvaluate.push_back(m_variableToNodeMap[rootVarForBackprop]); } // TODO: Verify that values were supplied for all inputs that requested outputs depend on ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training); m_computationNetwork->ForwardProp(outputsToEvaluate); GetNetworkOutputs(outputs); // TODO: How to deal with the specified 'computeDevice' Variable evalTimeStampVariable; if (arguments.empty()) evalTimeStampVariable = Inputs()[0]; else evalTimeStampVariable = arguments.begin()->first; return (outputsToRetainBackwardStateFor.size() > 0) ? MakeSharedObject(this->shared_from_this(), computeDevice, std::make_pair(evalTimeStampVariable, m_variableToNodeMap[evalTimeStampVariable]->GetEvalTimeStamp())) : nullptr; } /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state, const std::unordered_map& rootGradientValues, std::unordered_map& backPropagatedGradientValuesForInputs) { auto backpropState = dynamic_cast(state.get()); if (backpropState == nullptr) InvalidArgument("Invalid backprop state specified"); // TODO: Support multiple concurrent backprop states if (backpropState->EvalTimeStamp().second != m_variableToNodeMap[backpropState->EvalTimeStamp().first]->GetEvalTimeStamp()) LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function." "This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported"); if (rootGradientValues.size() > 1) LogicError("Currently gradient backprop from only one of the Function Outputs is supported"); // TODO: Avoid copying the data when possible // Zero all gradients of nodes below the root nodes for (auto rootGradientVarValuePair : rootGradientValues) m_computationNetwork->ZeroInputGradients(m_variableToNodeMap[rootGradientVarValuePair.first]); // Feed data into the arguments of the network PopulateNetworkGradients(rootGradientValues); // Backpropagate through the network ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training); auto rootComputationNodePtr = m_variableToNodeMap[rootGradientValues.begin()->first]; m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true); GetNetworkGradients(backPropagatedGradientValuesForInputs); // TODO: How to deal with the specified 'computeDevice' } }