// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include "CNTKLibrary.h" #include "CompositeFunction.h" #include "ComputationNetworkBuilder.h" #include "Utils.h" #include "ComputationNode.h" #include "ReshapingNodes.h" #include "EvaluationNodes.h" #include "TrainingNodes.h" #include "LinearAlgebraNodes.h" #include "InputAndParamNodes.h" #include "NonlinearityNodes.h" #include "RecurrentNodes.h" #include "Serialization.h" #include "Value.h" #include "RNNNodes.h" #include "UserDefinedV2FunctionNode.h" #include "BlockFunction.h" #include "SpecialPurposeNodes.h" #include "SequenceReshapeNodes.h" #include "UserDefinedFunction.h" using namespace Microsoft::MSR::CNTK; namespace CNTK { /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName"; /*static*/ std::atomic CompositeFunction::s_nextAutoGeneratedDynamicAxis(0); static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction"; Dictionary CompositeFunction::SerializeBlockComposite() const { Dictionary dict; dict[versionKey] = CurrentVersion(); dict[typeKey] = s_compositeFunctionTypeValue; dict[rootKey] = RootFunction()->Uid(); if (!Name().empty()) dict[nameKey] = Name(); dict[uidKey] = Uid(); return dict; } // Copy the internal state from the network into the function graph, // specifically from RngUser nodes into the attributes dictionaries of // the corresponding stateful primitive functions. void CompositeFunction::UpdateInternalState() const { if (!m_computationNetwork) return; for (auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); if (!primitiveFunction || !primitiveFunction->IsStateful()) continue; // TODO: same for BatchNorm auto& outputs = primitiveFunction->RawOutputs(); if (outputs.size() != 1) LogicError("Function '%S' UpdateInternalState: a stateful primitive function must have a single output.", AsString().c_str()); const auto& rng = m_variableToNodeMap.at(outputs[0])->As(); Dictionary state; state[PrimitiveFunction::AttributeNameRngSeed] = static_cast(rng->GetRngSeed()); state[PrimitiveFunction::AttributeNameRngOffset] = static_cast(rng->GetRngOffset()); primitiveFunction->SetState(state); } } // Generate a dictionary representing the internal (local) state of the function graph. Dictionary CompositeFunction::GetInternalState() const { UpdateInternalState(); Dictionary stateDictionary; for (auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); if (!primitiveFunction || !primitiveFunction->IsStateful()) continue; // TODO: same for BatchNorm stateDictionary[primitiveFunction->Uid()] = primitiveFunction->GetState(); } return stateDictionary; } /*virtual*/ Dictionary CompositeFunction::Serialize() const { UpdateInternalState(); Dictionary dict = SerializeBlockComposite(); // Find cycles in the graph and "break" them by inserting placeholders. // This needs to be done on Save, since here we have easy access to the shape and // dynamic axis info. std::unordered_set visitedFunctions; std::vector topoSortedPrimitiveFunctions; std::vector uniqueInputs; std::unordered_set inputUids; std::function SerializationTraversalFunc; SerializationTraversalFunc = [&visitedFunctions, &uniqueInputs, &topoSortedPrimitiveFunctions, &inputUids, &SerializationTraversalFunc](const FunctionPtr& function) { std::vector functionInputs = function->Inputs(); for (const auto& input : functionInputs) { auto& uid = input.Uid(); if (inputUids.find(uid) != inputUids.end()) continue; // check if this input corresponds to a cyclic edge in the graph. // BUG: A function being visited twice does not indicate it being a cyclic edge in the graph. // It just means there are at least 2 successors in the graph that have the function as input bool mustBeReplaced = input.IsOutput() && (visitedFunctions.find(input.Owner()) != visitedFunctions.end()); if (mustBeReplaced) { auto varKind = VariableKind::Placeholder; Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid); uniqueInputs.push_back(var); inputUids.insert(uid); } else if (!input.IsOutput()) { // leave the input as is. uniqueInputs.push_back(input); inputUids.insert(uid); } } visitedFunctions.insert(function); topoSortedPrimitiveFunctions.push_back(function); // For block functions we need to recursively traverse the underlying composite if (function->IsBlock()) PreorderTraverseFunctions(function->BlockRoot(), SerializationTraversalFunc); }; PreorderTraverseFunctions(RootFunction(), SerializationTraversalFunc); std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions)); assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid()); std::vector inputDictionaries; inputDictionaries.reserve(uniqueInputs.size()); inputUids.clear(); for (const auto& input : uniqueInputs) { if (inputUids.find(input.Uid()) != inputUids.end()) LogicError("Function '%S' Serialize: Input uids must be unique.", AsString().c_str()); inputUids.insert(input.Uid()); inputDictionaries.push_back(input.Serialize()); } dict[inputsKey] = std::move(inputDictionaries); std::vector functionDictionaries; std::unordered_set outputUids; for (const auto& primitiveFunction : topoSortedPrimitiveFunctions) { for (const auto& output : primitiveFunction->RawOutputs()) { if (outputUids.find(output.Uid()) != outputUids.end()) LogicError("Function '%S' Serialize: Output uids of all primitive functions in a function graph must be unique", AsString().c_str()); outputUids.insert(primitiveFunction->Uid()); } auto functionDict = UDFUtils::IsUDF(primitiveFunction) ? UDFUtils::Serialize(primitiveFunction) : primitiveFunction->Serialize(); functionDictionaries.push_back(functionDict); } dict[functionsKey] = std::move(functionDictionaries); return dict; } /*static*/ FunctionPtr CompositeFunction::DeserializeBlockComposite(const Dictionary& dict, const std::unordered_set& allPrimitiveFunctions, const std::unordered_map& allPlaceholderReplacements, const CNTK::DeviceDescriptor& device) { static const vector s_requiredDictionaryKeys = { typeKey, rootKey, uidKey }; ValidateDictionary(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion); const auto& rootUid = dict[rootKey].Value(); std::wstring name = L""; if (dict.Contains(nameKey)) name = dict[nameKey].Value(); const auto& uid = dict[uidKey].Value(); FunctionPtr root = *std::find_if(allPrimitiveFunctions.begin(), allPrimitiveFunctions.end(), [&rootUid](const FunctionPtr& func) { return func->Uid() == rootUid; }); // Find the subset of placeholder replacements that apply for this composite FunctionPtr composite = CompositeFunction::Create(root, name, uid); std::unordered_map placeholderReplacements; do { placeholderReplacements.clear(); auto compositePlaceholders = composite->Placeholders(); for (auto placeholder : compositePlaceholders) { if (allPlaceholderReplacements.find(placeholder) != allPlaceholderReplacements.end()) placeholderReplacements.insert({ placeholder, allPlaceholderReplacements.at(placeholder) }); } if (placeholderReplacements.size() > 0) composite = composite->ReplacePlaceholders(placeholderReplacements); } while (placeholderReplacements.size() > 0); return composite; } /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device) { static const vector s_requiredDictionaryKeys = { inputsKey, functionsKey }; size_t version = ValidateDictionary(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion); const auto& inputs = dict[inputsKey].Value>(); std::unordered_map uidToInputMap(inputs.size()); for (const auto& dictionaryValue : inputs) { const auto& dictionary = dictionaryValue.Value(); const auto& inputVar = Variable::Deserialize(dictionary, device); if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end()) { CNTK::LogicError("CompositeFunction::Deserialize: Input uids are not unique (several inputs share '%ls' uid) (%s).", inputVar.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); } uidToInputMap[inputVar.Uid()] = inputVar; } const auto& functions = dict[functionsKey].Value>(); std::unordered_map allPlaceholderReplacements; std::unordered_set allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created. for (const auto& dictionaryValue : functions) { auto functionDict = dictionaryValue.Value(); FunctionPtr root = UDFUtils::IsUDF(functionDict) ? UDFUtils::Deserialize(functionDict, uidToInputMap, device) : PrimitiveFunction::Deserialize(functionDict, uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device); allPrimitiveFunctions.insert(root); auto primitiveFunction = dynamic_cast(root.get()); if (primitiveFunction != nullptr && primitiveFunction->OpType() == PrimitiveOpType::Combine) { // Since Combine simply forwards other functions' outputs, all of its outputs // should already be in the uidToInputMap. continue; } for (const auto& output : root->RawOutputs()) { const auto& it = uidToInputMap.find(output.Uid()); if (it != uidToInputMap.end()) { if (!it->second.IsPlaceholder()) { CNTK::LogicError("CompositeFunction::Deserialize: Unexpected variable '%S' instead of a Placeholder (uid = %ls) (%s).", it->second.AsString().c_str(), it->second.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); } allPlaceholderReplacements[it->second] = output; } else { uidToInputMap[output.Uid()] = output; } } } // starting with the serialization version = 3, the state is preserved inside the attribute dictionaries of the // corresponding primitive functions. Earlier versions have a dedicated key-value pair in the composite function dict. if (version < 3) RestoreStatefulFunctions(version, dict, allPrimitiveFunctions); return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device); } void CompositeFunction::RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set functions) { Dictionary stateDictionary; if (dict.Contains(stateKey)) stateDictionary = dict[stateKey].Value(); for (auto& function : functions) { auto primitiveFunction = dynamic_cast(function.get()); if (!primitiveFunction || !primitiveFunction->IsStateful()) continue; if (stateDictionary.Contains(primitiveFunction->Uid())) { auto state = stateDictionary[primitiveFunction->Uid()].Value(); // Add key-value pairs expected by the SetState method to the state dictionary. state[PrimitiveFunction::AttributeNameRngSeed] = state[rngSeedKey].Value(); state[PrimitiveFunction::AttributeNameRngOffset] = state[rngOffsetKey].Value(); primitiveFunction->SetState(state); } else { if (GetTraceLevel() >= TraceLevel::Warning) { // TODO: all logging functionality should be refactored to live in a logging utility class. fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) " "when deserializing from a dictionary (version=%zu). " "Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version); } // Create state from scratch, so that function attributes contain all the required key-value pairs. Dictionary state; state[PrimitiveFunction::AttributeNameRngSeed] = Internal::GenerateRandomSeed(true); state[PrimitiveFunction::AttributeNameRngOffset] = size_t(0); primitiveFunction->SetState(state); } } } void CompositeFunction::CopyState(const CompositeFunction& source) { // Collect a vector of stateful function uids using a pre-order traversal of a function graphs. auto collectStatefulFunctionUIDs = [](const Function& function) -> vector { vector uids; PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) { auto primitiveFunction = dynamic_cast(funcPtr.get()); if (primitiveFunction && primitiveFunction->IsStateful()) { uids.push_back(funcPtr->Uid()); } }, true); return uids; }; auto theirUIDs = collectStatefulFunctionUIDs(source); auto ourUIDs = collectStatefulFunctionUIDs(*this); if (theirUIDs.size() != ourUIDs.size()) CNTK::LogicError("Cannot copy internal state, the source and the destination contain different number of stateful functions."); auto state = source.GetInternalState(); if (theirUIDs == ourUIDs) { // uids are identialy, no need to remap. SetInternalState(state); return; } // build a map of souce function to the destination (this) function UIDs. map uidMap; for (auto i = 0; i < theirUIDs.size(); i++) uidMap[theirUIDs[i]] = ourUIDs[i]; Dictionary remappedState; for (auto& kv : state) remappedState[uidMap[kv.first]] = kv.second; SetInternalState(remappedState); } void CompositeFunction::SetInternalState(const Dictionary& state) { if (state.Size() == 0) return; for (const auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); if (!primitiveFunction || !primitiveFunction->IsStateful()) continue; auto functionState = state[primitiveFunction->Uid()].Value(); primitiveFunction->SetState(functionState); if (!m_computationNetwork) continue; auto seed = functionState[PrimitiveFunction::AttributeNameRngSeed].Value(); auto offset = functionState[PrimitiveFunction::AttributeNameRngOffset].Value(); // copy the state directly into the network for (const auto& output : function->RawOutputs()) { auto node = m_variableToNodeMap.at(output); node->As()->SetRngState(seed, offset); } } } template /*static*/ Microsoft::MSR::CNTK::ComputationNodeBasePtr CompositeFunction::CreateLearnableParameterFromVariable(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkBuilder& builder, const NDShape& shape, const std::wstring& name) { switch (variable.GetDataType()) { case DataType::Float: return builder.template TypedCreateLearnableParameter(name, AsTensorShape(shape)); case DataType::Double: return builder.template TypedCreateLearnableParameter(name, AsTensorShape(shape)); case DataType::Float16: return builder.template TypedCreateLearnableParameter(name, AsTensorShape(shape)); default: return builder.CreateLearnableParameter(name, AsTensorShape(shape)); } } /*static*/ void CompositeFunction::CastAssignNodeValue(ComputationNodeBasePtr node, DataType dataType, std::shared_ptr matrix) { switch (dataType) { case DataType::Float: return (dynamic_cast*>(&*node))->Value().CastAssignValuesOf(*matrix); case DataType::Double: return (dynamic_cast*>(&*node))->Value().CastAssignValuesOf(*matrix); case DataType::Float16: return (dynamic_cast*>(&*node))->Value().CastAssignValuesOf(*matrix); default: LogicError("Unsupported data type"); } } // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the // top level 'variable' template /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder& builder, const std::unordered_map& fullyDefinedArgumentsMap, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap, const std::unordered_set& inputsToExcludeGradientsFor, bool useMangledNamesForComputationNodes) { auto iter = variableToNodeMap.find(variable); if (iter != variableToNodeMap.end()) { isVariableRootMap[variable] = false; return iter->second; } // The DataType, Shape and DynamicAxes of the variable must be known by now if (variable.GetDataType() == DataType::Unknown) InvalidArgument("Variable '%S' with unknown DataType found when compiling the Function graph.", variable.AsString().c_str()); if (variable.Shape().IsUnknown()) InvalidArgument("Variable '%S' with unknown shape found when compiling the Function graph.", variable.AsString().c_str()); if (variable.DynamicAxes() == Axis::UnknownDynamicAxes()) InvalidArgument("Variable '%S' with unknown dynamic axes found when compiling the Function graph.", variable.AsString().c_str()); // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs variableToNodeMap[variable] = nullptr; std::shared_ptr computationNodePtr; auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes); if (variable.IsParameter() || variable.IsConstant()) { if (variable.Shape().HasInferredDimension()) InvalidArgument("Parameter or Constant '%S' with unresolved shape %S found when compiling the Function graph.", variable.AsString().c_str(), variable.Shape().AsString().c_str()); computationNodePtr = CreateLearnableParameterFromVariable(variable, builder, variable.Shape(), internalNodeName); network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end())) computationNodePtr->SetLearningRateMultiplier(0.0); NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value(); std::shared_ptr valueMatrix = variable.IsConstant() ? value->GetMatrixBase() : value->GetWritableMatrixBase(); if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId())) { // shallow copy from parameter value to computation node value to link them together switch (variable.GetDataType()) { case DataType::Float: std::dynamic_pointer_cast>(computationNodePtr)->Value() = std::dynamic_pointer_cast>(valueMatrix)->AsReference(); break; case DataType::Double: std::dynamic_pointer_cast>(computationNodePtr)->Value() = std::dynamic_pointer_cast>(valueMatrix)->AsReference(); break; case DataType::Float16: std::dynamic_pointer_cast>(computationNodePtr)->Value() = std::dynamic_pointer_cast>(valueMatrix)->AsReference(); break; default: LogicError("Unsupported data type"); } } else // Constant: if initialized data lives on wrong device, make a copy to the right one (copy is OK since it's constant) { // TODO: the following two lines are a workaround for a bug in the Math library // (AssignValuesOf throws when source and destination matrices reside on different GPU devices). // Once this bug is fixed, change to // Matrix clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat()); switch (variable.GetDataType()) { case DataType::Float: { Matrix& nodeValue = dynamic_cast*>(&*computationNodePtr)->Value(); Matrix clonedMatrix(nodeValue.GetNumRows(), nodeValue.GetNumCols(), valueMatrix->GetDeviceId(), nodeValue.GetMatrixType(), nodeValue.GetFormat()); clonedMatrix.CastAssignValuesOf(*valueMatrix); clonedMatrix.TransferToDeviceIfNotThere(network->GetDeviceId(), true); nodeValue = std::move(clonedMatrix); break; } case DataType::Double: { Matrix& nodeValue = dynamic_cast*>(&*computationNodePtr)->Value(); Matrix clonedMatrix(nodeValue.GetNumRows(), nodeValue.GetNumCols(), valueMatrix->GetDeviceId(), nodeValue.GetMatrixType(), nodeValue.GetFormat()); clonedMatrix.CastAssignValuesOf(*valueMatrix); clonedMatrix.TransferToDeviceIfNotThere(network->GetDeviceId(), true); nodeValue = std::move(clonedMatrix); break; } case DataType::Float16: { Matrix& nodeValue = dynamic_cast*>(&*computationNodePtr)->Value(); Matrix clonedMatrix(nodeValue.GetNumRows(), nodeValue.GetNumCols(), valueMatrix->GetDeviceId(), nodeValue.GetMatrixType(), nodeValue.GetFormat()); clonedMatrix.CastAssignValuesOf(*valueMatrix); clonedMatrix.TransferToDeviceIfNotThere(network->GetDeviceId(), true); nodeValue = std::move(clonedMatrix); break; } default: LogicError("Unsupported data type"); } } } else if (variable.IsInput()) { auto fullyDefinedArgumentVar = variable; if (fullyDefinedArgumentVar.Shape().HasFreeDimension() && (fullyDefinedArgumentsMap.find(fullyDefinedArgumentVar) != fullyDefinedArgumentsMap.end())) fullyDefinedArgumentVar = fullyDefinedArgumentsMap.at(fullyDefinedArgumentVar); if (fullyDefinedArgumentVar.Shape().HasUnboundDimension()) InvalidArgument("Input Variable '%S' with unresolved shape %S found when compiling the Function graph.", fullyDefinedArgumentVar.AsString().c_str(), fullyDefinedArgumentVar.Shape().AsString().c_str()); // TODO: Input variables currently are required to have the default batch axis auto dynamicAxes = variable.DynamicAxes(); auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis()); if (IsSparseInput(variable) && (foundDefaultBatchAxis == dynamicAxes.end())) CNTK::LogicError("Sparse Input Variable '%S' found without a DefaultBatchAxis dynamic axis; this is currently unsupported.", variable.AsString().c_str()); if (!dynamicAxes.empty() && (dynamicAxes.back() != Axis::DefaultBatchAxis())) CNTK::LogicError("Input Variable '%S' does not have the DefaultBatchAxis as its last dynamic axis.", variable.AsString().c_str()); // TODO: Support inputs with > 1 dynamic axes if (dynamicAxes.size() > 2) CNTK::LogicError("Input Variable '%S' has %d dynamic axes; currently only inputs with <= 2 dynamic axes are supported.", variable.AsString().c_str(), (int)dynamicAxes.size()); if (!dynamicAxes.empty()) { // Construct the dynamic axis name to be used internally for the CNTK InputNodes std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName)) network->AddNodeToNetAndAttachInputs(New>(network->GetDeviceId(), internalDynamicAxisName), {}); if (IsSparseInput(variable)) computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()), internalDynamicAxisName); else computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()), internalDynamicAxisName); if (variable.NeedsGradient() && (inputsToExcludeGradientsFor.find(variable) == inputsToExcludeGradientsFor.end())) { // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default // gradients are not computed for Input nodes computationNodePtr->SetLearningRateMultiplier(0.00001f); } } else { computationNodePtr = CreateLearnableParameterFromVariable(variable, builder, fullyDefinedArgumentVar.Shape(), internalNodeName); network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end())) computationNodePtr->SetLearningRateMultiplier(0.0); } if (variable.Shape().HasFreeDimension()) computationNodePtr->MarkNeedsDynamicValidation(); } else { assert(variable.IsOutput()); // The primary output of a function is its first output auto primaryOutput = variable.Owner()->RawOutputs()[0]; // If this variable is not the primary output, then we need to look up the primary output // so that the variableToNodeMap will contain a mapping from the primary variable to an OutputMultiplexerNode // which is how we handle multiple outputs in V1. See also the instantiation of OutputMultiplexerNode in this file. if (primaryOutput != variable) GetNode(primaryOutput, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes); auto outputVariableNode = GetOutputVariableNode(variable, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes); // Can be null in case of loops with f.output == f.input. // Such loops cannot be handled, so we leave nullptr as computational node. if (outputVariableNode) computationNodePtr = outputVariableNode->template As()->shared_from_this(); else computationNodePtr = nullptr; } variableToNodeMap[variable] = computationNodePtr; if (isVariableRootMap.find(variable) == isVariableRootMap.end()) isVariableRootMap[variable] = variable.IsOutput(); return computationNodePtr; } /*static*/ Variable CompositeFunction::GetMappingForNoOpOutput(const Variable& variable, bool recursive) { Variable mappingVariable = variable; auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr; auto ownerPrimitiveFunc = dynamic_cast(ownerFunc); if (ownerPrimitiveFunc && (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp)) mappingVariable = ownerPrimitiveFunc->Inputs()[0]; if (recursive && (mappingVariable != variable)) return GetMappingForNoOpOutput(mappingVariable); else return mappingVariable; } /*static*/ Variable CompositeFunction::GetMappingVariable(const Variable& variable, bool recursive) { Variable mappingVariable = variable; auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr; auto ownerPrimitiveFunc = dynamic_cast(ownerFunc); if (ownerPrimitiveFunc) { if (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp) mappingVariable = GetMappingForNoOpOutput(variable); else { auto ownerBlockFunc = dynamic_cast(ownerFunc); if (ownerBlockFunc) mappingVariable = ownerBlockFunc->CompositeOutputsMap().at(variable); } } if (recursive && (mappingVariable != variable)) return GetMappingVariable(mappingVariable); else return mappingVariable; } template /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable, Function* function, const std::vector>& inputNodes, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, std::unordered_map& variableToNodeMap, bool useMangledNamesForComputationNodes) { PrimitiveFunction* primitiveFunction = dynamic_cast(function); if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp)) return variableToNodeMap[GetMappingVariable(variable)]; ComputationNodeBasePtr computationNodePtr; auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name(), useMangledNamesForComputationNodes); std::vector inputNodesBasePtrs; for (auto inputNode : inputNodes) inputNodesBasePtrs.push_back(inputNode); // take the dataType from the first input, if not specified (i.e. placeholder) then use default // node like BatchNormalization may have inputs with different precision, // and that validation is done in specific node constructor DataType inputNodeType = AsDataType(); if (inputNodes.size() > 0) { if (std::dynamic_pointer_cast, ComputationNodeBase>(inputNodes[0])) inputNodeType = DataType::Float; else if (std::dynamic_pointer_cast, ComputationNodeBase>(inputNodes[0])) inputNodeType = DataType::Double; else if (std::dynamic_pointer_cast, ComputationNodeBase>(inputNodes[0])) inputNodeType = DataType::Float16; } #define ASSIGN_NEW_NODE(nodeClass, ...) \ do { \ if (inputNodeType == DataType::Float) \ computationNodePtr = New>(__VA_ARGS__); \ else if (inputNodeType == DataType::Double) \ computationNodePtr = New>(__VA_ARGS__); \ else if (inputNodeType == DataType::Float16) \ computationNodePtr = New>(__VA_ARGS__); \ } while(0) #define ASSIGN_NEW_NODE2(nodeClass, dtype, ...) \ do { \ if (inputNodeType == DataType::Float) \ computationNodePtr = New>(__VA_ARGS__); \ else if (inputNodeType == DataType::Double) \ computationNodePtr = New>(__VA_ARGS__); \ else if (inputNodeType == DataType::Float16) \ computationNodePtr = New>(__VA_ARGS__); \ } while(0) auto outputs = function->RawOutputs(); if (variable == outputs[0]) { if (primitiveFunction) { auto functionInputs = function->Inputs(); auto& functionConfig = function->Attributes(); PrimitiveOpType op = primitiveFunction->OpType(); switch (op) { case PrimitiveOpType::Negate: ASSIGN_NEW_NODE(NegateNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sigmoid: ASSIGN_NEW_NODE(SigmoidNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Atanh: ASSIGN_NEW_NODE(AtanhNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Tanh: ASSIGN_NEW_NODE(TanhNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Acos: ASSIGN_NEW_NODE(AcosNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Cos: ASSIGN_NEW_NODE(CosineNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Asin: ASSIGN_NEW_NODE(AsinNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sin: ASSIGN_NEW_NODE(SinNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Atan: ASSIGN_NEW_NODE(AtanNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Tan: ASSIGN_NEW_NODE(TanNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Cosh: ASSIGN_NEW_NODE(CoshNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Asinh: ASSIGN_NEW_NODE(AsinhNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sinh: ASSIGN_NEW_NODE(SinhNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ReLU: ASSIGN_NEW_NODE(RectifiedLinearNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Exp: ASSIGN_NEW_NODE(ExpNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Log: ASSIGN_NEW_NODE(LogNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sqrt: ASSIGN_NEW_NODE(SqrtNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ELU: ASSIGN_NEW_NODE(ExponentialLinearUnitNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Floor: ASSIGN_NEW_NODE(FloorNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Abs: ASSIGN_NEW_NODE(AbsNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Reciprocal: ASSIGN_NEW_NODE(ReciprocalNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Softmax: ASSIGN_NEW_NODE(SoftmaxNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Hardmax: ASSIGN_NEW_NODE(HardmaxNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::StraightThrough: ASSIGN_NEW_NODE(StraightThroughNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::TopK: { auto k = functionConfig[PrimitiveFunction::AttributeNameNumItems].Value(); ASSIGN_NEW_NODE(TopKNode, network->GetDeviceId(), internalNodeName, k); break; } case PrimitiveOpType::StableSigmoid: ASSIGN_NEW_NODE(StableSigmoidNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::TransposeAxes: { if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec)) { auto perm = AsVector(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value>()); for (auto& p : perm) p = NormalizeStaticAxis(p, perm.size()); ASSIGN_NEW_NODE(TransposeDimensionsNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(perm)); } else { auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value(); auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value(); // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based ASSIGN_NEW_NODE(TransposeDimensionsNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2)); } break; } case PrimitiveOpType::Where: { auto dynamicAxes = variable.DynamicAxes(); auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); ASSIGN_NEW_NODE(WhereNode, network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName); break; } case PrimitiveOpType::ToSequence: { auto dynamicAxes = variable.DynamicAxes(); auto internalCNTKDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); ASSIGN_NEW_NODE(ToSequenceNode, network->GetDeviceId(), internalNodeName, internalCNTKDynamicAxisName); break; } case PrimitiveOpType::ToSequenceLike: ASSIGN_NEW_NODE(ToSequenceLikeNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::UnpackSequence: { auto paddingValue = functionConfig[PrimitiveFunction::AttributeNameSequenceUnpackPaddingValue].Value(); auto suppressMaskOutput = functionConfig[PrimitiveFunction::AttributeNameSequenceUnpackSuppressMaskOutput].Value(); ASSIGN_NEW_NODE(UnpackSequenceNode, network->GetDeviceId(), internalNodeName, paddingValue, suppressMaskOutput); break; } case PrimitiveOpType::Slice: { std::vector axis; std::vector beginIndex, endIndex, strides; if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec) && functionConfig.Contains(PrimitiveFunction::AttributeNameBeginIndexVec) && functionConfig.Contains(PrimitiveFunction::AttributeNameEndIndexVec)) { axis = AsVector(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value>()); beginIndex = AsVector(functionConfig[PrimitiveFunction::AttributeNameBeginIndexVec].Value>()); endIndex = AsVector(functionConfig[PrimitiveFunction::AttributeNameEndIndexVec].Value>()); if (functionConfig.Contains(PrimitiveFunction::AttributeNameSliceStridesVec)) strides = AsVector(functionConfig[PrimitiveFunction::AttributeNameSliceStridesVec].Value>()); else strides.resize(axis.size(), 1); } else if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxis) && functionConfig.Contains(PrimitiveFunction::AttributeNameBeginIndex) && functionConfig.Contains(PrimitiveFunction::AttributeNameEndIndex)) { axis.push_back(functionConfig[PrimitiveFunction::AttributeNameAxis].Value()); beginIndex.push_back(functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value()); endIndex.push_back(functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value()); if (functionConfig.Contains(PrimitiveFunction::AttributeNameSliceStrides)) strides.push_back(functionConfig[PrimitiveFunction::AttributeNameSliceStrides].Value()); else strides.push_back(1); } else { RuntimeError("Failed to create computation node: Slice operation with inconsistent attributes"); } // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based ASSIGN_NEW_NODE(SliceNode, network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis), strides); break; } case PrimitiveOpType::RandomSample: { auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); ASSIGN_NEW_NODE(RandomSampleNode, network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); break; } case PrimitiveOpType::RandomSampleInclusionFrequency: { auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); ASSIGN_NEW_NODE(RandomSampleInclusionFrequencyNode, network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); break; } case PrimitiveOpType::Dropout: { auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value(); ASSIGN_NEW_NODE(DropoutNode, network->GetDeviceId(), internalNodeName); SMART_NODE_INVOKE(DropoutNode, computationNodePtr, SetDropoutRate, dropoutRate); break; } case PrimitiveOpType::RandomDistribution: { auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value(); auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value(); auto rvtype = functionConfig[PrimitiveFunction::AttributeNameRandomDistributionType].Value(); std::vector randomDistributionArgs; if (functionConfig.Contains(PrimitiveFunction::AttributeNameRandomDistributionArgs)) randomDistributionArgs = AsVector(functionConfig[PrimitiveFunction::AttributeNameRandomDistributionArgs].Value>()); if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewShape)) { auto shape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value(); ASSIGN_NEW_NODE(RandomDistributionNode, network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs, AsTensorShape(shape)); } else ASSIGN_NEW_NODE(RandomDistributionNode, network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs); SMART_NODE_INVOKE(RandomDistributionNode, computationNodePtr, SetRngState, seed, offset); break; } case PrimitiveOpType::Reshape: { auto beginAxis = Axis(0); auto endAxis = Axis((int)functionInputs[0].Shape().Rank()); if (functionConfig.Contains(PrimitiveFunction::AttributeNameBeginAxis)) beginAxis = functionConfig[PrimitiveFunction::AttributeNameBeginAxis].Value(); if (functionConfig.Contains(PrimitiveFunction::AttributeNameEndAxis)) endAxis = functionConfig[PrimitiveFunction::AttributeNameEndAxis].Value(); auto replacementShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value(); for (size_t i = 0; i < replacementShape.Rank(); ++i) { if (replacementShape[i] == NDShape::InferredDimension) replacementShape[i] = 0; } ASSIGN_NEW_NODE(ReshapeNode, network->GetDeviceId(), internalNodeName, AsTensorShape(replacementShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis)); break; } case PrimitiveOpType::Squeeze: { auto beginAxis = Axis(0); auto inputShape = functionInputs[0].Shape(); auto endAxis = Axis((int)inputShape.Rank()); auto outputShape = GetSqueezedShape(inputShape, functionConfig); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(outputShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis)); break; } case PrimitiveOpType::ConstantOp: { double fillValue = functionConfig[PrimitiveFunction::AttributeNameFillValue].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, fillValue); break; } case PrimitiveOpType::EyeLikeOp: { bool outputSparse = functionConfig[PrimitiveFunction::AttributeNameOutputSparse].Value(); ASSIGN_NEW_NODE(EyeLikeNode, network->GetDeviceId(), internalNodeName, outputSparse); break; } case PrimitiveOpType::ROIPooling: { PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value()); auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value(); auto spatialScale = functionConfig[PrimitiveFunction::AttributeNameSpatialScale].Value(); ASSIGN_NEW_NODE(ROIPoolingNode, network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(roiOutputShape), spatialScale); break; } case PrimitiveOpType::Pooling: { PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value()); auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); auto ceilOutDim = false; auto includePad = false; if (functionConfig.Contains(PrimitiveFunction::AttributeNameCeilOutDim)) { ceilOutDim = functionConfig[PrimitiveFunction::AttributeNameCeilOutDim].Value(); } if (functionConfig.Contains(PrimitiveFunction::AttributeNameIncludePad)) { includePad = functionConfig[PrimitiveFunction::AttributeNameIncludePad].Value(); } ASSIGN_NEW_NODE(PoolingNode, network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutDim, includePad, ImageLayoutKind::CHW); break; } case PrimitiveOpType::Unpooling: { auto unpoolingWindowShape = functionConfig[PrimitiveFunction::AttributeNameUnpoolingWindowShape].Value(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); //We only get here after validation so it is safe to assume unpooling is max ASSIGN_NEW_NODE(MaxUnpoolingNode, network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW); break; } case PrimitiveOpType::SumAll: ASSIGN_NEW_NODE(SumElementsNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::OneHot: { auto numClass = functionConfig[PrimitiveFunction::AttributeNameNumClass].Value(); auto is_sparse = functionConfig[PrimitiveFunction::AttributeNameOneHotOutputSparse].Value(); auto axis = functionConfig[PrimitiveFunction::AttributeNameOneHotAxis].Value(); ASSIGN_NEW_NODE(OneHotNode, network->GetDeviceId(), numClass, is_sparse, axis.StaticAxisIndex(), internalNodeName); break; } case PrimitiveOpType::Gather: ASSIGN_NEW_NODE(GatherNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ToBatch: { ASSIGN_NEW_NODE(ToBatchAxisNode, network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::UnpackBatch: { ASSIGN_NEW_NODE(UnpackBatchAxisNode, network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::Plus: ASSIGN_NEW_NODE(PlusNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LogPlus: ASSIGN_NEW_NODE(LogPlusNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Pow: ASSIGN_NEW_NODE(PowNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Minus: ASSIGN_NEW_NODE(MinusNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ElementTimes: ASSIGN_NEW_NODE(ElementTimesNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Equal: ASSIGN_NEW_NODE(EqualNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::NotEqual: ASSIGN_NEW_NODE(NotEqualNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Less: ASSIGN_NEW_NODE(LessNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LessEqual: ASSIGN_NEW_NODE(LessEqualNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Greater: ASSIGN_NEW_NODE(GreaterNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::GreaterEqual: ASSIGN_NEW_NODE(GreaterEqualNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Times: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value(); ASSIGN_NEW_NODE(TimesNode, network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap); break; } case PrimitiveOpType::TransposeTimes: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); ASSIGN_NEW_NODE(TransposeTimesNode, network->GetDeviceId(), internalNodeName, outputRank); break; } case PrimitiveOpType::Convolution: { auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); NDShape dilation = { 1 }; if (functionConfig.Contains(PrimitiveFunction::AttributeNameDilation)) dilation = functionConfig[PrimitiveFunction::AttributeNameDilation].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); NDShape outputMapCount, kernelShape; std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose); NDShape outputShape = NDShape::Unknown(); if (functionConfig.Contains(PrimitiveFunction::AttributeNameOutputShape)) outputShape = functionConfig[PrimitiveFunction::AttributeNameOutputShape].Value(); auto groups = PrimitiveFunction::convolutionOpDefaultValueForGroups; if (functionConfig.Contains(PrimitiveFunction::AttributeNameGroups)) groups = functionConfig[PrimitiveFunction::AttributeNameGroups].Value(); auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value(); ASSIGN_NEW_NODE(ConvolutionNode, network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape), ImageLayoutKind::CHW, maxTempMemSizeInSamples, AsTensorShape(dilation), groups); break; } case PrimitiveOpType::ConvolutionSequenceShape: { auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); NDShape dilation = { 1 }; if (functionConfig.Contains(PrimitiveFunction::AttributeNameDilation)) dilation = functionConfig[PrimitiveFunction::AttributeNameDilation].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); NDShape outputMapCount, kernelShape; std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose); NDShape outputShape = NDShape::Unknown(); if (functionConfig.Contains(PrimitiveFunction::AttributeNameOutputShape)) outputShape = functionConfig[PrimitiveFunction::AttributeNameOutputShape].Value(); auto groups = PrimitiveFunction::convolutionOpDefaultValueForGroups; if (functionConfig.Contains(PrimitiveFunction::AttributeNameGroups)) groups = functionConfig[PrimitiveFunction::AttributeNameGroups].Value(); auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value(); ASSIGN_NEW_NODE(ConvolutionSequenceShapeNode, network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape), ImageLayoutKind::CHW, maxTempMemSizeInSamples, AsTensorShape(dilation), groups); break; } case PrimitiveOpType::CosDistance: ASSIGN_NEW_NODE(CosDistanceNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::CosDistanceWithNegativeSamples: ASSIGN_NEW_NODE(CosDistanceWithNegativeSamplesNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Logistic: ASSIGN_NEW_NODE(LogisticNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::SquaredError: ASSIGN_NEW_NODE(SquareErrorNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::CrossEntropyWithSoftmax: ASSIGN_NEW_NODE(CrossEntropyWithSoftmaxNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ClassificationError: ASSIGN_NEW_NODE(ClassificationErrorNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::EditDistanceError: { auto subPen = functionConfig[PrimitiveFunction::AttributeNameSubstitutionPenalty].Value(); auto delPen = functionConfig[PrimitiveFunction::AttributeNameDeletionPenalty].Value(); auto insPen = functionConfig[PrimitiveFunction::AttributeNameInsertionPenalty].Value(); auto squashInputs = functionConfig[PrimitiveFunction::AttributeNameSquashInputs].Value(); auto tokensToIgnore = AsVector(functionConfig[PrimitiveFunction::AttributeNameTokensToIgnore].Value>()); ASSIGN_NEW_NODE(EditDistanceErrorNode, network->GetDeviceId(), internalNodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore); break; } case PrimitiveOpType::LatticeSequenceWithSoftmax: { auto symListPath = functionConfig[PrimitiveFunction::AttributeNameSymListPath].Value(); auto phonePath = functionConfig[PrimitiveFunction::AttributeNamePhonePath].Value(); auto stateListPath = functionConfig[PrimitiveFunction::AttributeNameStateListPath].Value(); auto transProbPath = functionConfig[PrimitiveFunction::AttributeNameTransProbPath].Value(); auto latticeConfigPath = functionConfig[PrimitiveFunction::AttributeNameLatticeConfigPath].Value(); auto frameDropThresh = functionConfig[PrimitiveFunction::AttributeNameFrameDropThresh].Value(); auto doReferenceAlign = functionConfig[PrimitiveFunction::AttributeNameDoReferenceAlign].Value(); auto seqGammarUsesMBR = functionConfig[PrimitiveFunction::AttributeNameSeqGammarUsesMBR].Value(); auto seqGammarAMF = functionConfig[PrimitiveFunction::AttributeNameSeqGammarAMF].Value(); auto seqGammarLMF = functionConfig[PrimitiveFunction::AttributeNameSeqGammarLMF].Value(); auto seqGammarBMMIFactor = functionConfig[PrimitiveFunction::AttributeNameSeqGammarBMMIFactor].Value(); auto seqGammarWordPen = functionConfig[PrimitiveFunction::AttributeNameSeqGammarWordPen].Value(); auto hSmoothingWeight = functionConfig[PrimitiveFunction::AttributeNameHSmoothingWeight].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, symListPath, phonePath, stateListPath, transProbPath, latticeConfigPath, hSmoothingWeight, frameDropThresh, doReferenceAlign, seqGammarUsesMBR, seqGammarAMF, seqGammarLMF, seqGammarBMMIFactor, seqGammarWordPen); break; } case PrimitiveOpType::ForwardBackward: { auto delayContraint = functionConfig[PrimitiveFunction::AttributeNameDelayConstraint].Value(); auto blankTokenId = functionConfig[PrimitiveFunction::AttributeNameBlankTokenId].Value(); ASSIGN_NEW_NODE(ForwardBackwardNode, network->GetDeviceId(), internalNodeName, blankTokenId, delayContraint); break; } case PrimitiveOpType::LambdaRank: ASSIGN_NEW_NODE(LambdaRankNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::NDCG: ASSIGN_NEW_NODE(NDCG1EvalNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::PastValue: case PrimitiveOpType::FutureValue: { Variable inputOperandVar = functionInputs[0]; Variable initialStateVar = functionInputs[1]; size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value(); if (op == PrimitiveOpType::PastValue) ASSIGN_NEW_NODE(PastValueNode, network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); else ASSIGN_NEW_NODE(FutureValueNode, network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); break; } case PrimitiveOpType::ReduceElements: { bool keepDimensions = true; if (functionConfig.Contains(PrimitiveFunction::AttributeNameReductionKeepDimensions)) keepDimensions = functionConfig[PrimitiveFunction::AttributeNameReductionKeepDimensions].Value(); auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value(); std::vector reductionAxis; if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec)) { reductionAxis = AsVector(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value>()); } else if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxis)) { reductionAxis.push_back(functionConfig[PrimitiveFunction::AttributeNameAxis].Value()); } else { RuntimeError("Failed to create computation node': Reduce operation %ls with no '%ls' or '%ls' attributes", PrimitiveOpTypeName(op).c_str(), PrimitiveFunction::AttributeNameAxis.c_str(), PrimitiveFunction::AttributeNameAxisVec.c_str() ); } ASSIGN_NEW_NODE(ReduceElementsNode, network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis), keepDimensions); break; } case PrimitiveOpType::BatchNormalization: { auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value(); auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value(); auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value(); auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value(); auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value(); bool disableRegularization = false; if (functionConfig.Contains(PrimitiveFunction::AttributeNameDisableRegularization)) { disableRegularization = functionConfig[PrimitiveFunction::AttributeNameDisableRegularization].Value(); } ASSIGN_NEW_NODE(BatchNormalizationNode, network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, disableRegularization, ImageLayoutKind::CHW); break; } case PrimitiveOpType::Combine: // This operation is just a no-op and is a means to combine multiple functions to create a single Function // whose outputs are a union of the outputs of the Functions being combined. computationNodePtr = variableToNodeMap[variable]; break; case PrimitiveOpType::PackedIndex: ASSIGN_NEW_NODE(PackedIndexNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::GatherPacked: ASSIGN_NEW_NODE(GatherPackedNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ScatterPacked: ASSIGN_NEW_NODE(ScatterPackedNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Clip: ASSIGN_NEW_NODE(ClipNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Select: ASSIGN_NEW_NODE(IfNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Splice: { Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); ASSIGN_NEW_NODE(RowStackNode, network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis)); break; } case PrimitiveOpType::Pad: { auto head = AsVector(functionConfig[PrimitiveFunction::AttributeNamePaddingHead].Value>()); auto foot = AsVector(functionConfig[PrimitiveFunction::AttributeNamePaddingFoot].Value>()); auto mode = functionConfig[PrimitiveFunction::AttributeNamePaddingMode].Value(); auto constantValue = functionConfig[PrimitiveFunction::AttributeNamePaddingConstantValue].Value(); ASSIGN_NEW_NODE(PaddingNode, network->GetDeviceId(), internalNodeName, head, foot, (PaddingType)mode, constantValue); break; } case PrimitiveOpType::OptimizedRNNStack: { auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value(); auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value(); auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value(); auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value(); ASSIGN_NEW_NODE(OptimizedRNNStackNode, network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp); break; } case PrimitiveOpType::ReconcileDynamicAxis: { ASSIGN_NEW_NODE(ReconcileDynamicAxisNode, network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::LogSoftmax: { //This can be implemented as x => x - ReduceLogSum(x). How to do this here? ASSIGN_NEW_NODE(LogSoftmaxNode, network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::Pass: ASSIGN_NEW_NODE(PassNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LabelsToGraph: ASSIGN_NEW_NODE(LabelsToGraphNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::StopGradient: ASSIGN_NEW_NODE(StopGradientNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Assign: ASSIGN_NEW_NODE(AssignNode, network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Crop: if (functionInputs.size() == 2) { if (functionConfig.Contains(PrimitiveFunction::AttributeNameOffset)) { // Crop with given offsets. const auto& offsets = AsVector(functionConfig[PrimitiveFunction::AttributeNameOffset].Value>()); if (offsets.size() != 2) { CNTK::LogicError("Vector of crop offsets must have size 2."); } ASSIGN_NEW_NODE(CropNode, offsets[0], offsets[1], network->GetDeviceId(), internalNodeName); } else { // Crop with two inputs and automatic offset computation. ASSIGN_NEW_NODE(CropNode, network->GetDeviceId(), internalNodeName); } } else if (functionInputs.size() == 4) { // Crop with four inputs and automatic offset computation. ASSIGN_NEW_NODE(CropNode, network->GetDeviceId(), internalNodeName); } else { CNTK::LogicError("Crop node must have 2 or 4 node inputs."); } break; case PrimitiveOpType::Cast: { DataType outputType = (DataType)functionConfig[PrimitiveFunction::AttributeNameNewDataType].Value(); switch (outputType) { case DataType::Float: ASSIGN_NEW_NODE2(CastNode, float, network->GetDeviceId(), internalNodeName); break; case DataType::Double: ASSIGN_NEW_NODE2(CastNode, double, network->GetDeviceId(), internalNodeName); break; case DataType::Float16: ASSIGN_NEW_NODE2(CastNode, half, network->GetDeviceId(), internalNodeName); break; } break; } case PrimitiveOpType::CustomProxyOp: { ASSIGN_NEW_NODE(CustomProxyOpNode, network->GetDeviceId(), internalNodeName); break; } default: CNTK::LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str()); break; } // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs); if (computationNodePtr->Is()) { auto computationNodeExpectedInputCount = computationNodePtr->As()->GetExpectedNumInputs(); if (computationNodeExpectedInputCount != inputNodesBasePtrs.size()) CNTK::LogicError("The Primitive Function '%S' has %d inputs while the corresponding ComputationNode expects %d inputs.", function->AsString().c_str(), (int)inputNodesBasePtrs.size(), (int)computationNodeExpectedInputCount); } if (computationNodePtr->Is()) { auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value(); auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value(); computationNodePtr->As()->SetRngState(seed, offset); } } else { ASSIGN_NEW_NODE(UserDefinedV2FunctionNode, network->GetDeviceId(), internalNodeName, function->shared_from_this()); // For user defined functions, we only attach unique inputs in the internal computation network since, the UDF // backward implementations directly compute aggregate gradient values for unique inputs std::vector uniqueInputNodesBasePtrs; for (auto inputNodeBasePtr : inputNodesBasePtrs) { if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end()) uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr); } inputNodesBasePtrs = uniqueInputNodesBasePtrs; } } else { size_t i = 1; while (outputs[i] != variable) i++; assert(i < outputs.size()); ASSIGN_NEW_NODE(OutputMultiplexerNode, network->GetDeviceId(), CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes), i); inputNodesBasePtrs = { variableToNodeMap[outputs[0]] }; } network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs); return computationNodePtr; } template /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder& builder, const std::unordered_map& fullyDefinedArgumentsMap, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap, const std::unordered_set& inputsToExcludeGradientsFor, bool useMangledNamesForComputationNodes) { assert(variable.IsOutput()); Function* function = variable.Owner().get(); ComputationNodeBasePtr computationNodePtr; auto& functionInputs = function->m_inputs; DataType nonConstInputDataType = DataType::Unknown; for (auto& inputVar : functionInputs) { if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown)) { nonConstInputDataType = inputVar.GetDataType(); break; } } // Create the nodes corresponding to the inputs std::vector> inputNodes; for (auto& inputVar : functionInputs) { // If the inputVar is a constant and not the right DataType let's coerce it to the right type // except for FP16 that mismatch is needed (e.g. BatchNorm stats in FP16 need to be FP32) if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (nonConstInputDataType != DataType::Float16) && (inputVar.GetDataType() != nonConstInputDataType)) inputVar = Constant(inputVar).CloneAs(nonConstInputDataType); auto baseNodePtr = GetNode(inputVar, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes); inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr : nullptr); } BlockFunction* blockFunction = dynamic_cast(function); if (blockFunction) { // For block function, map each argument placeholder of the underlying composite to // the computation node corresponding to the block input that the argument placeholder // of the composite is mapped to. auto compositeArguments = blockFunction->Composite()->Arguments(); for (auto compositeArgument : compositeArguments) variableToNodeMap[compositeArgument] = variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping()); return GetNode(variable.BlockFunctionVariableMapping(), network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes); } else { switch (variable.GetDataType()) { case DataType::Float: computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap, useMangledNamesForComputationNodes); break; case DataType::Double: computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap, useMangledNamesForComputationNodes); break; case DataType::Float16: computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap, useMangledNamesForComputationNodes); break; default: RuntimeError("Unsupported variable data type for CreateComputationNode"); } } PrimitiveFunction* primitiveFunction = dynamic_cast(function); if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine)) { for (auto inputVar : functionInputs) isVariableRootMap[inputVar] = false; } return computationNodePtr; } std::unordered_set CompositeFunction::NonOwnerPreservingCopy(const std::unordered_set& outputs) { std::unordered_set result; for (auto& o : outputs) { Variable sanitized = o.NonCompositePreservingCopy(); result.insert(sanitized); } return result; } static bool VariableShapeMatchesNodeShape(const NDShape& varShape, const TensorShape& nodeShape) { if (varShape.Rank() == 0) return (nodeShape.GetNumElements() == 1); // Sometimes the nodeShape may have an additional trailing axis with dim==1 due to the lack of support for 0-d tensors in V1 engine. auto adjustedNodeShape = nodeShape; while ((adjustedNodeShape.GetRank() > varShape.Rank()) && (adjustedNodeShape.GetDim(adjustedNodeShape.GetRank() - 1) == 1)) adjustedNodeShape.TrimRankInPlace(adjustedNodeShape.GetRank() - 1); if (!varShape.HasUnboundDimension()) return (AsNDShape(adjustedNodeShape) == varShape); if (varShape.Rank() != adjustedNodeShape.GetRank()) return false; for (size_t i = 0; i < varShape.Rank(); ++i) { if ((varShape[i] != NDShape::FreeDimension) && (varShape[i] != NDShape::InferredDimension) && (varShape[i] != adjustedNodeShape.GetDim(i))) return false; } return true; } template std::pair> CompositeFunction::CreateComputationNetwork(const FunctionPtr& compositeFunction, const DeviceDescriptor& device, const std::unordered_set& outputs, const std::unordered_map& fullyDefinedArgumentsMap, const std::unordered_set& inputsExcludedFromGradientComputation, bool useMangledNamesForComputationNodes) { auto computationNetwork = std::make_shared(AsCNTKImplDeviceId(device)); ComputationNetworkBuilder builder(*computationNetwork); std::unordered_map isVariableRootMap; std::unordered_map variableToNodeMap; // Now recursively create the network in a top-down fashion auto rootFunction = compositeFunction->RootFunction(); auto rootFunctionOutputs = rootFunction->RawOutputs(); for (auto rootOutput : rootFunctionOutputs) GetNode(rootOutput, computationNetwork, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsExcludedFromGradientComputation, useMangledNamesForComputationNodes); // We need to patch the Computation node mappings for the arguments of block functions // since for recurrent inputs, the mappings are not fully established the first time std::function PatchBlockArgumentsAndOutputsMapping; PatchBlockArgumentsAndOutputsMapping = [&variableToNodeMap, &PatchBlockArgumentsAndOutputsMapping](const Variable& var) { if (var.IsOutput()) { BlockFunction* blockFunction = dynamic_cast(var.Owner().get()); if (blockFunction) { PostorderTraverseVariables(blockFunction->BlockRoot(), PatchBlockArgumentsAndOutputsMapping); auto compositeArguments = blockFunction->Composite()->Arguments(); for (auto compositeArgument : compositeArguments) { auto mappingVarNodeIter = variableToNodeMap.find(compositeArgument.BlockFunctionVariableMapping()); if (mappingVarNodeIter != variableToNodeMap.end()) variableToNodeMap[compositeArgument] = mappingVarNodeIter->second; } auto mappingVarNodeIter = variableToNodeMap.find(var.BlockFunctionVariableMapping()); if (mappingVarNodeIter != variableToNodeMap.end()) variableToNodeMap[var] = mappingVarNodeIter->second; } } }; PostorderTraverseVariables(rootFunction, PatchBlockArgumentsAndOutputsMapping); std::function IsVariableRoot = [&isVariableRootMap, &IsVariableRoot](const Variable& outputVar) { auto mappingVariable = GetMappingVariable(outputVar); return (isVariableRootMap[outputVar] && !IsFirstOutputOfMultiOutputFunction(mappingVariable) && ((mappingVariable == outputVar) || IsVariableRoot(mappingVariable))); }; // If any of the function or requested outputs is not a root node, we need to explicitly // add it to the 'output' group of the ComputationNetwork std::unordered_set networkOutputs(outputs); networkOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end()); for (auto output : networkOutputs) { if (!IsVariableRoot(output)) { auto computationNode = variableToNodeMap[output]; if (!computationNode) InvalidArgument("One of the requested outputs '%S' for the Function '%S' forward computation is not part of the graph underlying the Function.", output.AsString().c_str(), compositeFunction->AsString().c_str()); computationNetwork->AddToNodeGroup(L"output", computationNode); } } // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles. // Now attach those after we have created all ComputationNodes in the network for (auto varNodePair : variableToNodeMap) { auto& currentComputationNode = varNodePair.second; if (!currentComputationNode) LogicError("Function '%S': No computation node mapping exists for Variable %S.", compositeFunction->AsString().c_str(), varNodePair.first.AsString().c_str()); auto& currentComputationNodeInputs = currentComputationNode->GetInputs(); auto& currentVar = varNodePair.first; if (!currentVar.IsOutput()) continue; if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end()) { // This ComputationNode has at least one null input which now needs to be properly attached const PrimitiveFunction* primitiveFunc = dynamic_cast(currentVar.Owner().get()); if (primitiveFunc == nullptr) { LogicError("Non-primitive function '%S' cannot be a part of the CNTK recurrent loop.", currentVar.Owner()->Name().c_str()); } // Skip block primitives since they do not directly map to a computation node if (primitiveFunc->OpType() == PrimitiveOpType::Block) continue; // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering auto inputVars = primitiveFunc->Inputs(); ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars); inputVars.resize(currentComputationNode->GetNumInputs()); std::vector inputNodesBasePtrs; for (auto inputVar : inputVars) inputNodesBasePtrs.push_back(variableToNodeMap.at(inputVar)); currentComputationNode->AttachInputs(inputNodesBasePtrs); } } computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel()); computationNetwork->SetTrackGapNans(GetCheckedMode()); computationNetwork->SetIsV2Library(true); computationNetwork->CompileNetwork(); // Set EvalTimeStamp of all nodes in the network as "outdated" to make sure that all nodes will be evaluated at least once. // During CompileNetwork(), nodes in the network might get different timestamp values because other threads could update the global timestamp value. // (The global timestamp value is currently shared process-wide, i.e. among all nodes of all networks.) The nodes with a higher timestamp value are // thus incorrectly treated as "updated", and their inputs are not further evaluated by ComputationNetwork::PARTraversalFlowControlNode::ForwardProp(). // This could lead to incorrect results or crash, because the matrix of the input nodes might never be initialized for ForwardProp(). computationNetwork->SetEvalTimeStampsOutdatedWithRegardToAll(); // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork for (auto varNodePair : variableToNodeMap) { if (varNodePair.first.IsOutput()) { auto outputVar = varNodePair.first; auto computationNodePtr = variableToNodeMap.at(outputVar); auto outputShape = outputVar.Shape(); auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout(); if (!VariableShapeMatchesNodeShape(outputShape, computationNodeSampleLayout)) { LogicError("Function '%S': The output Variable '%S' shape '%S' does not match the SampleLayout shape '[%s]' of the corresponding ComputationNode in the network.", compositeFunction->AsString().c_str(), outputVar.AsString().c_str(), outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str()); } } } return { computationNetwork, variableToNodeMap }; } template ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, const std::unordered_set& outputs, const std::unordered_set& inputsToExcludeGradientsFor, bool allocateNetworkMatrices) { // Lets purge the current computation network and regenerate the network if the CompositeFunction // was previously compiled just for evaluation and not for gradient backpropagation. if ((m_computationNetwork != nullptr) && (m_currentBackpropRoots.empty() && !backpropRoots.empty())) PurgeComputationNetwork(); if (m_computationNetwork != nullptr) { // TODO: We should either invalidate and readapt the network if the backpropRoots change compared to what was specified when the network // was last constructed, to just recreate a new network. // For now just disallow changing the backpropRoots after the network is created if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots)) LogicError("Function '%S': Changing backprop roots (Current = '%S', New = '%S') across different Forward calls on a CNTK composite Function is currently unsupported.", AsString().c_str(), NamedListString(m_currentBackpropRoots).c_str(), NamedListString(backpropRoots).c_str()); // TODO: Support changing the device across different invocations of the forward method on a Function instance if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device) LogicError("Function '%S': Changing device (Current = '%S', New = %S') across different Forward calls on a CNTK composite Function is currently unsupported.", AsString().c_str(), AsDeviceDescriptor(m_computationNetwork->GetDeviceId()).AsString().c_str(), device.AsString().c_str()); if (!backpropRoots.empty() && (inputsToExcludeGradientsFor != m_inputsExcludedFromGradientComputation)) LogicError("Function '%S': Changing the set of inputs to exclude from gradient computation, across different Forward calls on a CNTK composite Function, is currently unsupported.", AsString().c_str()); // Verify if the free dimensions of any of the arguments have changed, and if so, update the corresponding // input ComputationNodes and rerun validation on the computation network for (auto freeDimensionArgumentMapping : m_fullyDefinedArgumentsMap) { auto newShape = freeDimensionArgumentMapping.second.Shape(); auto argumentComputationNode = m_variableToNodeMap[freeDimensionArgumentMapping.first]; if (AsTensorShape(newShape) != argumentComputationNode->GetSampleLayout()) argumentComputationNode->SetDims(AsTensorShape(newShape), argumentComputationNode->HasMBLayout()); } } else { auto networkInputs = this->Inputs(); for (auto inputExcluded : inputsToExcludeGradientsFor) { // Only inputs of the network can be excluded from gradient computation if (std::find(networkInputs.begin(), networkInputs.end(), inputExcluded) == networkInputs.end()) InvalidArgument("Variable '%S' specified for exclusion from gradient computation is not an input of the Function '%S'. " "Only an input of the Function can be explicitly excluded from gradient computation.", inputExcluded.AsString().c_str(), this->AsString().c_str()); } m_inputsExcludedFromGradientComputation = NonOwnerPreservingCopy(inputsToExcludeGradientsFor); m_currentBackpropRoots = NonOwnerPreservingCopy(backpropRoots); // TODO: We currently only support one backprop root if (backpropRoots.size() > 1) LogicError("Function '%S': %d backprop roots specified; currently at most one backprop root is supported.", AsString().c_str(), (int)backpropRoots.size()); auto placeholders = Placeholders(); if (!placeholders.empty()) InvalidArgument("%d unbound Placeholder(s) '%S' found in the Function. " "All Placeholders of a Function must be bound (to a variable) before performing a Forward computation.", (int)placeholders.size(), NamedListString(placeholders).c_str()); // Lets update the composite Function graph's inputs with any inferred dimensions that // were determined from the shapes of the supplied data auto networkArguments = Arguments(); for (auto argument : networkArguments) { if (argument.Shape().HasInferredDimension()) { auto fullyDefinedArgument = m_fullyDefinedArgumentsMap.at(argument); for (size_t i = 0; i < argument.Shape().Rank(); ++i) if (argument.Shape()[i] == NDShape::InferredDimension) argument.m_dataFields->m_shape[i] = fullyDefinedArgument.Shape()[i]; } } // Run the final validation on the entire network once before constructing/compiling the // internal computation network ValidateOrUpdateOutputs(); std::tie(m_computationNetwork, m_variableToNodeMap) = CreateComputationNetwork(this->shared_from_this(), device, outputs, m_fullyDefinedArgumentsMap, m_inputsExcludedFromGradientComputation, /*useMangledNamesForComputationNodes =*/ false); // Record the timestamps of Parameters and Constants assert(m_lastRecordedTimeStamps.empty()); auto functionParameters = Parameters(); for (auto parameter : functionParameters) m_lastRecordedTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() }); auto functionConstants = Constants(); for (auto constant : functionConstants) m_lastRecordedTimeStamps.insert({ constant, constant.CurrentValueTimeStamp() }); // Collect parameters and constants being assigned to PreorderTraverseFunctions(RootFunction(), [this](const FunctionPtr& function) { auto primitiveFunction = dynamic_cast(function.get()); if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::Assign)) m_refVariables.insert(primitiveFunction->Inputs()[0]); }, /*nestedSearchInsideBlockFunction =*/ true); } if (!m_networkMatricesAllocated && allocateNetworkMatrices) { m_allNetworkRoots = m_currentBackpropRoots; ComputationNodeBasePtr backpropRootNode; if (!m_currentBackpropRoots.empty()) backpropRootNode = m_variableToNodeMap.at(*m_currentBackpropRoots.begin()); // Now recursively traverse the network in a top-down fashion auto rootFunction = RootFunction(); auto rootFunctionOutputs = rootFunction->RawOutputs(); m_allNetworkRoots.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end()); std::vector forwardRootNodes; for (auto rootOutput : rootFunctionOutputs) forwardRootNodes.push_back(m_variableToNodeMap.at(rootOutput)); std::vector forwardOutputNodes; m_allNetworkRoots.insert(outputs.begin(), outputs.end()); for (auto output : outputs) forwardOutputNodes.push_back(m_variableToNodeMap.at(output)); m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode); m_networkMatricesAllocated = allocateNetworkMatrices; } else { // Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure // in the cached computation network for (auto output : outputs) { if (m_allNetworkRoots.find(output) == m_allNetworkRoots.end()) LogicError("Function '%S': Requested output '%S' is not part of the list of outputs '%S' that the Function was initially compiled for. " "Changing requested outputs across different Forward calls is currently unsupported.", AsString().c_str(), output.AsString().c_str(), NamedListString(m_allNetworkRoots).c_str()); } } return m_computationNetwork; } template /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair& variableValue, ComputationNodeBasePtr& computationNode, std::unordered_map& layoutsPopulated) { NDShape inferredVariableShape; std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject(variableValue.first, variableValue.second, &inferredVariableShape); if (!VariableShapeMatchesNodeShape(inferredVariableShape, computationNode->GetSampleLayout())) CNTK::LogicError("CompositeFunction::Forward: Inferred shape '%S' of Variable '%S' does not match the corresponding computation node shape '%s'.", inferredVariableShape.AsString().c_str(), variableValue.first.AsString().c_str(), ((std::string)computationNode->GetSampleLayout()).c_str()); // Switch the node matrix to the right matrix type auto& nodeData = computationNode->As>()->Value(); nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first); auto layout = CNTKMatrixAndMBLayout.second; auto& nodeLayout = computationNode->GetMBLayout(); if ((layout == nullptr) != (nodeLayout == nullptr)) InvalidArgument("The layout of the specified Value for Variable '%S' is incompatible with the layout of the corresponding ComputationNode.", variableValue.first.AsString().c_str()); else if (layout) { if (layoutsPopulated.find(nodeLayout) == layoutsPopulated.end()) { nodeLayout->CopyFrom(layout); layoutsPopulated.insert({ nodeLayout, variableValue.first }); } else { if (*nodeLayout != *layout) InvalidArgument("Different minibatch layouts detected (difference in sequence lengths or count or start flags) in data specified " "for the Function's arguments '%S' vs. '%S', though these arguments have the same dynamic axes '%S'", variableValue.first.AsString().c_str(), layoutsPopulated.at(nodeLayout).AsString().c_str(), DynamicAxesAsString(variableValue.first.DynamicAxes(), Internal::IsReversingTensorShapesInErrorMessagesEnabled()).c_str()); } } } std::unordered_map CompositeFunction::InferFreeDimensionsOfArguments(const std::unordered_map& arguments) { std::unordered_map inferredArgumentDimensions; for (auto argumentValuePair : arguments) { NDShape inferredVarShape; Utils::VerifyVariableValueCompatibility(argumentValuePair.first, argumentValuePair.second, &inferredVarShape); if (inferredVarShape != argumentValuePair.first.Shape()) inferredArgumentDimensions.insert({ argumentValuePair.first , inferredVarShape }); } if (!inferredArgumentDimensions.empty()) { if (m_fullyDefinedArgumentsMap.empty()) { for (auto inferredArgumentShapePair : inferredArgumentDimensions) { auto fullyDefinedArgument = inferredArgumentShapePair.first.Clone(); fullyDefinedArgument.m_dataFields->m_shape = inferredArgumentShapePair.second; m_fullyDefinedArgumentsMap.insert({ inferredArgumentShapePair.first, fullyDefinedArgument }); } if (GetCheckedMode()) m_latestFullyDefinedCompositeForCheckedModeValidation = this->Clone(ParameterCloningMethod::Share, m_fullyDefinedArgumentsMap); } else { bool argumentShapeChangedSinceLastTime = false; for (auto inferredArgumentShapePair : inferredArgumentDimensions) { if (inferredArgumentShapePair.second != m_fullyDefinedArgumentsMap[inferredArgumentShapePair.first].Shape()) { argumentShapeChangedSinceLastTime = true; m_fullyDefinedArgumentsMap[inferredArgumentShapePair.first].m_dataFields->m_shape = inferredArgumentShapePair.second; } } if (argumentShapeChangedSinceLastTime && m_latestFullyDefinedCompositeForCheckedModeValidation) m_latestFullyDefinedCompositeForCheckedModeValidation->ValidateOrUpdateOutputs(); } } return inferredArgumentDimensions; } void CompositeFunction::PopulateNetworkInputs(const std::unordered_map& arguments) { std::unordered_map layoutsPopulated; std::vector inputNodes; for (auto argumentValuePair : arguments) { auto argument = argumentValuePair.first; auto argumentComputationNode = m_variableToNodeMap.at(argument); assert(argumentComputationNode); inputNodes.push_back(argumentComputationNode); ValuePtr argumentValue = arguments.at(argument); switch (argumentValue->GetDataType()) { case DataType::Float: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode, layoutsPopulated); break; case DataType::Double: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode, layoutsPopulated); break; case DataType::Float16: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode, layoutsPopulated); break; default: LogicError("Function '%S' Forward: Unsupported DataType %s.", AsString().c_str(), DataTypeName(argumentValue->GetDataType())); break; } } m_computationNetwork->BumpEvalTimeStamp(inputNodes); } template /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode) { NDShape inferredVariableShape; std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject(variableGradient.first, variableGradient.second, &inferredVariableShape); if (!VariableShapeMatchesNodeShape(inferredVariableShape, computationNode->GetSampleLayout())) CNTK::LogicError("CompositeFunction::Backward: Inferred shape '%S' of Variable '%S' does not match the corresponding computation node shape '%s'.", inferredVariableShape.AsString().c_str(), variableGradient.first.AsString().c_str(), ((std::string)computationNode->GetSampleLayout()).c_str()); MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; auto nodeLayout = computationNode->GetMBLayout(); if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout))) InvalidArgument("The layout of the specified gradient Value for Variable '%S' is incompatible with the layout computed during Forward call.", variableGradient.first.AsString().c_str()); computationNode->As>()->AssignGradient(*CNTKMatrixAndMBLayout.first); } // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph void CompositeFunction::PopulateNetworkGradients(const std::unordered_map& gradients) { auto functionOutputs = RawOutputs(); for (auto gradientVarValuePair : gradients) { auto outputComputationNode = m_variableToNodeMap.at(gradientVarValuePair.first); ValuePtr gradientValue = gradientVarValuePair.second; switch (gradientValue->GetDataType()) { case DataType::Float: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; case DataType::Double: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; case DataType::Float16: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; default: LogicError("Function '%S' Backward: Unsupported DataType %s.", AsString().c_str(), DataTypeName(gradientValue->GetDataType())); break; } } } /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient) { auto varShape = GetVariableShape(var.Shape(), computationNode->GetSampleLayout()); auto valueShape = PackedValue::GetUnpackedShape(varShape, var.DynamicAxes(), computationNode->GetMBLayout()); if (varValue != nullptr) { // TODO: The shape of the specified output Value object must match the actual output shape if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape))) ::InvalidArgument("The shape %S of the specified Value object for Variable '%S' %s does not match the actual shape %S", varValue->Shape().AsString().c_str(), var.AsString().c_str(), getGradient ? "gradient" : "output", valueShape.AsString().c_str()); } ValuePtr nodeValue; auto layout = computationNode->GetMBLayout(); switch (var.GetDataType()) { case DataType::Float: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(varShape, var.DynamicAxes(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout(var, computationNode, matrix, layout); break; } case DataType::Double: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(varShape, var.DynamicAxes(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout(var, computationNode, matrix, layout); break; } case DataType::Float16: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(varShape, var.DynamicAxes(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout(var, computationNode, matrix, layout); break; } default: CNTK::LogicError("CompositeFunction::Forward/Backward: Unsupported DataType %s", DataTypeName(var.GetDataType())); break; } if (varValue == nullptr) varValue = nodeValue; else varValue->CopyFrom(*nodeValue); } void CompositeFunction::GetNetworkOutputs(std::unordered_map& outputs) { // Now copy the Forward values of output nodes from the network to outputs' Value objects for (auto& outputVarValuePair : outputs) { auto& valuePtr = outputVarValuePair.second; auto node = m_variableToNodeMap.at(outputVarValuePair.first); bool noValueStrorageProvided = (valuePtr == nullptr); GetNodeOutputOrGradient(outputVarValuePair.first, valuePtr, node, false /*getGradient*/); auto packedVarValue = std::dynamic_pointer_cast(valuePtr); if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked()) m_existingNetworkStorageReferences.push_back(packedVarValue); } } void CompositeFunction::GetNetworkGradients(std::unordered_map& gradients) { auto networkInputs = this->Inputs(); // Now copy the gradient values of input nodes of the network to gradients' Value objects for (auto& gradientVarValuePair : gradients) { // Only gradients corresponding to inputs of the network can be obtained if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end()) InvalidArgument("Gradient requested for Variable '%S' which is not a leaf input (Input, Parameter or Constant) of the Function '%S'; this is currently unsupported.", gradientVarValuePair.first.AsString().c_str(), this->AsString().c_str()); // Gradients can only be obtained for parameter variables or input variables that NeedsGradient if (!gradientVarValuePair.first.NeedsGradient() || (m_inputsExcludedFromGradientComputation.find(gradientVarValuePair.first) != m_inputsExcludedFromGradientComputation.end())) InvalidArgument("Gradient value incorrectly requested for Variable '%S', " "an Output or Constant or Input Variable with NeedsGradient setting of false, or an input for which gradient computation was explicitly excluded.", gradientVarValuePair.first.AsString().c_str()); auto computationNodePtr = m_variableToNodeMap.at(gradientVarValuePair.first); if (!computationNodePtr->NeedsGradient()) LogicError("Function '%S': Backpropagated gradient value cannot be read from a Variable '%S' whose ComputationNode has NeedsGradient set to false.", AsString().c_str(), gradientVarValuePair.first.AsString().c_str()); auto& valuePtr = gradientVarValuePair.second; bool noValueStrorageProvided = (valuePtr == nullptr); GetNodeOutputOrGradient(gradientVarValuePair.first, valuePtr, computationNodePtr, true /*getGradient*/); auto packedVarValue = std::dynamic_pointer_cast(valuePtr); if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked()) m_existingNetworkStorageReferences.push_back(packedVarValue); } } const std::vector& CompositeFunction::GetArgumentDependencies(const Variable& output) { if (m_perOutputVarArgumentDependencies.find(output) == m_perOutputVarArgumentDependencies.end()) { auto sanitizedOutput = output.NonCompositePreservingCopy(); if (sanitizedOutput.IsOutput()) m_perOutputVarArgumentDependencies[sanitizedOutput] = AsComposite(sanitizedOutput.Owner())->Arguments(); else if (sanitizedOutput.IsParameter() || sanitizedOutput.IsConstant()) m_perOutputVarArgumentDependencies[sanitizedOutput] = {}; else m_perOutputVarArgumentDependencies[sanitizedOutput] = { sanitizedOutput }; } return m_perOutputVarArgumentDependencies[output]; } std::unordered_map CompositeFunction::GetCurrentBackpropRootsTimeStamps() const { std::unordered_map currentBackpropRootsTimeStamps; assert(m_computationNetwork != nullptr); for (auto& backpropRoot : m_currentBackpropRoots) currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp(); return currentBackpropRootsTimeStamps; } /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map& arguments, std::unordered_map& outputs, const DeviceDescriptor& computeDevice, const std::unordered_set& outputsToRetainBackwardStateFor, const std::unordered_set& inputsToExcludeGradientsFor) { // Validate arguments and outputs if (outputs.empty()) InvalidArgument("At least one output has to be specified when calling Forward method of the Function '%S'.", this->AsString().c_str()); // Make sure that the DataType of the variables and corresponding values match // TODO: We need a better way to determine the ElementType for the network auto dataType = DataType::Unknown; for (auto variableValuePair : arguments) { if (dataType == DataType::Unknown) dataType = variableValuePair.first.GetDataType(); else if (dataType != variableValuePair.first.GetDataType()) LogicError("Function '%S' Forward: The DataType of all arguments must be same.", this->AsString().c_str()); } if (dataType == DataType::Unknown) { for (auto variableValuePair : outputs) { if (dataType == DataType::Unknown) dataType = variableValuePair.first.GetDataType(); } } std::unordered_set requestedOutputVariables; for (auto output : outputs) requestedOutputVariables.insert(output.first); std::unordered_set functionOutputs(m_outputs.begin(), m_outputs.end()); std::unordered_set requiredArguments; for (auto outputVariable : requestedOutputVariables) { auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVariable); requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end()); } // We should have argument values supplied for all required argument dependencies for the requested outputs std::vector missingRequiredArguments; std::unordered_map requiredArgumentValues; for (auto requiredArgument : requiredArguments) { auto iter = arguments.find(requiredArgument); if (iter == arguments.end()) missingRequiredArguments.push_back(requiredArgument); else requiredArgumentValues.insert(*iter); } if (!missingRequiredArguments.empty()) { InvalidArgument("Values for %d required arguments '%S', that the requested output(s) '%S' depend on, have not been provided.", (int)missingRequiredArguments.size(), NamedListString(missingRequiredArguments).c_str(), NamedListString(requestedOutputVariables).c_str()); } if (requiredArgumentValues.size() < arguments.size()) fprintf(stderr, "WARNING: Function::Forward provided values for (%d) extra arguments which are not required for evaluating the specified Function outputs!\n", (int)(arguments.size() - requiredArgumentValues.size())); auto inferredArgumentShapes = InferFreeDimensionsOfArguments(requiredArgumentValues); if (dataType == DataType::Float) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true); else if (dataType == DataType::Double) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true); else if (dataType == DataType::Float16) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true); else InvalidArgument("Unsupported DataType %s", DataTypeName(dataType)); // Feed data into the arguments of the network // TODO: Avoid copying the data when possible PopulateNetworkInputs(requiredArgumentValues); // Copy all new values for 'dirty' attributes from functions into corresponding network nodes. ApplyAttributeUpdates(); // Bump the timestamp of the parameter nodes whose values have changed for (auto& timeStampRecord : m_lastRecordedTimeStamps) { auto variable = timeStampRecord.first; auto prevTimeStamp = timeStampRecord.second; auto newTimeStamp = variable.CurrentValueTimeStamp(); if (newTimeStamp > prevTimeStamp) { timeStampRecord.second = newTimeStamp; m_variableToNodeMap.at(variable)->BumpEvalTimeStamp(); } } std::vector outputsToEvaluate; for (auto outputVariable : requestedOutputVariables) outputsToEvaluate.push_back(m_variableToNodeMap.at(outputVariable)); // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs' for (auto rootVarForBackprop : outputsToRetainBackwardStateFor) { if (outputs.find(rootVarForBackprop) == outputs.end()) outputsToEvaluate.push_back(m_variableToNodeMap.at(rootVarForBackprop)); } // Reset the timestamps of all backward roots to record an update in one or more inputs for (auto& backpropRoot : m_currentBackpropRoots) m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll(); // Reset the timestamps of all dropout node to force recomputation of the (random) dropout mask. list dropoutNodes = m_computationNetwork->GetNodesWithType(); for (auto& dropout : dropoutNodes) dropout->SetEvalTimeStampOutdatedWrtAll(); // Free any previous references to the matrix storage associated with the outputsToEvaluate ClearExistingOutputOrGradientStorageReferences(); ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training); m_computationNetwork->ForwardProp(outputsToEvaluate); // Call PostForwardAndBackProp after ForwardProp only in evaluation mode. if (outputsToRetainBackwardStateFor.empty()) { m_computationNetwork->PostForwardAndBackProp(outputsToEvaluate); RecordRefVariableUpdates(); } else { m_currentOutputsToEvaluate.clear(); for (auto outputToEvaluate : outputsToEvaluate) m_currentOutputsToEvaluate.push_back(outputToEvaluate); } GetNetworkOutputs(outputs); // TODO: How to deal with the specified 'computeDevice' BackPropStatePtr backpropStatePtr; if (outputsToRetainBackwardStateFor.size() > 0) backpropStatePtr = MakeSharedObject(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps()); return backpropStatePtr; } /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state, const std::unordered_map& rootGradientValues, std::unordered_map& backPropagatedGradientValuesForInputs) { auto backpropState = dynamic_cast(state.get()); if (backpropState == nullptr) InvalidArgument("Function '%S' Backward: Invalid backprop state passed.", AsString().c_str()); if (backPropagatedGradientValuesForInputs.empty()) InvalidArgument("Function '%S' Backward: List of inputs to compute gradients for, must not be empty.", AsString().c_str()); // TODO: Support multiple concurrent backprop states std::unordered_map currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps(); if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps) LogicError("Function '%S' Backward: The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified " "by subsequent Forward calls to the function. This is not a user error but a shortcoming of the current implementation where multiple independent " "backprop states are not simultaneously supported.", AsString().c_str()); if (rootGradientValues.size() > 1) LogicError("Function '%S' Backward: %d root gradient values specified; currently gradient backprop from only one of the Function Outputs is supported.", AsString().c_str(), (int)rootGradientValues.size()); // TODO: Avoid copying the data when possible // Zero all gradients of nodes below the root nodes for (auto rootGradientVarValuePair : rootGradientValues) m_computationNetwork->ZeroInputGradients(m_variableToNodeMap.at(rootGradientVarValuePair.first)); // Feed data into the arguments of the network PopulateNetworkGradients(rootGradientValues); // Backpropagate through the network ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training); auto rootComputationNodePtr = m_variableToNodeMap.at(rootGradientValues.begin()->first); m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true); GetNetworkGradients(backPropagatedGradientValuesForInputs); if (m_currentOutputsToEvaluate.size() > 0) { m_computationNetwork->PostForwardAndBackProp(m_currentOutputsToEvaluate); RecordRefVariableUpdates(); m_currentOutputsToEvaluate.clear(); } // TODO: How to deal with the specified 'computeDevice' } void CompositeFunction::ApplyAttributeUpdates() { // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input // This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated // w.r.t. inputs to force evaluation in each minibatch for (auto varNodePair : m_variableToNodeMap) { auto var = varNodePair.first; if (!var.IsOutput()) continue; auto function = var.Owner(); if (function->m_dirtyAttributes.empty()) continue; auto node = varNodePair.second; for (const wstring& attribute : function->m_dirtyAttributes) { if (attribute == PrimitiveFunction::AttributeNameDropoutRate) { auto dropoutRate = function->m_attributes[attribute].Value(); auto dropoutPtr = dynamic_cast(node.get()); assert(dropoutPtr != nullptr); dropoutPtr->SetDropoutRate(dropoutRate); } else if (attribute == PrimitiveFunction::AttributeNameRngSeed) { auto seed = function->m_attributes[PrimitiveFunction::AttributeNameRngSeed].Value(); auto rngUserPtr = dynamic_cast(node.get()); assert(rngUserPtr != nullptr); rngUserPtr->SetRngState(seed); } else { // Should never happen. LogicError("ApplyAttributeUpdates: function '%S' specified an unsupported attribute '%S'.", function->AsString().c_str(), attribute.c_str()); } } function->m_dirtyAttributes.clear(); node->SetEvalTimeStampOutdatedWrtAll(); } } }