https://github.com/Microsoft/CNTK
Raw File
Tip revision: 9ab0e793c594a66a7049207d634be68ee7c26201 authored by Vadim Mazalov on 15 August 2018, 23:12:34 UTC
Remove template definition
Tip revision: 9ab0e79
latticeforwardbackward.cpp
// latticearchive.cpp -- managing lattice archives
//
// F. Seide, V-hansu

#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms  --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#endif

#include "Basics.h"
#include "simple_checked_arrays.h"
#include "latticearchive.h"
#include "simplesenonehmm.h" // the model
#include "ssematrix.h"       // the matrices
#include "latticestorage.h"
#include <unordered_map>
#include <list>
#include <stdexcept>

using namespace std;

#define VIRGINLOGZERO (10 * LOGZERO) // used for printing statistics on unseen states
#undef CPU_VERIFICATION

#ifdef _WIN32
int msra::numa::node_override = -1; // for numahelpers.h
#endif

namespace msra { namespace lattices {

// ---------------------------------------------------------------------------
// helper class for allocation lots of small matrices, no free
// ---------------------------------------------------------------------------

class littlematrixheap
{
    static const size_t CHUNKSIZE;
    typedef msra::math::ssematrixfrombuffer matrixfrombuffer;
    std::list<std::vector<float>> heap;
    size_t allocatedinlast; // in last heap element
    size_t totalallocated;
    std::vector<matrixfrombuffer> matrices;

public:
    littlematrixheap(size_t estimatednumentries)
        : totalallocated(0), allocatedinlast(0)
    {
        matrices.reserve(estimatednumentries + 1);
    }
    msra::math::ssematrixbase &newmatrix(size_t rows, size_t cols)
    {
        const size_t elementsneeded = matrixfrombuffer::elementsneeded(rows, cols);
        if (heap.empty() || (heap.back().size() - allocatedinlast) < elementsneeded)
        {
            const size_t nelem = max(CHUNKSIZE, elementsneeded + 3 /*+3 for SSE alignment*/);
            heap.push_back(std::vector<float>(nelem));
            allocatedinlast = 0;
            // make sure starting element is SSE-aligned (the constructor demands that)
            const size_t offelem = (((size_t) &heap.back()[allocatedinlast]) / sizeof(float)) % 4;
            if (offelem != 0)
                allocatedinlast += 4 - offelem;
        }
        auto &buffer = heap.back();
        if (elementsneeded > heap.back().size() - allocatedinlast)
            LogicError("newmatrix: allocation logic screwed up");
        // get our buffer into a handy vector-like thingy
        array_ref<float> vecbuffer(&buffer[allocatedinlast], elementsneeded);
        // allocate in the current heap location
        matrices.resize(matrices.size() + 1);
        if (matrices.size() + 1 > matrices.capacity())
            LogicError("newmatrix: littlematrixheap cannot grow but was constructed with too small number of eements");
        auto &matrix = matrices.back();
        matrix = matrixfrombuffer(vecbuffer, rows, cols);
        allocatedinlast += elementsneeded;
        totalallocated += elementsneeded;
        return matrix;
    }
};

const size_t littlematrixheap::CHUNKSIZE = 256 * 1024; // 1 MB

// ---------------------------------------------------------------------------
// helpers for log-domain addition
// ---------------------------------------------------------------------------

#ifndef LOGZERO
#define LOGZERO -1e30f
#endif

// logadd (loga, logb) -> a += b, or loga = log [ exp(loga) + exp(logb) ]
static void logaddratio(float &loga, float diff)
{
    if (diff < -17.0f)
        return; // log (2^-24), 23-bit mantissa -> cut of after 24th bit
    loga += logf(1.0f + expf(diff));
}
static void logaddratio(double &loga, double diff)
{
    if (diff < -37.0f)
        return; // log (2^-53), 52-bit mantissa -> cut of after 53th bit
    loga += log(1.0 + exp(diff));
}
// loga <- log (exp (loga) + exp (logb)) = log (exp (loga) * (1.0 + exp (logb - loga)) = loga + log (1.0 + exp (logb - loga))
template <typename FLOAT>
static void logadd(FLOAT &loga, FLOAT logb)
{
    if (logb > loga) // we add smaller to bigger
        ::swap(loga, logb);
    if (loga <= LOGZERO) // both are 0
        return;
    logaddratio(loga, logb - loga);
}
template <typename FLOAT>
static void logmax(FLOAT &loga, FLOAT logb) // for testing (max approx)
{
    if (logb > loga)
        loga = logb;
}

template <typename FLOAT>
static FLOAT expdiff(FLOAT a, FLOAT b) // for testing
{
    if (b > a)
        return exp(b) * (exp(a - b) - 1);
    else
        return exp(a) * (1 - exp(b - a));
}

template <typename FLOAT>
static bool islogzero(FLOAT v)
{
    return v < LOGZERO / 2;
} // is this number to be considered 0

// ---------------------------------------------------------------------------
// other helpers go here
// ---------------------------------------------------------------------------

// helper to reconstruct the phonetic transcript
/*static*/ std::string lattice::gettranscript(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset)
{
    std::string trans;
    foreach_index (k, units) // we exploit that units have fixed boundaries
    {
        if (k > 0)
            trans.push_back(' ');
        trans.append(hset.gethmm(units[k].unit).getname());
    }
    return trans;
}

// ---------------------------------------------------------------------------
// forwardbackwardedge() -- perform state-level forward-backward on a single lattice edge
//
// Results:
//  - gammas(j,t) for valid time ranges (remaining areas are not initialized)
//  - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------

/*static*/ float lattice::forwardbackwardedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
                                              msra::math::ssematrixbase &loggammas, size_t edgeindex)
{
    // alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
    assert(loggammas.cols() == logLLs.cols() + 2);
    msra::math::ssematrixstriperef<msra::math::ssematrixbase> logalphas(loggammas, 1, logLLs.cols()); // shifted views into gammas(,) for alphas and betas
    msra::math::ssematrixstriperef<msra::math::ssematrixbase> logbetas(loggammas, 2, logLLs.cols());

    // alphas(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
    // betas(j,t) store the sum of all paths exiting from state j at time t, not including logLL(j,t)
    // gammas(j,t) = alphas(j,t) * betas(j,t) / totalLL

    // backward pass   --token passing
    size_t te = logbetas.cols();
    size_t je = logbetas.rows();
    float bwscore = 0.0f; // backward score
    for (size_t k = units.size() - 1; k + 1 > 0; k--)
    {
        const auto &hmm = hset.gethmm(units[k].unit);
        const size_t n = hmm.getnumstates();
        const auto &transP = hmm.gettransP();
        const size_t ts = te - units[k].frames; // end time of current unit
        const size_t js = je - n;               // range of state indices

        // pass in the transition score
        // t = ts: exit transition (last frame only or tee transition)
        float exitscore = 1e30f; // (something impossible)
        if (te == ts)            // tee transition
        {
            exitscore = bwscore + transP(-1, n);
        }
        else // not tee: expand all last states
        {
            for (size_t from = 0 /*no tee possible here*/; from < n; from++)
            {
                const size_t i = js + from; // origin trellis node
                logbetas(i, te - 1) = bwscore + transP(from, n);
            }
        }

        // expand from states j at time t (not yet including LL) to time t-1
        for (size_t t = te - 1; t + 1 > ts /*note: cannot test t >= ts because t < 0 possible*/; t--)
        {
            for (size_t to = 0; to < n; to++)
            {
                const size_t j = js + to;             // source trellis node
                const size_t s = hmm.getsenoneid(to); // senone id for state at position 'to' in the HMM
                const float acLL = logLLs(s, t);
                if (islogzero(acLL))
                    fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) ac score(%d,%d) is zero (%d st, %d fr: %s)\n",
                            (int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te,
                            (int) s, (int) t,
                            (int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());
                const float betajt = logbetas(j, t);   // sum over all all path exiting from (j,t) to end
                const float betajtpll = betajt + acLL; // incorporate acoustic score
                if (t > ts)
                    for (size_t from = 0 /*no transition from entry state*/; from < n; from++)
                    {
                        const size_t i = js + from; // target trellis node
                        const float pathscore = betajtpll + transP(from, to);
                        if (to == 0)
                            logbetas(i, t - 1 /*propagate into preceding frame*/) = pathscore;
                        else
                            logadd(logbetas(i, t - 1 /*propagate into preceding frame*/), pathscore);
                    }
                else // transition to entry state
                {
                    const float pathscore = betajtpll + transP(-1, to);
                    if (to == 0)
                        exitscore = pathscore;
                    else
                        logadd(exitscore, pathscore); // propagate into preceding unit
                }
            }
        }

        bwscore = exitscore;
        if (islogzero(bwscore))
            fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) bw score is zero (%d st, %d fr: %s)\n",
                    (int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te, (int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());

        te = ts;
        je = js;
    }
    assert(te == 0 && je == 0);
    const float totalbwscore = bwscore;

    // forward pass   --regular Viterbi
    // This also computes the gammas right away.
    size_t ts = 0;           // start frame for unit 'k'
    size_t js = 0;           // first row index of unit ' k'
    float fwscore = 0.0f;    // score passed across phone boundaries
    foreach_index (k, units) // we exploit that units have fixed boundaries
    {
        const auto &hmm = hset.gethmm(units[k].unit);
        const size_t n = hmm.getnumstates();
        const auto &transP = hmm.gettransP();
        const size_t te2 = ts + units[k].frames; // end time of current unit
        const size_t je2 = js + n;               // range of state indices

        // expand from states j at time t (including LL) to time t+1
        for (size_t t = ts; t < te2; t++) // note: loop not entered for 0-frame units (tees)
        {
            for (size_t to = 0; to < n; to++)
            {
                const size_t j = js + to; // target trellis node
                const size_t s = hmm.getsenoneid(to);
                const float acLL = logLLs(s, t);
                float alphajtnoll = LOGZERO;
                if (t == ts) // entering score
                {
                    const float pathscore = fwscore + transP(-1, to);
                    alphajtnoll = pathscore;
                }
                else
                    for (size_t from = 0 /*no entering possible*/; from < n; from++)
                    {
                        const size_t i = js + from; // origin trellis node
                        const float alphaitm1 = logalphas(i, t - 1 /*previous frame*/);
                        const float pathscore = alphaitm1 + transP(from, to);
                        logadd(alphajtnoll, pathscore);
                    }
                logalphas(j, t) = alphajtnoll + acLL;
            }
            // update the gammas  --do it here because in next frame, betas get overwritten by alphas (they share memory)
            for (size_t j = js; j < je2; j++)
            {
                if (!islogzero(totalbwscore))
                    loggammas(j, t) = logalphas(j, t) + logbetas(j, t) - totalbwscore;
                else // 0/0 problem, can occur if an ac score is so bad that it is 0 after going through softmax
                    loggammas(j, t) = LOGZERO;
            }
        }
        // t = te2: exit transition (last frame only or tee transition)
        float exitscore;
        if (te2 == ts) // tee transition
        {
            exitscore = fwscore + transP(-1, n);
        }
        else // not tee: expand all last states
        {
            exitscore = LOGZERO;
            for (size_t from = 0 /*no tee possible here*/; from < n; from++)
            {
                const size_t i = js + from;                   // origin trellis node
                const float alphaitm1 = logalphas(i, te2 - 1); // newly computed path score, transiting to t=te2
                const float pathscore = alphaitm1 + transP(from, n);
                logadd(exitscore, pathscore);
            }
        }
        fwscore = exitscore; // score passed on to next unit
        js = je2;
        ts = te2;
    }
    assert(js == logalphas.rows() && ts == logalphas.cols());
    const float totalfwscore = fwscore;

    // in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
    // These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
    if (islogzero(totalbwscore) ^ islogzero(totalfwscore))
        fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw 0 score %.10f vs. %.10f (%d st, %d fr: %s)\n",
                (int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());
    if (islogzero(totalbwscore))
    {
        fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
                (int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
        return LOGZERO;
    }

    if (fabsf(totalfwscore - totalbwscore) / ts > 1e-4f)
        fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw score %.10f vs. %.10f (%d st, %d fr: %s)\n",
                (int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());

    // we return the full path score
    return totalfwscore;
}

// ---------------------------------------------------------------------------
// alignedge() -- perform Viterbi alignment on a single edge
//
// This is an alternative to forwardbackwardedge() that just uses the best path.
// Results:
//  - if not returnsenoneids -> 'binary gammas(j,t)' for valid time ranges (remaining areas are not initialized); MMI-compatible
//  - if returnsenoneids ->  loggammas(0,t) will contain the senone ids directly instead (for sMBR mode)
//  - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------

/*static*/ float lattice::alignedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
                                    msra::math::ssematrixbase &loggammas, size_t edgeindex /*for diagnostic messages*/, const bool returnsenoneids,
                                    array_ref<unsigned short> thisedgealignmentsj)
{
    // alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
    assert(loggammas.cols() == logLLs.cols() + 2);
    msra::math::ssematrixstriperef<msra::math::ssematrixbase> backpointers(loggammas, 0, logLLs.cols());
    msra::math::ssematrixstriperef<msra::math::ssematrixbase> pathscores(loggammas, 2, logLLs.cols());

    // pathscores(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
    // backpointers(j,t) are the relative states that it came from
    // gammas(j,t) <- 1 if on best path, 0 otherwise

    const int invalidbp = -2;

    // Viterbi alignment
    size_t ts = 0;           // start frame for unit 'k'
    size_t js = 0;           // first row index of unit 'k'
    float fwscore = 0.0f;    // score passed across phone boundaries
    int fwbackpointer = -1;  // bp passed across phone boundaries, -1 means start of utterance
    foreach_index (k, units) // we exploit that units have fixed boundaries
    {
        const auto &hmm = hset.gethmm(units[k].unit);
        const size_t n = hmm.getnumstates();
        const auto &transP = hmm.gettransP();
        const size_t te = ts + units[k].frames;    // end time of current unit
        const size_t je = js + hmm.getnumstates(); // range of state indices

        // expand from states j at time t (including LL) to time t+1
        for (size_t t = ts; t < te; t++) // note: loop not entered for 0-frame units (tees)
        {
            for (size_t j = js; j < je; j++)
            {
                const size_t to = j - js; // relative state
                const size_t s = hmm.getsenoneid(to);
                pathscores(j, t) = LOGZERO;
                backpointers(j, t) = invalidbp;
                if (t == ts) // entering score
                {
                    const float pathscore = fwscore + transP(-1, to);
                    pathscores(j, t) = pathscore;
                    backpointers(j, t) = (float) fwbackpointer;
                }
                else
                    for (size_t i = js; i < je; i++)
                    {
                        const size_t from = i - js;
                        const float alphaitm1 = pathscores(i, t - 1 /*previous frame*/);
                        const float pathscore = alphaitm1 + transP(from, to);
                        if (pathscore > pathscores(j, t))
                        {
                            pathscores(j, t) = pathscore;
                            backpointers(j, t) = (float) i;
                        }
                    }
                const float acLL = logLLs(s, t);
                pathscores(j, t) += acLL;
            }
        }
        // t = te: exit transition (last frame only or tee transition)
        float exitscore = LOGZERO;
        int exitbackpointer = invalidbp;
        if (te == ts) // tee transition
        {
            exitscore = fwscore + transP(-1, n);
            exitbackpointer = fwbackpointer;
        }
        else // not tee: expand all last states
        {
            for (size_t i = js; i < je; i++)
            {
                const size_t from = i - js;
                const float alphaitm1 = pathscores(i, te - 1); // newly computed path score, transiting to t=te
                const float pathscore = alphaitm1 + transP(from, n);
                if (pathscore > exitscore)
                {
                    exitscore = pathscore;
                    exitbackpointer = (int) i;
                }
            }
        }
        if (exitbackpointer == invalidbp)
            LogicError("exitbackpointer came up empty");
        fwscore = exitscore;             // score passed on to next unit
        fwbackpointer = exitbackpointer; // and accompanying backpointer
        js = je;
        ts = te;
    }
    assert(js == pathscores.rows() && ts == pathscores.cols());

    // in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
    // These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
    if (islogzero(fwscore))
    {
        fprintf(stderr, "alignedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
                (int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
        return LOGZERO;
    }

    // traceback & gamma update
    size_t te = backpointers.cols();
    size_t je = backpointers.rows();
    int j = fwbackpointer;
    for (size_t k = units.size() - 1; k + 1 > 0; k--) // go in units because we also need to clear out the column
    {
        const auto &hmm = hset.gethmm(units[k].unit);
        const size_t ts2 = te - units[k].frames;    // end time of current unit
        const size_t js2 = je - hmm.getnumstates(); // range of state indices
        for (size_t t = te - 1; t + 1 > ts2; t--)
        {
            if (j < (int)js2 || j >= (int) je)
                LogicError("invalid backpointer resulting in state index out of range");

            int bp = (int) backpointers(j, t); // save the backpointer before overwriting it (gammas and backpointers are aliases of each other)
                                               // thisedgealignmentsj[t] = (unsigned short)hmm.getsenoneid(j - js2);
            if (!returnsenoneids)              // return binary gammas (for MMI; this mode is compatible with softalignmode)
                for (size_t i = js2; i < je; i++)
                    loggammas(i, t) = ((int) i == j) ? 0.0f : LOGZERO;
            else // return senone id (for sMBR; note: NOT compatible with softalignmode; calling code must know this)
                thisedgealignmentsj[t] = (unsigned short) hmm.getsenoneid(j - js2);

            if (bp == invalidbp)
                LogicError("deltabackpointer not initialized");
            j = bp; // trace back one step
        }

        te = ts2;
        je = js2;
    }
    if (j != -1)
        LogicError("invalid backpointer resulting in not reaching start of utterance when tracing back");
    assert(je == 0 && te == 0);

    // we return the full path score
    return fwscore;
}

// ---------------------------------------------------------------------------
// forwardbackwardlattice() -- lattice-level forward/backward
//
// This computes word posteriors, and also returns the per-node alphas and betas.
// Per-edge acoustic scores are passed in via a lambda, as this function is
// intended for use at multiple places with different scores.
// (Specifically, we also use it to determine a pruning threshold, based on
// the original lattice's ac. scores, before even bothering to compute the
// new ac. scores.)
// ---------------------------------------------------------------------------

double lattice::forwardbackwardlattice(const std::vector<float> &edgeacscores, parallelstate &parallelstate, std::vector<double> &logpps,
                                       std::vector<double> &logalphas, std::vector<double> &logbetas,
                                       const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode,
                                       const_array_ref<size_t> &uids, const edgealignments &thisedgealignments,
                                       std::vector<double> &logEframescorrect, std::vector<double> &Eframescorrectbuf, double &logEframescorrecttotal) const
{ // ^^ TODO: remove this
    // --- hand off to parallelized (CUDA) implementation if available
    if (parallelstate.enabled())
    {
        double totalfwscore = parallelforwardbackwardlattice(parallelstate, edgeacscores, thisedgealignments, lmf, wp, amf, boostingfactor, logpps, logalphas, logbetas, sMBRmode, uids, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);

        return totalfwscore;
    }
    // if we get here, we have no CUDA, and do it the good ol' way

    // allocate return values
    logpps.resize(edges.size()); // this is our primary return value

    // TODO: these are return values as well, but really shouldn't anymore; only used in some older baseline code we some day may want to compare against
    logalphas.assign(nodes.size(), LOGZERO);
    logalphas.front() = 0.0f;
    logbetas.assign(nodes.size(), LOGZERO);
    logbetas.back() = 0.0f;

    // --- sMBR version

    if (sMBRmode)
    {
        logEframescorrect.resize(edges.size());
        Eframescorrectbuf.resize(edges.size());

        std::vector<double> logaccalphas(nodes.size(), LOGZERO); // [i] expected frames-correct count over all paths from start to node i
        std::vector<double> logaccbetas(nodes.size(), LOGZERO);  // [i] likewise
        std::vector<double> logframescorrectedge(edges.size());  // raw counts of correct frames in each edge

        // forward pass
        foreach_index (j, edges)
        {
            if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
                continue;
            const auto &e = edges[j];
            const double inscore = logalphas[e.S];
            const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
            const double pathscore = inscore + edgescore;
            logadd(logalphas[e.E], pathscore);

            size_t ts = nodes[e.S].t;
            size_t te = nodes[e.E].t;
            size_t framescorrect = 0; // count raw number of correct frames
            for (size_t t = ts; t < te; t++)
                framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
            logframescorrectedge[j] = (framescorrect > 0) ? log((double) framescorrect) : LOGZERO; // remember for backward pass
            double loginaccs = logaccalphas[e.S] - logalphas[e.S];
            logadd(loginaccs, logframescorrectedge[j]);
            double logpathacc = loginaccs + logalphas[e.S] + edgescore;
            logadd(logaccalphas[e.E], logpathacc);
        }
        foreach_index (j, logaccalphas)
            logaccalphas[j] -= logalphas[j];

        const double totalfwscore = logalphas.back();
        const double totalfwacc = logaccalphas.back();
        if (islogzero(totalfwscore))
        {
            fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
            return LOGZERO; // failed, do not use resulting matrix
        }

        // backward pass and computation of state-conditioned frames-correct count
        for (size_t j = edges.size() - 1; j + 1 > 0; j--)
        {
            if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
                continue;
            const auto &e = edges[j];
            const double inscore = logbetas[e.E];
            const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
            const double pathscore = inscore + edgescore;
            logadd(logbetas[e.S], pathscore);

            double loginaccs = logaccbetas[e.E] - logbetas[e.E];
            logadd(loginaccs, logframescorrectedge[j]);
            double logpathacc = loginaccs + logbetas[e.E] + edgescore;
            logadd(logaccbetas[e.S], logpathacc);

            // sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
            double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
            if (logpp > 1e-2)
                fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
            if (logpp > 0.0)
                logpp = 0.0;
            logpps[j] = logpp;
            double tmplogeframecorrect = logframescorrectedge[j];
            logadd(tmplogeframecorrect, logaccalphas[e.S]);
            logadd(tmplogeframecorrect, logaccbetas[e.E] - logbetas[e.E]);
            Eframescorrectbuf[j] = exp(tmplogeframecorrect);
        }
        foreach_index (j, logaccbetas)
            logaccbetas[j] -= logbetas[j];
        const double totalbwscore = logbetas.front();
        const double totalbwacc = logaccbetas.front();
        if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
            fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());

        if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
            fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());

        logEframescorrecttotal = totalbwacc;
        return totalbwscore;
    }

    // --- MMI version

    // forward pass
    foreach_index (j, edges)
    {
        const auto &e = edges[j];
        const double inscore = logalphas[e.S];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
        const double pathscore = inscore + edgescore;
        logadd(logalphas[e.E], pathscore);
    }
    const double totalfwscore = logalphas.back();
    if (islogzero(totalfwscore))
    {
        fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
        return LOGZERO; // failed, do not use resulting matrix
    }

    // backward pass
    // this also computes the word posteriors on the fly, since we are at it
    for (size_t j = edges.size() - 1; j + 1 > 0; j--)
    {
        const auto &e = edges[j];
        const double inscore = logbetas[e.E];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        const double pathscore = inscore + edgescore;
        logadd(logbetas[e.S], pathscore);

        // compute lattice posteriors on the fly since we are at it
        double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
        if (logpp > 1e-2)
            fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
        if (logpp > 0.0)
            logpp = 0.0;
        logpps[j] = logpp;
    }

    const double totalbwscore = logbetas.front();
    if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
        fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());

