// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // // ComputationNetworkBuilder -- helper class for constructing ComputationNetworks and ComputationNodes from C++ (internal and external) // #define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings #include "Basics.h" #include "ComputationNetworkBuilder.h" #include "ComputationNode.h" #include "InputAndParamNodes.h" #include "LinearAlgebraNodes.h" #include "NonlinearityNodes.h" #include "ConvolutionalNodes.h" #include "RecurrentNodes.h" #include "ReshapingNodes.h" #include "TrainingCriterionNodes.h" #include "CompositeComputationNodes.h" #include "EvaluationCriterionNodes.h" #include "EsotericNodes.h" #include namespace Microsoft { namespace MSR { namespace CNTK { using namespace std; // create a new node of a type given as a string, with var args so that this can be used at multiple places template static shared_ptr> CreateStandardNode(const std::wstring& nodeType, _Types&&... _Args) { // please keep this table sorted if (nodeType == OperationNameOf(CRFNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode)) return New>(forward<_Types>(_Args)...); #ifdef ENABLE_BROADCASTING_ELEMENTTIMES else if (nodeType == L"ColumnElementTimes") return New>(forward<_Types>(_Args)...); #else else if (nodeType == OperationNameOf(ColumnElementTimesNode)) return New>(forward<_Types>(_Args)...); #endif else if (nodeType == OperationNameOf(CosDistanceNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(CosDistanceWithNegativeSamplesNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(CosineNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(CrossEntropyNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(CrossEntropyWithSoftmaxNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(SequenceWithSoftmaxNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(DiagonalNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(DiagTimesNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(DropoutNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(DummyCriterionNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ElementTimesNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ErrorPredictionNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ExpNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(FutureValueNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(GMMLogLikelihoodNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(HardmaxNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(InvStdDevNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(KhatriRaoProductNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(LSTMNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(LogNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(LogSoftmaxNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(LookupTableNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(MatrixL1RegNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(MatrixL2RegNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(MeanNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(MinusNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(NegateNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(NoiseContrastiveEstimationNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(PairNetworkNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ParallelNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(PastValueNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(PerDimMeanVarNormalizationNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(PerDimMeanVarDeNormalizationNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(PlusNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ReconcileMBLayoutNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(RectifiedLinearNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ReshapeNode)) return New>(forward<_Types>(_Args)...); #ifdef ENABLE_BROADCASTING_ELEMENTTIMES else if (nodeType == L"RowElementTimes") return New>(forward<_Types>(_Args)...); #else else if (nodeType == OperationNameOf(RowElementTimesNode)) return New>(forward<_Types>(_Args)...); #endif else if (nodeType == OperationNameOf(RowRepeatNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(RowSliceNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(RowStackNode)) return New>(forward<_Types>(_Args)...); #ifdef ENABLE_BROADCASTING_ELEMENTTIMES else if (nodeType == L"Scale") return New>(forward<_Types>(_Args)...); #else else if (nodeType == OperationNameOf(ScaleNode)) return New>(forward<_Types>(_Args)...); #endif else if (nodeType == OperationNameOf(SequenceDecoderNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ShiftNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(SigmoidNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(SoftmaxNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(SquareErrorNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(LogisticNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(StrideTimesNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(SumColumnElementsNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(SumElementsNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(TanhNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(TimeReverseNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(TimesNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(TransposeNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(TransposeTimesNode)) return New>(forward<_Types>(_Args)...); // old names we also support else if (nodeType == L"Delay") return New>(forward<_Types>(_Args)...); else if (nodeType == L"PerDimMeanVarNormalizationNode") return New>(forward<_Types>(_Args)...); else if (nodeType == L"PerDimMeanVarNormalizationNode") return New>(forward<_Types>(_Args)...); #if 1 else if (nodeType == OperationNameOf(DeprecatedReshapeNode)) return New>(forward<_Types>(_Args)...); #endif else InvalidArgument("Attempted to instantiate undefined operation %ls.", nodeType.c_str()); } // create a new node of a type given as a string, with var args so that this can be used at multiple places // This function is used for loading, while the above is used for creating standard-type networks. template static shared_ptr> CreateNode(const std::wstring& nodeType, _Types&&... _Args) { // check more types if (nodeType == OperationNameOf(AveragePoolingNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(BatchNormalizationNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(ConvolutionNode)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(SparseInputValue)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(InputValue)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(LearnableParameter)) return New>(forward<_Types>(_Args)...); else if (nodeType == OperationNameOf(MaxPoolingNode)) return New>(forward<_Types>(_Args)...); //else if (nodeType == OperationNameOf(SparseLearnableParameter)) return New>(forward<_Types>(_Args)...); else return CreateStandardNode(nodeType, forward<_Types>(_Args)...); } // this function is called from SimpleNetworkBuilder and old NDL template /*static*/ shared_ptr> ComputationNetworkBuilder::NewStandardNode(const std::wstring& nodeType, DEVICEID_TYPE deviceId, const wstring& name) { return CreateStandardNode(nodeType, deviceId, name); } // this function is used when loading from file template /*static*/ shared_ptr> ComputationNetworkBuilder::NewNode(const std::wstring& nodeType, DEVICEID_TYPE deviceId, const wstring& name) { return CreateNode(nodeType, deviceId, name); } shared_ptr NewComputationNodeFromConfig(const Microsoft::MSR::ScriptableObjects::IConfigRecordPtr configp) { wstring precision = configp->Get(L"precision"); // dispatch on ElemType wstring operationName = configp->Get(L"operation"); ComputationNodeBasePtr node; if (precision == L"float") node = CreateNode(operationName, configp); else if (precision == L"double") node = CreateNode(operationName, configp); else RuntimeError("NewStandardNode: Invalid value '%ls' for 'precision' parameter. Must be 'float' or 'double'.", precision.c_str()); // add a tag // Tags are used to declare special node types tp ComputationNetwork. const auto nodeWithTag = dynamic_pointer_cast(node); if (nodeWithTag) nodeWithTag->SetTag(configp->Get(L"tag")); return node; } // ----------------------------------------------------------------------- // node creation // ----------------------------------------------------------------------- // The following functions create nodes and add them to the net, but don't attach inputs (some don't have inputs). // There are special versions for nodes with custom constructors, and a catch-all, CreateComputationNode(), for all others. // TODO: Do we really need these? Folks who want to use C++ can instead say net->AddNodeToNet(New<>(...)), which is not that different. // TODO: separate into nodes that have inputs and those that duplicate functions with input adding except just not adding inputs. Clear? template shared_ptr> ComputationNetworkBuilder::CreateLearnableParameter(const std::wstring& paramName, const size_t rows, const size_t cols) { // TODO: in SimpleNetworkBuilder, this is very often followed by InitLearnableParameter()--we should have an overload that just does it right away return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), paramName, rows, cols)); } template shared_ptr> ComputationNetworkBuilder::CreateLearnableParameter(const std::wstring& paramName, const TensorShape& tensorShape) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), paramName, tensorShape)); } #if 0 // not functional at present //sparse matrix size is optionally specified template shared_ptr> ComputationNetworkBuilder::CreateSparseLearnableParameter(const std::wstring & paramName, const size_t rows, const size_t cols, const size_t size) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), paramName, rows, cols, size)); } #endif template shared_ptr> ComputationNetworkBuilder::CreateInputNode(const std::wstring& inputName, const size_t rows) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), inputName, rows)); } template shared_ptr> ComputationNetworkBuilder::CreateSparseInputNode(const std::wstring& inputName, const size_t rows) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), inputName, rows)); } template shared_ptr> ComputationNetworkBuilder::CreateInputNode(const std::wstring& inputName, const TensorShape& sampleLayout) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), inputName, sampleLayout)); } template shared_ptr> ComputationNetworkBuilder::CreateSparseInputNode(const std::wstring& inputName, const TensorShape& imageLayout) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), inputName, imageLayout)); } template shared_ptr> ComputationNetworkBuilder::CreatePairNetworkNode(const std::wstring& inputName, const size_t rows, const size_t cols) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), inputName, rows, cols)); } template shared_ptr> ComputationNetworkBuilder::CreateConvolutionNode(const std::wstring& nodeName, const size_t kernelWidth, const size_t kernelHeight, const size_t outputChannels, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind, const bool zeroPadding, const size_t maxTempMemSizeInSamples) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), nodeName, kernelWidth, kernelHeight, outputChannels, horizontalSubsample, verticalSubsample, imageLayoutKind, zeroPadding, maxTempMemSizeInSamples)); } template shared_ptr> ComputationNetworkBuilder::CreateMaxPoolingNode(const std::wstring& nodeName, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), nodeName, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayoutKind)); } template shared_ptr> ComputationNetworkBuilder::CreateAveragePoolingNode(const std::wstring& nodeName, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind) { return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), nodeName, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayoutKind)); } // this is the catch-all for all cases not covered as special cases above // Unlike the specialized ones above, this one creates nodes by type given as a string. template shared_ptr> ComputationNetworkBuilder::CreateComputationNode(const std::wstring& nodeType, const std::wstring& nodeName) { return net.AddNodeToNetWithElemType(NewStandardNode(nodeType, net.GetDeviceId(), nodeName)); } // ----------------------------------------------------------------------- // node creation // ----------------------------------------------------------------------- // The following functions create nodes and link them to the network and their inputs. // TODO: Do we need both this set and the one above that does not add inputs? Can they share more code? template shared_ptr> ComputationNetworkBuilder::PairNetwork(const ComputationNodePtr& a, const std::wstring nodeName) { if (net.GetNodeFromName(a->NodeName(), nullptr, false) != nullptr) { fprintf(stderr, "PairNetwork: asked to pair a node with name %ls in another network. However, this network has already a node with the same name. Should avoid this case.\n", a->NodeName().c_str()); RuntimeError("PairNetwork: asked to pair a node with name in another network. However, this network has already a node with the same name. Should avoid this case.\n"); } return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Convolution(const ComputationNodePtr weight, const ComputationNodePtr inputValues, const size_t kernelWidth, const size_t kernelHeight, const size_t outputChannels, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind, const bool zeroPadding, const size_t maxTempMemSizeInSamples, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, kernelWidth, kernelHeight, outputChannels, horizontalSubsample, verticalSubsample, imageLayoutKind, zeroPadding, maxTempMemSizeInSamples), weight, inputValues); } template shared_ptr> ComputationNetworkBuilder::MaxPooling(const ComputationNodePtr inputValues, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayoutKind), inputValues); } template shared_ptr> ComputationNetworkBuilder::AveragePooling(const ComputationNodePtr inputValues, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayoutKind), inputValues); } template shared_ptr> ComputationNetworkBuilder::ErrorPrediction(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::PerDimMeanVarNormalization(const ComputationNodePtr feature, const ComputationNodePtr mean, const ComputationNodePtr InvStdDev, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), feature, mean, InvStdDev); } template shared_ptr> ComputationNetworkBuilder::PerDimMeanVarDeNormalization(const ComputationNodePtr feature, const ComputationNodePtr mean, const ComputationNodePtr InvStdDev, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), feature, mean, InvStdDev); } template shared_ptr> ComputationNetworkBuilder::SquareError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::Logistic(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::Logistic(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b, c); } template shared_ptr> ComputationNetworkBuilder::SequenceDecoder(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr pairscore, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), label, prediction, pairscore); } template shared_ptr> ComputationNetworkBuilder::CrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), label, prediction); } template shared_ptr> ComputationNetworkBuilder::SequenceWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr loglikelihood, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), label, prediction, loglikelihood); } template shared_ptr> ComputationNetworkBuilder::NoiseContrastiveEstimation(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr input_weight, const ComputationNodePtr input_bias, const std::wstring nodeName, NCEEvalMode mode) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, mode), label, prediction, input_weight, input_bias); } template shared_ptr> ComputationNetworkBuilder::ClassCrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr input_weight, const ComputationNodePtr cls_log_post_prob, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), label, prediction, input_weight, cls_log_post_prob); } template shared_ptr> ComputationNetworkBuilder::CRF(const ComputationNodePtr label, const ComputationNodePtr postDepScore, const ComputationNodePtr transition_score, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), label, postDepScore, transition_score); } template shared_ptr> ComputationNetworkBuilder::DummyCriterion(const ComputationNodePtr objectives, const ComputationNodePtr derivatives, const ComputationNodePtr prediction, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), objectives, derivatives, prediction); } template shared_ptr> ComputationNetworkBuilder::LSTM(const ComputationNodePtr obs, const ComputationNodePtr inputGate, const ComputationNodePtr forgetGate, const ComputationNodePtr outputGate, const ComputationNodePtr memoryCellWgt, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), obs, inputGate, forgetGate, outputGate, memoryCellWgt); } template shared_ptr> ComputationNetworkBuilder::CrossEntropy(const ComputationNodePtr label, const ComputationNodePtr prediction, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), label, prediction); } template shared_ptr> ComputationNetworkBuilder::MatrixL1Reg(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::MatrixL2Reg(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Mean(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::InvStdDev(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Negate(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::RectifiedLinear(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Sigmoid(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Tanh(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Exp(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Log(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Cos(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Hardmax(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Softmax(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::LogSoftmax(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Sum(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } #ifndef ENABLE_BROADCASTING_ELEMENTTIMES template shared_ptr> ComputationNetworkBuilder::Scale(const ComputationNodePtr scalar, const ComputationNodePtr matrix, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), scalar, matrix); } #endif template shared_ptr> ComputationNetworkBuilder::Transpose(const ComputationNodePtr matrix, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), matrix); } template shared_ptr> ComputationNetworkBuilder::Times(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::TransposeTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::ElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } #ifndef ENABLE_BROADCASTING_ELEMENTTIMES template shared_ptr> ComputationNetworkBuilder::RowElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::ColumnElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } #endif template shared_ptr> ComputationNetworkBuilder::StrideTimes(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b, c); } template shared_ptr> ComputationNetworkBuilder::DiagTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::CosDistance(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::KhatriRaoProduct(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::Plus(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::Minus(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::Dropout(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::Reshape(const ComputationNodePtr a, const TensorShape& imageLayout, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, imageLayout), a); } #if 1 template shared_ptr> ComputationNetworkBuilder::DeprecatedReshape(const ComputationNodePtr a, const size_t numRows, const TensorShape& imageLayout, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, numRows, imageLayout), a); } #endif template shared_ptr> ComputationNetworkBuilder::RowRepeat(const ComputationNodePtr a, const size_t num_repeat, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, num_repeat), a); } template shared_ptr> ComputationNetworkBuilder::Diagonal(const ComputationNodePtr a, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a); } template shared_ptr> ComputationNetworkBuilder::PastValue(const ComputationNodePtr a, const float initHiddenActivity, const size_t row_size, size_t timeStep, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, initHiddenActivity, row_size, timeStep), a); } template shared_ptr> ComputationNetworkBuilder::FutureValue(const ComputationNodePtr a, const float initHiddenActivity, const size_t row_size, size_t timeStep, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, initHiddenActivity, row_size, timeStep), a); } template shared_ptr> ComputationNetworkBuilder::Parallel(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } template shared_ptr> ComputationNetworkBuilder::RowSlice(const ComputationNodePtr a, const size_t start_index, const size_t num_rows, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, start_index, num_rows), a); } template shared_ptr> ComputationNetworkBuilder::RowStack(const std::vector pinputs, const std::wstring nodeName) { vector inputs(pinputs.size()); for (size_t i = 0; i < inputs.size(); i++) inputs[i] = pinputs[i]; // convert to ComputationNodeBasePtr return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), inputs); } template shared_ptr> ComputationNetworkBuilder::GMMLogLikelihood(const ComputationNodePtr unnormedPrior, const ComputationNodePtr mean, const ComputationNodePtr logStddev, const ComputationNodePtr feature, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), unnormedPrior, mean, logStddev, feature); } template shared_ptr> ComputationNetworkBuilder::TimeReverse(const ComputationNodePtr input, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), input); } template shared_ptr> ComputationNetworkBuilder::LookupTable(const ComputationNodePtr dictionary, const ComputationNodePtr input, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), dictionary, input); } template shared_ptr> ComputationNetworkBuilder::BatchNormalization(const ComputationNodePtr input, const ComputationNodePtr scale, const ComputationNodePtr bias, const ComputationNodePtr runMean, const ComputationNodePtr runInvStdDev, bool eval, bool spatial, double expAvgFactor, ImageLayoutKind imageLayoutKind, const std::wstring nodeName) { return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, eval, spatial, expAvgFactor, imageLayoutKind), input, scale, bias, runMean, runInvStdDev); } template class ComputationNetworkBuilder; template class ComputationNetworkBuilder; } } }