swh:1:snp:f50ab94432af916b5fb8b4ad831e8dddded77084
Raw File
Tip revision: e8670cecf0e544172fdd51279896989482a86d5c authored by Mark Hillebrand on 03 November 2016, 11:31:28 UTC
bindings/python: edit dev scripts
Tip revision: e8670ce
TextParser.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include "DataDeserializerBase.h"
#include "Descriptors.h"
#include "TextConfigHelper.h"
#include "Indexer.h"
#include "CorpusDescriptor.h"

namespace Microsoft { namespace MSR { namespace CNTK {

template <class ElemType>
class CNTKTextFormatReaderTestRunner;

// TODO: more details when tracing warnings
// (e.g., buffer content around the char that triggered the warning)
template <class ElemType>
class TextParser : public DataDeserializerBase {
public:
    explicit TextParser(const TextConfigHelper& helper);

    TextParser(CorpusDescriptorPtr corpus, const TextConfigHelper& helper);

    ~TextParser();

    // Retrieves a chunk of data.
    ChunkPtr GetChunk(ChunkIdType chunkId) override;

    // Get information about chunks.
    ChunkDescriptions GetChunkDescriptions() override;

    // Get information about particular chunk.
    void GetSequencesForChunk(ChunkIdType chunkId, std::vector<SequenceDescription>& result) override;

    bool GetSequenceDescriptionByKey(const KeyType&, SequenceDescription&) override;

private:
    // Builds an index of the input data.
    void Initialize();

    struct DenseInputStreamBuffer : DenseSequenceData
    {
        // capacity = expected number of samples * sample size
        DenseInputStreamBuffer(size_t capacity)
        {
            m_buffer.reserve(capacity);
        }

        const void* GetDataBuffer() override
        {
            return m_buffer.data();
        }

        std::vector<ElemType> m_buffer;
    };

    // In case of sparse input, we also need a vector of
    // indices (one index for each input value) and a vector
    // of NNZ counts (one for each sample).
    struct SparseInputStreamBuffer : SparseSequenceData
    {
        SparseInputStreamBuffer()
        {
            m_totalNnzCount = 0;
        }

        const void* GetDataBuffer() override
        {
            return m_buffer.data();
        }

        std::vector<IndexType> m_indicesBuffer;
        std::vector<ElemType> m_buffer;
    };

    // A sequence buffer is a vector that contains sequence data for each input stream.
    typedef std::vector<SequenceDataPtr> SequenceBuffer;

    // A chunk of input data in the text format.
    class TextDataChunk;

    typedef std::shared_ptr<TextDataChunk> TextChunkPtr;

    enum TraceLevel
    {
        Error = 0,
        Warning = 1,
        Info = 2
    };

    const std::wstring m_filename;
    FILE* m_file;

    // An internal structure to assist with copying from input stream buffers into
    // into sequence data in a proper format.
    struct StreamInfo;
    std::vector<StreamInfo> m_streamInfos;

    size_t m_maxAliasLength;
    std::map<std::string, size_t> m_aliasToIdMap;

    std::unique_ptr<Indexer> m_indexer;

    int64_t m_fileOffsetStart;
    int64_t m_fileOffsetEnd;

    // TODO: not DRY (same in the Indexer), needs refactoring
    unique_ptr<char[]> m_buffer;
    const char* m_bufferStart;
    const char* m_bufferEnd;
    const char* m_pos; // buffer index

    unique_ptr<char[]> m_scratch; // local buffer for string parsing

    size_t m_chunkSizeBytes;
    unsigned int m_traceLevel;
    bool m_hadWarnings;
    unsigned int m_numAllowedErrors;
    bool m_skipSequenceIds;
    unsigned int m_numRetries; // specifies the number of times an unsuccessful
    // file operation should be repeated (default value is 5).

    // Corpus descriptor.
    CorpusDescriptorPtr m_corpus;

    // throws runtime exception when number of parsing errors is
    // greater than the specified threshold
    void IncrementNumberOfErrorsOrDie();

    // prints a messages that there were warnings which might
    // have been swallowed.
    void PrintWarningNotification();

    void SetFileOffset(int64_t position);

    void SkipToNextValue(size_t& bytesToRead);
    void SkipToNextInput(size_t& bytesToRead);

    bool TryRefillBuffer();

    int64_t GetFileOffset() const { return m_fileOffsetStart + (m_pos - m_bufferStart); }

    // Returns a string containing input file information (current offset, file name, etc.),
    // which can be included as a part of the trace/log message.
    std::wstring GetFileInfo();

    // Reads an alias/name and converts it to an internal stream id (= stream index).
    bool TryGetInputId(size_t& id, size_t& bytesToRead);

    bool TryReadRealNumber(ElemType& value, size_t& bytesToRead);

    bool TryReadUint64(size_t& value, size_t& bytesToRead);

    // Reads dense sample values into the provided vector.
    bool TryReadDenseSample(std::vector<ElemType>& values, size_t sampleSize, size_t& bytesToRead);

    // Reads sparse sample values and corresponding indices into the provided vectors.
    bool TryReadSparseSample(std::vector<ElemType>& values, std::vector<IndexType>& indices,
        size_t sampleSize, size_t& bytesToRead);

    // Reads one sample (an input identifier followed by a list of values)
    bool TryReadSample(SequenceBuffer& sequence, size_t& bytesToRead);

    // Reads one whole row (terminated by a row delimiter) of samples
    bool TryReadRow(SequenceBuffer& sequence, size_t& bytesToRead);

    // Returns true if there's still data available.
    bool inline CanRead() { return m_pos != m_bufferEnd || TryRefillBuffer(); }

    // Returns true if the trace level is greater or equal to 'Warning'
    bool inline ShouldWarn() { m_hadWarnings = true; return m_traceLevel >= Warning; }

    // Given a descriptor, retrieves the data for the corresponding sequence from the file.
    SequenceBuffer LoadSequence(const SequenceDescriptor& descriptor);

    // Given a descriptor, retrieves the data for the corresponding chunk from the file.
    void LoadChunk(TextChunkPtr& chunk, const ChunkDescriptor& descriptor);

    TextParser(CorpusDescriptorPtr corpus, const std::wstring& filename, const vector<StreamDescriptor>& streams);

    // Fills some metadata members to be conformant to the exposed SequenceData interface.
    void FillSequenceMetadata(SequenceBuffer& sequenceBuffer, size_t sequenceId);

    void SetTraceLevel(unsigned int traceLevel);

    void SetMaxAllowedErrors(unsigned int maxErrors);

    void SetSkipSequenceIds(bool skip);

    void SetChunkSize(size_t size);

    void SetNumRetries(unsigned int numRetries);

    friend class CNTKTextFormatReaderTestRunner<ElemType>;

    const std::string& GetSequenceKey(const SequenceDescriptor& s) const;

    DISABLE_COPY_AND_MOVE(TextParser);
};
}}}
back to top