https://github.com/Microsoft/CNTK
Tip revision: 2528557bd1e538a0b35f5335a5ba124b78dc47a9 authored by Jian Jiao on 04 October 2017, 18:20:45 UTC
add forward
add forward
Tip revision: 2528557
CompositeFunction.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "CompositeFunction.h"
#include "ComputationNetworkBuilder.h"
#include "Utils.h"
#include "ComputationNode.h"
#include "ReshapingNodes.h"
#include "EvaluationNodes.h"
#include "TrainingNodes.h"
#include "LinearAlgebraNodes.h"
#include "InputAndParamNodes.h"
#include "NonlinearityNodes.h"
#include "RecurrentNodes.h"
#include "Serialization.h"
#include "Value.h"
#include "RNNNodes.h"
#include "UserDefinedV2FunctionNode.h"
#include "BlockFunction.h"
#include "SpecialPurposeNodes.h"
#include "SequenceReshapeNodes.h"
#include "UserDefinedFunction.h"
using namespace Microsoft::MSR::CNTK;
namespace CNTK
{
/*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName";
/*static*/ std::atomic<unsigned int> CompositeFunction::s_nextAutoGeneratedDynamicAxis(0);
static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction";
Dictionary CompositeFunction::SerializeBlockComposite() const
{
Dictionary dict;
dict[versionKey] = CurrentVersion();
dict[typeKey] = s_compositeFunctionTypeValue;
dict[rootKey] = RootFunction()->Uid();
if (!Name().empty())
dict[nameKey] = Name();
dict[uidKey] = Uid();
return dict;
}
// Copy the internal state from the network into the function graph,
// specifically from RngUser nodes into the attributes dictionaries of
// the corresponding stateful primitive functions.
void CompositeFunction::UpdateInternalState() const
{
if (!m_computationNetwork)
return;
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
// TODO: same for BatchNorm
auto& outputs = primitiveFunction->RawOutputs();
if (outputs.size() != 1)
LogicError("Function '%S' UpdateInternalState: a stateful primitive function must have a single output.", AsString().c_str());
const auto& rng = m_variableToNodeMap.at(outputs[0])->As<RngUser>();
Dictionary state;
state[PrimitiveFunction::AttributeNameRngSeed] = static_cast<size_t>(rng->GetRngSeed());
state[PrimitiveFunction::AttributeNameRngOffset] = static_cast<size_t>(rng->GetRngOffset());
primitiveFunction->SetState(state);
}
}
// Generate a dictionary representing the internal (local) state of the function graph.
Dictionary CompositeFunction::GetInternalState() const
{
UpdateInternalState();
Dictionary stateDictionary;
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
// TODO: same for BatchNorm
stateDictionary[primitiveFunction->Uid()] = primitiveFunction->GetState();
}
return stateDictionary;
}
/*virtual*/ Dictionary CompositeFunction::Serialize() const
{
UpdateInternalState();
Dictionary dict = SerializeBlockComposite();
// Find cycles in the graph and "break" them by inserting placeholders.
// This needs to be done on Save, since here we have easy access to the shape and
// dynamic axis info.
std::unordered_set<FunctionPtr> visitedFunctions;
std::vector<FunctionPtr> topoSortedPrimitiveFunctions;
std::vector<Variable> uniqueInputs;
std::unordered_set<std::wstring> inputUids;
std::function<void(const FunctionPtr& function)> SerializationTraversalFunc;
SerializationTraversalFunc = [&visitedFunctions, &uniqueInputs, &topoSortedPrimitiveFunctions, &inputUids, &SerializationTraversalFunc](const FunctionPtr& function) {
std::vector<Variable> functionInputs = function->Inputs();
for (const auto& input : functionInputs)
{
auto& uid = input.Uid();
if (inputUids.find(uid) != inputUids.end())
continue;
// check if this input corresponds to a cyclic edge in the graph.
// BUG: A function being visited twice does not indicate it being a cyclic edge in the graph.
// It just means there are at least 2 successors in the graph that have the function as input
bool mustBeReplaced = input.IsOutput() && (visitedFunctions.find(input.Owner()) != visitedFunctions.end());
if (mustBeReplaced)
{
auto varKind = VariableKind::Placeholder;
Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid);
uniqueInputs.push_back(var);
inputUids.insert(uid);
}
else if (!input.IsOutput())
{
// leave the input as is.
uniqueInputs.push_back(input);
inputUids.insert(uid);
}
}
visitedFunctions.insert(function);
topoSortedPrimitiveFunctions.push_back(function);
// For block functions we need to recursively traverse the underlying composite
if (function->IsBlock())
PreorderTraverseFunctions(function->BlockRoot(), SerializationTraversalFunc);
};
PreorderTraverseFunctions(RootFunction(), SerializationTraversalFunc);
std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions));
assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid());
std::vector<DictionaryValue> inputDictionaries;
inputDictionaries.reserve(uniqueInputs.size());
inputUids.clear();
for (const auto& input : uniqueInputs)
{
if (inputUids.find(input.Uid()) != inputUids.end())
LogicError("Function '%S' Serialize: Input uids must be unique.", AsString().c_str());
inputUids.insert(input.Uid());
inputDictionaries.push_back(input.Serialize());
}
dict[inputsKey] = std::move(inputDictionaries);
std::vector<DictionaryValue> functionDictionaries;
std::unordered_set<std::wstring> outputUids;
for (const auto& primitiveFunction : topoSortedPrimitiveFunctions)
{
for (const auto& output : primitiveFunction->RawOutputs())
{
if (outputUids.find(output.Uid()) != outputUids.end())
LogicError("Function '%S' Serialize: Output uids of all primitive functions in a function graph must be unique", AsString().c_str());
outputUids.insert(primitiveFunction->Uid());
}
auto functionDict = UDFUtils::IsUDF(primitiveFunction) ? UDFUtils::Serialize(primitiveFunction) : primitiveFunction->Serialize();
functionDictionaries.push_back(functionDict);
}
dict[functionsKey] = std::move(functionDictionaries);
return dict;
}
/*static*/ FunctionPtr CompositeFunction::DeserializeBlockComposite(const Dictionary& dict,
const std::unordered_set<FunctionPtr>& allPrimitiveFunctions,
const std::unordered_map<Variable, Variable>& allPlaceholderReplacements,
const CNTK::DeviceDescriptor& device)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, rootKey, uidKey };
ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);
const auto& rootUid = dict[rootKey].Value<std::wstring>();
std::wstring name = L"";
if (dict.Contains(nameKey))
name = dict[nameKey].Value<std::wstring>();
const auto& uid = dict[uidKey].Value<std::wstring>();
FunctionPtr root = *std::find_if(allPrimitiveFunctions.begin(), allPrimitiveFunctions.end(), [&rootUid](const FunctionPtr& func) {
return func->Uid() == rootUid;
});
// Find the subset of placeholder replacements that apply for this composite
FunctionPtr composite = CompositeFunction::Create(root, name, uid);
std::unordered_map<Variable, Variable> placeholderReplacements;
do
{
placeholderReplacements.clear();
auto compositePlaceholders = composite->Placeholders();
for (auto placeholder : compositePlaceholders)
{
if (allPlaceholderReplacements.find(placeholder) != allPlaceholderReplacements.end())
placeholderReplacements.insert({ placeholder, allPlaceholderReplacements.at(placeholder) });
}
if (placeholderReplacements.size() > 0)
composite = composite->ReplacePlaceholders(placeholderReplacements);
} while (placeholderReplacements.size() > 0);
return composite;
}
/*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { inputsKey, functionsKey };
size_t version = ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);
const auto& inputs = dict[inputsKey].Value<vector<DictionaryValue>>();
std::unordered_map<std::wstring, Variable> uidToInputMap(inputs.size());
for (const auto& dictionaryValue : inputs)
{
const auto& dictionary = dictionaryValue.Value<Dictionary>();
const auto& inputVar = Variable::Deserialize(dictionary, device);
if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end())
{
CNTK::LogicError("CompositeFunction::Deserialize: Input uids are not unique (several inputs share '%ls' uid) (%s).",
inputVar.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
}
uidToInputMap[inputVar.Uid()] = inputVar;
}
const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();
std::unordered_map<Variable, Variable> allPlaceholderReplacements;
std::unordered_set<FunctionPtr> allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created.
for (const auto& dictionaryValue : functions)
{
auto functionDict = dictionaryValue.Value<Dictionary>();
FunctionPtr root = UDFUtils::IsUDF(functionDict) ?
UDFUtils::Deserialize(functionDict, uidToInputMap, device) :
PrimitiveFunction::Deserialize(functionDict, uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device);
allPrimitiveFunctions.insert(root);
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(root.get());
if (primitiveFunction != nullptr && primitiveFunction->OpType() == PrimitiveOpType::Combine)
{
// Since Combine simply forwards other functions' outputs, all of its outputs
// should already be in the uidToInputMap.
continue;
}
for (const auto& output : root->RawOutputs())
{
const auto& it = uidToInputMap.find(output.Uid());
if (it != uidToInputMap.end())
{
if (!it->second.IsPlaceholder())
{
CNTK::LogicError("CompositeFunction::Deserialize: Unexpected variable '%S' instead of a Placeholder (uid = %ls) (%s).",
it->second.AsString().c_str(), it->second.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
}
allPlaceholderReplacements[it->second] = output;
}
else
{
uidToInputMap[output.Uid()] = output;
}
}
}
// starting with the serialization version = 3, the state is preserved inside the attribute dictionaries of the
// corresponding primitive functions. Earlier versions have a dedicated key-value pair in the composite function dict.
if (version < 3)
RestoreStatefulFunctions(version, dict, allPrimitiveFunctions);
return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
}
void CompositeFunction::RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> functions)
{
Dictionary stateDictionary;
if (dict.Contains(stateKey))
stateDictionary = dict[stateKey].Value<Dictionary>();
for (auto& function : functions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
if (stateDictionary.Contains(primitiveFunction->Uid()))
{
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
// Add key-value pairs expected by the SetState method to the state dictionary.
state[PrimitiveFunction::AttributeNameRngSeed] = state[rngSeedKey].Value<size_t>();
state[PrimitiveFunction::AttributeNameRngOffset] = state[rngOffsetKey].Value<size_t>();
primitiveFunction->SetState(state);
}
else
{
if (GetTraceLevel() >= TraceLevel::Warning)
{
// TODO: all logging functionality should be refactored to live in a logging utility class.
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
"when deserializing from a dictionary (version=%zu). "
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
}
// Create state from scratch, so that function attributes contain all the required key-value pairs.
Dictionary state;
state[PrimitiveFunction::AttributeNameRngSeed] = Internal::GenerateRandomSeed(true);
state[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
primitiveFunction->SetState(state);
}
}
}
void CompositeFunction::CopyState(const CompositeFunction& source)
{
// Collect a vector of stateful function uids using a pre-order traversal of a function graphs.
auto collectStatefulFunctionUIDs = [](const Function& function) -> vector<wstring> {
vector<wstring> uids;
PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
if (primitiveFunction && primitiveFunction->IsStateful())
{
uids.push_back(funcPtr->Uid());
}
}, true);
return uids;
};
auto theirUIDs = collectStatefulFunctionUIDs(source);
auto ourUIDs = collectStatefulFunctionUIDs(*this);
if (theirUIDs.size() != ourUIDs.size())
CNTK::LogicError("Cannot copy internal state, the source and the destination contain different number of stateful functions.");
auto state = source.GetInternalState();
if (theirUIDs == ourUIDs)
{
// uids are identialy, no need to remap.
SetInternalState(state);
return;
}
// build a map of souce funtion to the destination (this) function UIDs.
map<wstring, wstring> uidMap;
for (auto i = 0; i < theirUIDs.size(); i++)
uidMap[theirUIDs[i]] = ourUIDs[i];
Dictionary remappedState;
for (auto& kv : state)
remappedState[uidMap[kv.first]] = kv.second;
SetInternalState(remappedState);
}
void CompositeFunction::SetInternalState(const Dictionary& state)
{
if (state.Size() == 0)
return;
for (const auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();
primitiveFunction->SetState(functionState);
if (!m_computationNetwork)
continue;
auto seed = functionState[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionState[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
// copy the state directly into the network
for (const auto& output : function->RawOutputs())
{
auto node = m_variableToNodeMap.at(output);
node->As<RngUser>()->SetRngState(seed, offset);
}
}
}
// Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions
// underlying the specified 'variable' and return the ComputationNode instance that corresponds to the
// top level 'variable'
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
bool useMangledNamesForComputationNodes)
{
auto iter = variableToNodeMap.find(variable);
if (iter != variableToNodeMap.end())
{
isVariableRootMap[variable] = false;
return iter->second;
}
// The DataType, Shape and DynamicAxes of the variable must be known by now
if (variable.GetDataType() == DataType::Unknown)
InvalidArgument("Variable '%S' with unknown DataType found when compiling the Function graph.", variable.AsString().c_str());
if (variable.Shape().IsUnknown())
InvalidArgument("Variable '%S' with unknown shape found when compiling the Function graph.", variable.AsString().c_str());
if (variable.DynamicAxes() == Axis::UnknownDynamicAxes())
InvalidArgument("Variable '%S' with unknown dynamic axes found when compiling the Function graph.", variable.AsString().c_str());
// Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs
variableToNodeMap[variable] = nullptr;
std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes);
if (variable.IsParameter() || variable.IsConstant())
{
if (variable.Shape().HasInferredDimension())
InvalidArgument("Parameter or Constant '%S' with unresolved shape %S found when compiling the Function graph.", variable.AsString().c_str(), variable.Shape().AsString().c_str());
computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape()));
network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end()))
computationNodePtr->SetLearningRateMultiplier(0.0);
NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();
if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
computationNodePtr->Value() = valueMatrix->AsReference();
else // Constant: if initialized data lives on wrong device, make a copy to the right one (copy is OK since it's constant)
{
// TODO: the following two lines are a workaround for a bug in the Math library
// (AssignValuesOf throws when source and destination matrices reside on different GPU devices).
// Once this bug is fixed, change to
// Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
Matrix<ElementType> clonedMatrix(network->GetDeviceId());
clonedMatrix.SwitchToMatrixType(valueMatrix->GetMatrixType(), valueMatrix->GetFormat(), false);
clonedMatrix.AssignValuesOf(*valueMatrix);
computationNodePtr->Value() = std::move(clonedMatrix);
}
}
else if (variable.IsInput())
{
auto fullyDefinedArgumentVar = variable;
if (fullyDefinedArgumentVar.Shape().HasFreeDimension() && (fullyDefinedArgumentsMap.find(fullyDefinedArgumentVar) != fullyDefinedArgumentsMap.end()))
fullyDefinedArgumentVar = fullyDefinedArgumentsMap.at(fullyDefinedArgumentVar);
if (fullyDefinedArgumentVar.Shape().HasUnboundDimension())
InvalidArgument("Input Variable '%S' with unresolved shape %S found when compiling the Function graph.", fullyDefinedArgumentVar.AsString().c_str(), fullyDefinedArgumentVar.Shape().AsString().c_str());
// TODO: Input variables currently are required to have the default batch axis
auto dynamicAxes = variable.DynamicAxes();
auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis());
if (IsSparseInput(variable) && (foundDefaultBatchAxis == dynamicAxes.end()))
CNTK::LogicError("Sparse Input Variable '%S' found without a DefaultBatchAxis dynamic axis; this is currently unsupported.", variable.AsString().c_str());
if (!dynamicAxes.empty() && (dynamicAxes.back() != Axis::DefaultBatchAxis()))
CNTK::LogicError("Input Variable '%S' does not have the DefaultBatchAxis as its last dynamic axis.", variable.AsString().c_str());
// TODO: Support inputs with > 1 dynamic axes
if (dynamicAxes.size() > 2)
CNTK::LogicError("Input Variable '%S' has %d dynamic axes; currently only inputs with <= 2 dynamic axes are supported.",
variable.AsString().c_str(), (int)dynamicAxes.size());
if (!dynamicAxes.empty())
{
// Construct the dynamic axis name to be used internally for the CNTK InputNodes
std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});
if (IsSparseInput(variable))
computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()), internalDynamicAxisName);
else
computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()), internalDynamicAxisName);
if (variable.NeedsGradient() && (inputsToExcludeGradientsFor.find(variable) == inputsToExcludeGradientsFor.end()))
{
// Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default
// gradients are not computed for Input nodes
computationNodePtr->SetLearningRateMultiplier(0.00001f);
}
}
else
{
computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()));
network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end()))
computationNodePtr->SetLearningRateMultiplier(0.0);
}
if (variable.Shape().HasFreeDimension())
computationNodePtr->MarkNeedsDynamicValidation();
}
else
{
assert(variable.IsOutput());
auto outputVariableNode = GetOutputVariableNode(variable, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
// Can be null in case of loops with f.output == f.input.
// Such loops cannot be handled, so we leave nullptr as computational node.
if (outputVariableNode)
computationNodePtr = outputVariableNode->template As<ComputationNode<ElementType>>()->shared_from_this();
else
computationNodePtr = nullptr;
}
variableToNodeMap[variable] = computationNodePtr;
if (isVariableRootMap.find(variable) == isVariableRootMap.end())
isVariableRootMap[variable] = variable.IsOutput();
return computationNodePtr;
}
/*static*/ Variable CompositeFunction::GetMappingForNoOpOutput(const Variable& variable, bool recursive)
{
Variable mappingVariable = variable;
auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
if (ownerPrimitiveFunc && (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp))
mappingVariable = ownerPrimitiveFunc->Inputs()[0];
if (recursive && (mappingVariable != variable))
return GetMappingForNoOpOutput(mappingVariable);
else
return mappingVariable;
}
/*static*/ Variable CompositeFunction::GetMappingVariable(const Variable& variable, bool recursive)
{
Variable mappingVariable = variable;
auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
if (ownerPrimitiveFunc)
{
if (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp)
mappingVariable = GetMappingForNoOpOutput(variable);
else
{
auto ownerBlockFunc = dynamic_cast<BlockFunction*>(ownerFunc);
if (ownerBlockFunc)
mappingVariable = ownerBlockFunc->CompositeOutputsMap().at(variable);
}
}
if (recursive && (mappingVariable != variable))
return GetMappingVariable(mappingVariable);
else
return mappingVariable;
}
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
Function* function,
const std::vector<std::shared_ptr<ComputationNode<ElementType>>>& inputNodes,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
bool useMangledNamesForComputationNodes)
{
PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp))
return variableToNodeMap[GetMappingVariable(variable)];
ComputationNodeBasePtr computationNodePtr;
auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name(), useMangledNamesForComputationNodes);
std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
for (auto inputNode : inputNodes)
inputNodesBasePtrs.push_back(inputNode);
auto outputs = function->RawOutputs();
if (variable == outputs[0])
{
if (primitiveFunction)
{
auto functionInputs = function->Inputs();
auto& functionConfig = function->Attributes();
PrimitiveOpType op = primitiveFunction->OpType();
switch (op)
{
case PrimitiveOpType::Negate:
computationNodePtr = New<NegateNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sigmoid:
computationNodePtr = New<SigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Tanh:
computationNodePtr = New<TanhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Acos:
computationNodePtr = New<AcosNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Cos:
computationNodePtr = New<CosineNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Asin:
computationNodePtr = New<AsinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sin:
computationNodePtr = New<SinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Cosh:
computationNodePtr = New<CoshNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sinh:
computationNodePtr = New<SinhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ReLU:
computationNodePtr = New<RectifiedLinearNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Exp:
computationNodePtr = New<ExpNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Log:
computationNodePtr = New<LogNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sqrt:
computationNodePtr = New<SqrtNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ELU:
computationNodePtr = New<ExponentialLinearUnitNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Floor:
computationNodePtr = New<FloorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Abs:
computationNodePtr = New<AbsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Reciprocal:
computationNodePtr = New<ReciprocalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Softmax:
computationNodePtr = New<SoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Hardmax:
computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::StableSigmoid:
computationNodePtr = New<StableSigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::TransposeAxes:
{
if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec))
{
auto perm = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
for (auto& p : perm)
p = NormalizeStaticAxis(p, perm.size());
computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(perm));
}
else
{
auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();
// The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
}
break;
}
case PrimitiveOpType::Where:
{
auto dynamicAxes = variable.DynamicAxes();
auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
break;
}
case PrimitiveOpType::ToSequence:
{
auto dynamicAxes = variable.DynamicAxes();
auto internalCNTKDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
computationNodePtr = New<ToSequenceNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKDynamicAxisName);
break;
}
case PrimitiveOpType::ToSequenceLike:
computationNodePtr = New<ToSequenceLikeNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::UnpackSequence:
{
auto paddingValue = functionConfig[PrimitiveFunction::AttributeNameSequenceUnpackPaddingValue].Value<double>();
auto suppressMaskOutput = functionConfig[PrimitiveFunction::AttributeNameSequenceUnpackSuppressMaskOutput].Value<bool>();
computationNodePtr = New<UnpackSequenceNode<ElementType>>(network->GetDeviceId(), internalNodeName, (ElementType)paddingValue, suppressMaskOutput);
break;
}
case PrimitiveOpType::Slice:
{
std::vector<Axis> axis;
std::vector<int> beginIndex, endIndex, strides;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec) &&
functionConfig.Contains(PrimitiveFunction::AttributeNameBeginIndexVec) &&
functionConfig.Contains(PrimitiveFunction::AttributeNameEndIndexVec))
{
axis = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
beginIndex = AsVector<int>(functionConfig[PrimitiveFunction::AttributeNameBeginIndexVec].Value<std::vector<DictionaryValue>>());
endIndex = AsVector<int>(functionConfig[PrimitiveFunction::AttributeNameEndIndexVec].Value<std::vector<DictionaryValue>>());
if (functionConfig.Contains(PrimitiveFunction::AttributeNameSliceStridesVec))
strides = AsVector<int>(functionConfig[PrimitiveFunction::AttributeNameSliceStridesVec].Value<std::vector<DictionaryValue>>());
else
strides.resize(axis.size(), 1);
}
else if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxis) &&
functionConfig.Contains(PrimitiveFunction::AttributeNameBeginIndex) &&
functionConfig.Contains(PrimitiveFunction::AttributeNameEndIndex))
{
axis.push_back(functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>());
beginIndex.push_back(functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value<int>());
endIndex.push_back(functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value<int>());
if (functionConfig.Contains(PrimitiveFunction::AttributeNameSliceStrides))
strides.push_back(functionConfig[PrimitiveFunction::AttributeNameSliceStrides].Value<int>());
else
strides.push_back(1);
}
else
{
RuntimeError("Failed to create computation node: Slice operation with inconsistent attributes");
}
// Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
computationNodePtr = New<SliceNode<ElementType>>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis), strides);
break;
}
case PrimitiveOpType::RandomSample:
{
auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
computationNodePtr = New<RandomSampleNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
break;
}
case PrimitiveOpType::RandomSampleInclusionFrequency:
{
auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
computationNodePtr = New<RandomSampleInclusionFrequencyNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
break;
}
case PrimitiveOpType::Dropout:
{
auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value<double>();
computationNodePtr = New<DropoutNode<ElementType>>(network->GetDeviceId(), internalNodeName);
computationNodePtr->As<DropoutNode<ElementType>>()->SetDropoutRate(dropoutRate);
break;
}
case PrimitiveOpType::RandomDistribution:
{
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
auto rvtype = functionConfig[PrimitiveFunction::AttributeNameRandomDistributionType].Value<std::wstring>();
std::vector<double> randomDistributionArgs;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRandomDistributionArgs))
randomDistributionArgs = AsVector<double>(functionConfig[PrimitiveFunction::AttributeNameRandomDistributionArgs].Value<std::vector<DictionaryValue>>());
if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewShape))
{
auto shape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
computationNodePtr = New<RandomDistributionNode<ElementType>>(network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs, AsTensorShape(shape));
}
else
computationNodePtr = New<RandomDistributionNode<ElementType>>(network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs);
computationNodePtr->As<RandomDistributionNode<ElementType>>()->SetRngState(seed, offset);
break;
}
case PrimitiveOpType::Reshape:
{
auto beginAxis = Axis(0);
auto endAxis = Axis((int)functionInputs[0].Shape().Rank());
if (functionConfig.Contains(PrimitiveFunction::AttributeNameBeginAxis))
beginAxis = functionConfig[PrimitiveFunction::AttributeNameBeginAxis].Value<Axis>();
if (functionConfig.Contains(PrimitiveFunction::AttributeNameEndAxis))
endAxis = functionConfig[PrimitiveFunction::AttributeNameEndAxis].Value<Axis>();
auto replacementShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
for (size_t i = 0; i < replacementShape.Rank(); ++i)
{
if (replacementShape[i] == NDShape::InferredDimension)
replacementShape[i] = 0;
}
computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(replacementShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis));
break;
}
case PrimitiveOpType::ROIPooling:
{
PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value<NDShape>();
auto spatialScale = functionConfig[PrimitiveFunction::AttributeNameSpatialScale].Value<double>();
computationNodePtr = New<ROIPoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(roiOutputShape), spatialScale);
break;
}
case PrimitiveOpType::Pooling:
{
PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto ceilOutDim = false;
auto includePad = false;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameCeilOutDim))
{
ceilOutDim = functionConfig[PrimitiveFunction::AttributeNameCeilOutDim].Value<bool>();
}
if (functionConfig.Contains(PrimitiveFunction::AttributeNameIncludePad))
{
includePad = functionConfig[PrimitiveFunction::AttributeNameIncludePad].Value<bool>();
}
computationNodePtr = New<PoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutDim, includePad, ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::Unpooling:
{
auto unpoolingWindowShape = functionConfig[PrimitiveFunction::AttributeNameUnpoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
//We only get here after validation so it is safe to assume unpooling is max
computationNodePtr = New<MaxUnpoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::SumAll:
computationNodePtr = New<SumElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::OneHot:
{
auto numClass = functionConfig[PrimitiveFunction::AttributeNameNumClass].Value<size_t>();
auto is_sparse = functionConfig[PrimitiveFunction::AttributeNameOneHotOutputSparse].Value<bool>();
auto axis = functionConfig[PrimitiveFunction::AttributeNameOneHotAxis].Value<Axis>();
computationNodePtr = New<OneHotNode<ElementType>>(network->GetDeviceId(), numClass, is_sparse, axis.StaticAxisIndex(), internalNodeName);
break;
}
case PrimitiveOpType::Gather:
computationNodePtr = New<GatherNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ToBatch:
{
computationNodePtr = New<ToBatchAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::UnpackBatch:
{
computationNodePtr = New<UnpackBatchAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::Plus:
computationNodePtr = New<PlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LogPlus:
computationNodePtr = New<LogPlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Pow:
computationNodePtr = New<PowNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Minus:
computationNodePtr = New<MinusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ElementTimes:
computationNodePtr = New<ElementTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Equal:
computationNodePtr = New<EqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::NotEqual:
computationNodePtr = New<NotEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Less:
computationNodePtr = New<LessNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LessEqual:
computationNodePtr = New<LessEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Greater:
computationNodePtr = New<GreaterNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::GreaterEqual:
computationNodePtr = New<GreaterEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Times:
{
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value<int>();
computationNodePtr = New<TimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
break;
}
case PrimitiveOpType::TransposeTimes:
{
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
computationNodePtr = New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank);
break;
}
case PrimitiveOpType::Convolution:
{
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
NDShape dilation = { 1 };
if (functionConfig.Contains(PrimitiveFunction::AttributeNameDilation))
dilation = functionConfig[PrimitiveFunction::AttributeNameDilation].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
NDShape outputMapCount, kernelShape;
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose);
NDShape outputShape = NDShape::Unknown();
if (functionConfig.Contains(PrimitiveFunction::AttributeNameOutputShape))
outputShape = functionConfig[PrimitiveFunction::AttributeNameOutputShape].Value<NDShape>();
auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName,
AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides),
sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose,
outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape),
ImageLayoutKind::CHW, maxTempMemSizeInSamples, AsTensorShape(dilation));
break;
}
case PrimitiveOpType::CosDistance:
computationNodePtr = New<CosDistanceNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::CosDistanceWithNegativeSamples:
computationNodePtr = New<CosDistanceWithNegativeSamplesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Logistic:
computationNodePtr = New<LogisticNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::SquaredError:
computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::CrossEntropyWithSoftmax:
computationNodePtr = New<CrossEntropyWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ClassificationError:
computationNodePtr = New<ClassificationErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::EditDistanceError:
{
auto subPen = functionConfig[PrimitiveFunction::AttributeNameSubstitutionPenalty].Value<float>();
auto delPen = functionConfig[PrimitiveFunction::AttributeNameDeletionPenalty].Value<float>();
auto insPen = functionConfig[PrimitiveFunction::AttributeNameInsertionPenalty].Value<float>();
auto squashInputs = functionConfig[PrimitiveFunction::AttributeNameSquashInputs].Value<bool>();
auto tokensToIgnore = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNameTokensToIgnore].Value<std::vector<DictionaryValue>>());
computationNodePtr = New<EditDistanceErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore);
break;
}
case PrimitiveOpType::ForwardBackward:
{
auto delayContraint = functionConfig[PrimitiveFunction::AttributeNameDelayConstraint].Value<int>();
auto blankTokenId = functionConfig[PrimitiveFunction::AttributeNameBlankTokenId].Value<size_t>();
computationNodePtr = New<ForwardBackwardNode<ElementType>>(network->GetDeviceId(), internalNodeName, blankTokenId, delayContraint);
break;
}
case PrimitiveOpType::LambdaRank:
computationNodePtr = New<LambdaRankNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::NDCG:
computationNodePtr = New<NDCG1EvalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::PastValue:
case PrimitiveOpType::FutureValue:
{
Variable inputOperandVar = functionInputs[0];
Variable initialStateVar = functionInputs[1];
size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
if (op == PrimitiveOpType::PastValue)
computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
else
computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
break;
}
case PrimitiveOpType::ReduceElements:
{
bool keepDimensions = true;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameReductionKeepDimensions))
keepDimensions = functionConfig[PrimitiveFunction::AttributeNameReductionKeepDimensions].Value<bool>();
auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
std::vector<Axis> reductionAxis;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec))
{
reductionAxis = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
}
else if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxis))
{
reductionAxis.push_back(functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>());
}
else
{
RuntimeError("Failed to create computation node': Reduce operation %ls with no '%ls' or '%ls' attributes",
PrimitiveOpTypeName(op).c_str(),
PrimitiveFunction::AttributeNameAxis.c_str(),
PrimitiveFunction::AttributeNameAxisVec.c_str()
);
}
computationNodePtr = New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis), keepDimensions);
break;
}
case PrimitiveOpType::BatchNormalization:
{
auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();
computationNodePtr = New<BatchNormalizationNode<ElementType>>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::Combine:
// This operation is just a no-op and is a means to combine multiple functions to create a single Function
// whose outputs are a union of the outputs of the Functions being combined.
computationNodePtr = variableToNodeMap[variable];
break;
case PrimitiveOpType::PackedIndex:
computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::GatherPacked:
computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ScatterPacked:
computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Clip:
computationNodePtr = New<ClipNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Select:
computationNodePtr = New<IfNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Splice:
{
Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
computationNodePtr = New<RowStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
break;
}
case PrimitiveOpType::Pad:
{
auto head = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNamePaddingHead].Value<std::vector<DictionaryValue>>());
auto foot = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNamePaddingFoot].Value<std::vector<DictionaryValue>>());
auto mode = functionConfig[PrimitiveFunction::AttributeNamePaddingMode].Value<size_t>();
auto constantValue = functionConfig[PrimitiveFunction::AttributeNamePaddingConstantValue].Value<double>();
computationNodePtr = New<PaddingNode<ElementType>>(network->GetDeviceId(), internalNodeName, head, foot, (PaddingType)mode, (ElementType)constantValue);
break;
}
case PrimitiveOpType::OptimizedRNNStack:
{
auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value<bool>();
auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value<size_t>();
auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value<size_t>();
auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value<std::wstring>();
computationNodePtr = New<OptimizedRNNStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
break;
}
case PrimitiveOpType::ReconcileDynamicAxis:
{
computationNodePtr = New<ReconcileDynamicAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::LogSoftmax:
{
//This can be implemented as x => x - ReduceLogSum(x). How to do this here?
computationNodePtr = New<LogSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::Pass:
computationNodePtr = New<PassNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LabelsToGraph:
computationNodePtr = New<LabelsToGraphNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::StopGradient:
computationNodePtr = New<StopGradientNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Assign:
computationNodePtr = New<AssignNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Crop:
if (functionInputs.size() == 2)
{
if (functionConfig.Contains(PrimitiveFunction::AttributeNameOffset))
{
// Crop with given offsets.
const auto& offsets = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNameOffset].Value<std::vector<DictionaryValue>>());
if (offsets.size() != 2)
{
CNTK::LogicError("Vector of crop offsets must have size 2.");
}
computationNodePtr = New<CropNode<ElementType>>(offsets[0], offsets[1], network->GetDeviceId(), internalNodeName);
}
else
{
// Crop with two inputs and automatic offset computation.
computationNodePtr = New<CropNode<ElementType>>(network->GetDeviceId(), internalNodeName);
}
}
else if (functionInputs.size() == 4)
{
// Crop with four inputs and automatic offset computation.
computationNodePtr = New<CropNode<ElementType>>(network->GetDeviceId(), internalNodeName);
}
else
{
CNTK::LogicError("Crop node must have 2 or 4 node inputs.");
}
break;
default:
CNTK::LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
break;
}
// Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);
if (computationNodePtr->Is<INumInputs>())
{
auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
CNTK::LogicError("The Primitive Function '%S' has %d inputs while the corresponding ComputationNode expects %d inputs.",
function->AsString().c_str(),
(int)inputNodesBasePtrs.size(),
(int)computationNodeExpectedInputCount);
}
if (computationNodePtr->Is<RngUser>())
{
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
}
}
else
{
computationNodePtr = New<UserDefinedV2FunctionNode<ElementType>>(network->GetDeviceId(), internalNodeName, function->shared_from_this());
// For user defined functions, we only attach unique inputs in the internal computation network since, the UDF
// backward implementations directly compute aggregate gradient values for unique inputs
std::vector<ComputationNodeBasePtr> uniqueInputNodesBasePtrs;
for (auto inputNodeBasePtr : inputNodesBasePtrs)
{
if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end())
uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr);
}
inputNodesBasePtrs = uniqueInputNodesBasePtrs;
}
}
else
{
size_t i = 1;
while (outputs[i] != variable) i++;
assert(i < outputs.size());
computationNodePtr = New<OutputMultiplexerNode<ElementType>>(network->GetDeviceId(), CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes), i);
inputNodesBasePtrs = { variableToNodeMap[outputs[0]] };
}
network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
return computationNodePtr;
}
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
bool useMangledNamesForComputationNodes)
{
assert(variable.IsOutput());
Function* function = variable.Owner().get();
ComputationNodeBasePtr computationNodePtr;
auto& functionInputs = function->m_inputs;
DataType nonConstInputDataType = DataType::Unknown;
for (auto& inputVar : functionInputs)
{
if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown))
{
nonConstInputDataType = inputVar.GetDataType();
break;
}
}
// Create the nodes corresponding to the inputs
std::vector<std::shared_ptr<ComputationNode<ElementType>>> inputNodes;
for (auto& inputVar : functionInputs)
{
// If the inputVar is a constant and not the right DataType let's coerce it to the right type
if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType))
inputVar = Constant(inputVar).CloneAs(nonConstInputDataType);
auto baseNodePtr = GetNode(inputVar, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As<ComputationNode<ElementType>>()->shared_from_this() : nullptr);
}
BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function);
if (blockFunction)
{
// For block function, map each argument placeholder of the underlying composite to
// the computation node corresponding to the block input that the argument placeholder
// of the composite is mapped to.
auto compositeArguments = blockFunction->Composite()->Arguments();
for (auto compositeArgument : compositeArguments)
variableToNodeMap[compositeArgument] = variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());
return GetNode(variable.BlockFunctionVariableMapping(), network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
}
else
computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap, useMangledNamesForComputationNodes);
PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine))
{
for (auto inputVar : functionInputs)
isVariableRootMap[inputVar] = false;
}
return computationNodePtr;
}
std::unordered_set<Variable> CompositeFunction::NonOwnerPreservingCopy(const std::unordered_set<Variable>& outputs)
{
std::unordered_set<Variable> result;
for (auto& o : outputs)
{
Variable sanitized = o.NonCompositePreservingCopy();
result.insert(sanitized);
}
return result;
}
static bool VariableShapeMatchesNodeShape(const NDShape& varShape, const TensorShape& nodeShape)
{
if (varShape.Rank() == 0)
return (nodeShape.GetNumElements() == 1);
// Sometimes the nodeShape may have an additional trailing axis with dim==1 due to the lack of support for 0-d tensors in V1 engine.
auto adjustedNodeShape = nodeShape;
while ((adjustedNodeShape.GetRank() > varShape.Rank()) && (adjustedNodeShape.GetDim(adjustedNodeShape.GetRank() - 1) == 1))
adjustedNodeShape.TrimRankInPlace(adjustedNodeShape.GetRank() - 1);
if (!varShape.HasUnboundDimension())
return (AsNDShape(adjustedNodeShape) == varShape);
if (varShape.Rank() != adjustedNodeShape.GetRank())
return false;
for (size_t i = 0; i < varShape.Rank(); ++i)
{
if ((varShape[i] != NDShape::FreeDimension) && (varShape[i] != NDShape::InferredDimension) && (varShape[i] != adjustedNodeShape.GetDim(i)))
return false;
}
return true;
}
template <typename ElementType>
std::pair<ComputationNetworkPtr, std::unordered_map<Variable, ComputationNodeBasePtr>>
CompositeFunction::CreateComputationNetwork(const FunctionPtr& compositeFunction,
const DeviceDescriptor& device,
const std::unordered_set<Variable>& outputs,
const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
const std::unordered_set<Variable>& inputsExcludedFromGradientComputation,
bool useMangledNamesForComputationNodes)
{
auto computationNetwork = std::make_shared<ComputationNetwork>(AsCNTKImplDeviceId(device));
ComputationNetworkBuilder<ElementType> builder(*computationNetwork);
std::unordered_map<Variable, bool> isVariableRootMap;
std::unordered_map<Variable, ComputationNodeBasePtr> variableToNodeMap;
// Now recursively create the network in a top-down fashion
auto rootFunction = compositeFunction->RootFunction();
auto rootFunctionOutputs = rootFunction->RawOutputs();
for (auto rootOutput : rootFunctionOutputs)
GetNode(rootOutput, computationNetwork, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsExcludedFromGradientComputation, useMangledNamesForComputationNodes);
// We need to patch the Computation node mappings for the arguments of block functions
// since for recurrent inputs, the mappings are not fully established the first time
std::function<void(const Variable&)> PatchBlockArgumentsAndOutputsMapping;
PatchBlockArgumentsAndOutputsMapping = [&variableToNodeMap, &PatchBlockArgumentsAndOutputsMapping](const Variable& var) {
if (var.IsOutput())
{
BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(var.Owner().get());
if (blockFunction)
{
PostorderTraverseVariables(blockFunction->BlockRoot(), PatchBlockArgumentsAndOutputsMapping);
auto compositeArguments = blockFunction->Composite()->Arguments();
for (auto compositeArgument : compositeArguments)
{
auto mappingVarNodeIter = variableToNodeMap.find(compositeArgument.BlockFunctionVariableMapping());
if (mappingVarNodeIter != variableToNodeMap.end())
variableToNodeMap[compositeArgument] = mappingVarNodeIter->second;
}
auto mappingVarNodeIter = variableToNodeMap.find(var.BlockFunctionVariableMapping());
if (mappingVarNodeIter != variableToNodeMap.end())
variableToNodeMap[var] = mappingVarNodeIter->second;
}
}
};
PostorderTraverseVariables(rootFunction, PatchBlockArgumentsAndOutputsMapping);
std::function<bool(const Variable&)> IsVariableRoot = [&isVariableRootMap, &IsVariableRoot](const Variable& outputVar) {
auto mappingVariable = GetMappingVariable(outputVar);
return (isVariableRootMap[outputVar] && !IsFirstOutputOfMultiOutputFunction(mappingVariable) && ((mappingVariable == outputVar) || IsVariableRoot(mappingVariable)));
};
// If any of the function or requested outputs is not a root node, we need to explicitly
// add it to the 'output' group of the ComputationNetwork
std::unordered_set<Variable> networkOutputs(outputs);
networkOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
for (auto output : networkOutputs)
{
if (!IsVariableRoot(output))
{
auto computationNode = variableToNodeMap[output];
if (!computationNode)
InvalidArgument("One of the requested outputs '%S' for the Function '%S' forward computation is not part of the graph underlying the Function.",
output.AsString().c_str(), compositeFunction->AsString().c_str());
computationNetwork->AddToNodeGroup(L"output", computationNode);
}
}
// In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles.
// Now attach those after we have created all ComputationNodes in the network
for (auto varNodePair : variableToNodeMap)
{
auto& currentComputationNode = varNodePair.second;
if (!currentComputationNode)
LogicError("Function '%S': No computation node mapping exists for Variable %S.", compositeFunction->AsString().c_str(), varNodePair.first.AsString().c_str());
auto& currentComputationNodeInputs = currentComputationNode->GetInputs();
auto& currentVar = varNodePair.first;
if (!currentVar.IsOutput())
continue;
if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end())
{
// This ComputationNode has at least one null input which now needs to be properly attached
const PrimitiveFunction* primitiveFunc = dynamic_cast<const PrimitiveFunction*>(currentVar.Owner().get());
// Skip block primitives since they do not directly map to a computation node
if (primitiveFunc->OpType() == PrimitiveOpType::Block)
continue;
// Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
auto inputVars = primitiveFunc->Inputs();
ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars);
inputVars.resize(currentComputationNode->GetNumInputs());
std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
for (auto inputVar : inputVars)
inputNodesBasePtrs.push_back(variableToNodeMap.at(inputVar));
currentComputationNode->AttachInputs(inputNodesBasePtrs);
}
}
computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel());
computationNetwork->SetTrackGapNans(GetCheckedMode());
computationNetwork->SetIsV2Library(true);
computationNetwork->CompileNetwork();
// Set EvalTimeStamp of all nodes in the network as "outdated" to make sure that all nodes will be evaluated at least once.
// During CompileNetwork(), nodes in the network might get different timestamp values because other threads could update the global timestamp value.
// (The global timestamp value is currently shared process-wide, i.e. among all nodes of all networks.) The nodes with a higher timestamp value are
// thus incorrectly treated as "updated", and their inputs are not further evaluated by ComputationNetwork::PARTraversalFlowControlNode::ForwardProp().
// This could lead to incorrect results or crash, because the matrix of the input nodes might never be initialized for ForwardProp().
computationNetwork->SetEvalTimeStampsOutdatedWithRegardToAll();
// Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork
for (auto varNodePair : variableToNodeMap)
{
if (varNodePair.first.IsOutput())
{
auto outputVar = varNodePair.first;
auto computationNodePtr = variableToNodeMap.at(outputVar);
auto outputShape = outputVar.Shape();
auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
if (!VariableShapeMatchesNodeShape(outputShape, computationNodeSampleLayout))
{
LogicError("Function '%S': The output Variable '%S' shape '%S' does not match the SampleLayout shape '[%s]' of the corresponding ComputationNode in the network.",
compositeFunction->AsString().c_str(), outputVar.AsString().c_str(), outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str());
}
}
}
return { computationNetwork, variableToNodeMap };
}
template <typename ElementType>
ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device,
const std::unordered_set<Variable>& backpropRoots,
const std::unordered_set<Variable>& outputs,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
bool allocateNetworkMatrices)
{
// Lets purge the current computation network and regenerate the network if the CompositeFunction
// was previously compiled just for evaluation and not for gradient backpropagation.
if ((m_computationNetwork != nullptr) && (m_currentBackpropRoots.empty() && !backpropRoots.empty()))
PurgeComputationNetwork();
if (m_computationNetwork != nullptr)
{
// TODO: We should either invalidate and readapt the network if the backpropRoots change compared to what was specified when the network
// was last constructed, to just recreate a new network.
// For now just disallow changing the backpropRoots after the network is created
if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots))
LogicError("Function '%S': Changing backprop roots (Current = '%S', New = '%S') across different Forward calls on a CNTK composite Function is currently unsupported.",
AsString().c_str(), NamedListString(m_currentBackpropRoots).c_str(), NamedListString(backpropRoots).c_str());
// TODO: Support changing the device across different invocations of the forward method on a Function instance
if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device)
LogicError("Function '%S': Changing device (Current = '%S', New = %S') across different Forward calls on a CNTK composite Function is currently unsupported.",
AsString().c_str(), AsDeviceDescriptor(m_computationNetwork->GetDeviceId()).AsString().c_str(), device.AsString().c_str());
if (!backpropRoots.empty() && (inputsToExcludeGradientsFor != m_inputsExcludedFromGradientComputation))
LogicError("Function '%S': Changing the set of inputs to exclude from gradient computation, across different Forward calls on a CNTK composite Function, is currently unsupported.", AsString().c_str());
// Verify if the free dimensions of any of the arguments have changed, and if so, update the corresponding
// input ComputationNodes and rerun validation on the computation network
for (auto freeDimensionArgumentMapping : m_fullyDefinedArgumentsMap)
{
auto newShape = freeDimensionArgumentMapping.second.Shape();
auto argumentComputationNode = m_variableToNodeMap[freeDimensionArgumentMapping.first];
if (AsTensorShape(newShape) != argumentComputationNode->GetSampleLayout())
argumentComputationNode->SetDims(AsTensorShape(newShape), argumentComputationNode->HasMBLayout());
}
}
else
{
auto networkInputs = this->Inputs();
for (auto inputExcluded : inputsToExcludeGradientsFor)
{
// Only inputs of the network can be excluded from gradient computation
if (std::find(networkInputs.begin(), networkInputs.end(), inputExcluded) == networkInputs.end())
InvalidArgument("Variable '%S' specified for exclusion from gradient computation is not an input of the Function '%S'. "
"Only an input of the Function can be explicitly excluded from gradient computation.",
inputExcluded.AsString().c_str(), this->AsString().c_str());
}
m_inputsExcludedFromGradientComputation = NonOwnerPreservingCopy(inputsToExcludeGradientsFor);
m_currentBackpropRoots = NonOwnerPreservingCopy(backpropRoots);
// TODO: We currently only support one backprop root
if (backpropRoots.size() > 1)
LogicError("Function '%S': %d backprop roots specified; currently at most one backprop root is supported.", AsString().c_str(), (int)backpropRoots.size());
auto placeholders = Placeholders();
if (!placeholders.empty())
InvalidArgument("%d unbound Placeholder(s) '%S' found in the Function. "
"All Placeholders of a Function must be bound (to a variable) before performing a Forward computation.",
(int)placeholders.size(), NamedListString(placeholders).c_str());
// Lets update the composite Function graph's inputs with any inferred dimensions that
// were determined from the shapes of the supplied data
auto networkArguments = Arguments();
for (auto argument : networkArguments)
{
if (argument.Shape().HasInferredDimension())
{
auto fullyDefinedArgument = m_fullyDefinedArgumentsMap.at(argument);
for (size_t i = 0; i < argument.Shape().Rank(); ++i)
if (argument.Shape()[i] == NDShape::InferredDimension)
argument.m_dataFields->m_shape[i] = fullyDefinedArgument.Shape()[i];
}
}
// Run the final validation on the entire network once before constructing/compiling the
// internal computation network
ValidateOrUpdateOutputs();
std::tie(m_computationNetwork, m_variableToNodeMap) = CreateComputationNetwork<ElementType>(this->shared_from_this(), device, outputs, m_fullyDefinedArgumentsMap, m_inputsExcludedFromGradientComputation, /*useMangledNamesForComputationNodes =*/ false);
// Record the timestamps of Parameters and Constants
assert(m_lastRecordedTimeStamps.empty());
auto functionParameters = Parameters();
for (auto parameter : functionParameters)
m_lastRecordedTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() });
auto functionConstants = Constants();
for (auto constant : functionConstants)
m_lastRecordedTimeStamps.insert({ constant, constant.CurrentValueTimeStamp() });
// Collect parameters and constants being assigned to
PreorderTraverseFunctions(RootFunction(), [this](const FunctionPtr& function) {
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::Assign))
m_refVariables.insert(primitiveFunction->Inputs()[0]);
}, /*nestedSearchInsideBlockFunction =*/ true);
}
if (!m_networkMatricesAllocated && allocateNetworkMatrices)
{
m_allNetworkRoots = m_currentBackpropRoots;
ComputationNodeBasePtr backpropRootNode;
if (!m_currentBackpropRoots.empty())
backpropRootNode = m_variableToNodeMap.at(*m_currentBackpropRoots.begin());
// Now recursively traverse the network in a top-down fashion
auto rootFunction = RootFunction();
auto rootFunctionOutputs = rootFunction->RawOutputs();
m_allNetworkRoots.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
std::vector<ComputationNodeBasePtr> forwardRootNodes;
for (auto rootOutput : rootFunctionOutputs)
forwardRootNodes.push_back(m_variableToNodeMap.at(rootOutput));
std::vector<ComputationNodeBasePtr> forwardOutputNodes;
m_allNetworkRoots.insert(outputs.begin(), outputs.end());
for (auto output : outputs)
forwardOutputNodes.push_back(m_variableToNodeMap.at(output));
m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode);
m_networkMatricesAllocated = allocateNetworkMatrices;
}
else
{
// Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure
// in the cached computation network
for (auto output : outputs)
{
if (m_allNetworkRoots.find(output) == m_allNetworkRoots.end())
LogicError("Function '%S': Requested output '%S' is not part of the list of outputs '%S' that the Function was initially compiled for. "
"Changing requested outputs across different Forward calls is currently unsupported.",
AsString().c_str(), output.AsString().c_str(), NamedListString(m_allNetworkRoots).c_str());
}
}
return m_computationNetwork;
}
template <typename ElementType>
/*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, ComputationNodeBasePtr& computationNode, std::unordered_map<MBLayoutPtr, Variable>& layoutsPopulated)
{
NDShape inferredVariableShape;
std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableValue.first, variableValue.second, &inferredVariableShape);
if (!VariableShapeMatchesNodeShape(inferredVariableShape, computationNode->GetSampleLayout()))
CNTK::LogicError("CompositeFunction::Forward: Inferred shape '%S' of Variable '%S' does not match the corresponding computation node shape '%s'.",
inferredVariableShape.AsString().c_str(), variableValue.first.AsString().c_str(), ((std::string)computationNode->GetSampleLayout()).c_str());
// Switch the node matrix to the right matrix type
auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();
nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);
auto layout = CNTKMatrixAndMBLayout.second;
auto& nodeLayout = computationNode->GetMBLayout();
if ((layout == nullptr) != (nodeLayout == nullptr))
InvalidArgument("The layout of the specified Value for Variable '%S' is incompatible with the layout of the corresponding ComputationNode.", variableValue.first.AsString().c_str());
else if (layout)
{
if (layoutsPopulated.find(nodeLayout) == layoutsPopulated.end())
{
nodeLayout->CopyFrom(layout);
layoutsPopulated.insert({ nodeLayout, variableValue.first });
}
else
{
if (*nodeLayout != *layout)
InvalidArgument("Different minibatch layouts detected (difference in sequence lengths or count or start flags) in data specified "
"for the Function's arguments '%S' vs. '%S', though these arguments have the same dynamic axes '%S'",
variableValue.first.AsString().c_str(), layoutsPopulated.at(nodeLayout).AsString().c_str(), DynamicAxesAsString(variableValue.first.DynamicAxes(), Internal::IsReversingTensorShapesInErrorMessagesEnabled()).c_str());
}
}
}
std::unordered_map<Variable, NDShape> CompositeFunction::InferFreeDimensionsOfArguments(const std::unordered_map<Variable, ValuePtr>& arguments)
{
std::unordered_map<Variable, NDShape> inferredArgumentDimensions;
for (auto argumentValuePair : arguments)
{
NDShape inferredVarShape;
Utils::VerifyVariableValueCompatibility(argumentValuePair.first, argumentValuePair.second, &inferredVarShape);
if (inferredVarShape != argumentValuePair.first.Shape())
inferredArgumentDimensions.insert({ argumentValuePair.first , inferredVarShape });
}
if (!inferredArgumentDimensions.empty())
{
if (m_fullyDefinedArgumentsMap.empty())
{
for (auto inferredArgumentShapePair : inferredArgumentDimensions)
{
auto fullyDefinedArgument = inferredArgumentShapePair.first.Clone();
fullyDefinedArgument.m_dataFields->m_shape = inferredArgumentShapePair.second;
m_fullyDefinedArgumentsMap.insert({ inferredArgumentShapePair.first, fullyDefinedArgument });
}
if (GetCheckedMode())
m_latestFullyDefinedCompositeForCheckedModeValidation = this->Clone(ParameterCloningMethod::Share, m_fullyDefinedArgumentsMap);
}
else
{
bool argumentShapeChangedSinceLastTime = false;
for (auto inferredArgumentShapePair : inferredArgumentDimensions)
{
if (inferredArgumentShapePair.second != m_fullyDefinedArgumentsMap[inferredArgumentShapePair.first].Shape())
{
argumentShapeChangedSinceLastTime = true;
m_fullyDefinedArgumentsMap[inferredArgumentShapePair.first].m_dataFields->m_shape = inferredArgumentShapePair.second;
}
}
if (argumentShapeChangedSinceLastTime && m_latestFullyDefinedCompositeForCheckedModeValidation)
m_latestFullyDefinedCompositeForCheckedModeValidation->ValidateOrUpdateOutputs();
}
}
return inferredArgumentDimensions;
}
void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
{
std::unordered_map<MBLayoutPtr, Variable> layoutsPopulated;
std::vector<ComputationNodeBasePtr> inputNodes;
for (auto argumentValuePair : arguments)
{
auto argument = argumentValuePair.first;
auto argumentComputationNode = m_variableToNodeMap.at(argument);
assert(argumentComputationNode);
inputNodes.push_back(argumentComputationNode);
ValuePtr argumentValue = arguments.at(argument);
switch (argumentValue->GetDataType())
{
case DataType::Float:
PopulateComputationNodeValue<float>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
break;
case DataType::Double:
PopulateComputationNodeValue<double>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
break;
default:
LogicError("Function '%S' Forward: Unsupported DataType %s.", AsString().c_str(), DataTypeName(argumentValue->GetDataType()));
break;
}
}
m_computationNetwork->BumpEvalTimeStamp(inputNodes);
}
template <typename ElementType>
/*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode)
{
NDShape inferredVariableShape;
std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableGradient.first, variableGradient.second, &inferredVariableShape);
if (!VariableShapeMatchesNodeShape(inferredVariableShape, computationNode->GetSampleLayout()))
CNTK::LogicError("CompositeFunction::Backward: Inferred shape '%S' of Variable '%S' does not match the corresponding computation node shape '%s'.",
inferredVariableShape.AsString().c_str(), variableGradient.first.AsString().c_str(), ((std::string)computationNode->GetSampleLayout()).c_str());
MBLayoutPtr layout = CNTKMatrixAndMBLayout.second;
auto nodeLayout = computationNode->GetMBLayout();
if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
InvalidArgument("The layout of the specified gradient Value for Variable '%S' is incompatible with the layout computed during Forward call.",
variableGradient.first.AsString().c_str());
computationNode->As<ComputationNode<ElementType>>()->AssignGradient(*CNTKMatrixAndMBLayout.first);
}
// Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph
void CompositeFunction::PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients)
{
auto functionOutputs = RawOutputs();
for (auto gradientVarValuePair : gradients)
{
auto outputComputationNode = m_variableToNodeMap.at(gradientVarValuePair.first);
ValuePtr gradientValue = gradientVarValuePair.second;
switch (gradientValue->GetDataType())
{
case DataType::Float:
PopulateComputationNodeGradient<float>(gradientVarValuePair, outputComputationNode);
break;
case DataType::Double:
PopulateComputationNodeGradient<double>(gradientVarValuePair, outputComputationNode);
break;
default:
LogicError("Function '%S' Backward: Unsupported DataType %s.", AsString().c_str(), DataTypeName(gradientValue->GetDataType()));
break;
}
}
}
/*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient)
{
auto varShape = GetVariableShape(var.Shape(), computationNode->GetSampleLayout());
auto valueShape = PackedValue::GetUnpackedShape(varShape, var.DynamicAxes(), computationNode->GetMBLayout());
if (varValue != nullptr)
{
// TODO: The shape of the specified output Value object must match the actual output shape
if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape)))
::InvalidArgument("The shape %S of the specified Value object for Variable '%S' %s does not match the actual shape %S",
varValue->Shape().AsString().c_str(), var.AsString().c_str(), getGradient ? "gradient" : "output", valueShape.AsString().c_str());
}
ValuePtr nodeValue;
auto layout = computationNode->GetMBLayout();
switch (var.GetDataType())
{
case DataType::Float:
{
auto& matrix = getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value();
if (varValue == nullptr)
nodeValue = MakeSharedObject<PackedValue>(varShape, var.DynamicAxes(), std::make_shared<Matrix<float>>(matrix.AsReference()), layout, /*readOnly =*/ false);
else
nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var, computationNode, matrix, layout);
break;
}
case DataType::Double:
{
auto& matrix = getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value();
if (varValue == nullptr)
nodeValue = MakeSharedObject<PackedValue>(varShape, var.DynamicAxes(), std::make_shared<Matrix<double>>(matrix.AsReference()), layout, /*readOnly =*/ false);
else
nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var, computationNode, matrix, layout);
break;
}
default:
CNTK::LogicError("CompositeFunction::Forward/Backward: Unsupported DataType %s", DataTypeName(var.GetDataType()));
break;
}
if (varValue == nullptr)
varValue = nodeValue;
else
varValue->CopyFrom(*nodeValue);
}
void CompositeFunction::GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs)
{
// Now copy the Forward values of output nodes from the network to outputs' Value objects
for (auto& outputVarValuePair : outputs)
{
auto& valuePtr = outputVarValuePair.second;
auto node = m_variableToNodeMap.at(outputVarValuePair.first);
bool noValueStrorageProvided = (valuePtr == nullptr);
GetNodeOutputOrGradient(outputVarValuePair.first, valuePtr, node, false /*getGradient*/);
auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
m_existingNetworkStorageReferences.push_back(packedVarValue);
}
}
void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
{
auto networkInputs = this->Inputs();
// Now copy the gradient values of input nodes of the network to gradients' Value objects
for (auto& gradientVarValuePair : gradients)
{
// Only gradients corresponding to inputs of the network can be obtained
if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end())
InvalidArgument("Gradient requested for Variable '%S' which is not a leaf input (Input, Parameter or Constant) of the Function '%S'; this is currently unsupported.",
gradientVarValuePair.first.AsString().c_str(), this->AsString().c_str());
// Gradients can only be obtained for parameter variables or input variables that NeedsGradient
if (!gradientVarValuePair.first.NeedsGradient() || (m_inputsExcludedFromGradientComputation.find(gradientVarValuePair.first) != m_inputsExcludedFromGradientComputation.end()))
InvalidArgument("Gradient value incorrectly requested for Variable '%S', "
"an Output or Constant or Input Variable with NeedsGradient setting of false, or an input for which gradient computation was explicitly excluded.",
gradientVarValuePair.first.AsString().c_str());
auto computationNodePtr = m_variableToNodeMap.at(gradientVarValuePair.first);
if (!computationNodePtr->NeedsGradient())
LogicError("Function '%S': Backpropagated gradient value cannot be read from a Variable '%S' whose ComputationNode has NeedsGradient set to false.",
AsString().c_str(), gradientVarValuePair.first.AsString().c_str());
auto& valuePtr = gradientVarValuePair.second;
bool noValueStrorageProvided = (valuePtr == nullptr);
GetNodeOutputOrGradient(gradientVarValuePair.first, valuePtr, computationNodePtr, true /*getGradient*/);
auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
m_existingNetworkStorageReferences.push_back(packedVarValue);
}
}
const std::vector<Variable>& CompositeFunction::GetArgumentDependencies(const Variable& output)
{
if (m_perOutputVarArgumentDependencies.find(output) == m_perOutputVarArgumentDependencies.end())
{
auto sanitizedOutput = output.NonCompositePreservingCopy();
if (sanitizedOutput.IsOutput())
m_perOutputVarArgumentDependencies[sanitizedOutput] = AsComposite(sanitizedOutput.Owner())->Arguments();
else if (sanitizedOutput.IsParameter() || sanitizedOutput.IsConstant())
m_perOutputVarArgumentDependencies[sanitizedOutput] = {};
else
m_perOutputVarArgumentDependencies[sanitizedOutput] = { sanitizedOutput };
}
return m_perOutputVarArgumentDependencies[output];
}
std::unordered_map<Variable, uint64_t> CompositeFunction::GetCurrentBackpropRootsTimeStamps() const
{
std::unordered_map<Variable, uint64_t> currentBackpropRootsTimeStamps;
assert(m_computationNetwork != nullptr);
for (auto& backpropRoot : m_currentBackpropRoots)
currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp();
return currentBackpropRootsTimeStamps;
}
/*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
std::unordered_map<Variable, ValuePtr>& outputs,
const DeviceDescriptor& computeDevice,
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
{
// Validate arguments and outputs
if (outputs.empty())
InvalidArgument("At least one output has to be specified when calling Forward method of the Function '%S'.", this->AsString().c_str());
// Make sure that the DataType of the variables and corresponding values match
// TODO: We need a better way to determine the ElementType for the network
auto dataType = DataType::Unknown;
for (auto variableValuePair : arguments)
{
if (dataType == DataType::Unknown)
dataType = variableValuePair.first.GetDataType();
else if (dataType != variableValuePair.first.GetDataType())
LogicError("Function '%S' Forward: The DataType of all arguments must be same.", this->AsString().c_str());
}
if (dataType == DataType::Unknown)
{
for (auto variableValuePair : outputs)
{
if (dataType == DataType::Unknown)
dataType = variableValuePair.first.GetDataType();
}
}
std::unordered_set<Variable> requestedOutputVariables;
for (auto output : outputs)
requestedOutputVariables.insert(output.first);
std::unordered_set<Variable> functionOutputs(m_outputs.begin(), m_outputs.end());
std::unordered_set<Variable> requiredArguments;
for (auto outputVariable : requestedOutputVariables)
{
auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVariable);
requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end());
}
// We should have argument values supplied for all required argument dependencies for the requested outputs
std::vector<Variable> missingRequiredArguments;
std::unordered_map<Variable, ValuePtr> requiredArgumentValues;
for (auto requiredArgument : requiredArguments)
{
auto iter = arguments.find(requiredArgument);
if (iter == arguments.end())
missingRequiredArguments.push_back(requiredArgument);
else
requiredArgumentValues.insert(*iter);
}
if (!missingRequiredArguments.empty())
{
InvalidArgument("Values for %d required arguments '%S', that the requested output(s) '%S' depend on, have not been provided.",
(int)missingRequiredArguments.size(), NamedListString(missingRequiredArguments).c_str(), NamedListString(requestedOutputVariables).c_str());
}
if (requiredArgumentValues.size() < arguments.size())
fprintf(stderr, "WARNING: Function::Forward provided values for (%d) extra arguments which are not required for evaluating the specified Function outputs!\n", (int)(arguments.size() - requiredArgumentValues.size()));
auto inferredArgumentShapes = InferFreeDimensionsOfArguments(requiredArgumentValues);
if (dataType == DataType::Float)
GetComputationNetwork<float>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
else if (dataType == DataType::Double)
GetComputationNetwork<double>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
else
InvalidArgument("Unsupported DataType %s", DataTypeName(dataType));
// Feed data into the arguments of the network
// TODO: Avoid copying the data when possible
PopulateNetworkInputs(requiredArgumentValues);
// Copy all new values for 'dirty' attributes from functions into corresponding network nodes.
ApplyAttributeUpdates();
// Bump the timestamp of the parameter nodes whose values have changed
for (auto& timeStampRecord : m_lastRecordedTimeStamps)
{
auto variable = timeStampRecord.first;
auto prevTimeStamp = timeStampRecord.second;
auto newTimeStamp = variable.CurrentValueTimeStamp();
if (newTimeStamp > prevTimeStamp)
{
timeStampRecord.second = newTimeStamp;
m_variableToNodeMap.at(variable)->BumpEvalTimeStamp();
}
}
std::vector<ComputationNodeBasePtr> outputsToEvaluate;
for (auto outputVariable : requestedOutputVariables)
outputsToEvaluate.push_back(m_variableToNodeMap.at(outputVariable));
// The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
{
if (outputs.find(rootVarForBackprop) == outputs.end())
outputsToEvaluate.push_back(m_variableToNodeMap.at(rootVarForBackprop));
}
// Reset the timestamps of all backward roots to record an update in one or more inputs
for (auto& backpropRoot : m_currentBackpropRoots)
m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();
// Reset the timestamps of all dropout node to force recomputation of the (random) dropout mask.
list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType<DropoutNodeBase>();
for (auto& dropout : dropoutNodes)
dropout->SetEvalTimeStampOutdatedWrtAll();
// Free any previous references to the matrix storage associated with the outputsToEvaluate
ClearExistingOutputOrGradientStorageReferences();
ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);
m_computationNetwork->ForwardProp(outputsToEvaluate);
// Call PostForwardAndBackProp after ForwardProp only in evaluation mode.
if (outputsToRetainBackwardStateFor.empty())
{
m_computationNetwork->PostForwardAndBackProp(outputsToEvaluate);
RecordRefVariableUpdates();
}
else
{
m_currentOutputsToEvaluate.clear();
for (auto outputToEvaluate : outputsToEvaluate)
m_currentOutputsToEvaluate.push_back(outputToEvaluate);
}
GetNetworkOutputs(outputs);
// TODO: How to deal with the specified 'computeDevice'
BackPropStatePtr backpropStatePtr;
if (outputsToRetainBackwardStateFor.size() > 0)
backpropStatePtr = MakeSharedObject<CNTKBackPropState>(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps());
return backpropStatePtr;
}
/*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
{
auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
if (backpropState == nullptr)
InvalidArgument("Function '%S' Backward: Invalid backprop state passed.", AsString().c_str());
if (backPropagatedGradientValuesForInputs.empty())
InvalidArgument("Function '%S' Backward: List of inputs to compute gradients for, must not be empty.", AsString().c_str());
// TODO: Support multiple concurrent backprop states
std::unordered_map<Variable, uint64_t> currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps();
if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps)
LogicError("Function '%S' Backward: The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified "
"by subsequent Forward calls to the function. This is not a user error but a shortcoming of the current implementation where multiple independent "
"backprop states are not simultaneously supported.", AsString().c_str());
if (rootGradientValues.size() > 1)
LogicError("Function '%S' Backward: %d root gradient values specified; currently gradient backprop from only one of the Function Outputs is supported.",
AsString().c_str(), (int)rootGradientValues.size());
// TODO: Avoid copying the data when possible
// Zero all gradients of nodes below the root nodes
for (auto rootGradientVarValuePair : rootGradientValues)
m_computationNetwork->ZeroInputGradients(m_variableToNodeMap.at(rootGradientVarValuePair.first));
// Feed data into the arguments of the network
PopulateNetworkGradients(rootGradientValues);
// Backpropagate through the network
ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training);
auto rootComputationNodePtr = m_variableToNodeMap.at(rootGradientValues.begin()->first);
m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true);
GetNetworkGradients(backPropagatedGradientValuesForInputs);
if (m_currentOutputsToEvaluate.size() > 0)
{
m_computationNetwork->PostForwardAndBackProp(m_currentOutputsToEvaluate);
RecordRefVariableUpdates();
m_currentOutputsToEvaluate.clear();
}
// TODO: How to deal with the specified 'computeDevice'
}
void CompositeFunction::ApplyAttributeUpdates()
{
// Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input
// This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated
// w.r.t. inputs to force evaluation in each minibatch
for (auto varNodePair : m_variableToNodeMap)
{
auto var = varNodePair.first;
if (!var.IsOutput())
continue;
auto function = var.Owner();
if (function->m_dirtyAttributes.empty())
continue;
auto node = varNodePair.second;
for (const wstring& attribute : function->m_dirtyAttributes)
{
if (attribute == PrimitiveFunction::AttributeNameDropoutRate)
{
auto dropoutRate = function->m_attributes[attribute].Value<double>();
auto dropoutPtr = dynamic_cast<DropoutNodeBase*>(node.get());
assert(dropoutPtr != nullptr);
dropoutPtr->SetDropoutRate(dropoutRate);
}
else if (attribute == PrimitiveFunction::AttributeNameRngSeed)
{
auto seed = function->m_attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto rngUserPtr = dynamic_cast<RngUser*>(node.get());
assert(rngUserPtr != nullptr);
rngUserPtr->SetRngState(seed);
}
else
{
// Should never happen.
LogicError("ApplyAttributeUpdates: function '%S' specified an unsupported attribute '%S'.",
function->AsString().c_str(), attribute.c_str());
}
}
function->m_dirtyAttributes.clear();
node->SetEvalTimeStampOutdatedWrtAll();
}
}
}