    return totalfwscore;
}

// ---------------------------------------------------------------------------
// forwardbackwardlatticesMBR() -- compute expected frame-accuracy counts,
// both the conditioned one (corresponding to c(q) in Dan Povey's thesis)
// and the global one (which is the sMBR criterion to optimize).
//
// Outputs:
//  - Eframescorrect[j] == expected frames-correct count conditioned on a state of edge[j].
//    We currently assume a hard state alignment. With that, the value turns out
//    to be identical for all states of an edge, so we only store it once per edge.
//  - return value: expected frames-correct count for entire lattice
//
// Call forwardbackwardlattices() first to compute logalphas/betas.
// ---------------------------------------------------------------------------

double lattice::forwardbackwardlatticesMBR(const std::vector<float> &edgeacscores, const msra::asr::simplesenonehmm &hset,
                                           const std::vector<double> &logalphas, const std::vector<double> &logbetas,
                                           const float lmf, const float wp, const float amf, const_array_ref<size_t> &uids,
                                           const edgealignments &thisedgealignments, std::vector<double> &Eframescorrect) const
{
    std::vector<double> accalphas(nodes.size(), 0);  // [i] expected frames-correct count over all paths from start to node i
    std::vector<double> accbetas(nodes.size(), 0);   // [i] likewise
    std::vector<size_t> maxcorrect(nodes.size(), 0); // [i] max correct frames up to this node (oracle)

    std::vector<double> framescorrectedge(edges.size()); // raw counts of correct frames in each edge

    std::vector<int> backpointersformaxcorr(nodes.size(), -2); // keep track of backpointer for the max corr
    backpointersformaxcorr.front() = -1;

    // forward pass
    foreach_index (j, edges)
    {
        if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
            continue;
        const auto &e = edges[j];
        const double inaccs = accalphas[e.S];
        size_t ts = nodes[e.S].t;
        size_t te = nodes[e.E].t;

        size_t framescorrect = 0; // count raw number of correct frames
        for (size_t t = ts; t < te; t++)
            framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
        framescorrectedge[j] = (double) framescorrect; // remember for backward pass

        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        // contribution to end node's path acc = start node's plus edge's correct count, weighted by LL, and divided by sum over LLs
        double pathacc = (inaccs + framescorrectedge[j]) * exp(logalphas[e.S] + edgescore - logalphas[e.E]);
        accalphas[e.E] += pathacc;
        // also keep track of max accuracy, so we can find out whether the lattice contains the correct path
        size_t oracleframescorrect = maxcorrect[e.S] + framescorrect; // keep track of most correct path up to end of this edge
        if (oracleframescorrect > maxcorrect[e.E])
        {
            maxcorrect[e.E] = oracleframescorrect;
            backpointersformaxcorr[size_t(e.E)] = j;
        }
    }
    const double totalfwacc = accalphas.back();

    hset; // just for reference

    // report on ground-truth path
    // TODO: we will later have code that adds this path if needed
    size_t oracleframeacc = maxcorrect.back();
    if (oracleframeacc != info.numframes)
        fprintf(stderr, "forwardbackwardlatticesMBR: ground-truth path missing from lattice (most correct path: %d out of %d frames correct)\n", (unsigned int) oracleframeacc, (int) info.numframes);

    // backward pass and computation of state-conditioned frames-correct count
    for (size_t j = edges.size() - 1; j + 1 > 0; j--)
    {
        if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
            continue;
        const auto &e = edges[j];
        const double inaccs = accbetas[e.E];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        double pathacc = (inaccs + framescorrectedge[j]) * exp(logbetas[e.E] + edgescore - logbetas[e.S]);
        accbetas[e.S] += pathacc;

        // sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
        Eframescorrect[j] = (float) (accalphas[e.S] + accbetas[e.E] + framescorrectedge[j]);
    }

    const double totalbwacc = accbetas.front();

    if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
        fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());

    return totalbwacc;
}

