swh:1:snp:f50ab94432af916b5fb8b4ad831e8dddded77084
Raw File
Tip revision: e8670cecf0e544172fdd51279896989482a86d5c authored by Mark Hillebrand on 03 November 2016, 11:31:28 UTC
bindings/python: edit dev scripts
Tip revision: e8670ce
TextParser.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#include "stdafx.h"
#define __STDC_FORMAT_MACROS
#include <inttypes.h>
#include <cfloat>
#include "Indexer.h"
#include "TextParser.h"
#include "TextReaderConstants.h"

#define isSign(c) ((c == '-' || c == '+'))
#define isE(c) ((c == 'e' || c == 'E'))

namespace Microsoft { namespace MSR { namespace CNTK {

inline bool IsDigit(char c)
{
    return '0' <= c && c <= '9';
}

enum State
{
    Init = 0,
    Sign,
    IntegralPart,
    Period,
    FractionalPart,
    TheLetterE,
    ExponentSign,
    Exponent
};

template <class ElemType>
class TextParser<ElemType>::TextDataChunk : public Chunk, public std::enable_shared_from_this<Chunk>
{
public:
    explicit TextDataChunk(const ChunkDescriptor& descriptor, TextParser* parser);

    // Gets sequences by id.
    void GetSequence(size_t sequenceId, std::vector<SequenceDataPtr>& result) override;

    // A map from sequence ids to the sequence data.
    std::vector<SequenceBuffer> m_sequenceMap;

    // chunk id (copied from the descriptor)
    ChunkIdType m_id;

