https://github.com/Microsoft/CNTK
Raw File
Tip revision: 1360e31b10664148665af21c5720825f75463df9 authored by Linquan Liu on 01 March 2019, 06:05:46 UTC
skip utterance with inconsistency </s>
Tip revision: 1360e31
latticeforwardbackward.cpp
// latticearchive.cpp -- managing lattice archives
//
// F. Seide, V-hansu

#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms  --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#endif

#include "Basics.h"
#include "simple_checked_arrays.h"
#include "latticearchive.h"
#include "simplesenonehmm.h" // the model
#include "ssematrix.h"       // the matrices
#include "latticestorage.h"
#include <unordered_map>
#include <list>
#include <stdexcept>
#include <regex>

using namespace std;

#define VIRGINLOGZERO (10 * LOGZERO) // used for printing statistics on unseen states
#undef CPU_VERIFICATION

#ifdef _WIN32
int msra::numa::node_override = -1; // for numahelpers.h
#endif

namespace msra { namespace lattices {

// ---------------------------------------------------------------------------
// helper class for allocation lots of small matrices, no free
// ---------------------------------------------------------------------------

class littlematrixheap
{
    static const size_t CHUNKSIZE;
    typedef msra::math::ssematrixfrombuffer matrixfrombuffer;
    std::list<std::vector<float>> heap;
    size_t allocatedinlast; // in last heap element
    size_t totalallocated;
    std::vector<matrixfrombuffer> matrices;

public:
    littlematrixheap(size_t estimatednumentries)
        : totalallocated(0), allocatedinlast(0)
    {
        matrices.reserve(estimatednumentries + 1);
    }
    msra::math::ssematrixbase &newmatrix(size_t rows, size_t cols)
    {
        const size_t elementsneeded = matrixfrombuffer::elementsneeded(rows, cols);
        if (heap.empty() || (heap.back().size() - allocatedinlast) < elementsneeded)
        {
            const size_t nelem = max(CHUNKSIZE, elementsneeded + 3 /*+3 for SSE alignment*/);
            heap.push_back(std::vector<float>(nelem));
            allocatedinlast = 0;
            // make sure starting element is SSE-aligned (the constructor demands that)
            const size_t offelem = (((size_t) &heap.back()[allocatedinlast]) / sizeof(float)) % 4;
            if (offelem != 0)
                allocatedinlast += 4 - offelem;
        }
        auto &buffer = heap.back();
        if (elementsneeded > heap.back().size() - allocatedinlast)
            LogicError("newmatrix: allocation logic screwed up");
        // get our buffer into a handy vector-like thingy
        array_ref<float> vecbuffer(&buffer[allocatedinlast], elementsneeded);
        // allocate in the current heap location
        matrices.resize(matrices.size() + 1);
        if (matrices.size() + 1 > matrices.capacity())
            LogicError("newmatrix: littlematrixheap cannot grow but was constructed with too small number of eements");
        auto &matrix = matrices.back();
        matrix = matrixfrombuffer(vecbuffer, rows, cols);
        allocatedinlast += elementsneeded;
        totalallocated += elementsneeded;
        return matrix;
    }
};

const size_t littlematrixheap::CHUNKSIZE = 256 * 1024; // 1 MB

// ---------------------------------------------------------------------------
// helpers for log-domain addition
// ---------------------------------------------------------------------------

#ifndef LOGZERO
#define LOGZERO -1e30f
#endif

// logadd (loga, logb) -> a += b, or loga = log [ exp(loga) + exp(logb) ]
static void logaddratio(float &loga, float diff)
{
    if (diff < -17.0f)
        return; // log (2^-24), 23-bit mantissa -> cut of after 24th bit
    loga += logf(1.0f + expf(diff));
}
static void logaddratio(double &loga, double diff)
{
    if (diff < -37.0f)
        return; // log (2^-53), 52-bit mantissa -> cut of after 53th bit
    loga += log(1.0 + exp(diff));
}
// loga <- log (exp (loga) + exp (logb)) = log (exp (loga) * (1.0 + exp (logb - loga)) = loga + log (1.0 + exp (logb - loga))
template <typename FLOAT>
static void logadd(FLOAT &loga, FLOAT logb)
{
    if (logb > loga) // we add smaller to bigger
        ::swap(loga, logb);
    if (loga <= LOGZERO) // both are 0
        return;
    logaddratio(loga, logb - loga);
}
template <typename FLOAT>
static void logmax(FLOAT &loga, FLOAT logb) // for testing (max approx)
{
    if (logb > loga)
        loga = logb;
}

template <typename FLOAT>
static FLOAT expdiff(FLOAT a, FLOAT b) // for testing
{
    if (b > a)
        return exp(b) * (exp(a - b) - 1);
    else
        return exp(a) * (1 - exp(b - a));
}

template <typename FLOAT>
static bool islogzero(FLOAT v)
{
    return v < LOGZERO / 2;
} // is this number to be considered 0

// ---------------------------------------------------------------------------
// other helpers go here
// ---------------------------------------------------------------------------

// helper to reconstruct the phonetic transcript
/*static*/ std::string lattice::gettranscript(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset)
{
    std::string trans;
    foreach_index (k, units) // we exploit that units have fixed boundaries
    {
        if (k > 0)
            trans.push_back(' ');
        trans.append(hset.gethmm(units[k].unit).getname());
    }
    return trans;
}

// ---------------------------------------------------------------------------
// forwardbackwardedge() -- perform state-level forward-backward on a single lattice edge
//
// Results:
//  - gammas(j,t) for valid time ranges (remaining areas are not initialized)
//  - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------

/*static*/ float lattice::forwardbackwardedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
                                              msra::math::ssematrixbase &loggammas, size_t edgeindex)
{
    // alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
    assert(loggammas.cols() == logLLs.cols() + 2);
    msra::math::ssematrixstriperef<msra::math::ssematrixbase> logalphas(loggammas, 1, logLLs.cols()); // shifted views into gammas(,) for alphas and betas
    msra::math::ssematrixstriperef<msra::math::ssematrixbase> logbetas(loggammas, 2, logLLs.cols());

    // alphas(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
    // betas(j,t) store the sum of all paths exiting from state j at time t, not including logLL(j,t)
    // gammas(j,t) = alphas(j,t) * betas(j,t) / totalLL

    // backward pass   --token passing
    size_t te = logbetas.cols();
    size_t je = logbetas.rows();
    float bwscore = 0.0f; // backward score
    for (size_t k = units.size() - 1; k + 1 > 0; k--)
    {
        const auto &hmm = hset.gethmm(units[k].unit);
        const size_t n = hmm.getnumstates();
        const auto &transP = hmm.gettransP();
        const size_t ts = te - units[k].frames; // end time of current unit
        const size_t js = je - n;               // range of state indices

        // pass in the transition score
        // t = ts: exit transition (last frame only or tee transition)
        float exitscore = 1e30f; // (something impossible)
        if (te == ts)            // tee transition
        {
            exitscore = bwscore + transP(-1, n);
        }
        else // not tee: expand all last states
        {
            for (size_t from = 0 /*no tee possible here*/; from < n; from++)
            {
                const size_t i = js + from; // origin trellis node
                logbetas(i, te - 1) = bwscore + transP(from, n);
            }
        }

        // expand from states j at time t (not yet including LL) to time t-1
        for (size_t t = te - 1; t + 1 > ts /*note: cannot test t >= ts because t < 0 possible*/; t--)
        {
            for (size_t to = 0; to < n; to++)
            {
                const size_t j = js + to;             // source trellis node
                const size_t s = hmm.getsenoneid(to); // senone id for state at position 'to' in the HMM
                const float acLL = logLLs(s, t);
                if (islogzero(acLL))
                    fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) ac score(%d,%d) is zero (%d st, %d fr: %s)\n",
                            (int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te,
                            (int) s, (int) t,
                            (int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());
                const float betajt = logbetas(j, t);   // sum over all all path exiting from (j,t) to end
                const float betajtpll = betajt + acLL; // incorporate acoustic score
                if (t > ts)
                    for (size_t from = 0 /*no transition from entry state*/; from < n; from++)
                    {
                        const size_t i = js + from; // target trellis node
                        const float pathscore = betajtpll + transP(from, to);
                        if (to == 0)
                            logbetas(i, t - 1 /*propagate into preceding frame*/) = pathscore;
                        else
                            logadd(logbetas(i, t - 1 /*propagate into preceding frame*/), pathscore);
                    }
                else // transition to entry state
                {
                    const float pathscore = betajtpll + transP(-1, to);
                    if (to == 0)
                        exitscore = pathscore;
                    else
                        logadd(exitscore, pathscore); // propagate into preceding unit
                }
            }
        }

        bwscore = exitscore;
        if (islogzero(bwscore))
            fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) bw score is zero (%d st, %d fr: %s)\n",
                    (int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te, (int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());

        te = ts;
        je = js;
    }
    assert(te == 0 && je == 0);
    const float totalbwscore = bwscore;

    // forward pass   --regular Viterbi
    // This also computes the gammas right away.
    size_t ts = 0;           // start frame for unit 'k'
    size_t js = 0;           // first row index of unit ' k'
    float fwscore = 0.0f;    // score passed across phone boundaries
    foreach_index (k, units) // we exploit that units have fixed boundaries
    {
        const auto &hmm = hset.gethmm(units[k].unit);
        const size_t n = hmm.getnumstates();
        const auto &transP = hmm.gettransP();
        const size_t te2 = ts + units[k].frames; // end time of current unit
        const size_t je2 = js + n;               // range of state indices

        // expand from states j at time t (including LL) to time t+1
        for (size_t t = ts; t < te2; t++) // note: loop not entered for 0-frame units (tees)
        {
            for (size_t to = 0; to < n; to++)
            {
                const size_t j = js + to; // target trellis node
                const size_t s = hmm.getsenoneid(to);
                const float acLL = logLLs(s, t);
                float alphajtnoll = LOGZERO;
                if (t == ts) // entering score
                {
                    const float pathscore = fwscore + transP(-1, to);
                    alphajtnoll = pathscore;
                }
                else
                    for (size_t from = 0 /*no entering possible*/; from < n; from++)
                    {
                        const size_t i = js + from; // origin trellis node
                        const float alphaitm1 = logalphas(i, t - 1 /*previous frame*/);
                        const float pathscore = alphaitm1 + transP(from, to);
                        logadd(alphajtnoll, pathscore);
                    }
                logalphas(j, t) = alphajtnoll + acLL;
            }
            // update the gammas  --do it here because in next frame, betas get overwritten by alphas (they share memory)
            for (size_t j = js; j < je2; j++)
            {
                if (!islogzero(totalbwscore))
                    loggammas(j, t) = logalphas(j, t) + logbetas(j, t) - totalbwscore;
                else // 0/0 problem, can occur if an ac score is so bad that it is 0 after going through softmax
                    loggammas(j, t) = LOGZERO;
            }
        }
        // t = te2: exit transition (last frame only or tee transition)
        float exitscore;
        if (te2 == ts) // tee transition
        {
            exitscore = fwscore + transP(-1, n);
        }
        else // not tee: expand all last states
        {
            exitscore = LOGZERO;
            for (size_t from = 0 /*no tee possible here*/; from < n; from++)
            {
                const size_t i = js + from;                   // origin trellis node
                const float alphaitm1 = logalphas(i, te2 - 1); // newly computed path score, transiting to t=te2
                const float pathscore = alphaitm1 + transP(from, n);
                logadd(exitscore, pathscore);
            }
        }
        fwscore = exitscore; // score passed on to next unit
        js = je2;
        ts = te2;
    }
    assert(js == logalphas.rows() && ts == logalphas.cols());
    const float totalfwscore = fwscore;

    // in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
    // These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
    if (islogzero(totalbwscore) ^ islogzero(totalfwscore))
        fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw 0 score %.10f vs. %.10f (%d st, %d fr: %s)\n",
                (int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());
    if (islogzero(totalbwscore))
    {
        fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
                (int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
        return LOGZERO;
    }

    if (fabsf(totalfwscore - totalbwscore) / ts > 1e-4f)
        fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw score %.10f vs. %.10f (%d st, %d fr: %s)\n",
                (int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());

    // we return the full path score
    return totalfwscore;
}

// ---------------------------------------------------------------------------
// alignedge() -- perform Viterbi alignment on a single edge
//
// This is an alternative to forwardbackwardedge() that just uses the best path.
// Results:
//  - if not returnsenoneids -> 'binary gammas(j,t)' for valid time ranges (remaining areas are not initialized); MMI-compatible
//  - if returnsenoneids ->  loggammas(0,t) will contain the senone ids directly instead (for sMBR mode)
//  - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------

/*static*/ float lattice::alignedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
                                    msra::math::ssematrixbase &loggammas, size_t edgeindex /*for diagnostic messages*/, const bool returnsenoneids,
                                    array_ref<unsigned short> thisedgealignmentsj)
{
    // alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
    assert(loggammas.cols() == logLLs.cols() + 2);
    msra::math::ssematrixstriperef<msra::math::ssematrixbase> backpointers(loggammas, 0, logLLs.cols());
    msra::math::ssematrixstriperef<msra::math::ssematrixbase> pathscores(loggammas, 2, logLLs.cols());

    // pathscores(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
    // backpointers(j,t) are the relative states that it came from
    // gammas(j,t) <- 1 if on best path, 0 otherwise

    const int invalidbp = -2;

    // Viterbi alignment
    size_t ts = 0;           // start frame for unit 'k'
    size_t js = 0;           // first row index of unit 'k'
    float fwscore = 0.0f;    // score passed across phone boundaries
    int fwbackpointer = -1;  // bp passed across phone boundaries, -1 means start of utterance
    foreach_index (k, units) // we exploit that units have fixed boundaries
    {
        const auto &hmm = hset.gethmm(units[k].unit);
        const size_t n = hmm.getnumstates();
        const auto &transP = hmm.gettransP();
        const size_t te = ts + units[k].frames;    // end time of current unit
        const size_t je = js + hmm.getnumstates(); // range of state indices

        // expand from states j at time t (including LL) to time t+1
        for (size_t t = ts; t < te; t++) // note: loop not entered for 0-frame units (tees)
        {
            for (size_t j = js; j < je; j++)
            {
                const size_t to = j - js; // relative state
                const size_t s = hmm.getsenoneid(to);
                pathscores(j, t) = LOGZERO;
                backpointers(j, t) = invalidbp;
                if (t == ts) // entering score
                {
                    const float pathscore = fwscore + transP(-1, to);
                    pathscores(j, t) = pathscore;
                    backpointers(j, t) = (float) fwbackpointer;
                }
                else
                    for (size_t i = js; i < je; i++)
                    {
                        const size_t from = i - js;
                        const float alphaitm1 = pathscores(i, t - 1 /*previous frame*/);
                        const float pathscore = alphaitm1 + transP(from, to);
                        if (pathscore > pathscores(j, t))
                        {
                            pathscores(j, t) = pathscore;
                            backpointers(j, t) = (float) i;
                        }
                    }
                const float acLL = logLLs(s, t);
                pathscores(j, t) += acLL;
            }
        }
        // t = te: exit transition (last frame only or tee transition)
        float exitscore = LOGZERO;
        int exitbackpointer = invalidbp;
        if (te == ts) // tee transition
        {
            exitscore = fwscore + transP(-1, n);
            exitbackpointer = fwbackpointer;
        }
        else // not tee: expand all last states
        {
            for (size_t i = js; i < je; i++)
            {
                const size_t from = i - js;
                const float alphaitm1 = pathscores(i, te - 1); // newly computed path score, transiting to t=te
                const float pathscore = alphaitm1 + transP(from, n);
                if (pathscore > exitscore)
                {
                    exitscore = pathscore;
                    exitbackpointer = (int) i;
                }
            }
        }
        if (exitbackpointer == invalidbp)
            LogicError("exitbackpointer came up empty");
        fwscore = exitscore;             // score passed on to next unit
        fwbackpointer = exitbackpointer; // and accompanying backpointer
        js = je;
        ts = te;
    }
    assert(js == pathscores.rows() && ts == pathscores.cols());

    // in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
    // These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
    if (islogzero(fwscore))
    {
        fprintf(stderr, "alignedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
                (int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
        return LOGZERO;
    }

    // traceback & gamma update
    size_t te = backpointers.cols();
    size_t je = backpointers.rows();
    int j = fwbackpointer;
    for (size_t k = units.size() - 1; k + 1 > 0; k--) // go in units because we also need to clear out the column
    {
        const auto &hmm = hset.gethmm(units[k].unit);
        const size_t ts2 = te - units[k].frames;    // end time of current unit
        const size_t js2 = je - hmm.getnumstates(); // range of state indices
        for (size_t t = te - 1; t + 1 > ts2; t--)
        {
            if (j < (int)js2 || j >= (int) je)
                LogicError("invalid backpointer resulting in state index out of range");

            int bp = (int) backpointers(j, t); // save the backpointer before overwriting it (gammas and backpointers are aliases of each other)
                                               // thisedgealignmentsj[t] = (unsigned short)hmm.getsenoneid(j - js2);
            if (!returnsenoneids)              // return binary gammas (for MMI; this mode is compatible with softalignmode)
                for (size_t i = js2; i < je; i++)
                    loggammas(i, t) = ((int) i == j) ? 0.0f : LOGZERO;
            else // return senone id (for sMBR; note: NOT compatible with softalignmode; calling code must know this)
                thisedgealignmentsj[t] = (unsigned short) hmm.getsenoneid(j - js2);

            if (bp == invalidbp)
                LogicError("deltabackpointer not initialized");
            j = bp; // trace back one step
        }

        te = ts2;
        je = js2;
    }
    if (j != -1)
        LogicError("invalid backpointer resulting in not reaching start of utterance when tracing back");
    assert(je == 0 && te == 0);

    // we return the full path score
    return fwscore;
}

// ---------------------------------------------------------------------------
// forwardbackwardlattice() -- lattice-level forward/backward
//
// This computes word posteriors, and also returns the per-node alphas and betas.
// Per-edge acoustic scores are passed in via a lambda, as this function is
// intended for use at multiple places with different scores.
// (Specifically, we also use it to determine a pruning threshold, based on
// the original lattice's ac. scores, before even bothering to compute the
// new ac. scores.)
// ---------------------------------------------------------------------------

double lattice::forwardbackwardlattice(const std::vector<float> &edgeacscores, parallelstate &parallelstate, std::vector<double> &logpps,
                                       std::vector<double> &logalphas, std::vector<double> &logbetas,
                                       const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode,
                                       const_array_ref<size_t> &uids, const edgealignments &thisedgealignments,
                                       std::vector<double> &logEframescorrect, std::vector<double> &Eframescorrectbuf, double &logEframescorrecttotal) const
{ // ^^ TODO: remove this
    // --- hand off to parallelized (CUDA) implementation if available
    if (parallelstate.enabled())
    {
        double totalfwscore = parallelforwardbackwardlattice(parallelstate, edgeacscores, thisedgealignments, lmf, wp, amf, boostingfactor, logpps, logalphas, logbetas, sMBRmode, uids, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);

        parallelstate.getlogbetas(logbetas);
        if (nodes.size() != logbetas.size())
        {
            // it is possible if #define TWO_CHANNEL in parallelforwardbackward.cpp: in which case, logbetas will be doulbe the size of (nodes)
            if (logbetas.size() != (nodes.size() * 2))
            {
                RuntimeError("forwardbackwardlattice: logbetas size is not equal or twice of node size, logbetas.size() = %d, nodes.size() = %d", int(logbetas.size()), int(nodes.size()));
            }

            //only taket the first half of the data
            logbetas.erase(logbetas.begin() + nodes.size(), logbetas.begin() + logbetas.size());
        }
        return totalfwscore;
    }
    // if we get here, we have no CUDA, and do it the good ol' way

    // allocate return values
    logpps.resize(edges.size()); // this is our primary return value

    // TODO: these are return values as well, but really shouldn't anymore; only used in some older baseline code we some day may want to compare against
    logalphas.assign(nodes.size(), LOGZERO);
    logalphas.front() = 0.0f;
    logbetas.assign(nodes.size(), LOGZERO);
    logbetas.back() = 0.0f;

    // --- sMBR version

    if (sMBRmode)
    {
        logEframescorrect.resize(edges.size());
        Eframescorrectbuf.resize(edges.size());

        std::vector<double> logaccalphas(nodes.size(), LOGZERO); // [i] expected frames-correct count over all paths from start to node i
        std::vector<double> logaccbetas(nodes.size(), LOGZERO);  // [i] likewise
        std::vector<double> logframescorrectedge(edges.size());  // raw counts of correct frames in each edge

        // forward pass
        foreach_index (j, edges)
        {
            if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
                continue;
            const auto &e = edges[j];
            const double inscore = logalphas[e.S];
            const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
            const double pathscore = inscore + edgescore;
            logadd(logalphas[e.E], pathscore);

            size_t ts = nodes[e.S].t;
            size_t te = nodes[e.E].t;
            size_t framescorrect = 0; // count raw number of correct frames
            for (size_t t = ts; t < te; t++)
                framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
            logframescorrectedge[j] = (framescorrect > 0) ? log((double) framescorrect) : LOGZERO; // remember for backward pass
            double loginaccs = logaccalphas[e.S] - logalphas[e.S];
            logadd(loginaccs, logframescorrectedge[j]);
            double logpathacc = loginaccs + logalphas[e.S] + edgescore;
            logadd(logaccalphas[e.E], logpathacc);
        }
        foreach_index (j, logaccalphas)
            logaccalphas[j] -= logalphas[j];

        const double totalfwscore = logalphas.back();
        const double totalfwacc = logaccalphas.back();
        if (islogzero(totalfwscore))
        {
            fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
            return LOGZERO; // failed, do not use resulting matrix
        }

        // backward pass and computation of state-conditioned frames-correct count
        for (size_t j = edges.size() - 1; j + 1 > 0; j--)
        {
            if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
                continue;
            const auto &e = edges[j];
            const double inscore = logbetas[e.E];
            const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
            const double pathscore = inscore + edgescore;
            logadd(logbetas[e.S], pathscore);

            double loginaccs = logaccbetas[e.E] - logbetas[e.E];
            logadd(loginaccs, logframescorrectedge[j]);
            double logpathacc = loginaccs + logbetas[e.E] + edgescore;
            logadd(logaccbetas[e.S], logpathacc);

            // sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
            double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
            if (logpp > 1e-2)
                fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
            if (logpp > 0.0)
                logpp = 0.0;
            logpps[j] = logpp;
            double tmplogeframecorrect = logframescorrectedge[j];
            logadd(tmplogeframecorrect, logaccalphas[e.S]);
            logadd(tmplogeframecorrect, logaccbetas[e.E] - logbetas[e.E]);
            Eframescorrectbuf[j] = exp(tmplogeframecorrect);
        }
        foreach_index (j, logaccbetas)
            logaccbetas[j] -= logbetas[j];
        const double totalbwscore = logbetas.front();
        const double totalbwacc = logaccbetas.front();
        if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
            fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());

        if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
            fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());

        logEframescorrecttotal = totalbwacc;
        return totalbwscore;
    }

    // --- MMI version

    // forward pass
    foreach_index (j, edges)
    {
        const auto &e = edges[j];
        const double inscore = logalphas[e.S];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
        const double pathscore = inscore + edgescore;
        logadd(logalphas[e.E], pathscore);
    }
    const double totalfwscore = logalphas.back();
    if (islogzero(totalfwscore))
    {
        fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
        return LOGZERO; // failed, do not use resulting matrix
    }

    // backward pass
    // this also computes the word posteriors on the fly, since we are at it
    for (size_t j = edges.size() - 1; j + 1 > 0; j--)
    {
        const auto &e = edges[j];
        const double inscore = logbetas[e.E];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        const double pathscore = inscore + edgescore;
        logadd(logbetas[e.S], pathscore);

        // compute lattice posteriors on the fly since we are at it
        double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
        if (logpp > 1e-2)
            fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
        if (logpp > 0.0)
            logpp = 0.0;
        logpps[j] = logpp;
    }

    const double totalbwscore = logbetas.front();
    if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
        fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());