// ---------------------------------------------------------------------------
// bestpathlattice() -- lattice-level "forward/backward" that only returns the
// best path, but in the form of word posteriors, which are 1 or 0, just like
// a real lattice-level forward/backward would do.
// We don't really use this; this was only for a contrast experiment.
// ---------------------------------------------------------------------------

double lattice::bestpathlattice(const std::vector<float> &edgeacscores, std::vector<double> &logpps,
                                const float lmf, const float wp, const float amf) const
{
    // forward pass --sortnedness => regular Viterbi
    std::vector<double> logalphas(nodes.size(), LOGZERO);
    std::vector<int> backpointers(nodes.size(), -2);
    logalphas.front() = 0.0f;
    backpointers.front() = -1;
    foreach_index (j, edges)
    {
        const auto &e = edges[j];
        const double inscore = logalphas[e.S];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
        const double pathscore = inscore + edgescore;
        if (pathscore > logalphas[e.E])
        {
            logalphas[e.E] = pathscore;
            backpointers[e.E] = j;
        }
    }

    const double totalfwscore = logalphas.back();
    if (islogzero(totalfwscore))
    {
        fprintf(stderr, "bestpathlattice: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
        return LOGZERO; // failed, do not use resulting matrix
    }

    // traceback
    // We encode the result by storing log 1 in edges on the best path, and log 0 else;
    // this makes it naturally compatible with softalign mode
    logpps.resize(edges.size());
    foreach_index (j, edges)
        logpps[j] = LOGZERO;

    int backpos = backpointers[nodes.size() - 1];
    while (backpos >= 0)
    {
        logpps[backpos] = 0.0f; // edge is on best path -> PP = 1.0
        backpos = backpointers[edges[backpos].S];
    }
    assert(backpos == -1);

    return totalfwscore;
}

// ---------------------------------------------------------------------------
// forwardbackwardalign() -- compute the statelevel gammas or viterbi alignments
// the first phase of lattice::forwardbackward
//
// Outputs:
// ---------------------------------------------------------------------------
void lattice::forwardbackwardalign(parallelstate &parallelstate,
                                   const msra::asr::simplesenonehmm &hset, const bool softalignstates,
                                   const double minlogpp, const std::vector<double> &origlogpps,
                                   std::vector<msra::math::ssematrixbase *> &abcs, littlematrixheap &matrixheap,
                                   const bool returnsenoneids,
                                   std::vector<float> &edgeacscores, const msra::math::ssematrixbase &logLLs,
                                   edgealignments &thisedgealignments, backpointers &thisbackpointers, array_ref<size_t> &uids, const_array_ref<size_t> bounds) const
{ // NOTE: this will be removed and replaced by a proper representation of alignments someday
    // do forward-backward or alignment on a per-edge basis. This gives us:
    //  - per-edge gamma[j,t] = P(s(t)==s_j|edge) if forwardbackward, per-edge alignment thisedgealignments[j] if alignment
    //  - per-edge acoustic scores
    const size_t silunitid = hset.gethmmid("sil"); // shall be the same as parallelstate.getsilunitid()
    bool parallelsil = true;
    bool cpuverification = false;

#ifndef PARALLEL_SIL // we use a define to make this marked
    parallelsil = false;
#endif
#ifdef CPU_VERIFICATION
    cpuverification = true;
#endif

    // Phase 1: abcs allocate
    if (!parallelstate.enabled() || !parallelsil || cpuverification) // allocate abcs when 1.parallelstate not enabled (cpu mode); 2. enabled but not PARALLEL_SIL (silence need to be allocate); 3. cpuverfication
    {
        abcs.resize(edges.size(), NULL); // [edge index] -> alpha/beta/gamma matrices for each edge
        size_t countskip = 0;            // if pruning: count how many edges are pruned

        foreach_index (j, edges)
        {
            // determine number of frames
            // TODO: this is not efficient--we only use a block-diagonal-like structure, rest is empty (exploiting the fixed boundaries)
            const size_t edgeframes = nodes[edges[j].E].t - nodes[edges[j].S].t;
            if (edgeframes == 0) // dummy !NULL edge at end of lattice
            {
                if ((size_t) j != edges.size() - 1)
                    RuntimeError("forwardbackwardalign: unxpected 0-frame edge (only allowed at very end)");
                // note: abcs[j] is already initialized to be NULL in this case, which protects us from accidentally using it
            }
            else
            {
                // determine the number of states in an edge
                const auto &aligntokens = getaligninfo(j); // get alignment tokens
                size_t edgestates = 0;

                bool edgehassil = false;
                foreach_index (i, aligntokens)
                    if (aligntokens[i].unit == silunitid)
                        edgehassil = true;

                if (!cpuverification && !edgehassil && parallelstate.enabled()) // !cpuverification, parallel & is non sil, we do not allocate
                {
                    abcs[j] = NULL;
                    continue;
                }

                foreach_index (k, aligntokens)
                    edgestates += hset.gethmm(aligntokens[k].unit).getnumstates();

                // allocate the matrix
                if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
                    countskip++;
                else
                    abcs[j] = &matrixheap.newmatrix(edgestates, edgeframes + 2); // +2 to have one extra column for betas and one for gammas
            }
        }
        if (minlogpp > LOGZERO)
            fprintf(stderr, "forwardbackwardalign: %d of %d edges pruned\n", (int) countskip, (int) edges.size());
    }

    // Phase 2: alignment on CPU
    if (parallelstate.enabled() && !parallelsil) // silence edge shall be process separately if not cuda and not PARALLEL_SIL
    {
        if (softalignstates)
            LogicError("forwardbackwardalign: parallelized version currently only handles hard alignments");
        if (minlogpp > LOGZERO)
            fprintf(stderr, "forwardbackwardalign: pruning not supported (we won't need it!) :)\n");
        edgeacscores.resize(edges.size());
        for (size_t j = 0; j < edges.size(); j++)
        {
            const auto &aligntokens = getaligninfo(j); // get alignment tokens
            if (aligntokens.size() == 0)
                continue;
            bool edgehassil = false;
            foreach_index (i, aligntokens)
            {
                if (aligntokens[i].unit == silunitid)
                    edgehassil = true;
            }
            if (!edgehassil) // only process sil
                continue;
            const edgeinfowithscores &e = edges[j];
            const size_t ts = nodes[e.S].t;
            const size_t te = nodes[e.E].t;
            const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
            edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, true, thisedgealignments[j]);
        }
    }

    // Phase 3: alignment on GPU
    if (parallelstate.enabled())
        parallelforwardbackwardalign(parallelstate, hset, logLLs, edgeacscores, thisedgealignments, thisbackpointers);

    // zhaorui align to reference mlf
    if (bounds.size() > 0)
    {
        size_t framenum = bounds.size();

        msra::math::ssematrixbase *refabcs;
        size_t ts, te, t;
        ts = te = 0;

        vector<aligninfo> refinfo(1);
        vector<unsigned short> refalign(framenum);

        array_ref<aligninfo> refunits(refinfo.data(), 1);
        array_ref<unsigned short> refedgealignmentsj(refalign.data(), framenum);

        while (te < framenum)
        {
            // found one phone's boundary (ts, te)
            t = ts + 1;
            while (t < framenum && bounds[t] == 0)
                t++;
            te = t;

            // make one phone unit
            size_t phoneid = bounds[ts] - 1;
            refunits[0].unit = phoneid;
            refunits[0].frames = te - ts;

            size_t edgestates = hset.gethmm(phoneid).getnumstates();
            littlematrixheap refmatrixheap(1); // for abcs
            refabcs = &refmatrixheap.newmatrix(edgestates, te - ts + 2);
            const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
            // do alignment
            alignedge((const_array_ref<aligninfo>) refunits, hset, edgeLLs, *refabcs, 0, true, refedgealignmentsj);

            for (t = ts; t < te; t++)
            {
                uids[t] = (size_t) refedgealignmentsj[t - ts];
            }
            ts = te;
        }
    }

    // Phase 4: alignment or forwardbackward on CPU for non parallel mode or verification

    if (!parallelstate.enabled() || cpuverification) // non parallel mode or verification
    {
        edgeacscores.resize(edges.size());
        std::vector<float> edgeacscoresgpu;
        edgealignments thisedgealignmentsgpu(thisedgealignments);
        if (cpuverification)
        {
            parallelstate.getedgeacscores(edgeacscoresgpu);
            parallelstate.copyalignments(thisedgealignmentsgpu);
        }
        foreach_index (j, edges)
        {
            const edgeinfowithscores &e = edges[j];
            const size_t ts = nodes[e.S].t;
            const size_t te = nodes[e.E].t;
            if (ts == te) // dummy !NULL edge at end
                edgeacscores[j] = 0.0f;
            else
            {
                const auto &aligntokens = getaligninfo(j); // get alignment tokens
                const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
                if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
                    edgeacscores[j] = LOGZERO; // will kill word level forwardbackward hypothesis
                else if (softalignstates)
                    edgeacscores[j] = forwardbackwardedge(aligntokens, hset, edgeLLs, *abcs[j], j);
                else
                    edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, returnsenoneids, thisedgealignments[j]);
            }
            if (cpuverification)
            {
                const auto &aligntokens = getaligninfo(j); // get alignment tokens
                bool edgehassil = false;
                foreach_index (i, aligntokens)
                {
                    if (aligntokens[i].unit == silunitid)
                        edgehassil = true;
                }
                if (fabs(edgeacscores[j] - edgeacscoresgpu[j]) > 1e-3)
                {
                    fprintf(stderr, "edge %d, sil ? %d, edgeacscores / edgeacscoresgpu MISMATCH %f v.s. %f, diff %e\n",
                            j, edgehassil ? 1 : 0, (float) edgeacscores[j], (float) edgeacscoresgpu[j],
                            (float) (edgeacscores[j] - edgeacscoresgpu[j]));
                    fprintf(stderr, "aligntokens: ");
                    foreach_index (i, aligntokens)
                        fprintf(stderr, "%d %d; ", i, aligntokens[i].unit);
                    fprintf(stderr, "\n");
                }
                for (size_t t = ts; t < te; t++)
                {
                    if (thisedgealignments[j][t - ts] != thisedgealignmentsgpu[j][t - ts])
                        fprintf(stderr, "edge %d, sil ? %d, time %d, alignment / alignmentgpu MISMATCH %d v.s. %d\n", j, edgehassil ? 1 : 0, (int) (t - ts), thisedgealignments[j][t - ts], thisedgealignmentsgpu[j][t - ts]);
                }
            }
        }
    }
}