    // a non-owned pointer to the parser that created this chunk
    TextParser* m_parser;
};


template <class ElemType>
struct TextParser<ElemType>::StreamInfo
{
    StorageType m_type;
    size_t m_sampleDimension;
};

template <class ElemType>
TextParser<ElemType>::TextParser(const TextConfigHelper& helper) : TextParser(std::make_shared<CorpusDescriptor>(), helper)
{}

template <class ElemType>
TextParser<ElemType>::TextParser(CorpusDescriptorPtr corpus, const TextConfigHelper& helper) :
TextParser(corpus, helper.GetFilePath(), helper.GetStreams())
{
    SetTraceLevel(helper.GetTraceLevel());
    SetMaxAllowedErrors(helper.GetMaxAllowedErrors());
    SetChunkSize(helper.GetChunkSize());
    SetSkipSequenceIds(helper.ShouldSkipSequenceIds());

    Initialize();
}


template <class ElemType>
TextParser<ElemType>::TextParser(CorpusDescriptorPtr corpus, const std::wstring& filename, const vector<StreamDescriptor>& streams) :
    m_filename(filename),
    m_file(nullptr),
    m_streamInfos(streams.size()),
    m_indexer(nullptr),
    m_fileOffsetStart(0),
    m_fileOffsetEnd(0),
    m_buffer(new char[BUFFER_SIZE + 1]),
    m_bufferStart(nullptr),
    m_bufferEnd(nullptr),
    m_pos(nullptr),
    m_chunkSizeBytes(0),
    m_traceLevel(TraceLevel::Error),
    m_hadWarnings(false),
    m_numAllowedErrors(0),
    m_skipSequenceIds(false),
    m_numRetries(5),
    m_corpus(corpus)
{
    assert(streams.size() > 0);

    m_maxAliasLength = 0;

    for (size_t i = 0; i < streams.size(); ++i)
    {
        const StreamDescriptor& stream = streams[i];
        const string& alias = stream.m_alias;
        if (m_maxAliasLength < alias.length())
        {
            m_maxAliasLength = alias.length();
        }
        m_aliasToIdMap[alias] = i;
        m_streamInfos[i].m_type = stream.m_storageType;
        m_streamInfos[i].m_sampleDimension = stream.m_sampleDimension;

        auto streamDescription = std::make_shared<StreamDescription>(stream);
        streamDescription->m_sampleLayout = std::make_shared<TensorShape>(stream.m_sampleDimension);
        m_streams.push_back(streamDescription);
    }

    assert(m_maxAliasLength > 0);

    m_scratch = unique_ptr<char[]>(new char[m_maxAliasLength + 1]);
}

template <class ElemType>
TextParser<ElemType>::~TextParser()
{
    if (m_file)
    {
        fclose(m_file);
    }
}

template <class ElemType>
void TextParser<ElemType>::PrintWarningNotification()
{
    if (m_hadWarnings && m_traceLevel < Warning)
    {
        fprintf(stderr,
            "A number of warnings were generated while reading input data, "
            "to see them please set 'traceLevel' to a value greater or equal to %d.\n", Warning);
    }
}

template <class ElemType>
void TextParser<ElemType>::Initialize()
{
    if (m_indexer != nullptr)
    {
        return;
    }

    attempt(m_numRetries, [this]()
    {
        if (m_file == nullptr)
        {
            m_file = fopenOrDie(m_filename, L"rbS");
        }
        else if (ferror(m_file) != 0)
        {
            fclose(m_file);
            m_file = fopenOrDie(m_filename, L"rbS");
        }
        
        if (funicode(m_file))
        {
            // Retrying won't help here, the file is UTF-16 encoded.
            m_numRetries = 0;
            RuntimeError("Found a UTF-16 BOM at the beginning of the input file (%ls). "
                "UTF-16 encoding is currently not supported.", m_filename.c_str());
        }

        m_indexer = make_unique<Indexer>(m_file, m_skipSequenceIds, m_chunkSizeBytes);

        m_indexer->Build(m_corpus);
    });

    assert(m_indexer != nullptr);

    int64_t position = _ftelli64(m_file);
    if (position == -1L)
    {
        RuntimeError("Error retrieving current position in the input file (%ls).", m_filename.c_str());
    }

    m_fileOffsetStart = position;
    m_fileOffsetEnd = position;
}

template <class ElemType>
ChunkDescriptions TextParser<ElemType>::GetChunkDescriptions()
{
    assert(m_indexer != nullptr);

    const auto& index = m_indexer->GetIndex();

    ChunkDescriptions result;
    result.reserve(index.m_chunks.size());
    for (auto const& chunk : index.m_chunks)
    {
        result.push_back(shared_ptr<ChunkDescription>(
            new ChunkDescription {
                chunk.m_id,
                chunk.m_numberOfSamples,
                chunk.m_numberOfSequences
        }));
    }

    return result;
}

template <class ElemType>
void TextParser<ElemType>::GetSequencesForChunk(ChunkIdType chunkId, std::vector<SequenceDescription>& result)
{
    const auto& index = m_indexer->GetIndex();
    const auto& chunk = index.m_chunks[chunkId];
    result.reserve(chunk.m_sequences.size());

    for (auto const& s : chunk.m_sequences)
    {
        result.push_back(
        {
            s.m_id,
            s.m_numberOfSamples,
            s.m_chunkId,
            s.m_key
        });
    }
}

template <class ElemType>
TextParser<ElemType>::TextDataChunk::TextDataChunk(const ChunkDescriptor& descriptor, TextParser* parser) :
    m_parser(parser)
{
    m_id = descriptor.m_id;
}

template <class ElemType>
void TextParser<ElemType>::TextDataChunk::GetSequence(size_t sequenceId, std::vector<SequenceDataPtr>& result)
{
    assert(sequenceId < m_sequenceMap.size());
    result.reserve(m_parser->m_streamInfos.size());

    const auto& sequenceData = m_sequenceMap[sequenceId];
    result.insert(result.end(), sequenceData.begin(), sequenceData.end());
}

template <class ElemType>
ChunkPtr TextParser<ElemType>::GetChunk(ChunkIdType chunkId)
{
    const auto& chunkDescriptor = m_indexer->GetIndex().m_chunks[chunkId];
    auto textChunk = make_shared<TextDataChunk>(chunkDescriptor, this);

    attempt(m_numRetries, [this, &textChunk, &chunkDescriptor]()
    {
        if (ferror(m_file) != 0)
        {
            fclose(m_file);
            m_file = fopenOrDie(m_filename, L"rbS");
        }
        LoadChunk(textChunk, chunkDescriptor);
    });

    return textChunk;
}

template <class ElemType>
void TextParser<ElemType>::LoadChunk(TextChunkPtr& chunk, const ChunkDescriptor& descriptor)
{
    chunk->m_sequenceMap.resize(descriptor.m_sequences.size());
    for (const auto& sequenceDescriptor : descriptor.m_sequences)
    {
        chunk->m_sequenceMap[sequenceDescriptor.m_id] = LoadSequence(sequenceDescriptor);
    }
}

template <class ElemType>
void TextParser<ElemType>::IncrementNumberOfErrorsOrDie()
{
    if (m_numAllowedErrors == 0)
    {
        PrintWarningNotification();
        RuntimeError("Reached the maximum number of allowed errors"
            " while reading the input file (%ls).",
            m_filename.c_str());
    }
    --m_numAllowedErrors;
}

template <class ElemType>
bool TextParser<ElemType>::TryRefillBuffer()
{
    size_t bytesRead = fread(m_buffer.get(), 1, BUFFER_SIZE, m_file);

    if (bytesRead == (size_t)-1)
    {
        PrintWarningNotification();
        RuntimeError("Could not read from the input file (%ls).", m_filename.c_str());
    }

    if (!bytesRead)
    {
        return false;
    }

    m_fileOffsetStart = m_fileOffsetEnd;
    m_fileOffsetEnd += bytesRead;
    m_bufferStart = m_buffer.get();
    m_pos = m_bufferStart;
    m_bufferEnd = m_bufferStart + bytesRead;
    return true;
}

template <class ElemType>
void TextParser<ElemType>::SetFileOffset(int64_t offset)
{
    int rc = _fseeki64(m_file, offset, SEEK_SET);
    if (rc)
    {
        PrintWarningNotification();
        RuntimeError("Error seeking to position %" PRId64 " in the input file (%ls).",
            offset, m_filename.c_str());
    }

    m_fileOffsetStart = offset;
    m_fileOffsetEnd = offset;

    TryRefillBuffer();
}

template <class ElemType>
typename TextParser<ElemType>::SequenceBuffer TextParser<ElemType>::LoadSequence(const SequenceDescriptor& sequenceDsc)
{
    auto fileOffset = sequenceDsc.m_fileOffsetBytes;

    if (fileOffset < m_fileOffsetStart || fileOffset > m_fileOffsetEnd)
    {
        SetFileOffset(fileOffset);
    }

    size_t bufferOffset = fileOffset - m_fileOffsetStart;
    m_pos = m_bufferStart + bufferOffset;
    size_t bytesToRead = sequenceDsc.m_byteSize;

    SequenceBuffer sequence;

    // TODO: reuse loaded sequences instead of creating new ones!
    for (auto const & stream : m_streamInfos)
    {
        if (stream.m_type == StorageType::dense)
        {
            sequence.push_back(make_unique<DenseInputStreamBuffer>(
                stream.m_sampleDimension * sequenceDsc.m_numberOfSamples));
        }
        else
        {
            sequence.push_back(make_unique<SparseInputStreamBuffer>());
        }
    }

    size_t numRowsRead = 0, expectedRowCount = sequenceDsc.m_numberOfSamples;
    for (size_t i = 0; i < expectedRowCount; i++)
    {
        if ((TryReadRow(sequence, bytesToRead)))
        {
            ++numRowsRead;
        }
        else
        {
            IncrementNumberOfErrorsOrDie();
            if (ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Could not read a row (# %" PRIu64 ")"
                    " while loading sequence (id = %s) %ls.\n",
                    i + 1,
                    GetSequenceKey(sequenceDsc).c_str(),
                    GetFileInfo().c_str());
            }
        }

        if (!bytesToRead && numRowsRead < expectedRowCount)
        {
            if (ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Exhausted all input"
                    " expected for the current sequence (id = %s) %ls,"
                    " but only read %" PRIu64 " out of %" PRIu64 " expected rows.\n",
                    GetSequenceKey(sequenceDsc).c_str(),
                    GetFileInfo().c_str(), numRowsRead, expectedRowCount);
            }
            break;
        }
    }

