Raw File
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#define _CRT_SECURE_NO_WARNINGS
#define _SCL_SECURE_NO_WARNINGS

#include <cmath>
#include "TruncatedBpttPacker.h"
#include "ReaderUtil.h"

namespace Microsoft { namespace MSR { namespace CNTK {

using namespace std;

// Represents a slot where we accumulate sequences from which the minibatch is created.
// The number of slots equals number of parallel sequences we want to pack.
class Slot
{
public:
    Slot() : m_length(0), m_sampleCursor(0), m_sampleOffset(0)
    {}

    // Checks if slot is empty.
    bool IsEmpty() const
    {
        return m_sequences.empty();
    }

    void Reset() 
    {
        m_length = 0;
        m_sampleCursor = 0;
        m_sampleOffset = 0;
        m_sequences.clear();
        m_endOfSweepFlags.clear();
    }

    // Gets the number of available samples in the slot.
    size_t AvailableNumberOfSamples() const
    {
        assert(m_length >= m_sampleCursor);
        assert(m_sequences.empty() ? m_sampleCursor == 0 : FrontSequence()->m_numberOfSamples >= m_sampleCursor);
        return m_length - m_sampleCursor;
    }

    // Adds a new sequence to the end of the slot.
    void PushSequence(SequenceDataPtr s, bool endOfSweep)
    {
        m_sequences.push_back(s);
        m_endOfSweepFlags.push_back(endOfSweep);
        m_length += s->m_numberOfSamples;
    }

    SequenceDataPtr FrontSequence() const
    {
        assert(!m_sequences.empty());
        return m_sequences.front();
    }

    // Pops the front sequence at the beginning of the slot.
    bool PopSequence()
    {
        assert(!m_sequences.empty());
        m_sampleCursor = 0;
        m_sampleOffset = 0;
        m_length -= m_sequences.front()->m_numberOfSamples;
        m_sequences.pop_front();
        bool endOfSweepFlag = m_endOfSweepFlags.front();
        m_endOfSweepFlags.pop_front();
        return endOfSweepFlag;
    }

    // Contains the current sample cursor in the first sequence(m_sequences.front()) of the slot.
    size_t m_sampleCursor;

    // offset of the current sample into the data region of the first sequence.
    // For dense input, this is just (sample cursor) x (sample size in bytes).
    // For sparse input, this is (element size in bytes) x (sum of nnz counts 
    // of all preceding samples).
    size_t m_sampleOffset; 

private:
    // Prepared sequences.
    deque<SequenceDataPtr> m_sequences;
    
    // For each 'in-flight' sequence we keep a flag that indicate whether 
    // the sequence data comes from an the end of a sweep.
    std::deque<bool> m_endOfSweepFlags;

    // Contains the size of the slot in samples (accumulated over all m_sequences).
    size_t m_length;
};

// Represents a buffer of slots from which the minibatch is created.
struct SequenceBuffer
{
    SequenceBuffer(size_t numParallelSequences)
    {
        // Allocates required slots.
        m_slots.resize(numParallelSequences);
    }

    // Checks whether there is more data available in any of the slots.
    bool NothingToPack() const
    {
        auto it = find_if(m_slots.begin(), m_slots.end(), [](const Slot& s) -> bool { return !s.IsEmpty(); });
        return it == m_slots.end();
    }

