https://github.com/Microsoft/CNTK
Tip revision: 1360e31b10664148665af21c5720825f75463df9 authored by Linquan Liu on 01 March 2019, 06:05:46 UTC
skip utterance with inconsistency </s>
skip utterance with inconsistency </s>
Tip revision: 1360e31
latticeforwardbackward.cpp
// latticearchive.cpp -- managing lattice archives
//
// F. Seide, V-hansu
#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#endif
#include "Basics.h"
#include "simple_checked_arrays.h"
#include "latticearchive.h"
#include "simplesenonehmm.h" // the model
#include "ssematrix.h" // the matrices
#include "latticestorage.h"
#include <unordered_map>
#include <list>
#include <stdexcept>
#include <regex>
using namespace std;
#define VIRGINLOGZERO (10 * LOGZERO) // used for printing statistics on unseen states
#undef CPU_VERIFICATION
#ifdef _WIN32
int msra::numa::node_override = -1; // for numahelpers.h
#endif
namespace msra { namespace lattices {
// ---------------------------------------------------------------------------
// helper class for allocation lots of small matrices, no free
// ---------------------------------------------------------------------------
class littlematrixheap
{
static const size_t CHUNKSIZE;
typedef msra::math::ssematrixfrombuffer matrixfrombuffer;
std::list<std::vector<float>> heap;
size_t allocatedinlast; // in last heap element
size_t totalallocated;
std::vector<matrixfrombuffer> matrices;
public:
littlematrixheap(size_t estimatednumentries)
: totalallocated(0), allocatedinlast(0)
{
matrices.reserve(estimatednumentries + 1);
}
msra::math::ssematrixbase &newmatrix(size_t rows, size_t cols)
{
const size_t elementsneeded = matrixfrombuffer::elementsneeded(rows, cols);
if (heap.empty() || (heap.back().size() - allocatedinlast) < elementsneeded)
{
const size_t nelem = max(CHUNKSIZE, elementsneeded + 3 /*+3 for SSE alignment*/);
heap.push_back(std::vector<float>(nelem));
allocatedinlast = 0;
// make sure starting element is SSE-aligned (the constructor demands that)
const size_t offelem = (((size_t) &heap.back()[allocatedinlast]) / sizeof(float)) % 4;
if (offelem != 0)
allocatedinlast += 4 - offelem;
}
auto &buffer = heap.back();
if (elementsneeded > heap.back().size() - allocatedinlast)
LogicError("newmatrix: allocation logic screwed up");
// get our buffer into a handy vector-like thingy
array_ref<float> vecbuffer(&buffer[allocatedinlast], elementsneeded);
// allocate in the current heap location
matrices.resize(matrices.size() + 1);
if (matrices.size() + 1 > matrices.capacity())
LogicError("newmatrix: littlematrixheap cannot grow but was constructed with too small number of eements");
auto &matrix = matrices.back();
matrix = matrixfrombuffer(vecbuffer, rows, cols);
allocatedinlast += elementsneeded;
totalallocated += elementsneeded;
return matrix;
}
};
const size_t littlematrixheap::CHUNKSIZE = 256 * 1024; // 1 MB
// ---------------------------------------------------------------------------
// helpers for log-domain addition
// ---------------------------------------------------------------------------
#ifndef LOGZERO
#define LOGZERO -1e30f
#endif
// logadd (loga, logb) -> a += b, or loga = log [ exp(loga) + exp(logb) ]
static void logaddratio(float &loga, float diff)
{
if (diff < -17.0f)
return; // log (2^-24), 23-bit mantissa -> cut of after 24th bit
loga += logf(1.0f + expf(diff));
}
static void logaddratio(double &loga, double diff)
{
if (diff < -37.0f)
return; // log (2^-53), 52-bit mantissa -> cut of after 53th bit
loga += log(1.0 + exp(diff));
}
// loga <- log (exp (loga) + exp (logb)) = log (exp (loga) * (1.0 + exp (logb - loga)) = loga + log (1.0 + exp (logb - loga))
template <typename FLOAT>
static void logadd(FLOAT &loga, FLOAT logb)
{
if (logb > loga) // we add smaller to bigger
::swap(loga, logb);
if (loga <= LOGZERO) // both are 0
return;
logaddratio(loga, logb - loga);
}
template <typename FLOAT>
static void logmax(FLOAT &loga, FLOAT logb) // for testing (max approx)
{
if (logb > loga)
loga = logb;
}
template <typename FLOAT>
static FLOAT expdiff(FLOAT a, FLOAT b) // for testing
{
if (b > a)
return exp(b) * (exp(a - b) - 1);
else
return exp(a) * (1 - exp(b - a));
}
template <typename FLOAT>
static bool islogzero(FLOAT v)
{
return v < LOGZERO / 2;
} // is this number to be considered 0
// ---------------------------------------------------------------------------
// other helpers go here
// ---------------------------------------------------------------------------
// helper to reconstruct the phonetic transcript
/*static*/ std::string lattice::gettranscript(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset)
{
std::string trans;
foreach_index (k, units) // we exploit that units have fixed boundaries
{
if (k > 0)
trans.push_back(' ');
trans.append(hset.gethmm(units[k].unit).getname());
}
return trans;
}
// ---------------------------------------------------------------------------
// forwardbackwardedge() -- perform state-level forward-backward on a single lattice edge
//
// Results:
// - gammas(j,t) for valid time ranges (remaining areas are not initialized)
// - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------
/*static*/ float lattice::forwardbackwardedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
msra::math::ssematrixbase &loggammas, size_t edgeindex)
{
// alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
assert(loggammas.cols() == logLLs.cols() + 2);
msra::math::ssematrixstriperef<msra::math::ssematrixbase> logalphas(loggammas, 1, logLLs.cols()); // shifted views into gammas(,) for alphas and betas
msra::math::ssematrixstriperef<msra::math::ssematrixbase> logbetas(loggammas, 2, logLLs.cols());
// alphas(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
// betas(j,t) store the sum of all paths exiting from state j at time t, not including logLL(j,t)
// gammas(j,t) = alphas(j,t) * betas(j,t) / totalLL
// backward pass --token passing
size_t te = logbetas.cols();
size_t je = logbetas.rows();
float bwscore = 0.0f; // backward score
for (size_t k = units.size() - 1; k + 1 > 0; k--)
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t ts = te - units[k].frames; // end time of current unit
const size_t js = je - n; // range of state indices
// pass in the transition score
// t = ts: exit transition (last frame only or tee transition)
float exitscore = 1e30f; // (something impossible)
if (te == ts) // tee transition
{
exitscore = bwscore + transP(-1, n);
}
else // not tee: expand all last states
{
for (size_t from = 0 /*no tee possible here*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
logbetas(i, te - 1) = bwscore + transP(from, n);
}
}
// expand from states j at time t (not yet including LL) to time t-1
for (size_t t = te - 1; t + 1 > ts /*note: cannot test t >= ts because t < 0 possible*/; t--)
{
for (size_t to = 0; to < n; to++)
{
const size_t j = js + to; // source trellis node
const size_t s = hmm.getsenoneid(to); // senone id for state at position 'to' in the HMM
const float acLL = logLLs(s, t);
if (islogzero(acLL))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) ac score(%d,%d) is zero (%d st, %d fr: %s)\n",
(int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te,
(int) s, (int) t,
(int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());
const float betajt = logbetas(j, t); // sum over all all path exiting from (j,t) to end
const float betajtpll = betajt + acLL; // incorporate acoustic score
if (t > ts)
for (size_t from = 0 /*no transition from entry state*/; from < n; from++)
{
const size_t i = js + from; // target trellis node
const float pathscore = betajtpll + transP(from, to);
if (to == 0)
logbetas(i, t - 1 /*propagate into preceding frame*/) = pathscore;
else
logadd(logbetas(i, t - 1 /*propagate into preceding frame*/), pathscore);
}
else // transition to entry state
{
const float pathscore = betajtpll + transP(-1, to);
if (to == 0)
exitscore = pathscore;
else
logadd(exitscore, pathscore); // propagate into preceding unit
}
}
}
bwscore = exitscore;
if (islogzero(bwscore))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) bw score is zero (%d st, %d fr: %s)\n",
(int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te, (int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());
te = ts;
je = js;
}
assert(te == 0 && je == 0);
const float totalbwscore = bwscore;
// forward pass --regular Viterbi
// This also computes the gammas right away.
size_t ts = 0; // start frame for unit 'k'
size_t js = 0; // first row index of unit ' k'
float fwscore = 0.0f; // score passed across phone boundaries
foreach_index (k, units) // we exploit that units have fixed boundaries
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t te2 = ts + units[k].frames; // end time of current unit
const size_t je2 = js + n; // range of state indices
// expand from states j at time t (including LL) to time t+1
for (size_t t = ts; t < te2; t++) // note: loop not entered for 0-frame units (tees)
{
for (size_t to = 0; to < n; to++)
{
const size_t j = js + to; // target trellis node
const size_t s = hmm.getsenoneid(to);
const float acLL = logLLs(s, t);
float alphajtnoll = LOGZERO;
if (t == ts) // entering score
{
const float pathscore = fwscore + transP(-1, to);
alphajtnoll = pathscore;
}
else
for (size_t from = 0 /*no entering possible*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
const float alphaitm1 = logalphas(i, t - 1 /*previous frame*/);
const float pathscore = alphaitm1 + transP(from, to);
logadd(alphajtnoll, pathscore);
}
logalphas(j, t) = alphajtnoll + acLL;
}
// update the gammas --do it here because in next frame, betas get overwritten by alphas (they share memory)
for (size_t j = js; j < je2; j++)
{
if (!islogzero(totalbwscore))
loggammas(j, t) = logalphas(j, t) + logbetas(j, t) - totalbwscore;
else // 0/0 problem, can occur if an ac score is so bad that it is 0 after going through softmax
loggammas(j, t) = LOGZERO;
}
}
// t = te2: exit transition (last frame only or tee transition)
float exitscore;
if (te2 == ts) // tee transition
{
exitscore = fwscore + transP(-1, n);
}
else // not tee: expand all last states
{
exitscore = LOGZERO;
for (size_t from = 0 /*no tee possible here*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
const float alphaitm1 = logalphas(i, te2 - 1); // newly computed path score, transiting to t=te2
const float pathscore = alphaitm1 + transP(from, n);
logadd(exitscore, pathscore);
}
}
fwscore = exitscore; // score passed on to next unit
js = je2;
ts = te2;
}
assert(js == logalphas.rows() && ts == logalphas.cols());
const float totalfwscore = fwscore;
// in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
// These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
if (islogzero(totalbwscore) ^ islogzero(totalfwscore))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw 0 score %.10f vs. %.10f (%d st, %d fr: %s)\n",
(int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());
if (islogzero(totalbwscore))
{
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
(int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
return LOGZERO;
}
if (fabsf(totalfwscore - totalbwscore) / ts > 1e-4f)
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw score %.10f vs. %.10f (%d st, %d fr: %s)\n",
(int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());
// we return the full path score
return totalfwscore;
}
// ---------------------------------------------------------------------------
// alignedge() -- perform Viterbi alignment on a single edge
//
// This is an alternative to forwardbackwardedge() that just uses the best path.
// Results:
// - if not returnsenoneids -> 'binary gammas(j,t)' for valid time ranges (remaining areas are not initialized); MMI-compatible
// - if returnsenoneids -> loggammas(0,t) will contain the senone ids directly instead (for sMBR mode)
// - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------
/*static*/ float lattice::alignedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
msra::math::ssematrixbase &loggammas, size_t edgeindex /*for diagnostic messages*/, const bool returnsenoneids,
array_ref<unsigned short> thisedgealignmentsj)
{
// alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
assert(loggammas.cols() == logLLs.cols() + 2);
msra::math::ssematrixstriperef<msra::math::ssematrixbase> backpointers(loggammas, 0, logLLs.cols());
msra::math::ssematrixstriperef<msra::math::ssematrixbase> pathscores(loggammas, 2, logLLs.cols());
// pathscores(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
// backpointers(j,t) are the relative states that it came from
// gammas(j,t) <- 1 if on best path, 0 otherwise
const int invalidbp = -2;
// Viterbi alignment
size_t ts = 0; // start frame for unit 'k'
size_t js = 0; // first row index of unit 'k'
float fwscore = 0.0f; // score passed across phone boundaries
int fwbackpointer = -1; // bp passed across phone boundaries, -1 means start of utterance
foreach_index (k, units) // we exploit that units have fixed boundaries
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t te = ts + units[k].frames; // end time of current unit
const size_t je = js + hmm.getnumstates(); // range of state indices
// expand from states j at time t (including LL) to time t+1
for (size_t t = ts; t < te; t++) // note: loop not entered for 0-frame units (tees)
{
for (size_t j = js; j < je; j++)
{
const size_t to = j - js; // relative state
const size_t s = hmm.getsenoneid(to);
pathscores(j, t) = LOGZERO;
backpointers(j, t) = invalidbp;
if (t == ts) // entering score
{
const float pathscore = fwscore + transP(-1, to);
pathscores(j, t) = pathscore;
backpointers(j, t) = (float) fwbackpointer;
}
else
for (size_t i = js; i < je; i++)
{
const size_t from = i - js;
const float alphaitm1 = pathscores(i, t - 1 /*previous frame*/);
const float pathscore = alphaitm1 + transP(from, to);
if (pathscore > pathscores(j, t))
{
pathscores(j, t) = pathscore;
backpointers(j, t) = (float) i;
}
}
const float acLL = logLLs(s, t);
pathscores(j, t) += acLL;
}
}
// t = te: exit transition (last frame only or tee transition)
float exitscore = LOGZERO;
int exitbackpointer = invalidbp;
if (te == ts) // tee transition
{
exitscore = fwscore + transP(-1, n);
exitbackpointer = fwbackpointer;
}
else // not tee: expand all last states
{
for (size_t i = js; i < je; i++)
{
const size_t from = i - js;
const float alphaitm1 = pathscores(i, te - 1); // newly computed path score, transiting to t=te
const float pathscore = alphaitm1 + transP(from, n);
if (pathscore > exitscore)
{
exitscore = pathscore;
exitbackpointer = (int) i;
}
}
}
if (exitbackpointer == invalidbp)
LogicError("exitbackpointer came up empty");
fwscore = exitscore; // score passed on to next unit
fwbackpointer = exitbackpointer; // and accompanying backpointer
js = je;
ts = te;
}
assert(js == pathscores.rows() && ts == pathscores.cols());
// in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
// These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
if (islogzero(fwscore))
{
fprintf(stderr, "alignedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
(int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
return LOGZERO;
}
// traceback & gamma update
size_t te = backpointers.cols();
size_t je = backpointers.rows();
int j = fwbackpointer;
for (size_t k = units.size() - 1; k + 1 > 0; k--) // go in units because we also need to clear out the column
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t ts2 = te - units[k].frames; // end time of current unit
const size_t js2 = je - hmm.getnumstates(); // range of state indices
for (size_t t = te - 1; t + 1 > ts2; t--)
{
if (j < (int)js2 || j >= (int) je)
LogicError("invalid backpointer resulting in state index out of range");
int bp = (int) backpointers(j, t); // save the backpointer before overwriting it (gammas and backpointers are aliases of each other)
// thisedgealignmentsj[t] = (unsigned short)hmm.getsenoneid(j - js2);
if (!returnsenoneids) // return binary gammas (for MMI; this mode is compatible with softalignmode)
for (size_t i = js2; i < je; i++)
loggammas(i, t) = ((int) i == j) ? 0.0f : LOGZERO;
else // return senone id (for sMBR; note: NOT compatible with softalignmode; calling code must know this)
thisedgealignmentsj[t] = (unsigned short) hmm.getsenoneid(j - js2);
if (bp == invalidbp)
LogicError("deltabackpointer not initialized");
j = bp; // trace back one step
}
te = ts2;
je = js2;
}
if (j != -1)
LogicError("invalid backpointer resulting in not reaching start of utterance when tracing back");
assert(je == 0 && te == 0);
// we return the full path score
return fwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardlattice() -- lattice-level forward/backward
//
// This computes word posteriors, and also returns the per-node alphas and betas.
// Per-edge acoustic scores are passed in via a lambda, as this function is
// intended for use at multiple places with different scores.
// (Specifically, we also use it to determine a pruning threshold, based on
// the original lattice's ac. scores, before even bothering to compute the
// new ac. scores.)
// ---------------------------------------------------------------------------
double lattice::forwardbackwardlattice(const std::vector<float> &edgeacscores, parallelstate ¶llelstate, std::vector<double> &logpps,
std::vector<double> &logalphas, std::vector<double> &logbetas,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode,
const_array_ref<size_t> &uids, const edgealignments &thisedgealignments,
std::vector<double> &logEframescorrect, std::vector<double> &Eframescorrectbuf, double &logEframescorrecttotal) const
{ // ^^ TODO: remove this
// --- hand off to parallelized (CUDA) implementation if available
if (parallelstate.enabled())
{
double totalfwscore = parallelforwardbackwardlattice(parallelstate, edgeacscores, thisedgealignments, lmf, wp, amf, boostingfactor, logpps, logalphas, logbetas, sMBRmode, uids, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);
parallelstate.getlogbetas(logbetas);
if (nodes.size() != logbetas.size())
{
// it is possible if #define TWO_CHANNEL in parallelforwardbackward.cpp: in which case, logbetas will be doulbe the size of (nodes)
if (logbetas.size() != (nodes.size() * 2))
{
RuntimeError("forwardbackwardlattice: logbetas size is not equal or twice of node size, logbetas.size() = %d, nodes.size() = %d", int(logbetas.size()), int(nodes.size()));
}
//only taket the first half of the data
logbetas.erase(logbetas.begin() + nodes.size(), logbetas.begin() + logbetas.size());
}
return totalfwscore;
}
// if we get here, we have no CUDA, and do it the good ol' way
// allocate return values
logpps.resize(edges.size()); // this is our primary return value
// TODO: these are return values as well, but really shouldn't anymore; only used in some older baseline code we some day may want to compare against
logalphas.assign(nodes.size(), LOGZERO);
logalphas.front() = 0.0f;
logbetas.assign(nodes.size(), LOGZERO);
logbetas.back() = 0.0f;
// --- sMBR version
if (sMBRmode)
{
logEframescorrect.resize(edges.size());
Eframescorrectbuf.resize(edges.size());
std::vector<double> logaccalphas(nodes.size(), LOGZERO); // [i] expected frames-correct count over all paths from start to node i
std::vector<double> logaccbetas(nodes.size(), LOGZERO); // [i] likewise
std::vector<double> logframescorrectedge(edges.size()); // raw counts of correct frames in each edge
// forward pass
foreach_index (j, edges)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logalphas[e.E], pathscore);
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
size_t framescorrect = 0; // count raw number of correct frames
for (size_t t = ts; t < te; t++)
framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
logframescorrectedge[j] = (framescorrect > 0) ? log((double) framescorrect) : LOGZERO; // remember for backward pass
double loginaccs = logaccalphas[e.S] - logalphas[e.S];
logadd(loginaccs, logframescorrectedge[j]);
double logpathacc = loginaccs + logalphas[e.S] + edgescore;
logadd(logaccalphas[e.E], logpathacc);
}
foreach_index (j, logaccalphas)
logaccalphas[j] -= logalphas[j];
const double totalfwscore = logalphas.back();
const double totalfwacc = logaccalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// backward pass and computation of state-conditioned frames-correct count
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inscore = logbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logbetas[e.S], pathscore);
double loginaccs = logaccbetas[e.E] - logbetas[e.E];
logadd(loginaccs, logframescorrectedge[j]);
double logpathacc = loginaccs + logbetas[e.E] + edgescore;
logadd(logaccbetas[e.S], logpathacc);
// sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
if (logpp > 1e-2)
fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
if (logpp > 0.0)
logpp = 0.0;
logpps[j] = logpp;
double tmplogeframecorrect = logframescorrectedge[j];
logadd(tmplogeframecorrect, logaccalphas[e.S]);
logadd(tmplogeframecorrect, logaccbetas[e.E] - logbetas[e.E]);
Eframescorrectbuf[j] = exp(tmplogeframecorrect);
}
foreach_index (j, logaccbetas)
logaccbetas[j] -= logbetas[j];
const double totalbwscore = logbetas.front();
const double totalbwacc = logaccbetas.front();
if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());
if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());
logEframescorrecttotal = totalbwacc;
return totalbwscore;
}
// --- MMI version
// forward pass
foreach_index (j, edges)
{
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
const double pathscore = inscore + edgescore;
logadd(logalphas[e.E], pathscore);
}
const double totalfwscore = logalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// backward pass
// this also computes the word posteriors on the fly, since we are at it
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
const auto &e = edges[j];
const double inscore = logbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logbetas[e.S], pathscore);
// compute lattice posteriors on the fly since we are at it
double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
if (logpp > 1e-2)
fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
if (logpp > 0.0)
logpp = 0.0;
logpps[j] = logpp;
}
const double totalbwscore = logbetas.front();
if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());
return totalfwscore;
}
void lattice::constructnodenbestoken(std::vector<NBestToken> &tokenlattice, const bool wordNbest, size_t numtokens2keep, size_t nidx) const
{
std::map<double, std::vector<PrevTokenInfo>>::iterator mp_itr;
std::map<uint64_t, std::vector<size_t>> mp_wid_tokenidx;
std::map<uint64_t, std::vector<size_t>>::iterator mp_itr1;
size_t count;
bool done;
TokenInfo tokeninfo;
uint64_t wid;
vector<size_t> vt_tokenidx;
if (wordNbest) mp_wid_tokenidx.clear();
count = 0;
done = false;
// Sometime,s numtokens is larger than numPathsEMBR. if </s>, keep tokens to be numPathsEMBR
for (mp_itr = tokenlattice[nidx].mp_score_token_infos.begin(); mp_itr != tokenlattice[nidx].mp_score_token_infos.end(); mp_itr++)
{
for (size_t i = 0; i < mp_itr->second.size(); i++)
{
tokeninfo.prev_edge_index = mp_itr->second[i].prev_edge_index;
tokeninfo.prev_token_index = mp_itr->second[i].prev_token_index;
tokeninfo.score = mp_itr->second[i].path_score;
if (wordNbest)
{
wid = nodes[edges[tokeninfo.prev_edge_index].S].wid;
mp_itr1 = mp_wid_tokenidx.find(wid);
bool different = true;
if (mp_itr1 == mp_wid_tokenidx.end())
{
// the wid does not exist in previous tokens of this node, so it is a path with different word sequence
vt_tokenidx.clear();
vt_tokenidx.push_back(count);
mp_wid_tokenidx.insert(pair<uint64_t, std::vector<size_t>>(wid, vt_tokenidx));
}
else
{
for (size_t j = 0; j < mp_itr1->second.size(); j++)
{
size_t oldnodeidx, oldtokenidx, newnodeidx, newtokenidx;
oldnodeidx = edges[tokenlattice[nidx].vt_nbest_tokens[mp_itr1->second[j]].prev_edge_index].S;
oldtokenidx = tokenlattice[nidx].vt_nbest_tokens[mp_itr1->second[j]].prev_token_index;
newnodeidx = edges[tokeninfo.prev_edge_index].S; newtokenidx = tokeninfo.prev_token_index;
while (1)
{
if (nodes[oldnodeidx].wid != nodes[newnodeidx].wid) break;
if (oldnodeidx == newnodeidx)
{
if (oldtokenidx == newtokenidx) different = false;
break;
}
if (oldnodeidx == 0 || newnodeidx == 0)
{
fprintf(stderr, "nbestlatticeEMBR: WARNING: should not come her, oldnodeidx = %d, newnodeidx = %d\n", int(oldnodeidx), int(newnodeidx));
break;
}
size_t tmpnodeix, tmptokenidx;
tmpnodeix = edges[tokenlattice[oldnodeidx].vt_nbest_tokens[oldtokenidx].prev_edge_index].S;
tmptokenidx = tokenlattice[oldnodeidx].vt_nbest_tokens[oldtokenidx].prev_token_index;
oldnodeidx = tmpnodeix; oldtokenidx = tmptokenidx;
tmpnodeix = edges[tokenlattice[newnodeidx].vt_nbest_tokens[newtokenidx].prev_edge_index].S;
tmptokenidx = tokenlattice[newnodeidx].vt_nbest_tokens[newtokenidx].prev_token_index;
newnodeidx = tmpnodeix; newtokenidx = tmptokenidx;
}
if (!different) break;
}
if (different)
{
mp_itr1->second.push_back(count);
}
}
if (different)
{
tokenlattice[nidx].vt_nbest_tokens.push_back(tokeninfo);
count++;
}
}
else
{
tokenlattice[nidx].vt_nbest_tokens.push_back(tokeninfo);
count++;
}
if (count >= numtokens2keep)
{
done = true;
break;
}
}
if (done) break;
}
// free the space.
tokenlattice[nidx].mp_score_token_infos.clear();
}
float compute_wer(vector<size_t> &ref, vector<size_t> &rec)
{
short ** mat;
size_t i, j;
mat = new short*[rec.size() + 1];
for (i = 0; i <= rec.size(); i++) mat[i] = new short[ref.size() + 1];
for (i = 0; i <= rec.size(); i++) mat[i][0] = short(i);
for (j = 1; j <= ref.size(); j++) mat[0][j] = short(j);
for (i = 1; i <= rec.size(); i++)
for (j = 1; j <= ref.size(); j++)
{
mat[i][j] = mat[i - 1][j - 1];
if (rec[i - 1] != ref[j - 1])
{
if ((mat[i - 1][j]) < mat[i][j]) mat[i][j] = mat[i - 1][j];
if ((mat[i][j - 1]) < mat[i][j]) mat[i][j] = mat[i][j - 1];
mat[i][j] ++;
}
}
float wer = float(mat[rec.size()][ref.size()]) / ref.size();
for (i = 0; i < rec.size(); i++) delete[] mat[i];
delete[] mat;
return wer;
}
//linquan's
std::vector<std::wstring> splitword2character(const std::wstring &s)
{
/*std::vector<std::wstring> char_array;
std::wstring tgt;
for (wchar_t ch : s)
{
if (ch <= 0x9fa5 && ch >= 0x4e00)
{
if (tgt.length() > 0)
{
char_array.push_back(tgt);
tgt.clear();
}
char_array.push_back(to_wstring(ch));
}
else if (ch == L' ')
{
if (tgt.length() > 0)
{
char_array.push_back(tgt);
tgt.clear();
}
}
else
tgt.push_back(ch);
}
if (tgt.length() > 0)
{
char_array.push_back(tgt);
}
return char_array;*/
std::wregex words_regex(L"([\u4e00-\u9fa5]|[^\u4e00-\u9fa5\\s]+)");
auto words_begin = std::wsregex_iterator(s.begin(), s.end(), words_regex);
auto words_end = std::wsregex_iterator();
std::vector<std::wstring> tgt;
for (std::wsregex_iterator it = words_begin; it != words_end; ++it)
{
std::wsmatch match = *it;
std::wstring match_str = match.str();
tgt.push_back(match_str);
}
return tgt;
}
bool istagword(const std::wstring &s)
{
std::wregex words_regex(L"^[\\<\\[\\{].*?[\\>\\]\\}]$");
std::wsmatch match;
if (std::regex_match(s.cbegin(), s.cend(), match, words_regex))
return true;
return false;
}
float computewerandcer(std::vector<size_t> &wids, std::vector<size_t> &path_ids, const std::unordered_map<size_t, std::wstring> *ptr_id2wordmap4node)
{
float wer;
if (ptr_id2wordmap4node->size() > 0)
{
std::vector<std::wstring> refwords;
std::vector<std::wstring> regwords;
std::vector<size_t> refid;
std::vector<size_t> regid;
std::wstring temp_string;
std::vector<std::wstring> character_array;
std::unordered_map<std::wstring, size_t> idmappingtable;
refwords.clear();
regwords.clear();
character_array.clear();
refid.clear();
regid.clear();
idmappingtable.clear();
std::unordered_map<size_t, std::wstring>::const_iterator maptable_itr;
for (std::vector<size_t>::const_iterator it = wids.begin(); it != wids.end(); ++it)
{
maptable_itr = ptr_id2wordmap4node->find(*it);
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
character_array = splitword2character(temp_string);
foreach_index(_i, character_array)
{
refwords.push_back(character_array[_i]);
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
{
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
}
}
}
for (std::vector<size_t>::const_iterator it = path_ids.begin(); it != path_ids.end(); ++it)
{
maptable_itr = ptr_id2wordmap4node->find(*it);
temp_string = (maptable_itr != ptr_id2wordmap4node->end()) ? maptable_itr->second : std::to_wstring(*it);
character_array = splitword2character(temp_string);
foreach_index(_i, character_array)
{
regwords.push_back(character_array[_i]);
if (idmappingtable.find(character_array[_i]) == idmappingtable.end())
{
idmappingtable.insert(pair<std::wstring, size_t>(character_array[_i], idmappingtable.size() + 1));
}
}
}
//map characters to id to be compatiable with egacy code
//skip tag words
foreach_index(_k, refwords)
{
if (!istagword(refwords[_k]))
refid.push_back(idmappingtable.find(refwords[_k])->second);
}
foreach_index(_k, regwords)
{
if (!istagword(regwords[_k]))
regid.push_back(idmappingtable.find(regwords[_k])->second);
}
wer = compute_wer(refid, regid);
}
else
{
wer = compute_wer(wids, path_ids);
}
return wer;
}
double lattice::nbestlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate ¶llelstate, std::vector<NBestToken> &tokenlattice, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords,
const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector<size_t> wids) const
{ // ^^ TODO: remove this
// --- hand off to parallelized (CUDA) implementation if available
std::map<double, std::vector<PrevTokenInfo>>::iterator mp_itr;
size_t numtokens2keep;
// TODO: support parallel state
parallelstate;
PrevTokenInfo prevtokeninfo;
std::vector<PrevTokenInfo> vt_prevtokeninfo;
// if we get here, we have no CUDA, and do it the good ol' way
// allocate return values
tokenlattice.resize(nodes.size());
tokenlattice[0].vt_nbest_tokens.resize(1);
tokenlattice[0].vt_nbest_tokens[0].score = 0.0f;
tokenlattice[0].vt_nbest_tokens[0].prev_edge_index = 0;
tokenlattice[0].vt_nbest_tokens[0].prev_token_index = 0;
// forward pass
foreach_index(j, edges)
{
const auto &e = edges[j];
if (enforceValidPathEMBR)
{
if (e.S == 0 && nodes[e.E].wid != 1) continue;
}
if (excludeSpecialWords)
{
// 0~4 is: !NULL, <s>, </s>, !sent_start, and !sent_end
if (nodes[e.E].wid > 4)
{
if (is_special_words[e.E]) continue;
}
if (nodes[e.S].wid > 4)
{
if (is_special_words[e.S]) continue;
}
}
if (tokenlattice[e.S].mp_score_token_infos.size() != 0)
{
//sanity check
if(tokenlattice[e.S].vt_nbest_tokens.size() != 0)
RuntimeError("nbestlatticeEMBR: node = %d, mp_score_token_infos.size() = %d, vt_nbest_tokens.size() = %d, both are not 0!", int(e.S), int(tokenlattice[e.S].mp_score_token_infos.size()), int(tokenlattice[e.S].vt_nbest_tokens.size()));
// Sometime,s numtokens is larger than numPathsEMBR. if </s>, keep tokens to be numPathsEMBR
if (nodes[e.S].wid == 2) numtokens2keep = numPathsEMBR;
else numtokens2keep = numtokens;
constructnodenbestoken(tokenlattice, wordNbest, numtokens2keep, e.S);
}
if (tokenlattice[e.S].vt_nbest_tokens.size() == 0)
{
// it is possible to happen, when you exclude specialwords
continue;
}
prevtokeninfo.prev_edge_index = j;
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
for (size_t i = 0; i < tokenlattice[e.S].vt_nbest_tokens.size(); i++)
{
prevtokeninfo.prev_token_index = i;
double pathscore = tokenlattice[e.S].vt_nbest_tokens[i].score + edgescore;
prevtokeninfo.path_score = pathscore;
if (useAccInNbest && nodes[e.E].wid == 2)
{
// add the wegithed path Accuracy into path score
std::vector<size_t> path, path_ids; // stores the edges in the path
size_t curnodeidx, curtokenidx, prevtokenidx, prevnodeidx;
// ignore the edge with ending node </s> in the path, as </s> will anyway not be used for WER computation
path.clear(); // store the edge sequence of the path
path_ids.clear(); // store the wid sequence of the path
curnodeidx = e.S;
curtokenidx = i;
while (curnodeidx != 0)
{
path.insert(path.begin(), tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index);
prevtokenidx = tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_token_index;
prevnodeidx = edges[tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index].S;
curnodeidx = prevnodeidx;
curtokenidx = prevtokenidx;
}
for (size_t k = 0; k < path.size(); k++)
{
if (k == 0)
{
if (!is_special_words[edges[path[k]].S]) path_ids.push_back(nodes[edges[path[k]].S].wid);
}
if (!is_special_words[edges[path[k]].E]) path_ids.push_back(nodes[edges[path[k]].E].wid);
}
//linquan
//float wer = compute_wer(wids, path_ids);
float wer = computewerandcer(wids, path_ids, ptr_id2wordmap4node);
// will favor the path with better WER
pathscore -= double(accWeightInNbest*wer);
// If you only want WER to affect the selection of Nbest, disable the below line. If you aslo want the WER as weight in error computation, enable this line
prevtokeninfo.path_score = pathscore;
}
mp_itr = tokenlattice[e.E].mp_score_token_infos.find(pathscore);
if (mp_itr != tokenlattice[e.E].mp_score_token_infos.end())
{
mp_itr->second.push_back(prevtokeninfo);
}
else
{
vt_prevtokeninfo.clear();
vt_prevtokeninfo.push_back(prevtokeninfo);
tokenlattice[e.E].mp_score_token_infos.insert(std::pair<double, std::vector<PrevTokenInfo>>(pathscore, vt_prevtokeninfo));
}
}
}
// for the last node, which is </s> or !NULL (!NULL if you do not merge numerator lattice into denominator lattice)
numtokens2keep = numPathsEMBR;
constructnodenbestoken(tokenlattice, wordNbest, numtokens2keep, tokenlattice.size() - 1);
double bestscore;
if (tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size() == 0)
{
if (!excludeSpecialWords) RuntimeError("nbestlatticeEMBR: no token survive while excludeSpecialWords is false");
else bestscore = LOGZERO;
}
else bestscore = tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[0].score;
if (islogzero(bestscore))
{
fprintf(stderr, "nbestlatticeEMBR: WARNING: best score is logzero in lattice \n");
return LOGZERO; // failed, do not use resulting matrix
}
return bestscore;
}
// ---------------------------------------------------------------------------
// backwardlatticeEMBR() -- lattice-level backward
//
// This computes per-node betas for EMBR
// ---------------------------------------------------------------------------
double lattice::backwardlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate ¶llelstate, std::vector<double> &edgelogbetas, std::vector<double> &logbetas,
const float lmf, const float wp, const float amf) const
{ // ^^ TODO: remove this
// --- hand off to parallelized (CUDA) implementation if available
if (parallelstate.enabled())
{
double totalbwscore = parallelbackwardlatticeEMBR(parallelstate, edgeacscores, lmf, wp, amf, edgelogbetas, logbetas);
parallelstate.getlogbetas(logbetas);
parallelstate.getedgelogbetas(edgelogbetas);
if (nodes.size() != logbetas.size())
{
// it is possible if #define TWO_CHANNEL in parallelforwardbackward.cpp: in which case, logbetas will be doulbe the size of (nodes)
if (logbetas.size() != (nodes.size() * 2))
{
RuntimeError("forwardbackwardlattice: logbetas size is not equal or twice of node size, logbetas.size() = %d, nodes.size() = %d", int(logbetas.size()), int(nodes.size()));
}
//only taket the first half of the data
logbetas.erase(logbetas.begin() + nodes.size(), logbetas.begin() + logbetas.size());
}
return totalbwscore;
}
// if we get here, we have no CUDA, and do it the good ol' way
// allocate return values
logbetas.assign(nodes.size(), LOGZERO);
logbetas.back() = 0.0f;
edgelogbetas.assign(edges.size(), LOGZERO);
// backward pass
// this also computes the word posteriors on the fly, since we are at it
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
const auto &e = edges[j];
const double inscore = logbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
edgelogbetas[j] = pathscore;
logadd(logbetas[e.S], pathscore);
}
const double totalbwscore = logbetas.front();
if (islogzero(totalbwscore))
{
fprintf(stderr, "backwardlatticeEMBR: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int)nodes.size(), (int)edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
return totalbwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardlatticesMBR() -- compute expected frame-accuracy counts,
// both the conditioned one (corresponding to c(q) in Dan Povey's thesis)
// and the global one (which is the sMBR criterion to optimize).
//
// Outputs:
// - Eframescorrect[j] == expected frames-correct count conditioned on a state of edge[j].
// We currently assume a hard state alignment. With that, the value turns out
// to be identical for all states of an edge, so we only store it once per edge.
// - return value: expected frames-correct count for entire lattice
//
// Call forwardbackwardlattices() first to compute logalphas/betas.
// ---------------------------------------------------------------------------
double lattice::forwardbackwardlatticesMBR(const std::vector<float> &edgeacscores, const msra::asr::simplesenonehmm &hset,
const std::vector<double> &logalphas, const std::vector<double> &logbetas,
const float lmf, const float wp, const float amf, const_array_ref<size_t> &uids,
const edgealignments &thisedgealignments, std::vector<double> &Eframescorrect) const
{
std::vector<double> accalphas(nodes.size(), 0); // [i] expected frames-correct count over all paths from start to node i
std::vector<double> accbetas(nodes.size(), 0); // [i] likewise
std::vector<size_t> maxcorrect(nodes.size(), 0); // [i] max correct frames up to this node (oracle)
std::vector<double> framescorrectedge(edges.size()); // raw counts of correct frames in each edge
std::vector<int> backpointersformaxcorr(nodes.size(), -2); // keep track of backpointer for the max corr
backpointersformaxcorr.front() = -1;
// forward pass
foreach_index (j, edges)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inaccs = accalphas[e.S];
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
size_t framescorrect = 0; // count raw number of correct frames
for (size_t t = ts; t < te; t++)
framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
framescorrectedge[j] = (double) framescorrect; // remember for backward pass
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
// contribution to end node's path acc = start node's plus edge's correct count, weighted by LL, and divided by sum over LLs
double pathacc = (inaccs + framescorrectedge[j]) * exp(logalphas[e.S] + edgescore - logalphas[e.E]);
accalphas[e.E] += pathacc;
// also keep track of max accuracy, so we can find out whether the lattice contains the correct path
size_t oracleframescorrect = maxcorrect[e.S] + framescorrect; // keep track of most correct path up to end of this edge
if (oracleframescorrect > maxcorrect[e.E])
{
maxcorrect[e.E] = oracleframescorrect;
backpointersformaxcorr[size_t(e.E)] = j;
}
}
const double totalfwacc = accalphas.back();
hset; // just for reference
// report on ground-truth path
// TODO: we will later have code that adds this path if needed
size_t oracleframeacc = maxcorrect.back();
if (oracleframeacc != info.numframes)
fprintf(stderr, "forwardbackwardlatticesMBR: ground-truth path missing from lattice (most correct path: %d out of %d frames correct)\n", (unsigned int) oracleframeacc, (int) info.numframes);
// backward pass and computation of state-conditioned frames-correct count
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inaccs = accbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
double pathacc = (inaccs + framescorrectedge[j]) * exp(logbetas[e.E] + edgescore - logbetas[e.S]);
accbetas[e.S] += pathacc;
// sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
Eframescorrect[j] = (float) (accalphas[e.S] + accbetas[e.E] + framescorrectedge[j]);
}
const double totalbwacc = accbetas.front();
if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());
return totalbwacc;
}
// ---------------------------------------------------------------------------
// bestpathlattice() -- lattice-level "forward/backward" that only returns the
// best path, but in the form of word posteriors, which are 1 or 0, just like
// a real lattice-level forward/backward would do.
// We don't really use this; this was only for a contrast experiment.
// ---------------------------------------------------------------------------
double lattice::bestpathlattice(const std::vector<float> &edgeacscores, std::vector<double> &logpps,
const float lmf, const float wp, const float amf) const
{
// forward pass --sortnedness => regular Viterbi
std::vector<double> logalphas(nodes.size(), LOGZERO);
std::vector<int> backpointers(nodes.size(), -2);
logalphas.front() = 0.0f;
backpointers.front() = -1;
foreach_index (j, edges)
{
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
const double pathscore = inscore + edgescore;
if (pathscore > logalphas[e.E])
{
logalphas[e.E] = pathscore;
backpointers[e.E] = j;
}
}
const double totalfwscore = logalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "bestpathlattice: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// traceback
// We encode the result by storing log 1 in edges on the best path, and log 0 else;
// this makes it naturally compatible with softalign mode
logpps.resize(edges.size());
foreach_index (j, edges)
logpps[j] = LOGZERO;
int backpos = backpointers[nodes.size() - 1];
while (backpos >= 0)
{
logpps[backpos] = 0.0f; // edge is on best path -> PP = 1.0
backpos = backpointers[edges[backpos].S];
}
assert(backpos == -1);
return totalfwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardalign() -- compute the statelevel gammas or viterbi alignments
// the first phase of lattice::forwardbackward
//
// Outputs:
// ---------------------------------------------------------------------------
void lattice::forwardbackwardalign(parallelstate ¶llelstate,
const msra::asr::simplesenonehmm &hset, const bool softalignstates,
const double minlogpp, const std::vector<double> &origlogpps,
std::vector<msra::math::ssematrixbase *> &abcs, littlematrixheap &matrixheap,
const bool returnsenoneids,
std::vector<float> &edgeacscores, const msra::math::ssematrixbase &logLLs,
edgealignments &thisedgealignments, backpointers &thisbackpointers, array_ref<size_t> &uids, const_array_ref<size_t> bounds) const
{ // NOTE: this will be removed and replaced by a proper representation of alignments someday
// do forward-backward or alignment on a per-edge basis. This gives us:
// - per-edge gamma[j,t] = P(s(t)==s_j|edge) if forwardbackward, per-edge alignment thisedgealignments[j] if alignment
// - per-edge acoustic scores
const size_t silunitid = hset.gethmmid("sil"); // shall be the same as parallelstate.getsilunitid()
bool parallelsil = true;
bool cpuverification = false;
#ifndef PARALLEL_SIL // we use a define to make this marked
parallelsil = false;
#endif
#ifdef CPU_VERIFICATION
cpuverification = true;
#endif
// Phase 1: abcs allocate
if (!parallelstate.enabled() || !parallelsil || cpuverification) // allocate abcs when 1.parallelstate not enabled (cpu mode); 2. enabled but not PARALLEL_SIL (silence need to be allocate); 3. cpuverfication
{
abcs.resize(edges.size(), NULL); // [edge index] -> alpha/beta/gamma matrices for each edge
size_t countskip = 0; // if pruning: count how many edges are pruned
foreach_index (j, edges)
{
// determine number of frames
// TODO: this is not efficient--we only use a block-diagonal-like structure, rest is empty (exploiting the fixed boundaries)
const size_t edgeframes = nodes[edges[j].E].t - nodes[edges[j].S].t;
if (edgeframes == 0) // dummy !NULL edge at end of lattice
{
if ((size_t) j != edges.size() - 1)
RuntimeError("forwardbackwardalign: unxpected 0-frame edge (only allowed at very end)");
// note: abcs[j] is already initialized to be NULL in this case, which protects us from accidentally using it
}
else
{
// determine the number of states in an edge
const auto &aligntokens = getaligninfo(j); // get alignment tokens
size_t edgestates = 0;
bool edgehassil = false;
foreach_index (i, aligntokens)
if (aligntokens[i].unit == silunitid)
edgehassil = true;
if (!cpuverification && !edgehassil && parallelstate.enabled()) // !cpuverification, parallel & is non sil, we do not allocate
{
abcs[j] = NULL;
continue;
}
foreach_index (k, aligntokens)
edgestates += hset.gethmm(aligntokens[k].unit).getnumstates();
// allocate the matrix
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
countskip++;
else
abcs[j] = &matrixheap.newmatrix(edgestates, edgeframes + 2); // +2 to have one extra column for betas and one for gammas
}
}
if (minlogpp > LOGZERO)
fprintf(stderr, "forwardbackwardalign: %d of %d edges pruned\n", (int) countskip, (int) edges.size());
}
// Phase 2: alignment on CPU
if (parallelstate.enabled() && !parallelsil) // silence edge shall be process separately if not cuda and not PARALLEL_SIL
{
if (softalignstates)
LogicError("forwardbackwardalign: parallelized version currently only handles hard alignments");
if (minlogpp > LOGZERO)
fprintf(stderr, "forwardbackwardalign: pruning not supported (we won't need it!) :)\n");
edgeacscores.resize(edges.size());
for (size_t j = 0; j < edges.size(); j++)
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
if (aligntokens.size() == 0)
continue;
bool edgehassil = false;
foreach_index (i, aligntokens)
{
if (aligntokens[i].unit == silunitid)
edgehassil = true;
}
if (!edgehassil) // only process sil
continue;
const edgeinfowithscores &e = edges[j];
const size_t ts = nodes[e.S].t;
const size_t te = nodes[e.E].t;
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, true, thisedgealignments[j]);
}
}
// Phase 3: alignment on GPU
if (parallelstate.enabled())
parallelforwardbackwardalign(parallelstate, hset, logLLs, edgeacscores, thisedgealignments, thisbackpointers);
// zhaorui align to reference mlf
if (bounds.size() > 0)
{
size_t framenum = bounds.size();
msra::math::ssematrixbase *refabcs;
size_t ts, te, t;
ts = te = 0;
vector<aligninfo> refinfo(1);
vector<unsigned short> refalign(framenum);
array_ref<aligninfo> refunits(refinfo.data(), 1);
array_ref<unsigned short> refedgealignmentsj(refalign.data(), framenum);
while (te < framenum)
{
// found one phone's boundary (ts, te)
t = ts + 1;
while (t < framenum && bounds[t] == 0)
t++;
te = t;
// make one phone unit
size_t phoneid = bounds[ts] - 1;
refunits[0].unit = phoneid;
refunits[0].frames = te - ts;
size_t edgestates = hset.gethmm(phoneid).getnumstates();
littlematrixheap refmatrixheap(1); // for abcs
refabcs = &refmatrixheap.newmatrix(edgestates, te - ts + 2);
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
// do alignment
alignedge((const_array_ref<aligninfo>) refunits, hset, edgeLLs, *refabcs, 0, true, refedgealignmentsj);
for (t = ts; t < te; t++)
{
uids[t] = (size_t) refedgealignmentsj[t - ts];
}
ts = te;
}
}
// Phase 4: alignment or forwardbackward on CPU for non parallel mode or verification
if (!parallelstate.enabled() || cpuverification) // non parallel mode or verification
{
edgeacscores.resize(edges.size());
std::vector<float> edgeacscoresgpu;
edgealignments thisedgealignmentsgpu(thisedgealignments);
if (cpuverification)
{
parallelstate.getedgeacscores(edgeacscoresgpu);
parallelstate.copyalignments(thisedgealignmentsgpu);
}
foreach_index (j, edges)
{
const edgeinfowithscores &e = edges[j];
const size_t ts = nodes[e.S].t;
const size_t te = nodes[e.E].t;
if (ts == te) // dummy !NULL edge at end
edgeacscores[j] = 0.0f;
else
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
edgeacscores[j] = LOGZERO; // will kill word level forwardbackward hypothesis
else if (softalignstates)
edgeacscores[j] = forwardbackwardedge(aligntokens, hset, edgeLLs, *abcs[j], j);
else
edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, returnsenoneids, thisedgealignments[j]);
}
if (cpuverification)
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
bool edgehassil = false;
foreach_index (i, aligntokens)
{
if (aligntokens[i].unit == silunitid)
edgehassil = true;
}
if (fabs(edgeacscores[j] - edgeacscoresgpu[j]) > 1e-3)
{
fprintf(stderr, "edge %d, sil ? %d, edgeacscores / edgeacscoresgpu MISMATCH %f v.s. %f, diff %e\n",
j, edgehassil ? 1 : 0, (float) edgeacscores[j], (float) edgeacscoresgpu[j],
(float) (edgeacscores[j] - edgeacscoresgpu[j]));
fprintf(stderr, "aligntokens: ");
foreach_index (i, aligntokens)
fprintf(stderr, "%d %d; ", i, aligntokens[i].unit);
fprintf(stderr, "\n");
}
for (size_t t = ts; t < te; t++)
{
if (thisedgealignments[j][t - ts] != thisedgealignmentsgpu[j][t - ts])
fprintf(stderr, "edge %d, sil ? %d, time %d, alignment / alignmentgpu MISMATCH %d v.s. %d\n", j, edgehassil ? 1 : 0, (int) (t - ts), thisedgealignments[j][t - ts], thisedgealignmentsgpu[j][t - ts]);
}
}
}
}
// make sure thisedgealignment has values for later CPU use
if (parallelstate.enabled())
{
parallelstate.copyalignments(thisedgealignments);
parallelstate.getedgeacscores(edgeacscores);
}
}
// compute the error signal for sMBR mode
void lattice::sMBRerrorsignal(parallelstate ¶llelstate,
msra::math::ssematrixbase &errorsignal, msra::math::ssematrixbase &errorsignalneg, // output
const std::vector<double> &logpps, const float amf,
double minlogpp, const std::vector<double> &origlogpps, const std::vector<double> &logEframescorrect,
const double logEframescorrecttotal, const edgealignments &thisedgealignments) const
{
if (parallelstate.enabled()) // parallel version
{
/* time measurement for parallel sMBRerrorsignal
errorsignalcompute: 19.871935 ms (cuda) v.s. 448.711444 ms (emu) */
if (minlogpp > LOGZERO)
fprintf(stderr, "sMBRerrorsignal: pruning not supported (we won't need it!) :)\n");
parallelsMBRerrorsignal(parallelstate, thisedgealignments, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg);
return;
}
// linear mode
foreach_coord (i, j, errorsignal)
errorsignal(i, j) = 0.0f; // Note: we don't actually put anything into the numgammas
foreach_index (j, edges)
{
const auto &e = edges[j];
if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
continue;
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp) // this is pruned
continue;
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
const double diff = logEframescorrect[j] - logEframescorrecttotal;
// Note: the contribution of the states of an edge to their senones is the same for all states
// so we compute it once and add it to all; this will not be the case without hard alignments.
const double pp = exp(logpps[j]); // edge posterior
const float edgecorrect = (float) (pp * diff) / amf;
for (size_t t = ts; t < te; t++)
{
const size_t s = thisedgealignments[j][t - ts];
errorsignal(s, t) += edgecorrect;
}
}
}
// compute the error signal for sMBR mode
size_t sample_from_cumulative_prob(const std::vector<double> &cumulative_prob)
{
if (cumulative_prob.size() < 1)
{
RuntimeError("sample_from_cumulative_prob: the number of bins is 0 \n");
}
double rand_prob = (double)rand() / (double)RAND_MAX * cumulative_prob.back();
for (size_t i = 0; i < cumulative_prob.size() - 1; i++)
{
if (rand_prob <= cumulative_prob[i]) return i;
}
return cumulative_prob.size() - 1;
}
void lattice::EMBRsamplepaths(const std::vector<double> &edgelogbetas,
const std::vector<double> &logbetas, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const bool excludeSpecialWords, std::vector<vector<size_t>> & vt_paths) const
{
// In mp_node_ocp, key is the node id, and value stores the outgoing cumulative locally normalized probability. e.g., if the outgoing probabilities of the node are 0.3 0.1 0.6, the ocp stores: 0.3 0.4 1.0.
// This serves as a cache to avoid recomputation if sampling the same node twice
std::map<size_t, vector<double>> mp_node_ocp;
std::map<size_t, vector<double>>::iterator mp_itr;
std::vector<size_t> path; // stores the edges in the path
std::vector<double> ocp;
mp_node_ocp.clear();
vt_paths.clear();
size_t curnodeidx, edgeidx;
if(enforceValidPathEMBR)
{
for (size_t i = 0; i < vt_node_out_edge_indices[0].size(); i++)
{
// remove the edge
if (nodes[edges[vt_node_out_edge_indices[0][i]].E].wid != 1) lattice::erase_node_out_edges(0, i, i);
}
}
// this is inefficent implementation, we should think of efficient ways to do it later
if (excludeSpecialWords)
{
size_t nidx;
for(size_t j = 0; j < vt_node_out_edge_indices.size(); j++)
{
for (size_t i = 0; i < vt_node_out_edge_indices[j].size(); i++)
{
// remove the edge
// 0~4 is: !NULL, <s>, </s>, !sent_start, and !sent_end
nidx = edges[vt_node_out_edge_indices[j][i]].E;
if (nodes[nidx].wid > 4)
{
if (is_special_words[nidx])
{
lattice::erase_node_out_edges(j, i, i);
continue;
}
}
nidx = edges[vt_node_out_edge_indices[j][i]].S;
if (nodes[nidx].wid > 4)
{
if (is_special_words[nidx]) lattice::erase_node_out_edges(j, i, i);
}
}
}
}
while (vt_paths.size() < numPathsEMBR)
{
path.clear();
curnodeidx = 0;
//start sampling from node 0
bool success = false;
while(true)
{
mp_itr = mp_node_ocp.find(curnodeidx);
if (mp_itr == mp_node_ocp.end())
{
ocp.clear();
for (size_t i = 0; i < vt_node_out_edge_indices[curnodeidx].size(); i++)
{
double prob = exp(edgelogbetas[vt_node_out_edge_indices[curnodeidx][i]] - logbetas[curnodeidx]);
if(i == 0) ocp.push_back(prob);
else ocp.push_back(prob + ocp.back());
}
mp_node_ocp.insert(pair<size_t, vector<double>>(curnodeidx, ocp));
edgeidx = vt_node_out_edge_indices[curnodeidx][sample_from_cumulative_prob(ocp)];
}
else
{
edgeidx = vt_node_out_edge_indices[curnodeidx][sample_from_cumulative_prob(mp_itr->second)];
}
path.push_back(edgeidx);
curnodeidx = edges[edgeidx].E;
// the end of lattice is not !NULL (the end of !NULL is deleted in dbn.exe when converting lattice of htk format to chunk)
// if (nodes[edges[edgeidx].E].t == nodes[edges[edgeidx].S].t)
// wid = 2 is for </s>, the lattice ends with </s>
// the node has no outgoing arc
if (vt_node_out_edge_indices[curnodeidx].size() == 0)
{
if ( (nodes[curnodeidx].wid == 2 || nodes[curnodeidx].wid == 0) && nodes[curnodeidx].t == info.numframes)
{
success = true;
break;
}
else
{
fprintf(stderr, "EMBRsamplepaths: WARNING: the node with index = %d has no outgoing arc, but it is not the node </s> with timing ending with last frame \n", int(curnodeidx));
success = false;
break;
}
}
}
if (success == true) vt_paths.push_back(path);
}
if (vt_paths.size() != numPathsEMBR)
{
fprintf(stderr, "EMBRsamplepaths: Error: vt_paths.size() = %d, and numPathsEMBR = %d \n", int(vt_paths.size()), int(numPathsEMBR));
exit(-1);
}
}
void lattice::EMBRnbestpaths(std::vector<NBestToken>& tokenlattice, std::vector<vector<size_t>> & vt_paths, std::vector<double>& path_posterior_probs) const
{
double log_nbest_posterior_prob;
path_posterior_probs.resize(tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size());
log_nbest_posterior_prob = LOGZERO;
for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++)
{
logadd(log_nbest_posterior_prob, tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[i].score);
}
for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++)
{
path_posterior_probs[i] = exp(tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens[i].score - log_nbest_posterior_prob);
}
std::vector<size_t> path; // stores the edges in the path
vt_paths.clear();
size_t curnodeidx, curtokenidx, prevtokenidx, prevnodeidx;
for (size_t i = 0; i < tokenlattice[tokenlattice.size() - 1].vt_nbest_tokens.size(); i++)
{
path.clear();
curnodeidx = tokenlattice.size() - 1;
curtokenidx = i;
while (curnodeidx != 0)
{
path.insert(path.begin(), tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index);
prevtokenidx = tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_token_index;
prevnodeidx = edges[tokenlattice[curnodeidx].vt_nbest_tokens[curtokenidx].prev_edge_index].S;
curnodeidx = prevnodeidx;
curtokenidx = prevtokenidx;
}
vt_paths.push_back(path);
}
}
double lattice::get_edge_weights(std::vector<size_t>& wids, std::vector<std::vector<size_t>>& vt_paths, std::vector<double>& vt_edge_weights, std::vector<double>& vt_path_posterior_probs, string getPathMethodEMBR, double& onebest_wer) const
{
struct PATHINFO
{
size_t count;
float WER;
};
std::map<string, PATHINFO> mp_path_info;
std::map<string, PATHINFO>::iterator mp_itr;
std::unordered_set<string> set_edge_path;
std::vector<double> vt_path_weights;
vt_path_weights.resize(vt_paths.size());
vector<size_t> path_ids;
double avg_wer;
avg_wer = 0;
for (size_t i = 0; i < vt_paths.size(); i++)
{
path_ids.clear();
for (size_t j = 0; j < vt_paths[i].size(); j++)
{
if (j == 0)
{
if (!is_special_words[edges[vt_paths[i][j]].S]) path_ids.push_back(nodes[edges[vt_paths[i][j]].S].wid);
nodes[edges[vt_paths[i][j]].S].wid;
}
if (!is_special_words[edges[vt_paths[i][j]].E]) path_ids.push_back(nodes[edges[vt_paths[i][j]].E].wid);
nodes[edges[vt_paths[i][j]].E].wid;
}
//linquan
//vt_path_weights[i] = compute_wer(wids, path_ids);
vt_path_weights[i] = computewerandcer(wids, path_ids, ptr_id2wordmap4node);
string pathidstr = "$";
for (size_t j = 0; j < path_ids.size(); j++) pathidstr += ("_" + std::to_string(path_ids[j]));
mp_itr = mp_path_info.find(pathidstr);
if (mp_itr != mp_path_info.end())
{
mp_itr->second.count++;
}
else
{
PATHINFO pathinfo;
pathinfo.count = 1;
pathinfo.WER = float(vt_path_weights[i]);
mp_path_info.insert(pair<string, PATHINFO>(pathidstr, pathinfo));
}
// this uses weighted avg wer
avg_wer += (vt_path_weights[i] * vt_path_posterior_probs[i]);
}
if (getPathMethodEMBR == "sampling") onebest_wer = -10000;
else onebest_wer = vt_path_weights[0];
for (size_t i = 0; i < vt_path_weights.size(); i++)
{
// loss - mean_loss
vt_path_weights[i] -= avg_wer;
if(getPathMethodEMBR == "sampling") vt_path_weights[i] /= (vt_paths.size() - 1);
else vt_path_weights[i] *= (vt_path_posterior_probs[i]);
}
for (size_t i = 0; i < vt_paths.size(); i++)
{
for (size_t j = 0; j < vt_paths[i].size(); j++)
// substraction, since we want to minimize the loss function, rather than maximize
vt_edge_weights[vt_paths[i][j]] -= vt_path_weights[i];
}
set_edge_path.clear();
for (size_t i = 0; i < vt_paths.size(); i++)
{
string pathedgeidstr = "$";
for (size_t j = 0; j < vt_paths[i].size(); j++)
{
pathedgeidstr += ("_" + std::to_string(vt_paths[i][j]));
}
set_edge_path.insert(pathedgeidstr);
}
return avg_wer;
}
void lattice::EMBRerrorsignal(parallelstate ¶llelstate,
const edgealignments &thisedgealignments, std::vector<double>& edge_weights, msra::math::ssematrixbase &errorsignal) const
{
Microsoft::MSR::CNTK::Matrix<float> errorsignalcpu(-1);
if (parallelstate.enabled()) // parallel version
{
parallelstate.setedgeweights(edge_weights);
std::vector<double> verify_edge_weights;
parallelstate.getedgeweights(verify_edge_weights);
parallelEMBRerrorsignal(parallelstate, thisedgealignments, edge_weights, errorsignal);
parallelstate.getgamma(errorsignalcpu);
return;
}
// linear mode
foreach_coord(i, j, errorsignal)
errorsignal(i, j) = 0.0f; // Note: we don't actually put anything into the numgammas
foreach_index(j, edges)
{
const auto &e = edges[j];
if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
continue;
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
for (size_t t = ts; t < te; t++)
{
const size_t s = thisedgealignments[j][t - ts];
errorsignal(s, t) = errorsignal(s, t) + float(edge_weights[j]);
}
}
}
// compute the error signal for MMI mode
void lattice::mmierrorsignal(parallelstate ¶llelstate, double minlogpp, const std::vector<double> &origlogpps,
std::vector<msra::math::ssematrixbase *> &abcs, const bool softalignstates,
const std::vector<double> &logpps, const msra::asr::simplesenonehmm &hset,
const edgealignments &thisedgealignments, msra::math::ssematrixbase &errorsignal) const
{
if (parallelstate.enabled())
{
if (minlogpp > LOGZERO)
fprintf(stderr, "mmierrorsignal: pruning not supported (we won't need it!) :)\n");
if (softalignstates)
LogicError("mmierrorsignal: parallel version for softalignstates mode is not supported yet");
parallelmmierrorsignal(parallelstate, thisedgealignments, logpps, errorsignal);
return;
}
for (size_t j = 0; j < (errorsignal).cols(); j++)
for (size_t i = 0; i < (errorsignal).rows(); i++)
errorsignal(i, j) = VIRGINLOGZERO; // set to zero --note: may be in-place with logLLs, which now get overwritten
// size_t warnings = 0; // [v-hansu] check code for mmi; search this comment to see all related codes
foreach_index (j, edges)
{
const auto &e = edges[j];
if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
continue;
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp) // this is pruned
continue;
const auto &aligntokens = getaligninfo(j); // get alignment tokens
auto &loggammas = *abcs[j];
const float edgelogP = (float) logpps[j];
// if (islogzero (edgelogP)) // we had a 0 prob
// continue;
// accumulate this edge's gamma matrix into target posteriors
const size_t tedge = nodes[e.S].t;
size_t ts = 0; // time index into gamma matrix
size_t js = 0; // state index into gamma matrix
foreach_index (k, aligntokens) // we exploit that units have fixed boundaries
{
const auto &unit = aligntokens[k];
const size_t te = ts + unit.frames;
const auto &hmm = hset.gethmm(unit.unit); // TODO: inline these expressions
const size_t n = hmm.getnumstates();
const size_t je = js + n;
// P(s) = P(s|e) * P(e)
for (size_t t = ts; t < te; t++)
{
const size_t tutt = t + tedge; // time index w.r.t. utterance
// double logsum = LOGZERO; // [v-hansu] check code for mmi; search this comment to see all related codes
for (size_t i = 0; i < n; i++)
{
const size_t j2 = js + i; // state index for this unit in matrix
const size_t s = hmm.getsenoneid(i); // state class index
const float gammajt = loggammas(j2, t);
const float statelogP = edgelogP + gammajt;
logadd(errorsignal(s, tutt), statelogP);
}
}
ts = te;
js = je;
}
assert(ts + 2 == loggammas.cols() && js == loggammas.rows());
}
// check normalizedness (is that an actual English word?)
// also count non-zero probs
size_t nonzerostates = 0;
foreach_column (t, errorsignal)
{
double logsum = LOGZERO;
foreach_row (s, errorsignal)
{
if (islogzero(errorsignal(s, t)))
nonzerostates++;
else
logadd(logsum, (double) errorsignal(s, t));
// TODO: count VIRGINLOGZERO, print per frame
}
if (fabs(logsum) / errorsignal.rows() > 1e-6)
fprintf(stderr, "forwardbackward: WARNING: overall posterior column(%d) sum = exp (%.10f) != 1\n", (int) t, logsum);
}
fprintf(stderr, "forwardbackward: %.3f%% non-zero state posteriors\n", 100.0f - nonzerostates * 100.0f / errorsignal.rows() / errorsignal.cols());
// convert to non-log posterior --that's what we return
foreach_coord (i, j, errorsignal)
errorsignal(i, j) = expf(errorsignal(i, j));
}
// compute ground truth's score
// It is critical to get all details consistent with the lattice, to avoid skewing the weights.
// ... TODO: we don't need this to be a class member, actually; try to just make it a 'static' function.
/*static*/ double lattice::scoregroundtruth(const_array_ref<size_t> uids, const_array_ref<htkmlfwordsequence::word> transcript, const std::vector<float> &transcriptunigrams,
const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset, const float lmf, const float wp, const float amf)
{
if (transcript[0].firstframe != 0) // TODO: should we store the #frames instead? Then we can validate the total duration
LogicError("scoregroundtruth: first transcript[] token does not start at frame 0");
// get the silence models, since they are treated specially
const size_t numframes = logLLs.cols();
const auto &sil = hset.gethmm(hset.gethmmid("sil"));
const auto &sp = hset.gethmm(hset.gethmmid("sp"));
if (sp.numstates != 1 || sil.numstates != 3)
RuntimeError("scoregroundtruth: only supports 1-state /sp/ and 3-state /sil/ tied to /sp/");
const size_t silst = sp.senoneids[0];
// loop over words
double pathscore = 0.0;
foreach_index (i, transcript)
{
size_t ts = transcript[i].firstframe;
size_t te = ((size_t) i + 1 < transcript.size()) ? transcript[i + 1].firstframe : numframes;
if (ts >= te)
LogicError("scoregroundtruth: transcript[] tokens out of order");
// acoustic score: loop over frames
const msra::asr::simplesenonehmm::transP *prevtransP = NULL; // previous transP
int prevs = -1; // previous state index
for (size_t t = ts; t < te; t++)
{
size_t senoneid = uids[t];
// recover the transP and state index
int s;
const msra::asr::simplesenonehmm::transP *transP;
if (senoneid == silst)
{
if (prevtransP == &sil.gettransP()) // "silst" may be tied to /sp/, and thus the /sil/ center state may be ambiguous
{
transP = prevtransP; // remain in /sil/
s = 1; // note that this will fail if /sp/ can follow /sil/ and if /sil/ may end with "silst" (both not allowed currently)
}
else // "silst" -> we are in /sp/
{
transP = &sp.gettransP();
s = 0;
}
}
else // all others must be non-ambiguous
{
int transPindex = hset.senonetransP(senoneid);
int sindex = hset.senonestate(senoneid);
if (transPindex == -1 || sindex == -1)
RuntimeError("scoregroundtruth: failed to resolve ambiguous senone %s", hset.getsenonename(senoneid));
transP = &hset.transPs[transPindex];
s = sindex;
}
// if changing phoneme then add necessary enter/exit score
if (transP != prevtransP)
{
if (prevtransP) // previous exit transition
pathscore += (*prevtransP)(prevs, prevtransP->getnumstates());
prevs = -1; // enter transition
}
// add inter-state transP
pathscore += (*transP)(prevs, s);
// acoustic score
pathscore += logLLs(senoneid, t);
// remember state for next frame
prevtransP = transP;
prevs = s;
}
// add the last exit transition
if (prevtransP) // previous exit transition
pathscore += (*prevtransP)(prevs, prevtransP->getnumstates());
// need to add a /sp/ tee transition if we don't end in sp, since our transcript dictionary includes a /sp/ at the end of each entry
if (uids[te - 1] != sp.senoneids[0])
pathscore += sp.gettransP()(-1, sp.numstates);
// lm
pathscore += transcriptunigrams[i] * lmf + wp;
}
if (islogzero(pathscore))
fprintf(stderr, "scoregroundtruth: ground-truth path has zero probability; some model inconsistency, maybe?\n");
// account for amf
pathscore /= amf;
fprintf(stderr, "scoregroundtruth: ground-truth score %.6f (%d frames)\n", pathscore, (int) numframes);
return pathscore;
}
// ---------------------------------------------------------------------------
// sMBRdiagnostics() -- helper to print some diagnostics for analyzing sMBR results
// ---------------------------------------------------------------------------
// static // with 'static', compiler will complain if the function is not used (we only compile it in sometimes for diagnostics)
void sMBRdiagnostics(const msra::math::ssematrixbase &errorsignal, const_array_ref<size_t> uids,
const_array_ref<size_t> bestpath, const vector<bool> &refseen, const msra::asr::simplesenonehmm &hset)
{
// TODO:
// - print best positive runner-up
// - WARN tag if neg > pos
// - check the sum and warn if not 0
// - indicate whether best path state is correct or not
size_t numcor = 0;
size_t numnegbetter = 0; // # frames the neg competitor is better
size_t numposbetter = 0; // # frames the pos competitor is better
foreach_column (t, errorsignal)
{
const size_t sref = uids[t];
const char *srefname = hset.getsenonename(sref);
const size_t sbest = bestpath[t];
const char *sbestname = hset.getsenonename(sbest);
if (sref == sbest)
numcor++;
// for each frame, print error signal for ground truth and runner up (second largest abs value)
size_t sneg = SIZE_MAX; // competitor
float eneg = 0.0f;
size_t spos = SIZE_MAX; // best positive competitor
float epos = 0.0f;
foreach_row (s, errorsignal)
{
if (s == sref)
continue;
if (errorsignal(s, t) < eneg)
{
sneg = s;
eneg = errorsignal(s, t);
}
if (errorsignal(s, t) > epos)
{
spos = s;
epos = errorsignal(s, t);
}
}
if (fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < -eneg)
numnegbetter++;
if (fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < epos)
numposbetter++;
const char *snegname = sneg == SIZE_MAX ? "-" : hset.getsenonename(sneg);
const char *sposname = spos == SIZE_MAX ? "-" : hset.getsenonename(spos);
fprintf(stderr, "e(%d): ref %s: %.6f / %s: %.6f / %s: %.6f / top %s: %.6f%s%s%s%s%s\n",
(int) t, srefname, errorsignal(sref, t), snegname, eneg, sposname, epos, sbestname, errorsignal(sbest, t),
sbest == sref ? "" : " ERR",
fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < 0 ? " INV!!" : "",
fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < -eneg ? " WEAK" : "",
fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < epos ? " 2ND" : "",
refseen[t] ? "" : " NOREF");
}
// print this to validate our bestpath computation
fprintf(stderr, "sMBRdiagnostics: %d frames correct out of %d, %.2f%%, neg better in %d, pos in %d\n",
(int) numcor, (int) errorsignal.cols(), 100.0f * numcor / errorsignal.cols(),
(int) numnegbetter, (int) numposbetter);
}
// static // with 'static', compiler will complain if the function is not used (we only compile it in sometimes for diagnostics)
void sMBRsuppressweirdstuff(msra::math::ssematrixbase &errorsignal, const_array_ref<size_t> uids)
{
size_t numweird = 0;
foreach_column (t, errorsignal)
{
const size_t sref = uids[t];
// for each frame, print error signal for ground truth and runner up (second largest abs value)
const float eref = errorsignal(sref, t);
bool isweird = eref < 0.0f; // negative for reference!?
for (size_t s = 0; s < errorsignal.rows() && !isweird; s++)
{
if (s == sref)
continue;
if (fabs(errorsignal(s, t)) > eref)
isweird = true;
}
if (isweird)
{
foreach_row (s, errorsignal)
errorsignal(s, t) = 0.0f;
numweird++;
}
}
// print this to validate our bestpath computation
fprintf(stderr, "sMBRsuppressweirdstuff: %d weird frames out of %d, %.2f%% were flattened\n",
(int) numweird, (int) errorsignal.cols(), 100.0f * numweird / errorsignal.cols());
}
// ---------------------------------------------------------------------------
// forwardbackward() -- main function for MMI/sMBR
//
// This computes the lattice state-level statistics for sequence training using MMI or sMBR.
//
// Outputs, MMI mode:
// - result = dengammas = denominator gammas (non-log form)
// - returns log of sum over all paths' likelihoods (the denominator of the MMI objective)
// - note: numgammas is not used/touched in MMI mode
//
// Outputs, sMBR mode:
// TODO: fix this comment
// - result = errorsignal = (abs value of) negative contributions to error signal
// - errorsignalbuf = for temporarily use to get errorsignal
// - returns expected frames-correct count (the sMBR objective)
// ---------------------------------------------------------------------------
double lattice::forwardbackward(parallelstate ¶llelstate, const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset,
msra::math::ssematrixbase &result, msra::math::ssematrixbase &errorsignalbuf,
const float lmf, const float wp, const float amf, const float boostingfactor,
const bool sMBRmode, const bool EMBR, const string EMBRUnit, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const string getPathMethodEMBR, const string showWERMode, const bool excludeSpecialWords, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numRawPathsEMBR, array_ref<size_t> uids, vector<size_t> wids, const_array_ref<size_t> bounds,
const_array_ref<htkmlfwordsequence::word> transcript, const std::vector<float> &transcriptunigram) const
{
std::vector<NBestToken> tokenlattice;
tokenlattice.clear();
if (wids.size() == 0) return 0;
if (numPathsEMBR < 1)
{
fprintf(stderr, "forwardbackward: WARNING: numPathsEMBR = %d , which is smaller than 1\n", (int)numPathsEMBR);
return LOGZERO; // failed, do not use resulting matrix
}
if (EMBRUnit != "word")
{
fprintf(stderr, "forwardbackward: Error: Currently do not support EMBR unit other than word\n");
return LOGZERO; // failed, do not use resulting matrix
}
// sanity check
if (nodes[0].wid != 0) RuntimeError("The first node is not 0 (i.e.) !NULL, but is %d \n", int(nodes[0].wid));
// the lattice last node could be either 0 or 2, i.e., if it is an merged lattice (merged numerator and denominator the dnb code dedicately removes ending !NULL, it is 0. If it is not merged lattice (the one that I changed TAMER code to only use denominator lattice), the last node could be !NULL
if(nodes[nodes.size()-1].wid != 2 && nodes[nodes.size() - 1].wid != 0) RuntimeError("The last node is not 2 (i.e.) </s> or 0 (i.e, !NULL), but is %d \n", int(nodes[0].wid));
// I want to make sure there is only one </s>, it is crucial when the useAccinNbest is true: we add sentence acc into nbest cost function in the </s>.
size_t sent_end_count = 0;
if (nodes[nodes.size() - 1].wid == 2) sent_end_count = 1;
for (size_t i = 1; i < nodes.size() - 1; i++)
{
if (nodes[i].wid == 2)
{
if (nodes[nodes.size() - 1].wid == 2 && (i != nodes.size() - 1))
{
//RuntimeError("The node %d wid is 2 (i.e.) </s>, but it is not the last node, total number of node is %d \n", int(i), int(nodes.size()));
fprintf(stderr, "The node %d wid is 2 (i.e.) </s>, but it is not the last node, total number of node is %d \n", int(i), int(nodes.size()));
return LOGZERO; // bad data, do not use resulting matrix
}
sent_end_count++;
}
if (nodes[i].wid == 0) RuntimeError("The node %d wid is 0 (i.e.) </s>, but it is not the first node or last node, total number of node is %d \n", int(i), int(nodes.size()));
}
if (sent_end_count != 1)
{
RuntimeError("</s> count is not 1 in the lattice, but %d, and total number of node is %d \n", int(sent_end_count), int(nodes.size()));
}
bool softalign = true;
bool softalignstates = false; // true if soft alignment within edges, currently we only support soft within edge in cpu mode
bool softalignlattice = softalign; // w.r.t. whole lattice
edgealignments thisedgealignments(*this); // alignments memory allocate for this lattice
backpointers thisbackpointers(*this, hset); // memory for forwardbackward
if (info.numframes != logLLs.cols())
LogicError("forwardbackward: #frames mismatch between lattice (%d) and LLs (%d)", (int) info.numframes, (int) logLLs.cols());
// TODO: the following checks should throw, but I don't dare in case this will crash a critical job... if we never see this warning, then
if (info.numframes != uids.size())
fprintf(stderr, "forwardbackward: #frames mismatch between lattice (%d) and uids (%d)\n", (int) info.numframes, (int) uids.size());
if (info.numframes != result.cols())
fprintf(stderr, "forwardbackward: #frames mismatch between lattice (%d) and result (%d)\n", (int) info.numframes, (int) result.cols());
littlematrixheap matrixheap(info.numedges); // for abcs
// PHASE 0: fake word level forward backwards --only used when pruning enabled
const double minlogpp = LOGZERO; // pruning threshold --LOGZERO means disabled
std::vector<double> origlogpps; // word posterior from original lattice, for pruning decision
// PHASE 1: per-edge forward backwards (="time alignments")
// score the ground truth --only if a transcript is provided, which happens if the user provides a language model
// TODO: no longer used, remove this. 'transcript' parameter is no longer used in this function.
transcript;
transcriptunigram;
// allocate alpha/beta/gamma matrices (all are sharing the same memory in-place)
std::vector<msra::math::ssematrixbase *> abcs;
std::vector<float> edgeacscores; // [edge index] acoustic scores
// return senone id for EMBR or sMBR, but not for MMI
forwardbackwardalign(parallelstate, hset, softalignstates, minlogpp, origlogpps, abcs, matrixheap, (sMBRmode || EMBR) /*returnsenoneids*/, edgeacscores, logLLs, thisedgealignments, thisbackpointers, uids, bounds);
// PHASE 2: lattice-level forward backward
// we exploit that the lattice is sorted by (end node, start node) for in-place processing
// checklattice(); // comment out by v-hansu to save time
#ifdef PRINT_TIME_MEASUREMENT
auto_timer latlevelfwbw;
#endif
std::vector<double> logpps;
std::vector<double> Eframescorrectbuf; // this is used for compute the Eframescorrectdiff
std::vector<double> logEframescorrect; // this is the final output of PHASE 2
std::vector<double> logalphas;
std::vector<double> logbetas;
std::vector<double> edgelogbetas; // the edge score plus the edge's outgoing node's beta scores
double totalfwscore = 0; // TODO: name no longer precise in sMBRmode
double logEframescorrecttotal = LOGZERO;
double totalbwscore = 0;
bool returnEframescorrect = sMBRmode;
if (softalignlattice)
{
if (EMBR)
{
//compute Beta only,
if (getPathMethodEMBR == "sampling")
{
totalbwscore = backwardlatticeEMBR(edgeacscores, parallelstate, edgelogbetas, logbetas, lmf, wp, amf);
totalfwscore = totalbwscore; // to make the existing code happy
}
else //nbest
{
double bestscore = nbestlatticeEMBR(edgeacscores, parallelstate, tokenlattice, numRawPathsEMBR, enforceValidPathEMBR, excludeSpecialWords, lmf, wp, amf, wordNbest, useAccInNbest, accWeightInNbest, numPathsEMBR, wids);
totalfwscore = bestscore; // to make the code happy, it should be called bestscore, rather than totalfwscore though, will fix later
}
}
else
{
totalfwscore = forwardbackwardlattice(edgeacscores, parallelstate, logpps, logalphas, logbetas, lmf, wp, amf, boostingfactor, returnEframescorrect, (const_array_ref<size_t> &) uids, thisedgealignments, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);
if (sMBRmode && !returnEframescorrect)
logEframescorrecttotal = forwardbackwardlatticesMBR(edgeacscores, hset, logalphas, logbetas, lmf, wp, amf, (const_array_ref<size_t> &) uids, thisedgealignments, Eframescorrectbuf);
// ^^ BUGBUG not tested
}
}
else
totalfwscore = bestpathlattice(edgeacscores, logpps, lmf, wp, amf);
#ifdef PRINT_TIME_MEASUREMENT
latlevelfwbw.show("latlevelfwbw"); // 68.395682 ms
#endif
if (islogzero(totalfwscore))
{
fprintf(stderr, "forwardbackward: WARNING: totalforwardscore is zero: (%d nodes/%d edges), totalfwscore = %f \n", (int)nodes.size(), (int)edges.size(), totalfwscore);
if(!EMBR || !excludeSpecialWords)
return LOGZERO; // failed, do not use resulting matrix
}
// PHASE 3: compute final state-level posteriors (MMI mode)
// compute expected frames correct in sMBRmode
const size_t numframes = logLLs.cols();
assert(numframes == info.numframes);
if (EMBR)
{
std::vector<vector<size_t>> vt_paths;
std::vector<double> edge_weights(edges.size(), 0.0);
std::vector<double> path_posterior_probs;
double onebest_wer = 0.0;
double avg_wer = 0.0;
// for getPathMethodEMBR=sampling, the onebest_wer does not make any sense, pls. do not use it
// ToDO: if it is logzero(totalfwscore), the criterion shown in the training log is not totally correct: for this problematic utterance, the wer is counted as 0. Problematic in the sense that: we set excludeSpecialWords is true, and found no token survive
if (!islogzero(totalfwscore))
{
// Do path sampling
if (getPathMethodEMBR == "sampling")
{
EMBRsamplepaths(edgelogbetas, logbetas, numPathsEMBR, enforceValidPathEMBR, excludeSpecialWords, vt_paths);
path_posterior_probs.resize(vt_paths.size(), (1.0 / vt_paths.size()));
}
else
{
EMBRnbestpaths(tokenlattice, vt_paths, path_posterior_probs);
}
avg_wer = get_edge_weights(wids, vt_paths, edge_weights, path_posterior_probs, getPathMethodEMBR, onebest_wer);
}
auto &errorsignal = result;
EMBRerrorsignal(parallelstate, thisedgealignments, edge_weights, errorsignal);
if(getPathMethodEMBR == "nbest" && showWERMode == "onebest") return onebest_wer;
else return avg_wer;
}
else
{
// MMI mode
if (!sMBRmode)
{
// we first take the sum in log domain to avoid numerical issues
auto &dengammas = result; // result is denominator gammas
mmierrorsignal(parallelstate, minlogpp, origlogpps, abcs, softalignstates, logpps, hset, thisedgealignments, dengammas);
return totalfwscore / numframes; // return value is av. posterior
}
// sMBR mode
else
{
auto &errorsignal = result;
sMBRerrorsignal(parallelstate, errorsignal, errorsignalbuf, logpps, amf, minlogpp, origlogpps, logEframescorrect, logEframescorrecttotal, thisedgealignments);
static bool dummyvariable = (fprintf(stderr, "note: new version with kappa adjustment, kappa = %.2f\n", 1 / amf), true); // we only print once
return exp(logEframescorrecttotal) / numframes; // return value is av. expected frame-correct count
}
}
}
};
};