https://github.com/Microsoft/CNTK
Tip revision: 7c8e3fb76cd6f65d9b4e0c6ff96bd00eb16211fb authored by Wayne Xiong on 26 February 2016, 22:23:19 UTC
Merge remote-tracking branch 'origin/jdroppo/ivector' into weixi/waynecoding
Merge remote-tracking branch 'origin/jdroppo/ivector' into weixi/waynecoding
Tip revision: 7c8e3fb
SimpleOutputWriter.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "Basics.h"
#include "DataReader.h"
#include "ComputationNetwork.h"
#include "DataReaderHelpers.h"
#include "Helpers.h"
#include "File.h"
#include "fileutil.h"
#include <vector>
#include <string>
#include <stdexcept>
#include <fstream>
#include <cstdio>
using namespace std;
namespace Microsoft { namespace MSR { namespace CNTK {
template <class ElemType>
class SimpleOutputWriter
{
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;
private:
std::vector<ComputationNodeBasePtr> DetermineOutputNodes(const std::vector<std::wstring>& outputNodeNames)
{
std::vector<ComputationNodeBasePtr> outputNodes;
if (outputNodeNames.size() == 0)
{
if (m_verbosity > 0)
fprintf(stderr, "OutputNodeNames are not specified, using the default outputnodes.\n");
if (m_net->OutputNodes().size() == 0)
LogicError("There is no default output node specified in the network.");
outputNodes = m_net->OutputNodes();
}
else
{
for (int i = 0; i < outputNodeNames.size(); i++)
outputNodes.push_back(m_net->GetNodeFromName(outputNodeNames[i]));
}
return outputNodes;
}
std::vector<ComputationNodeBasePtr> DetermineInputNodes(const std::vector<ComputationNodeBasePtr>& outputNodes)
{
//use map to remove duplicated items
std::set<ComputationNodeBasePtr> inputNodesMap;
for (auto& onode : outputNodes)
{
for (auto& inode : m_net->InputNodes(onode))
inputNodesMap.insert(inode);
}
std::vector<ComputationNodeBasePtr> inputNodes;
for (auto& inode : inputNodesMap)
inputNodes.push_back(inode);
return inputNodes;
}
std::map<std::wstring, Matrix<ElemType>*> RetrieveInputMatrices(const std::vector<ComputationNodeBasePtr>& inputNodes)
{
std::map<std::wstring, Matrix<ElemType>*> inputMatrices;
for (auto& inode : inputNodes)
inputMatrices[inode->NodeName()] = &dynamic_pointer_cast<ComputationNode<ElemType>>(inode)->Value();
return inputMatrices;
}
public:
SimpleOutputWriter(ComputationNetworkPtr net, int verbosity = 0)
: m_net(net), m_verbosity(verbosity)
{
}
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, IDataWriter<ElemType>& dataWriter, const std::vector<std::wstring>& outputNodeNames, size_t numOutputSamples = requestDataSize, bool doUnitTest = false)
{
std::vector<ComputationNodeBasePtr> outputNodes = DetermineOutputNodes(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = DetermineInputNodes(outputNodes);
// allocate memory for forward computation
m_net->AllocateAllMatrices({}, outputNodes, nullptr);
std::map<std::wstring, Matrix<ElemType>*> inputMatrices = RetrieveInputMatrices(inputNodes);
// evaluate with minibatches
dataReader.StartMinibatchLoop(mbSize, 0, numOutputSamples);
if (!dataWriter.SupportMultiUtterances())
dataReader.SetNumParallelSequences(1);
m_net->StartEvaluateMinibatchLoop(outputNodes);
size_t totalEpochSamples = 0;
std::map<std::wstring, void*, nocase_compare> outputMatrices;
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
{
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
for (int i = 0; i < outputNodes.size(); i++)
{
m_net->ForwardProp(outputNodes[i]);
outputMatrices[outputNodes[i]->NodeName()] = (void*) (&dynamic_pointer_cast<ComputationNode<ElemType>>(outputNodes[i])->Value());
}
if (doUnitTest)
{
std::map<std::wstring, void*, nocase_compare> inputMatricesUnitTest;
for (auto iter = inputMatrices.begin(); iter != inputMatrices.end(); iter++)
inputMatricesUnitTest[iter->first] = (void*) (iter->second);
dataWriter.SaveData(0, inputMatricesUnitTest, actualMBSize, actualMBSize, 0);
}
else
dataWriter.SaveData(0, outputMatrices, actualMBSize, actualMBSize, 0);
totalEpochSamples += actualMBSize;
// call DataEnd function in dataReader to do
// reader specific process if sentence ending is reached
dataReader.DataEnd();
}
if (m_verbosity > 0)
fprintf(stderr, "Total Samples Evaluated = %lu\n", totalEpochSamples);
// clean up
}
// pass this to WriteOutput() (to file-path, below) to specify how the output should be formatted
struct WriteFormattingOptions
{
// How to interpret the data:
bool isCategoryLabel; // true: find max value in column and output the index instead of the entire vector
std::wstring labelMappingFile; // optional dictionary for pretty-printing category labels
bool transpose; // true: one line per sample, each sample (column vector) forms one line; false: one column per sample
// The following strings are interspersed with the data:
// overall
std::string prologue; // print this at the start (e.g. a global header or opening bracket)
std::string epilogue; // and this at the end
// sequences
std::string sequenceSeparator; // print this between sequences (i.e. before all sequences but the first)
std::string sequencePrologue; // print this before each sequence (after sequenceSeparator)
std::string sequenceEpilogue; // and this after each sequence
// elements
std::string elementSeparator; // print this between elements on a row
std::string sampleSeparator; // and this between rows
// Optional printf precision parameter:
std::string precisionFormat; // printf precision, e.g. ".2" to get a "%.2f"
WriteFormattingOptions() :
isCategoryLabel(false), transpose(true), sequenceEpilogue("\n"), elementSeparator(" "), sampleSeparator("\n")
{ }
// Process -- replace newlines and all %s by the given string
static std::string Processed(const std::wstring& nodeName, std::string fragment)
{
fragment = msra::strfun::ReplaceAll<std::string>(fragment, "\\n", "\n");
fragment = msra::strfun::ReplaceAll<std::string>(fragment, "\\t", "\t");
if (fragment.find("%s") != fragment.npos)
fragment = msra::strfun::ReplaceAll<std::string>(fragment, "%s", msra::strfun::utf8(nodeName));
return fragment;
}
};
// TODO: Remove code dup with above function by creating a fake Writer object and then calling the other function.
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, std::wstring outputPath, const std::vector<std::wstring>& outputNodeNames, const WriteFormattingOptions & formattingOptions, size_t numOutputSamples = requestDataSize)
{
std::vector<ComputationNodeBasePtr> outputNodes = DetermineOutputNodes(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = DetermineInputNodes(outputNodes);
// allocate memory for forward computation
m_net->AllocateAllMatrices({}, outputNodes, nullptr);
std::map<std::wstring, Matrix<ElemType>*> inputMatrices = RetrieveInputMatrices(inputNodes);
// load a label mapping if requested
std::vector<std::string> labelMapping;
if (formattingOptions.isCategoryLabel && !formattingOptions.labelMappingFile.empty())
File::LoadLabelFile(formattingOptions.labelMappingFile, labelMapping);
// open output files
File::MakeIntermediateDirs(outputPath);
std::map<ComputationNodeBasePtr, shared_ptr<File>> outputStreams; // TODO: why does unique_ptr not work here? Complains about non-existent default_delete()
for (auto & onode : outputNodes)
{
std::wstring nodeOutputPath = outputPath;
if (nodeOutputPath != L"-")
nodeOutputPath += L"." + onode->NodeName();
auto f = make_shared<File>(nodeOutputPath, fileOptionsWrite | fileOptionsText);
outputStreams[onode] = f;
}
// evaluate with minibatches
dataReader.StartMinibatchLoop(mbSize, 0, numOutputSamples);
m_net->StartEvaluateMinibatchLoop(outputNodes);
size_t totalEpochSamples = 0;
size_t numMBsRun = 0;
size_t tempArraySize = 0;
ElemType* tempArray = nullptr;
for (auto & onode : outputNodes)
{
FILE * f = *outputStreams[onode];
fprintfOrDie(f, "%s", formattingOptions.prologue.c_str());
}
char formatChar = !formattingOptions.isCategoryLabel ? 'f' : !formattingOptions.labelMappingFile.empty() ? 's' : 'u';
std::string valueFormatString = "%" + formattingOptions.precisionFormat + formatChar; // format string used in fprintf() for formatting the values
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
{
ComputationNetwork::BumpEvalTimeStamp(inputNodes);
for (auto & onode : outputNodes)
{
// compute the node value
// Note: Intermediate values are memoized, so in case of multiple output nodes, we only compute what has not been computed already.
m_net->ForwardProp(onode);
// get it (into a flat CPU-side vector)
Matrix<ElemType>& outputValues = dynamic_pointer_cast<ComputationNode<ElemType>>(onode)->Value();
outputValues.CopyToArray(tempArray, tempArraySize);
ElemType* pCurValue = tempArray;
// sequence separator
FILE * f = *outputStreams[onode];
const auto sequenceSeparator = formattingOptions.Processed(onode->NodeName(), formattingOptions.sequenceSeparator);
const auto sequencePrologue = formattingOptions.Processed(onode->NodeName(), formattingOptions.sequencePrologue);
const auto sequenceEpilogue = formattingOptions.Processed(onode->NodeName(), formattingOptions.sequenceEpilogue);
const auto elementSeparator = formattingOptions.Processed(onode->NodeName(), formattingOptions.elementSeparator);
const auto sampleSeparator = formattingOptions.Processed(onode->NodeName(), formattingOptions.sampleSeparator);
if (numMBsRun > 0 && !sequenceSeparator.empty())
fprintfOrDie(f, "%s", sequenceSeparator.c_str());
fprintfOrDie(f, "%s", sequencePrologue.c_str());
// output it according to our format specification
size_t T = outputValues.GetNumCols();
size_t dim = outputValues.GetNumRows();
if (formattingOptions.isCategoryLabel)
{
if (formatChar == 's') // verify label dimension
{
if (dim != labelMapping.size())
InvalidArgument("write: Row dimension %d does not match number of entries %d in labelMappingFile '%ls'", (int)dim, (int)labelMapping.size(), formattingOptions.labelMappingFile.c_str());
}
// update the matrix in-place from one-hot (or max) to index
// find the max in each column
foreach_column(j, outputValues)
{
double maxPos = -1;
double maxVal = 0;
foreach_row(i, outputValues)
{
double val = pCurValue[i + j * dim];
if (maxPos < 0 || val >= maxVal)
{
maxPos = (double)i;
maxVal = val;
}
}
pCurValue[j] = (ElemType) maxPos; // overwrite in-place, assuming a flat vector
}
dim = 1;
}
size_t iend = formattingOptions.transpose ? dim : T;
size_t jend = formattingOptions.transpose ? T : dim;
size_t istride = formattingOptions.transpose ? 1 : jend;
size_t jstride = formattingOptions.transpose ? iend : 1;
for (size_t j = 0; j < jend; j++)
{
if (j > 0)
fprintfOrDie(f, "%s", sampleSeparator.c_str());
for (size_t i = 0; i < iend; i++)
{
if (i > 0)
fprintfOrDie(f, "%s", elementSeparator.c_str());
if (formatChar == 'f') // print as real number
{
double dval = pCurValue[i * istride + j * jstride];
fprintfOrDie(f, valueFormatString.c_str(), dval);
}
else if (formatChar == 'u') // print category as integer index
{
unsigned int uval = (unsigned int) pCurValue[i * istride + j * jstride];
fprintfOrDie(f, valueFormatString.c_str(), uval);
}
else if (formatChar == 's') // print category as a label string
{
size_t uval = (size_t) pCurValue[i * istride + j * jstride];
assert(uval < labelMapping.size());
const char * sval = labelMapping[uval].c_str();
fprintfOrDie(f, valueFormatString.c_str(), sval);
}
}
}
fprintfOrDie(f, "%s", sequenceEpilogue.c_str());
}
totalEpochSamples += actualMBSize;
fprintf(stderr, "Minibatch[%lu]: ActualMBSize = %lu\n", ++numMBsRun, actualMBSize);
}
for (auto & onode : outputNodes)
{
FILE * f = *outputStreams[onode];
fprintfOrDie(f, "%s", formattingOptions.epilogue.c_str());
}
delete[] tempArray;
fprintf(stderr, "Total Samples Evaluated = %lu\n", totalEpochSamples);
// flush all files (where we can catch errors) so that we can then destruct the handle cleanly without error
for (auto & iter : outputStreams)
iter.second->Flush();
}
private:
ComputationNetworkPtr m_net;
int m_verbosity;
void operator=(const SimpleOutputWriter&); // (not assignable)
};
}}}