    return totalfwscore;
}

void lattice::constructnodenbestoken(std::vector<NBestToken> &tokenlattice, const bool wordNbest, size_t numtokens2keep, size_t nidx) const
{
    std::map<double, std::vector<PrevTokenInfo>>::iterator mp_itr;
    std::map<uint64_t, std::vector<size_t>> mp_wid_tokenidx;
    std::map<uint64_t, std::vector<size_t>>::iterator mp_itr1;
    size_t count;
    bool done;
    TokenInfo tokeninfo;
    uint64_t wid;
    vector<size_t> vt_tokenidx;

    if (wordNbest) mp_wid_tokenidx.clear();

    count = 0;
    done = false;
    // Sometime,s numtokens is larger than numPathsEMBR. if </s>, keep tokens to be numPathsEMBR

    for (mp_itr = tokenlattice[nidx].mp_score_token_infos.begin(); mp_itr != tokenlattice[nidx].mp_score_token_infos.end(); mp_itr++)
    {        
        for (size_t i = 0; i < mp_itr->second.size(); i++)
        {
            tokeninfo.prev_edge_index = mp_itr->second[i].prev_edge_index;
            tokeninfo.prev_token_index = mp_itr->second[i].prev_token_index;
            tokeninfo.score = mp_itr->second[i].path_score;

            if (wordNbest)
            {
                wid = nodes[edges[tokeninfo.prev_edge_index].S].wid;
                mp_itr1 = mp_wid_tokenidx.find(wid);

                bool different = true;

                if (mp_itr1 == mp_wid_tokenidx.end())
                {
                    // the wid does not exist in previous tokens of this node, so it is a path with different word sequence
                    vt_tokenidx.clear();
                    vt_tokenidx.push_back(count);
                    mp_wid_tokenidx.insert(pair<uint64_t, std::vector<size_t>>(wid, vt_tokenidx));
                }
                else
                {
                    for (size_t j = 0; j < mp_itr1->second.size(); j++)
                    {

                        size_t oldnodeidx, oldtokenidx, newnodeidx, newtokenidx;

                        oldnodeidx = edges[tokenlattice[nidx].vt_nbest_tokens[mp_itr1->second[j]].prev_edge_index].S;
                        oldtokenidx = tokenlattice[nidx].vt_nbest_tokens[mp_itr1->second[j]].prev_token_index;
                        newnodeidx = edges[tokeninfo.prev_edge_index].S; newtokenidx = tokeninfo.prev_token_index;


                        while (1)
                        {
                            if (nodes[oldnodeidx].wid != nodes[newnodeidx].wid) break;
                            if (oldnodeidx == newnodeidx)
                            {
                                if (oldtokenidx == newtokenidx)  different = false;
                                break;
                            }

                            if (oldnodeidx == 0 || newnodeidx == 0)
                            {
                                fprintf(stderr, "nbestlatticeEMBR: WARNING: should not come her, oldnodeidx = %d, newnodeidx = %d\n", int(oldnodeidx), int(newnodeidx));
                                break;
                            }
                            size_t tmpnodeix, tmptokenidx;
                        

                            tmpnodeix = edges[tokenlattice[oldnodeidx].vt_nbest_tokens[oldtokenidx].prev_edge_index].S;
                            tmptokenidx = tokenlattice[oldnodeidx].vt_nbest_tokens[oldtokenidx].prev_token_index;
                            oldnodeidx = tmpnodeix; oldtokenidx = tmptokenidx;
                            

                            tmpnodeix = edges[tokenlattice[newnodeidx].vt_nbest_tokens[newtokenidx].prev_edge_index].S;
                            tmptokenidx = tokenlattice[newnodeidx].vt_nbest_tokens[newtokenidx].prev_token_index;
                            newnodeidx = tmpnodeix; newtokenidx = tmptokenidx;
                        }
                        if (!different) break;
                    }

                    if (different)
                    {
                        mp_itr1->second.push_back(count);
                    }
                }

                if (different)
                {
                    tokenlattice[nidx].vt_nbest_tokens.push_back(tokeninfo);
                    count++;
                }
            }
            else
            {
                tokenlattice[nidx].vt_nbest_tokens.push_back(tokeninfo);
                count++;
            }

            if (count >= numtokens2keep)
            {
                done = true;
                break;
            }
        }
        if (done) break;
    }

    // free the space.
    tokenlattice[nidx].mp_score_token_infos.clear();

}
float compute_wer(vector<size_t> &ref, vector<size_t> &rec)
{
    short ** mat;
    size_t i, j;

    mat = new short*[rec.size() + 1];
    for (i = 0; i <= rec.size(); i++) mat[i] = new short[ref.size() + 1];

    for (i = 0; i <= rec.size(); i++) mat[i][0] = short(i);
    for (j = 1; j <= ref.size(); j++) mat[0][j] = short(j);

    for (i = 1; i <= rec.size(); i++)
        for (j = 1; j <= ref.size(); j++)
        {
            mat[i][j] = mat[i - 1][j - 1];

            if (rec[i - 1] != ref[j - 1])
            {

                if ((mat[i - 1][j]) < mat[i][j]) mat[i][j] = mat[i - 1][j];
                if ((mat[i][j - 1]) < mat[i][j]) mat[i][j] = mat[i][j - 1];
                mat[i][j] ++;
            }
        }
    float wer = float(mat[rec.size()][ref.size()]) / ref.size();
    for (i = 0; i < rec.size(); i++) delete[] mat[i];
    delete[] mat;
    return wer;
}

