https://github.com/Microsoft/CNTK
Tip revision: b881dabb9bb554b3b88108a3b2c1a81fc4604fac authored by Mark Hillebrand on 18 January 2016, 08:38:04 UTC
License change
License change
Tip revision: b881dab
LUSequenceReader.h
//
// <copyright file="LUSequenceReader.h" company="Microsoft">
// Copyright (c) Microsoft Corporation. All rights reserved.
// </copyright>
//
// LUSequenceReader.h - Include file for the MTK and MLF format of features and samples
#pragma once
//#define LEAKDETECT
#include "DataReader.h"
#include "DataWriter.h"
#include "LUSequenceParser.h"
#include "commandArgUtil.h" // for intargvector
#include <string>
#include <map>
#include <vector>
#include "minibatchsourcehelpers.h"
namespace Microsoft { namespace MSR { namespace CNTK {
#define CACHE_BLOG_SIZE 50000
#define STRIDX2CLS L"idx2cls"
#define CLASSINFO L"classinfo"
#define MAX_STRING 2048
enum LabelKind
{
labelNone = 0, // no labels to worry about
labelCategory = 1, // category labels, creates mapping tables
labelNextWord = 2, // sentence mapping (predicts next word)
labelOther = 3, // some other type of label
};
template<class ElemType>
class LUSequenceReader : public IDataReader<ElemType>
{
protected:
bool m_idx2clsRead;
bool m_clsinfoRead;
std::wstring m_file;
public:
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
int nwords, dims, nsamps, nglen, nmefeats;
int m_seed;
bool mRandomize;
int class_size;
map<int, vector<int>> class_words;
vector<int>class_cn;
public:
/// deal with OOV
map<string, string> mWordMapping;
string mWordMappingFn;
string mUnkStr;
protected:
LULUSequenceParser<ElemType, LabelType> m_parser;
// LUBatchLUSequenceParser<ElemType, LabelType> m_parser;
size_t m_mbSize; // size of minibatch requested
size_t m_mbStartSample; // starting sample # of the next minibatch
size_t m_epochSize; // size of an epoch
size_t m_epoch; // which epoch are we on
size_t m_epochStartSample; // the starting sample for the epoch
size_t m_totalSamples; // number of samples in the dataset
size_t m_featureDim; // feature dimensions for extra features
size_t m_featureCount; // total number of non-zero features (in labelsDim + extra features dim)
/// for language modeling, the m_featureCount = 1, since there is only one nonzero element
size_t m_readNextSampleLine; // next sample to read Line
size_t m_readNextSample; // next sample to read
size_t m_seqIndex; // index into the m_sequence array
bool m_labelFirst; // the label is the first element in a line
intargvector m_wordContext;
enum LabelInfoType
{
labelInfoMin = 0,
labelInfoIn = labelInfoMin,
labelInfoOut,
labelInfoMax
};
std::wstring m_labelsName[labelInfoMax];
std::wstring m_featuresName;
std::wstring m_labelsCategoryName[labelInfoMax];
std::wstring m_labelsMapName[labelInfoMax];
std::wstring m_sequenceName;
ElemType* m_featuresBuffer;
ElemType* m_labelsBuffer;
LabelIdType* m_labelsIdBuffer;
size_t* m_sequenceBuffer;
bool m_endReached;
int m_traceLevel;
// feature and label data are parallel arrays
std::vector<std::vector<vector<LabelIdType>>> m_featureWordContext;
std::vector<vector<LabelIdType>> m_featureData;
std::vector<LabelIdType> m_labelIdData;
std::vector<ElemType> m_labelData;
std::vector<size_t> m_sequence;
// we have two one for input and one for output
struct LabelInfo
{
LabelKind type; // labels are categories, create mapping table
std::map<LabelIdType, LabelType> mapIdToLabel;
std::map<LabelType, LabelIdType> mapLabelToId;
map<string, int> word4idx;
map<int, string> idx4word;
LabelIdType idMax; // maximum label ID we have encountered so far
LabelIdType dim; // maximum label ID we will ever see (used for array dimensions)
std::string beginSequence; // starting sequence string (i.e. <s>)
std::string endSequence; // ending sequence string (i.e. </s>)
bool busewordmap; /// whether using wordmap to map unseen words to unk
std::wstring mapName;
std::wstring fileToWrite; // set to the path if we need to write out the label file
} m_labelInfo[labelInfoMax];
// caching support
DataReader<ElemType>* m_cachingReader;
DataWriter<ElemType>* m_cachingWriter;
ConfigParameters m_readerConfig;
void InitCache(const ConfigParameters& config);
void UpdateDataVariables();
void LMSetupEpoch();
size_t RecordsToRead(size_t mbStartSample, bool tail=false);
void ReleaseMemory();
void WriteLabelFile();
void LoadLabelFile(const std::wstring &filePath, std::vector<LabelType>& retLabels);
LabelIdType GetIdFromLabel(const std::string& label, LabelInfo& labelInfo);
bool GetIdFromLabel(const vector<string>& label, LabelInfo& labelInfo, vector<LabelIdType>& val);
bool CheckIdFromLabel(const std::string& labelValue, const LabelInfo& labelInfo, unsigned & labelId);
virtual bool ReadRecord(size_t readSample);
bool SentenceEnd();
public:
virtual void Init(const ConfigParameters& config);
void ReadLabelInfo(const wstring & vocfile, map<string, int> & word4idx,
map<int, string>& idx4word) ;
void ChangeMaping(const map<string, string>& maplist,
const string & unkstr ,
map<string, int> & word4idx);
void ReadWord(char *wrod, FILE *fin);
virtual void Destroy();
LUSequenceReader() {
m_featuresBuffer=NULL; m_labelsBuffer=NULL; m_clsinfoRead = false; m_idx2clsRead = false;
}
virtual ~LUSequenceReader();
virtual void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
void SetNbrSlicesEachRecurrentIter(const size_t /*mz*/) {};
void SentenceEnd(std::vector<size_t> &/*sentenceEnd*/) {};
void SetSentenceEndInBatch(std::vector<size_t> &/*sentenceEnd*/) {};
virtual const std::map<LabelIdType, LabelType>& GetLabelMapping(const std::wstring& sectionName);
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart=0);
};
template<class ElemType>
class BatchLUSequenceReader : public LUSequenceReader<ElemType>
{
public:
using LabelType = typename IDataReader<ElemType>::LabelType;
using LabelIdType = typename IDataReader<ElemType>::LabelIdType;
using LUSequenceReader<ElemType>::mWordMappingFn;
using LUSequenceReader<ElemType>::m_cachingReader;
using LUSequenceReader<ElemType>::mWordMapping;
using LUSequenceReader<ElemType>::mUnkStr;
using LUSequenceReader<ElemType>::m_cachingWriter;
using LUSequenceReader<ElemType>::m_featuresName;
using LUSequenceReader<ElemType>::m_labelsName;
using LUSequenceReader<ElemType>::labelInfoMin;
using LUSequenceReader<ElemType>::labelInfoMax;
using LUSequenceReader<ElemType>::m_featureDim;
using LUSequenceReader<ElemType>::class_size;
using LUSequenceReader<ElemType>::m_labelInfo;
// using LUSequenceReader<ElemType>::m_labelInfoIn;
using LUSequenceReader<ElemType>::m_mbStartSample;
using LUSequenceReader<ElemType>::m_epoch;
using LUSequenceReader<ElemType>::m_totalSamples;
using LUSequenceReader<ElemType>::m_epochStartSample;
using LUSequenceReader<ElemType>::m_seqIndex;
using LUSequenceReader<ElemType>::m_endReached;
using LUSequenceReader<ElemType>::m_readNextSampleLine;
using LUSequenceReader<ElemType>::m_readNextSample;
using LUSequenceReader<ElemType>::m_traceLevel;
using LUSequenceReader<ElemType>::m_wordContext;
using LUSequenceReader<ElemType>::m_featureCount;
using typename LUSequenceReader<ElemType>::LabelInfo;
using LUSequenceReader<ElemType>::labelInfoIn;
using LUSequenceReader<ElemType>::labelInfoOut;
// using LUSequenceReader<ElemType>::arrayLabels;
using LUSequenceReader<ElemType>::m_readerConfig;
using LUSequenceReader<ElemType>::m_featuresBuffer;
using LUSequenceReader<ElemType>::m_labelsBuffer;
using LUSequenceReader<ElemType>::m_labelsIdBuffer;
using LUSequenceReader<ElemType>::m_mbSize;
using LUSequenceReader<ElemType>::m_epochSize;
using LUSequenceReader<ElemType>::m_featureData;
using LUSequenceReader<ElemType>::m_sequence;
using LUSequenceReader<ElemType>::m_labelData;
using LUSequenceReader<ElemType>::m_labelIdData;
using LUSequenceReader<ElemType>::m_idx2clsRead;
using LUSequenceReader<ElemType>::m_clsinfoRead;
using LUSequenceReader<ElemType>::m_featureWordContext;
using LUSequenceReader<ElemType>::LoadLabelFile;
using LUSequenceReader<ElemType>::ReleaseMemory;
using LUSequenceReader<ElemType>::LMSetupEpoch;
using LUSequenceReader<ElemType>::ChangeMaping;
using LUSequenceReader<ElemType>::GetIdFromLabel;
using LUSequenceReader<ElemType>::InitCache;
using LUSequenceReader<ElemType>::ReadLabelInfo;
using LUSequenceReader<ElemType>::mRandomize;
using LUSequenceReader<ElemType>::m_seed;
private:
size_t mLastProcssedSentenceId ;
size_t mBlgSize;
size_t mPosInSentence;
vector<size_t> mToProcess;
size_t mLastPosInSentence;
size_t mNumRead ;
std::vector<vector<LabelType>> m_featureTemp;
std::vector<LabelType> m_labelTemp;
bool mSentenceEnd;
bool mSentenceBegin;
public:
vector<bool> mProcessed;
LUBatchLUSequenceParser<ElemType, LabelType> m_parser;
BatchLUSequenceReader() {
mLastProcssedSentenceId = 0;
mBlgSize = 1;
mLastPosInSentence = 0;
mNumRead = 0;
mSentenceEnd = false;
mSentenceBegin = true;
}
~BatchLUSequenceReader() {
if (m_labelTemp.size() > 0)
m_labelTemp.clear();
if (m_featureTemp.size() > 0)
m_featureTemp.clear();
};
void Init(const ConfigParameters& readerConfig);
void Reset();
/// return length of sentences size
size_t FindNextSentences(size_t numSentences);
bool DataEnd(EndDataType endDataType);
void SetSentenceEnd(int wrd, int pos, int actualMbSize);
void SetSentenceBegin(int wrd, int pos, int actualMbSize);
void SetSentenceBegin(int wrd, size_t pos, size_t actualMbSize) { SetSentenceBegin(wrd, (int)pos, (int)actualMbSize); } // TODO: clean this up
void SetSentenceEnd(int wrd, size_t pos, size_t actualMbSize) { SetSentenceEnd(wrd, (int)pos, (int)actualMbSize); }
void SetSentenceBegin(size_t wrd, size_t pos, size_t actualMbSize) { SetSentenceBegin((int)wrd, (int)pos, (int)actualMbSize); }
void SetSentenceEnd(size_t wrd, size_t pos, size_t actualMbSize) { SetSentenceEnd((int)wrd, (int)pos, (int)actualMbSize); }
void GetLabelOutput(std::map<std::wstring, Matrix<ElemType>*>& matrices,
size_t m_mbStartSample, size_t actualmbsize);
void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples=requestDataSize);
bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices);
bool EnsureDataAvailable(size_t mbStartSample);
size_t NumberSlicesInEachRecurrentIter();
void SetNbrSlicesEachRecurrentIter(const size_t mz);
void SetSentenceEndInBatch(std::vector<size_t> &sentenceEnd);
public:
void LoadWordMapping(const ConfigParameters& readerConfig);
};
}}}