    // A matrix of prepared sequences. The number of rows(RN) = m_numParallelSequences = number of slots
    // in each row we at least holding sequences to fill in the truncation length.
    // Only at the end of the epoch there could be less than truncation length number of samples in this matrix
    //
    // It looks something like that:
    // slot1: /***s11***/ /***s12**/
    //  ....
    // slotM: /**********sM1****/
    //  ....
    // slotN: /*sRN1*//*sRN2*//*sRN2*/
    vector<Slot> m_slots;
};

TruncatedBPTTPacker::TruncatedBPTTPacker(
    SequenceEnumeratorPtr sequenceEnumerator,
    const vector<StreamDescriptionPtr>& streams,
    size_t numberOfBuffers,
    CorpusDescriptorPtr corpus)
    : PackerBase(corpus, sequenceEnumerator, streams, numberOfBuffers)
{
    auto sparseOutput = find_if(m_outputStreamDescriptions.begin(), m_outputStreamDescriptions.end(), [](const StreamDescriptionPtr& s){ return s->m_storageType == StorageType::sparse_csc; });
    if (sparseOutput != m_outputStreamDescriptions.end())
    {
        // TODO: add support for sparse.
        RuntimeError("Sparse output is not supported in BPTT mode.");
    }

    // Preparing layouts.
    for (int i = 0; i < m_outputStreamDescriptions.size(); ++i)
    {
        auto pMBLayout = make_shared<MBLayout>();
        pMBLayout->SetUniqueAxisName(L"TruncatedBPTTPacker");
        m_currentLayouts.push_back(pMBLayout);
    }
}

void TruncatedBPTTPacker::Reset() 
{
    for (auto& buffer : m_sequenceBufferPerStream) 
    {
        for (auto& slot : buffer->m_slots) 
        {
            slot.Reset();
        }
    }
}

void TruncatedBPTTPacker::SetConfiguration(const ReaderConfiguration& config, const std::vector<MemoryProviderPtr>& memoryProviders)
{
    auto oldMinibatchSize = m_config.m_minibatchSizeInSamples;
    auto oldTruncationSize = m_config.m_truncationSize;

    PackerBase::SetConfiguration(config, memoryProviders);

    if (m_config.m_truncationSize == 0)
        LogicError("Truncation size cannot be zero.");

    if (oldMinibatchSize != m_config.m_minibatchSizeInSamples ||
        oldTruncationSize != m_config.m_truncationSize)
    {
        // Estimating the number of parallel sequences to pack (slots) from the minibatch size and truncation size.
        m_numParallelSequences = max(1, static_cast<int>(std::floor(m_config.m_minibatchSizeInSamples / m_config.m_truncationSize)));

        if (m_config.m_numberOfWorkers > m_numParallelSequences)
        {
            InvalidArgument("Too many workers for minibatch size; please increase minibatch size or decrease number of workers.");
        }

        m_numParallelSequences =
            (m_numParallelSequences / m_config.m_numberOfWorkers) +
            (m_config.m_workerRank < (m_numParallelSequences % m_config.m_numberOfWorkers) ? 1 : 0);

        m_sequenceBufferPerStream.clear();

        // Preparing the buffers. 
        for (int j = 0; j < m_streamBuffers.size(); ++j)
            for (int i = 0; i < m_outputStreamDescriptions.size(); ++i)
            {
                const auto& stream = m_outputStreamDescriptions[i];
                auto& buffer = m_streamBuffers[j][i];
                buffer.Resize(m_numParallelSequences * m_config.m_truncationSize * GetSampleSize(stream));
                m_sequenceBufferPerStream.push_back(make_shared<SequenceBuffer>(m_numParallelSequences));
            }
    }
    else
    {
        Reset();
    }
}

Minibatch TruncatedBPTTPacker::ReadMinibatch()
{
    FillOutAvailableSlots();

    // Currently all we expect sequences of identical length between different streams,
    // so it is sufficient to check a single stream only.
    if (m_sequenceBufferPerStream.front()->NothingToPack())
    {   
        return Minibatch(/*endOfSweep = */false,/*endOfEpoch = */ true);
    }

    Minibatch result;

    // Iterating over the streams/slots and packing them into the minibatch.
    std::vector<size_t> mbSeqIdToCorpusSeqId;
    for (size_t streamIndex = 0; streamIndex < m_outputStreamDescriptions.size(); ++streamIndex)
    {
        // We will take only the last stream, because currently in BPTT 
        // all mblayouts should  match anyway.
        mbSeqIdToCorpusSeqId.clear();

        m_currentLayouts[streamIndex]->Init(m_numParallelSequences, m_config.m_truncationSize);
        size_t sequenceId = 0;
        for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
        {
            result.m_endOfSweep |= PackSlot(streamIndex, slotIndex, sequenceId, mbSeqIdToCorpusSeqId);
        }

        StreamMinibatchPtr m = make_shared<StreamMinibatch>();
        m->m_data = m_streamBuffers[m_currentBufferIndex][streamIndex].m_data.get();
        m->m_layout = m_currentLayouts[streamIndex];
        result.m_data.push_back(m);
    }

    m_currentBufferIndex = (m_currentBufferIndex + 1) % m_numberOfBuffers;

    // Eagerly set the end of epoch flag if all the data have been packed.
    result.m_endOfEpoch = m_sequenceBufferPerStream.front()->NothingToPack();

    // Return mapping between sequence id inside the minibatch layout
    // and the string representation of the sequence key.
    if (m_corpus == nullptr)
        result.m_getKeyById = [](size_t)
        {
            RuntimeError("Sequence Id mapping is not available for old style configurations. Please use deserializers.");
            return "";
        };
    else
        result.m_getKeyById = [mbSeqIdToCorpusSeqId, this](size_t id) { return m_corpus->IdToKey(mbSeqIdToCorpusSeqId[id]); };
    return result;
}

// Packs a slot of sequences into the minibatch.
bool TruncatedBPTTPacker::PackSlot(size_t streamIndex, size_t slotIndex, size_t& sequenceId, std::vector<size_t>& idToKey)
{
    bool containsEndOfSweepSequence = false;
    auto& slot = m_sequenceBufferPerStream[streamIndex]->m_slots[slotIndex];

    // Let's see how much samples we need to read.
    size_t numberOfSamples = min(m_config.m_truncationSize, slot.AvailableNumberOfSamples());
    if (numberOfSamples == 0)
    {
        // Reached the end of the data, put the corresponding row in the minibatch layout to gap.
        m_currentLayouts[streamIndex]->AddSequence(GAP_SEQUENCE_ID, slotIndex, 0, m_config.m_truncationSize);

        // Check that nothing is in the slot any more.
        assert(slot.IsEmpty());
        return false;
    }

    size_t sampleSize = GetSampleSize(m_inputStreamDescriptions[streamIndex]);
    StorageType storageType = m_inputStreamDescriptions[streamIndex]->m_storageType;
    size_t elementSize = GetSizeByType(m_inputStreamDescriptions[streamIndex]->m_elementType);

    // Distance between two samples of the same sequence in bytes.
    size_t strideSize = m_numParallelSequences * sampleSize;

    // Add current sequence to the minibatch layout.
    idToKey.resize(sequenceId + 1);
    idToKey[sequenceId] = slot.FrontSequence()->m_key.m_sequence;
    m_currentLayouts[streamIndex]->AddSequence(
        sequenceId++,
        slotIndex,
        -(int)slot.m_sampleCursor,
        slot.FrontSequence()->m_numberOfSamples - slot.m_sampleCursor);

    // Ok, now fill in the buffer with data.
    for (size_t currentTimestep = 0; currentTimestep < numberOfSamples; ++currentTimestep)
    {
        // Check if reach the end of the front sequence.
        if (slot.m_sampleCursor >= slot.FrontSequence()->m_numberOfSamples)
        {
            // Starting a new sequence. Have to reset current pointers and add it to the minibatch layout.
            containsEndOfSweepSequence |= slot.PopSequence();

            //Adding next sequence to the minibatch.
            idToKey.resize(sequenceId + 1);
            idToKey[sequenceId] = slot.FrontSequence()->m_key.m_sequence;
            m_currentLayouts[streamIndex]->AddSequence(
                sequenceId++,
                slotIndex,
                currentTimestep,
                currentTimestep + slot.FrontSequence()->m_numberOfSamples);
        }

        // Fill in the data from the first sequence in the slot.
        auto data = slot.FrontSequence();
        // Get buffer destination for the current sample.
        auto& buffer = m_streamBuffers[m_currentBufferIndex][streamIndex];
        auto offset = strideSize * currentTimestep + slotIndex * sampleSize;
        assert(offset >= 0 && offset < buffer.m_size);
        char* destination = buffer.m_data.get() + offset;

        // Pack the sample.
        if (storageType == StorageType::dense)
        {
            assert(slot.m_sampleOffset == slot.m_sampleCursor * sampleSize);
            PackDenseSample(destination, data, slot.m_sampleOffset, sampleSize);
            slot.m_sampleOffset += sampleSize;
        }
        else
        {
            assert(storageType == StorageType::sparse_csc);
            // TODO: make type casts members of the SparseSequenceData
            SparseSequenceDataPtr sparseSequence = static_pointer_cast<SparseSequenceData>(data);
            assert(slot.m_sampleCursor < sparseSequence->m_nnzCounts.size());
            PackSparseSampleAsDense(destination, sparseSequence, slot.m_sampleCursor, 
                slot.m_sampleOffset, sampleSize, elementSize);
            slot.m_sampleOffset += sparseSequence->m_nnzCounts[slot.m_sampleCursor];
            assert(slot.m_sampleOffset <= sparseSequence->m_totalNnzCount);
        }

        slot.m_sampleCursor++;
    }

    // Cleaning up the last sequence we have just read if needed.
    if (slot.m_sampleCursor >= slot.FrontSequence()->m_numberOfSamples)
    {
        containsEndOfSweepSequence |= slot.PopSequence();
    }

    // Adding the last gap if there is one.
    if (numberOfSamples < m_config.m_truncationSize)
    {
        m_currentLayouts[streamIndex]->AddSequence(
            GAP_SEQUENCE_ID,
            slotIndex,
            numberOfSamples,
            m_config.m_truncationSize);
    }

    return containsEndOfSweepSequence;
}

void TruncatedBPTTPacker::FillOutAvailableSlots()
{
     // Filling out any available spaces
    for (size_t slotIndex = 0; slotIndex < m_numParallelSequences; ++slotIndex)
    {
        ReadSequencesToSlot(slotIndex);
    }
}

void TruncatedBPTTPacker::ReadSequencesToSlot(size_t slotIndex)
{
    const auto& firstStreamSlot = m_sequenceBufferPerStream.front()->m_slots[slotIndex];
    while (m_config.m_truncationSize >= firstStreamSlot.AvailableNumberOfSamples())
    {
        // We need a single sequence, potentially we can request (m_truncationSize - slot.AvailableNumberOfSamples())
        // to be more efficient. In reality the truncation size usually is less the sequence size.
        // Bptt always operates on a local timeline, so we do not limit the global minibatch count.
        const auto& sequences = m_sequenceEnumerator->GetNextSequences(SIZE_MAX, 1);

        // assert that number of input streams == number of output streams -- 
        // this does not have to be the case in general, but the current
        // implementation makes this implicit assumption, so let's make it
        // explicit instead until we can get rid of it altogether.
        assert(sequences.m_endOfEpoch || sequences.m_data.size() == m_outputStreamDescriptions.size());

        const auto& data = sequences.m_data;

        // Adding sequence to the slot for all streams.
        for (size_t streamIndex = 0; streamIndex < data.size(); ++streamIndex)
        {
            assert(data[streamIndex].size() == 1);

            const auto& streamSequenceDataVector = data[streamIndex];
            auto& slot = m_sequenceBufferPerStream[streamIndex]->m_slots[slotIndex];

            // Check that all sequences are of the same length.
            if (data.front().front()->m_numberOfSamples != streamSequenceDataVector.front()->m_numberOfSamples)
            {
                RuntimeError("For BPTT sequences between different input stream should have the same length.");
            }
                
            slot.PushSequence(streamSequenceDataVector.front(), sequences.m_endOfSweep);

            assert(firstStreamSlot.AvailableNumberOfSamples() == slot.AvailableNumberOfSamples());
        }

        if (sequences.m_endOfEpoch)
        {
            return;
        }
    }
}

}}}
back to top