    // Double check if there are empty input streams.
    // TODO this handling needs to be graceful, but currently CNTK complains when we return empty sequences.
    bool hasEmptyInputs = false, hasDuplicateInputs = false;
    uint32_t maxInputLength = 0;
    for (size_t i = 0; i < sequence.size(); ++i)
    {
        if (sequence[i]->m_numberOfSamples == 0)
        {
            fprintf(stderr,
                "ERROR: Input ('%ls') is empty in sequence (id = %s) %ls.\n",
                m_streams[i]->m_name.c_str(), GetSequenceKey(sequenceDsc).c_str(), GetFileInfo().c_str());
            hasEmptyInputs = true;
        }

        if (sequence[i]->m_numberOfSamples > expectedRowCount)
        {
            hasDuplicateInputs = true;
            if (ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Input ('%ls') contains more samples than expected"
                    " (%u vs. %" PRIu64 ") for sequence (id = %s) %ls.\n",
                    m_streams[i]->m_name.c_str(), sequence[i]->m_numberOfSamples, expectedRowCount,
                    GetSequenceKey(sequenceDsc).c_str(), GetFileInfo().c_str());
            }
        }
        maxInputLength = max(sequence[i]->m_numberOfSamples, maxInputLength);
    }

    if (hasEmptyInputs)
    {
        PrintWarningNotification();
        RuntimeError("Malformed input file. Bailing out.");
    }