// compute the error signal for sMBR mode
void lattice::sMBRerrorsignal(parallelstate &parallelstate,
                              msra::math::ssematrixbase &errorsignal, msra::math::ssematrixbase &errorsignalneg, // output
                              const std::vector<double> &logpps, const float amf,
                              double minlogpp, const std::vector<double> &origlogpps, const std::vector<double> &logEframescorrect,
                              const double logEframescorrecttotal, const edgealignments &thisedgealignments) const
{
    if (parallelstate.enabled()) // parallel version
    {
        /*  time measurement for parallel sMBRerrorsignal
            errorsignalcompute: 19.871935 ms (cuda) v.s. 448.711444 ms (emu) */
        if (minlogpp > LOGZERO)
            fprintf(stderr, "sMBRerrorsignal: pruning not supported (we won't need it!) :)\n");
        parallelsMBRerrorsignal(parallelstate, thisedgealignments, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg);
        return;
    }

    //  linear mode
    foreach_coord (i, j, errorsignal)
        errorsignal(i, j) = 0.0f; // Note: we don't actually put anything into the numgammas
    foreach_index (j, edges)
    {
        const auto &e = edges[j];
        if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
            continue;
        if (minlogpp > LOGZERO && origlogpps[j] < minlogpp) // this is pruned
            continue;

        size_t ts = nodes[e.S].t;
        size_t te = nodes[e.E].t;

        const double diff = logEframescorrect[j] - logEframescorrecttotal;
        // Note: the contribution of the states of an edge to their senones is the same for all states
        // so we compute it once and add it to all; this will not be the case without hard alignments.
        const double pp = exp(logpps[j]); // edge posterior
        const float edgecorrect = (float) (pp * diff) / amf;
        for (size_t t = ts; t < te; t++)
        {
            const size_t s = thisedgealignments[j][t - ts];
            errorsignal(s, t) += edgecorrect;
        }
    }
}