//linquan's
std::vector<std::wstring> splitword2character(const std::wstring &s)
{
    /*std::vector<std::wstring> char_array;
    std::wstring tgt;
    for (wchar_t ch : s)
    {
    if (ch <= 0x9fa5 && ch >= 0x4e00)
    {
    if (tgt.length() > 0)
    {
    char_array.push_back(tgt);
    tgt.clear();
    }
    char_array.push_back(to_wstring(ch));
    }
    else if (ch == L' ')
    {
    if (tgt.length() > 0)
    {
    char_array.push_back(tgt);
    tgt.clear();
    }
    }
    else
    tgt.push_back(ch);
    }

    if (tgt.length() > 0)
    {
    char_array.push_back(tgt);
    }

    return char_array;*/

    std::wregex words_regex(L"([\u4e00-\u9fa5]|[^\u4e00-\u9fa5\\s]+)");
    auto words_begin = std::wsregex_iterator(s.begin(), s.end(), words_regex);
    auto words_end = std::wsregex_iterator();
    std::vector<std::wstring> tgt;

    for (std::wsregex_iterator it = words_begin; it != words_end; ++it)
    {
        std::wsmatch match = *it;
        std::wstring match_str = match.str();
        tgt.push_back(match_str);
    }

    return tgt;
}

bool istagword(const std::wstring &s)
{
    std::wregex words_regex(L"^[\\<\\[\\{].*?[\\>\\]\\}]$");
    std::wsmatch match;

    if (std::regex_match(s.cbegin(), s.cend(), match, words_regex))
        return true;

    return false;
}

float computewerandcer(std::vector<size_t> &wids, std::vector<size_t> &path_ids, const std::unordered_map<size_t, std::wstring> *ptr_id2wordmap4node)
{
    float wer;
    if (ptr_id2wordmap4node->size() > 0)
    {
        std::vector<std::wstring> refwords;
        std::vector<std::wstring> regwords;
        std::vector<size_t> refid;
        std::vector<size_t> regid;
        std::wstring temp_string;
        std::vector<std::wstring> character_array;
        std::unordered_map<std::wstring, size_t> idmappingtable;

        refwords.clear();
        regwords.clear();
        character_array.clear();
        refid.clear();
        regid.clear();
        idmappingtable.clear();
        std::unordered_map<size_t, std::wstring>::const_iterator maptable_itr;
        for (std::vector<size_t>::const_iterator it = wids.begin(); it != wids.end(); ++it)
        {
            maptable_itr = ptr_id2wordmap4node->find(*it);
            temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
            character_array = splitword2character(temp_string);

            foreach_index(_i, character_array)
            {
                refwords.push_back(character_array[_i]);
                if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
                {
                    idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
                }
            }
        }

        for (std::vector<size_t>::const_iterator it = path_ids.begin(); it != path_ids.end(); ++it)
        {
            maptable_itr = ptr_id2wordmap4node->find(*it);

            temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
            character_array = splitword2character(temp_string);

            foreach_index(_i, character_array)
            {
                regwords.push_back(character_array[_i]);
                if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
                {
                    idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
                }
            }
        }

        //map characters to id to be compatiable with egacy code
        //skip tag words
        foreach_index(_k, refwords)
        {
            if (!istagword(refwords[_k]))
                refid.push_back(idmappingtable.find(refwords[_k])->second);
        }

        foreach_index(_k, regwords)
        {
            if (!istagword(regwords[_k]))
                regid.push_back(idmappingtable.find(regwords[_k])->second);
        }

        wer = compute_wer(refid, regid);
    }
    else
    {
        wer = compute_wer(wids, path_ids);
    }

    return wer;
}

