swh:1:snp:f50ab94432af916b5fb8b4ad831e8dddded77084
Raw File
Tip revision: 6b2bcac054816cd6e8679e60738f369ea60efcc4 authored by Mark Hillebrand on 09 June 2017, 11:45:25 UTC
configure: bump openjdk
Tip revision: 6b2bcac
minibatchiterator.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
// minibatchiterator.h -- iterator for minibatches
//

#pragma once
#define NONUMLATTICEMMI // [v-hansu] move from main.cpp, no numerator lattice for mmi training

#include <vector>
#include <unordered_map>
#include "ssematrix.h"
#include "latticearchive.h"        // for reading HTK phoneme lattices (MMI training)
#include "simple_checked_arrays.h" // for const_array_ref
#include "latticesource.h"

namespace msra { namespace dbn {

// ---------------------------------------------------------------------------
// minibatchsource -- abstracted interface into frame sources
// There are three implementations:
//  - the old minibatchframesource to randomize across frames and page to disk
//  - minibatchutterancesource that randomizes in chunks and pages from input files directly
//  - a wrapper that uses a thread to read ahead in parallel to CPU/GPU processing
// ---------------------------------------------------------------------------
class minibatchsource
{
public:
    // read a minibatch
    // This function returns all values in a "caller can keep them" fashion:
    //  - uids are stored in a huge 'const' array, and will never go away
    //  - transcripts are copied by value
    //  - lattices are returned as a shared_ptr
    // Thus, getbatch() can be called in a thread-safe fashion, allowing for a 'minibatchsource' implementation that wraps another with a read-ahead thread.
    // Return value is 'true' if it did read anything from disk, and 'false' if data came only from RAM cache. This is used for controlling the read-ahead thread.
    virtual bool getbatch(const size_t globalts,
                          const size_t framesrequested, msra::dbn::matrix &feat, std::vector<size_t> &uids,
                          std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
                          std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices) = 0;
    // alternate (updated) definition for multiple inputs/outputs - read as a vector of feature matrixes or a vector of label strings
    virtual bool getbatch(const size_t globalts,
                          const size_t framesrequested, std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
                          std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
                          std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
                          std::vector<std::vector<size_t>> &phoneboundaries) = 0;
    // getbatch() overload to support subsetting of mini-batches for parallel training
    // Default implementation does not support subsetting and throws an exception on
    // calling this overload with a numsubsets value other than 1.
    virtual bool getbatch(const size_t globalts,
                          const size_t framesrequested, const size_t subsetnum, const size_t numsubsets, size_t &framesadvanced,
                          std::vector<msra::dbn::matrix> &feat, std::vector<std::vector<size_t>> &uids,
                          std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> &transcripts,
                          std::vector<std::shared_ptr<const latticesource::latticepair>> &lattices, std::vector<std::vector<size_t>> &sentendmark,
                          std::vector<std::vector<size_t>> &phoneboundaries)
    {
        assert((subsetnum == 0) && (numsubsets == 1) && !supportsbatchsubsetting());
        subsetnum;
        numsubsets;
        bool retVal = getbatch(globalts, framesrequested, feat, uids, transcripts, lattices, sentendmark, phoneboundaries);
        framesadvanced = feat[0].cols();

        return retVal;
    }

    virtual bool supportsbatchsubsetting() const
    {
        return false;
    }

    virtual size_t totalframes() const = 0;

    virtual double gettimegetbatch() = 0;                         // used to report runtime
    virtual size_t firstvalidglobalts(const size_t globalts) = 0; // get first valid epoch start from intended 'globalts'
    virtual const std::vector<size_t> &unitcounts() const = 0;    // report number of senones
    virtual void setverbosity(int newverbosity) = 0;
    virtual ~minibatchsource()
    {
    }
};

// ---------------------------------------------------------------------------
// minibatchiterator -- class to iterate over one epoch, minibatch by minibatch
// This iterator supports both random frames and random utterances through the minibatchsource interface whichis common to both.
// This supports multiple data passes with identical randomization; which is intended to be used for utterance-based training.
// ---------------------------------------------------------------------------
class minibatchiterator
{
    void operator=(const minibatchiterator &); // (non-copyable)

    const size_t epochstartframe;
    const size_t epochendframe;
    size_t firstvalidepochstartframe; // epoch start frame rounded up to first utterance boundary after epoch boundary
    const size_t requestedmbframes;   // requested mb size; actual minibatches can be smaller (or even larger for lattices)
    const size_t datapasses;          // we return the data this many times; caller must sub-sample with 'datapass'

    msra::dbn::minibatchsource &source; // feature source to read from

    // subset to read during distributed data-parallel training (no subsetting: (0,1))
    size_t subsetnum;
    size_t numsubsets;

