// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #pragma once #include "DataDeserializerBase.h" #include "Descriptors.h" #include "TextConfigHelper.h" #include "Indexer.h" #include "CorpusDescriptor.h" namespace Microsoft { namespace MSR { namespace CNTK { template class CNTKTextFormatReaderTestRunner; // TODO: more details when tracing warnings // (e.g., buffer content around the char that triggered the warning) template class TextParser : public DataDeserializerBase { public: explicit TextParser(const TextConfigHelper& helper); TextParser(CorpusDescriptorPtr corpus, const TextConfigHelper& helper); ~TextParser(); // Retrieves a chunk of data. ChunkPtr GetChunk(ChunkIdType chunkId) override; // Get information about chunks. ChunkDescriptions GetChunkDescriptions() override; // Get information about particular chunk. void GetSequencesForChunk(ChunkIdType chunkId, std::vector& result) override; bool GetSequenceDescriptionByKey(const KeyType&, SequenceDescription&) override; private: // Builds an index of the input data. void Initialize(); struct DenseInputStreamBuffer : DenseSequenceData { // capacity = expected number of samples * sample size DenseInputStreamBuffer(size_t capacity) { m_buffer.reserve(capacity); } const void* GetDataBuffer() override { return m_buffer.data(); } std::vector m_buffer; }; // In case of sparse input, we also need a vector of // indices (one index for each input value) and a vector // of NNZ counts (one for each sample). struct SparseInputStreamBuffer : SparseSequenceData { SparseInputStreamBuffer() { m_totalNnzCount = 0; } const void* GetDataBuffer() override { return m_buffer.data(); } std::vector m_indicesBuffer; std::vector m_buffer; }; // A sequence buffer is a vector that contains sequence data for each input stream. typedef std::vector SequenceBuffer; // A chunk of input data in the text format. class TextDataChunk; typedef std::shared_ptr TextChunkPtr; enum TraceLevel { Error = 0, Warning = 1, Info = 2 }; const std::wstring m_filename; FILE* m_file; // An internal structure to assist with copying from input stream buffers into // into sequence data in a proper format. struct StreamInfo; std::vector m_streamInfos; size_t m_maxAliasLength; std::map m_aliasToIdMap; std::unique_ptr m_indexer; int64_t m_fileOffsetStart; int64_t m_fileOffsetEnd; // TODO: not DRY (same in the Indexer), needs refactoring unique_ptr m_buffer; const char* m_bufferStart; const char* m_bufferEnd; const char* m_pos; // buffer index unique_ptr m_scratch; // local buffer for string parsing size_t m_chunkSizeBytes; unsigned int m_traceLevel; bool m_hadWarnings; unsigned int m_numAllowedErrors; bool m_skipSequenceIds; unsigned int m_numRetries; // specifies the number of times an unsuccessful // file operation should be repeated (default value is 5). // Corpus descriptor. CorpusDescriptorPtr m_corpus; // throws runtime exception when number of parsing errors is // greater than the specified threshold void IncrementNumberOfErrorsOrDie(); // prints a messages that there were warnings which might // have been swallowed. void PrintWarningNotification(); void SetFileOffset(int64_t position); void SkipToNextValue(size_t& bytesToRead); void SkipToNextInput(size_t& bytesToRead); bool TryRefillBuffer(); int64_t GetFileOffset() const { return m_fileOffsetStart + (m_pos - m_bufferStart); } // Returns a string containing input file information (current offset, file name, etc.), // which can be included as a part of the trace/log message. std::wstring GetFileInfo(); // Reads an alias/name and converts it to an internal stream id (= stream index). bool TryGetInputId(size_t& id, size_t& bytesToRead); bool TryReadRealNumber(ElemType& value, size_t& bytesToRead); bool TryReadUint64(size_t& value, size_t& bytesToRead); // Reads dense sample values into the provided vector. bool TryReadDenseSample(std::vector& values, size_t sampleSize, size_t& bytesToRead); // Reads sparse sample values and corresponding indices into the provided vectors. bool TryReadSparseSample(std::vector& values, std::vector& indices, size_t sampleSize, size_t& bytesToRead); // Reads one sample (an input identifier followed by a list of values) bool TryReadSample(SequenceBuffer& sequence, size_t& bytesToRead); // Reads one whole row (terminated by a row delimiter) of samples bool TryReadRow(SequenceBuffer& sequence, size_t& bytesToRead); // Returns true if there's still data available. bool inline CanRead() { return m_pos != m_bufferEnd || TryRefillBuffer(); } // Returns true if the trace level is greater or equal to 'Warning' bool inline ShouldWarn() { m_hadWarnings = true; return m_traceLevel >= Warning; } // Given a descriptor, retrieves the data for the corresponding sequence from the file. SequenceBuffer LoadSequence(const SequenceDescriptor& descriptor); // Given a descriptor, retrieves the data for the corresponding chunk from the file. void LoadChunk(TextChunkPtr& chunk, const ChunkDescriptor& descriptor); TextParser(CorpusDescriptorPtr corpus, const std::wstring& filename, const vector& streams); // Fills some metadata members to be conformant to the exposed SequenceData interface. void FillSequenceMetadata(SequenceBuffer& sequenceBuffer, size_t sequenceId); void SetTraceLevel(unsigned int traceLevel); void SetMaxAllowedErrors(unsigned int maxErrors); void SetSkipSequenceIds(bool skip); void SetChunkSize(size_t size); void SetNumRetries(unsigned int numRetries); friend class CNTKTextFormatReaderTestRunner; const std::string& GetSequenceKey(const SequenceDescriptor& s) const; DISABLE_COPY_AND_MOVE(TextParser); }; }}}