https://github.com/Microsoft/CNTK
Tip revision: 4419c2b48d055af4810fab27f2441bb91b22b45f authored by Binbin Zhang on 04 June 2018, 03:50:57 UTC
add bidirectional FSMN node and make it work in NDL and add FSMN CPU forward
add bidirectional FSMN node and make it work in NDL and add FSMN CPU forward
Tip revision: 4419c2b
CompositeFunction.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "CompositeFunction.h"
#include "ComputationNetworkBuilder.h"
#include "Utils.h"
#include "ComputationNode.h"
#include "ReshapingNodes.h"
#include "EvaluationNodes.h"
#include "TrainingNodes.h"
#include "LinearAlgebraNodes.h"
#include "InputAndParamNodes.h"
#include "NonlinearityNodes.h"
#include "RecurrentNodes.h"
#include "Serialization.h"
#include "Value.h"
#include "RNNNodes.h"
#include "UserDefinedV2FunctionNode.h"
#include "BlockFunction.h"
#include "SpecialPurposeNodes.h"
#include "SequenceReshapeNodes.h"
#include "UserDefinedFunction.h"
using namespace Microsoft::MSR::CNTK;
namespace CNTK
{
/*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName";
/*static*/ std::atomic<unsigned int> CompositeFunction::s_nextAutoGeneratedDynamicAxis(0);
static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction";
Dictionary CompositeFunction::SerializeBlockComposite() const
{
Dictionary dict;
dict[versionKey] = CurrentVersion();
dict[typeKey] = s_compositeFunctionTypeValue;
dict[rootKey] = RootFunction()->Uid();
if (!Name().empty())
dict[nameKey] = Name();
dict[uidKey] = Uid();
return dict;
}
// Copy the internal state from the network into the function graph,
// specifically from RngUser nodes into the attributes dictionaries of
// the corresponding stateful primitive functions.
void CompositeFunction::UpdateInternalState() const
{
if (!m_computationNetwork)
return;
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
// TODO: same for BatchNorm
auto& outputs = primitiveFunction->RawOutputs();
if (outputs.size() != 1)
LogicError("Function '%S' UpdateInternalState: a stateful primitive function must have a single output.", AsString().c_str());
const auto& rng = m_variableToNodeMap.at(outputs[0])->As<RngUser>();
Dictionary state;
state[PrimitiveFunction::AttributeNameRngSeed] = static_cast<size_t>(rng->GetRngSeed());
state[PrimitiveFunction::AttributeNameRngOffset] = static_cast<size_t>(rng->GetRngOffset());
primitiveFunction->SetState(state);
}
}
// Generate a dictionary representing the internal (local) state of the function graph.
Dictionary CompositeFunction::GetInternalState() const
{
UpdateInternalState();
Dictionary stateDictionary;
for (auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
// TODO: same for BatchNorm
stateDictionary[primitiveFunction->Uid()] = primitiveFunction->GetState();
}
return stateDictionary;
}
/*virtual*/ Dictionary CompositeFunction::Serialize() const
{
UpdateInternalState();
Dictionary dict = SerializeBlockComposite();
// Find cycles in the graph and "break" them by inserting placeholders.
// This needs to be done on Save, since here we have easy access to the shape and
// dynamic axis info.
std::unordered_set<FunctionPtr> visitedFunctions;
std::vector<FunctionPtr> topoSortedPrimitiveFunctions;
std::vector<Variable> uniqueInputs;
std::unordered_set<std::wstring> inputUids;
std::function<void(const FunctionPtr& function)> SerializationTraversalFunc;
SerializationTraversalFunc = [&visitedFunctions, &uniqueInputs, &topoSortedPrimitiveFunctions, &inputUids, &SerializationTraversalFunc](const FunctionPtr& function) {
std::vector<Variable> functionInputs = function->Inputs();
for (const auto& input : functionInputs)
{
auto& uid = input.Uid();
if (inputUids.find(uid) != inputUids.end())
continue;
// check if this input corresponds to a cyclic edge in the graph.
// BUG: A function being visited twice does not indicate it being a cyclic edge in the graph.
// It just means there are at least 2 successors in the graph that have the function as input
bool mustBeReplaced = input.IsOutput() && (visitedFunctions.find(input.Owner()) != visitedFunctions.end());
if (mustBeReplaced)
{
auto varKind = VariableKind::Placeholder;
Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid);
uniqueInputs.push_back(var);
inputUids.insert(uid);
}
else if (!input.IsOutput())
{
// leave the input as is.
uniqueInputs.push_back(input);
inputUids.insert(uid);
}
}
visitedFunctions.insert(function);
topoSortedPrimitiveFunctions.push_back(function);
// For block functions we need to recursively traverse the underlying composite
if (function->IsBlock())
PreorderTraverseFunctions(function->BlockRoot(), SerializationTraversalFunc);
};
PreorderTraverseFunctions(RootFunction(), SerializationTraversalFunc);
std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions));
assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid());
std::vector<DictionaryValue> inputDictionaries;
inputDictionaries.reserve(uniqueInputs.size());
inputUids.clear();
for (const auto& input : uniqueInputs)
{
if (inputUids.find(input.Uid()) != inputUids.end())
LogicError("Function '%S' Serialize: Input uids must be unique.", AsString().c_str());
inputUids.insert(input.Uid());
inputDictionaries.push_back(input.Serialize());
}
dict[inputsKey] = std::move(inputDictionaries);
std::vector<DictionaryValue> functionDictionaries;
std::unordered_set<std::wstring> outputUids;
for (const auto& primitiveFunction : topoSortedPrimitiveFunctions)
{
for (const auto& output : primitiveFunction->RawOutputs())
{
if (outputUids.find(output.Uid()) != outputUids.end())
LogicError("Function '%S' Serialize: Output uids of all primitive functions in a function graph must be unique", AsString().c_str());
outputUids.insert(primitiveFunction->Uid());
}
auto functionDict = UDFUtils::IsUDF(primitiveFunction) ? UDFUtils::Serialize(primitiveFunction) : primitiveFunction->Serialize();
functionDictionaries.push_back(functionDict);
}
dict[functionsKey] = std::move(functionDictionaries);
return dict;
}
/*static*/ FunctionPtr CompositeFunction::DeserializeBlockComposite(const Dictionary& dict,
const std::unordered_set<FunctionPtr>& allPrimitiveFunctions,
const std::unordered_map<Variable, Variable>& allPlaceholderReplacements,
const CNTK::DeviceDescriptor& device)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, rootKey, uidKey };
ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);
const auto& rootUid = dict[rootKey].Value<std::wstring>();
std::wstring name = L"";
if (dict.Contains(nameKey))
name = dict[nameKey].Value<std::wstring>();
const auto& uid = dict[uidKey].Value<std::wstring>();
FunctionPtr root = *std::find_if(allPrimitiveFunctions.begin(), allPrimitiveFunctions.end(), [&rootUid](const FunctionPtr& func) {
return func->Uid() == rootUid;
});
// Find the subset of placeholder replacements that apply for this composite
FunctionPtr composite = CompositeFunction::Create(root, name, uid);
std::unordered_map<Variable, Variable> placeholderReplacements;
do
{
placeholderReplacements.clear();
auto compositePlaceholders = composite->Placeholders();
for (auto placeholder : compositePlaceholders)
{
if (allPlaceholderReplacements.find(placeholder) != allPlaceholderReplacements.end())
placeholderReplacements.insert({ placeholder, allPlaceholderReplacements.at(placeholder) });
}
if (placeholderReplacements.size() > 0)
composite = composite->ReplacePlaceholders(placeholderReplacements);
} while (placeholderReplacements.size() > 0);
return composite;
}
/*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { inputsKey, functionsKey };
size_t version = ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);
const auto& inputs = dict[inputsKey].Value<vector<DictionaryValue>>();
std::unordered_map<std::wstring, Variable> uidToInputMap(inputs.size());
for (const auto& dictionaryValue : inputs)
{
const auto& dictionary = dictionaryValue.Value<Dictionary>();
const auto& inputVar = Variable::Deserialize(dictionary, device);
if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end())
{
CNTK::LogicError("CompositeFunction::Deserialize: Input uids are not unique (several inputs share '%ls' uid) (%s).",
inputVar.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
}
uidToInputMap[inputVar.Uid()] = inputVar;
}
const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();
std::unordered_map<Variable, Variable> allPlaceholderReplacements;
std::unordered_set<FunctionPtr> allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created.
for (const auto& dictionaryValue : functions)
{
auto functionDict = dictionaryValue.Value<Dictionary>();
FunctionPtr root = UDFUtils::IsUDF(functionDict) ?
UDFUtils::Deserialize(functionDict, uidToInputMap, device) :
PrimitiveFunction::Deserialize(functionDict, uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device);
allPrimitiveFunctions.insert(root);
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(root.get());
if (primitiveFunction != nullptr && primitiveFunction->OpType() == PrimitiveOpType::Combine)
{
// Since Combine simply forwards other functions' outputs, all of its outputs
// should already be in the uidToInputMap.
continue;
}
for (const auto& output : root->RawOutputs())
{
const auto& it = uidToInputMap.find(output.Uid());
if (it != uidToInputMap.end())
{
if (!it->second.IsPlaceholder())
{
CNTK::LogicError("CompositeFunction::Deserialize: Unexpected variable '%S' instead of a Placeholder (uid = %ls) (%s).",
it->second.AsString().c_str(), it->second.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
}
allPlaceholderReplacements[it->second] = output;
}
else
{
uidToInputMap[output.Uid()] = output;
}
}
}
// starting with the serialization version = 3, the state is preserved inside the attribute dictionaries of the
// corresponding primitive functions. Earlier versions have a dedicated key-value pair in the composite function dict.
if (version < 3)
RestoreStatefulFunctions(version, dict, allPrimitiveFunctions);
return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
}
void CompositeFunction::RestoreStatefulFunctions(size_t version, const Dictionary& dict, std::unordered_set<FunctionPtr> functions)
{
Dictionary stateDictionary;
if (dict.Contains(stateKey))
stateDictionary = dict[stateKey].Value<Dictionary>();
for (auto& function : functions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
if (stateDictionary.Contains(primitiveFunction->Uid()))
{
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
// Add key-value pairs expected by the SetState method to the state dictionary.
state[PrimitiveFunction::AttributeNameRngSeed] = state[rngSeedKey].Value<size_t>();
state[PrimitiveFunction::AttributeNameRngOffset] = state[rngOffsetKey].Value<size_t>();
primitiveFunction->SetState(state);
}
else
{
if (GetTraceLevel() >= TraceLevel::Warning)
{
// TODO: all logging functionality should be refactored to live in a logging utility class.
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
"when deserializing from a dictionary (version=%zu). "
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
}
// Create state from scratch, so that function attributes contain all the required key-value pairs.
Dictionary state;
state[PrimitiveFunction::AttributeNameRngSeed] = Internal::GenerateRandomSeed(true);
state[PrimitiveFunction::AttributeNameRngOffset] = size_t(0);
primitiveFunction->SetState(state);
}
}
}
void CompositeFunction::CopyState(const CompositeFunction& source)
{
// Collect a vector of stateful funciton uids using a pre-order traversal of a function graphs.
auto collectStatefulFunctionUIDs = [](const Function& function) -> vector<wstring> {
vector<wstring> uids;
PreorderTraverseFunctions(function.RootFunction(), [&uids](const FunctionPtr& funcPtr) {
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
if (primitiveFunction && primitiveFunction->IsStateful())
{
uids.push_back(funcPtr->Uid());
}
}, true);
return uids;
};
auto theirUIDs = collectStatefulFunctionUIDs(source);
auto ourUIDs = collectStatefulFunctionUIDs(*this);
if (theirUIDs.size() != ourUIDs.size())
CNTK::LogicError("Cannot copy internal state, the source and the destination contain different number of stateful functions.");
auto state = source.GetInternalState();
if (theirUIDs == ourUIDs)
{
// uids are identialy, no need to remap.
SetInternalState(state);
return;
}
// build a map of souce funtion to the destination (this) function UIDs.
map<wstring, wstring> uidMap;
for (auto i = 0; i < theirUIDs.size(); i++)
uidMap[theirUIDs[i]] = ourUIDs[i];
Dictionary remappedState;
for (auto& kv : state)
remappedState[uidMap[kv.first]] = kv.second;
SetInternalState(remappedState);
}
void CompositeFunction::SetInternalState(const Dictionary& state)
{
if (state.Size() == 0)
return;
for (const auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (!primitiveFunction || !primitiveFunction->IsStateful())
continue;
auto functionState = state[primitiveFunction->Uid()].Value<Dictionary>();
primitiveFunction->SetState(functionState);
if (!m_computationNetwork)
continue;
auto seed = functionState[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionState[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
// copy the state directly into the network
for (const auto& output : function->RawOutputs())
{
auto node = m_variableToNodeMap.at(output);
node->As<RngUser>()->SetRngState(seed, offset);
}
}
}
// Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions
// underlying the specified 'variable' and return the ComputationNode instance that corresponds to the
// top level 'variable'
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
bool useMangledNamesForComputationNodes)
{
auto iter = variableToNodeMap.find(variable);
if (iter != variableToNodeMap.end())
{
isVariableRootMap[variable] = false;
return iter->second;
}
// The DataType, Shape and DynamicAxes of the variable must be known by now
if (variable.GetDataType() == DataType::Unknown)
InvalidArgument("Variable '%S' with unknown DataType found when compiling the Function graph.", variable.AsString().c_str());
if (variable.Shape().IsUnknown())
InvalidArgument("Variable '%S' with unknown shape found when compiling the Function graph.", variable.AsString().c_str());
if (variable.DynamicAxes() == Axis::UnknownDynamicAxes())
InvalidArgument("Variable '%S' with unknown dynamic axes found when compiling the Function graph.", variable.AsString().c_str());
// Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs
variableToNodeMap[variable] = nullptr;
std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes);
if (variable.IsParameter() || variable.IsConstant())
{
if (variable.Shape().HasInferredDimension())
InvalidArgument("Parameter or Constant '%S' with unresolved shape %S found when compiling the Function graph.", variable.AsString().c_str(), variable.Shape().AsString().c_str());
computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape()));
network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end()))
computationNodePtr->SetLearningRateMultiplier(0.0);
NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();
if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
computationNodePtr->Value() = valueMatrix->AsReference();
else // Constant: if initialized data lives on wrong device, make a copy to the right one (copy is OK since it's constant)
{
// TODO: the following two lines are a workaround for a bug in the Math library
// (AssignValuesOf throws when source and destination matrices reside on different GPU devices).
// Once this bug is fixed, change to
// Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
Matrix<ElementType> clonedMatrix(network->GetDeviceId());
clonedMatrix.SwitchToMatrixType(valueMatrix->GetMatrixType(), valueMatrix->GetFormat(), false);
clonedMatrix.AssignValuesOf(*valueMatrix);
computationNodePtr->Value() = std::move(clonedMatrix);
}
}
else if (variable.IsInput())
{
auto fullyDefinedArgumentVar = variable;
if (fullyDefinedArgumentVar.Shape().HasFreeDimension() && (fullyDefinedArgumentsMap.find(fullyDefinedArgumentVar) != fullyDefinedArgumentsMap.end()))
fullyDefinedArgumentVar = fullyDefinedArgumentsMap.at(fullyDefinedArgumentVar);
if (fullyDefinedArgumentVar.Shape().HasUnboundDimension())
InvalidArgument("Input Variable '%S' with unresolved shape %S found when compiling the Function graph.", fullyDefinedArgumentVar.AsString().c_str(), fullyDefinedArgumentVar.Shape().AsString().c_str());
// TODO: Input variables currently are required to have the default batch axis
auto dynamicAxes = variable.DynamicAxes();
auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis());
if (IsSparseInput(variable) && (foundDefaultBatchAxis == dynamicAxes.end()))
CNTK::LogicError("Sparse Input Variable '%S' found without a DefaultBatchAxis dynamic axis; this is currently unsupported.", variable.AsString().c_str());
if (!dynamicAxes.empty() && (dynamicAxes.back() != Axis::DefaultBatchAxis()))
CNTK::LogicError("Input Variable '%S' does not have the DefaultBatchAxis as its last dynamic axis.", variable.AsString().c_str());
// TODO: Support inputs with > 1 dynamic axes
if (dynamicAxes.size() > 2)
CNTK::LogicError("Input Variable '%S' has %d dynamic axes; currently only inputs with <= 2 dynamic axes are supported.",
variable.AsString().c_str(), (int)dynamicAxes.size());
if (!dynamicAxes.empty())
{
// Construct the dynamic axis name to be used internally for the CNTK InputNodes
std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});
if (IsSparseInput(variable))
computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()), internalDynamicAxisName);
else
computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()), internalDynamicAxisName);
if (variable.NeedsGradient() && (inputsToExcludeGradientsFor.find(variable) == inputsToExcludeGradientsFor.end()))
{
// Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default
// gradients are not computed for Input nodes
computationNodePtr->SetLearningRateMultiplier(0.00001f);
}
}
else
{
computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(fullyDefinedArgumentVar.Shape()));
network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end()))
computationNodePtr->SetLearningRateMultiplier(0.0);
}
if (variable.Shape().HasFreeDimension())
computationNodePtr->MarkNeedsDynamicValidation();
}
else
{
assert(variable.IsOutput());
auto outputVariableNode = GetOutputVariableNode(variable, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
// Can be null in case of loops with f.output == f.input.
// Such loops cannot be handled, so we leave nullptr as computational node.
if (outputVariableNode)
computationNodePtr = outputVariableNode->template As<ComputationNode<ElementType>>()->shared_from_this();
else
computationNodePtr = nullptr;
}
variableToNodeMap[variable] = computationNodePtr;
if (isVariableRootMap.find(variable) == isVariableRootMap.end())
isVariableRootMap[variable] = variable.IsOutput();
return computationNodePtr;
}
/*static*/ Variable CompositeFunction::GetMappingForNoOpOutput(const Variable& variable, bool recursive)
{
Variable mappingVariable = variable;
auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
if (ownerPrimitiveFunc && (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp))
mappingVariable = ownerPrimitiveFunc->Inputs()[0];
if (recursive && (mappingVariable != variable))
return GetMappingForNoOpOutput(mappingVariable);
else
return mappingVariable;
}
/*static*/ Variable CompositeFunction::GetMappingVariable(const Variable& variable, bool recursive)
{
Variable mappingVariable = variable;
auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
if (ownerPrimitiveFunc)
{
if (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp)
mappingVariable = GetMappingForNoOpOutput(variable);
else
{
auto ownerBlockFunc = dynamic_cast<BlockFunction*>(ownerFunc);
if (ownerBlockFunc)
mappingVariable = ownerBlockFunc->CompositeOutputsMap().at(variable);
}
}
if (recursive && (mappingVariable != variable))
return GetMappingVariable(mappingVariable);
else
return mappingVariable;
}
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
Function* function,
const std::vector<std::shared_ptr<ComputationNode<ElementType>>>& inputNodes,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
bool useMangledNamesForComputationNodes)
{
PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp))
return variableToNodeMap[GetMappingVariable(variable)];
ComputationNodeBasePtr computationNodePtr;
auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name(), useMangledNamesForComputationNodes);
std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
for (auto inputNode : inputNodes)
inputNodesBasePtrs.push_back(inputNode);
auto outputs = function->RawOutputs();
if (variable == outputs[0])
{
if (primitiveFunction)
{
auto functionInputs = function->Inputs();
auto& functionConfig = function->Attributes();
PrimitiveOpType op = primitiveFunction->OpType();
switch (op)
{
case PrimitiveOpType::Negate:
computationNodePtr = New<NegateNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sigmoid:
computationNodePtr = New<SigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Tanh:
computationNodePtr = New<TanhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Cos:
computationNodePtr = New<CosineNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sin:
computationNodePtr = New<SinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ReLU:
computationNodePtr = New<RectifiedLinearNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Exp:
computationNodePtr = New<ExpNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Log:
computationNodePtr = New<LogNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sqrt:
computationNodePtr = New<SqrtNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ELU:
computationNodePtr = New<ExponentialLinearUnitNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Floor:
computationNodePtr = New<FloorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Abs:
computationNodePtr = New<AbsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Reciprocal:
computationNodePtr = New<ReciprocalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Softmax:
computationNodePtr = New<SoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Hardmax:
computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::StableSigmoid:
computationNodePtr = New<StableSigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::TransposeAxes:
{
if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec))
{
auto perm = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
for (auto& p : perm)
p = NormalizeStaticAxis(p, perm.size());
computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(perm));
}
else
{
auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();
// The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
}
break;
}
case PrimitiveOpType::Where:
{
auto dynamicAxes = variable.DynamicAxes();
auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
break;
}
case PrimitiveOpType::ToSequence:
{
auto dynamicAxes = variable.DynamicAxes();
auto internalCNTKDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
computationNodePtr = New<ToSequenceNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKDynamicAxisName);
break;
}
case PrimitiveOpType::ToSequenceLike:
computationNodePtr = New<ToSequenceLikeNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::UnpackSequence:
{
auto paddingValue = functionConfig[PrimitiveFunction::AttributeNameSequenceUnpackPaddingValue].Value<double>();
auto suppressMaskOutput = functionConfig[PrimitiveFunction::AttributeNameSequenceUnpackSuppressMaskOutput].Value<bool>();
computationNodePtr = New<UnpackSequenceNode<ElementType>>(network->GetDeviceId(), internalNodeName, (ElementType)paddingValue, suppressMaskOutput);
break;
}
case PrimitiveOpType::Slice:
{
std::vector<Axis> axis;
std::vector<int> beginIndex, endIndex;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxisVec) &&
functionConfig.Contains(PrimitiveFunction::AttributeNameBeginIndexVec) &&
functionConfig.Contains(PrimitiveFunction::AttributeNameEndIndexVec))
{
axis = AsVector<Axis>(functionConfig[PrimitiveFunction::AttributeNameAxisVec].Value<std::vector<DictionaryValue>>());
beginIndex = AsVector<int>(functionConfig[PrimitiveFunction::AttributeNameBeginIndexVec].Value<std::vector<DictionaryValue>>());
endIndex = AsVector<int>(functionConfig[PrimitiveFunction::AttributeNameEndIndexVec].Value<std::vector<DictionaryValue>>());
}
else if (functionConfig.Contains(PrimitiveFunction::AttributeNameAxis) &&
functionConfig.Contains(PrimitiveFunction::AttributeNameBeginIndex) &&
functionConfig.Contains(PrimitiveFunction::AttributeNameEndIndex))
{
axis.push_back(functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>());
beginIndex.push_back(functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value<int>());
endIndex.push_back(functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value<int>());
}
else
{
RuntimeError("Failed to create computation node: Slice operation with inconsistent attributes");
}
// Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
computationNodePtr = New<SliceNode<ElementType>>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis));
break;
}
case PrimitiveOpType::RandomSample:
{
auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
computationNodePtr = New<RandomSampleNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
break;
}
case PrimitiveOpType::RandomSampleInclusionFrequency:
{
auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
computationNodePtr = New<RandomSampleInclusionFrequencyNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
break;
}
case PrimitiveOpType::Dropout:
{
auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value<double>();
computationNodePtr = New<DropoutNode<ElementType>>(network->GetDeviceId(), internalNodeName);
computationNodePtr->As<DropoutNode<ElementType>>()->SetDropoutRate(dropoutRate);
break;
}
case PrimitiveOpType::RandomDistribution:
{
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
auto rvtype = functionConfig[PrimitiveFunction::AttributeNameRandomDistributionType].Value<std::wstring>();
std::vector<double> randomDistributionArgs;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRandomDistributionArgs))
randomDistributionArgs = AsVector<double>(functionConfig[PrimitiveFunction::AttributeNameRandomDistributionArgs].Value<std::vector<DictionaryValue>>());
if (functionConfig.Contains(PrimitiveFunction::AttributeNameNewShape))
{
auto shape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
computationNodePtr = New<RandomDistributionNode<ElementType>>(network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs, AsTensorShape(shape));
}
else
computationNodePtr = New<RandomDistributionNode<ElementType>>(network->GetDeviceId(), internalNodeName, rvtype, randomDistributionArgs);
computationNodePtr->As<RandomDistributionNode<ElementType>>()->SetRngState(seed, offset);
break;
}
case PrimitiveOpType::Reshape:
{
auto beginAxis = Axis(0);
auto endAxis = Axis((int)functionInputs[0].Shape().Rank());
if (functionConfig.Contains(PrimitiveFunction::AttributeNameBeginAxis))
beginAxis = functionConfig[PrimitiveFunction::AttributeNameBeginAxis].Value<Axis>();
if (functionConfig.Contains(PrimitiveFunction::AttributeNameEndAxis))
endAxis = functionConfig[PrimitiveFunction::AttributeNameEndAxis].Value<Axis>();
auto replacementShape = functionConfig[PrimitiveFunction::AttributeNameNewShape].Value<NDShape>();
for (size_t i = 0; i < replacementShape.Rank(); ++i)
{
if (replacementShape[i] == NDShape::InferredDimension)
replacementShape[i] = 0;
}
computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(replacementShape), AsCNTKInternalAxisIdx(beginAxis), AsCNTKInternalAxisIdx(endAxis));
break;
}
case PrimitiveOpType::ROIPooling:
{
auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value<NDShape>();
computationNodePtr = New<ROIPoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape));
break;
}
case PrimitiveOpType::Pooling:
{
PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto ceilOutDim = false;
auto includePad = false;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameCeilOutDim))
{
ceilOutDim = functionConfig[PrimitiveFunction::AttributeNameCeilOutDim].Value<bool>();
}
if (functionConfig.Contains(PrimitiveFunction::AttributeNameIncludePad))
{
includePad = functionConfig[PrimitiveFunction::AttributeNameIncludePad].Value<bool>();
}
computationNodePtr = New<PoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ceilOutDim, includePad, ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::Unpooling:
{
auto unpoolingWindowShape = functionConfig[PrimitiveFunction::AttributeNameUnpoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
//We only get here after validation so it is safe to assume unpooling is max
computationNodePtr = New<MaxUnpoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::SumAll:
computationNodePtr = New<SumElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::OneHot:
{
auto numClass = functionConfig[PrimitiveFunction::AttributeNameNumClass].Value<size_t>();
auto is_sparse = functionConfig[PrimitiveFunction::AttributeNameOneHotOutputSparse].Value<bool>();
auto axis = functionConfig[PrimitiveFunction::AttributeNameOneHotAxis].Value<Axis>();
computationNodePtr = New<OneHotNode<ElementType>>(network->GetDeviceId(), numClass, is_sparse, axis.StaticAxisIndex(), internalNodeName);
break;
}
case PrimitiveOpType::Gather:
computationNodePtr = New<GatherNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Plus:
computationNodePtr = New<PlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LogPlus:
computationNodePtr = New<LogPlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Pow:
computationNodePtr = New<PowNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Minus:
computationNodePtr = New<MinusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ElementTimes:
computationNodePtr = New<ElementTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Equal:
computationNodePtr = New<EqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::NotEqual:
computationNodePtr = New<NotEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Less:
computationNodePtr = New<LessNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LessEqual:
computationNodePtr = New<LessEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Greater:
computationNodePtr = New<GreaterNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::GreaterEqual:
computationNodePtr = New<GreaterEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Times:
{
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value<int>();
computationNodePtr = New<TimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
break;
}
case PrimitiveOpType::TransposeTimes:
{
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
computationNodePtr = New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank);
break;
}
case PrimitiveOpType::Convolution:
{
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
NDShape outputMapCount, kernelShape;
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape(), transpose);
NDShape outputShape = NDShape::Unknown;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameOutputShape))
outputShape = functionConfig[PrimitiveFunction::AttributeNameOutputShape].Value<NDShape>();
auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName,
AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides),
sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose,
outputShape.IsUnknown() ? TensorShape(0) : AsTensorShape(outputShape),
ImageLayoutKind::CHW, maxTempMemSizeInSamples);
break;
}
case PrimitiveOpType::CosDistance:
computationNodePtr = New<CosDistanceNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::CosDistanceWithNegativeSamples:
computationNodePtr = New<CosDistanceWithNegativeSamplesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Logistic:
computationNodePtr = New<LogisticNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::SquaredError:
computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::CrossEntropyWithSoftmax:
computationNodePtr = New<CrossEntropyWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ClassificationError:
computationNodePtr = New<ClassificationErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::EditDistanceError:
{
auto subPen = functionConfig[PrimitiveFunction::AttributeNameSubstitutionPenalty].Value<float>();
auto delPen = functionConfig[PrimitiveFunction::AttributeNameDeletionPenalty].Value<float>();
auto insPen = functionConfig[PrimitiveFunction::AttributeNameInsertionPenalty].Value<float>();
auto squashInputs = functionConfig[PrimitiveFunction::AttributeNameSquashInputs].Value<bool>();
auto tokensToIgnore = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNameTokensToIgnore].Value<std::vector<DictionaryValue>>());
computationNodePtr = New<EditDistanceErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName, subPen, delPen, insPen, squashInputs, tokensToIgnore);
break;
}
case PrimitiveOpType::ForwardBackward:
{
auto delayContraint = functionConfig[PrimitiveFunction::AttributeNameDelayConstraint].Value<int>();
auto blankTokenId = functionConfig[PrimitiveFunction::AttributeNameBlankTokenId].Value<size_t>();
computationNodePtr = New<ForwardBackwardNode<ElementType>>(network->GetDeviceId(), internalNodeName, blankTokenId, delayContraint);
break;
}
case PrimitiveOpType::LambdaRank:
computationNodePtr = New<LambdaRankNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::NDCG:
computationNodePtr = New<NDCG1EvalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::PastValue:
case PrimitiveOpType::FutureValue:
{
Variable inputOperandVar = functionInputs[0];
Variable initialStateVar = functionInputs[1];
size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
if (op == PrimitiveOpType::PastValue)
computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
else
computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
break;
}
case PrimitiveOpType::ReduceElements:
{
bool keepDimensions = true;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameReductionKeepDimensions))
keepDimensions = functionConfig[PrimitiveFunction::AttributeNameReductionKeepDimensions].Value<bool>();
auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
computationNodePtr = New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis), keepDimensions);
break;
}
case PrimitiveOpType::BatchNormalization:
{
auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();
computationNodePtr = New<BatchNormalizationNode<ElementType>>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::Combine:
// This operation is just a no-op and is a means to combine multiple functions to create a single Function
// whose outputs are a union of the outputs of the Functions being combined.
computationNodePtr = variableToNodeMap[variable];
break;
case PrimitiveOpType::PackedIndex:
computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::GatherPacked:
computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ScatterPacked:
computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Clip:
computationNodePtr = New<ClipNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Select:
computationNodePtr = New<IfNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Splice:
{
Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
computationNodePtr = New<RowStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
break;
}
case PrimitiveOpType::OptimizedRNNStack:
{
auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value<bool>();
auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value<size_t>();
auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value<size_t>();
auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value<std::wstring>();
computationNodePtr = New<OptimizedRNNStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
break;
}
case PrimitiveOpType::ReconcileDynamicAxis:
{
computationNodePtr = New<ReconcileDynamicAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::LogSoftmax:
{
//This can be implemented as x => x - ReduceLogSum(x). How to do this here?
computationNodePtr = New<LogSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::Pass:
computationNodePtr = New<PassNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LabelsToGraph:
computationNodePtr = New<LabelsToGraphNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::StopGradient:
computationNodePtr = New<StopGradientNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Assign:
computationNodePtr = New<AssignNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
default:
CNTK::LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
break;
}
// Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);
if (computationNodePtr->Is<INumInputs>())
{
auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
CNTK::LogicError("The Primitive Function '%S' has %d inputs while the corresponding ComputationNode expects %d inputs.",
function->AsString().c_str(),
(int)inputNodesBasePtrs.size(),
(int)computationNodeExpectedInputCount);
}
if (computationNodePtr->Is<RngUser>())
{
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
}
}
else
{
computationNodePtr = New<UserDefinedV2FunctionNode<ElementType>>(network->GetDeviceId(), internalNodeName, function->shared_from_this());
// For user defined functions, we only attach unique inputs in the internal computation network since, the UDF
// backward implementations directly compute aggregate gradient values for unique inputs
std::vector<ComputationNodeBasePtr> uniqueInputNodesBasePtrs;
for (auto inputNodeBasePtr : inputNodesBasePtrs)
{
if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end())
uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr);
}
inputNodesBasePtrs = uniqueInputNodesBasePtrs;
}
}
else
{
size_t i = 1;
while (outputs[i] != variable) i++;
assert(i < outputs.size());
computationNodePtr = New<OutputMultiplexerNode<ElementType>>(network->GetDeviceId(), CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name(), useMangledNamesForComputationNodes), i);
inputNodesBasePtrs = { variableToNodeMap[outputs[0]] };
}
network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
return computationNodePtr;
}
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
bool useMangledNamesForComputationNodes)
{
assert(variable.IsOutput());
Function* function = variable.Owner().get();
ComputationNodeBasePtr computationNodePtr;
auto& functionInputs = function->m_inputs;
DataType nonConstInputDataType = DataType::Unknown;
for (auto& inputVar : functionInputs)
{
if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown))
{
nonConstInputDataType = inputVar.GetDataType();
break;
}
}
// Create the nodes corresponding to the inputs
std::vector<std::shared_ptr<ComputationNode<ElementType>>> inputNodes;
for (auto& inputVar : functionInputs)
{
// If the inputVar is a constant and not the right DataType let's coerce it to the right type
if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType))
inputVar = Constant(inputVar).CloneAs(nonConstInputDataType);
auto baseNodePtr = GetNode(inputVar, network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As<ComputationNode<ElementType>>()->shared_from_this() : nullptr);
}
BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function);
if (blockFunction)
{
// For block function, map each argument placeholder of the underlying composite to
// the computation node corresponding to the block input that the argument placeholder
// of the composite is mapped to.
auto compositeArguments = blockFunction->Composite()->Arguments();
for (auto compositeArgument : compositeArguments)
variableToNodeMap[compositeArgument] = variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());
return GetNode(variable.BlockFunctionVariableMapping(), network, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor, useMangledNamesForComputationNodes);
}
else
computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap, useMangledNamesForComputationNodes);
PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine))
{
for (auto inputVar : functionInputs)
isVariableRootMap[inputVar] = false;
}
return computationNodePtr;
}
std::unordered_set<Variable> CompositeFunction::NonOwnerPreservingCopy(const std::unordered_set<Variable>& outputs)
{
std::unordered_set<Variable> result;
for (auto& o : outputs)
{
Variable sanitized = o.NonCompositePreservingCopy();
result.insert(sanitized);
}
return result;
}
static bool VariableShapeMatchesNodeShape(const NDShape& varShape, const TensorShape& nodeShape)
{
if (varShape.Rank() == 0)
return (nodeShape.GetNumElements() == 1);
// Sometimes the nodeShape may have an additional trailing axis with dim==1 due to the lack of support for 0-d tensors in V1 engine.
auto adjustedNodeShape = nodeShape;
while ((adjustedNodeShape.GetRank() > varShape.Rank()) && (adjustedNodeShape.GetDim(adjustedNodeShape.GetRank() - 1) == 1))
adjustedNodeShape.TrimRankInPlace(adjustedNodeShape.GetRank() - 1);
if (!varShape.HasUnboundDimension())
return (AsNDShape(adjustedNodeShape) == varShape);
if (varShape.Rank() != adjustedNodeShape.GetRank())
return false;
for (size_t i = 0; i < varShape.Rank(); ++i)
{
if ((varShape[i] != NDShape::FreeDimension) && (varShape[i] != NDShape::InferredDimension) && (varShape[i] != adjustedNodeShape.GetDim(i)))
return false;
}
return true;
}
template <typename ElementType>
std::pair<ComputationNetworkPtr, std::unordered_map<Variable, ComputationNodeBasePtr>>
CompositeFunction::CreateComputationNetwork(const FunctionPtr& compositeFunction,
const DeviceDescriptor& device,
const std::unordered_set<Variable>& outputs,
const std::unordered_map<Variable, Variable>& fullyDefinedArgumentsMap,
const std::unordered_set<Variable>& inputsExcludedFromGradientComputation,
bool useMangledNamesForComputationNodes)
{
auto computationNetwork = std::make_shared<ComputationNetwork>(AsCNTKImplDeviceId(device));
ComputationNetworkBuilder<ElementType> builder(*computationNetwork);
std::unordered_map<Variable, bool> isVariableRootMap;
std::unordered_map<Variable, ComputationNodeBasePtr> variableToNodeMap;
// Now recursively create the network in a top-down fashion
auto rootFunction = compositeFunction->RootFunction();
auto rootFunctionOutputs = rootFunction->RawOutputs();
for (auto rootOutput : rootFunctionOutputs)
GetNode(rootOutput, computationNetwork, builder, fullyDefinedArgumentsMap, variableToNodeMap, isVariableRootMap, inputsExcludedFromGradientComputation, useMangledNamesForComputationNodes);
// We need to patch the Computation node mappings for the arguments of block functions
// since for recurrent inputs, the mappings are not fully established the first time
std::function<void(const Variable&)> PatchBlockArgumentsAndOutputsMapping;
PatchBlockArgumentsAndOutputsMapping = [&variableToNodeMap, &PatchBlockArgumentsAndOutputsMapping](const Variable& var) {
if (var.IsOutput())
{
BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(var.Owner().get());
if (blockFunction)
{
PostorderTraverseVariables(blockFunction->BlockRoot(), PatchBlockArgumentsAndOutputsMapping);
auto compositeArguments = blockFunction->Composite()->Arguments();
for (auto compositeArgument : compositeArguments)
{
auto mappingVarNodeIter = variableToNodeMap.find(compositeArgument.BlockFunctionVariableMapping());
if (mappingVarNodeIter != variableToNodeMap.end())
variableToNodeMap[compositeArgument] = mappingVarNodeIter->second;
}
auto mappingVarNodeIter = variableToNodeMap.find(var.BlockFunctionVariableMapping());
if (mappingVarNodeIter != variableToNodeMap.end())
variableToNodeMap[var] = mappingVarNodeIter->second;
}
}
};
PostorderTraverseVariables(rootFunction, PatchBlockArgumentsAndOutputsMapping);
std::function<bool(const Variable&)> IsVariableRoot = [&isVariableRootMap, &IsVariableRoot](const Variable& outputVar) {
auto mappingVariable = GetMappingVariable(outputVar);
return (isVariableRootMap[outputVar] && !IsFirstOutputOfMultiOutputFunction(mappingVariable) && ((mappingVariable == outputVar) || IsVariableRoot(mappingVariable)));
};
// If any of the function or requested outputs is not a root node, we need to explicitly
// add it to the 'output' group of the ComputationNetwork
std::unordered_set<Variable> networkOutputs(outputs);
networkOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
for (auto output : networkOutputs)
{
if (!IsVariableRoot(output))
{
auto computationNode = variableToNodeMap[output];
if (!computationNode)
InvalidArgument("One of the requested outputs '%S' for the Function '%S' forward computation is not part of the graph underlying the Function.",
output.AsString().c_str(), compositeFunction->AsString().c_str());
computationNetwork->AddToNodeGroup(L"output", computationNode);
}
}
// In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles.
// Now attach those after we have created all ComputationNodes in the network
for (auto varNodePair : variableToNodeMap)
{
auto& currentComputationNode = varNodePair.second;
if (!currentComputationNode)
LogicError("Function '%S': No computation node mapping exists for Variable %S.", compositeFunction->AsString().c_str(), varNodePair.first.AsString().c_str());
auto& currentComputationNodeInputs = currentComputationNode->GetInputs();
auto& currentVar = varNodePair.first;
if (!currentVar.IsOutput())
continue;
if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end())
{
// This ComputationNode has at least one null input which now needs to be properly attached
const PrimitiveFunction* primitiveFunc = dynamic_cast<const PrimitiveFunction*>(currentVar.Owner().get());
// Skip block primitives since they do not directly map to a computation node
if (primitiveFunc->OpType() == PrimitiveOpType::Block)
continue;
// Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
auto inputVars = primitiveFunc->Inputs();
ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars);
inputVars.resize(currentComputationNode->GetNumInputs());
std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
for (auto inputVar : inputVars)
inputNodesBasePtrs.push_back(variableToNodeMap.at(inputVar));
currentComputationNode->AttachInputs(inputNodesBasePtrs);
}
}
computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel());
computationNetwork->SetTrackGapNans(GetCheckedMode());
computationNetwork->SetIsV2Library(true);
computationNetwork->CompileNetwork();
// Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork
for (auto varNodePair : variableToNodeMap)
{
if (varNodePair.first.IsOutput())
{
auto outputVar = varNodePair.first;
auto computationNodePtr = variableToNodeMap.at(outputVar);
auto outputShape = outputVar.Shape();
auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
if (!VariableShapeMatchesNodeShape(outputShape, computationNodeSampleLayout))
{
LogicError("Function '%S': The output Variable '%S' shape '%S' does not match the SampleLayout shape '[%s]' of the corresponding ComputationNode in the network.",
compositeFunction->AsString().c_str(), outputVar.AsString().c_str(), outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str());
}
}
}
return { computationNetwork, variableToNodeMap };
}
template <typename ElementType>
ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device,
const std::unordered_set<Variable>& backpropRoots,
const std::unordered_set<Variable>& outputs,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
bool allocateNetworkMatrices)
{
// Lets purge the current computation network and regenerate the network if the CompositeFunction
// was previously compiled just for evaluation and not for gradient backpropagation.
if ((m_computationNetwork != nullptr) && (m_currentBackpropRoots.empty() && !backpropRoots.empty()))
PurgeComputationNetwork();
if (m_computationNetwork != nullptr)
{
// TODO: We should either invalidate and readapt the network if the backpropRoots change compared to what was specified when the network
// was last constructed, to just recreate a new network.
// For now just disallow changing the backpropRoots after the network is created
if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots))
LogicError("Function '%S': Changing backprop roots (Current = '%S', New = '%S') across different Forward calls on a CNTK composite Function is currently unsupported.",
AsString().c_str(), NamedListString(m_currentBackpropRoots).c_str(), NamedListString(backpropRoots).c_str());
// TODO: Support changing the device across different invocations of the forward method on a Function instance
if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device)
LogicError("Function '%S': Changing device (Current = '%S', New = %S') across different Forward calls on a CNTK composite Function is currently unsupported.",
AsString().c_str(), AsDeviceDescriptor(m_computationNetwork->GetDeviceId()).AsString().c_str(), device.AsString().c_str());
if (!backpropRoots.empty() && (inputsToExcludeGradientsFor != m_inputsExcludedFromGradientComputation))
LogicError("Function '%S': Changing the set of inputs to exclude from gradient computation, across different Forward calls on a CNTK composite Function, is currently unsupported.", AsString().c_str());
// Verify if the free dimensions of any of the arguments have changed, and if so, update the corresponding
// input ComputationNodes and rerun validation on the computation network
for (auto freeDimensionArgumentMapping : m_fullyDefinedArgumentsMap)
{
auto newShape = freeDimensionArgumentMapping.second.Shape();
auto argumentComputationNode = m_variableToNodeMap[freeDimensionArgumentMapping.first];
if (AsTensorShape(newShape) != argumentComputationNode->GetSampleLayout())
argumentComputationNode->SetDims(AsTensorShape(newShape), argumentComputationNode->HasMBLayout());
}
}
else
{
auto networkInputs = this->Inputs();
for (auto inputExcluded : inputsToExcludeGradientsFor)
{
// Only inputs of the network can be excluded from gradient computation
if (std::find(networkInputs.begin(), networkInputs.end(), inputExcluded) == networkInputs.end())
InvalidArgument("Variable '%S' specified for exclusion from gradient computation is not an input of the Function '%S'. "
"Only an input of the Function can be explicitly excluded from gradient computation.",
inputExcluded.AsString().c_str(), this->AsString().c_str());
}
m_inputsExcludedFromGradientComputation = NonOwnerPreservingCopy(inputsToExcludeGradientsFor);
m_currentBackpropRoots = NonOwnerPreservingCopy(backpropRoots);
// TODO: We currently only support one backprop root
if (backpropRoots.size() > 1)
LogicError("Function '%S': %d backprop roots specified; currently at most one backprop root is supported.", AsString().c_str(), (int)backpropRoots.size());
auto placeholders = Placeholders();
if (!placeholders.empty())
InvalidArgument("%d unbound Placeholder(s) '%S' found in the Function. "
"All Placeholders of a Function must be bound (to a variable) before performing a Forward computation.",
(int)placeholders.size(), NamedListString(placeholders).c_str());
// Lets update the composite Function graph's inputs with any inferred dimensions that
// were determined from the shapes of the supplied data
auto networkArguments = Arguments();
for (auto argument : networkArguments)
{
if (argument.Shape().HasInferredDimension())
{
auto fullyDefinedArgument = m_fullyDefinedArgumentsMap.at(argument);
for (size_t i = 0; i < argument.Shape().Rank(); ++i)
if (argument.Shape()[i] == NDShape::InferredDimension)
argument.m_dataFields->m_shape[i] = fullyDefinedArgument.Shape()[i];
}
}
// Run the final validation on the entire network once before constructing/compiling the
// internal computation network
ValidateOrUpdateOutputs();
std::tie(m_computationNetwork, m_variableToNodeMap) = CreateComputationNetwork<ElementType>(this->shared_from_this(), device, outputs, m_fullyDefinedArgumentsMap, m_inputsExcludedFromGradientComputation, /*useMangledNamesForComputationNodes =*/ false);
// Record the timestamps of Parameters and Constants
assert(m_lastRecordedTimeStamps.empty());
auto functionParameters = Parameters();
for (auto parameter : functionParameters)
m_lastRecordedTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() });
auto functionConstants = Constants();
for (auto constant : functionConstants)
m_lastRecordedTimeStamps.insert({ constant, constant.CurrentValueTimeStamp() });
// Collect parameters and constants being assigned to
PreorderTraverseFunctions(RootFunction(), [this](const FunctionPtr& function) {
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(function.get());
if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::Assign))
m_refVariables.insert(primitiveFunction->Inputs()[0]);
}, /*nestedSearchInsideBlockFunction =*/ true);
}
if (!m_networkMatricesAllocated && allocateNetworkMatrices)
{
m_allNetworkRoots = m_currentBackpropRoots;
ComputationNodeBasePtr backpropRootNode;
if (!m_currentBackpropRoots.empty())
backpropRootNode = m_variableToNodeMap.at(*m_currentBackpropRoots.begin());
// Now recursively traverse the network in a top-down fashion
auto rootFunction = RootFunction();
auto rootFunctionOutputs = rootFunction->RawOutputs();
m_allNetworkRoots.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
std::vector<ComputationNodeBasePtr> forwardRootNodes;
for (auto rootOutput : rootFunctionOutputs)
forwardRootNodes.push_back(m_variableToNodeMap.at(rootOutput));
std::vector<ComputationNodeBasePtr> forwardOutputNodes;
m_allNetworkRoots.insert(outputs.begin(), outputs.end());
for (auto output : outputs)
forwardOutputNodes.push_back(m_variableToNodeMap.at(output));
m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode);
m_networkMatricesAllocated = allocateNetworkMatrices;
}
else
{
// Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure
// in the cached computation network
for (auto output : outputs)
{
if (m_allNetworkRoots.find(output) == m_allNetworkRoots.end())
LogicError("Function '%S': Requested output '%S' is not part of the list of outputs '%S' that the Function was initially compiled for. "
"Changing requested outputs across different Forward calls is currently unsupported.",
AsString().c_str(), output.AsString().c_str(), NamedListString(m_allNetworkRoots).c_str());
}
}
return m_computationNetwork;
}
template <typename ElementType>
/*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, ComputationNodeBasePtr& computationNode, std::unordered_map<MBLayoutPtr, Variable>& layoutsPopulated)
{
NDShape inferredVariableShape;
std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableValue.first, variableValue.second, &inferredVariableShape);
if (!VariableShapeMatchesNodeShape(inferredVariableShape, computationNode->GetSampleLayout()))
CNTK::LogicError("CompositeFunction::Forward: Inferred shape '%S' of Variable '%S' does not match the corresponding computation node shape '%s'.",
inferredVariableShape.AsString().c_str(), variableValue.first.AsString().c_str(), ((std::string)computationNode->GetSampleLayout()).c_str());
// Switch the node matrix to the right matrix type
auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();
nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);
auto layout = CNTKMatrixAndMBLayout.second;
auto& nodeLayout = computationNode->GetMBLayout();
if ((layout == nullptr) != (nodeLayout == nullptr))
InvalidArgument("The layout of the specified Value for Variable '%S' is incompatible with the layout of the corresponding ComputationNode.", variableValue.first.AsString().c_str());
else if (layout)
{
if (layoutsPopulated.find(nodeLayout) == layoutsPopulated.end())
{
nodeLayout->CopyFrom(layout);
layoutsPopulated.insert({ nodeLayout, variableValue.first });
}
else
{
if (*nodeLayout != *layout)
InvalidArgument("Different minibatch layouts detected (difference in sequence lengths or count or start flags) in data specified "
"for the Function's arguments '%S' vs. '%S', though these arguments have the same dynamic axes '%S'",
variableValue.first.AsString().c_str(), layoutsPopulated.at(nodeLayout).AsString().c_str(), DynamicAxesAsString(variableValue.first.DynamicAxes(), Internal::IsReversingTensorShapesInErrorMessagesEnabled()).c_str());
}
}
}
std::unordered_map<Variable, NDShape> CompositeFunction::InferFreeDimensionsOfArguments(const std::unordered_map<Variable, ValuePtr>& arguments)
{
std::unordered_map<Variable, NDShape> inferredArgumentDimensions;
for (auto argumentValuePair : arguments)
{
NDShape inferredVarShape;
Utils::VerifyVariableValueCompatibility(argumentValuePair.first, argumentValuePair.second, &inferredVarShape);
if (inferredVarShape != argumentValuePair.first.Shape())
inferredArgumentDimensions.insert({ argumentValuePair.first , inferredVarShape });
}
if (!inferredArgumentDimensions.empty())
{
if (m_fullyDefinedArgumentsMap.empty())
{
for (auto inferredArgumentShapePair : inferredArgumentDimensions)
{
auto fullyDefinedArgument = inferredArgumentShapePair.first.Clone();
fullyDefinedArgument.m_dataFields->m_shape = inferredArgumentShapePair.second;
m_fullyDefinedArgumentsMap.insert({ inferredArgumentShapePair.first, fullyDefinedArgument });
}
if (GetCheckedMode())
m_latestFullyDefinedCompositeForCheckedModeValidation = this->Clone(ParameterCloningMethod::Share, m_fullyDefinedArgumentsMap);
}
else
{
bool argumentShapeChangedSinceLastTime = false;
for (auto inferredArgumentShapePair : inferredArgumentDimensions)
{
if (inferredArgumentShapePair.second != m_fullyDefinedArgumentsMap[inferredArgumentShapePair.first].Shape())
{
argumentShapeChangedSinceLastTime = true;
m_fullyDefinedArgumentsMap[inferredArgumentShapePair.first].m_dataFields->m_shape = inferredArgumentShapePair.second;
}
}
if (argumentShapeChangedSinceLastTime && m_latestFullyDefinedCompositeForCheckedModeValidation)
m_latestFullyDefinedCompositeForCheckedModeValidation->ValidateOrUpdateOutputs();
}
}
return inferredArgumentDimensions;
}
void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
{
std::unordered_map<MBLayoutPtr, Variable> layoutsPopulated;
std::vector<ComputationNodeBasePtr> inputNodes;
for (auto argumentValuePair : arguments)
{
auto argument = argumentValuePair.first;
auto argumentComputationNode = m_variableToNodeMap.at(argument);
assert(argumentComputationNode);
inputNodes.push_back(argumentComputationNode);
ValuePtr argumentValue = arguments.at(argument);
switch (argumentValue->GetDataType())
{
case DataType::Float:
PopulateComputationNodeValue<float>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
break;
case DataType::Double:
PopulateComputationNodeValue<double>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
break;
default:
LogicError("Function '%S' Forward: Unsupported DataType %s.", AsString().c_str(), DataTypeName(argumentValue->GetDataType()));
break;
}
}
m_computationNetwork->BumpEvalTimeStamp(inputNodes);
}
template <typename ElementType>
/*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode)
{
NDShape inferredVariableShape;
std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableGradient.first, variableGradient.second, &inferredVariableShape);
if (!VariableShapeMatchesNodeShape(inferredVariableShape, computationNode->GetSampleLayout()))
CNTK::LogicError("CompositeFunction::Backward: Inferred shape '%S' of Variable '%S' does not match the corresponding computation node shape '%s'.",
inferredVariableShape.AsString().c_str(), variableGradient.first.AsString().c_str(), ((std::string)computationNode->GetSampleLayout()).c_str());
MBLayoutPtr layout = CNTKMatrixAndMBLayout.second;
auto nodeLayout = computationNode->GetMBLayout();
if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
InvalidArgument("The layout of the specified gradient Value for Variable '%S' is incompatible with the layout computed during Forward call.",
variableGradient.first.AsString().c_str());
computationNode->As<ComputationNode<ElementType>>()->AssignGradient(*CNTKMatrixAndMBLayout.first);
}
// Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph
void CompositeFunction::PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients)
{
auto functionOutputs = RawOutputs();
for (auto gradientVarValuePair : gradients)
{
auto outputComputationNode = m_variableToNodeMap.at(gradientVarValuePair.first);
ValuePtr gradientValue = gradientVarValuePair.second;
switch (gradientValue->GetDataType())
{
case DataType::Float:
PopulateComputationNodeGradient<float>(gradientVarValuePair, outputComputationNode);
break;
case DataType::Double:
PopulateComputationNodeGradient<double>(gradientVarValuePair, outputComputationNode);
break;
default:
LogicError("Function '%S' Backward: Unsupported DataType %s.", AsString().c_str(), DataTypeName(gradientValue->GetDataType()));
break;
}
}
}
/*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient)
{
auto varShape = GetVariableShape(var.Shape(), computationNode->GetSampleLayout());
auto valueShape = PackedValue::GetUnpackedShape(varShape, var.DynamicAxes(), computationNode->GetMBLayout());
if (varValue != nullptr)
{
// TODO: The shape of the specified output Value object must match the actual output shape
if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape)))
::InvalidArgument("The shape %S of the specified Value object for Variable '%S' %s does not match the actual shape %S",
varValue->Shape().AsString().c_str(), var.AsString().c_str(), getGradient ? "gradient" : "output", valueShape.AsString().c_str());
}
ValuePtr nodeValue;
auto layout = computationNode->GetMBLayout();
switch (var.GetDataType())
{
case DataType::Float:
{
auto& matrix = getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value();
if (varValue == nullptr)
nodeValue = MakeSharedObject<PackedValue>(varShape, var.DynamicAxes(), std::make_shared<Matrix<float>>(matrix.AsReference()), layout, /*readOnly =*/ false);
else
nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var, computationNode, matrix, layout);
break;
}
case DataType::Double:
{
auto& matrix = getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value();
if (varValue == nullptr)
nodeValue = MakeSharedObject<PackedValue>(varShape, var.DynamicAxes(), std::make_shared<Matrix<double>>(matrix.AsReference()), layout, /*readOnly =*/ false);
else
nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var, computationNode, matrix, layout);
break;
}
default:
CNTK::LogicError("CompositeFunction::Forward/Backward: Unsupported DataType %s", DataTypeName(var.GetDataType()));
break;
}
if (varValue == nullptr)
varValue = nodeValue;
else
varValue->CopyFrom(*nodeValue);
}
void CompositeFunction::GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs)
{
// Now copy the Forward values of output nodes from the network to outputs' Value objects
for (auto& outputVarValuePair : outputs)
{
auto& valuePtr = outputVarValuePair.second;
auto node = m_variableToNodeMap.at(outputVarValuePair.first);
bool noValueStrorageProvided = (valuePtr == nullptr);
GetNodeOutputOrGradient(outputVarValuePair.first, valuePtr, node, false /*getGradient*/);
auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
m_existingNetworkStorageReferences.push_back(packedVarValue);
}
}
void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
{
auto networkInputs = this->Inputs();
// Now copy the gradient values of input nodes of the network to gradients' Value objects
for (auto& gradientVarValuePair : gradients)
{
// Only gradients corresponding to inputs of the network can be obtained
if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end())
InvalidArgument("Gradient requested for Variable '%S' which is not a leaf input (Input, Parameter or Constant) of the Function '%S'; this is currently unsupported.",
gradientVarValuePair.first.AsString().c_str(), this->AsString().c_str());
// Gradients can only be obtained for parameter variables or input variables that NeedsGradient
if (!gradientVarValuePair.first.NeedsGradient() || (m_inputsExcludedFromGradientComputation.find(gradientVarValuePair.first) != m_inputsExcludedFromGradientComputation.end()))
InvalidArgument("Gradient value incorrectly requested for Variable '%S', "
"an Output or Constant or Input Variable with NeedsGradient setting of false, or an input for which gradient computation was explicitly excluded.",
gradientVarValuePair.first.AsString().c_str());
auto computationNodePtr = m_variableToNodeMap.at(gradientVarValuePair.first);
if (!computationNodePtr->NeedsGradient())
LogicError("Function '%S': Backpropagated gradient value cannot be read from a Variable '%S' whose ComputationNode has NeedsGradient set to false.",
AsString().c_str(), gradientVarValuePair.first.AsString().c_str());
auto& valuePtr = gradientVarValuePair.second;
bool noValueStrorageProvided = (valuePtr == nullptr);
GetNodeOutputOrGradient(gradientVarValuePair.first, valuePtr, computationNodePtr, true /*getGradient*/);
auto packedVarValue = std::dynamic_pointer_cast<PackedValue>(valuePtr);
if (noValueStrorageProvided && packedVarValue && packedVarValue->IsPacked())
m_existingNetworkStorageReferences.push_back(packedVarValue);
}
}
const std::vector<Variable>& CompositeFunction::GetArgumentDependencies(const Variable& output)
{
if (m_perOutputVarArgumentDependencies.find(output) == m_perOutputVarArgumentDependencies.end())
{
auto sanitizedOutput = output.NonCompositePreservingCopy();
if (sanitizedOutput.IsOutput())
m_perOutputVarArgumentDependencies[sanitizedOutput] = AsComposite(sanitizedOutput.Owner())->Arguments();
else if (sanitizedOutput.IsParameter() || sanitizedOutput.IsConstant())
m_perOutputVarArgumentDependencies[sanitizedOutput] = {};
else
m_perOutputVarArgumentDependencies[sanitizedOutput] = { sanitizedOutput };
}
return m_perOutputVarArgumentDependencies[output];
}
std::unordered_map<Variable, uint64_t> CompositeFunction::GetCurrentBackpropRootsTimeStamps() const
{
std::unordered_map<Variable, uint64_t> currentBackpropRootsTimeStamps;
assert(m_computationNetwork != nullptr);
for (auto& backpropRoot : m_currentBackpropRoots)
currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp();
return currentBackpropRootsTimeStamps;
}
/*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
std::unordered_map<Variable, ValuePtr>& outputs,
const DeviceDescriptor& computeDevice,
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
{
// Validate arguments and outputs
if (outputs.empty())
InvalidArgument("At least one output has to be specified when calling Forward method of the Function '%S'.", this->AsString().c_str());
// Make sure that the DataType of the variables and corresponding values match
// TODO: We need a better way to determine the ElementType for the network
auto dataType = DataType::Unknown;
for (auto variableValuePair : arguments)
{
if (dataType == DataType::Unknown)
dataType = variableValuePair.first.GetDataType();
else if (dataType != variableValuePair.first.GetDataType())
LogicError("Function '%S' Forward: The DataType of all arguments must be same.", this->AsString().c_str());
}
if (dataType == DataType::Unknown)
{
for (auto variableValuePair : outputs)
{
if (dataType == DataType::Unknown)
dataType = variableValuePair.first.GetDataType();
}
}
std::unordered_set<Variable> requestedOutputVariables;
for (auto output : outputs)
requestedOutputVariables.insert(output.first);
std::unordered_set<Variable> functionOutputs(m_outputs.begin(), m_outputs.end());
std::unordered_set<Variable> requiredArguments;
for (auto outputVariable : requestedOutputVariables)
{
auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVariable);
requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end());
}
// We should have argument values supplied for all required argument dependencies for the requested outputs
std::vector<Variable> missingRequiredArguments;
std::unordered_map<Variable, ValuePtr> requiredArgumentValues;
for (auto requiredArgument : requiredArguments)
{
auto iter = arguments.find(requiredArgument);
if (iter == arguments.end())
missingRequiredArguments.push_back(requiredArgument);
else
requiredArgumentValues.insert(*iter);
}
if (!missingRequiredArguments.empty())
{
InvalidArgument("Values for %d required arguments '%S', that the requested output(s) '%S' depend on, have not been provided.",
(int)missingRequiredArguments.size(), NamedListString(missingRequiredArguments).c_str(), NamedListString(requestedOutputVariables).c_str());
}
if (requiredArgumentValues.size() < arguments.size())
fprintf(stderr, "WARNING: Function::Forward provided values for (%d) extra arguments which are not required for evaluating the specified Function outputs!\n", (int)(arguments.size() - requiredArgumentValues.size()));
auto inferredArgumentShapes = InferFreeDimensionsOfArguments(requiredArgumentValues);
if (dataType == DataType::Float)
GetComputationNetwork<float>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
else if (dataType == DataType::Double)
GetComputationNetwork<double>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
else
InvalidArgument("Unsupported DataType %s", DataTypeName(dataType));
// Feed data into the arguments of the network
// TODO: Avoid copying the data when possible
PopulateNetworkInputs(requiredArgumentValues);
// Copy all new values for 'dirty' attributes from functions into corresponding network nodes.
ApplyAttributeUpdates();
// Bump the timestamp of the parameter nodes whose values have changed
for (auto& timeStampRecord : m_lastRecordedTimeStamps)
{
auto variable = timeStampRecord.first;
auto prevTimeStamp = timeStampRecord.second;
auto newTimeStamp = variable.CurrentValueTimeStamp();
if (newTimeStamp > prevTimeStamp)
{
timeStampRecord.second = newTimeStamp;
m_variableToNodeMap.at(variable)->BumpEvalTimeStamp();
}
}
std::vector<ComputationNodeBasePtr> outputsToEvaluate;
for (auto outputVariable : requestedOutputVariables)
outputsToEvaluate.push_back(m_variableToNodeMap.at(outputVariable));
// The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
{
if (outputs.find(rootVarForBackprop) == outputs.end())
outputsToEvaluate.push_back(m_variableToNodeMap.at(rootVarForBackprop));
}
// Reset the timestamps of all backward roots to record an update in one or more inputs
for (auto& backpropRoot : m_currentBackpropRoots)
m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();
// Reset the timestamps of all dropout node to force recomputation of the (random) dropout mask.
list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType<DropoutNodeBase>();
for (auto& dropout : dropoutNodes)
dropout->SetEvalTimeStampOutdatedWrtAll();
// Free any previous references to the matrix storage associated with the outputsToEvaluate
ClearExistingOutputOrGradientStorageReferences();
ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);
m_computationNetwork->ForwardProp(outputsToEvaluate);
// Call PostForwardAndBackProp after ForwardProp only in evaluation mode.
if (outputsToRetainBackwardStateFor.empty())
{
m_computationNetwork->PostForwardAndBackProp(outputsToEvaluate);
RecordRefVariableUpdates();
}
else
{
m_currentOutputsToEvaluate.clear();
for (auto outputToEvaluate : outputsToEvaluate)
m_currentOutputsToEvaluate.push_back(outputToEvaluate);
}
GetNetworkOutputs(outputs);
// TODO: How to deal with the specified 'computeDevice'
BackPropStatePtr backpropStatePtr;
if (outputsToRetainBackwardStateFor.size() > 0)
backpropStatePtr = MakeSharedObject<CNTKBackPropState>(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps());
return backpropStatePtr;
}
/*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
{
auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
if (backpropState == nullptr)
InvalidArgument("Function '%S' Backward: Invalid backprop state passed.", AsString().c_str());
if (backPropagatedGradientValuesForInputs.empty())
InvalidArgument("Function '%S' Backward: List of inputs to compute gradients for, must not be empty.", AsString().c_str());
// TODO: Support multiple concurrent backprop states
std::unordered_map<Variable, uint64_t> currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps();
if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps)
LogicError("Function '%S' Backward: The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified "
"by subsequent Forward calls to the function. This is not a user error but a shortcoming of the current implementation where multiple independent "
"backprop states are not simultaneously supported.", AsString().c_str());
if (rootGradientValues.size() > 1)
LogicError("Function '%S' Backward: %d root gradient values specified; currently gradient backprop from only one of the Function Outputs is supported.",
AsString().c_str(), (int)rootGradientValues.size());
// TODO: Avoid copying the data when possible
// Zero all gradients of nodes below the root nodes
for (auto rootGradientVarValuePair : rootGradientValues)
m_computationNetwork->ZeroInputGradients(m_variableToNodeMap.at(rootGradientVarValuePair.first));
// Feed data into the arguments of the network
PopulateNetworkGradients(rootGradientValues);
// Backpropagate through the network
ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training);
auto rootComputationNodePtr = m_variableToNodeMap.at(rootGradientValues.begin()->first);
m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true);
GetNetworkGradients(backPropagatedGradientValuesForInputs);
if (m_currentOutputsToEvaluate.size() > 0)
{
m_computationNetwork->PostForwardAndBackProp(m_currentOutputsToEvaluate);
RecordRefVariableUpdates();
m_currentOutputsToEvaluate.clear();
}
// TODO: How to deal with the specified 'computeDevice'
}
void CompositeFunction::ApplyAttributeUpdates()
{
// Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input
// This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated
// w.r.t. inputs to force evaluation in each minibatch
for (auto varNodePair : m_variableToNodeMap)
{
auto var = varNodePair.first;
if (!var.IsOutput())
continue;
auto function = var.Owner();
if (function->m_dirtyAttributes.empty())
continue;
auto node = varNodePair.second;
for (const wstring& attribute : function->m_dirtyAttributes)
{
if (attribute == PrimitiveFunction::AttributeNameDropoutRate)
{
auto dropoutRate = function->m_attributes[attribute].Value<double>();
auto dropoutPtr = dynamic_cast<DropoutNodeBase*>(node.get());
assert(dropoutPtr != nullptr);
dropoutPtr->SetDropoutRate(dropoutRate);
}
else if (attribute == PrimitiveFunction::AttributeNameRngSeed)
{
auto seed = function->m_attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto rngUserPtr = dynamic_cast<RngUser*>(node.get());
assert(rngUserPtr != nullptr);
rngUserPtr->SetRngState(seed);
}
else
{
// Should never happen.
LogicError("ApplyAttributeUpdates: function '%S' specified an unsupported attribute '%S'.",
function->AsString().c_str(), attribute.c_str());
}
}
function->m_dirtyAttributes.clear();
node->SetEvalTimeStampOutdatedWrtAll();
}
}
}