// compute the error signal for MMI mode
void lattice::mmierrorsignal(parallelstate &parallelstate, double minlogpp, const std::vector<double> &origlogpps,
                             std::vector<msra::math::ssematrixbase *> &abcs, const bool softalignstates,
                             const std::vector<double> &logpps, const msra::asr::simplesenonehmm &hset,
                             const edgealignments &thisedgealignments, msra::math::ssematrixbase &errorsignal) const
{
    if (parallelstate.enabled())
    {
        if (minlogpp > LOGZERO)
            fprintf(stderr, "mmierrorsignal: pruning not supported (we won't need it!) :)\n");
        if (softalignstates)
            LogicError("mmierrorsignal: parallel version for softalignstates mode is not supported yet");
        parallelmmierrorsignal(parallelstate, thisedgealignments, logpps, errorsignal);
        return;
    }

    for (size_t j = 0; j < (errorsignal).cols(); j++)
        for (size_t i = 0; i < (errorsignal).rows(); i++)
            errorsignal(i, j) = VIRGINLOGZERO; // set to zero  --note: may be in-place with logLLs, which now get overwritten

    // size_t warnings = 0;   // [v-hansu] check code for mmi; search this comment to see all related codes
    foreach_index (j, edges)
    {
        const auto &e = edges[j];
        if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
            continue;
        if (minlogpp > LOGZERO && origlogpps[j] < minlogpp) // this is pruned
            continue;

        const auto &aligntokens = getaligninfo(j); // get alignment tokens
        auto &loggammas = *abcs[j];

        const float edgelogP = (float) logpps[j];
        // if (islogzero (edgelogP))               // we had a 0 prob
        //    continue;

        // accumulate this edge's gamma matrix into target posteriors
        const size_t tedge = nodes[e.S].t;
        size_t ts = 0;                 // time index into gamma matrix
        size_t js = 0;                 // state index into gamma matrix
        foreach_index (k, aligntokens) // we exploit that units have fixed boundaries
        {
            const auto &unit = aligntokens[k];
            const size_t te = ts + unit.frames;
            const auto &hmm = hset.gethmm(unit.unit); // TODO: inline these expressions
            const size_t n = hmm.getnumstates();
            const size_t je = js + n;
            // P(s) = P(s|e) * P(e)
            for (size_t t = ts; t < te; t++)
            {
                const size_t tutt = t + tedge; // time index w.r.t. utterance
                // double logsum = LOGZERO;         // [v-hansu] check code for mmi; search this comment to see all related codes
                for (size_t i = 0; i < n; i++)
                {
                    const size_t j2 = js + i;             // state index for this unit in matrix
                    const size_t s = hmm.getsenoneid(i); // state class index
                    const float gammajt = loggammas(j2, t);
                    const float statelogP = edgelogP + gammajt;
                    logadd(errorsignal(s, tutt), statelogP);
                }
            }
            ts = te;
            js = je;
        }
        assert(ts + 2 == loggammas.cols() && js == loggammas.rows());
    }

    // check normalizedness (is that an actual English word?)
    // also count non-zero probs
    size_t nonzerostates = 0;
    foreach_column (t, errorsignal)
    {
        double logsum = LOGZERO;
        foreach_row (s, errorsignal)
        {
            if (islogzero(errorsignal(s, t)))
                nonzerostates++;
            else
                logadd(logsum, (double) errorsignal(s, t));
            // TODO: count VIRGINLOGZERO, print per frame
        }
        if (fabs(logsum) / errorsignal.rows() > 1e-6)
            fprintf(stderr, "forwardbackward: WARNING: overall posterior column(%d) sum = exp (%.10f) != 1\n", (int) t, logsum);
    }
    fprintf(stderr, "forwardbackward: %.3f%% non-zero state posteriors\n", 100.0f - nonzerostates * 100.0f / errorsignal.rows() / errorsignal.cols());

    // convert to non-log posterior  --that's what we return
    foreach_coord (i, j, errorsignal)
        errorsignal(i, j) = expf(errorsignal(i, j));
}

