swh:1:snp:f50ab94432af916b5fb8b4ad831e8dddded77084
Raw File
Tip revision: a87d69a778403f3174045f2cfb590476e7646189 authored by Sergii Dymchenko on 02 August 2018, 17:41:05 UTC
Fix Hardmax/Softmax/LogSoftmax ONNX export.
Tip revision: a87d69a
ReaderUtil.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include "Reader.h"
#include "SequenceEnumerator.h"
#include "Config.h"
#include <boost/algorithm/string.hpp>

namespace CNTK {

size_t GetRandomizationWindowFromConfig(const Microsoft::MSR::CNTK::ConfigParameters& config);

inline size_t GetRandomSeed(const Microsoft::MSR::CNTK::ConfigParameters& config)
{
    return config(L"randomizationSeed", size_t(0));
}

static std::vector<unsigned char> FillIndexTable()
{
    std::vector<unsigned char> indexTable;
    indexTable.resize(std::numeric_limits<unsigned char>().max());
    char value = 0;
    for (unsigned char i = 'A'; i <= 'Z'; i++)
        indexTable[i] = value++;
    assert(value == 26);

    for (unsigned char i = 'a'; i <= 'z'; i++)
        indexTable[i] = value++;
    assert(value == 52);

    for (unsigned char i = '0'; i <= '9'; i++)
        indexTable[i] = value++;
    assert(value == 62);
    indexTable['+'] = value++;
    indexTable['/'] = value++;
    assert(value == 64);
    return indexTable;
}

const static char* base64IndexTable = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static std::vector<unsigned char> base64DecodeTable = FillIndexTable();

inline bool IsBase64Char(char c)
{
    return isalnum(c) || c == '/' || c == '+' || c == '=';
}

inline bool DecodeBase64(const char* begin, const char* end, std::vector<char>& result)
{
    assert(std::find_if(begin, end, [](char c) { return !IsBase64Char(c); }) == end);

    size_t length = end - begin;
    if (length % 4 != 0)
        return false;

    result.resize((length * 3) / 4); // Upper bound on the max number of decoded symbols.
    size_t currentDecodedIndex = 0;
    while (begin < end)
    {
        result[currentDecodedIndex++] = base64DecodeTable[*begin] << 2 | base64DecodeTable[*(begin + 1)] >> 4;
        result[currentDecodedIndex++] = base64DecodeTable[*(begin + 1)] << 4 | base64DecodeTable[*(begin + 2)] >> 2;
        result[currentDecodedIndex++] = base64DecodeTable[*(begin + 2)] << 6 | base64DecodeTable[*(begin + 3)];
        begin += 4;
    }

    // In Base 64 each 3 characters are encoded with 4 bytes. Plus there could be padding (last two bytes)
    size_t resultingLength = (length * 3) / 4 - (*(end - 2) == '=' ? 2 : (*(end - 1) == '=' ? 1 : 0));
    result.resize(resultingLength);
    return true;
}

// Class to clean/keep track of invalid sequences.
class SequenceCleaner
{
public:
    SequenceCleaner(size_t maxNumberOfInvalidSequences)
        : m_maxNumberOfInvalidSequences(maxNumberOfInvalidSequences),
        m_numberOfCleanedSequences(0)
    {}

    // Removes invalid sequences in place.
    void Clean(Sequences& sequences)
    {
        if (sequences.m_data.empty())
            return;

        size_t clean = 0;
        for (size_t i = 0; i < sequences.m_data.front().size(); ++i)
        {
            bool invalid = false;
            for (const auto& s : sequences.m_data)
            {
                if (!s[i]->m_isValid)
                {
                    invalid = true;
                    break;
                }
            }

            if (invalid)
            {
                m_numberOfCleanedSequences++;
                continue;
            }

            // For all streams reassign the sequence.
            for (auto& s : sequences.m_data)
                s[clean] = s[i];
            clean++;
        }

        if (clean == 0)
        {
            sequences.m_data.resize(0);
            return;
        }

        // For all streams set new size.
        for (auto& s : sequences.m_data)
            s.resize(clean);

        if (m_numberOfCleanedSequences > m_maxNumberOfInvalidSequences)
            RuntimeError("Number of invalid sequences '%d' in the input exceeded the specified maximum number '%d'",
                (int)m_numberOfCleanedSequences,
                (int)m_maxNumberOfInvalidSequences);
    }

private:
    // Number of sequences cleaned.
    size_t m_numberOfCleanedSequences;

    // Max number of allowed invalid sequences.
    size_t m_maxNumberOfInvalidSequences;
};

// Boost split is too slow, this one gives almost 200% better results for initial parsing of big text files.
// Splits the incoming sequence given by begin and end according to the delimiters without string copies.
template<class T>
inline void Split(T* begin, T* end, const std::vector<bool>& delimiters, std::vector<boost::iterator_range<T*>>& result)
{
    assert(delimiters.size() > std::numeric_limits<unsigned char>::max());

    auto start = begin;
    while (begin != end)
    {
        if (delimiters[static_cast<unsigned char>(*begin)])
        {
            result.push_back(boost::make_iterator_range(start, begin));
            start = begin + 1;
        }
        ++begin;
    }

    // Adding last.
    result.push_back(boost::make_iterator_range(start, end));
}

// Function that is used to build delimiter hashes.
inline std::vector<bool> DelimiterHash(const std::vector<char>& values)
{
    std::vector<bool> delim_equal(256, false);
    for (const auto& c : values)
        delim_equal[c] = true;
    return delim_equal;
}

// Reads from start to end till one of the delimiters is reached.
// Can be used on character buffers that are not proper C strings.
// Returns a new start pointer.
template<class T>
inline T* ReadTillDelimiter(T* begin, T* end, const std::vector<bool>& delimiters, boost::iterator_range<T*>& result)
{
    assert(delimiters.size() > std::numeric_limits<unsigned char>::max());

    auto start = begin;
    while (begin != end)
    {
        if (delimiters[static_cast<unsigned char>(*begin)])
        {
            result = boost::make_iterator_range(start, begin);
            return begin + 1;
        }

        ++begin;
    }

    result = boost::make_iterator_range(begin, end);
    return end;
}

}
back to top