    if (hasDuplicateInputs)
    {
        IncrementNumberOfErrorsOrDie();
    }
    else if (maxInputLength < expectedRowCount)
    {
        if (ShouldWarn())
        {
            fprintf(stderr,
                "WARNING: Maximum per-input number of samples for sequence (id = %s) %ls"
                " is less than expected (%u vs. %" PRIu64 ").\n",
                GetSequenceKey(sequenceDsc).c_str(),
                GetFileInfo().c_str(), maxInputLength, expectedRowCount);
        }
        IncrementNumberOfErrorsOrDie();
    }

    if (m_traceLevel >= Info)
    {
        fprintf(stderr,
            "INFO: Finished loading sequence (id = %s) %ls,"
            " successfully read %" PRIu64 " out of expected %" PRIu64 " rows.\n",
            GetSequenceKey(sequenceDsc).c_str(), GetFileInfo().c_str(), numRowsRead, expectedRowCount);
    }

    FillSequenceMetadata(sequence, sequenceDsc.m_id);
    return sequence;
}

template<class ElemType>
void TextParser<ElemType>::FillSequenceMetadata(SequenceBuffer& sequenceData, size_t sequenceId)
{
    for (size_t j = 0; j < m_streamInfos.size(); ++j)
    {
        const StreamInfo& stream = m_streamInfos[j];
        SequenceDataBase* data = sequenceData[j].get();
        if (stream.m_type == StorageType::dense)
        {
            auto denseData = static_cast<DenseInputStreamBuffer*>(data);
            denseData->m_sampleLayout = m_streams[j]->m_sampleLayout;
        }
        else
        {
            auto sparseData = static_cast<SparseInputStreamBuffer*>(data);
            sparseData->m_indices = sparseData->m_indicesBuffer.data();
            assert(data->m_numberOfSamples == sparseData->m_nnzCounts.size());
        }

        data->m_id = sequenceId;
    }
}

template <class ElemType>
bool TextParser<ElemType>::TryReadRow(SequenceBuffer& sequence, size_t& bytesToRead)
{
    while (bytesToRead && CanRead() && IsDigit(*m_pos))
    {
        // skip sequence ids
        ++m_pos;
        --bytesToRead;
    }

    size_t numSampleRead = 0;
    while (bytesToRead && CanRead())
    {
        char c = *m_pos;

        if (c == ROW_DELIMITER)
        {
            // found the end of row, skip the delimiter, return.
            ++m_pos;
            --bytesToRead;

            if (numSampleRead == 0 && ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Empty input row %ls.\n", GetFileInfo().c_str());
            }
            else if (numSampleRead > m_streams.size() && ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Input row %ls contains more"
                    " samples than expected (%" PRIu64 " vs. %" PRIu64 ").\n",
                    GetFileInfo().c_str(), numSampleRead, m_streams.size());
            }

            return numSampleRead > 0;
        }

        if (isColumnDelimiter(c))
        {
            // skip column (input) delimiters.
            ++m_pos;
            --bytesToRead;
            continue;
        }

        if (TryReadSample(sequence, bytesToRead))
        {
            numSampleRead++;
        }
        else
        {
            // skip over until the next sample/end of row
            SkipToNextInput(bytesToRead);
        }
    }

    if (ShouldWarn())
    {
        fprintf(stderr,
            "WARNING: Exhausted all input expected for the current sequence"
            " while reading an input row %ls."
            " Possibly, a trailing newline is missing.\n", GetFileInfo().c_str());
    }
    return false;
}