double lattice::nbestlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate &parallelstate, std::vector<NBestToken> &tokenlattice, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords, 
    const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector<size_t> wids) const
{ // ^^ TODO: remove this
  // --- hand off to parallelized (CUDA) implementation if available
    std::map<double, std::vector<PrevTokenInfo>>::iterator mp_itr;
    size_t numtokens2keep;
    // TODO: support parallel state
    parallelstate;
    PrevTokenInfo prevtokeninfo;
    std::vector<PrevTokenInfo> vt_prevtokeninfo;
    
    // if we get here, we have no CUDA, and do it the good ol' way

    // allocate return values
    tokenlattice.resize(nodes.size());
    
    tokenlattice[0].vt_nbest_tokens.resize(1);
    tokenlattice[0].vt_nbest_tokens[0].score = 0.0f;
    tokenlattice[0].vt_nbest_tokens[0].prev_edge_index = 0;
    tokenlattice[0].vt_nbest_tokens[0].prev_token_index = 0;
    // forward pass
    foreach_index(j, edges)
    {
        const auto &e = edges[j];
        if (enforceValidPathEMBR)
        {
            if (e.S == 0 && nodes[e.E].wid != 1) continue;
        }
        if (excludeSpecialWords)
        {
            // 0~4 is: !NULL, <s>, </s>, !sent_start, and !sent_end
            if (nodes[e.E].wid > 4)
            {
                if (is_special_words[e.E]) continue;                
            }
            if (nodes[e.S].wid > 4)
            {
                if (is_special_words[e.S]) continue;
            }

        }

        if (tokenlattice[e.S].mp_score_token_infos.size() != 0)
        {
            //sanity check
            if(tokenlattice[e.S].vt_nbest_tokens.size() != 0)
                RuntimeError("nbestlatticeEMBR: node = %d,  mp_score_token_infos.size() = %d, vt_nbest_tokens.size() = %d, both are not 0!", int(e.S), int(tokenlattice[e.S].mp_score_token_infos.size()), int(tokenlattice[e.S].vt_nbest_tokens.size()));
            
          
           
            // Sometime,s numtokens is larger than numPathsEMBR. if </s>, keep tokens to be numPathsEMBR

            if (nodes[e.S].wid == 2) numtokens2keep = numPathsEMBR;
            else numtokens2keep = numtokens;

            constructnodenbestoken(tokenlattice, wordNbest, numtokens2keep, e.S);

        }

        if (tokenlattice[e.S].vt_nbest_tokens.size() == 0)
        {
            // it is possible to happen, when you exclude specialwords
            continue;
        }
        prevtokeninfo.prev_edge_index = j;

        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned

        for (size_t i = 0; i < tokenlattice[e.S].vt_nbest_tokens.size(); i++)
        {
            prevtokeninfo.prev_token_index = i;

            double pathscore = tokenlattice[e.S].vt_nbest_tokens[i].score + edgescore;

            prevtokeninfo.path_score = pathscore;

            
            if (useAccInNbest && nodes[e.E].wid == 2)
            {
                // add the wegithed path Accuracy into path score

                std::vector<size_t> path, path_ids; // stores the edges in the path
               
                size_t curnodeidx, curtokenidx, prevtokenidx, prevnodeidx;
                // ignore the edge with ending node </s> in the path, as </s> will anyway not be used for WER computation
                path.clear(); // store the edge sequence of the path
                path_ids.clear(); // store the wid sequence of the path
                curnodeidx = e.S;
                curtokenidx = i;
                while (curnodeidx != 0)
                {
                    path.insert(path.begin(), tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index);

                    prevtokenidx = tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_token_index;
                    prevnodeidx = edges[tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index].S;

                    curnodeidx = prevnodeidx;
                    curtokenidx = prevtokenidx;
                 }
                 
                 
                 for (size_t k = 0; k < path.size(); k++)
                 {
                     if (k == 0)
                     {
                         if (!is_special_words[edges[path[k]].S]) path_ids.push_back(nodes[edges[path[k]].S].wid);
                     }
                     if (!is_special_words[edges[path[k]].E]) path_ids.push_back(nodes[edges[path[k]].E].wid);
                 }

                 
                 //linquan
                 //float wer = compute_wer(wids, path_ids);
                 float wer = computewerandcer(wids, path_ids, ptr_id2wordmap4node);
                 // will favor the path with better WER
                 pathscore -= double(accWeightInNbest*wer);

                 // If you only want WER to affect the selection of Nbest, disable the below line. If you aslo want the WER as weight in error computation, enable this line
                 prevtokeninfo.path_score = pathscore;
            }
                
            mp_itr = tokenlattice[e.E].mp_score_token_infos.find(pathscore);
            if (mp_itr != tokenlattice[e.E].mp_score_token_infos.end())
            {
                mp_itr->second.push_back(prevtokeninfo);
            }
            else
            {
                vt_prevtokeninfo.clear();
                vt_prevtokeninfo.push_back(prevtokeninfo);
                tokenlattice[e.E].mp_score_token_infos.insert(std::pair<double, std::vector<PrevTokenInfo>>(pathscore, vt_prevtokeninfo));
            }

        }   
    }
    // for the last node, which is </s> or !NULL (!NULL if you do not merge numerator lattice into denominator lattice)
    numtokens2keep = numPathsEMBR;
    constructnodenbestoken(tokenlattice, wordNbest, numtokens2keep, tokenlattice.size() - 1);
    
    
    double bestscore;
    if (tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size() == 0)
    {
        if (!excludeSpecialWords)   RuntimeError("nbestlatticeEMBR: no token survive while excludeSpecialWords is false");
        else bestscore = LOGZERO;
        
    }
    else bestscore = tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[0].score;
    
    
    if (islogzero(bestscore))
    {
        
        fprintf(stderr, "nbestlatticeEMBR: WARNING: best score is logzero in lattice \n");
        return LOGZERO; // failed, do not use resulting matrix
    }

    
    return bestscore;
}

// ---------------------------------------------------------------------------
// backwardlatticeEMBR() -- lattice-level backward
//
// This computes per-node  betas for EMBR
// ---------------------------------------------------------------------------

