https://github.com/Microsoft/CNTK
Tip revision: d90390926fdb04765013f207ccef700580cfdf67 authored by Jian Jiao on 27 February 2017, 19:46:03 UTC
forward
forward
Tip revision: d903909
CompositeFunction.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "CompositeFunction.h"
#include "ComputationNetworkBuilder.h"
#include "Utils.h"
#include "ComputationNode.h"
#include "ReshapingNodes.h"
#include "EvaluationNodes.h"
#include "TrainingNodes.h"
#include "LinearAlgebraNodes.h"
#include "InputAndParamNodes.h"
#include "NonlinearityNodes.h"
#include "RecurrentNodes.h"
#include "Serialization.h"
#include "Value.h"
#include "RNNNodes.h"
#include "UserDefinedV2FunctionNode.h"
#include "BlockFunction.h"
using namespace Microsoft::MSR::CNTK;
namespace CNTK
{
/*static*/ const std::wstring CompositeFunction::CompositeFunctionOpName = L"CompositeFunctionOpName";
/*static*/ std::atomic<unsigned int> CompositeFunction::s_nextAutoGeneratedDynamicAxis(0);
static const std::wstring s_compositeFunctionTypeValue = L"CompositeFunction";
Dictionary CompositeFunction::SerializeBlockComposite() const
{
Dictionary dict;
dict[versionKey] = CurrentVersion();
dict[typeKey] = s_compositeFunctionTypeValue;
dict[rootKey] = RootFunction()->Uid();
if (!Name().empty())
dict[nameKey] = Name();
dict[uidKey] = Uid();
return dict;
}
/*virtual*/ Dictionary CompositeFunction::Serialize() const
{
Dictionary dict = SerializeBlockComposite();
// Find cycles in the graph and "break" them by inserting placeholders.
// This needs to be done on Save, since here we have easy access to the shape and
// dynamic axis info.
std::unordered_set<FunctionPtr> visitedFunctions;
std::vector<FunctionPtr> topoSortedPrimitiveFunctions;
std::vector<Variable> uniqueInputs;
std::unordered_set<std::wstring> inputUids;
std::function<void(const FunctionPtr& function)> SerializationTraversalFunc;
SerializationTraversalFunc = [&visitedFunctions, &uniqueInputs, &topoSortedPrimitiveFunctions, &inputUids, &SerializationTraversalFunc](const FunctionPtr& function) {
std::vector<Variable> functionInputs = function->Inputs();
for (const auto& input : functionInputs)
{
auto& uid = input.Uid();
if (inputUids.find(uid) != inputUids.end())
continue;
// check if this input corresponds to a cyclic edge in the graph.
// BUG: A function being visited twice does not indicate it being a cyclic edge in the graph.
// It just means there are at least 2 successors in the graph that have the function as input
bool mustBeReplaced = input.IsOutput() && (visitedFunctions.find(input.Owner()) != visitedFunctions.end());
if (mustBeReplaced)
{
auto varKind = VariableKind::Placeholder;
Variable var(input.Shape(), varKind, input.GetDataType(), nullptr, input.IsSparse(), input.DynamicAxes(), input.Name(), uid);
uniqueInputs.push_back(var);
inputUids.insert(uid);
}
else if (!input.IsOutput())
{
// leave the input as is.
uniqueInputs.push_back(input);
inputUids.insert(uid);
}
}
visitedFunctions.insert(function);
topoSortedPrimitiveFunctions.push_back(function);
// For block functions we need to recursively traverse the underlying composite
if (function->IsBlock())
PreorderTraverseFunctions(function->BlockRoot(), SerializationTraversalFunc);
};
PreorderTraverseFunctions(RootFunction(), SerializationTraversalFunc);
std::reverse(std::begin(topoSortedPrimitiveFunctions), std::end(topoSortedPrimitiveFunctions));
assert(topoSortedPrimitiveFunctions.back()->Uid() == RootFunction()->Uid());
std::vector<DictionaryValue> inputDictionaries;
inputDictionaries.reserve(uniqueInputs.size());
inputUids.clear();
for (const auto& input : uniqueInputs)
{
if (inputUids.find(input.Uid()) != inputUids.end())
LogicError("Input uids must be unique");
inputUids.insert(input.Uid());
inputDictionaries.push_back(input.Serialize());
}
dict[inputsKey] = std::move(inputDictionaries);
std::vector<DictionaryValue> functionDictionaries;
std::unordered_set<std::wstring> outputUids;
for (const auto& primitiveFunction : topoSortedPrimitiveFunctions)
{
for (const auto& output : primitiveFunction->RawOutputs())
{
if (outputUids.find(output.Uid()) != outputUids.end())
LogicError("Output uids of all primitive functions in a function graph must be unique");
outputUids.insert(primitiveFunction->Uid());
}
functionDictionaries.push_back(primitiveFunction->Serialize());
}
dict[functionsKey] = std::move(functionDictionaries);
// Now, collect and store the internal state for all non-pure (stateful) functions in the graph
// (with the corresponding nodes that subclass from RngUser: Dropout, RandomSample, etc).
Dictionary stateDictionary;
for (const auto& kv : m_variableToNodeMap)
{
if (kv.second->Is<RngUser>() && kv.first.IsOutput())
{
// The RNG state should be associated with the actual function that the computation node
// corresponds to, and not the block primitives that wrap the actual function
auto ownerFunction = kv.first.Owner().get();
if (!ownerFunction->IsBlock())
{
auto rng = kv.second->As<RngUser>();
Dictionary state;
state[rngSeedKey] = static_cast<size_t>(rng->GetRngSeed());
state[rngOffsetKey] = static_cast<size_t>(rng->GetRngOffset());
stateDictionary[ownerFunction->Uid()] = state;
}
}
}
dict[stateKey] = std::move(stateDictionary);
return dict;
}
/*static*/ FunctionPtr CompositeFunction::DeserializeBlockComposite(const Dictionary& dict,
const std::unordered_set<FunctionPtr>& allPrimitiveFunctions,
const std::unordered_map<Variable, Variable>& allPlaceholderReplacements,
const CNTK::DeviceDescriptor& device)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { typeKey, rootKey, uidKey };
ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);
const auto& rootUid = dict[rootKey].Value<std::wstring>();
std::wstring name = L"";
if (dict.Contains(nameKey))
name = dict[nameKey].Value<std::wstring>();
const auto& uid = dict[uidKey].Value<std::wstring>();
FunctionPtr root = *std::find_if(allPrimitiveFunctions.begin(), allPrimitiveFunctions.end(), [&rootUid](const FunctionPtr& func) {
return func->Uid() == rootUid;
});
// Find the subset of placeholder replacements that apply for this composite
FunctionPtr composite = CompositeFunction::Create(root, name, uid);
std::unordered_map<Variable, Variable> placeholderReplacements;
do
{
placeholderReplacements.clear();
auto compositePlaceholders = composite->Placeholders();
for (auto placeholder : compositePlaceholders)
{
if (allPlaceholderReplacements.find(placeholder) != allPlaceholderReplacements.end())
placeholderReplacements.insert({ placeholder, allPlaceholderReplacements.at(placeholder) });
}
if (placeholderReplacements.size() > 0)
composite = composite->ReplacePlaceholders(placeholderReplacements);
} while (placeholderReplacements.size() > 0);
return composite;
}
/*static*/ FunctionPtr CompositeFunction::Deserialize(const Dictionary& dict, const CNTK::DeviceDescriptor& device)
{
static const vector<std::wstring> s_requiredDictionaryKeys = { inputsKey, functionsKey };
size_t version = ValidateDictionary<CompositeFunction>(dict, s_requiredDictionaryKeys, s_compositeFunctionTypeValue, s_serializationVersion);
const auto& inputs = dict[inputsKey].Value<vector<DictionaryValue>>();
std::unordered_map<std::wstring, Variable> uidToInputMap(inputs.size());
for (const auto& dictionaryValue : inputs)
{
const auto& dictionary = dictionaryValue.Value<Dictionary>();
const auto& inputVar = Variable::Deserialize(dictionary, device);
if (uidToInputMap.find(inputVar.Uid()) != uidToInputMap.end())
{
LogicError("Input uids are not unique (several inputs share '%ls' uid) "
"(%s).", inputVar.Uid().c_str(), GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
}
uidToInputMap[inputVar.Uid()] = inputVar;
}
Dictionary stateDictionary;
if (dict.Contains(stateKey))
stateDictionary = dict[stateKey].Value<Dictionary>();
const auto& functions = dict[functionsKey].Value<vector<DictionaryValue>>();
std::unordered_map<Variable, Variable> allPlaceholderReplacements;
std::unordered_set<FunctionPtr> allPrimitiveFunctions; // this keeps all primitive functions alive until a composite function is created.
for (const auto& dictionaryValue : functions)
{
FunctionPtr root = PrimitiveFunction::Deserialize(dictionaryValue.Value<Dictionary>(), uidToInputMap, allPrimitiveFunctions, allPlaceholderReplacements, device);
allPrimitiveFunctions.insert(root);
auto primitiveFunction = dynamic_cast<PrimitiveFunction*>(root.get());
// Since Combine simply forwards other functions' outputs, all of its outputs
// should already be in the uidToInputMap.
auto opType = primitiveFunction->OpType();
if (opType == PrimitiveOpType::Combine)
continue;
if (primitiveFunction->IsStateful())
{
if (stateDictionary.Contains(primitiveFunction->Uid()))
{
auto state = stateDictionary[primitiveFunction->Uid()].Value<Dictionary>();
auto seed = state[rngSeedKey].Value<size_t>();
auto offset = state[rngOffsetKey].Value<size_t>();
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngSeed] = seed;
primitiveFunction->m_attributes[PrimitiveFunction::AttributeNameRngOffset] = offset;
}
else if (Internal::GetComputationNetworkTraceLevel() > 0)
{
// TODO: all logging functionality should be refactored to live in a logging utility class.
fprintf(stderr, "WARNING: no state information found for the stateful function (%ls) "
"when deserializing from a dictionary (version=%zu). "
"Reproducibility not guaranteed.", primitiveFunction->OpName().c_str(), version);
}
}
for (const auto& output : root->RawOutputs())
{
const auto& it = uidToInputMap.find(output.Uid());
if (it != uidToInputMap.end())
{
if (!it->second.IsPlaceholder())
{
LogicError("Unexpected variable type %ls instead of a Placeholder for input %ls variable (uid = %ls)"
"(%s).", VariableKindName(it->second.Kind()), it->second.Name().c_str(), it->second.Uid().c_str(),
GetVersionsString<CompositeFunction>(s_serializationVersion, version).c_str());
}
allPlaceholderReplacements[it->second] = output;
}
else
{
uidToInputMap[output.Uid()] = output;
}
}
}
return DeserializeBlockComposite(dict, allPrimitiveFunctions, allPlaceholderReplacements, device);
}
void CompositeFunction::CopyState(const CompositeFunction& source)
{
// Create a map with all non-pure (stateful) functions in the function graph.
auto collectStatefulFunctions = [](const std::unordered_set<FunctionPtr>& allPrimitiveFunctions) -> std::map<std::wstring, FunctionPtr> {
std::map<std::wstring, FunctionPtr> functionMap;
for (auto funcPtr : allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(funcPtr.get());
if (primitiveFunction->IsStateful())
{
functionMap[primitiveFunction->Uid()] = funcPtr;
}
}
return functionMap;
};
std::map<std::wstring, FunctionPtr> statefulFunctionsTo = collectStatefulFunctions(m_allPrimitiveFunctions);
std::map<std::wstring, FunctionPtr> statefulFunctionsFrom = collectStatefulFunctions(source.m_allPrimitiveFunctions);
assert(statefulFunctionsTo.size() == statefulFunctionsFrom.size());
if (statefulFunctionsFrom.size() == 0)
{
return;
}
// Copy state captured in the attributes dictionaries.
for (const auto& kv : statefulFunctionsFrom)
{
statefulFunctionsTo[kv.first]->m_attributes = kv.second->Attributes();
}
UpdateInternalNetworkState();
}
void CompositeFunction::UpdateInternalNetworkState()
{
if (!m_computationNetwork)
{
return;
}
for (const auto& function : m_allPrimitiveFunctions)
{
auto primitiveFunction = dynamic_cast<const PrimitiveFunction*>(function.get());
if (primitiveFunction->IsStateful())
{
for (const auto& output : function->RawOutputs())
{
auto node = m_variableToNodeMap.at(output);
auto attributes = function->Attributes();
auto seed = attributes[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
auto offset = attributes[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
node->As<RngUser>()->SetRngState(seed, offset);
}
}
}
}
// Names of the dynamic axes in the CNTK engine for some special sets of dynamic axes values
// Note: The no sequence axis corresponds to a special case where there is no sequence axis (i.e. has been reduced over)
// and the special name is used to identify this when loading back a model saved in CNTK v1 format. This will not really be needed
// when the new CNTK v2 model serialization format is ready.
/*static*/ const std::wstring CompositeFunction::InternalDefaultDynamicAxisName = L"*";
/*static*/ const std::wstring CompositeFunction::InternalNoSequenceAxisName = L"__noSequenceAxis";
// Recursively create a sub-network of ComputationNode instances corresponding to the graph of Functions
// underlying the specified 'variable' and return the ComputationNode instance that corresponds to the
// top level 'variable'
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::GetNode(const Variable& variable,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
{
auto iter = variableToNodeMap.find(variable);
if (iter != variableToNodeMap.end())
{
isVariableRootMap[variable] = false;
return iter->second;
}
// The DataType, Shape and DynamicAxes of the variable must be known by now
if (variable.GetDataType() == DataType::Unknown)
InvalidArgument("Variable%S with unknown DataType detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
if (variable.Shape().IsUnknown())
InvalidArgument("Variable%S with unknown shape detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
if (variable.Shape().HasInferredDimension())
InvalidArgument("Variable%S with InferredDimension for at least one axis in its shape, detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
if (variable.DynamicAxes() == Axis::UnknownDynamicAxes())
InvalidArgument("Variable%S with unknown dynamic axes detected when compiling the Function graph!", ParanthesizedName(variable.Name()).c_str());
// Lets add a null entry in the map for this variable, to break infinite recursion when processing recurrent graphs
variableToNodeMap[variable] = nullptr;
std::shared_ptr<ComputationNode<ElementType>> computationNodePtr;
if (variable.IsParameter() || variable.IsConstant())
{
auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name());
computationNodePtr = builder.CreateLearnableParameter(internalNodeName, AsTensorShape(variable.Shape()));
network->InitLearnableParameters(computationNodePtr, L"fixedValue", 0); // must call this to follow protocol; can overwrite later
if (!variable.NeedsGradient() || (inputsToExcludeGradientsFor.find(variable) != inputsToExcludeGradientsFor.end()))
computationNodePtr->SetLearningRateMultiplier(0.0);
NDArrayViewPtr value = variable.IsConstant() ? Constant(variable).Value() : Parameter(variable).Value();
std::shared_ptr<const Matrix<ElementType>> valueMatrix = variable.IsConstant() ? value->GetMatrix<ElementType>() : value->GetWritableMatrix<ElementType>();
if (variable.IsParameter() || (valueMatrix->GetDeviceId() == network->GetDeviceId()))
computationNodePtr->Value() = valueMatrix->AsReference();
else // Constant: if initialized data lives on wrong device, make a copy to the right one (copy is OK since it's constant)
{
Matrix<ElementType> clonedMatrix(valueMatrix->GetNumRows(), valueMatrix->GetNumCols(), network->GetDeviceId(), valueMatrix->GetMatrixType(), valueMatrix->GetFormat());
clonedMatrix.AssignValuesOf(*valueMatrix);
computationNodePtr->Value() = std::move(clonedMatrix);
}
}
else if (variable.IsInput())
{
auto internalNodeName = CNTKInternalNodeNameFromUidAndName(variable.Uid(), variable.Name());
// TODO: Input variables currently are required to have the default batch axis
auto dynamicAxes = variable.DynamicAxes();
auto foundDefaultBatchAxis = std::find(dynamicAxes.begin(), dynamicAxes.end(), Axis::DefaultBatchAxis());
if (foundDefaultBatchAxis == dynamicAxes.end())
LogicError("Currently Input Variables are required to have the DefaultBatchAxis as one of their dynamic axes");
if (dynamicAxes.back() != Axis::DefaultBatchAxis())
LogicError("Currently Input Variables are required to have the DefaultBatchAxis as their last dynamic axes");
// TODO: Support inputs with > 1 dynamic axes
if ((dynamicAxes.size() < 1) || (dynamicAxes.size() > 2))
LogicError("Currently only Input variables with 1 or 2 dynamic axis are supported");
// Construct the dynamic axis name to be used internally for the CNTK InputNodes
std::wstring internalDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
if (!internalDynamicAxisName.empty() && !network->NodeNameExists(internalDynamicAxisName))
network->AddNodeToNetAndAttachInputs(New<DynamicAxisNode<ElementType>>(network->GetDeviceId(), internalDynamicAxisName), {});
if (IsSparseInput(variable))
computationNodePtr = builder.CreateSparseInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName);
else
computationNodePtr = builder.CreateInputNode(internalNodeName, AsTensorShape(variable.Shape()), internalDynamicAxisName);
if (variable.NeedsGradient() && (inputsToExcludeGradientsFor.find(variable) == inputsToExcludeGradientsFor.end()))
{
// Set a dummy learning rate multiplier to force gradient computation for the input computation node since by default
// gradients are not computed for Input nodes
computationNodePtr->SetLearningRateMultiplier(0.00001f);
}
}
else
{
assert(variable.IsOutput());
auto outputVariableNode = GetOutputVariableNode(variable, network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor);
// Can be null in case of loops with f.output == f.input.
// Such loops cannot be handled, so we leave nullptr as computational node.
if (outputVariableNode)
computationNodePtr = outputVariableNode->template As<ComputationNode<ElementType>>()->shared_from_this();
else
computationNodePtr = nullptr;
}
variableToNodeMap[variable] = computationNodePtr;
if (isVariableRootMap.find(variable) == isVariableRootMap.end())
isVariableRootMap[variable] = variable.IsOutput();
return computationNodePtr;
}
/*static*/ Variable CompositeFunction::GetMappingForNoOpOutput(const Variable& variable, bool recursive)
{
Variable mappingVariable = variable;
auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
if (ownerPrimitiveFunc && (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp))
mappingVariable = ownerPrimitiveFunc->Inputs()[0];
if (recursive && (mappingVariable != variable))
return GetMappingForNoOpOutput(mappingVariable);
else
return mappingVariable;
}
/*static*/ Variable CompositeFunction::GetMappingVariable(const Variable& variable, bool recursive)
{
Variable mappingVariable = variable;
auto ownerFunc = variable.IsOutput() ? variable.Owner().get() : nullptr;
auto ownerPrimitiveFunc = dynamic_cast<PrimitiveFunction*>(ownerFunc);
if (ownerPrimitiveFunc)
{
if (ownerPrimitiveFunc->OpType() == PrimitiveOpType::NoOp)
mappingVariable = GetMappingForNoOpOutput(variable);
else
{
auto ownerBlockFunc = dynamic_cast<BlockFunction*>(ownerFunc);
if (ownerBlockFunc)
mappingVariable = ownerBlockFunc->CompositeOutputsMap().at(variable);
}
}
if (recursive && (mappingVariable != variable))
return GetMappingVariable(mappingVariable);
else
return mappingVariable;
}
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::CreateComputationNode(const Variable& variable,
Function* function,
const std::vector<std::shared_ptr<ComputationNode<ElementType>>>& inputNodes,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap)
{
PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
if (primitiveFunction && (primitiveFunction->OpType() == PrimitiveOpType::NoOp))
return variableToNodeMap[GetMappingVariable(variable)];
ComputationNodeBasePtr computationNodePtr;
auto internalNodeName = CNTKInternalNodeNameFromUidAndName(function->Uid(), function->Name());
std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
for (auto inputNode : inputNodes)
inputNodesBasePtrs.push_back(inputNode);
if (primitiveFunction)
{
auto functionInputs = function->Inputs();
auto& functionConfig = function->Attributes();
PrimitiveOpType op = primitiveFunction->OpType();
switch (op)
{
case PrimitiveOpType::Negate:
computationNodePtr = New<NegateNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sigmoid:
computationNodePtr = New<SigmoidNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Tanh:
computationNodePtr = New<TanhNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Cos:
computationNodePtr = New<CosineNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sin:
computationNodePtr = New<SinNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ReLU:
computationNodePtr = New<RectifiedLinearNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Exp:
computationNodePtr = New<ExpNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Log:
computationNodePtr = New<LogNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Sqrt:
computationNodePtr = New<SqrtNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Floor:
computationNodePtr = New<FloorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Abs:
computationNodePtr = New<AbsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Reciprocal:
computationNodePtr = New<ReciprocalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Softmax:
computationNodePtr = New<SoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Hardmax:
computationNodePtr = New<HardmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::TransposeAxes:
{
auto axis1 = functionConfig[PrimitiveFunction::AttributeNameAxis1].Value<Axis>();
auto axis2 = functionConfig[PrimitiveFunction::AttributeNameAxis2].Value<Axis>();
// The axis ids passed to the internal CNTK TransposeDimensionsNode are 1 based instead of 0 based
computationNodePtr = New<TransposeDimensionsNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(axis1), AsCNTKInternalAxisIdx(axis2));
break;
}
case PrimitiveOpType::Where:
{
auto dynamicAxes = variable.DynamicAxes();
auto internalCNTKWhereNodeDynamicAxisName = InternalDynamicAxisNameFromDynamicAxes(dynamicAxes);
computationNodePtr = New<WhereNode<ElementType>>(network->GetDeviceId(), internalNodeName, internalCNTKWhereNodeDynamicAxisName);
break;
}
case PrimitiveOpType::Slice:
{
auto axis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
auto beginIndex = functionConfig[PrimitiveFunction::AttributeNameBeginIndex].Value<int>();
auto endIndex = functionConfig[PrimitiveFunction::AttributeNameEndIndex].Value<int>();
// Internal CNTK SliceNode takes 1 based axis indices instead of 0 based
computationNodePtr = New<SliceNode<ElementType>>(network->GetDeviceId(), internalNodeName, beginIndex, endIndex, AsCNTKInternalAxisIdx(axis));
break;
}
case PrimitiveOpType::RandomSample:
{
auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
computationNodePtr = New<RandomSampleNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
break;
}
case PrimitiveOpType::RandomSampleInclusionFrequency:
{
auto numSamples = functionConfig[PrimitiveFunction::AttributeNameNumSamples].Value<size_t>();
auto allowDuplicates = functionConfig[PrimitiveFunction::AttributeNameAllowDuplicates].Value<bool>();
computationNodePtr = New<RandomSampleInclusionFrequencyNode<ElementType>>(network->GetDeviceId(), internalNodeName, numSamples, allowDuplicates);
break;
}
case PrimitiveOpType::Dropout:
{
auto dropoutRate = functionConfig[PrimitiveFunction::AttributeNameDropoutRate].Value<double>();
computationNodePtr = New<DropoutNode<ElementType>>(network->GetDeviceId(), internalNodeName);
computationNodePtr->As<DropoutNode<ElementType>>()->SetDropoutRate(dropoutRate);
break;
}
case PrimitiveOpType::Reshape:
{
computationNodePtr = New<ReshapeNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(primitiveFunction->RawOutputs()[0].Shape()));
break;
}
case PrimitiveOpType::ROIPooling:
{
auto roiOutputShape = functionConfig[PrimitiveFunction::AttributeNameROIOutputShape].Value<NDShape>();
computationNodePtr = New<ROIPoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(roiOutputShape));
break;
}
case PrimitiveOpType::Pooling:
{
PoolingType poolingType = (PoolingType)(functionConfig[PrimitiveFunction::AttributeNamePoolingType].Value<size_t>());
auto poolingWindowsShape = functionConfig[PrimitiveFunction::AttributeNamePoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
computationNodePtr = New<PoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKPoolKind(poolingType), AsTensorShape(poolingWindowsShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::Unpooling:
{
auto unpoolingWindowShape = functionConfig[PrimitiveFunction::AttributeNameUnpoolingWindowShape].Value<NDShape>();
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
//We only get here after validation so it is safe to assume unpooling is max
computationNodePtr = New<MaxUnpoolingNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(unpoolingWindowShape), AsTensorShape(strides), autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::SumAll:
computationNodePtr = New<SumElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Plus:
computationNodePtr = New<PlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LogPlus:
computationNodePtr = New<LogPlusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Minus:
computationNodePtr = New<MinusNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ElementTimes:
computationNodePtr = New<ElementTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Equal:
computationNodePtr = New<EqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::NotEqual:
computationNodePtr = New<NotEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Less:
computationNodePtr = New<LessNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::LessEqual:
computationNodePtr = New<LessEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Greater:
computationNodePtr = New<GreaterNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::GreaterEqual:
computationNodePtr = New<GreaterEqualNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Times:
{
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
auto inferInputRankToMap = functionConfig[PrimitiveFunction::AttributeNameInferInputRankToMap].Value<int>();
computationNodePtr = New<TimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank, inferInputRankToMap);
break;
}
case PrimitiveOpType::TransposeTimes:
{
size_t outputRank = functionConfig[PrimitiveFunction::AttributeNameOutputRank].Value<size_t>();
computationNodePtr = New<TransposeTimesNode<ElementType>>(network->GetDeviceId(), internalNodeName, outputRank);
break;
}
case PrimitiveOpType::Convolution:
{
NDShape outputMapCount, kernelShape;
std::tie(outputMapCount, kernelShape) = GetConvolutionOutputMapCountAndKernelShape(functionInputs[0].Shape(), functionInputs[1].Shape());
auto strides = functionConfig[PrimitiveFunction::AttributeNameStrides].Value<NDShape>();
auto lowerPad = functionConfig[PrimitiveFunction::AttributeNameLowerPad].Value<NDShape>();
auto upperPad = functionConfig[PrimitiveFunction::AttributeNameUpperPad].Value<NDShape>();
auto sharing = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameSharing].Value<std::vector<DictionaryValue>>());
auto autoPadding = AsVector<bool>(functionConfig[PrimitiveFunction::AttributeNameAutoPadding].Value<std::vector<DictionaryValue>>());
auto transpose = functionConfig[PrimitiveFunction::AttributeNameTranspose].Value<bool>();
auto maxTempMemSizeInSamples = functionConfig[PrimitiveFunction::AttributeNameMaxTempMemSizeInSamples].Value<size_t>();
computationNodePtr = New<ConvolutionNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(kernelShape), AsTensorShape(outputMapCount), AsTensorShape(strides), sharing, autoPadding, AsTensorShape(lowerPad), AsTensorShape(upperPad), transpose, ImageLayoutKind::CHW, maxTempMemSizeInSamples);
break;
}
case PrimitiveOpType::CosDistance:
computationNodePtr = New<CosDistanceNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Logistic:
computationNodePtr = New<LogisticNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::SquaredError:
computationNodePtr = New<SquareErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::CrossEntropyWithSoftmax:
computationNodePtr = New<CrossEntropyWithSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ClassificationError:
computationNodePtr = New<ClassificationErrorNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::EditDistanceError:
{
auto subPen = functionConfig[PrimitiveFunction::AttributeNameSubstitutionPenalty].Value<float>();
auto delPen = functionConfig[PrimitiveFunction::AttributeNameDeletionPenalty].Value<float>();
auto insPen = functionConfig[PrimitiveFunction::AttributeNameInsertionPenalty].Value<float>();
auto squashInputs = functionConfig[PrimitiveFunction::AttributeNameSquashInputs].Value<bool>();
auto samplesToIgnore = AsVector<size_t>(functionConfig[PrimitiveFunction::AttributeNameSamplesToIgnore].Value<std::vector<DictionaryValue>>());
computationNodePtr = New<EditDistanceErrorNode<ElementType>>(network->GetDeviceId(), subPen, delPen, insPen, squashInputs, samplesToIgnore, internalNodeName);
break;
}
case PrimitiveOpType::LambdaRank:
computationNodePtr = New<LambdaRankNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::NDCG:
computationNodePtr = New<NDCG1EvalNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::PastValue:
case PrimitiveOpType::FutureValue:
{
Variable inputOperandVar = functionInputs[0];
Variable initialStateVar = functionInputs[1];
size_t offset = primitiveFunction->Attributes()[PrimitiveFunction::AttributeNameOffset].Value<size_t>();
if (op == PrimitiveOpType::PastValue)
computationNodePtr = New<PastValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
else
computationNodePtr = New<FutureValueNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsTensorShape(inputOperandVar.Shape()), offset);
break;
}
case PrimitiveOpType::ReduceElements:
{
auto reductionAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
auto reductionOpName = functionConfig[PrimitiveFunction::AttributeNameReductionOpName].Value<std::wstring>();
computationNodePtr = New<ReduceElementsNode<ElementType>>(network->GetDeviceId(), internalNodeName, reductionOpName, AsCNTKInternalAxisIdx(reductionAxis));
break;
}
case PrimitiveOpType::BatchNormalization:
{
auto spatial = functionConfig[PrimitiveFunction::AttributeNameSpatial].Value<bool>();
auto normalizationTimeConstant = functionConfig[PrimitiveFunction::AttributeNameNormalizationTimeConstant].Value<double>();
auto blendTimeConstant = functionConfig[PrimitiveFunction::AttributeNameBlendTimeConstant].Value<double>();
auto epsilon = functionConfig[PrimitiveFunction::AttributeNameEpsilon].Value<double>();
auto useCuDNNEngine = functionConfig[PrimitiveFunction::AttributeNameUseCuDNNEngine].Value<bool>();
computationNodePtr = New<BatchNormalizationNode<ElementType>>(network->GetDeviceId(), internalNodeName, spatial, normalizationTimeConstant, blendTimeConstant, epsilon, !useCuDNNEngine, ImageLayoutKind::CHW);
break;
}
case PrimitiveOpType::Combine:
// This operation is just a no-op and is a means to combine multiple functions to create a single Function
// whose outputs are a union of the outputs of the Functions being combined.
computationNodePtr = variableToNodeMap[variable];
break;
case PrimitiveOpType::PackedIndex:
computationNodePtr = New<PackedIndexNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::GatherPacked:
computationNodePtr = New<GatherPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::ScatterPacked:
computationNodePtr = New<ScatterPackedNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Clip:
computationNodePtr = New<ClipNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Select:
computationNodePtr = New<IfNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
case PrimitiveOpType::Splice:
{
Axis spliceAxis = functionConfig[PrimitiveFunction::AttributeNameAxis].Value<Axis>();
computationNodePtr = New<RowStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, AsCNTKInternalAxisIdx(spliceAxis));
break;
}
case PrimitiveOpType::OptimizedRNNStack:
{
auto bidirectional = functionConfig[PrimitiveFunction::AttributeNameBidirectional].Value<bool>();
auto numLayers = functionConfig[PrimitiveFunction::AttributeNameNumLayers].Value<size_t>();
auto hiddenSize = functionConfig[PrimitiveFunction::AttributeNameHiddenSize].Value<size_t>();
auto recurrentOp = functionConfig[PrimitiveFunction::AttributeNameRecurrentOp].Value<std::wstring>();
computationNodePtr = New<OptimizedRNNStackNode<ElementType>>(network->GetDeviceId(), internalNodeName, bidirectional, numLayers, hiddenSize, recurrentOp);
break;
}
case PrimitiveOpType::ReconcileDynamicAxis:
{
computationNodePtr = New<ReconcileDynamicAxisNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::LogSoftmax:
{
//This can be implemented as x => x - ReduceLogSum(x). How to do this here?
computationNodePtr = New<LogSoftmaxNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
}
case PrimitiveOpType::Pass:
computationNodePtr = New<PassNode<ElementType>>(network->GetDeviceId(), internalNodeName);
break;
default:
LogicError("Specified op %S not yet supported", PrimitiveOpTypeName(op).c_str());
break;
}
// Let's reorder inputNodesBasePtrs properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
ReorderAsCNTKComputationNodeInputs(op, inputNodesBasePtrs);
if (computationNodePtr->Is<INumInputs>())
{
auto computationNodeExpectedInputCount = computationNodePtr->As<INumInputs>()->GetExpectedNumInputs();
if (computationNodeExpectedInputCount != inputNodesBasePtrs.size())
LogicError("Input count mismatch: The Primitive function for op %S has %d inputs while the corresponding ComputationNode has %d inputs",
PrimitiveOpTypeName(op).c_str(),
(int)inputNodesBasePtrs.size(),
(int)computationNodeExpectedInputCount);
}
if (computationNodePtr->Is<RngUser>())
{
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngSeed))
{
auto seed = functionConfig[PrimitiveFunction::AttributeNameRngSeed].Value<size_t>();
uint64_t offset = 0;
if (functionConfig.Contains(PrimitiveFunction::AttributeNameRngOffset))
{
offset = functionConfig[PrimitiveFunction::AttributeNameRngOffset].Value<size_t>();
}
computationNodePtr->As<RngUser>()->SetRngState(seed, offset);
}
}
}
else
{
computationNodePtr = New<UserDefinedV2FunctionNode<ElementType>>(network->GetDeviceId(), internalNodeName, function->shared_from_this());
// For user defined functions, we only attach unique inputs in the internal computation network since, the UDF
// backward implementations directly compute aggregate gradient values for unique inputs
std::vector<ComputationNodeBasePtr> uniqueInputNodesBasePtrs;
for (auto inputNodeBasePtr : inputNodesBasePtrs)
{
if (std::find(uniqueInputNodesBasePtrs.begin(), uniqueInputNodesBasePtrs.end(), inputNodeBasePtr) == uniqueInputNodesBasePtrs.end())
uniqueInputNodesBasePtrs.push_back(inputNodeBasePtr);
}
inputNodesBasePtrs = uniqueInputNodesBasePtrs;
}
network->AddNodeToNetAndAttachInputs(computationNodePtr, inputNodesBasePtrs);
return computationNodePtr;
}
template <typename ElementType>
/*static*/ ComputationNodeBasePtr CompositeFunction::GetOutputVariableNode(const Variable& variable,
Microsoft::MSR::CNTK::ComputationNetworkPtr& network,
ComputationNetworkBuilder<ElementType>& builder,
std::unordered_map<Variable, ComputationNodeBasePtr>& variableToNodeMap,
std::unordered_map<Variable, bool>& isVariableRootMap,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
{
assert(variable.IsOutput());
Function* function = variable.Owner().get();
ComputationNodeBasePtr computationNodePtr;
auto& functionInputs = function->m_inputs;
DataType nonConstInputDataType = DataType::Unknown;
for (auto& inputVar : functionInputs)
{
if (!inputVar.IsConstant() && (inputVar.GetDataType() != DataType::Unknown))
{
nonConstInputDataType = inputVar.GetDataType();
break;
}
}
// Create the nodes corresponding to the inputs
std::vector<std::shared_ptr<ComputationNode<ElementType>>> inputNodes;
for (auto& inputVar : functionInputs)
{
// If the inputVar is a constant and not the right DataType let's coerce it to the right type
if (inputVar.IsConstant() && (nonConstInputDataType != DataType::Unknown) && (inputVar.GetDataType() != nonConstInputDataType))
{
auto originalConstantValue = Constant(inputVar).Value();
auto constantValueCPU = originalConstantValue->DeepClone(DeviceDescriptor::CPUDevice(), true);
NDArrayViewPtr newConstantValue = CloneAsDataType(constantValueCPU, nonConstInputDataType, true);
inputVar = Constant(newConstantValue->DeepClone(originalConstantValue->Device(), originalConstantValue->IsReadOnly()), inputVar.Name());
}
auto baseNodePtr = GetNode(inputVar, network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor);
inputNodes.push_back((baseNodePtr != nullptr) ? baseNodePtr->template As<ComputationNode<ElementType>>()->shared_from_this() : nullptr);
}
BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function);
if (blockFunction)
{
// For block function, map each argument placeholder of the underlying composite to
// the computation node corresponding to the block input that the argument placeholder
// of the composite is mapped to.
auto compositeArguments = blockFunction->Composite()->Arguments();
for (auto compositeArgument : compositeArguments)
variableToNodeMap[compositeArgument] = variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());
return GetNode(variable.BlockFunctionVariableMapping(), network, builder, variableToNodeMap, isVariableRootMap, inputsToExcludeGradientsFor);
}
else
computationNodePtr = CreateComputationNode(variable, function, inputNodes, network, variableToNodeMap);
PrimitiveFunction* primitiveFunction = dynamic_cast<PrimitiveFunction*>(function);
if (!primitiveFunction || (primitiveFunction->OpType() != PrimitiveOpType::Combine))
{
for (auto inputVar : functionInputs)
isVariableRootMap[inputVar] = false;
}
return computationNodePtr;
}
template <typename ElementType>
ComputationNetworkPtr CompositeFunction::GetComputationNetwork(const DeviceDescriptor& device,
const std::unordered_set<Variable>& backpropRoots,
const std::unordered_set<Variable>& outputs,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor,
bool allocateNetworkMatrices)
{
if (m_computationNetwork != nullptr)
{
// TODO: We should either invalidate and readapt the network if he backpropRoots change compared to what was specified when the network
// was last constructed, to just recreate a new network.
// For now just disallow changing the backpropRoots after the network is created
if (!backpropRoots.empty() && (m_currentBackpropRoots != backpropRoots))
LogicError("Changing backprop roots across different Forward calls on a CNTK composite Function is currently unsupported");
// TODO: Support changing the device across different invocations of the forward method on a Function instance
if (AsDeviceDescriptor(m_computationNetwork->GetDeviceId()) != device)
LogicError("Changing device across different Forward calls on a CNTK composite Function is currently unsupported");
if (!backpropRoots.empty() && (inputsToExcludeGradientsFor != m_inputsExcludedFromGradientComputation))
LogicError("Changing the set of inputs to exclude from gradient computation, across different Forward calls on a CNTK composite Function, is currently unsupported");
}
else
{
m_computationNetwork = std::make_shared<ComputationNetwork>(AsCNTKImplDeviceId(device));
auto networkInputs = this->Inputs();
for (auto inputExcluded : inputsToExcludeGradientsFor)
{
// Only inputs of the network can be excluded from gradient computation
if (std::find(networkInputs.begin(), networkInputs.end(), inputExcluded) == networkInputs.end())
InvalidArgument("Function::Forward: Only inputs of a Function can be excluded from gradient computation");
}
m_inputsExcludedFromGradientComputation = inputsToExcludeGradientsFor;
ComputationNetworkBuilder<ElementType> builder(*m_computationNetwork);
// TODO: We currently only support one backprop root
if (backpropRoots.size() > 1)
LogicError("More than one backprop roots is currently unsupported");
auto placeholders = Placeholders();
if (!placeholders.empty())
InvalidArgument("All placeholders of a Function must be bound before performing a Forward computation on the Function!");
// Now recursively create the network in a top-down fashion
auto rootFunction = RootFunction();
auto rootFunctionOutputs = rootFunction->RawOutputs();
for (auto rootOutput : rootFunctionOutputs)
GetNode(rootOutput, m_computationNetwork, builder, m_variableToNodeMap, m_isVariableRootMap, m_inputsExcludedFromGradientComputation);
// We need to patch the Computation node mappings for the arguments of block functions
// since for recurrent inputs, the mappings are not fully established the first time
std::function<void(const FunctionPtr&)> PatchBlockArgumentsMapping;
PatchBlockArgumentsMapping = [this, &PatchBlockArgumentsMapping](const FunctionPtr& function) {
BlockFunction* blockFunction = dynamic_cast<BlockFunction*>(function.get());
if (blockFunction)
{
auto compositeArguments = blockFunction->Composite()->Arguments();
for (auto compositeArgument : compositeArguments)
m_variableToNodeMap[compositeArgument] = m_variableToNodeMap.at(compositeArgument.BlockFunctionVariableMapping());
PreorderTraverseFunctions(function->BlockRoot(), PatchBlockArgumentsMapping);
}
};
PreorderTraverseFunctions(rootFunction, PatchBlockArgumentsMapping);
std::function<bool(const Variable&)> IsVariableRoot;
IsVariableRoot = [this, &IsVariableRoot](const Variable& outputVar) {
auto mappingVariable = GetMappingVariable(outputVar);
return (m_isVariableRootMap[outputVar] && ((mappingVariable == outputVar) || IsVariableRoot(mappingVariable)));
};
// If any of the function or requested outputs is not a root node, we need to explicitly
// add it to the 'output' group of the ComputationNetwork
std::unordered_set<Variable> networkOutputs(outputs);
networkOutputs.insert(rootFunctionOutputs.begin(), rootFunctionOutputs.end());
for (auto output : networkOutputs)
{
if (!IsVariableRoot(output))
{
auto computationNode = m_variableToNodeMap[output];
if (!computationNode)
InvalidArgument("One of the requested outputs for the Function forward computation is not part of the graph underlying the Function");
m_computationNetwork->AddToNodeGroup(L"output", computationNode);
}
}
m_currentBackpropRoots = backpropRoots;
// In case of recurrence, the inputs of some of the ComputationNodes are not attached due to cycles.
// Now attach those after we have created all ComputationNodes in the network
for (auto varNodePair : m_variableToNodeMap)
{
auto& currentComputationNode = varNodePair.second;
if (!currentComputationNode)
LogicError("No computation node mapping exists for Variable %S", varNodePair.first.Name().c_str());
auto& currentComputationNodeInputs = currentComputationNode->GetInputs();
auto& currentVar = varNodePair.first;
if (!currentVar.IsOutput())
continue;
if (std::find(currentComputationNodeInputs.begin(), currentComputationNodeInputs.end(), nullptr) != currentComputationNodeInputs.end())
{
// This ComputationNode has at least one null input which now needs to be properly attached
const PrimitiveFunction* primitiveFunc = dynamic_cast<const PrimitiveFunction*>(currentVar.Owner().get());
// Skip block primitives since they do not directly map to a computation node
if (primitiveFunc->OpType() == PrimitiveOpType::Block)
continue;
// Let's reorder properly since the ordering of inputs of CNTK internal ComputationNode may be different from the PrimitiveFunction inputs ordering
auto inputVars = primitiveFunc->Inputs();
ReorderAsCNTKComputationNodeInputs(primitiveFunc->OpType(), inputVars);
inputVars.resize(currentComputationNode->GetNumInputs());
std::vector<ComputationNodeBasePtr> inputNodesBasePtrs;
for (auto inputVar : inputVars)
inputNodesBasePtrs.push_back(m_variableToNodeMap.at(inputVar));
currentComputationNode->AttachInputs(inputNodesBasePtrs);
}
}
m_computationNetwork->SetTraceLevel(Internal::GetComputationNetworkTraceLevel());
m_computationNetwork->CompileNetwork();
// Verify that the shapes of the output Variables that we computed match the corresponding nodes in the ComputationNetwork
for (auto varNodePair : m_variableToNodeMap)
{
if (varNodePair.first.IsOutput())
{
auto outputVar = varNodePair.first;
auto computationNodePtr = m_variableToNodeMap.at(outputVar);
auto outputShape = outputVar.Shape();
auto computationNodeSampleLayout = computationNodePtr->GetSampleLayout();
if (((outputShape.Rank() == 0) && (computationNodeSampleLayout[0] != 1)) ||
((outputShape.Rank() != 0) && (computationNodeSampleLayout != AsTensorViewShape(outputShape)) && (computationNodeSampleLayout != AsTensorShape(outputShape))))
{
LogicError("The output Variable shape %S does not match the SampleLayout shape %s of the corresponding ComputationNode in the network", outputShape.AsString().c_str(), ((std::string)computationNodeSampleLayout).c_str());
}
}
}
// Record the timestamps of Parameter values
assert(m_lastRecordedParameterValueTimeStamps.empty());
auto functionParameters = Parameters();
for (auto parameter : functionParameters)
m_lastRecordedParameterValueTimeStamps.insert({ parameter, parameter.CurrentValueTimeStamp() });
}
if (!m_networkMatricesAllocated && allocateNetworkMatrices)
{
ComputationNodeBasePtr backpropRootNode;
if (!m_currentBackpropRoots.empty())
backpropRootNode = m_variableToNodeMap.at(*m_currentBackpropRoots.begin());
// Now recursively traverse the network in a top-down fashion
auto rootFunction = RootFunction();
auto rootFunctionOutputs = rootFunction->RawOutputs();
std::vector<ComputationNodeBasePtr> forwardRootNodes;
for (auto rootOutput : rootFunctionOutputs)
forwardRootNodes.push_back(m_variableToNodeMap.at(rootOutput));
std::vector<ComputationNodeBasePtr> forwardOutputNodes;
for (auto output : outputs)
forwardOutputNodes.push_back(m_variableToNodeMap.at(output));
m_computationNetwork->AllocateAllMatrices(forwardRootNodes, forwardOutputNodes, backpropRootNode);
m_networkMatricesAllocated = allocateNetworkMatrices;
std::unordered_set<ComputationNodeBasePtr> allNetworkRoots = { backpropRootNode };
allNetworkRoots.insert(forwardRootNodes.begin(), forwardRootNodes.end());
allNetworkRoots.insert(forwardOutputNodes.begin(), forwardOutputNodes.end());
m_allNetworkRootsInGlobalEvalOrder = m_computationNetwork->SortByGlobalEvalOrder(allNetworkRoots);
}
else
{
// Make sure the outputs requested are a subset of the outputs we setup the current matrix allocation structure
// in the cached computation network
for (auto output : outputs)
{
auto computationNode = m_variableToNodeMap.at(output);
if (std::find(m_allNetworkRootsInGlobalEvalOrder.begin(), m_allNetworkRootsInGlobalEvalOrder.end(), computationNode) == m_allNetworkRootsInGlobalEvalOrder.end())
LogicError("Changing requested outputs across different Forward calls on a CNTK composite Function is currently unsupported");
}
}
return m_computationNetwork;
}
template <typename ElementType>
/*static*/ void CompositeFunction::PopulateComputationNodeValue(const std::pair<Variable, ValuePtr>& variableValue, ComputationNodeBasePtr& computationNode, std::unordered_map<MBLayoutPtr, Variable>& layoutsPopulated)
{
if (!computationNode->Is<InputValueBase<ElementType>>())
LogicError("CompositeFunction::Forward: Illegal to populate value of computation node type other than InputValueBase!");
std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableValue.first, variableValue.second);
// Switch the node matrix to the right matrix type
auto& nodeData = computationNode->As<ComputationNode<ElementType>>()->Value();
nodeData.AssignValuesOf(*CNTKMatrixAndMBLayout.first);
auto layout = CNTKMatrixAndMBLayout.second;
auto& nodeLayout = computationNode->GetMBLayout();
if (layoutsPopulated.find(nodeLayout) == layoutsPopulated.end())
{
nodeLayout->CopyFrom(layout);
layoutsPopulated.insert({ nodeLayout, variableValue.first });
}
else
{
if (*nodeLayout != *layout)
InvalidArgument("Function::Forward: Different minibatch layouts detected (difference in sequence lengths or count or start flags) in data specified for 2 of the Function's argument ('%S', '%S') having same dynamic axes", variableValue.first.Name().c_str(), layoutsPopulated.at(nodeLayout).Name().c_str());
}
}
void CompositeFunction::PopulateNetworkInputs(const std::unordered_map<Variable, ValuePtr>& arguments)
{
std::unordered_map<MBLayoutPtr, Variable> layoutsPopulated;
std::vector<ComputationNodeBasePtr> inputNodes;
for (auto argumentValuePair : arguments)
{
auto argument = argumentValuePair.first;
auto argumentComputationNode = m_variableToNodeMap.at(argument);
assert(argumentComputationNode);
inputNodes.push_back(argumentComputationNode);
ValuePtr argumentValue = arguments.at(argument);
switch (argumentValue->GetDataType())
{
case DataType::Float:
PopulateComputationNodeValue<float>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
break;
case DataType::Double:
PopulateComputationNodeValue<double>({ argument, argumentValue }, argumentComputationNode, layoutsPopulated);
break;
default:
LogicError("Unsupported DataType %s", DataTypeName(argumentValue->GetDataType()));
break;
}
}
m_computationNetwork->BumpEvalTimeStamp(inputNodes);
}
template <typename ElementType>
/*static*/ void CompositeFunction::PopulateComputationNodeGradient(const std::pair<Variable, ValuePtr>& variableGradient, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode)
{
std::pair<std::shared_ptr<const Matrix<ElementType>>, MBLayoutPtr> CNTKMatrixAndMBLayout = Utils::GetCNTKImplMatrixAndMBLayoutFromValueObject<ElementType>(variableGradient.first, variableGradient.second);
MBLayoutPtr layout = CNTKMatrixAndMBLayout.second;
auto nodeLayout = computationNode->GetMBLayout();
if (((layout == nullptr) != (nodeLayout == nullptr)) || ((layout != nullptr) && (*layout != *nodeLayout)))
InvalidArgument("The layout of the specified gradient Value is incompatible with the layout of the corresponding Variable computed during Forward call");
computationNode->As<ComputationNode<ElementType>>()->AssignGradient(*CNTKMatrixAndMBLayout.first);
}
// Assign the supplied gradients corresponding to the root(s) of the network to be backpropagated through the graph
void CompositeFunction::PopulateNetworkGradients(const std::unordered_map<Variable, ValuePtr>& gradients)
{
auto functionOutputs = RawOutputs();
for (auto gradientVarValuePair : gradients)
{
auto outputComputationNode = m_variableToNodeMap.at(gradientVarValuePair.first);
ValuePtr gradientValue = gradientVarValuePair.second;
switch (gradientValue->GetDataType())
{
case DataType::Float:
PopulateComputationNodeGradient<float>(gradientVarValuePair, outputComputationNode);
break;
case DataType::Double:
PopulateComputationNodeGradient<double>(gradientVarValuePair, outputComputationNode);
break;
default:
LogicError("Unsupported DataType %s", DataTypeName(gradientValue->GetDataType()));
break;
}
}
}
static NDShape GetValueShape(const Variable& var, const ComputationNodeBasePtr& computationNodePtr)
{
size_t outputValueNumAxes = var.Shape().Rank();
// Add the batch and dynamic axes if needed
if (computationNodePtr->GetMBLayout() != nullptr)
outputValueNumAxes += 2;
std::vector<size_t> outputShapeDims(outputValueNumAxes);
for (size_t i = 0; i < var.Shape().Rank(); ++i)
outputShapeDims[i] = computationNodePtr->GetSampleLayout().GetDim(i);
if (computationNodePtr->GetMBLayout() != nullptr)
{
outputShapeDims[var.Shape().Rank()] = computationNodePtr->GetMBLayout()->GetNumTimeSteps();
outputShapeDims[var.Shape().Rank() + 1] = computationNodePtr->GetMBLayout()->GetNumSequences();
}
return NDShape(outputShapeDims);
}
/*static*/ void CompositeFunction::GetNodeOutputOrGradient(Variable var, ValuePtr& varValue, Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, bool getGradient)
{
auto valueShape = GetValueShape(var, computationNode);
if (varValue != nullptr)
{
// TODO: The shape of the specified output Value object must match the actual output shape
if ((varValue->Shape() != valueShape) && (AsTensorShape(varValue->Shape()) != AsTensorShape(valueShape)))
InvalidArgument("The shape %S of the specified Value object for %s does not match the actual shape %S", AsStringForErrorReporting(varValue->Shape()).c_str(), getGradient ? "gradient" : "output", AsStringForErrorReporting(valueShape).c_str());
}
ValuePtr nodeValue;
auto layout = computationNode->GetMBLayout();
switch (var.GetDataType())
{
case DataType::Float:
{
auto& matrix = getGradient ? computationNode->As<ComputationNode<float>>()->Gradient() : computationNode->As<ComputationNode<float>>()->Value();
if (varValue == nullptr)
nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<float>>(matrix.AsReference()), layout, /*readOnly =*/ false);
else
nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<float>(var, matrix, layout);
break;
}
case DataType::Double:
{
auto& matrix = getGradient ? computationNode->As<ComputationNode<double>>()->Gradient() : computationNode->As<ComputationNode<double>>()->Value();
if (varValue == nullptr)
nodeValue = MakeSharedObject<PackedValue>(var.Shape(), std::make_shared<Matrix<double>>(matrix.AsReference()), layout, /*readOnly =*/ false);
else
nodeValue = Utils::GetValueObjectFromCNTKImplMatrixAndMBLayout<double>(var, matrix, layout);
break;
}
default:
LogicError("Unsupported DataType %s", DataTypeName(var.GetDataType()));
break;
}
if (varValue == nullptr)
varValue = nodeValue;
else
varValue->CopyFrom(*nodeValue);
}
void CompositeFunction::GetNetworkOutputs(std::unordered_map<Variable, ValuePtr>& outputs)
{
// Now copy the Forward values of output nodes from the network to outputs' Value objects
for (auto outputVarValuePair : outputs)
GetNodeOutputOrGradient(outputVarValuePair.first, outputs[outputVarValuePair.first], m_variableToNodeMap.at(outputVarValuePair.first), false /*getGradient*/);
}
void CompositeFunction::GetNetworkGradients(std::unordered_map<Variable, ValuePtr>& gradients)
{
auto networkInputs = this->Inputs();
// Now copy the gradient values of input nodes of the network to gradients' Value objects
for (auto gradientVarValuePair : gradients)
{
// Only gradients corresponding to inputs of the network can be obtained
if (std::find(networkInputs.begin(), networkInputs.end(), gradientVarValuePair.first) == networkInputs.end())
InvalidArgument("Backpropagated gradient values can only be obtained for inputs of a Function");
// Gradients can only be obtained for parameter variables or input variables that NeedsGradient
if (!gradientVarValuePair.first.NeedsGradient() || (m_inputsExcludedFromGradientComputation.find(gradientVarValuePair.first) != m_inputsExcludedFromGradientComputation.end()))
InvalidArgument("Gradient value incorrectly requested for an Output or Constant Variable, an Input Variable with NeedsGradient setting of false, or an input for which gradient computation was explicitly excluded");
auto computationNodePtr = m_variableToNodeMap.at(gradientVarValuePair.first);
if (!computationNodePtr->NeedsGradient())
LogicError("Backpropagated gradient value cannot be read from a ComputationNode that has NeedsGradient set to false");
GetNodeOutputOrGradient(gradientVarValuePair.first, gradients[gradientVarValuePair.first], computationNodePtr, true /*getGradient*/);
}
}
const std::vector<Variable>& CompositeFunction::GetArgumentDependencies(const Variable& output)
{
if (m_perOutputVarArgumentDependencies.find(output) == m_perOutputVarArgumentDependencies.end())
{
if (output.IsOutput())
m_perOutputVarArgumentDependencies[output] = AsComposite(output.Owner())->Arguments();
else
m_perOutputVarArgumentDependencies[output] = { output };
}
return m_perOutputVarArgumentDependencies[output];
}
std::unordered_map<Variable, uint64_t> CompositeFunction::GetCurrentBackpropRootsTimeStamps() const
{
std::unordered_map<Variable, uint64_t> currentBackpropRootsTimeStamps;
assert(m_computationNetwork != nullptr);
for (auto& backpropRoot : m_currentBackpropRoots)
currentBackpropRootsTimeStamps[backpropRoot] = m_variableToNodeMap.at(backpropRoot)->GetEvalTimeStamp();
return currentBackpropRootsTimeStamps;
}
/*virtual*/ BackPropStatePtr CompositeFunction::Forward(const std::unordered_map<Variable, ValuePtr>& arguments,
std::unordered_map<Variable, ValuePtr>& outputs,
const DeviceDescriptor& computeDevice,
const std::unordered_set<Variable>& outputsToRetainBackwardStateFor,
const std::unordered_set<Variable>& inputsToExcludeGradientsFor)
{
// Validate arguments and outputs
if (outputs.empty())
InvalidArgument("CompositeFunction::Forward: At least one output has to be specified!");
// Make sure that the DataType of the variables and corresponding values match
// TODO: We need a better way to determine the ElementType for the network
auto dataType = DataType::Unknown;
for (auto variableValuePair : arguments)
{
if (dataType == DataType::Unknown)
dataType = variableValuePair.first.GetDataType();
else if (dataType != variableValuePair.first.GetDataType())
LogicError("CompositeFunction::Forward: The DataType of all arguments of the Function must be same");
}
if (dataType == DataType::Unknown)
{
for (auto variableValuePair : outputs)
{
if (dataType == DataType::Unknown)
dataType = variableValuePair.first.GetDataType();
}
}
std::unordered_set<Variable> requestedOutputVariables;
for (auto output : outputs)
requestedOutputVariables.insert(output.first);
if (dataType == DataType::Float)
GetComputationNetwork<float>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
else if (dataType == DataType::Double)
GetComputationNetwork<double>(computeDevice, outputsToRetainBackwardStateFor, requestedOutputVariables, inputsToExcludeGradientsFor, true);
else
InvalidArgument("Unsupported DataType %s", DataTypeName(dataType));
std::unordered_set<Variable> functionOutputs(m_outputs.begin(), m_outputs.end());
std::vector<ComputationNodeBasePtr> outputsToEvaluate;
std::unordered_set<Variable> requiredArguments;
for (auto outputVarValuePair : outputs)
{
auto& requiredArgumentsForCurrentOutput = GetArgumentDependencies(outputVarValuePair.first);
requiredArguments.insert(requiredArgumentsForCurrentOutput.begin(), requiredArgumentsForCurrentOutput.end());
auto outputComputationNode = m_variableToNodeMap.at(outputVarValuePair.first);
outputsToEvaluate.push_back(outputComputationNode);
}
// We should have argument values supplied for all required argument dependencies for the requested outputs
std::vector<Variable> missingRequiredArguments;
std::unordered_map<Variable, ValuePtr> requiredArgumentValues;
for (auto requiredArgument : requiredArguments)
{
auto iter = arguments.find(requiredArgument);
if (iter == arguments.end())
missingRequiredArguments.push_back(requiredArgument);
else
requiredArgumentValues.insert(*iter);
}
if (!missingRequiredArguments.empty())
{
std::wstring missingRequiredArgumentNames = NamedListString(missingRequiredArguments);
InvalidArgument("Function::Forward: Values for %d required arguments (%S), that the requested output(s) depend on, have not been provided", (int)missingRequiredArguments.size(), missingRequiredArgumentNames.c_str());
}
if (requiredArgumentValues.size() < arguments.size())
fprintf(stderr, "WARNING: Function::Forward provided values for (%d) extra arguments which are not required for evaluating the specified Function outputs!\n", (int)(arguments.size() - requiredArgumentValues.size()));
// Feed data into the arguments of the network
// TODO: Avoid copying the data when possible
PopulateNetworkInputs(requiredArgumentValues);
// Dropout nodes have an implicit input in the form of the random mask that is applied to its explicit input
// This mask is regenerated every minibatch and hence dropout nodes with a non-zero dropout rate must me marked outdated
// w.r.t. inputs to force evaluation in each minibatch
list<ComputationNodeBasePtr> dropoutNodes = m_computationNetwork->GetNodesWithType(OperationNameOf(DropoutNode));
for (auto& nodeIter : dropoutNodes)
nodeIter->SetEvalTimeStampOutdatedWrtAll();
// Bump the timestamp of the parameter nodes whose values have changed
for (auto& paramTimeStampRecord : m_lastRecordedParameterValueTimeStamps)
{
auto parameter = paramTimeStampRecord.first;
auto prevTimeStamp = paramTimeStampRecord.second;
auto newTimeStamp = parameter.CurrentValueTimeStamp();
if (newTimeStamp > prevTimeStamp)
{
paramTimeStampRecord.second = newTimeStamp;
m_variableToNodeMap.at(parameter)->BumpEvalTimeStamp();
}
}
// The 'outputsToRetainBackwardStateFor' nodes also need to be evaluated if not already specified in 'outputs'
for (auto rootVarForBackprop : outputsToRetainBackwardStateFor)
{
if (outputs.find(rootVarForBackprop) == outputs.end())
outputsToEvaluate.push_back(m_variableToNodeMap.at(rootVarForBackprop));
}
// Reset the timestamps of all backward roots to record an update in one or more inputs
for (auto& backpropRoot : m_currentBackpropRoots)
m_variableToNodeMap.at(backpropRoot)->SetEvalTimeStampOutdatedWrtAll();
// TODO: Verify that values were supplied for all inputs that requested outputs depend on
ScopedNetworkOperationMode modeGuard(m_computationNetwork, outputsToRetainBackwardStateFor.empty() ? NetworkOperationMode::inferring : NetworkOperationMode::training);
// We may have to include additional nodes in the ForwardProp to align with how the memory sharing structure is setup
// We need to include all roots that lie earlier in the global eval order than the actual outputs we are interested
// in evaluation.
// TODO: This may incur additonal compute costs in some rare scenarios. We need to come up with a better way to handle this.
outputsToEvaluate = m_computationNetwork->SortByGlobalEvalOrder(outputsToEvaluate);
auto lastOutputInEvalOrder = outputsToEvaluate.back();
auto iterEndRootInEvalOrder = std::find(m_allNetworkRootsInGlobalEvalOrder.begin(), m_allNetworkRootsInGlobalEvalOrder.end(), lastOutputInEvalOrder) + 1;
auto augmentedOutputsToEvaluate = std::vector<ComputationNodeBasePtr>(m_allNetworkRootsInGlobalEvalOrder.begin(), iterEndRootInEvalOrder);
m_computationNetwork->ForwardProp(augmentedOutputsToEvaluate);
GetNetworkOutputs(outputs);
// TODO: How to deal with the specified 'computeDevice'
Variable evalTimeStampVariable;
if (requiredArgumentValues.empty())
evalTimeStampVariable = Inputs()[0];
else
evalTimeStampVariable = requiredArgumentValues.begin()->first;
BackPropStatePtr backpropStatePtr;
if (outputsToRetainBackwardStateFor.size() > 0)
backpropStatePtr = MakeSharedObject<CNTKBackPropState>(this->shared_from_this(), computeDevice, GetCurrentBackpropRootsTimeStamps());
return backpropStatePtr;
}
/*virtual*/ void CompositeFunction::Backward(const BackPropStatePtr& state,
const std::unordered_map<Variable, ValuePtr>& rootGradientValues,
std::unordered_map<Variable, ValuePtr>& backPropagatedGradientValuesForInputs)
{
auto backpropState = dynamic_cast<const CNTKBackPropState*>(state.get());
if (backpropState == nullptr)
InvalidArgument("Invalid backprop state specified");
// TODO: Support multiple concurrent backprop states
std::unordered_map<Variable, uint64_t> currentBackpropRootTimeStamps = GetCurrentBackpropRootsTimeStamps();
if (backpropState->BackpropRootsForwardTimeStamps() != currentBackpropRootTimeStamps)
LogicError("The specified backprop state specified cannot be used for backpropagation as the Function's internal state was modified by subsequent Forward calls to the function."
"This is not a user error but a shortcoming of the current implementation where multiple independent backprop states are not simultaneously supported");
if (rootGradientValues.size() > 1)
LogicError("Currently gradient backprop from only one of the Function Outputs is supported");
// TODO: Avoid copying the data when possible
// Zero all gradients of nodes below the root nodes
for (auto rootGradientVarValuePair : rootGradientValues)
m_computationNetwork->ZeroInputGradients(m_variableToNodeMap.at(rootGradientVarValuePair.first));
// Feed data into the arguments of the network
PopulateNetworkGradients(rootGradientValues);
// Backpropagate through the network
ScopedNetworkOperationMode modeGuard(m_computationNetwork, NetworkOperationMode::training);
auto rootComputationNodePtr = m_variableToNodeMap.at(rootGradientValues.begin()->first);
m_computationNetwork->GetNestedNetwork(rootComputationNodePtr)->Backprop(FrameRange(nullptr), true, true);
GetNetworkGradients(backPropagatedGradientValuesForInputs);
// TODO: How to deal with the specified 'computeDevice'
}
}