    std::vector<msra::dbn::matrix> featbuf;                                                      // buffer for holding curernt minibatch's frames
    std::vector<std::vector<size_t>> uids;                                                       // buffer for storing current minibatch's frame-level label sequence
    std::vector<const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>> transcripts; // buffer for storing current minibatch's word-level label sequences (if available and used; empty otherwise)
    std::vector<std::shared_ptr<const latticesource::latticepair>> lattices;                          // lattices of the utterances in current minibatch (empty in frame mode)

    std::vector<std::vector<size_t>> sentendmark;     // buffer for storing current minibatch's utterance end
    std::vector<std::vector<size_t>> phoneboundaries; // buffer for storing phone boundaries
    size_t mbstartframe;                              // current start frame into generalized time line (used for frame-wise mode and for diagnostic messages)
    size_t actualmbframes;                            // actual number of frames in current minibatch
    size_t mbframesadvanced;                          // logical number of frames the current MB represents (to advance time; > featbuf.cols() possible, intended for the case of distributed data-parallel training)
    size_t datapass;                                  // current datapass = pass through the data
    double timegetbatch;                              // [v-hansu] for time measurement
    double timechecklattice;

private:
    // fetch the next mb
    // This updates featbuf, uids[], mbstartframe, and actualmbframes.
    void fillorclear()
    {
        if (!hasdata()) // we hit the end of the epoch: just cleanly clear out everything (not really needed, can't be requested ever)
        {
            foreach_index (i, featbuf)
                featbuf[i].resize(0, 0);

            foreach_index (i, uids)
                uids[i].clear();

            transcripts.clear();
            actualmbframes = 0;
            return;
        }
        // process one mini-batch (accumulation and update)
        assert(requestedmbframes > 0);
        const size_t requestedframes = std::min(requestedmbframes, epochendframe - mbstartframe); // (< mbsize at end)
        assert(requestedframes > 0);
        source.getbatch(mbstartframe, requestedframes, subsetnum, numsubsets, mbframesadvanced, featbuf, uids, transcripts, lattices, sentendmark, phoneboundaries);
        timegetbatch = source.gettimegetbatch();
        actualmbframes = featbuf[0].cols(); // for single i/o, there featbuf is length 1
        // note:
        //  - in frame mode, actualmbframes may still return less if at end of sweep
        //  - in utterance mode, it likely returns less than requested, and
        //    it may also be > epochendframe (!) for the last utterance, which, most likely, crosses the epoch boundary
        //  - in case of data parallelism, featbuf.cols() < mbframesadvanced
        auto_timer timerchecklattice;
        if (!lattices.empty())
        {
            size_t totalframes = 0;
            foreach_index (i, lattices)
                totalframes += lattices[i]->getnumframes();
            if (totalframes != actualmbframes)
                LogicError("fillorclear: frames in lattices do not match minibatch size");
        }
        timechecklattice = timerchecklattice;
    }
    bool hasdata() const
    {
        return mbstartframe < epochendframe;
    } // true if we can access and/or advance
    void checkhasdata() const
    {
        if (!hasdata())
            LogicError("minibatchiterator: access beyond end of epoch");
    }

public:
    // interface: for (minibatchiterator i (...), i, i++) { ... }
    minibatchiterator(msra::dbn::minibatchsource &source, size_t epoch, size_t epochframes, size_t requestedmbframes, size_t subsetnum, size_t numsubsets, size_t datapasses)
        : source(source),
          epochstartframe(epoch * epochframes),
          epochendframe(epochstartframe + epochframes),
          requestedmbframes(requestedmbframes),
          subsetnum(subsetnum),
          numsubsets(numsubsets),
          datapasses(datapasses),
          timegetbatch(0),
          timechecklattice(0)
    {
        firstvalidepochstartframe = source.firstvalidglobalts(epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
        fprintf(stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
                (int) epoch, (int) epochstartframe, (int) epochendframe, (int) firstvalidepochstartframe, (int) subsetnum, (int) numsubsets, (int) datapasses);
        mbstartframe = firstvalidepochstartframe;
        datapass = 0;
        fillorclear(); // get the first batch
    }

    // TODO not nice, but don't know how to access these frames otherwise
    // mbiterator constructor, set epochstart and -endframe explicitly
    minibatchiterator(msra::dbn::minibatchsource &source, size_t epoch, size_t epochstart, size_t epochend, size_t requestedmbframes, size_t subsetnum, size_t numsubsets, size_t datapasses)
        : source(source),
          epochstartframe(epochstart),
          epochendframe(epochend),
          requestedmbframes(requestedmbframes),
          subsetnum(subsetnum),
          numsubsets(numsubsets),
          datapasses(datapasses),
          timegetbatch(0),
          timechecklattice(0)
    {
        firstvalidepochstartframe = source.firstvalidglobalts(epochstartframe); // epochstartframe may fall between utterance boundaries; this gets us the first valid boundary
        fprintf(stderr, "minibatchiterator: epoch %d: frames [%d..%d] (first utterance at frame %d), data subset %d of %d, with %d datapasses\n",
                (int) epoch, (int) epochstartframe, (int) epochendframe, (int) firstvalidepochstartframe, (int) subsetnum, (int) numsubsets, (int) datapasses);
        mbstartframe = firstvalidepochstartframe;
        datapass = 0;
        fillorclear(); // get the first batch
    }

    // need virtual destructor to ensure proper destruction
    virtual ~minibatchiterator()
    {
    }

    // returns true if we still have data
    operator bool() const
    {
        return hasdata();
    }

    // advance to the next minimb
    void operator++(int /*denotes postfix version*/)
    {
        checkhasdata();
        mbstartframe += mbframesadvanced;
        // if we hit the end, we will get mbstartframe >= epochendframe <=> !hasdata()
        // (most likely actually mbstartframe > epochendframe since the last utterance likely crosses the epoch boundary)
        // in case of multiple datapasses, reset to start when hitting the end
        if (!hasdata() && datapass + 1 < datapasses)
        {
            mbstartframe = firstvalidepochstartframe;
            datapass++;
            fprintf(stderr, "\nminibatchiterator: entering %d-th repeat pass through the data\n", (int) (datapass + 1));
        }
        fillorclear();
    }

    // accessors to current minibatch
    size_t currentmbstartframe() const
    {
        return mbstartframe;
    }
    size_t currentmbframes() const
    {
        return actualmbframes;
    }
    size_t currentmbframesadvanced() const
    {
        return mbframesadvanced;
    }
    size_t currentmblattices() const
    {
        return lattices.size();
    }
    size_t currentdatapass() const
    {
        return datapass;
    } // 0..datapasses-1; use this for sub-sampling
    size_t requestedframes() const
    {
        return requestedmbframes;
    }
    double gettimegetbatch()
    {
        return timegetbatch;
    }
    double gettimechecklattice()
    {
        return timechecklattice;
    }
    bool isfirst() const
    {
        return mbstartframe == firstvalidepochstartframe && datapass == 0;
    }
    float progress() const // (note: 100%+eps possible for last utterance)
    {
        const float epochframes = (float) (epochendframe - epochstartframe);
        return (mbstartframe + mbframesadvanced - epochstartframe + datapass * epochframes) / (datapasses * epochframes);
    }
    std::pair<size_t, size_t> range() const
    {
        return std::make_pair(epochstartframe, epochendframe);
    }

    // return the current minibatch frames as a matrix ref into the feature buffer
    // Number of frames is frames().cols() == currentmbframes().
    // For frame-based randomization, this is 'requestedmbframes' most of the times, while for utterance randomization,
    // this depends highly on the utterance lengths.
    // User is allowed to manipulate the frames... for now--TODO: move silence filtering here as well

    msra::dbn::matrixstripe frames(size_t i)
    {
        checkhasdata();
        assert(featbuf.size() >= i + 1);
        return msra::dbn::matrixstripe(featbuf[i], 0, actualmbframes);
    }

    msra::dbn::matrixstripe frames()
    {
        checkhasdata();
        assert(featbuf.size() == 1);
        return msra::dbn::matrixstripe(featbuf[0], 0, actualmbframes);
    }

    // return the reference transcript labels (state alignment) for current minibatch
    /*const*/ std::vector<size_t> &labels()
    {
        checkhasdata();
        assert(uids.size() == 1);
        return uids[0];
    }
    /*const*/ std::vector<size_t> &labels(size_t i)
    {
        checkhasdata();
        assert(uids.size() >= i + 1);
        return uids[i];
    }

    std::vector<size_t> &sentends()
    {
        checkhasdata();
        assert(sentendmark.size() == 1);
        return sentendmark[0];
    }
    std::vector<size_t> &bounds()
    {
        checkhasdata();
        assert(phoneboundaries.size() == 1);
        return phoneboundaries[0];
    }
    std::vector<size_t> &bounds(size_t i)
    {
        checkhasdata();
        assert(phoneboundaries.size() >= i + 1);
        return phoneboundaries[i];
    }

    // return a lattice for an utterance (caller should first get total through currentmblattices())
    std::shared_ptr<const msra::dbn::latticepair> lattice(size_t uttindex) const
    {
        return lattices[uttindex];
    } // lattices making up the current
    bool haslattice()
    {
        return lattices.size() > 0 ? true : false;
    }
    // return the reference transcript labels (words with alignments) for current minibatch (or empty if no transcripts requested)
    const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word> transcript(size_t uttindex)
    {
        return transcripts.empty() ? const_array_ref<msra::lattices::lattice::htkmlfwordsequence::word>() : transcripts[uttindex];
    }
};
} }
back to top