double lattice::backwardlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate &parallelstate, std::vector<double> &edgelogbetas, std::vector<double> &logbetas,
    const float lmf, const float wp, const float amf) const
{ // ^^ TODO: remove this
  // --- hand off to parallelized (CUDA) implementation if available
    if (parallelstate.enabled())
    {
        double totalbwscore = parallelbackwardlatticeEMBR(parallelstate, edgeacscores, lmf, wp, amf, edgelogbetas, logbetas);
        
        parallelstate.getlogbetas(logbetas);
        parallelstate.getedgelogbetas(edgelogbetas);
        if (nodes.size() != logbetas.size())
        {
            // it is possible if #define TWO_CHANNEL in parallelforwardbackward.cpp: in which case, logbetas will be doulbe the size of (nodes)
            if (logbetas.size() != (nodes.size() * 2))
            {
                RuntimeError("forwardbackwardlattice: logbetas size is not equal or twice of node size, logbetas.size() = %d, nodes.size() = %d", int(logbetas.size()), int(nodes.size()));
            }

            //only taket the first half of the data
            logbetas.erase(logbetas.begin() + nodes.size(), logbetas.begin() + logbetas.size());
        }

        
        return totalbwscore;
    }
    // if we get here, we have no CUDA, and do it the good ol' way

    // allocate return values

    logbetas.assign(nodes.size(), LOGZERO);
    logbetas.back() = 0.0f;

    edgelogbetas.assign(edges.size(), LOGZERO);

    
    // backward pass
    // this also computes the word posteriors on the fly, since we are at it
    for (size_t j = edges.size() - 1; j + 1 > 0; j--)
    {
        const auto &e = edges[j];
        const double inscore = logbetas[e.E];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        const double pathscore = inscore + edgescore;

        edgelogbetas[j] = pathscore;

        logadd(logbetas[e.S], pathscore);

    }

    const double totalbwscore = logbetas.front();

    if (islogzero(totalbwscore))
    {
        fprintf(stderr, "backwardlatticeEMBR: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int)nodes.size(), (int)edges.size());
        return LOGZERO; // failed, do not use resulting matrix
    }

    return totalbwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardlatticesMBR() -- compute expected frame-accuracy counts,
// both the conditioned one (corresponding to c(q) in Dan Povey's thesis)
// and the global one (which is the sMBR criterion to optimize).
//
// Outputs:
//  - Eframescorrect[j] == expected frames-correct count conditioned on a state of edge[j].
//    We currently assume a hard state alignment. With that, the value turns out
//    to be identical for all states of an edge, so we only store it once per edge.
//  - return value: expected frames-correct count for entire lattice
//
// Call forwardbackwardlattices() first to compute logalphas/betas.
// ---------------------------------------------------------------------------

double lattice::forwardbackwardlatticesMBR(const std::vector<float> &edgeacscores, const msra::asr::simplesenonehmm &hset,
                                           const std::vector<double> &logalphas, const std::vector<double> &logbetas,
                                           const float lmf, const float wp, const float amf, const_array_ref<size_t> &uids,
                                           const edgealignments &thisedgealignments, std::vector<double> &Eframescorrect) const
{
    std::vector<double> accalphas(nodes.size(), 0);  // [i] expected frames-correct count over all paths from start to node i
    std::vector<double> accbetas(nodes.size(), 0);   // [i] likewise
    std::vector<size_t> maxcorrect(nodes.size(), 0); // [i] max correct frames up to this node (oracle)

    std::vector<double> framescorrectedge(edges.size()); // raw counts of correct frames in each edge

    std::vector<int> backpointersformaxcorr(nodes.size(), -2); // keep track of backpointer for the max corr
    backpointersformaxcorr.front() = -1;

    // forward pass
    foreach_index (j, edges)
    {
        if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
            continue;
        const auto &e = edges[j];
        const double inaccs = accalphas[e.S];
        size_t ts = nodes[e.S].t;
        size_t te = nodes[e.E].t;

        size_t framescorrect = 0; // count raw number of correct frames
        for (size_t t = ts; t < te; t++)
            framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
        framescorrectedge[j] = (double) framescorrect; // remember for backward pass

        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        // contribution to end node's path acc = start node's plus edge's correct count, weighted by LL, and divided by sum over LLs
        double pathacc = (inaccs + framescorrectedge[j]) * exp(logalphas[e.S] + edgescore - logalphas[e.E]);
        accalphas[e.E] += pathacc;
        // also keep track of max accuracy, so we can find out whether the lattice contains the correct path
        size_t oracleframescorrect = maxcorrect[e.S] + framescorrect; // keep track of most correct path up to end of this edge
        if (oracleframescorrect > maxcorrect[e.E])
        {
            maxcorrect[e.E] = oracleframescorrect;
            backpointersformaxcorr[size_t(e.E)] = j;
        }
    }
    const double totalfwacc = accalphas.back();

    hset; // just for reference

    // report on ground-truth path
    // TODO: we will later have code that adds this path if needed
    size_t oracleframeacc = maxcorrect.back();
    if (oracleframeacc != info.numframes)
        fprintf(stderr, "forwardbackwardlatticesMBR: ground-truth path missing from lattice (most correct path: %d out of %d frames correct)\n", (unsigned int) oracleframeacc, (int) info.numframes);

    // backward pass and computation of state-conditioned frames-correct count
    for (size_t j = edges.size() - 1; j + 1 > 0; j--)
    {
        if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
            continue;
        const auto &e = edges[j];
        const double inaccs = accbetas[e.E];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        double pathacc = (inaccs + framescorrectedge[j]) * exp(logbetas[e.E] + edgescore - logbetas[e.S]);
        accbetas[e.S] += pathacc;

        // sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
        Eframescorrect[j] = (float) (accalphas[e.S] + accbetas[e.E] + framescorrectedge[j]);
    }

    const double totalbwacc = accbetas.front();

    if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
        fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());

    return totalbwacc;
}

// ---------------------------------------------------------------------------
// bestpathlattice() -- lattice-level "forward/backward" that only returns the
// best path, but in the form of word posteriors, which are 1 or 0, just like
// a real lattice-level forward/backward would do.
// We don't really use this; this was only for a contrast experiment.
// ---------------------------------------------------------------------------

double lattice::bestpathlattice(const std::vector<float> &edgeacscores, std::vector<double> &logpps,
                                const float lmf, const float wp, const float amf) const
{
    // forward pass --sortnedness => regular Viterbi
    std::vector<double> logalphas(nodes.size(), LOGZERO);
    std::vector<int> backpointers(nodes.size(), -2);
    logalphas.front() = 0.0f;
    backpointers.front() = -1;
    foreach_index (j, edges)
    {
        const auto &e = edges[j];
        const double inscore = logalphas[e.S];
        const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
        const double pathscore = inscore + edgescore;
        if (pathscore > logalphas[e.E])
        {
            logalphas[e.E] = pathscore;
            backpointers[e.E] = j;
        }
    }

    const double totalfwscore = logalphas.back();
    if (islogzero(totalfwscore))
    {
        fprintf(stderr, "bestpathlattice: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
        return LOGZERO; // failed, do not use resulting matrix
    }

    // traceback
    // We encode the result by storing log 1 in edges on the best path, and log 0 else;
    // this makes it naturally compatible with softalign mode
    logpps.resize(edges.size());
    foreach_index (j, edges)
        logpps[j] = LOGZERO;

    int backpos = backpointers[nodes.size() - 1];
    while (backpos >= 0)
    {
        logpps[backpos] = 0.0f; // edge is on best path -> PP = 1.0
        backpos = backpointers[edges[backpos].S];
    }
    assert(backpos == -1);

    return totalfwscore;
}

// ---------------------------------------------------------------------------
// forwardbackwardalign() -- compute the statelevel gammas or viterbi alignments
// the first phase of lattice::forwardbackward
//
// Outputs:
// ---------------------------------------------------------------------------
void lattice::forwardbackwardalign(parallelstate &parallelstate,
                                   const msra::asr::simplesenonehmm &hset, const bool softalignstates,
                                   const double minlogpp, const std::vector<double> &origlogpps,
                                   std::vector<msra::math::ssematrixbase *> &abcs, littlematrixheap &matrixheap,
                                   const bool returnsenoneids,
                                   std::vector<float> &edgeacscores, const msra::math::ssematrixbase &logLLs,
                                   edgealignments &thisedgealignments, backpointers &thisbackpointers, array_ref<size_t> &uids, const_array_ref<size_t> bounds) const
{ // NOTE: this will be removed and replaced by a proper representation of alignments someday
    // do forward-backward or alignment on a per-edge basis. This gives us:
    //  - per-edge gamma[j,t] = P(s(t)==s_j|edge) if forwardbackward, per-edge alignment thisedgealignments[j] if alignment
    //  - per-edge acoustic scores
    const size_t silunitid = hset.gethmmid("sil"); // shall be the same as parallelstate.getsilunitid()
    bool parallelsil = true;
    bool cpuverification = false;

#ifndef PARALLEL_SIL // we use a define to make this marked
    parallelsil = false;
#endif
#ifdef CPU_VERIFICATION
    cpuverification = true;
#endif

    // Phase 1: abcs allocate
    if (!parallelstate.enabled() || !parallelsil || cpuverification) // allocate abcs when 1.parallelstate not enabled (cpu mode); 2. enabled but not PARALLEL_SIL (silence need to be allocate); 3. cpuverfication
    {
        abcs.resize(edges.size(), NULL); // [edge index] -> alpha/beta/gamma matrices for each edge
        size_t countskip = 0;            // if pruning: count how many edges are pruned

        foreach_index (j, edges)
        {
            // determine number of frames
            // TODO: this is not efficient--we only use a block-diagonal-like structure, rest is empty (exploiting the fixed boundaries)
            const size_t edgeframes = nodes[edges[j].E].t - nodes[edges[j].S].t;
            if (edgeframes == 0) // dummy !NULL edge at end of lattice
            {
                if ((size_t) j != edges.size() - 1)
                    RuntimeError("forwardbackwardalign: unxpected 0-frame edge (only allowed at very end)");
                // note: abcs[j] is already initialized to be NULL in this case, which protects us from accidentally using it
            }
            else
            {
                // determine the number of states in an edge
                const auto &aligntokens = getaligninfo(j); // get alignment tokens
                size_t edgestates = 0;

                bool edgehassil = false;
                foreach_index (i, aligntokens)
                    if (aligntokens[i].unit == silunitid)
                        edgehassil = true;

                if (!cpuverification && !edgehassil && parallelstate.enabled()) // !cpuverification, parallel & is non sil, we do not allocate
                {
                    abcs[j] = NULL;
                    continue;
                }

                foreach_index (k, aligntokens)
                    edgestates += hset.gethmm(aligntokens[k].unit).getnumstates();

                // allocate the matrix
                if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
                    countskip++;
                else
                    abcs[j] = &matrixheap.newmatrix(edgestates, edgeframes + 2); // +2 to have one extra column for betas and one for gammas
            }
        }
        if (minlogpp > LOGZERO)
            fprintf(stderr, "forwardbackwardalign: %d of %d edges pruned\n", (int) countskip, (int) edges.size());
    }

    // Phase 2: alignment on CPU
    if (parallelstate.enabled() && !parallelsil) // silence edge shall be process separately if not cuda and not PARALLEL_SIL
    {
        if (softalignstates)
            LogicError("forwardbackwardalign: parallelized version currently only handles hard alignments");
        if (minlogpp > LOGZERO)
            fprintf(stderr, "forwardbackwardalign: pruning not supported (we won't need it!) :)\n");
        edgeacscores.resize(edges.size());
        for (size_t j = 0; j < edges.size(); j++)
        {
            const auto &aligntokens = getaligninfo(j); // get alignment tokens
            if (aligntokens.size() == 0)
                continue;
            bool edgehassil = false;
            foreach_index (i, aligntokens)
            {
                if (aligntokens[i].unit == silunitid)
                    edgehassil = true;
            }
            if (!edgehassil) // only process sil
                continue;
            const edgeinfowithscores &e = edges[j];
            const size_t ts = nodes[e.S].t;
            const size_t te = nodes[e.E].t;
            const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
            edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, true, thisedgealignments[j]);
        }
    }

    // Phase 3: alignment on GPU
    if (parallelstate.enabled())
        parallelforwardbackwardalign(parallelstate, hset, logLLs, edgeacscores, thisedgealignments, thisbackpointers);

    // zhaorui align to reference mlf
    if (bounds.size() > 0)
    {
        size_t framenum = bounds.size();

        msra::math::ssematrixbase *refabcs;
        size_t ts, te, t;
        ts = te = 0;

        vector<aligninfo> refinfo(1);
        vector<unsigned short> refalign(framenum);

        array_ref<aligninfo> refunits(refinfo.data(), 1);
        array_ref<unsigned short> refedgealignmentsj(refalign.data(), framenum);

        while (te < framenum)
        {
            // found one phone's boundary (ts, te)
            t = ts + 1;
            while (t < framenum && bounds[t] == 0)
                t++;
            te = t;

            // make one phone unit
            size_t phoneid = bounds[ts] - 1;
            refunits[0].unit = phoneid;
            refunits[0].frames = te - ts;

            size_t edgestates = hset.gethmm(phoneid).getnumstates();
            littlematrixheap refmatrixheap(1); // for abcs
            refabcs = &refmatrixheap.newmatrix(edgestates, te - ts + 2);
            const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
            // do alignment
            alignedge((const_array_ref<aligninfo>) refunits, hset, edgeLLs, *refabcs, 0, true, refedgealignmentsj);

            for (t = ts; t < te; t++)
            {
                uids[t] = (size_t) refedgealignmentsj[t - ts];
            }
            ts = te;
        }
    }

    // Phase 4: alignment or forwardbackward on CPU for non parallel mode or verification

    if (!parallelstate.enabled() || cpuverification) // non parallel mode or verification
    {
        edgeacscores.resize(edges.size());
        std::vector<float> edgeacscoresgpu;
        edgealignments thisedgealignmentsgpu(thisedgealignments);
        if (cpuverification)
        {
            parallelstate.getedgeacscores(edgeacscoresgpu);
            parallelstate.copyalignments(thisedgealignmentsgpu);
        }
        foreach_index (j, edges)
        {
            const edgeinfowithscores &e = edges[j];
            const size_t ts = nodes[e.S].t;
            const size_t te = nodes[e.E].t;
            if (ts == te) // dummy !NULL edge at end
                edgeacscores[j] = 0.0f;
            else
            {
                const auto &aligntokens = getaligninfo(j); // get alignment tokens
                const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
                if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
                    edgeacscores[j] = LOGZERO; // will kill word level forwardbackward hypothesis
                else if (softalignstates)
                    edgeacscores[j] = forwardbackwardedge(aligntokens, hset, edgeLLs, *abcs[j], j);
                else
                    edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, returnsenoneids, thisedgealignments[j]);
            }
            if (cpuverification)
            {
                const auto &aligntokens = getaligninfo(j); // get alignment tokens
                bool edgehassil = false;
                foreach_index (i, aligntokens)
                {
                    if (aligntokens[i].unit == silunitid)
                        edgehassil = true;
                }
                if (fabs(edgeacscores[j] - edgeacscoresgpu[j]) > 1e-3)
                {
                    fprintf(stderr, "edge %d, sil ? %d, edgeacscores / edgeacscoresgpu MISMATCH %f v.s. %f, diff %e\n",
                            j, edgehassil ? 1 : 0, (float) edgeacscores[j], (float) edgeacscoresgpu[j],
                            (float) (edgeacscores[j] - edgeacscoresgpu[j]));
                    fprintf(stderr, "aligntokens: ");
                    foreach_index (i, aligntokens)
                        fprintf(stderr, "%d %d; ", i, aligntokens[i].unit);
                    fprintf(stderr, "\n");
                }
                for (size_t t = ts; t < te; t++)
                {
                    if (thisedgealignments[j][t - ts] != thisedgealignmentsgpu[j][t - ts])
                        fprintf(stderr, "edge %d, sil ? %d, time %d, alignment / alignmentgpu MISMATCH %d v.s. %d\n", j, edgehassil ? 1 : 0, (int) (t - ts), thisedgealignments[j][t - ts], thisedgealignmentsgpu[j][t - ts]);
                }
            }
        }
    }

    // make sure thisedgealignment has values for later CPU use
    if (parallelstate.enabled())
    {
        parallelstate.copyalignments(thisedgealignments);
        parallelstate.getedgeacscores(edgeacscores);
    }
}

// compute the error signal for sMBR mode
void lattice::sMBRerrorsignal(parallelstate &parallelstate,
                              msra::math::ssematrixbase &errorsignal, msra::math::ssematrixbase &errorsignalneg, // output
                              const std::vector<double> &logpps, const float amf,
                              double minlogpp, const std::vector<double> &origlogpps, const std::vector<double> &logEframescorrect,
                              const double logEframescorrecttotal, const edgealignments &thisedgealignments) const
{
    if (parallelstate.enabled()) // parallel version
    {
        /*  time measurement for parallel sMBRerrorsignal
            errorsignalcompute: 19.871935 ms (cuda) v.s. 448.711444 ms (emu) */
        if (minlogpp > LOGZERO)
            fprintf(stderr, "sMBRerrorsignal: pruning not supported (we won't need it!) :)\n");
        parallelsMBRerrorsignal(parallelstate, thisedgealignments, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg);
        return;
    }

    //  linear mode
    foreach_coord (i, j, errorsignal)
        errorsignal(i, j) = 0.0f; // Note: we don't actually put anything into the numgammas
    foreach_index (j, edges)
    {
        const auto &e = edges[j];
        if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
            continue;
        if (minlogpp > LOGZERO && origlogpps[j] < minlogpp) // this is pruned
            continue;

        size_t ts = nodes[e.S].t;
        size_t te = nodes[e.E].t;

        const double diff = logEframescorrect[j] - logEframescorrecttotal;
        // Note: the contribution of the states of an edge to their senones is the same for all states
        // so we compute it once and add it to all; this will not be the case without hard alignments.
        const double pp = exp(logpps[j]); // edge posterior
        const float edgecorrect = (float) (pp * diff) / amf;
        for (size_t t = ts; t < te; t++)
        {
            const size_t s = thisedgealignments[j][t - ts];
            errorsignal(s, t) += edgecorrect;
        }
    }
}