// compute ground truth's score
// It is critical to get all details consistent with the lattice, to avoid skewing the weights.
// ... TODO: we don't need this to be a class member, actually; try to just make it a 'static' function.
/*static*/ double lattice::scoregroundtruth(const_array_ref<size_t> uids, const_array_ref<htkmlfwordsequence::word> transcript, const std::vector<float> &transcriptunigrams,
                                            const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset, const float lmf, const float wp, const float amf)
{
    if (transcript[0].firstframe != 0) // TODO: should we store the #frames instead? Then we can validate the total duration
        LogicError("scoregroundtruth: first transcript[] token does not start at frame 0");

    // get the silence models, since they are treated specially
    const size_t numframes = logLLs.cols();
    const auto &sil = hset.gethmm(hset.gethmmid("sil"));
    const auto &sp = hset.gethmm(hset.gethmmid("sp"));
    if (sp.numstates != 1 || sil.numstates != 3)
        RuntimeError("scoregroundtruth: only supports 1-state /sp/ and 3-state /sil/ tied to /sp/");
    const size_t silst = sp.senoneids[0];

    // loop over words
    double pathscore = 0.0;
    foreach_index (i, transcript)
    {
        size_t ts = transcript[i].firstframe;
        size_t te = ((size_t) i + 1 < transcript.size()) ? transcript[i + 1].firstframe : numframes;
        if (ts >= te)
            LogicError("scoregroundtruth: transcript[] tokens out of order");
        // acoustic score: loop over frames
        const msra::asr::simplesenonehmm::transP *prevtransP = NULL; // previous transP
        int prevs = -1;                                              // previous state index
        for (size_t t = ts; t < te; t++)
        {
            size_t senoneid = uids[t];
            // recover the transP and state index
            int s;
            const msra::asr::simplesenonehmm::transP *transP;
            if (senoneid == silst)
            {
                if (prevtransP == &sil.gettransP()) // "silst" may be tied to /sp/, and thus the /sil/ center state may be ambiguous
                {
                    transP = prevtransP; // remain in /sil/
                    s = 1;               // note that this will fail if /sp/ can follow /sil/ and if /sil/ may end with "silst" (both not allowed currently)
                }
                else // "silst" -> we are in /sp/
                {
                    transP = &sp.gettransP();
                    s = 0;
                }
            }
            else // all others must be non-ambiguous
            {
                int transPindex = hset.senonetransP(senoneid);
                int sindex = hset.senonestate(senoneid);
                if (transPindex == -1 || sindex == -1)
                    RuntimeError("scoregroundtruth: failed to resolve ambiguous senone %s", hset.getsenonename(senoneid));
                transP = &hset.transPs[transPindex];
                s = sindex;
            }
            // if changing phoneme then add necessary enter/exit score
            if (transP != prevtransP)
            {
                if (prevtransP) // previous exit transition
                    pathscore += (*prevtransP)(prevs, prevtransP->getnumstates());
                prevs = -1; // enter transition
            }
            // add inter-state transP
            pathscore += (*transP)(prevs, s);
            // acoustic score
            pathscore += logLLs(senoneid, t);
            // remember state for next frame
            prevtransP = transP;
            prevs = s;
        }
        // add the last exit transition
        if (prevtransP) // previous exit transition
            pathscore += (*prevtransP)(prevs, prevtransP->getnumstates());
        // need to add a /sp/ tee transition if we don't end in sp, since our transcript dictionary includes a /sp/ at the end of each entry
        if (uids[te - 1] != sp.senoneids[0])
            pathscore += sp.gettransP()(-1, sp.numstates);
        // lm
        pathscore += transcriptunigrams[i] * lmf + wp;
    }
    if (islogzero(pathscore))
        fprintf(stderr, "scoregroundtruth: ground-truth path has zero probability; some model inconsistency, maybe?\n");
    // account for amf
    pathscore /= amf;
    fprintf(stderr, "scoregroundtruth: ground-truth score %.6f (%d frames)\n", pathscore, (int) numframes);
    return pathscore;
}

