// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include "CNTKLibrary.h" #include "CompositeFunction.h" #include "ComputationNetworkBuilder.h" #include "Utils.h" #include "ComputationNode.h" #include "ReshapingNodes.h" #include "EvaluationNodes.h" #include "TrainingNodes.h" #include "LinearAlgebraNodes.h" #include "InputAndParamNodes.h" #include "NonlinearityNodes.h" #include "RecurrentNodes.h" #include "Serialization.h" #include "Value.h" #include "RNNNodes.h" #include "UserDefinedV2FunctionNode.h" using namespace Microsoft::MSR::CNTK; namespace CNTK { /*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName"; /*static*/ std::atomic CompositeFunction::s_nextAutoGeneratedDynamicAxis(0); static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction"; /*virtual*/ Dictionary CompositeFunction::Serialize() const { Dictionary dict; dict[versionKey] = CurrentVersion(); dict[typeKey] = s_compositeFunctionTypeValue; dict[rootKey] = RootFunction()->Uid(); dict[nameKey] = Name(); dict[uidKey] = Uid(); // Find cycles in the graph and "break" them by inserting placeholders. // This needs to be done on Save, since here we have easy access to the shape and // dynamic axis info. std::unordered_set visitedFunctions; std::vector topoSortedPrimitiveFunctions; std::vector inputs; std::unordered_set inputUids; Traverse(RootFunction(), [&visitedFunctions, &inputs, &topoSortedPrimitiveFunctions, &inputUids](const FunctionPtr& function) { std::vector functionInputs = function->Inputs(); for (const auto& input : functionInputs) { auto& uid = input.Uid(); if (inputUids.find(uid) != inputUids.end()) { continue; } // check if this input corresponds to a cyclic edge in the graph. bool mustBeReplaced = input.IsOutput() && visitedFunctions.find(input.Owner()) != visitedFunctions.end(); if (mustBeReplaced) { auto varKind = VariableKind::Placeholder; Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid); inputs.push_back(var); inputUids.insert(uid); } else if (!input.IsOutput()) { // leave the input as is. inputs.push_back(input); inputUids.insert(uid); } } visitedFunctions.insert(function); topoSortedPrimitiveFunctions.push_back(function); }); std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions)); assert(topoSortedPrimitiveFunctions.size() == m_allPrimitiveFunctions.size()); assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid()); std::vector inputDictionaries; inputDictionaries.reserve(inputs.size()); inputUids.clear(); for (const auto& input : inputs) { if (inputUids.find(input.Uid()) != inputUids.end()) { LogicError("Input uids must be unique"); } inputUids.insert(input.Uid()); inputDictionaries.push_back(input.Serialize()); } dict[inputsKey] = std::move(inputDictionaries); std::vector functionDictionaries; std::unordered_set outputUids; for (const auto& primitiveFunciton : topoSortedPrimitiveFunctions) { for (const auto& output : primitiveFunciton->Outputs()) { if (outputUids.find(output.Uid()) != outputUids.end()) { LogicError("Output uids of all primitive functions in a function graph must be unique"); } outputUids.insert(primitiveFunciton->Uid()); } functionDictionaries.push_back(primitiveFunciton->Serialize()); } dict[functionsKey] = std::move(functionDictionaries); // Now, collect and store the internal state for all non-pure (stateful) functions in the graph // (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc). Dictionary stateDictionary; for (const auto& kv : m_variableToNodeMap) { if (kv.second->Is()) { auto rng = kv.second->As(); Dictionary state; state[rngSeedKey] = static_cast(rng->GetRngSeed()); state[rngOffsetKey] = static_cast(rng->GetRngOffset()); stateDictionary[kv.first.Owner()->Uid()] = state; } } dict[stateKey] = std::move(stateDictionary); return dict; } /*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device) { static const vector s_requiredDictionaryKeys = { typeKey, rootKey, nameKey, uidKey, inputsKey, functionsKey }; size_t version = ValidateDictionary(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion); const auto& rootUid = dict[rootKey].Value(); const auto& name = dict[nameKey].Value(); const auto& uid = dict[uidKey].Value(); const auto& inputs = dict[inputsKey].Value>(); std::unordered_map uidToInputMap(inputs.size()); for (const auto& dictionaryValue : inputs) { const auto& dictionary = dictionaryValue.Value(); const auto& inputVar = Variable::Deserialize(dictionary, device); if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end()) { LogicError("Input uids are not unique (several inputs share '%ls' uid) " "(%s).", inputVar.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); } uidToInputMap[inputVar.Uid()] = inputVar; } Dictionary stateDictionary; if (dict.Contains(stateKey)) { stateDictionary = dict[stateKey].Value(); } const auto& functions = dict[functionsKey].Value>(); FunctionPtr root; std::unordered_map placeholderReplacements; std::unordered_set allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created. for (const auto& dictionaryValue : functions) { root = PrimitiveFunction::Deserialize(dictionaryValue.Value(), uidToInputMap, device); allPrimitiveFunctions.insert(root); auto primitiveFunction = dynamic_cast(root.get()); // Since Combine simply forwards other functions' outputs, all of its outputs // should already be in the uidToInputMap. auto opType = primitiveFunction->OpType(); if (opType == PrimitiveOpType::Combine) { continue; } if (primitiveFunction->IsStateful()) { if (stateDictionary.Contains(primitiveFunction->Uid())) { auto state = stateDictionary[primitiveFunction->Uid()].Value(); auto seed = state[rngSeedKey].Value(); auto offset = state[rngOffsetKey].Value(); primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed; primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset; } else if (Internal::GetComputationNetworkTraceLevel() > 0) { // TODO: all logging functionality should be refactored to live in a logging utility class. fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) " "when deserializing from a dictionary (version=%zu). " "Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version); } } for (const auto& output : root->Outputs()) { const auto& it = uidToInputMap.find(output.Uid()); if (it != uidToInputMap.end()) { if (!it->second.IsPlaceholder()) { LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)" "(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(), GetVersionsString(s_serializationVersion, version).c_str()); } placeholderReplacements[it->second] = output; } else { uidToInputMap[output.Uid()] = output; } } } if (root->Uid() != rootUid) { LogicError("Root UID '%ls' is different from the expected value '%ls'.", root->Uid().c_str(), rootUid.c_str()); } if (placeholderReplacements.size() > 0) { return CompositeFunction::Create(root->ReplacePlaceholders(placeholderReplacements), name, uid); } return CompositeFunction::Create(root, name, uid); } void CompositeFunction::CopyState(const CompositeFunction& source) { // Create a map with all non-pure (stateful) functions in the function graph. auto collectStatefulFunctions = [](const std::unordered_set& allPrimitiveFunctions) -> std::map { std::map functionMap; for (auto funcPtr : allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(funcPtr.get()); if (primitiveFunction->IsStateful()) { functionMap[primitiveFunction->Uid()] = funcPtr; } } return functionMap; }; std::map statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions); std::map statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions); assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size()); if (statefulFunctionsFrom.size() == 0) { return; } // Copy state captured in the attributes dictionaries. for (const auto& kv : statefulFunctionsFrom) { statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes(); } UpdateInternalNetworkState(); } void CompositeFunction::UpdateInternalNetworkState() { if (!m_computationNetwork) { return; } for (const auto& function : m_allPrimitiveFunctions) { auto primitiveFunction = dynamic_cast(function.get()); if (primitiveFunction->IsStateful()) { for (const auto& output : function->Outputs()) { auto node = m_variableToNodeMap[output]; auto attributes = function->Attributes(); auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value(); auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value(); node->As()->SetRngState(seed, offset); } } } } // Names of the dynamic axes in the CNTK engine for some special sets of dynamic axes values // Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over) // and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed // when the new CNTK v2 model serialization format is ready. /*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*"; /*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis"; // Replace any PlaceHolder Variables in the graph of Functions underlying 'this' CompositeFunction. All PlaceHolder variables // should have been replaced before performing any Forward compute of 'this' Function. /*virtual*/ void CompositeFunction::ReplacePlaceholdersInPlace(const std::unordered_map& placeholderReplacements, std::unordered_set& visitedFunctions, std::unordered_set& replacedPlaceholders) { RootFunction()->ReplacePlaceholdersInPlace(placeholderReplacements, visitedFunctions, replacedPlaceholders); // If any of the placeholders were replaced with Output variables, let's add the graph of function underneath each of those to 'm_allPrimitiveFunctions' set for (auto replacedPlaceholder : replacedPlaceholders) { auto replacingVariable = placeholderReplacements.at(replacedPlaceholder); if (replacingVariable.IsOutput()) { auto ownerFunc = replacingVariable.Owner(); std::unordered_set visitedFunctions2; Collect(ownerFunc, visitedFunctions2); // Add the newly visited functions to 'm_allPrimitiveFunctions' set m_allPrimitiveFunctions.insert(visitedFunctions2.begin(), visitedFunctions2.end()); } } std::unordered_map functionVisitCounts; // An arbitrary cap on changing output shape of recurrent nodes, to detect infinite inference loops const size_t maxNumValidationPassesAllowed = 25; bool recurrentNodeOutputModified = false; size_t numValidationPasses = 0; do { recurrentNodeOutputModified = false; functionVisitCounts.clear(); RootFunction()->ValidateOrUpdateOutputs(functionVisitCounts, recurrentNodeOutputModified); numValidationPasses++; } while (recurrentNodeOutputModified && (numValidationPasses < maxNumValidationPassesAllowed)); if (numValidationPasses >= maxNumValidationPassesAllowed) LogicError("A recurrent node output shape change happened in successive %d validation passes indicating a potential infinite inference loop!", (int)numValidationPasses); } // Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions // underlying the specified 'variable' and return the ComputationNode instance that corresponds to the // top level 'variable' template /*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder& builder, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap) { auto iter = variableToNodeMap.find(variable); if (iter != variableToNodeMap.end()) { isVariableRootMap[variable] = false; return iter->second; } // The DataType, Shape and DynamicAxes of the variable must be known by now if (variable.GetDataType() == DataType::Unknown) InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.Shape().IsUnknown()) InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.Shape().HasInferredDimension()) InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); if (variable.DynamicAxes() == Axis::UnknownDynamicAxes()) InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str()); // Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs variableToNodeMap[variable] = nullptr; std::shared_ptr> computationNodePtr; if (variable.IsParameter() || variable.IsConstant()) { auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape())); network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later if (!variable.NeedsGradient()) computationNodePtr->SetLearningRateMultiplier(0.0); NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value(); std::shared_ptr> valueMatrix = variable.IsConstant() ? value->GetMatrix() : value->GetWritableMatrix(); if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId())) computationNodePtr->Value() = valueMatrix->AsReference(); else { Matrix clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat()); clonedMatrix.AssignValuesOf(*valueMatrix); computationNodePtr->Value() = std::move(clonedMatrix); } } else if (variable.IsInput()) { auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name()); // TODO: Input variables currently are required to have the default batch axis auto dynamicAxes = variable.DynamicAxes(); auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis()); if (foundDefaultBatchAxis == dynamicAxes.end()) LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes"); if (dynamicAxes.back() != Axis::DefaultBatchAxis()) LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes"); // TODO: Support inputs with > 1 dynamic axes if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2)) LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported"); // Construct the dynamic axis name to be used internally for the CNTK InputNodes std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName)) network->AddNodeToNetAndAttachInputs(New>(network->GetDeviceId(), internalDynamicAxisName), {}); if (IsSparseInput(variable)) computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); else computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName); if (variable.NeedsGradient()) { // Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default // gradients are not computed for Input nodes computationNodePtr->SetLearningRateMultiplier(0.00001f); } } else { assert(variable.IsOutput()); computationNodePtr = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap)->template As>()->shared_from_this(); } variableToNodeMap[variable] = computationNodePtr; if (isVariableRootMap.find(variable) == isVariableRootMap.end()) isVariableRootMap[variable] = variable.IsOutput(); return computationNodePtr; } template /*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable, Function* function, const std::vector>>& inputNodes, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, std::unordered_map& variableToNodeMap) { ComputationNodeBasePtr computationNodePtr; auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name()); std::vector inputNodesBasePtrs; for (auto inputNode : inputNodes) inputNodesBasePtrs.push_back(inputNode); PrimitiveFunction* primitiveFunction = dynamic_cast(function); if (primitiveFunction) { auto functionInputs = function->Inputs(); auto& functionConfig = function->Attributes(); PrimitiveOpType op = primitiveFunction->OpType(); switch (op) { case PrimitiveOpType::Negate: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sigmoid: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Tanh: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Cos: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sin: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ReLU: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Exp: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Log: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Sqrt: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Floor: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Abs: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Reciprocal: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Softmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Hardmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::TransposeAxes: { auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value(); auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value(); // The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2)); break; } case PrimitiveOpType::Where: { auto dynamicAxes = variable.DynamicAxes(); auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName); break; } case PrimitiveOpType::Slice: { auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value(); auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value(); // Internal CNTK SliceNode takes 1 based axis indices instead of 0 based computationNodePtr = New>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis)); break; } case PrimitiveOpType::RandomSample: { auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); break; } case PrimitiveOpType::RandomSampleInclusionFrequency: { auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value(); auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates); break; } case PrimitiveOpType::Dropout: { auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName); computationNodePtr->As>()->SetDropoutRate(dropoutRate); break; } case PrimitiveOpType::Reshape: { auto newShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(newShape)); break; } case PrimitiveOpType::ROIPooling: { auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape)); break; } case PrimitiveOpType::Pooling: { PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value()); auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value(); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW); break; } case PrimitiveOpType::SumAll: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Plus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LogPlus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Minus: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ElementTimes: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Equal: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::NotEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Less: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::LessEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Greater: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::GreaterEqual: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Times: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap); break; } case PrimitiveOpType::TransposeTimes: { size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, outputRank); break; } case PrimitiveOpType::Convolution: { NDShape outputMapCount, kernelShape; std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape()); auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value(); auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value(); auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value(); auto sharing = AsVector(functionConfig[PrimitiveFunction::AttributeNameSharing].Value>()); auto autoPadding = AsVector(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value>()); auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value(); auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples); break; } case PrimitiveOpType::CosDistance: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Logistic: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::SquaredError: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::CrossEntropyWithSoftmax: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ClassificationError: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::PastValue: case PrimitiveOpType::FutureValue: { Variable inputOperandVar = functionInputs[0]; Variable initialStateVar = functionInputs[1]; size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value(); if (op == PrimitiveOpType::PastValue) computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); else computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset); break; } case PrimitiveOpType::ReduceElements: { auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis)); break; } case PrimitiveOpType::BatchNormalization: { auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value(); auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value(); auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value(); auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value(); auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW); break; } case PrimitiveOpType::Combine: // This operation is just a no-op and is a means to combine multiple functions to create a single Function // whose outputs are a union of the outputs of the Functions being combined. computationNodePtr = variableToNodeMap[variable]; break; case PrimitiveOpType::PackedIndex: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::GatherPacked: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::ScatterPacked: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Clip: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Select: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; case PrimitiveOpType::Splice: { Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis)); break; } case PrimitiveOpType::OptimizedRNNStack: { auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value(); auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value(); auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value(); auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value(); computationNodePtr = New>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp); break; } case PrimitiveOpType::ReconcileDynamicAxis: { computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::LogSoftmax: { //This can be implemented as x => x - ReduceLogSum(x). How to do this here? computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; } case PrimitiveOpType::Pass: computationNodePtr = New>(network->GetDeviceId(), internalNodeName); break; default: LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str()); break; } // Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs); if (computationNodePtr->Is()) { auto computationNodeExpectedInputCount = computationNodePtr->As()->GetExpectedNumInputs(); if (computationNodeExpectedInputCount != inputNodesBasePtrs.size()) LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs", PrimitiveOpTypeName(op).c_str(), (int)inputNodesBasePtrs.size(), (int)computationNodeExpectedInputCount); } if (computationNodePtr->Is()) { if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed)) { auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value(); unsigned long long offset = 0; if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset)) { offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value(); } computationNodePtr->As()->SetRngState(seed, offset); } } } else computationNodePtr = New>(network->GetDeviceId(), internalNodeName, function->shared_from_this()); network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs); return computationNodePtr; } template /*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable, Microsoft::MSR::CNTK::ComputationNetworkPtr& network, ComputationNetworkBuilder& builder, std::unordered_map& variableToNodeMap, std::unordered_map& isVariableRootMap) { assert(variable.IsOutput()); Function* function = variable.Owner().get(); ComputationNodeBasePtr computationNodePtr; auto& functionInputs = function->m_inputs; DataType nonConstInputDataType = DataType::Unknown; for (auto& inputVar : functionInputs) { if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown)) { nonConstInputDataType = inputVar.GetDataType(); break; } } // Create the nodes corresponding to the inputs std::vector>> inputNodes; for (auto& inputVar : functionInputs) { // If the inputVar is a constant and not the right DataType let's coerce it to the right type if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType)) { auto originalConstantValue = Constant(inputVar).Value(); auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true); NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true); inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name()); } auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap); inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As>()->shared_from_this() : nullptr); } computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap); PrimitiveFunction* primitiveFunction = dynamic_cast(function); if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine)) { for (auto inputVar : functionInputs) isVariableRootMap[inputVar] = false; } return computationNodePtr; } template ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device, const std::unordered_set& backpropRoots, const std::unordered_set& outputs, bool allocateNetworkMatrices) { if (m_computationNetwork != nullptr) { // TODO: We should either invalidate and readapt the network if he backpropRoots change compared to what was specified when the network // was last constructed, to just recreate a new network. // For now just disallow changing the backpropRoots after the network is created if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots)) LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported"); // TODO: Support changing the device across different invocations of the forward method on a Function instance if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device) LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported"); } else { m_computationNetwork = std::make_shared(AsCNTKImplDeviceId(device)); ComputationNetworkBuilder builder(*m_computationNetwork); // TODO: We currently only support one backprop root if (backpropRoots.size() > 1) LogicError("More than one backprop roots is currently unsupported"); auto placeholders = Placeholders(); if (!placeholders.empty()) InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!"); // Now recursively create the network in a top-down fashion auto rootFunction = RootFunction(); auto rootFunctionOutputs = rootFunction->Outputs(); for (auto rootOutput : rootFunctionOutputs) GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap); // If any of the function outputs is not a root node, we need to explicitly add it to the 'output' group of the ComputationNetwork for (auto rootOutput : rootFunctionOutputs) { if (!m_isVariableRootMap[rootOutput]) m_computationNetwork->AddToNodeGroup(L"output", m_variableToNodeMap[rootOutput]); } // If any of the requested outputs is not a root node, we need to explicitly add it to the 'output' group of the ComputationNetwork for (auto output : outputs) { if (!m_isVariableRootMap[output]) { auto computationNode = m_variableToNodeMap[output]; if (!computationNode) InvalidArgument("One of the requested outputs for the Function forward computation is not part of the graph underlying the Function"); m_computationNetwork->AddToNodeGroup(L"output", computationNode); } } m_currentBackpropRoots = backpropRoots; // In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles. // Now attach those after we have created all ComputationNodes in the network for (auto varNodePair : m_variableToNodeMap) { auto& currentComputationNode = varNodePair.second; auto& currentComputationNodeInputs = currentComputationNode->GetInputs(); auto& currentVar = varNodePair.first; if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end()) { // This ComputationNode has at least one null input which now needs to be properly attached const PrimitiveFunction* primitiveFunc = dynamic_cast(currentVar.Owner().get()); // Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering auto inputVars = primitiveFunc->Inputs(); ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars); inputVars.resize(currentComputationNode->GetNumInputs()); std::vector inputNodesBasePtrs; for (auto inputVar : inputVars) inputNodesBasePtrs.push_back(m_variableToNodeMap[inputVar]); currentComputationNode->AttachInputs(inputNodesBasePtrs); } } m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel()); m_computationNetwork->CompileNetwork(); // Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork for (auto varNodePair : m_variableToNodeMap) { if (varNodePair.first.IsOutput()) { auto outputVar = varNodePair.first; auto computationNodePtr = m_variableToNodeMap[outputVar]; auto outputShape = outputVar.Shape(); auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout(); if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) || ((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape)))) { LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str()); } } } // Record the timestamps of Parameter values assert(m_lastRecordedParameterValueTimeStamps.empty()); auto functionParameters = Parameters(); for (auto parameter : functionParameters) m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() }); } if (!m_networkMatricesAllocated && allocateNetworkMatrices) { ComputationNodeBasePtr backpropRootNode; // Now recursively traverse the network in a top-down fashion auto rootFunction = RootFunction(); auto rootFunctionOutputs = rootFunction->Outputs(); std::vector forwardRootNodes; for (auto rootOutput : rootFunctionOutputs) { auto currentRootNode = m_variableToNodeMap[rootOutput]; forwardRootNodes.push_back(currentRootNode); if (m_currentBackpropRoots.find(rootOutput) != m_currentBackpropRoots.end()) backpropRootNode = currentRootNode; } std::vector forwardOutputNodes; for (auto output : outputs) { auto currentOutputNode = m_variableToNodeMap[output]; forwardOutputNodes.push_back(currentOutputNode); if (m_currentBackpropRoots.find(output) != m_currentBackpropRoots.end()) backpropRootNode = currentOutputNode; } m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode); m_networkMatricesAllocated = allocateNetworkMatrices; m_currentOutputs = outputs; m_currentOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end()); m_currentOutputs.insert(m_currentBackpropRoots.begin(), m_currentBackpropRoots.end()); } else { // Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure // in the cached computation network for (auto output : outputs) { if (m_currentOutputs.find(output) == m_currentOutputs.end()) LogicError("Changing requested outputs across different Forward calls on a CNTK composite Function is currently unsupported"); } } return m_computationNetwork; } template /*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair& variableValue, ComputationNodeBasePtr& computationNode) { std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout; auto packedValue = dynamic_cast(variableValue.second.get()); if (packedValue) CNTKMatrixAndMBLayout = packedValue->PackedData(); else CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject(variableValue.first, variableValue.second); MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; auto& nodeData = computationNode->As>()->Value(); // Switch the node matrix to the right matrix type nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first); computationNode->GetMBLayout()->CopyFrom(layout); } void CompositeFunction::PopulateNetworkInputs(const std::unordered_map& arguments) { std::vector inputNodes; for (auto argumentValuePair : arguments) { auto argument = argumentValuePair.first; auto argumentComputationNode = m_variableToNodeMap[argument]; assert(argumentComputationNode); inputNodes.push_back(argumentComputationNode); ValuePtr argumentValue = arguments.at(argument); MBLayoutPtr layout; switch (argumentValue->GetDataType()) { case DataType::Float: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode); break; case DataType::Double: PopulateComputationNodeValue({ argument, argumentValue }, argumentComputationNode); break; default: LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType())); break; } } m_computationNetwork->BumpEvalTimeStamp(inputNodes); } template /*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode) { std::pair>, MBLayoutPtr> CNTKMatrixAndMBLayout; auto packedValue = dynamic_cast(variableGradient.second.get()); if (packedValue) CNTKMatrixAndMBLayout = packedValue->PackedData(); else CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject(variableGradient.first, variableGradient.second); MBLayoutPtr layout = CNTKMatrixAndMBLayout.second; auto nodeLayout = computationNode->GetMBLayout(); if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout))) InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call"); computationNode->As>()->AssignGradient(*CNTKMatrixAndMBLayout.first); } // Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph void CompositeFunction::PopulateNetworkGradients(const std::unordered_map& gradients) { auto functionOutputs = this->Outputs(); for (auto gradientVarValuePair : gradients) { auto outputComputationNode = m_variableToNodeMap[gradientVarValuePair.first]; ValuePtr gradientValue = gradientVarValuePair.second; switch (gradientValue->GetDataType()) { case DataType::Float: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; case DataType::Double: PopulateComputationNodeGradient(gradientVarValuePair, outputComputationNode); break; default: LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType())); break; } } } static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr) { size_t outputValueNumAxes = var.Shape().Rank(); // Add the batch and dynamic axes if needed if (computationNodePtr->GetMBLayout() != nullptr) outputValueNumAxes += 2; std::vector outputShapeDims(outputValueNumAxes); for (size_t i = 0; i < var.Shape().Rank(); ++i) outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i); if (computationNodePtr->GetMBLayout() != nullptr) { outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps(); outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences(); } return NDShape(outputShapeDims); } /*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient) { auto valueShape = GetValueShape(var, computationNode); if (varValue != nullptr) { // TODO: The shape of the specified output Value object must match the actual output shape if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape))) InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str()); } ValuePtr nodeValue; auto layout = computationNode->GetMBLayout(); switch (var.GetDataType()) { case DataType::Float: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); break; } case DataType::Double: { auto& matrix = getGradient ? computationNode->As>()->Gradient() : computationNode->As>()->Value(); if (varValue == nullptr) nodeValue = MakeSharedObject(var.Shape(), std::make_shared>(matrix.AsReference()), layout, /*readOnly =*/ false); else nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout(var, matrix, layout); break; } default: LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType())); break; } if (varValue == nullptr) varValue = nodeValue; else varValue->CopyFrom(*nodeValue); } void CompositeFunction::GetNetworkOutputs(std::unordered_map& outputs) { // Now copy the Forward values of output nodes from the network to outputs' Value objects for (auto outputVarValuePair : outputs) GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap[outputVarValuePair.first], false /*getGradient*/); } void CompositeFunction::GetNetworkGradients(std::unordered_map& gradients) { auto networkInputs = this->Inputs(); // Now copy the gradient values of input nodes of the network to gradients' Value objects for (auto gradientVarValuePair : gradients) { // Only gradients corresponding to inputs of the network can be obtained if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end()) InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function"); // Gradients can only be obtained for parameter variables or input variables that NeedsGradient if (!gradientVarValuePair.first.NeedsGradient()) InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, or an Input Variable with NeedsGradient setting of false"); auto computationNodePtr = m_variableToNodeMap[gradientVarValuePair.first]; if (!computationNodePtr->NeedsGradient()) LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false"); GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/); } } const std::vector& CompositeFunction::GetArgumentDependencies(const Variable& output) { assert(output.IsOutput()); auto iter = m_perOutputVarArgumentDependencies.find(output); if (iter != m_perOutputVarArgumentDependencies.end()) return iter->second; auto wrappedComposite = CompositeFunction::Create(output.Owner()); m_perOutputVarArgumentDependencies[output] = wrappedComposite->Arguments(); return m_perOutputVarArgumentDependencies[output]; } std::unordered_map CompositeFunction::GetCurrentBackpropRootsTimeStamps() const { std::unordered_map currentBackpropRootsTimeStamps; assert(m_computationNetwork != nullptr); for (auto& backpropRoot : m_currentBackpropRoots) currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp(); return currentBackpropRootsTimeStamps; } /*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map& arguments, std::unordered_map& outputs, const DeviceDescriptor& computeDevice, const std::unordered_set& outputsToRetainBackwardStateFor) { // Validate arguments and outputs if (outputs.empty()) InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!"); // Make sure that the DataType of the variables and corresponding values match // TODO: We need a better way to determine the ElementType for the network auto dataType = DataType::Unknown; for (auto variableValuePair : arguments) { if (dataType == DataType::Unknown) dataType = variableValuePair.first.GetDataType(); else if (dataType != variableValuePair.first.GetDataType()) LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same"); } if (dataType == DataType::Unknown) { for (auto variableValuePair : outputs) { if (dataType == DataType::Unknown) dataType = variableValuePair.first.GetDataType(); } } std::unordered_set requestedOutputVariables; for (auto output : outputs) requestedOutputVariables.insert(output.first); if (dataType == DataType::Float) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, true); else if (dataType == DataType::Double) GetComputationNetwork(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, true); else InvalidArgument("Unsupported DataType %s", DataTypeName(dataType)); std::unordered_set functionOutputs(this->Outputs().begin(), this->Outputs().end()); std::vector outputsToEvaluate; std::unordered_set requiredArguments; for (auto outputVarValuePair : outputs) { auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first); requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end()); auto outputComputationNode = m_variableToNodeMap[outputVarValuePair.first]; outputsToEvaluate.push_back(outputComputationNode); } // TODO: Avoid copying the data when possible // We should have argument values supplied for all required argument dependencies for the requested outputs for (auto requiredArgument : requiredArguments) { if (arguments.find(requiredArgument) == arguments.end()) InvalidArgument("Function::Forward: Required argument's (%S) value that the requested output(s) depend on has not been provided", requiredArgument.Name().c_str()); } // Feed data into the arguments of the network PopulateNetworkInputs(arguments); // Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input // This mask is regerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated // w.r.t. inputs to force evaluation in each minibatch list dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode)); for (auto& nodeIter : dropoutNodes) nodeIter->SetEvalTimeStampOutdatedWrtAll(); // Bump the timestamp of the parameter nodes whose values have changed for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps) { auto parameter = paramTimeStampRecord.first; auto prevTimeStamp = paramTimeStampRecord.second; auto newTimeStamp = parameter.CurrentValueTimeStamp(); if (newTimeStamp > prevTimeStamp) { paramTimeStampRecord.second = newTimeStamp; m_variableToNodeMap[parameter]->BumpEvalTimeStamp(); } } // The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs' for (auto rootVarForBackprop : outputsToRetainBackwardStateFor) { if (outputs.find(rootVarForBackprop) == outputs.end()) outputsToEvaluate.push_back(m_variableToNodeMap[rootVarForBackprop]); } // Reset the timestamps of all backward roots to record an update in one or more inputs for (auto& backpropRoot : m_currentBackpropRoots) m_variableToNodeMap[backpropRoot]->SetEvalTimeStampOutdatedWrtAll(); // TODO: Verify that values were supplied for all inputs that requested outputs depend on ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training); m_computationNetwork->ForwardProp(outputsToEvaluate); GetNetworkOutputs(outputs); // TODO: How to deal with the specified 'computeDevice' Variable evalTimeStampVariable; if (arguments.empty()) evalTimeStampVariable = Inputs()[0]; else evalTimeStampVariable = arguments.begin()->first; BackPropStatePtr backpropStatePtr; if (outputsToRetainBackwardStateFor.size() > 0) backpropStatePtr = MakeSharedObject(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps()); return backpropStatePtr; } /*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state, const std::unordered_map& rootGradientValues, std::unordered_map& backPropagatedGradientValuesForInputs) { auto backpropState = dynamic_cast(state.get()); if (backpropState == nullptr) InvalidArgument("Invalid backprop state specified"); // TODO: Support multiple concurrent backprop states std::unordered_map currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps(); if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps) LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function." "This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported"); if (rootGradientValues.size() > 1) LogicError("Currently gradient backprop from only one of the Function Outputs is supported"); // TODO: Avoid copying the data when possible // Zero all gradients of nodes below the root nodes for (auto rootGradientVarValuePair : rootGradientValues) m_computationNetwork->ZeroInputGradients(m_variableToNodeMap[rootGradientVarValuePair.first]); // Feed data into the arguments of the network PopulateNetworkGradients(rootGradientValues); // Backpropagate through the network ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training); auto rootComputationNodePtr = m_variableToNodeMap[rootGradientValues.begin()->first]; m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true); GetNetworkGradients(backPropagatedGradientValuesForInputs); // TODO: How to deal with the specified 'computeDevice' } }