https://github.com/Microsoft/CNTK
Tip revision: a05c3c642648373f4ede0956e4286257c3d59a61 authored by liqfu on 24 August 2018, 17:46:51 UTC
CNTK splice allows broadcast. This case is handled in the change. For noop (identity) ops, its inputs and outputs types shall be set according to upstream ops. ToBatch/ToSequence and Unpack Batch/Sequence ops added during model importing need tp be skipped. Model import need to handle ops with multiple outputs.
CNTK splice allows broadcast. This case is handled in the change. For noop (identity) ops, its inputs and outputs types shall be set according to upstream ops. ToBatch/ToSequence and Unpack Batch/Sequence ops added during model importing need tp be skipped. Model import need to handle ops with multiple outputs.
Tip revision: a05c3c6
SpecialPurposeNodes.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "Basics.h"
#include "ComputationNode.h"
#include "SpecialPurposeNodes.h"
#include <string>
#include <vector>
#include <stdexcept>
#include <memory>
namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------
// Trace (node, say='', logFrequency=10, logFirst=10, logGradientToo=false, onlyUpToRow=100000000, onlyUpToT=100000000, format=[])
//
// Debugging aid to trace a node's value using WriteMinibatchWithFormatting().
// -----------------------------------------------------------------------
template <class ElemType>
TraceNode<ElemType>::TraceNode(const ScriptableObjects::IConfigRecordPtr configp) :
TraceNode(configp->Get(L"deviceId"), L"<placeholder>")
{
AttachInputsFromConfig(configp, this->GetExpectedNumInputs());
m_message = (const std::wstring&)configp->Get(L"say");
m_logFirst = configp->Get(L"logFirst");
m_logFrequency = configp->Get(L"logFrequency");
m_logGradientToo = configp->Get(L"logGradientToo");
m_formattingOptions = WriteFormattingOptions(*configp);
m_onlyUpToRow = configp->Get(L"onlyUpToRow");
m_onlyUpToT = configp->Get(L"onlyUpToT");
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::Save(File& fstream) const /*override*/
{
Base::Save(fstream);
fstream << m_message;
fstream << m_logFirst;
fstream << m_logFrequency;
fstream << m_logGradientToo;
m_formattingOptions.Save(fstream);
// BUGBUG: This serializes the pathname of the mapping file to disk. Not nice. But no better solution.
fstream << m_onlyUpToRow;
fstream << m_onlyUpToT;
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::Load(File& fstream, size_t modelVersion) /*override*/
{
Base::Load(fstream, modelVersion);
fstream >> m_message;
fstream >> m_logFirst;
fstream >> m_logFrequency;
fstream >> m_logGradientToo;
m_formattingOptions.Load(fstream, modelVersion);
fstream >> m_onlyUpToRow;
fstream >> m_onlyUpToT;
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::BeginForwardProp() /*override*/
{
Base::BeginForwardProp();
++m_numMBsRun;
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::ForwardProp(const FrameRange& fr) /*override*/
{
size_t rank = DetermineElementwiseTensorRank();
auto result = ValueTensorFor(rank, fr);
auto input = InputRef(0).ValueTensorFor(rank, fr);
result.AssignCopyOf(input);
// do the tracing
Log(fr, false/*means log value*/);
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::BackpropTo(const size_t inputIndex, const FrameRange& fr) /*override*/
{
assert(inputIndex == 0); inputIndex;
size_t rank = DetermineElementwiseTensorRank();
auto sliceOutputGrad = GradientTensorFor(rank, fr); // propagate from this one...
auto sliceInputGrad = InputRef(0).GradientTensorFor(rank, fr); // ...to this one
sliceInputGrad.AddCopyOf(sliceOutputGrad);
// do the tracing
if (m_logGradientToo)
Log(fr, true/*means log gradient*/);
}
// log value or gradient
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::Log(const FrameRange& fr, bool logGradientInstead) const
{
if (m_numMBsRun == 1)
{
const auto prologue = m_formattingOptions.Processed(NodeName(), m_formattingOptions.prologue, m_numMBsRun);
fprintf(stderr, "%s", prologue.c_str());
}
if (m_numMBsRun <= m_logFirst || (m_logFrequency && (m_numMBsRun-1) % m_logFrequency == 0))
{
char formatChar = !m_formattingOptions.isCategoryLabel ? 'f' : !m_formattingOptions.labelMappingFile.empty() ? 's' : 'u';
auto valueFormatString = "%" + m_formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
const auto sequenceSeparator = m_formattingOptions.Processed(NodeName(), m_formattingOptions.sequenceSeparator, m_numMBsRun);
const auto sequencePrologue = m_formattingOptions.Processed(NodeName(), m_formattingOptions.sequencePrologue, m_numMBsRun);
const auto sequenceEpilogue = m_formattingOptions.Processed(NodeName(), m_formattingOptions.sequenceEpilogue, m_numMBsRun);
const auto elementSeparator = m_formattingOptions.Processed(NodeName(), m_formattingOptions.elementSeparator, m_numMBsRun);
const auto sampleSeparator = m_formattingOptions.Processed(NodeName(), m_formattingOptions.sampleSeparator, m_numMBsRun);
let timeRange = fr.GetTimeRange();
fprintf(stderr, "------- Trace["); // --- for better visual separability from actual content
if (fr.IsAllFrames())
;
else if (timeRange.second == timeRange.first + 1)
fprintf(stderr, "%d", (int)timeRange.first);
else if (timeRange.second > timeRange.first + 1)
fprintf(stderr, "%d..%d", (int)timeRange.first, (int)timeRange.second-1);
fprintf(stderr, "] %ls %s--> %s\n", m_message.c_str(), logGradientInstead ? "(gradient) " : "", InputRef(0).FormatOperationPrototype("").c_str());
InputRef(0).WriteMinibatchWithFormatting(stderr, fr, m_onlyUpToRow, m_onlyUpToT, m_formattingOptions.transpose, m_formattingOptions.isCategoryLabel, m_formattingOptions.isSparse, m_labelMapping,
sequenceSeparator, sequencePrologue, sequenceEpilogue, elementSeparator, sampleSeparator,
valueFormatString, logGradientInstead);
}
}
template <class ElemType>
/*virtual*/ void TraceNode<ElemType>::Validate(bool isFinalValidationPass) // override
{
ValidateUnaryMap(isFinalValidationPass);
if (isFinalValidationPass)
{
if (m_labelMapping.empty() && (m_formattingOptions.isCategoryLabel || m_formattingOptions.isSparse) && !m_formattingOptions.labelMappingFile.empty())
File::LoadLabelFile(m_formattingOptions.labelMappingFile, m_labelMapping);
}
m_numMBsRun = 0;
}
template class TraceNode<float>;
template class TraceNode<double>;
template class TraceNode<half>;
}}}