// ---------------------------------------------------------------------------
// sMBRdiagnostics() -- helper to print some diagnostics for analyzing sMBR results
// ---------------------------------------------------------------------------

// static  // with 'static', compiler will complain if the function is not used (we only compile it in sometimes for diagnostics)
void sMBRdiagnostics(const msra::math::ssematrixbase &errorsignal, const_array_ref<size_t> uids,
                     const_array_ref<size_t> bestpath, const vector<bool> &refseen, const msra::asr::simplesenonehmm &hset)
{
    // TODO:
    //  - print best positive runner-up
    //  - WARN tag if neg > pos
    //  - check the sum and warn if not 0
    //  - indicate whether best path state is correct or not
    size_t numcor = 0;
    size_t numnegbetter = 0; // # frames the neg competitor is better
    size_t numposbetter = 0; // # frames the pos competitor is better
    foreach_column (t, errorsignal)
    {
        const size_t sref = uids[t];
        const char *srefname = hset.getsenonename(sref);
        const size_t sbest = bestpath[t];
        const char *sbestname = hset.getsenonename(sbest);
        if (sref == sbest)
            numcor++;
        // for each frame, print error signal for ground truth and runner up (second largest abs value)
        size_t sneg = SIZE_MAX; // competitor
        float eneg = 0.0f;
        size_t spos = SIZE_MAX; // best positive competitor
        float epos = 0.0f;
        foreach_row (s, errorsignal)
        {
            if (s == sref)
                continue;
            if (errorsignal(s, t) < eneg)
            {
                sneg = s;
                eneg = errorsignal(s, t);
            }
            if (errorsignal(s, t) > epos)
            {
                spos = s;
                epos = errorsignal(s, t);
            }
        }
        if (fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < -eneg)
            numnegbetter++;
        if (fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < epos)
            numposbetter++;
        const char *snegname = sneg == SIZE_MAX ? "-" : hset.getsenonename(sneg);
        const char *sposname = spos == SIZE_MAX ? "-" : hset.getsenonename(spos);
        fprintf(stderr, "e(%d): ref %s: %.6f / %s: %.6f / %s: %.6f / top %s: %.6f%s%s%s%s%s\n",
                (int) t, srefname, errorsignal(sref, t), snegname, eneg, sposname, epos, sbestname, errorsignal(sbest, t),
                sbest == sref ? "" : " ERR",
                fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < 0 ? " INV!!" : "",
                fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < -eneg ? " WEAK" : "",
                fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < epos ? " 2ND" : "",
                refseen[t] ? "" : " NOREF");
    }
    // print this to validate our bestpath computation
    fprintf(stderr, "sMBRdiagnostics: %d frames correct out of %d, %.2f%%, neg better in %d, pos in %d\n",
            (int) numcor, (int) errorsignal.cols(), 100.0f * numcor / errorsignal.cols(),
            (int) numnegbetter, (int) numposbetter);
}