// Reads one sample (an pipe-prefixed input identifier followed by a list of values)
template <class ElemType>
bool TextParser<ElemType>::TryReadSample(SequenceBuffer& sequence, size_t& bytesToRead)
{
    assert(m_pos < m_bufferEnd);

    // prefix check.
    if (*m_pos != NAME_PREFIX)
    {
        if (ShouldWarn())
        {
            fprintf(stderr,
                "WARNING: Unexpected character('%c') in place of a name prefix ('%c')"
                " in an input name %ls.\n",
                *m_pos, NAME_PREFIX, GetFileInfo().c_str());
        }
        IncrementNumberOfErrorsOrDie();
        return false;
    }

    // skip name prefix
    ++m_pos;
    --bytesToRead;

    if (bytesToRead && CanRead() && *m_pos == ESCAPE_SYMBOL)
    {
        // A vertical bar followed by the number sign (|#) is treated as an escape sequence, 
        // everything that follows is ignored until the next vertical bar or the end of 
        // row, whichever comes first.
        ++m_pos;
        --bytesToRead;
        return false;
    }

    size_t id;
    if (!TryGetInputId(id, bytesToRead))
    {
        IncrementNumberOfErrorsOrDie();
        return false;
    }

    const StreamInfo& stream = m_streamInfos[id];

    if (stream.m_type == StorageType::dense)
    {
        DenseInputStreamBuffer* data = reinterpret_cast<DenseInputStreamBuffer*>(sequence[id].get());
        vector<ElemType>& values = data->m_buffer;
        size_t size = values.size();
        assert(size % stream.m_sampleDimension == 0);
        if (!TryReadDenseSample(values, stream.m_sampleDimension, bytesToRead))
        {
            // expected a dense sample, but was not able to fully read it, ignore it.
            if (values.size() != size)
            {
                //clean up the buffer
                values.resize(size);
            }
            IncrementNumberOfErrorsOrDie();
            return false;
        }
        // everything went well, increment the number of samples.
        ++data->m_numberOfSamples;
    }
    else
    {
        SparseInputStreamBuffer* data = reinterpret_cast<SparseInputStreamBuffer*>(sequence[id].get());
        vector<ElemType>& values = data->m_buffer;
        vector<IndexType>& indices = data->m_indicesBuffer;
        assert(values.size() == indices.size());
        size_t size = values.size();
        if (!TryReadSparseSample(values, indices, stream.m_sampleDimension, bytesToRead))
        {
            // expected a sparse sample, but something went south, ignore it.
            if (values.size() != size)
            {
                //clean up the buffer
                values.resize(size);
            }
            if (indices.size() != size)
            {
                //clean up the buffer
                indices.resize(size);
            }

            IncrementNumberOfErrorsOrDie();
            return false;
        }
        assert(values.size() == indices.size());
        ++data->m_numberOfSamples;
        IndexType count = static_cast<IndexType>(values.size() - size);
        data->m_nnzCounts.push_back(count);
        data->m_totalNnzCount += count;
    }

    return true;
}

template <class ElemType>
bool TextParser<ElemType>::TryGetInputId(size_t& id, size_t& bytesToRead)
{
    char* scratchIndex = m_scratch.get();

    while (bytesToRead && CanRead())
    {
        char c = *m_pos;

        // stop as soon as there's a value delimiter, an input prefix
        // or a non-printable character (e.g., newline, carriage return).
        if (isValueDelimiter(c) || c == NAME_PREFIX || isNonPrintable(c))
        {
            size_t size = scratchIndex - m_scratch.get();
            if (size)
            {
                string name(m_scratch.get(), size);
                auto it = m_aliasToIdMap.find(name);
                if (it != m_aliasToIdMap.end())
                {
                    id = it->second;
                    return true;
                }

                if (ShouldWarn())
                {
                    fprintf(stderr,
                        "WARNING: Invalid input ('%s') %ls. "
                        "Input name '%s' was not specified in the reader config section.\n",
                        name.c_str(), GetFileInfo().c_str(), name.c_str());
                }
            }
            else if (ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Input name prefix ('%c') is followed by"
                    " an invalid character ('%c') %ls.\n",
                    NAME_PREFIX, c, GetFileInfo().c_str());
            }

            return false;
        }
        else if (scratchIndex < (m_scratch.get() + m_maxAliasLength))
        {
            *scratchIndex = c;
            ++scratchIndex;
        }
        else
        {
            // the current string length is already equal to the maximum expected length,
            // yet it's not followed by a delimiter.
            if (ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Did not find a valid input name %ls.\n",
                    GetFileInfo().c_str());
            }
            return false;
        }

        ++m_pos;
        --bytesToRead;
    }

    if (ShouldWarn())
    {
        fprintf(stderr,
            "WARNING: Exhausted all input expected for the current sequence"
            " while reading an input name %ls.\n", GetFileInfo().c_str());
    }
    return false;
}

