// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include "CNTKLibrary.h" #include "CompositeFunction.h" #include "ComputationNetworkBuilder.h" #include "Utils.h" #include "ComputationNode.h" #include "ReshapingNodes.h" #include "EvaluationNodes.h" #include "TrainingNodes.h" #include "LinearAlgebraNodes.h" #include "InputAndParamNodes.h" #include "NonlinearityNodes.h" #include "RecurrentNodes.h" #include "Serialization.h" #include "Value.h" #include "RNNNodes.h" #include "UserDefinedV2FunctionNode.h" #include "BlockFunction.h" #include "SpecialPurposeNodes.h" using namespace Microsoft::MSR::CNTK; namespace CNTK { /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName"; /*static*/ std::atomic CompositeFunction::s_nextAutoGeneratedDynamicAxis(0); static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction"; Dictionary CompositeFunction::SerializeBlockComposite() const { Dictionary dict; dict[versionKey] = CurrentVersion(); dict[typeKey] = s_compositeFunctionTypeValue; dict[rootKey] = RootFunction()->Uid(); if (!Name().empty()) dict[nameKey] = Name(); dict[uidKey] = Uid(); return dict; } /*virtual*/ Dictionary CompositeFunction::Serialize() const { Dictionary dict = SerializeBlockComposite(); // Find cycles in the graph and "break" them by inserting placeholders. // This needs to be done on Save, since here we have easy access to the shape and // dynamic axis info. std::unordered_set visitedFunctions; std::vector topoSortedPrimitiveFunctions; std::vector uniqueInputs; std::unordered_set inputUids; std::function SerializationTraversalFunc; SerializationTraversalFunc = [&visitedFunctions, &uniqueInputs, &topoSortedPrimitiveFunctions, &inputUids, &SerializationTraversalFunc](const FunctionPtr& function) { std::vector functionInputs = function->Inputs(); for (const auto& input : functionInputs) { auto& uid = input.Uid(); if (inputUids.find(uid) != inputUids.end()) continue; // check if this input corresponds to a cyclic edge in the graph. // BUG: A function being visited twice does not indicate it being a cyclic edge in the graph. // It just means there are at least 2 successors in the graph that have the function as input bool mustBeReplaced = input.IsOutput() && (visitedFunctions.find(input.Owner()) != visitedFunctions.end()); if (mustBeReplaced) { auto varKind = VariableKind::Placeholder; Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid); uniqueInputs.push_back(var); inputUids.insert(uid); } else if (!input.IsOutput()) { // leave the input as is. uniqueInputs.push_back(input); inputUids.insert(uid); } } visitedFunctions.insert(function); topoSortedPrimitiveFunctions.push_back(function); // For block functions we need to recursively traverse the underlying composite if (function->IsBlock()) PreorderTraverseFunctions(function->BlockRoot(), SerializationTraversalFunc); }; PreorderTraverseFunctions(RootFunction(), SerializationTraversalFunc); std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions)); assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid()); std::vector inputDictionaries; inputDictionaries.reserve(uniqueInputs.size()); inputUids.clear(); for (const auto& input : uniqueInputs) { if (inputUids.find(input.Uid()) != inputUids.end()) LogicError("Input uids must be unique"); inputUids.insert(input.Uid()); inputDictionaries.push_back(input.Serialize()); } dict[inputsKey] = std::move(inputDictionaries); std::vector functionDictionaries; std::unordered_set outputUids; for (const auto& primitiveFunction : topoSortedPrimitiveFunctions) { for (const auto& output : primitiveFunction->RawOutputs()) { if (outputUids.find(output.Uid()) != outputUids.end()) LogicError("Output uids of all primitive functions in a function graph must be unique"); outputUids.insert(primitiveFunction->Uid()); } functionDictionaries.push_back(primitiveFunction->Serialize()); } dict[functionsKey] = std::move(functionDictionaries); // Now, collect and store the internal state for all non-pure (stateful) functions in the graph // (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc). Dictionary stateDictionary; for (const auto& kv : m_variableToNodeMap) { if (kv.second->Is() && kv.first.IsOutput()) { // The RNG state should be associated with the actual function that the computation node // corresponds to, and not the block primitives that wrap the actual function auto ownerFunction = kv.first.Owner().get(); if (!ownerFunction->IsBlock()) { auto rng = kv.second->As(); Dictionary state; state[rngSeedKey] = static_cast(rng->GetRngSeed()); state[rngOffsetKey] = static_cast(rng->GetRngOffset()); stateDictionary[ownerFunction->Uid()] = state; } } } dict[stateKey] = std::move(stateDictionary); return dict; } /*static*/ FunctionPtr CompositeFunction::DeserializeBlockComposite(const Dictionary& dict, const std::unordered_set& allPrimitiveFunctions, const std::unordered_map& allPlaceholderReplacements, const CNTK::DeviceDescriptor& device) { static const vector s_requiredDictionaryKeys = { typeKey, rootKey, uidKey }; ValidateDictionary(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion); const auto& rootUid = dict[rootKey].Value(); std::wstring name = L""; if (dict.Contains(nameKey)) name = dict[nameKey].Value(); const auto& uid = dict[uidKey].Value(); FunctionPtr root = *std::find_if(allPrimitiveFunctions.begin(), allPrimitiveFunctions.end(), [&rootUid](const FunctionPtr& func) { return func->Uid() == rootUid; }); // Find the subset of placeholder replacements that apply for this composite FunctionPtr composite = CompositeFunction::Create(root, name, uid); std::unordered_map placeholderReplacements; do { placeholderReplacements.clear(); auto compositePlaceholders = composite->Placeholders(); for (auto placeholder : compositePlaceholders) { if (allPlaceholderReplacements.find(placeholder) != allPlaceholderReplacements.end()) placeholderReplacements.insert({ placeholder, allPlaceholderReplacements.at(placeholder) }); } if (placeholderReplacements.size() > 0) composite = composite->ReplacePlaceholders(placeholderReplacements); } while (placeholderReplacements.size() > 0); return composite; } /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device) { static const vector s_requiredDictionaryKeys = { inputsKey, functionsKey }; size_t version = ValidateDictionary(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion); const auto& inputs = dict[inputsKey].Value>(); std::unordered_map uidToInputMap(inputs.size()); for (const auto& dictionaryValue : inputs) { const auto& dictionary = dictionaryValue.Value(); const auto& inputVar = Variable::Deserialize(dictionary, device); if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end()) { LogicError("Input uids are not unique (several inputs share '%ls' uid) " "(%s).", inputVar.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); } uidToInputMap[inputVar.Uid()] = inputVar; } Dictionary stateDictionary; if (dict.Contains(stateKey)) stateDictionary = dict[stateKey].Value(); const auto& functions = dict[functionsKey].Value>(); std::unordered_map allPlaceholderReplacements; std::unordered_set allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created. for (const auto& dictionaryValue : functions) { FunctionPtr root = PrimitiveFunction::Deserialize(dictionaryValue.Value(), uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device); allPrimitiveFunctions.insert(root); auto primitiveFunction = dynamic_cast(root.get()); // Since Combine simply forwards other functions' outputs, all of its outputs // should already be in the uidToInputMap. auto opType = primitiveFunction->OpType(); if (opType == PrimitiveOpType::Combine) continue; if (primitiveFunction->IsStateful()) { if (stateDictionary.Contains(primitiveFunction->Uid())) { auto state = stateDictionary[primitiveFunction->Uid()].Value(); auto seed = state[rngSeedKey].Value(); auto offset = state[rngOffsetKey].Value(); primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed; primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset; } else if (Internal::GetComputationNetworkTraceLevel() > 0) { // TODO: all logging functionality should be refactored to live in a logging utility class. fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) " "when deserializing from a dictionary (version=%zu). " "Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version); } } for (const auto& output : root->RawOutputs()) { const auto& it = uidToInputMap.find(output.Uid()); if (it != uidToInputMap.end()) { if (!it->second.IsPlaceholder()) { LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)" "(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); } allPlaceholderReplacements[it->second] = output; } else { uidToInputMap[output.Uid()] = output; } } } return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device); } void CompositeFunction::CopyState(const CompositeFunction& source) { // Create a map with all non-pure (stateful) functions in the function graph. auto collectStatefulFunctions = [](const std::unordered_set& allPrimitiveFunctions) -> std::map { std::map functionMap; for (auto funcPtr : allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(funcPtr.get()); if (primitiveFunction->IsStateful()) { functionMap[primitiveFunction->Uid()] = funcPtr; } } return functionMap; }; std::map statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions); std::map statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions); assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size()); if (statefulFunctionsFrom.size() == 0) { return; } // Copy state captured in the attributes dictionaries. for (const auto& kv : statefulFunctionsFrom) { statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes(); } UpdateInternalNetworkState(); } void CompositeFunction::UpdateInternalNetworkState() { if (!m_computationNetwork) { return; } for (const auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); if (primitiveFunction->IsStateful()) { for (const auto& output : function->RawOutputs()) { auto node = m_variableToNodeMap.at(output); auto attributes = function->Attributes(); auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value(); auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value(); node->As()->SetRngState(seed, offset); } } } } // Names of the dynamic axes in the CNTK engine for some special sets of dynamic axes values // Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over) // and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed // when the new CNTK v2 model serialization format is ready. /*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*"; /*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis"; // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the // top level 'variable' template /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder& builder, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap, const std::unordered_set& inputsToExcludeGradientsFor) { auto iter = variableToNodeMap.find(variable); if (iter != variableToNodeMap.end()) { isVariableRootMap[variable] = false; return iter->second; } // The DataType, Shape and DynamicAxes of the variable must be known by now if (variable.GetDataType() == DataType::Unknown) InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.Shape().IsUnknown()) InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.Shape().HasInferredDimension()) InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.DynamicAxes() == Axis::UnknownDynamicAxes()) InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs variableToNodeMap[variable] = nullptr; std::shared_ptr> computationNodePtr; if (variable.IsParameter() || variable.IsConstant()) { auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape())); network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end())) computationNodePtr->SetLearningRateMultiplier(0.0); NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value(); std::shared_ptr> valueMatrix = variable.IsConstant() ? value->GetMatrix() : value->GetWritableMatrix(); if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId())) computationNodePtr->Value() = valueMatrix->AsReference(); else // Constant: if initialized data lives on wrong device, make a copy to the right one (copy is OK since it's constant) { Matrix clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat()); clonedMatrix.AssignValuesOf(*valueMatrix); computationNodePtr->Value() = std::move(clonedMatrix); } } else if (variable.IsInput()) { auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); // TODO: Input variables currently are required to have the default batch axis auto dynamicAxes = variable.DynamicAxes(); auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis()); if (foundDefaultBatchAxis == dynamicAxes.end()) LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes"); if (dynamicAxes.back() != Axis::DefaultBatchAxis()) LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes"); // TODO: Support inputs with > 1 dynamic axes if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2)) LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported"); // Construct the dynamic axis name to be used internally for the CNTK InputNodes std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName)) network->AddNodeToNetAndAttachInputs(New>(network->GetDeviceId(), internalDynamicAxisName), {}); if (IsSparseInput(variable)) computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); else computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); if (variable.NeedsGradient() && (inputsToExcludeGradientsFor.find(variable) == inputsToExcludeGradientsFor.end())) { // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default // gradients are not computed for Input nodes computationNodePtr->SetLearningRateMultiplier(0.00001f); } } else { assert(variable.IsOutput()); auto outputVariableNode = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor); // Can be null in case of loops with f.output == f.input. // Such loops cannot be handled, so we leave nullptr as computational node. if (outputVariableNode) computationNodePtr = outputVariableNode->template As>()->shared_from_this(); else computationNodePtr = nullptr; } variableToNodeMap[variable] = computationNodePtr; if (isVariableRootMap.find(variable) == isVariableRootMap.end()) isVariableRootMap[variable] = variable.IsOutput(); return computationNodePtr; } /*static*/ Variable CompositeFunction::GetMappingForNoOpOutput(const Variable& variable, bool recursive) { Variable mappingVariable = variable; auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr; auto ownerPrimitiveFunc = dynamic_cast(ownerFunc); if (ownerPrimitiveFunc && (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp)) mappingVariable = ownerPrimitiveFunc->Inputs()[0]; if (recursive && (mappingVariable != variable)) return GetMappingForNoOpOutput(mappingVariable); else return mappingVariable; } /*static*/ Variable CompositeFunction::GetMappingVariable(const Variable& variable, bool recursive) { Variable mappingVariable = variable; auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr; auto ownerPrimitiveFunc = dynamic_cast(ownerFunc); if (ownerPrimitiveFunc) { if (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp) mappingVariable = GetMappingForNoOpOutput(variable); else { auto ownerBlockFunc = dynamic_cast(ownerFunc); if (ownerBlockFunc) mappingVariable = ownerBlockFunc->CompositeOutputsMap().at(variable); } } if (recursive && (mappingVariable != variable)) return GetMappingVariable(mappingVariable); else return mappingVariable; } template /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable, Function* function, const std::vector>>& inputNodes, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, std::unordered_map& variableToNodeMap) { PrimitiveFunction* primitiveFunction = dynamic_cast(function); if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp)) return variableToNodeMap[GetMappingVariable(variable)]; ComputationNodeBasePtr computationNodePtr; auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name()); std::vector inputNodesBasePtrs; for (auto inputNode : inputNodes) inputNodesBasePtrs.push_back(inputNode); if (primitiveFunction) { auto functionInputs = function->Inputs(); auto& functionConfig = function->Attributes(); PrimitiveOpType op = primitiveFunction->OpType(); switch (op) { case PrimitiveOpType::Negate: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sigmoid: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Tanh: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Cos: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sin: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ReLU: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Exp: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Log: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sqrt: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Floor: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Abs: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Reciprocal: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Softmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Hardmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::TransposeAxes: { auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value(); auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value(); // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2)); break; } case PrimitiveOpType::Where: { auto dynamicAxes = variable.DynamicAxes(); auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName); break; } case PrimitiveOpType::Slice: { auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value(); auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value(); // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based computationNodePtr = New>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis)); break; } case PrimitiveOpType::RandomSample: { auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); break; } case PrimitiveOpType::RandomSampleInclusionFrequency: { auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); break; } case PrimitiveOpType::Dropout: { auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName); computationNodePtr->As>()->SetDropoutRate(dropoutRate); break; } case PrimitiveOpType::Reshape: { computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(primitiveFunction->RawOutputs()[0].Shape())); break; } case PrimitiveOpType::ROIPooling: { auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape)); break; } case PrimitiveOpType::Pooling: { PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value()); auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW); break; } case PrimitiveOpType::Unpooling: { auto unpoolingWindowShape = functionConfig[PrimitiveFunction::AttributeNameUnpoolingWindowShape].Value(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); //We only get here after validation so it is safe to assume unpooling is max computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW); break; } case PrimitiveOpType::SumAll: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Plus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LogPlus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Minus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ElementTimes: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Equal: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::NotEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Less: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LessEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Greater: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::GreaterEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Times: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap); break; } case PrimitiveOpType::TransposeTimes: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank); break; } case PrimitiveOpType::Convolution: { NDShape outputMapCount, kernelShape; std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape()); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples); break; } case PrimitiveOpType::CosDistance: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Logistic: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::SquaredError: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::CrossEntropyWithSoftmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ClassificationError: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::EditDistanceError: { auto subPen = functionConfig[PrimitiveFunction::AttributeNameSubstitutionPenalty].Value(); auto delPen = functionConfig[PrimitiveFunction::AttributeNameDeletionPenalty].Value(); auto insPen = functionConfig[PrimitiveFunction::AttributeNameInsertionPenalty].Value(); auto squashInputs = functionConfig[PrimitiveFunction::AttributeNameSquashInputs].Value(); auto tokensToIgnore = AsVector(functionConfig[PrimitiveFunction::AttributeNameTokensToIgnore].Value>()); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore); break; } case PrimitiveOpType::LambdaRank: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::NDCG: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::PastValue: case PrimitiveOpType::FutureValue: { Variable inputOperandVar = functionInputs[0]; Variable initialStateVar = functionInputs[1]; size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value(); if (op == PrimitiveOpType::PastValue) computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); else computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); break; } case PrimitiveOpType::ReduceElements: { auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis)); break; } case PrimitiveOpType::BatchNormalization: { auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value(); auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value(); auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value(); auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value(); auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW); break; } case PrimitiveOpType::Combine: // This operation is just a no-op and is a means to combine multiple functions to create a single Function // whose outputs are a union of the outputs of the Functions being combined. computationNodePtr = variableToNodeMap[variable]; break; case PrimitiveOpType::PackedIndex: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::GatherPacked: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ScatterPacked: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Clip: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Select: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Splice: { Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis)); break; } case PrimitiveOpType::OptimizedRNNStack: { auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value(); auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value(); auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value(); auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp); break; } case PrimitiveOpType::ReconcileDynamicAxis: { computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::LogSoftmax: { //This can be implemented as x => x - ReduceLogSum(x). How to do this here? computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::Pass: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LabelsToGraph: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::StopGradient: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; default: LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str()); break; } // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs); if (computationNodePtr->Is()) { auto computationNodeExpectedInputCount = computationNodePtr->As()->GetExpectedNumInputs(); if (computationNodeExpectedInputCount != inputNodesBasePtrs.size()) LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs", PrimitiveOpTypeName(op).c_str(), (int)inputNodesBasePtrs.size(), (int)computationNodeExpectedInputCount); } if (computationNodePtr->Is()) { if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed)) { auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value(); uint64_t offset = 0; if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset)) { offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value(); } computationNodePtr->As()->SetRngState(seed, offset); } } } else { computationNodePtr = New>(network->GetDeviceId(), internalNodeName, function->shared_from_this()); // For user defined functions, we only attach unique inputs in the internal computation network since, the UDF // backward implementations directly compute aggregate gradient values for unique inputs std::vector uniqueInputNodesBasePtrs; for (auto inputNodeBasePtr : inputNodesBasePtrs) { if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end()) uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr); } inputNodesBasePtrs = uniqueInputNodesBasePtrs; } network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs); return computationNodePtr; } template /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder& builder, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap, const std::unordered_set& inputsToExcludeGradientsFor) { assert(variable.IsOutput()); Function* function = variable.Owner().get(); ComputationNodeBasePtr computationNodePtr; auto& functionInputs = function->m_inputs; DataType nonConstInputDataType = DataType::Unknown; for (auto& inputVar : functionInputs) { if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown)) { nonConstInputDataType = inputVar.GetDataType(); break; } } // Create the nodes corresponding to the inputs std::vector>> inputNodes; for (auto& inputVar : functionInputs) { // If the inputVar is a constant and not the right DataType let's coerce it to the right type if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType)) { auto originalConstantValue = Constant(inputVar).Value(); auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true); NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true); inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name()); } auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor); inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As>()->shared_from_this() : nullptr); } BlockFunction* blockFunction = dynamic_cast(function); if (blockFunction) { // For block function, map each argument placeholder of the underlying composite to // the computation node corresponding to the block input that the argument placeholder // of the composite is mapped to. auto compositeArguments = blockFunction->Composite()->Arguments(); for (auto compositeArgument : compositeArguments) variableToNodeMap[compositeArgument] = variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping()); return GetNode(variable.BlockFunctionVariableMapping(), network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor); } else computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap); PrimitiveFunction* primitiveFunction = dynamic_cast(function); if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine)) { for (auto inputVar : functionInputs) isVariableRootMap[inputVar] = false; } return computationNodePtr; } std::unordered_set CompositeFunction::NonOwnerPreservingCopy(const std::unordered_set& outputs) { std::unordered_set result; for (auto& o : outputs) { Variable sanitized = o.NonCompositePreservingCopy(); result.insert(sanitized); } return result; } template ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, const std::unordered_set& outputs, const std::unordered_set& inputsToExcludeGradientsFor, bool allocateNetworkMatrices) { if (m_computationNetwork != nullptr) { // TODO: We should either invalidate and readapt the network if the backpropRoots change compared to what was specified when the network // was last constructed, to just recreate a new network. // For now just disallow changing the backpropRoots after the network is created if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots)) LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported"); // TODO: Support changing the device across different invocations of the forward method on a Function instance if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device) LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported"); if (!backpropRoots.empty() && (inputsToExcludeGradientsFor != m_inputsExcludedFromGradientComputation)) LogicError("Changing the set of inputs to exclude from gradient computation, across different Forward calls on a CNTK composite Function, is currently unsupported"); } else { m_computationNetwork = std::make_shared(AsCNTKImplDeviceId(device)); auto networkInputs = this->Inputs(); for (auto inputExcluded : inputsToExcludeGradientsFor) { // Only inputs of the network can be excluded from gradient computation if (std::find(networkInputs.begin(), networkInputs.end(), inputExcluded) == networkInputs.end()) InvalidArgument("Function::Forward: Only inputs of a Function can be excluded from gradient computation"); } m_inputsExcludedFromGradientComputation = NonOwnerPreservingCopy(inputsToExcludeGradientsFor); ComputationNetworkBuilder builder(*m_computationNetwork); // TODO: We currently only support one backprop root if (backpropRoots.size() > 1) LogicError("More than one backprop roots is currently unsupported"); auto placeholders = Placeholders(); if (!placeholders.empty()) InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!"); // Now recursively create the network in a top-down fashion auto rootFunction = RootFunction(); auto rootFunctionOutputs = rootFunction->RawOutputs(); for (auto rootOutput : rootFunctionOutputs) GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap, m_inputsExcludedFromGradientComputation); // We need to patch the Computation node mappings for the arguments of block functions // since for recurrent inputs, the mappings are not fully established the first time std::function PatchBlockArgumentsMapping; PatchBlockArgumentsMapping = [this, &PatchBlockArgumentsMapping](const FunctionPtr& function) { BlockFunction* blockFunction = dynamic_cast(function.get()); if (blockFunction) { auto compositeArguments = blockFunction->Composite()->Arguments(); for (auto compositeArgument : compositeArguments) m_variableToNodeMap[compositeArgument] = m_variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping()); PreorderTraverseFunctions(function->BlockRoot(), PatchBlockArgumentsMapping); } }; PreorderTraverseFunctions(rootFunction, PatchBlockArgumentsMapping); std::function IsVariableRoot; IsVariableRoot = [this, &IsVariableRoot](const Variable& outputVar) { auto mappingVariable = GetMappingVariable(outputVar); return (m_isVariableRootMap[outputVar] && ((mappingVariable == outputVar) || IsVariableRoot(mappingVariable))); }; // If any of the function or requested outputs is not a root node, we need to explicitly // add it to the 'output' group of the ComputationNetwork std::unordered_set networkOutputs(outputs); networkOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end()); for (auto output : networkOutputs) { if (!IsVariableRoot(output)) { auto computationNode = m_variableToNodeMap[output]; if (!computationNode) InvalidArgument("One of the requested outputs for the Function forward computation is not part of the graph underlying the Function"); m_computationNetwork->AddToNodeGroup(L"output", computationNode); } } m_currentBackpropRoots = NonOwnerPreservingCopy(backpropRoots); // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles. // Now attach those after we have created all ComputationNodes in the network for (auto varNodePair : m_variableToNodeMap) { auto& currentComputationNode = varNodePair.second; if (!currentComputationNode) LogicError("No computation node mapping exists for Variable %S", varNodePair.first.Name().c_str()); auto& currentComputationNodeInputs = currentComputationNode->GetInputs(); auto& currentVar = varNodePair.first; if (!currentVar.IsOutput()) continue; if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end()) { // This ComputationNode has at least one null input which now needs to be properly attached const PrimitiveFunction* primitiveFunc = dynamic_cast(currentVar.Owner().get()); // Skip block primitives since they do not directly map to a computation node if (primitiveFunc->OpType() == PrimitiveOpType::Block) continue; // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering auto inputVars = primitiveFunc->Inputs(); ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars); inputVars.resize(currentComputationNode->GetNumInputs()); std::vector inputNodesBasePtrs; for (auto inputVar : inputVars) inputNodesBasePtrs.push_back(m_variableToNodeMap.at(inputVar)); currentComputationNode->AttachInputs(inputNodesBasePtrs); } } m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel()); m_computationNetwork->CompileNetwork(); // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork for (auto varNodePair : m_variableToNodeMap) { if (varNodePair.first.IsOutput()) { auto outputVar = varNodePair.first; auto computationNodePtr = m_variableToNodeMap.at(outputVar); auto outputShape = outputVar.Shape(); auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout(); if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) || ((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape)))) { LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str()); } } } // Record the timestamps of Parameter values assert(m_lastRecordedParameterValueTimeStamps.empty()); auto functionParameters = Parameters(); for (auto parameter : functionParameters) m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() }); } if (!m_networkMatricesAllocated && allocateNetworkMatrices) { ComputationNodeBasePtr backpropRootNode; if (!m_currentBackpropRoots.empty()) backpropRootNode = m_variableToNodeMap.at(*m_currentBackpropRoots.begin()); // Now recursively traverse the network in a top-down fashion auto rootFunction = RootFunction(); auto rootFunctionOutputs = rootFunction->RawOutputs(); std::vector forwardRootNodes; for (auto rootOutput : rootFunctionOutputs) forwardRootNodes.push_back(m_variableToNodeMap.at(rootOutput)); std::vector forwardOutputNodes; for (auto output : outputs) forwardOutputNodes.push_back(m_variableToNodeMap.at(output)); m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode); m_networkMatricesAllocated = allocateNetworkMatrices; std::unordered_set allNetworkRoots = { backpropRootNode }; allNetworkRoots.insert(forwardRootNodes.begin(), forwardRootNodes.end()); allNetworkRoots.insert(forwardOutputNodes.begin(), forwardOutputNodes.end()); m_allNetworkRootsInGlobalEvalOrder = m_computationNetwork->SortByGlobalEvalOrder(allNetworkRoots); } else { // Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure // in the cached computation network for (auto output : outputs) { auto computationNode = m_variableToNodeMap.at(output); if (std::find(m_allNetworkRootsInGlobalEvalOrder.begin(), m_allNetworkRootsInGlobalEvalOrder.end(), computationNode) == m_allNetworkRootsInGlobalEvalOrder.end()) LogicError("Changing requested outputs across different Forward calls on a CNTK composite Function is currently unsupported"); } } return m_computationNetwork; } template /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair& variableValue, ComputationNodeBasePtr& computationNode, std::unordered_map& layoutsPopulated) { if (!computationNode->Is>()) LogicError("CompositeFunction::Forward: Illegal to populate value of computation node type other than InputValueBase!"); std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject(variableValue.first, variableValue.second); // Switch the node matrix to the right matrix type auto& nodeData = computationNode->As>()->Value(); nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first); auto layout = CNTKMatrixAndMBLayout.second; auto& nodeLayout = computationNode->GetMBLayout(); if (layoutsPopulated.find(nodeLayout) == layoutsPopulated.end()) { nodeLayout->CopyFrom(layout); layoutsPopulated.insert({ nodeLayout, variableValue.first }); } else { if (*nodeLayout != *layout) InvalidArgument("Function::Forward: Different minibatch layouts detected (difference in sequence lengths or count or start flags) in data specified for 2 of the Function's argument ('%S', '%S') having same dynamic axes", variableValue.first.Name().c_str(), layoutsPopulated.at(nodeLayout).Name().c_str()); } } void CompositeFunction::PopulateNetworkInputs(const std::unordered_map& arguments) { std::unordered_map layoutsPopulated; std::vector inputNodes; for (auto argumentValuePair : arguments) { auto argument = argumentValuePair.first; auto argumentComputationNode = m_variableToNodeMap.at(argument); assert(argumentComputationNode); inputNodes.push_back(argumentComputationNode); ValuePtr argumentValue = arguments.at(argument); switch (argumentValue->GetDataType()) { case DataType::Float: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode, layoutsPopulated); break; case DataType::Double: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode, layoutsPopulated); break; default: LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType())); break; } } m_computationNetwork->BumpEvalTimeStamp(inputNodes); } template /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode) { std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject(variableGradient.first, variableGradient.second); MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; auto nodeLayout = computationNode->GetMBLayout(); if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout))) InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call"); computationNode->As>()->AssignGradient(*CNTKMatrixAndMBLayout.first); } // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph void CompositeFunction::PopulateNetworkGradients(const std::unordered_map& gradients) { auto functionOutputs = RawOutputs(); for (auto gradientVarValuePair : gradients) { auto outputComputationNode = m_variableToNodeMap.at(gradientVarValuePair.first); ValuePtr gradientValue = gradientVarValuePair.second; switch (gradientValue->GetDataType()) { case DataType::Float: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; case DataType::Double: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; default: LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType())); break; } } } static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr) { size_t outputValueNumAxes = var.Shape().Rank(); // Add the batch and dynamic axes if needed if (computationNodePtr->GetMBLayout() != nullptr) outputValueNumAxes += 2; std::vector outputShapeDims(outputValueNumAxes); for (size_t i = 0; i < var.Shape().Rank(); ++i) outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i); if (computationNodePtr->GetMBLayout() != nullptr) { outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps(); outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences(); } return NDShape(outputShapeDims); } /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient) { auto valueShape = GetValueShape(var, computationNode); if (varValue != nullptr) { // TODO: The shape of the specified output Value object must match the actual output shape if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape))) InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str()); } ValuePtr nodeValue; auto layout = computationNode->GetMBLayout(); switch (var.GetDataType()) { case DataType::Float: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); break; } case DataType::Double: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); break; } default: LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType())); break; } if (varValue == nullptr) varValue = nodeValue; else varValue->CopyFrom(*nodeValue); } void CompositeFunction::GetNetworkOutputs(std::unordered_map& outputs) { // Now copy the Forward values of output nodes from the network to outputs' Value objects for (auto outputVarValuePair : outputs) GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap.at(outputVarValuePair.first), false /*getGradient*/); } void CompositeFunction::GetNetworkGradients(std::unordered_map& gradients) { auto networkInputs = this->Inputs(); // Now copy the gradient values of input nodes of the network to gradients' Value objects for (auto gradientVarValuePair : gradients) { // Only gradients corresponding to inputs of the network can be obtained if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end()) InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function"); // Gradients can only be obtained for parameter variables or input variables that NeedsGradient if (!gradientVarValuePair.first.NeedsGradient() || (m_inputsExcludedFromGradientComputation.find(gradientVarValuePair.first) != m_inputsExcludedFromGradientComputation.end())) InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, an Input Variable with NeedsGradient setting of false, or an input for which gradient computation was explicitly excluded"); auto computationNodePtr = m_variableToNodeMap.at(gradientVarValuePair.first); if (!computationNodePtr->NeedsGradient()) LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false"); GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/); } } const std::vector& CompositeFunction::GetArgumentDependencies(const Variable& output) { if (m_perOutputVarArgumentDependencies.find(output) == m_perOutputVarArgumentDependencies.end()) { auto sanitizedOutput = output.NonCompositePreservingCopy(); if (sanitizedOutput.IsOutput()) m_perOutputVarArgumentDependencies[sanitizedOutput] = AsComposite(sanitizedOutput.Owner())->Arguments(); else m_perOutputVarArgumentDependencies[sanitizedOutput] = { sanitizedOutput }; } return m_perOutputVarArgumentDependencies[output]; } std::unordered_map CompositeFunction::GetCurrentBackpropRootsTimeStamps() const { std::unordered_map currentBackpropRootsTimeStamps; assert(m_computationNetwork != nullptr); for (auto& backpropRoot : m_currentBackpropRoots) currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp(); return currentBackpropRootsTimeStamps; } /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map& arguments, std::unordered_map& outputs, const DeviceDescriptor& computeDevice, const std::unordered_set& outputsToRetainBackwardStateFor, const std::unordered_set& inputsToExcludeGradientsFor) { // Validate arguments and outputs if (outputs.empty()) InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!"); // Make sure that the DataType of the variables and corresponding values match // TODO: We need a better way to determine the ElementType for the network auto dataType = DataType::Unknown; for (auto variableValuePair : arguments) { if (dataType == DataType::Unknown) dataType = variableValuePair.first.GetDataType(); else if (dataType != variableValuePair.first.GetDataType()) LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same"); } if (dataType == DataType::Unknown) { for (auto variableValuePair : outputs) { if (dataType == DataType::Unknown) dataType = variableValuePair.first.GetDataType(); } } std::unordered_set requestedOutputVariables; for (auto output : outputs) requestedOutputVariables.insert(output.first); if (dataType == DataType::Float) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true); else if (dataType == DataType::Double) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true); else InvalidArgument("Unsupported DataType %s", DataTypeName(dataType)); std::unordered_set functionOutputs(m_outputs.begin(), m_outputs.end()); std::vector outputsToEvaluate; std::unordered_set requiredArguments; for (auto outputVariable : requestedOutputVariables) { auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVariable); requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end()); auto outputComputationNode = m_variableToNodeMap.at(outputVariable); outputsToEvaluate.push_back(outputComputationNode); } // We should have argument values supplied for all required argument dependencies for the requested outputs std::vector missingRequiredArguments; std::unordered_map requiredArgumentValues; for (auto requiredArgument : requiredArguments) { auto iter = arguments.find(requiredArgument); if (iter == arguments.end()) missingRequiredArguments.push_back(requiredArgument); else requiredArgumentValues.insert(*iter); } if (!missingRequiredArguments.empty()) { std::wstring missingRequiredArgumentNames = NamedListString(missingRequiredArguments); InvalidArgument("Function::Forward: Values for %d required arguments (%S), that the requested output(s) depend on, have not been provided", (int)missingRequiredArguments.size(), missingRequiredArgumentNames.c_str()); } if (requiredArgumentValues.size() < arguments.size()) fprintf(stderr, "WARNING: Function::Forward provided values for (%d) extra arguments which are not required for evaluating the specified Function outputs!\n", (int)(arguments.size() - requiredArgumentValues.size())); // Feed data into the arguments of the network // TODO: Avoid copying the data when possible PopulateNetworkInputs(requiredArgumentValues); // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input // This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated // w.r.t. inputs to force evaluation in each minibatch list dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode)); for (auto& nodeIter : dropoutNodes) nodeIter->SetEvalTimeStampOutdatedWrtAll(); // Bump the timestamp of the parameter nodes whose values have changed for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps) { auto parameter = paramTimeStampRecord.first; auto prevTimeStamp = paramTimeStampRecord.second; auto newTimeStamp = parameter.CurrentValueTimeStamp(); if (newTimeStamp > prevTimeStamp) { paramTimeStampRecord.second = newTimeStamp; m_variableToNodeMap.at(parameter)->BumpEvalTimeStamp(); } } // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs' for (auto rootVarForBackprop : outputsToRetainBackwardStateFor) { if (outputs.find(rootVarForBackprop) == outputs.end()) outputsToEvaluate.push_back(m_variableToNodeMap.at(rootVarForBackprop)); } // Reset the timestamps of all backward roots to record an update in one or more inputs for (auto& backpropRoot : m_currentBackpropRoots) m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll(); // TODO: Verify that values were supplied for all inputs that requested outputs depend on ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training); // We may have to include additional nodes in the ForwardProp to align with how the memory sharing structure is setup // We need to include all roots that lie earlier in the global eval order than the actual outputs we are interested // in evaluation. // TODO: This may incur additonal compute costs in some rare scenarios. We need to come up with a better way to handle this. outputsToEvaluate = m_computationNetwork->SortByGlobalEvalOrder(outputsToEvaluate); auto lastOutputInEvalOrder = outputsToEvaluate.back(); auto iterEndRootInEvalOrder = std::find(m_allNetworkRootsInGlobalEvalOrder.begin(), m_allNetworkRootsInGlobalEvalOrder.end(), lastOutputInEvalOrder) + 1; auto augmentedOutputsToEvaluate = std::vector(m_allNetworkRootsInGlobalEvalOrder.begin(), iterEndRootInEvalOrder); m_computationNetwork->ForwardProp(augmentedOutputsToEvaluate); GetNetworkOutputs(outputs); // TODO: How to deal with the specified 'computeDevice' Variable evalTimeStampVariable; if (requiredArgumentValues.empty()) evalTimeStampVariable = Inputs()[0]; else evalTimeStampVariable = requiredArgumentValues.begin()->first; BackPropStatePtr backpropStatePtr; if (outputsToRetainBackwardStateFor.size() > 0) backpropStatePtr = MakeSharedObject(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps()); return backpropStatePtr; } /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state, const std::unordered_map& rootGradientValues, std::unordered_map& backPropagatedGradientValuesForInputs) { auto backpropState = dynamic_cast(state.get()); if (backpropState == nullptr) InvalidArgument("Invalid backprop state specified"); // TODO: Support multiple concurrent backprop states std::unordered_map currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps(); if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps) LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function." "This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported"); if (rootGradientValues.size() > 1) LogicError("Currently gradient backprop from only one of the Function Outputs is supported"); // TODO: Avoid copying the data when possible // Zero all gradients of nodes below the root nodes for (auto rootGradientVarValuePair : rootGradientValues) m_computationNetwork->ZeroInputGradients(m_variableToNodeMap.at(rootGradientVarValuePair.first)); // Feed data into the arguments of the network PopulateNetworkGradients(rootGradientValues); // Backpropagate through the network ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training); auto rootComputationNodePtr = m_variableToNodeMap.at(rootGradientValues.begin()->first); m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true); GetNetworkGradients(backPropagatedGradientValuesForInputs); // TODO: How to deal with the specified 'computeDevice' } }