// static  // with 'static', compiler will complain if the function is not used (we only compile it in sometimes for diagnostics)
void sMBRsuppressweirdstuff(msra::math::ssematrixbase &errorsignal, const_array_ref<size_t> uids)
{
    size_t numweird = 0;
    foreach_column (t, errorsignal)
    {
        const size_t sref = uids[t];
        // for each frame, print error signal for ground truth and runner up (second largest abs value)
        const float eref = errorsignal(sref, t);
        bool isweird = eref < 0.0f; // negative for reference!?
        for (size_t s = 0; s < errorsignal.rows() && !isweird; s++)
        {
            if (s == sref)
                continue;
            if (fabs(errorsignal(s, t)) > eref)
                isweird = true;
        }
        if (isweird)
        {
            foreach_row (s, errorsignal)
                errorsignal(s, t) = 0.0f;
            numweird++;
        }
    }
    // print this to validate our bestpath computation
    fprintf(stderr, "sMBRsuppressweirdstuff: %d weird frames out of %d, %.2f%% were flattened\n",
            (int) numweird, (int) errorsignal.cols(), 100.0f * numweird / errorsignal.cols());
}

// ---------------------------------------------------------------------------
// forwardbackward() -- main function for MMI/sMBR
//
// This computes the lattice state-level statistics for sequence training using MMI or sMBR.
//
// Outputs, MMI mode:
//  - result = dengammas = denominator gammas (non-log form)
//  - returns log of sum over all paths' likelihoods (the denominator of the MMI objective)
//  - note: numgammas is not used/touched in MMI mode
//
// Outputs, sMBR mode:
// TODO: fix this comment
//  - result = errorsignal = (abs value of) negative contributions to error signal
//  - errorsignalbuf = for temporarily use to get errorsignal
//  - returns expected frames-correct count (the sMBR objective)
// ---------------------------------------------------------------------------

double lattice::forwardbackward(parallelstate &parallelstate, const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset,
                                msra::math::ssematrixbase &result, msra::math::ssematrixbase &errorsignalbuf,
                                const float lmf, const float wp, const float amf, const float boostingfactor,
                                const bool sMBRmode, array_ref<size_t> uids, const_array_ref<size_t> bounds,
                                const_array_ref<htkmlfwordsequence::word> transcript, const std::vector<float> &transcriptunigrams) const
{
    bool softalign = true;
    bool softalignstates = false;      // true if soft alignment within edges, currently we only support soft within edge in cpu mode
    bool softalignlattice = softalign; // w.r.t. whole lattice

    edgealignments thisedgealignments(*this);   // alignments memory allocate for this lattice
    backpointers thisbackpointers(*this, hset); // memory for forwardbackward

    if (info.numframes != logLLs.cols())
        LogicError("forwardbackward: #frames mismatch between lattice (%d) and LLs (%d)", (int) info.numframes, (int) logLLs.cols());
    // TODO: the following checks should throw, but I don't dare in case this will crash a critical job... if we never see this warning, then
    if (info.numframes != uids.size())
        fprintf(stderr, "forwardbackward: #frames mismatch between lattice (%d) and uids (%d)\n", (int) info.numframes, (int) uids.size());
    if (info.numframes != result.cols())
        fprintf(stderr, "forwardbackward: #frames mismatch between lattice (%d) and result (%d)\n", (int) info.numframes, (int) result.cols());

    littlematrixheap matrixheap(info.numedges); // for abcs

    // PHASE 0: fake word level forward backwards --only used when pruning enabled
    const double minlogpp = LOGZERO; // pruning threshold  --LOGZERO means disabled
    std::vector<double> origlogpps;  // word posterior from original lattice, for pruning decision

    // PHASE 1: per-edge forward backwards (="time alignments")

    // score the ground truth  --only if a transcript is provided, which happens if the user provides a language model
    // TODO: no longer used, remove this. 'transcript' parameter is no longer used in this function.
    transcript;
    transcriptunigrams;

    // allocate alpha/beta/gamma matrices (all are sharing the same memory in-place)
    std::vector<msra::math::ssematrixbase *> abcs;
    std::vector<float> edgeacscores; // [edge index] acoustic scores
    // funcation call for forwardbackward on edge level
    forwardbackwardalign(parallelstate, hset, softalignstates, minlogpp, origlogpps, abcs, matrixheap, sMBRmode /*returnsenoneids*/, edgeacscores, logLLs, thisedgealignments, thisbackpointers, uids, bounds);

// PHASE 2: lattice-level forward backward

// we exploit that the lattice is sorted by (end node, start node) for in-place processing

// checklattice();      // comment out by v-hansu to save time
#ifdef PRINT_TIME_MEASUREMENT
    auto_timer latlevelfwbw;
#endif
    std::vector<double> logpps;
    std::vector<double> Eframescorrectbuf; // this is used for compute the Eframescorrectdiff
    std::vector<double> logEframescorrect; // this is the final output of PHASE 2
    std::vector<double> logalphas;
    std::vector<double> logbetas;
    double totalfwscore = 0; // TODO: name no longer precise in sMBRmode
    double logEframescorrecttotal = LOGZERO;

    bool returnEframescorrect = sMBRmode;
    if (softalignlattice)
    {
        totalfwscore = forwardbackwardlattice(edgeacscores, parallelstate, logpps, logalphas, logbetas, lmf, wp, amf, boostingfactor, returnEframescorrect, (const_array_ref<size_t> &) uids, thisedgealignments, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);
        if (sMBRmode && !returnEframescorrect)
            logEframescorrecttotal = forwardbackwardlatticesMBR(edgeacscores, hset, logalphas, logbetas, lmf, wp, amf, (const_array_ref<size_t> &) uids, thisedgealignments, Eframescorrectbuf);
        // ^^ BUGBUG not tested
    }
    else
        totalfwscore = bestpathlattice(edgeacscores, logpps, lmf, wp, amf);
#ifdef PRINT_TIME_MEASUREMENT
    latlevelfwbw.show("latlevelfwbw"); // 68.395682 ms
#endif
    if (islogzero(totalfwscore))
    {
        fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
        return LOGZERO; // failed, do not use resulting matrix
    }

    // PHASE 3: compute final state-level posteriors (MMI mode)

    // compute expected frames correct in sMBRmode

    const size_t numframes = logLLs.cols();
    assert(numframes == info.numframes);
    // fprintf (stderr, "forwardbackward: total forward score %.6f (%d frames)\n", totalfwscore, (int) numframes);   // for now--while we are debugging the GPU port

    // MMI mode
    if (!sMBRmode)
    {
        // we first take the sum in log domain to avoid numerical issues
        auto &dengammas = result; // result is denominator gammas
        mmierrorsignal(parallelstate, minlogpp, origlogpps, abcs, softalignstates, logpps, hset, thisedgealignments, dengammas);
        return totalfwscore / numframes; // return value is av. posterior
    }
    // sMBR mode
    else
    {
        auto &errorsignal = result;
        sMBRerrorsignal(parallelstate, errorsignal, errorsignalbuf, logpps, amf, minlogpp, origlogpps, logEframescorrect, logEframescorrecttotal, thisedgealignments);

        static bool dummyvariable = (fprintf(stderr, "note: new version with kappa adjustment, kappa = %.2f\n", 1 / amf), true); // we only print once
        return exp(logEframescorrecttotal) / numframes;                                                                          // return value is av. expected frame-correct count
    }
}
};
};
back to top