template <class ElemType>
bool TextParser<ElemType>::TryReadDenseSample(vector<ElemType>& values, size_t sampleSize, size_t& bytesToRead)
{
    size_t counter = 0;
    ElemType value;

    while (bytesToRead && CanRead())
    {
        char c = *m_pos;

        if (isValueDelimiter(c))
        {
            // skip value delimiters
            ++m_pos;
            --bytesToRead;
            continue;
        }

        // return as soon as we hit a non-printable or a name prefix
        if (isNonPrintable(c) || c == NAME_PREFIX)
        {
            if (counter > sampleSize)
            {
                if (ShouldWarn())
                {
                    fprintf(stderr,
                        "WARNING: Dense sample (size = %" PRIu64 ") %ls"
                        " exceeds the expected size (%" PRIu64 ").\n",
                        counter, GetFileInfo().c_str(), sampleSize);
                }
                return false;
            }

            // For dense matrices, it should be possible to input only the left part
            // if the suffix is sparse. Fill up the rest with zeros.
            if (counter < sampleSize)
            {
                if (ShouldWarn())
                {
                    fprintf(stderr,
                        "WARNING: A dense sample %ls has a sparse suffix "
                        "(expected size = %" PRIu64 ", actual size = %" PRIu64 ").\n",
                        GetFileInfo().c_str(), sampleSize, counter);
                }
                for (; counter < sampleSize; ++counter)
                {
                    values.push_back(0.0f);
                }
            }

            return true;
        }

        if (!TryReadRealNumber(value, bytesToRead))
        {
            // bail out.
            return false;
        }

        values.push_back(value);
        ++counter;
    }

    IncrementNumberOfErrorsOrDie();
    if (ShouldWarn())
    {
        fprintf(stderr,
            "WARNING: Exhausted all input expected for the current sequence"
            " while reading a dense sample %ls.\n", GetFileInfo().c_str());
    }
    return false;
}

template <class ElemType>
bool TextParser<ElemType>::TryReadSparseSample(std::vector<ElemType>& values, std::vector<IndexType>& indices,
    size_t sampleSize, size_t& bytesToRead)
{
    size_t index = 0;
    ElemType value;

    while (bytesToRead && CanRead())
    {
        char c = *m_pos;

        if (isValueDelimiter(c))
        {
            // skip value delimiters
            ++m_pos;
            --bytesToRead;
            continue;
        }

        // return as soon as we hit a non-printable or a name prefix
        if (isNonPrintable(c) || c == NAME_PREFIX)
        {
            // empty sparse samples are allowed ("|InputeName_1|InputName2...")
            return true;
        }

        // read next sparse index
        if (!TryReadUint64(index, bytesToRead))
        {
            // bail out.
            return false;
        }

        if (index >= sampleSize)
        {
            if (ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Sparse index value (%" PRIu64 ") %ls"
                    " exceeds the maximum expected value (%" PRIu64 ").\n",
                    index, GetFileInfo().c_str(), sampleSize - 1);
            }
            // bail out.
            return false;
        }

        // an index must be followed by a delimiter
        c = *m_pos;
        if (c != INDEX_DELIMITER)
        {
            if (ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Unexpected character('%c')"
                    " in place of the index delimiter ('%c')"
                    " after a sparse value index (%" PRIu64 ") %ls.\n",
                    c, INDEX_DELIMITER, index, GetFileInfo().c_str());
            }
            return false;
        }

        // skip index delimiter
        ++m_pos;
        --bytesToRead;

        // read the corresponding value
        if (!TryReadRealNumber(value, bytesToRead))
        {
            // bail out.
            return false;
        }

        values.push_back(value);
        indices.push_back(static_cast<IndexType>(index));
    }

    if (ShouldWarn())
    {
        fprintf(stderr,
            "WARNING: Exhausted all input expected for the current sequence"
            " while reading a sparse sample %ls.\n", GetFileInfo().c_str());
    }

    return false;
}