// compute the error signal for sMBR mode
size_t sample_from_cumulative_prob(const std::vector<double> &cumulative_prob)
{
    if (cumulative_prob.size() < 1)
    {
        RuntimeError("sample_from_cumulative_prob: the number of bins is 0 \n");
    }
    double rand_prob = (double)rand() / (double)RAND_MAX * cumulative_prob.back();
    for (size_t i = 0; i < cumulative_prob.size() - 1; i++)
    {
        if (rand_prob <= cumulative_prob[i]) return i;
    }
    return cumulative_prob.size() - 1;
}

void lattice::EMBRsamplepaths(const std::vector<double> &edgelogbetas,
    const std::vector<double> &logbetas, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const bool excludeSpecialWords,  std::vector<vector<size_t>> & vt_paths) const
{
    // In mp_node_ocp, key is the node id, and value stores the outgoing cumulative locally normalized probability. e.g., if the outgoing probabilities of the node are 0.3 0.1 0.6, the ocp stores: 0.3 0.4 1.0. 
    // This serves as a cache to avoid recomputation if sampling the same node twice
    std::map<size_t, vector<double>> mp_node_ocp; 
    std::map<size_t, vector<double>>::iterator mp_itr;
    std::vector<size_t> path; // stores the edges in the path
    std::vector<double> ocp;

    mp_node_ocp.clear();
    vt_paths.clear();
    size_t curnodeidx, edgeidx; 
    if(enforceValidPathEMBR)
    {
        for (size_t i = 0; i < vt_node_out_edge_indices[0].size(); i++)
        {
            // remove the edge
            if (nodes[edges[vt_node_out_edge_indices[0][i]].E].wid != 1)    lattice::erase_node_out_edges(0, i, i);
        }
    
    }

    // this is inefficent implementation, we should think of efficient ways to do it later
    if (excludeSpecialWords)
    {
        size_t nidx;
        for(size_t j = 0; j < vt_node_out_edge_indices.size(); j++)
        {
            for (size_t i = 0; i < vt_node_out_edge_indices[j].size(); i++)
            {
                // remove the edge
                // 0~4 is: !NULL, <s>, </s>, !sent_start, and !sent_end
                nidx = edges[vt_node_out_edge_indices[j][i]].E;

                if (nodes[nidx].wid > 4)
                {
                    if (is_special_words[nidx])
                    {
                        lattice::erase_node_out_edges(j, i, i);
                        continue;
                    }
                }

                nidx = edges[vt_node_out_edge_indices[j][i]].S;

                if (nodes[nidx].wid > 4)
                {
                    if (is_special_words[nidx])  lattice::erase_node_out_edges(j, i, i);
                }
            }
        }
    }
    while (vt_paths.size() < numPathsEMBR)
    {
        path.clear(); 
        curnodeidx = 0;
        //start sampling from node 0
        bool success = false;

        while(true)
        {
            mp_itr = mp_node_ocp.find(curnodeidx);
            if (mp_itr == mp_node_ocp.end())
            {
                ocp.clear();
                
                for (size_t i = 0; i < vt_node_out_edge_indices[curnodeidx].size(); i++)
                {
                    double prob = exp(edgelogbetas[vt_node_out_edge_indices[curnodeidx][i]] - logbetas[curnodeidx]);
                    if(i == 0)    ocp.push_back(prob);
                    else ocp.push_back(prob + ocp.back());
                } 
                mp_node_ocp.insert(pair<size_t, vector<double>>(curnodeidx, ocp));
                edgeidx = vt_node_out_edge_indices[curnodeidx][sample_from_cumulative_prob(ocp)];

            }
            else
            {
                edgeidx = vt_node_out_edge_indices[curnodeidx][sample_from_cumulative_prob(mp_itr->second)];
            }

            path.push_back(edgeidx);
            curnodeidx = edges[edgeidx].E;
            // the end of lattice is not !NULL (the end of !NULL is deleted in dbn.exe when converting lattice of htk format to chunk)
            // if (nodes[edges[edgeidx].E].t == nodes[edges[edgeidx].S].t)
            // wid = 2 is for </s>, the lattice ends with </s>
            // the node has no outgoing arc
            if (vt_node_out_edge_indices[curnodeidx].size() == 0)
            {
                if ( (nodes[curnodeidx].wid == 2 || nodes[curnodeidx].wid == 0) && nodes[curnodeidx].t == info.numframes)
                {
                    success = true;
                    break;
                }
                else
                {
                    fprintf(stderr, "EMBRsamplepaths: WARNING: the node with index = %d has no outgoing arc, but it is not the node </s> with timing ending with last frame \n", int(curnodeidx));
                    success = false;
                    break;
                }
            }                
        }
        if (success == true)    vt_paths.push_back(path);
    }
    if (vt_paths.size() != numPathsEMBR)
    {
        fprintf(stderr, "EMBRsamplepaths: Error: vt_paths.size() = %d, and  numPathsEMBR = %d \n", int(vt_paths.size()), int(numPathsEMBR));
        exit(-1);
    }
}

void lattice::EMBRnbestpaths(std::vector<NBestToken>& tokenlattice, std::vector<vector<size_t>> & vt_paths, std::vector<double>& path_posterior_probs) const
{

    double log_nbest_posterior_prob;

    path_posterior_probs.resize(tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size());
    log_nbest_posterior_prob = LOGZERO;
    
    for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++)
    {
        logadd(log_nbest_posterior_prob, tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[i].score);
    }
    for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++)
    {
        path_posterior_probs[i] = exp(tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[i].score - log_nbest_posterior_prob);
    }
    std::vector<size_t> path; // stores the edges in the path
    vt_paths.clear();
    size_t curnodeidx, curtokenidx, prevtokenidx, prevnodeidx;

    for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++)
    {
        path.clear();
        curnodeidx = tokenlattice.size() - 1;
        curtokenidx = i;
        while (curnodeidx != 0)
        {
            path.insert(path.begin(), tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index);

            prevtokenidx = tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_token_index;
            prevnodeidx = edges[tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index].S;

            curnodeidx = prevnodeidx;
            curtokenidx = prevtokenidx;

        }
        vt_paths.push_back(path);
    }
}
double lattice::get_edge_weights(std::vector<size_t>& wids, std::vector<std::vector<size_t>>& vt_paths, std::vector<double>& vt_edge_weights, std::vector<double>& vt_path_posterior_probs, string getPathMethodEMBR, double& onebest_wer) const
{
   
    struct PATHINFO
    {
        size_t count;
        float WER;
    };

    std::map<string, PATHINFO> mp_path_info;
    std::map<string, PATHINFO>::iterator mp_itr;
    std::unordered_set<string> set_edge_path;

    std::vector<double> vt_path_weights;
    vt_path_weights.resize(vt_paths.size());
    

    vector<size_t> path_ids;
    double avg_wer;
    avg_wer = 0;

    for (size_t i = 0; i < vt_paths.size(); i++)
    {
        path_ids.clear();

        for (size_t j = 0; j < vt_paths[i].size(); j++)
        {
            if (j == 0)
            {
                if (!is_special_words[edges[vt_paths[i][j]].S]) path_ids.push_back(nodes[edges[vt_paths[i][j]].S].wid);

                nodes[edges[vt_paths[i][j]].S].wid;
            }
            if (!is_special_words[edges[vt_paths[i][j]].E]) path_ids.push_back(nodes[edges[vt_paths[i][j]].E].wid);
            nodes[edges[vt_paths[i][j]].E].wid;
        }
                
        //linquan
        //vt_path_weights[i] = compute_wer(wids, path_ids);
        vt_path_weights[i] = computewerandcer(wids, path_ids, ptr_id2wordmap4node);

        string pathidstr = "$";
        for (size_t j = 0; j < path_ids.size(); j++) pathidstr += ("_" + std::to_string(path_ids[j]));
        mp_itr = mp_path_info.find(pathidstr);
        if (mp_itr != mp_path_info.end())
        {
            mp_itr->second.count++;
        }
        else
        {
            PATHINFO pathinfo;
            pathinfo.count = 1;
            pathinfo.WER = float(vt_path_weights[i]);
            mp_path_info.insert(pair<string, PATHINFO>(pathidstr, pathinfo));
        }

        // this uses weighted avg wer
        avg_wer += (vt_path_weights[i] * vt_path_posterior_probs[i]);

    }
    if (getPathMethodEMBR == "sampling") onebest_wer = -10000;
    else onebest_wer = vt_path_weights[0];
    
    for (size_t i = 0; i < vt_path_weights.size(); i++)
    {
        // loss - mean_loss
        vt_path_weights[i] -= avg_wer;
        if(getPathMethodEMBR == "sampling") vt_path_weights[i] /= (vt_paths.size() - 1);
        else vt_path_weights[i] *= (vt_path_posterior_probs[i]);        
    }


    for (size_t i = 0; i < vt_paths.size(); i++)
    {       
        for (size_t j = 0; j < vt_paths[i].size(); j++)
            // substraction, since we want to minimize the loss function, rather than maximize
            vt_edge_weights[vt_paths[i][j]] -= vt_path_weights[i];
    }
    
    set_edge_path.clear();

    for (size_t i = 0; i < vt_paths.size(); i++)
    {
        string pathedgeidstr = "$";
        for (size_t j = 0; j < vt_paths[i].size(); j++)
        {
            pathedgeidstr += ("_" + std::to_string(vt_paths[i][j]));

        }
        set_edge_path.insert(pathedgeidstr);
    }
    return avg_wer;
}
void lattice::EMBRerrorsignal(parallelstate &parallelstate,
    const edgealignments &thisedgealignments, std::vector<double>& edge_weights, msra::math::ssematrixbase &errorsignal) const

{
    Microsoft::MSR::CNTK::Matrix<float> errorsignalcpu(-1);
    if (parallelstate.enabled()) // parallel version
    {
        parallelstate.setedgeweights(edge_weights);        
        std::vector<double> verify_edge_weights;
        parallelstate.getedgeweights(verify_edge_weights);        
        parallelEMBRerrorsignal(parallelstate, thisedgealignments, edge_weights, errorsignal);
        parallelstate.getgamma(errorsignalcpu);
        return; 
        
    }
    //  linear mode
    foreach_coord(i, j, errorsignal)
        errorsignal(i, j) = 0.0f; // Note: we don't actually put anything into the numgammas

    foreach_index(j, edges)
    {

        const auto &e = edges[j];
        if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
            continue;


        size_t ts = nodes[e.S].t;
        size_t te = nodes[e.E].t;

        for (size_t t = ts; t < te; t++)
        {
            const size_t s = thisedgealignments[j][t - ts];
            errorsignal(s, t) = errorsignal(s, t) + float(edge_weights[j]);
        }
    }
}