template <class ElemType>
void TextParser<ElemType>::SkipToNextValue(size_t& bytesToRead)
{
    while (bytesToRead && CanRead())
    {
        char c = *m_pos;
        // skip everything until we hit either a value delimiter, an input marker or the end of row.
        if (isValueDelimiter(c) || c == NAME_PREFIX || c == ROW_DELIMITER)
        {
            return;
        }
        ++m_pos;
        --bytesToRead;
    }
}

template <class ElemType>
void TextParser<ElemType>::SkipToNextInput(size_t& bytesToRead)
{
    while (bytesToRead && CanRead())
    {
        char c = *m_pos;
        // skip everything until we hit either an input marker or the end of row.
        if (c == NAME_PREFIX || c == ROW_DELIMITER)
        {
            return;
        }
        ++m_pos;
        --bytesToRead;
    }
}

template <class ElemType>
bool TextParser<ElemType>::TryReadUint64(size_t& value, size_t& bytesToRead)
{
    value = 0;
    bool found = false;
    while (bytesToRead && CanRead())
    {
        char c = *m_pos;

        if (!IsDigit(c))
        {
            return found;
        }

        found |= true;

        size_t temp = value;
        value = value * 10 + (c - '0');
        if (temp > value)
        {
            if (ShouldWarn())
            {
                fprintf(stderr,
                    "WARNING: Overflow while reading a uint64 value %ls.\n",
                    GetFileInfo().c_str());
            }

            return false;
        }

        ++m_pos;
        --bytesToRead;
    }

    if (ShouldWarn())
    {
        fprintf(stderr,
            "WARNING: Exhausted all input expected for the current sequence"
            " while reading a uint64 value %ls.\n", GetFileInfo().c_str());
    }
    return false;
}



// TODO: better precision (at the moment we're at parity with UCIFast)?
// Assumes that bytesToRead is greater than the number of characters 
// in the string representation of the floating point number
// (i.e., the string is followed by one of the delimiters)
// Post condition: m_pos points to the first character that 
// cannot be parsed as part of a floating point number.
// Returns true if parsing was successful.
template <class ElemType>
bool TextParser<ElemType>::TryReadRealNumber(ElemType& value, size_t& bytesToRead)
{
    State state = State::Init;
    double coefficient = .0, number = .0, divider = .0;
    bool negative = false;

    while (bytesToRead && CanRead())
    {
        char c = *m_pos;

        switch (state)
        {
        case State::Init:
            // the number must either start with a number or a sign
            if (IsDigit(c))
            {
                state = IntegralPart;
                number = (c - '0');
            }
            else if (isSign(c))
            {
                state = Sign;
                negative = (c == '-');
            }
            else
            {
                if (ShouldWarn())
                {
                    fprintf(stderr,
                        "WARNING: Unexpected character ('%c')"
                        " in a floating point value %ls.\n",
                        c, GetFileInfo().c_str());
                }
                return false;
            }
            break;
        case Sign:
            // the sign must be followed by a number
            if (IsDigit(c))
            {
                state = IntegralPart;
                number = (c - '0');
            }
            else
            {
                if (ShouldWarn())
                {
                    fprintf(stderr,
                        "WARNING: A sign symbol is followed by an invalid character('%c')"
                        " in a floating point value %ls.\n",
                        c, GetFileInfo().c_str());
                }
                return false;
            }
            break;
        case IntegralPart:
            if (IsDigit(c))
            {
                number = number * 10 + (c - '0');
            }
            else if (c == '.')
            {
                state = Period;
            }
            else if (isE(c))
            {
                state = TheLetterE;
                coefficient = (negative) ? -number : number;
                number = 0;
            }
            else
            {
                value = static_cast<ElemType>((negative) ? -number : number);
                return true;
            }
            break;
        case Period:
            if (IsDigit(c))
            {
                state = FractionalPart;
                coefficient = number;
                number = (c - '0');
                divider = 10;
            }
            else
            {
                value = static_cast<ElemType>((negative) ? -number : number);
                return true;
            }
            break;
        case FractionalPart:
            if (IsDigit(c))
            {
                // TODO: ignore if number of precision digits > FLT_[MANT_]DIG/DBL_[MANT_]DIG
                // no state change
                number = number * 10 + (c - '0');
                divider *= 10;
            }
            else if (isE(c))
            {
                state = TheLetterE;
                coefficient += (number / divider);
                if (negative)
                {
                    coefficient = -coefficient;
                }
            }
            else
            {
                coefficient += (number / divider);
                value = static_cast<ElemType>((negative) ? -coefficient : coefficient);
                return true;
            }
            break;
        case TheLetterE:
            // followed with optional minus or plus sign and nonempty sequence of decimal digits
            if (IsDigit(c))
            {
                state = Exponent;
                negative = false;
                number = (c - '0');
            }
            else if (isSign(c))
            {
                state = ExponentSign;
                negative = (c == '-');
            }
            else
            {
                if (ShouldWarn())
                {
                    fprintf(stderr,
                        "WARNING: An exponent symbol is followed by"
                        " an invalid character('%c')"
                        " in a floating point value %ls.\n", c, GetFileInfo().c_str());
                }
                return false;
            }
            break;
        case ExponentSign:
            // exponent sign must be followed by a number
            if (IsDigit(c))
            {
                state = Exponent;
                number = (c - '0');
            }
            else
            {
                if (ShouldWarn())
                {
                    fprintf(stderr,
                        "WARNING: An exponent sign symbol followed by"
                        " an unexpected character('%c')"
                        " in a floating point value %ls.\n", c, GetFileInfo().c_str());
                }
                return false;
            }
            break;
        case Exponent:
            if (IsDigit(c))
            {
                // no state change
                number = number * 10 + (c - '0');
            }
            else
            {
                // TODO: check the exponent value (see FLT_[MAX/MIN]_10_EXP).
                double exponent = (negative) ? -number : number;
                value = static_cast<ElemType>(coefficient * pow(10.0, exponent));
                return true;
            }
            break;
        default:
            LogicError("Reached an invalid state while reading a floating point value %ls.\n",
                GetFileInfo().c_str());
        }

        ++m_pos;
        --bytesToRead;
    }

    if (ShouldWarn())
    {
        fprintf(stderr,
            "WARNING: Exhausted all input expected for the current sequence"
            " while reading a floating point value %ls.\n", GetFileInfo().c_str());
    }

    return false;
}

template <class ElemType>
void TextParser<ElemType>::SetTraceLevel(unsigned int traceLevel)
{
    m_traceLevel = traceLevel;
}

template <class ElemType>
void TextParser<ElemType>::SetMaxAllowedErrors(unsigned int maxErrors)
{
    m_numAllowedErrors = maxErrors;
}

template <class ElemType>
void TextParser<ElemType>::SetSkipSequenceIds(bool skip)
{
    m_skipSequenceIds = skip;
}

template <class ElemType>
void TextParser<ElemType>::SetChunkSize(size_t size)
{
    m_chunkSizeBytes = size;
}

template <class ElemType>
void TextParser<ElemType>::SetNumRetries(unsigned int numRetries)
{
    m_numRetries = numRetries;
}

template <class ElemType>
std::wstring TextParser<ElemType>::GetFileInfo()
{
    std::wstringstream info;
    info << L"at offset " << GetFileOffset() << L" in the input file (" << m_filename << L")";
    return info.str();
}

template <class ElemType>
bool TextParser<ElemType>::GetSequenceDescriptionByKey(const KeyType& key, SequenceDescription& result)
{
    const auto& keys = m_indexer->GetIndex().m_keyToSequenceInChunk;
    auto sequenceLocation = keys.find(key.m_sequence);
    if (sequenceLocation == keys.end())
    {
        return false;
    }

    result = m_indexer->GetIndex().m_chunks[sequenceLocation->second.first].m_sequences[sequenceLocation->second.second];
    return true;
}

template <class ElemType>
const string& TextParser<ElemType>::GetSequenceKey(const SequenceDescriptor& s) const
{
    return m_corpus->GetStringRegistry()[s.m_key.m_sequence];
}

template class TextParser<float>;
template class TextParser<double>;
}}}
back to top