// compute the error signal for MMI mode
void lattice::mmierrorsignal(parallelstate &parallelstate, double minlogpp, const std::vector<double> &origlogpps,
                             std::vector<msra::math::ssematrixbase *> &abcs, const bool softalignstates,
                             const std::vector<double> &logpps, const msra::asr::simplesenonehmm &hset,
                             const edgealignments &thisedgealignments, msra::math::ssematrixbase &errorsignal) const
{
    if (parallelstate.enabled())
    {
        if (minlogpp > LOGZERO)
            fprintf(stderr, "mmierrorsignal: pruning not supported (we won't need it!) :)\n");
        if (softalignstates)
            LogicError("mmierrorsignal: parallel version for softalignstates mode is not supported yet");
        parallelmmierrorsignal(parallelstate, thisedgealignments, logpps, errorsignal);
        return;
    }

    for (size_t j = 0; j < (errorsignal).cols(); j++)
        for (size_t i = 0; i < (errorsignal).rows(); i++)
            errorsignal(i, j) = VIRGINLOGZERO; // set to zero  --note: may be in-place with logLLs, which now get overwritten

    // size_t warnings = 0;   // [v-hansu] check code for mmi; search this comment to see all related codes
    foreach_index (j, edges)
    {
        const auto &e = edges[j];
        if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
            continue;
        if (minlogpp > LOGZERO && origlogpps[j] < minlogpp) // this is pruned
            continue;

        const auto &aligntokens = getaligninfo(j); // get alignment tokens
        auto &loggammas = *abcs[j];

        const float edgelogP = (float) logpps[j];
        // if (islogzero (edgelogP))               // we had a 0 prob
        //    continue;

        // accumulate this edge's gamma matrix into target posteriors
        const size_t tedge = nodes[e.S].t;
        size_t ts = 0;                 // time index into gamma matrix
        size_t js = 0;                 // state index into gamma matrix
        foreach_index (k, aligntokens) // we exploit that units have fixed boundaries
        {
            const auto &unit = aligntokens[k];
            const size_t te = ts + unit.frames;
            const auto &hmm = hset.gethmm(unit.unit); // TODO: inline these expressions
            const size_t n = hmm.getnumstates();
            const size_t je = js + n;
            // P(s) = P(s|e) * P(e)
            for (size_t t = ts; t < te; t++)
            {
                const size_t tutt = t + tedge; // time index w.r.t. utterance
                // double logsum = LOGZERO;         // [v-hansu] check code for mmi; search this comment to see all related codes
                for (size_t i = 0; i < n; i++)
                {
                    const size_t j2 = js + i;             // state index for this unit in matrix
                    const size_t s = hmm.getsenoneid(i); // state class index
                    const float gammajt = loggammas(j2, t);
                    const float statelogP = edgelogP + gammajt;
                    logadd(errorsignal(s, tutt), statelogP);
                }
            }
            ts = te;
            js = je;
        }
        assert(ts + 2 == loggammas.cols() && js == loggammas.rows());
    }

    // check normalizedness (is that an actual English word?)
    // also count non-zero probs
    size_t nonzerostates = 0;
    foreach_column (t, errorsignal)
    {
        double logsum = LOGZERO;
        foreach_row (s, errorsignal)
        {
            if (islogzero(errorsignal(s, t)))
                nonzerostates++;
            else
                logadd(logsum, (double) errorsignal(s, t));
            // TODO: count VIRGINLOGZERO, print per frame
        }
        if (fabs(logsum) / errorsignal.rows() > 1e-6)
            fprintf(stderr, "forwardbackward: WARNING: overall posterior column(%d) sum = exp (%.10f) != 1\n", (int) t, logsum);
    }
    fprintf(stderr, "forwardbackward: %.3f%% non-zero state posteriors\n", 100.0f - nonzerostates * 100.0f / errorsignal.rows() / errorsignal.cols());

    // convert to non-log posterior  --that's what we return
    foreach_coord (i, j, errorsignal)
        errorsignal(i, j) = expf(errorsignal(i, j));
}

// compute ground truth's score
// It is critical to get all details consistent with the lattice, to avoid skewing the weights.
// ... TODO: we don't need this to be a class member, actually; try to just make it a 'static' function.
/*static*/ double lattice::scoregroundtruth(const_array_ref<size_t> uids, const_array_ref<htkmlfwordsequence::word> transcript, const std::vector<float> &transcriptunigrams,
                                            const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset, const float lmf, const float wp, const float amf)
{
    if (transcript[0].firstframe != 0) // TODO: should we store the #frames instead? Then we can validate the total duration
        LogicError("scoregroundtruth: first transcript[] token does not start at frame 0");

    // get the silence models, since they are treated specially
    const size_t numframes = logLLs.cols();
    const auto &sil = hset.gethmm(hset.gethmmid("sil"));
    const auto &sp = hset.gethmm(hset.gethmmid("sp"));
    if (sp.numstates != 1 || sil.numstates != 3)
        RuntimeError("scoregroundtruth: only supports 1-state /sp/ and 3-state /sil/ tied to /sp/");
    const size_t silst = sp.senoneids[0];

    // loop over words
    double pathscore = 0.0;
    foreach_index (i, transcript)
    {
        size_t ts = transcript[i].firstframe;
        size_t te = ((size_t) i + 1 < transcript.size()) ? transcript[i + 1].firstframe : numframes;
        if (ts >= te)
            LogicError("scoregroundtruth: transcript[] tokens out of order");
        // acoustic score: loop over frames
        const msra::asr::simplesenonehmm::transP *prevtransP = NULL; // previous transP
        int prevs = -1;                                              // previous state index
        for (size_t t = ts; t < te; t++)
        {
            size_t senoneid = uids[t];
            // recover the transP and state index
            int s;
            const msra::asr::simplesenonehmm::transP *transP;
            if (senoneid == silst)
            {
                if (prevtransP == &sil.gettransP()) // "silst" may be tied to /sp/, and thus the /sil/ center state may be ambiguous
                {
                    transP = prevtransP; // remain in /sil/
                    s = 1;               // note that this will fail if /sp/ can follow /sil/ and if /sil/ may end with "silst" (both not allowed currently)
                }
                else // "silst" -> we are in /sp/
                {
                    transP = &sp.gettransP();
                    s = 0;
                }
            }
            else // all others must be non-ambiguous
            {
                int transPindex = hset.senonetransP(senoneid);
                int sindex = hset.senonestate(senoneid);
                if (transPindex == -1 || sindex == -1)
                    RuntimeError("scoregroundtruth: failed to resolve ambiguous senone %s", hset.getsenonename(senoneid));
                transP = &hset.transPs[transPindex];
                s = sindex;
            }
            // if changing phoneme then add necessary enter/exit score
            if (transP != prevtransP)
            {
                if (prevtransP) // previous exit transition
                    pathscore += (*prevtransP)(prevs, prevtransP->getnumstates());
                prevs = -1; // enter transition
            }
            // add inter-state transP
            pathscore += (*transP)(prevs, s);
            // acoustic score
            pathscore += logLLs(senoneid, t);
            // remember state for next frame
            prevtransP = transP;
            prevs = s;
        }
        // add the last exit transition
        if (prevtransP) // previous exit transition
            pathscore += (*prevtransP)(prevs, prevtransP->getnumstates());
        // need to add a /sp/ tee transition if we don't end in sp, since our transcript dictionary includes a /sp/ at the end of each entry
        if (uids[te - 1] != sp.senoneids[0])
            pathscore += sp.gettransP()(-1, sp.numstates);
        // lm
        pathscore += transcriptunigrams[i] * lmf + wp;
    }
    if (islogzero(pathscore))
        fprintf(stderr, "scoregroundtruth: ground-truth path has zero probability; some model inconsistency, maybe?\n");
    // account for amf
    pathscore /= amf;
    fprintf(stderr, "scoregroundtruth: ground-truth score %.6f (%d frames)\n", pathscore, (int) numframes);
    return pathscore;
}

// ---------------------------------------------------------------------------
// sMBRdiagnostics() -- helper to print some diagnostics for analyzing sMBR results
// ---------------------------------------------------------------------------

// static  // with 'static', compiler will complain if the function is not used (we only compile it in sometimes for diagnostics)
void sMBRdiagnostics(const msra::math::ssematrixbase &errorsignal, const_array_ref<size_t> uids,
                     const_array_ref<size_t> bestpath, const vector<bool> &refseen, const msra::asr::simplesenonehmm &hset)
{
    // TODO:
    //  - print best positive runner-up
    //  - WARN tag if neg > pos
    //  - check the sum and warn if not 0
    //  - indicate whether best path state is correct or not
    size_t numcor = 0;
    size_t numnegbetter = 0; // # frames the neg competitor is better
    size_t numposbetter = 0; // # frames the pos competitor is better
    foreach_column (t, errorsignal)
    {
        const size_t sref = uids[t];
        const char *srefname = hset.getsenonename(sref);
        const size_t sbest = bestpath[t];
        const char *sbestname = hset.getsenonename(sbest);
        if (sref == sbest)
            numcor++;
        // for each frame, print error signal for ground truth and runner up (second largest abs value)
        size_t sneg = SIZE_MAX; // competitor
        float eneg = 0.0f;
        size_t spos = SIZE_MAX; // best positive competitor
        float epos = 0.0f;
        foreach_row (s, errorsignal)
        {
            if (s == sref)
                continue;
            if (errorsignal(s, t) < eneg)
            {
                sneg = s;
                eneg = errorsignal(s, t);
            }
            if (errorsignal(s, t) > epos)
            {
                spos = s;
                epos = errorsignal(s, t);
            }
        }
        if (fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < -eneg)
            numnegbetter++;
        if (fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < epos)
            numposbetter++;
        const char *snegname = sneg == SIZE_MAX ? "-" : hset.getsenonename(sneg);
        const char *sposname = spos == SIZE_MAX ? "-" : hset.getsenonename(spos);
        fprintf(stderr, "e(%d): ref %s: %.6f / %s: %.6f / %s: %.6f / top %s: %.6f%s%s%s%s%s\n",
                (int) t, srefname, errorsignal(sref, t), snegname, eneg, sposname, epos, sbestname, errorsignal(sbest, t),
                sbest == sref ? "" : " ERR",
                fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < 0 ? " INV!!" : "",
                fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < -eneg ? " WEAK" : "",
                fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < epos ? " 2ND" : "",
                refseen[t] ? "" : " NOREF");
    }
    // print this to validate our bestpath computation
    fprintf(stderr, "sMBRdiagnostics: %d frames correct out of %d, %.2f%%, neg better in %d, pos in %d\n",
            (int) numcor, (int) errorsignal.cols(), 100.0f * numcor / errorsignal.cols(),
            (int) numnegbetter, (int) numposbetter);
}

// static  // with 'static', compiler will complain if the function is not used (we only compile it in sometimes for diagnostics)
void sMBRsuppressweirdstuff(msra::math::ssematrixbase &errorsignal, const_array_ref<size_t> uids)
{
    size_t numweird = 0;
    foreach_column (t, errorsignal)
    {
        const size_t sref = uids[t];
        // for each frame, print error signal for ground truth and runner up (second largest abs value)
        const float eref = errorsignal(sref, t);
        bool isweird = eref < 0.0f; // negative for reference!?
        for (size_t s = 0; s < errorsignal.rows() && !isweird; s++)
        {
            if (s == sref)
                continue;
            if (fabs(errorsignal(s, t)) > eref)
                isweird = true;
        }
        if (isweird)
        {
            foreach_row (s, errorsignal)
                errorsignal(s, t) = 0.0f;
            numweird++;
        }
    }
    // print this to validate our bestpath computation
    fprintf(stderr, "sMBRsuppressweirdstuff: %d weird frames out of %d, %.2f%% were flattened\n",
            (int) numweird, (int) errorsignal.cols(), 100.0f * numweird / errorsignal.cols());
}

// ---------------------------------------------------------------------------
// forwardbackward() -- main function for MMI/sMBR
//
// This computes the lattice state-level statistics for sequence training using MMI or sMBR.
//
// Outputs, MMI mode:
//  - result = dengammas = denominator gammas (non-log form)
//  - returns log of sum over all paths' likelihoods (the denominator of the MMI objective)
//  - note: numgammas is not used/touched in MMI mode
//
// Outputs, sMBR mode:
// TODO: fix this comment
//  - result = errorsignal = (abs value of) negative contributions to error signal
//  - errorsignalbuf = for temporarily use to get errorsignal
//  - returns expected frames-correct count (the sMBR objective)
// ---------------------------------------------------------------------------

double lattice::forwardbackward(parallelstate &parallelstate, const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset,
                                msra::math::ssematrixbase &result, msra::math::ssematrixbase &errorsignalbuf,
                                const float lmf, const float wp, const float amf, const float boostingfactor,
                                const bool sMBRmode, const bool EMBR, const string  EMBRUnit, const size_t numPathsEMBR, const bool enforceValidPathEMBR,  const string getPathMethodEMBR, const string showWERMode, const bool excludeSpecialWords,  const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numRawPathsEMBR, array_ref<size_t> uids, vector<size_t> wids, const_array_ref<size_t> bounds,
                                const_array_ref<htkmlfwordsequence::word> transcript, const std::vector<float> &transcriptunigram) const                                
{
   
    std::vector<NBestToken> tokenlattice;
    tokenlattice.clear();

    if (wids.size() == 0) return 0;
    
    if (numPathsEMBR < 1)
    {
        fprintf(stderr, "forwardbackward: WARNING: numPathsEMBR = %d , which is smaller than 1\n", (int)numPathsEMBR);
        return LOGZERO; // failed, do not use resulting matrix
    }
    if (EMBRUnit != "word")
    {
        fprintf(stderr, "forwardbackward: Error: Currently do not support EMBR unit other than word\n");
        return LOGZERO; // failed, do not use resulting matrix
    }
    // sanity check

    if (nodes[0].wid != 0)    RuntimeError("The first node is not 0 (i.e.) !NULL, but is %d \n", int(nodes[0].wid));

    // the lattice last node could be either 0 or 2, i.e., if it is an merged lattice (merged numerator and denominator the dnb code dedicately removes ending !NULL,  it is 0. If it is not merged lattice (the one that I changed TAMER code to only use denominator lattice), the last node could be !NULL
    if(nodes[nodes.size()-1].wid != 2 && nodes[nodes.size() - 1].wid != 0) RuntimeError("The last node is not 2 (i.e.) </s> or 0 (i.e, !NULL), but is %d \n", int(nodes[0].wid));
    // I want to make sure there is only one </s>, it is crucial when the useAccinNbest is true: we add sentence acc into nbest cost function in the </s>.
    size_t sent_end_count = 0;

    if (nodes[nodes.size() - 1].wid == 2) sent_end_count = 1;

    for (size_t i = 1; i < nodes.size() - 1; i++)
    {
        if (nodes[i].wid == 2)
        {
            if (nodes[nodes.size() - 1].wid == 2 && (i != nodes.size() - 1))
            {
                //RuntimeError("The  node %d wid is  2 (i.e.) </s>, but it is not the last node, total number of node is %d \n", int(i), int(nodes.size()));
                fprintf(stderr, "The  node %d wid is  2 (i.e.) </s>, but it is not the last node, total number of node is %d \n", int(i), int(nodes.size()));
                return LOGZERO; // bad data, do not use resulting matrix
            }
            sent_end_count++;
        }
        if (nodes[i].wid == 0) RuntimeError("The  node %d wid is  0 (i.e.) </s>, but it is not the first node or last node, total number of node is %d \n", int(i), int(nodes.size()));

    }
    if (sent_end_count != 1)
    {
        RuntimeError("</s> count is not 1 in the lattice, but %d, and total number of node is %d \n", int(sent_end_count), int(nodes.size()));
    }
    
    bool softalign = true;
    bool softalignstates = false;      // true if soft alignment within edges, currently we only support soft within edge in cpu mode
    bool softalignlattice = softalign; // w.r.t. whole lattice

    edgealignments thisedgealignments(*this);   // alignments memory allocate for this lattice
    backpointers thisbackpointers(*this, hset); // memory for forwardbackward

    if (info.numframes != logLLs.cols())
        LogicError("forwardbackward: #frames mismatch between lattice (%d) and LLs (%d)", (int) info.numframes, (int) logLLs.cols());
    // TODO: the following checks should throw, but I don't dare in case this will crash a critical job... if we never see this warning, then
    if (info.numframes != uids.size())
        fprintf(stderr, "forwardbackward: #frames mismatch between lattice (%d) and uids (%d)\n", (int) info.numframes, (int) uids.size());
    if (info.numframes != result.cols())
        fprintf(stderr, "forwardbackward: #frames mismatch between lattice (%d) and result (%d)\n", (int) info.numframes, (int) result.cols());

    littlematrixheap matrixheap(info.numedges); // for abcs

    // PHASE 0: fake word level forward backwards --only used when pruning enabled
    const double minlogpp = LOGZERO; // pruning threshold  --LOGZERO means disabled
    std::vector<double> origlogpps;  // word posterior from original lattice, for pruning decision

    // PHASE 1: per-edge forward backwards (="time alignments")

    // score the ground truth  --only if a transcript is provided, which happens if the user provides a language model
    // TODO: no longer used, remove this. 'transcript' parameter is no longer used in this function.
    transcript;
    transcriptunigram;

    // allocate alpha/beta/gamma matrices (all are sharing the same memory in-place)
    std::vector<msra::math::ssematrixbase *> abcs;
    std::vector<float> edgeacscores; // [edge index] acoustic scores    
    // return senone id for EMBR or sMBR, but not for MMI
    forwardbackwardalign(parallelstate, hset, softalignstates, minlogpp, origlogpps, abcs, matrixheap, (sMBRmode || EMBR) /*returnsenoneids*/, edgeacscores, logLLs, thisedgealignments, thisbackpointers, uids, bounds);
// PHASE 2: lattice-level forward backward

// we exploit that the lattice is sorted by (end node, start node) for in-place processing

// checklattice();      // comment out by v-hansu to save time
#ifdef PRINT_TIME_MEASUREMENT
    auto_timer latlevelfwbw;
#endif
    std::vector<double> logpps;
    std::vector<double> Eframescorrectbuf; // this is used for compute the Eframescorrectdiff
    std::vector<double> logEframescorrect; // this is the final output of PHASE 2
    std::vector<double> logalphas;
    std::vector<double> logbetas;

    std::vector<double> edgelogbetas; // the edge score plus the edge's outgoing node's beta scores
    double totalfwscore = 0; // TODO: name no longer precise in sMBRmode
    double logEframescorrecttotal = LOGZERO;
    double totalbwscore = 0;

    bool returnEframescorrect = sMBRmode;
    if (softalignlattice)
    {
        if (EMBR)
        {
            //compute Beta only, 
            if (getPathMethodEMBR == "sampling")
            {
                totalbwscore = backwardlatticeEMBR(edgeacscores, parallelstate, edgelogbetas, logbetas, lmf, wp, amf);
                totalfwscore = totalbwscore; // to make the existing code happy
            }
            else //nbest
            {           
                double bestscore = nbestlatticeEMBR(edgeacscores, parallelstate, tokenlattice, numRawPathsEMBR, enforceValidPathEMBR, excludeSpecialWords, lmf, wp, amf,  wordNbest, useAccInNbest, accWeightInNbest, numPathsEMBR, wids);
                totalfwscore = bestscore; // to make the code happy, it should be called bestscore, rather than totalfwscore though, will fix later
            }
        }
        else
        {
            totalfwscore = forwardbackwardlattice(edgeacscores, parallelstate, logpps, logalphas, logbetas, lmf, wp, amf, boostingfactor, returnEframescorrect, (const_array_ref<size_t> &) uids, thisedgealignments, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);
            if (sMBRmode && !returnEframescorrect)
            logEframescorrecttotal = forwardbackwardlatticesMBR(edgeacscores, hset, logalphas, logbetas, lmf, wp, amf, (const_array_ref<size_t> &) uids, thisedgealignments, Eframescorrectbuf);
            // ^^ BUGBUG not tested
        }
    }
    else
        totalfwscore = bestpathlattice(edgeacscores, logpps, lmf, wp, amf);
#ifdef PRINT_TIME_MEASUREMENT
    latlevelfwbw.show("latlevelfwbw"); // 68.395682 ms
#endif
    if (islogzero(totalfwscore))
    {
        fprintf(stderr, "forwardbackward: WARNING: totalforwardscore is zero: (%d nodes/%d edges), totalfwscore = %f \n", (int)nodes.size(), (int)edges.size(), totalfwscore);
        if(!EMBR || !excludeSpecialWords)
            return LOGZERO; // failed, do not use resulting matrix
    }

    // PHASE 3: compute final state-level posteriors (MMI mode)

    // compute expected frames correct in sMBRmode

    const size_t numframes = logLLs.cols();
    assert(numframes == info.numframes);    
    if (EMBR)
    {
        std::vector<vector<size_t>> vt_paths;
        std::vector<double> edge_weights(edges.size(), 0.0);
        std::vector<double> path_posterior_probs;

        double onebest_wer = 0.0;
        double avg_wer = 0.0;
        // for getPathMethodEMBR=sampling, the onebest_wer does not make any sense, pls. do not  use it                
        // ToDO: if it is logzero(totalfwscore), the criterion shown in the training log is not totally correct: for this problematic utterance, the wer is counted as 0. Problematic in the sense that: we set excludeSpecialWords is true, and found no token survive

        if (!islogzero(totalfwscore))
        {
            // Do path sampling
            if (getPathMethodEMBR == "sampling")
            {
                EMBRsamplepaths(edgelogbetas, logbetas, numPathsEMBR, enforceValidPathEMBR, excludeSpecialWords, vt_paths);
                path_posterior_probs.resize(vt_paths.size(), (1.0 / vt_paths.size()));
            }
            else
            {
                EMBRnbestpaths(tokenlattice, vt_paths, path_posterior_probs);
            }

            avg_wer = get_edge_weights(wids, vt_paths, edge_weights, path_posterior_probs, getPathMethodEMBR, onebest_wer);
        }


        auto &errorsignal = result;
        EMBRerrorsignal(parallelstate, thisedgealignments, edge_weights, errorsignal);
        if(getPathMethodEMBR == "nbest" && showWERMode == "onebest") return onebest_wer;
        else return avg_wer;
    }
    
    else
    {
        // MMI mode
        if (!sMBRmode)
        {
            // we first take the sum in log domain to avoid numerical issues
            auto &dengammas = result; // result is denominator gammas
            mmierrorsignal(parallelstate, minlogpp, origlogpps, abcs, softalignstates, logpps, hset, thisedgealignments, dengammas);
            return totalfwscore / numframes; // return value is av. posterior
        }
        // sMBR mode
        else
        {
            auto &errorsignal = result;
            sMBRerrorsignal(parallelstate, errorsignal, errorsignalbuf, logpps, amf, minlogpp, origlogpps, logEframescorrect, logEframescorrecttotal, thisedgealignments);

            static bool dummyvariable = (fprintf(stderr, "note: new version with kappa adjustment, kappa = %.2f\n", 1 / amf), true); // we only print once
            return exp(logEframescorrecttotal) / numframes;                                                                          // return value is av. expected frame-correct count
        }
    }
}
};
};
back to top