https://github.com/Microsoft/CNTK
Tip revision: 16a41cef30894ca92667bd93079cd6fa11b3e92d authored by Sayan Pathak on 02 November 2017, 16:10:10 UTC
Added super resolution tutorial contributed by Borna with added code to minimize test downloads, fix tests, added documentation and small editorial changes to LSGAN tutorial
Added super resolution tutorial contributed by Borna with added code to minimize test downloads, fix tests, added documentation and small editorial changes to LSGAN tutorial
Tip revision: 16a41ce
latticeforwardbackward.cpp
// latticearchive.cpp -- managing lattice archives
//
// F. Seide, V-hansu
#ifndef _CRT_SECURE_NO_WARNINGS
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#endif
#include "Basics.h"
#include "simple_checked_arrays.h"
#include "latticearchive.h"
#include "simplesenonehmm.h" // the model
#include "ssematrix.h" // the matrices
#include "latticestorage.h"
#include <unordered_map>
#include <list>
#include <stdexcept>
using namespace std;
#define VIRGINLOGZERO (10 * LOGZERO) // used for printing statistics on unseen states
#undef CPU_VERIFICATION
#ifdef _WIN32
int msra::numa::node_override = -1; // for numahelpers.h
#endif
namespace msra { namespace lattices {
// ---------------------------------------------------------------------------
// helper class for allocation lots of small matrices, no free
// ---------------------------------------------------------------------------
class littlematrixheap
{
static const size_t CHUNKSIZE;
typedef msra::math::ssematrixfrombuffer matrixfrombuffer;
std::list<std::vector<float>> heap;
size_t allocatedinlast; // in last heap element
size_t totalallocated;
std::vector<matrixfrombuffer> matrices;
public:
littlematrixheap(size_t estimatednumentries)
: totalallocated(0), allocatedinlast(0)
{
matrices.reserve(estimatednumentries + 1);
}
msra::math::ssematrixbase &newmatrix(size_t rows, size_t cols)
{
const size_t elementsneeded = matrixfrombuffer::elementsneeded(rows, cols);
if (heap.empty() || (heap.back().size() - allocatedinlast) < elementsneeded)
{
const size_t nelem = max(CHUNKSIZE, elementsneeded + 3 /*+3 for SSE alignment*/);
heap.push_back(std::vector<float>(nelem));
allocatedinlast = 0;
// make sure starting element is SSE-aligned (the constructor demands that)
const size_t offelem = (((size_t) &heap.back()[allocatedinlast]) / sizeof(float)) % 4;
if (offelem != 0)
allocatedinlast += 4 - offelem;
}
auto &buffer = heap.back();
if (elementsneeded > heap.back().size() - allocatedinlast)
LogicError("newmatrix: allocation logic screwed up");
// get our buffer into a handy vector-like thingy
array_ref<float> vecbuffer(&buffer[allocatedinlast], elementsneeded);
// allocate in the current heap location
matrices.resize(matrices.size() + 1);
if (matrices.size() + 1 > matrices.capacity())
LogicError("newmatrix: littlematrixheap cannot grow but was constructed with too small number of eements");
auto &matrix = matrices.back();
matrix = matrixfrombuffer(vecbuffer, rows, cols);
allocatedinlast += elementsneeded;
totalallocated += elementsneeded;
return matrix;
}
};
const size_t littlematrixheap::CHUNKSIZE = 256 * 1024; // 1 MB
// ---------------------------------------------------------------------------
// helpers for log-domain addition
// ---------------------------------------------------------------------------
#ifndef LOGZERO
#define LOGZERO -1e30f
#endif
// logadd (loga, logb) -> a += b, or loga = log [ exp(loga) + exp(logb) ]
static void logaddratio(float &loga, float diff)
{
if (diff < -17.0f)
return; // log (2^-24), 23-bit mantissa -> cut of after 24th bit
loga += logf(1.0f + expf(diff));
}
static void logaddratio(double &loga, double diff)
{
if (diff < -37.0f)
return; // log (2^-53), 52-bit mantissa -> cut of after 53th bit
loga += log(1.0 + exp(diff));
}
// loga <- log (exp (loga) + exp (logb)) = log (exp (loga) * (1.0 + exp (logb - loga)) = loga + log (1.0 + exp (logb - loga))
template <typename FLOAT>
static void logadd(FLOAT &loga, FLOAT logb)
{
if (logb > loga) // we add smaller to bigger
::swap(loga, logb);
if (loga <= LOGZERO) // both are 0
return;
logaddratio(loga, logb - loga);
}
template <typename FLOAT>
static void logmax(FLOAT &loga, FLOAT logb) // for testing (max approx)
{
if (logb > loga)
loga = logb;
}
template <typename FLOAT>
static FLOAT expdiff(FLOAT a, FLOAT b) // for testing
{
if (b > a)
return exp(b) * (exp(a - b) - 1);
else
return exp(a) * (1 - exp(b - a));
}
template <typename FLOAT>
static bool islogzero(FLOAT v)
{
return v < LOGZERO / 2;
} // is this number to be considered 0
// ---------------------------------------------------------------------------
// other helpers go here
// ---------------------------------------------------------------------------
// helper to reconstruct the phonetic transcript
/*static*/ std::string lattice::gettranscript(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset)
{
std::string trans;
foreach_index (k, units) // we exploit that units have fixed boundaries
{
if (k > 0)
trans.push_back(' ');
trans.append(hset.gethmm(units[k].unit).getname());
}
return trans;
}
// ---------------------------------------------------------------------------
// forwardbackwardedge() -- perform state-level forward-backward on a single lattice edge
//
// Results:
// - gammas(j,t) for valid time ranges (remaining areas are not initialized)
// - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------
/*static*/ float lattice::forwardbackwardedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
msra::math::ssematrixbase &loggammas, size_t edgeindex)
{
// alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
assert(loggammas.cols() == logLLs.cols() + 2);
msra::math::ssematrixstriperef<msra::math::ssematrixbase> logalphas(loggammas, 1, logLLs.cols()); // shifted views into gammas(,) for alphas and betas
msra::math::ssematrixstriperef<msra::math::ssematrixbase> logbetas(loggammas, 2, logLLs.cols());
// alphas(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
// betas(j,t) store the sum of all paths exiting from state j at time t, not including logLL(j,t)
// gammas(j,t) = alphas(j,t) * betas(j,t) / totalLL
// backward pass --token passing
size_t te = logbetas.cols();
size_t je = logbetas.rows();
float bwscore = 0.0f; // backward score
for (size_t k = units.size() - 1; k + 1 > 0; k--)
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t ts = te - units[k].frames; // end time of current unit
const size_t js = je - n; // range of state indices
// pass in the transition score
// t = ts: exit transition (last frame only or tee transition)
float exitscore = 1e30f; // (something impossible)
if (te == ts) // tee transition
{
exitscore = bwscore + transP(-1, n);
}
else // not tee: expand all last states
{
for (size_t from = 0 /*no tee possible here*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
logbetas(i, te - 1) = bwscore + transP(from, n);
}
}
// expand from states j at time t (not yet including LL) to time t-1
for (size_t t = te - 1; t + 1 > ts /*note: cannot test t >= ts because t < 0 possible*/; t--)
{
for (size_t to = 0; to < n; to++)
{
const size_t j = js + to; // source trellis node
const size_t s = hmm.getsenoneid(to); // senone id for state at position 'to' in the HMM
const float acLL = logLLs(s, t);
if (islogzero(acLL))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) ac score(%d,%d) is zero (%d st, %d fr: %s)\n",
(int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te,
(int) s, (int) t,
(int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());
const float betajt = logbetas(j, t); // sum over all all path exiting from (j,t) to end
const float betajtpll = betajt + acLL; // incorporate acoustic score
if (t > ts)
for (size_t from = 0 /*no transition from entry state*/; from < n; from++)
{
const size_t i = js + from; // target trellis node
const float pathscore = betajtpll + transP(from, to);
if (to == 0)
logbetas(i, t - 1 /*propagate into preceding frame*/) = pathscore;
else
logadd(logbetas(i, t - 1 /*propagate into preceding frame*/), pathscore);
}
else // transition to entry state
{
const float pathscore = betajtpll + transP(-1, to);
if (to == 0)
exitscore = pathscore;
else
logadd(exitscore, pathscore); // propagate into preceding unit
}
}
}
bwscore = exitscore;
if (islogzero(bwscore))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d unit %d (%s) frames [%d,%d) bw score is zero (%d st, %d fr: %s)\n",
(int) edgeindex, (int) k, hmm.getname(), (int) ts, (int) te, (int) logbetas.rows(), (int) logbetas.cols(), gettranscript(units, hset).c_str());
te = ts;
je = js;
}
assert(te == 0 && je == 0);
const float totalbwscore = bwscore;
// forward pass --regular Viterbi
// This also computes the gammas right away.
size_t ts = 0; // start frame for unit 'k'
size_t js = 0; // first row index of unit ' k'
float fwscore = 0.0f; // score passed across phone boundaries
foreach_index (k, units) // we exploit that units have fixed boundaries
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t te2 = ts + units[k].frames; // end time of current unit
const size_t je2 = js + n; // range of state indices
// expand from states j at time t (including LL) to time t+1
for (size_t t = ts; t < te2; t++) // note: loop not entered for 0-frame units (tees)
{
for (size_t to = 0; to < n; to++)
{
const size_t j = js + to; // target trellis node
const size_t s = hmm.getsenoneid(to);
const float acLL = logLLs(s, t);
float alphajtnoll = LOGZERO;
if (t == ts) // entering score
{
const float pathscore = fwscore + transP(-1, to);
alphajtnoll = pathscore;
}
else
for (size_t from = 0 /*no entering possible*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
const float alphaitm1 = logalphas(i, t - 1 /*previous frame*/);
const float pathscore = alphaitm1 + transP(from, to);
logadd(alphajtnoll, pathscore);
}
logalphas(j, t) = alphajtnoll + acLL;
}
// update the gammas --do it here because in next frame, betas get overwritten by alphas (they share memory)
for (size_t j = js; j < je2; j++)
{
if (!islogzero(totalbwscore))
loggammas(j, t) = logalphas(j, t) + logbetas(j, t) - totalbwscore;
else // 0/0 problem, can occur if an ac score is so bad that it is 0 after going through softmax
loggammas(j, t) = LOGZERO;
}
}
// t = te2: exit transition (last frame only or tee transition)
float exitscore;
if (te2 == ts) // tee transition
{
exitscore = fwscore + transP(-1, n);
}
else // not tee: expand all last states
{
exitscore = LOGZERO;
for (size_t from = 0 /*no tee possible here*/; from < n; from++)
{
const size_t i = js + from; // origin trellis node
const float alphaitm1 = logalphas(i, te2 - 1); // newly computed path score, transiting to t=te2
const float pathscore = alphaitm1 + transP(from, n);
logadd(exitscore, pathscore);
}
}
fwscore = exitscore; // score passed on to next unit
js = je2;
ts = te2;
}
assert(js == logalphas.rows() && ts == logalphas.cols());
const float totalfwscore = fwscore;
// in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
// These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
if (islogzero(totalbwscore) ^ islogzero(totalfwscore))
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw 0 score %.10f vs. %.10f (%d st, %d fr: %s)\n",
(int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());
if (islogzero(totalbwscore))
{
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
(int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
return LOGZERO;
}
if (fabsf(totalfwscore - totalbwscore) / ts > 1e-4f)
fprintf(stderr, "forwardbackwardedge: WARNING: edge J=%d fw and bw score %.10f vs. %.10f (%d st, %d fr: %s)\n",
(int) edgeindex, (float) totalfwscore, (float) totalbwscore, (int) js, (int) ts, gettranscript(units, hset).c_str());
// we return the full path score
return totalfwscore;
}
// ---------------------------------------------------------------------------
// alignedge() -- perform Viterbi alignment on a single edge
//
// This is an alternative to forwardbackwardedge() that just uses the best path.
// Results:
// - if not returnsenoneids -> 'binary gammas(j,t)' for valid time ranges (remaining areas are not initialized); MMI-compatible
// - if returnsenoneids -> loggammas(0,t) will contain the senone ids directly instead (for sMBR mode)
// - return value is edge acoustic score
// Gammas matrix must have two extra columns as buffer.
// ---------------------------------------------------------------------------
/*static*/ float lattice::alignedge(const_array_ref<aligninfo> units, const msra::asr::simplesenonehmm &hset, const msra::math::ssematrixbase &logLLs,
msra::math::ssematrixbase &loggammas, size_t edgeindex /*for diagnostic messages*/, const bool returnsenoneids,
array_ref<unsigned short> thisedgealignmentsj)
{
// alphas and betas are stored in-place inside the loggammas matrix shifted by one?two columns
assert(loggammas.cols() == logLLs.cols() + 2);
msra::math::ssematrixstriperef<msra::math::ssematrixbase> backpointers(loggammas, 0, logLLs.cols());
msra::math::ssematrixstriperef<msra::math::ssematrixbase> pathscores(loggammas, 2, logLLs.cols());
// pathscores(j,t) store the sum of all paths up to including state j at time t, including logLL(j,t)
// backpointers(j,t) are the relative states that it came from
// gammas(j,t) <- 1 if on best path, 0 otherwise
const int invalidbp = -2;
// Viterbi alignment
size_t ts = 0; // start frame for unit 'k'
size_t js = 0; // first row index of unit 'k'
float fwscore = 0.0f; // score passed across phone boundaries
int fwbackpointer = -1; // bp passed across phone boundaries, -1 means start of utterance
foreach_index (k, units) // we exploit that units have fixed boundaries
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t n = hmm.getnumstates();
const auto &transP = hmm.gettransP();
const size_t te = ts + units[k].frames; // end time of current unit
const size_t je = js + hmm.getnumstates(); // range of state indices
// expand from states j at time t (including LL) to time t+1
for (size_t t = ts; t < te; t++) // note: loop not entered for 0-frame units (tees)
{
for (size_t j = js; j < je; j++)
{
const size_t to = j - js; // relative state
const size_t s = hmm.getsenoneid(to);
pathscores(j, t) = LOGZERO;
backpointers(j, t) = invalidbp;
if (t == ts) // entering score
{
const float pathscore = fwscore + transP(-1, to);
pathscores(j, t) = pathscore;
backpointers(j, t) = (float) fwbackpointer;
}
else
for (size_t i = js; i < je; i++)
{
const size_t from = i - js;
const float alphaitm1 = pathscores(i, t - 1 /*previous frame*/);
const float pathscore = alphaitm1 + transP(from, to);
if (pathscore > pathscores(j, t))
{
pathscores(j, t) = pathscore;
backpointers(j, t) = (float) i;
}
}
const float acLL = logLLs(s, t);
pathscores(j, t) += acLL;
}
}
// t = te: exit transition (last frame only or tee transition)
float exitscore = LOGZERO;
int exitbackpointer = invalidbp;
if (te == ts) // tee transition
{
exitscore = fwscore + transP(-1, n);
exitbackpointer = fwbackpointer;
}
else // not tee: expand all last states
{
for (size_t i = js; i < je; i++)
{
const size_t from = i - js;
const float alphaitm1 = pathscores(i, te - 1); // newly computed path score, transiting to t=te
const float pathscore = alphaitm1 + transP(from, n);
if (pathscore > exitscore)
{
exitscore = pathscore;
exitbackpointer = (int) i;
}
}
}
if (exitbackpointer == invalidbp)
LogicError("exitbackpointer came up empty");
fwscore = exitscore; // score passed on to next unit
fwbackpointer = exitbackpointer; // and accompanying backpointer
js = je;
ts = te;
}
assert(js == pathscores.rows() && ts == pathscores.cols());
// in extreme cases, we may have 0 ac probs, which lead to 0 path scores and division by 0 (subtracting LOGZERO)
// These cases must be handled separately. If the whole path is 0 (0 prob is on the only path at some point) then skip the lattice.
if (islogzero(fwscore))
{
fprintf(stderr, "alignedge: WARNING: edge J=%d has zero ac. score (%d st, %d fr: %s)\n",
(int) edgeindex, (int) js, (int) ts, gettranscript(units, hset).c_str());
return LOGZERO;
}
// traceback & gamma update
size_t te = backpointers.cols();
size_t je = backpointers.rows();
int j = fwbackpointer;
for (size_t k = units.size() - 1; k + 1 > 0; k--) // go in units because we also need to clear out the column
{
const auto &hmm = hset.gethmm(units[k].unit);
const size_t ts2 = te - units[k].frames; // end time of current unit
const size_t js2 = je - hmm.getnumstates(); // range of state indices
for (size_t t = te - 1; t + 1 > ts2; t--)
{
if (j < (int)js2 || j >= (int) je)
LogicError("invalid backpointer resulting in state index out of range");
int bp = (int) backpointers(j, t); // save the backpointer before overwriting it (gammas and backpointers are aliases of each other)
// thisedgealignmentsj[t] = (unsigned short)hmm.getsenoneid(j - js2);
if (!returnsenoneids) // return binary gammas (for MMI; this mode is compatible with softalignmode)
for (size_t i = js2; i < je; i++)
loggammas(i, t) = ((int) i == j) ? 0.0f : LOGZERO;
else // return senone id (for sMBR; note: NOT compatible with softalignmode; calling code must know this)
thisedgealignmentsj[t] = (unsigned short) hmm.getsenoneid(j - js2);
if (bp == invalidbp)
LogicError("deltabackpointer not initialized");
j = bp; // trace back one step
}
te = ts2;
je = js2;
}
if (j != -1)
LogicError("invalid backpointer resulting in not reaching start of utterance when tracing back");
assert(je == 0 && te == 0);
// we return the full path score
return fwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardlattice() -- lattice-level forward/backward
//
// This computes word posteriors, and also returns the per-node alphas and betas.
// Per-edge acoustic scores are passed in via a lambda, as this function is
// intended for use at multiple places with different scores.
// (Specifically, we also use it to determine a pruning threshold, based on
// the original lattice's ac. scores, before even bothering to compute the
// new ac. scores.)
// ---------------------------------------------------------------------------
double lattice::forwardbackwardlattice(const std::vector<float> &edgeacscores, parallelstate ¶llelstate, std::vector<double> &logpps,
std::vector<double> &logalphas, std::vector<double> &logbetas,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode,
const_array_ref<size_t> &uids, const edgealignments &thisedgealignments,
std::vector<double> &logEframescorrect, std::vector<double> &Eframescorrectbuf, double &logEframescorrecttotal) const
{ // ^^ TODO: remove this
// --- hand off to parallelized (CUDA) implementation if available
if (parallelstate.enabled())
{
double totalfwscore = parallelforwardbackwardlattice(parallelstate, edgeacscores, thisedgealignments, lmf, wp, amf, boostingfactor, logpps, logalphas, logbetas, sMBRmode, uids, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);
return totalfwscore;
}
// if we get here, we have no CUDA, and do it the good ol' way
// allocate return values
logpps.resize(edges.size()); // this is our primary return value
// TODO: these are return values as well, but really shouldn't anymore; only used in some older baseline code we some day may want to compare against
logalphas.assign(nodes.size(), LOGZERO);
logalphas.front() = 0.0f;
logbetas.assign(nodes.size(), LOGZERO);
logbetas.back() = 0.0f;
// --- sMBR version
if (sMBRmode)
{
logEframescorrect.resize(edges.size());
Eframescorrectbuf.resize(edges.size());
std::vector<double> logaccalphas(nodes.size(), LOGZERO); // [i] expected frames-correct count over all paths from start to node i
std::vector<double> logaccbetas(nodes.size(), LOGZERO); // [i] likewise
std::vector<double> logframescorrectedge(edges.size()); // raw counts of correct frames in each edge
// forward pass
foreach_index (j, edges)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logalphas[e.E], pathscore);
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
size_t framescorrect = 0; // count raw number of correct frames
for (size_t t = ts; t < te; t++)
framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
logframescorrectedge[j] = (framescorrect > 0) ? log((double) framescorrect) : LOGZERO; // remember for backward pass
double loginaccs = logaccalphas[e.S] - logalphas[e.S];
logadd(loginaccs, logframescorrectedge[j]);
double logpathacc = loginaccs + logalphas[e.S] + edgescore;
logadd(logaccalphas[e.E], logpathacc);
}
foreach_index (j, logaccalphas)
logaccalphas[j] -= logalphas[j];
const double totalfwscore = logalphas.back();
const double totalfwacc = logaccalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// backward pass and computation of state-conditioned frames-correct count
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inscore = logbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logbetas[e.S], pathscore);
double loginaccs = logaccbetas[e.E] - logbetas[e.E];
logadd(loginaccs, logframescorrectedge[j]);
double logpathacc = loginaccs + logbetas[e.E] + edgescore;
logadd(logaccbetas[e.S], logpathacc);
// sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
if (logpp > 1e-2)
fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
if (logpp > 0.0)
logpp = 0.0;
logpps[j] = logpp;
double tmplogeframecorrect = logframescorrectedge[j];
logadd(tmplogeframecorrect, logaccalphas[e.S]);
logadd(tmplogeframecorrect, logaccbetas[e.E] - logbetas[e.E]);
Eframescorrectbuf[j] = exp(tmplogeframecorrect);
}
foreach_index (j, logaccbetas)
logaccbetas[j] -= logbetas[j];
const double totalbwscore = logbetas.front();
const double totalbwacc = logaccbetas.front();
if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());
if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());
logEframescorrecttotal = totalbwacc;
return totalbwscore;
}
// --- MMI version
// forward pass
foreach_index (j, edges)
{
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
const double pathscore = inscore + edgescore;
logadd(logalphas[e.E], pathscore);
}
const double totalfwscore = logalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// backward pass
// this also computes the word posteriors on the fly, since we are at it
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
const auto &e = edges[j];
const double inscore = logbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
const double pathscore = inscore + edgescore;
logadd(logbetas[e.S], pathscore);
// compute lattice posteriors on the fly since we are at it
double logpp = logalphas[e.S] + edgescore + logbetas[e.E] - totalfwscore;
if (logpp > 1e-2)
fprintf(stderr, "forwardbackward: WARNING: edge J=%d log posterior %.10f > 0\n", (int) j, (float) logpp);
if (logpp > 0.0)
logpp = 0.0;
logpps[j] = logpp;
}
const double totalbwscore = logbetas.front();
if (fabs(totalfwscore - totalbwscore) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());
return totalfwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardlatticesMBR() -- compute expected frame-accuracy counts,
// both the conditioned one (corresponding to c(q) in Dan Povey's thesis)
// and the global one (which is the sMBR criterion to optimize).
//
// Outputs:
// - Eframescorrect[j] == expected frames-correct count conditioned on a state of edge[j].
// We currently assume a hard state alignment. With that, the value turns out
// to be identical for all states of an edge, so we only store it once per edge.
// - return value: expected frames-correct count for entire lattice
//
// Call forwardbackwardlattices() first to compute logalphas/betas.
// ---------------------------------------------------------------------------
double lattice::forwardbackwardlatticesMBR(const std::vector<float> &edgeacscores, const msra::asr::simplesenonehmm &hset,
const std::vector<double> &logalphas, const std::vector<double> &logbetas,
const float lmf, const float wp, const float amf, const_array_ref<size_t> &uids,
const edgealignments &thisedgealignments, std::vector<double> &Eframescorrect) const
{
std::vector<double> accalphas(nodes.size(), 0); // [i] expected frames-correct count over all paths from start to node i
std::vector<double> accbetas(nodes.size(), 0); // [i] likewise
std::vector<size_t> maxcorrect(nodes.size(), 0); // [i] max correct frames up to this node (oracle)
std::vector<double> framescorrectedge(edges.size()); // raw counts of correct frames in each edge
std::vector<int> backpointersformaxcorr(nodes.size(), -2); // keep track of backpointer for the max corr
backpointersformaxcorr.front() = -1;
// forward pass
foreach_index (j, edges)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inaccs = accalphas[e.S];
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
size_t framescorrect = 0; // count raw number of correct frames
for (size_t t = ts; t < te; t++)
framescorrect += (thisedgealignments[j][t - ts] == uids[t]);
framescorrectedge[j] = (double) framescorrect; // remember for backward pass
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
// contribution to end node's path acc = start node's plus edge's correct count, weighted by LL, and divided by sum over LLs
double pathacc = (inaccs + framescorrectedge[j]) * exp(logalphas[e.S] + edgescore - logalphas[e.E]);
accalphas[e.E] += pathacc;
// also keep track of max accuracy, so we can find out whether the lattice contains the correct path
size_t oracleframescorrect = maxcorrect[e.S] + framescorrect; // keep track of most correct path up to end of this edge
if (oracleframescorrect > maxcorrect[e.E])
{
maxcorrect[e.E] = oracleframescorrect;
backpointersformaxcorr[size_t(e.E)] = j;
}
}
const double totalfwacc = accalphas.back();
hset; // just for reference
// report on ground-truth path
// TODO: we will later have code that adds this path if needed
size_t oracleframeacc = maxcorrect.back();
if (oracleframeacc != info.numframes)
fprintf(stderr, "forwardbackwardlatticesMBR: ground-truth path missing from lattice (most correct path: %d out of %d frames correct)\n", (unsigned int) oracleframeacc, (int) info.numframes);
// backward pass and computation of state-conditioned frames-correct count
for (size_t j = edges.size() - 1; j + 1 > 0; j--)
{
if (islogzero(edgeacscores[j])) // indicates that this edge is pruned
continue;
const auto &e = edges[j];
const double inaccs = accbetas[e.E];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
double pathacc = (inaccs + framescorrectedge[j]) * exp(logbetas[e.E] + edgescore - logbetas[e.S]);
accbetas[e.S] += pathacc;
// sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
Eframescorrect[j] = (float) (accalphas[e.S] + accbetas[e.E] + framescorrectedge[j]);
}
const double totalbwacc = accbetas.front();
if (fabs(totalfwacc - totalbwacc) / info.numframes > 1e-4)
fprintf(stderr, "forwardbackwardlatticesMBR: WARNING: lattice fw and bw accs %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());
return totalbwacc;
}
// ---------------------------------------------------------------------------
// bestpathlattice() -- lattice-level "forward/backward" that only returns the
// best path, but in the form of word posteriors, which are 1 or 0, just like
// a real lattice-level forward/backward would do.
// We don't really use this; this was only for a contrast experiment.
// ---------------------------------------------------------------------------
double lattice::bestpathlattice(const std::vector<float> &edgeacscores, std::vector<double> &logpps,
const float lmf, const float wp, const float amf) const
{
// forward pass --sortnedness => regular Viterbi
std::vector<double> logalphas(nodes.size(), LOGZERO);
std::vector<int> backpointers(nodes.size(), -2);
logalphas.front() = 0.0f;
backpointers.front() = -1;
foreach_index (j, edges)
{
const auto &e = edges[j];
const double inscore = logalphas[e.S];
const double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
const double pathscore = inscore + edgescore;
if (pathscore > logalphas[e.E])
{
logalphas[e.E] = pathscore;
backpointers[e.E] = j;
}
}
const double totalfwscore = logalphas.back();
if (islogzero(totalfwscore))
{
fprintf(stderr, "bestpathlattice: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// traceback
// We encode the result by storing log 1 in edges on the best path, and log 0 else;
// this makes it naturally compatible with softalign mode
logpps.resize(edges.size());
foreach_index (j, edges)
logpps[j] = LOGZERO;
int backpos = backpointers[nodes.size() - 1];
while (backpos >= 0)
{
logpps[backpos] = 0.0f; // edge is on best path -> PP = 1.0
backpos = backpointers[edges[backpos].S];
}
assert(backpos == -1);
return totalfwscore;
}
// ---------------------------------------------------------------------------
// forwardbackwardalign() -- compute the statelevel gammas or viterbi alignments
// the first phase of lattice::forwardbackward
//
// Outputs:
// ---------------------------------------------------------------------------
void lattice::forwardbackwardalign(parallelstate ¶llelstate,
const msra::asr::simplesenonehmm &hset, const bool softalignstates,
const double minlogpp, const std::vector<double> &origlogpps,
std::vector<msra::math::ssematrixbase *> &abcs, littlematrixheap &matrixheap,
const bool returnsenoneids,
std::vector<float> &edgeacscores, const msra::math::ssematrixbase &logLLs,
edgealignments &thisedgealignments, backpointers &thisbackpointers, array_ref<size_t> &uids, const_array_ref<size_t> bounds) const
{ // NOTE: this will be removed and replaced by a proper representation of alignments someday
// do forward-backward or alignment on a per-edge basis. This gives us:
// - per-edge gamma[j,t] = P(s(t)==s_j|edge) if forwardbackward, per-edge alignment thisedgealignments[j] if alignment
// - per-edge acoustic scores
const size_t silunitid = hset.gethmmid("sil"); // shall be the same as parallelstate.getsilunitid()
bool parallelsil = true;
bool cpuverification = false;
#ifndef PARALLEL_SIL // we use a define to make this marked
parallelsil = false;
#endif
#ifdef CPU_VERIFICATION
cpuverification = true;
#endif
// Phase 1: abcs allocate
if (!parallelstate.enabled() || !parallelsil || cpuverification) // allocate abcs when 1.parallelstate not enabled (cpu mode); 2. enabled but not PARALLEL_SIL (silence need to be allocate); 3. cpuverfication
{
abcs.resize(edges.size(), NULL); // [edge index] -> alpha/beta/gamma matrices for each edge
size_t countskip = 0; // if pruning: count how many edges are pruned
foreach_index (j, edges)
{
// determine number of frames
// TODO: this is not efficient--we only use a block-diagonal-like structure, rest is empty (exploiting the fixed boundaries)
const size_t edgeframes = nodes[edges[j].E].t - nodes[edges[j].S].t;
if (edgeframes == 0) // dummy !NULL edge at end of lattice
{
if ((size_t) j != edges.size() - 1)
RuntimeError("forwardbackwardalign: unxpected 0-frame edge (only allowed at very end)");
// note: abcs[j] is already initialized to be NULL in this case, which protects us from accidentally using it
}
else
{
// determine the number of states in an edge
const auto &aligntokens = getaligninfo(j); // get alignment tokens
size_t edgestates = 0;
bool edgehassil = false;
foreach_index (i, aligntokens)
if (aligntokens[i].unit == silunitid)
edgehassil = true;
if (!cpuverification && !edgehassil && parallelstate.enabled()) // !cpuverification, parallel & is non sil, we do not allocate
{
abcs[j] = NULL;
continue;
}
foreach_index (k, aligntokens)
edgestates += hset.gethmm(aligntokens[k].unit).getnumstates();
// allocate the matrix
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
countskip++;
else
abcs[j] = &matrixheap.newmatrix(edgestates, edgeframes + 2); // +2 to have one extra column for betas and one for gammas
}
}
if (minlogpp > LOGZERO)
fprintf(stderr, "forwardbackwardalign: %d of %d edges pruned\n", (int) countskip, (int) edges.size());
}
// Phase 2: alignment on CPU
if (parallelstate.enabled() && !parallelsil) // silence edge shall be process separately if not cuda and not PARALLEL_SIL
{
if (softalignstates)
LogicError("forwardbackwardalign: parallelized version currently only handles hard alignments");
if (minlogpp > LOGZERO)
fprintf(stderr, "forwardbackwardalign: pruning not supported (we won't need it!) :)\n");
edgeacscores.resize(edges.size());
for (size_t j = 0; j < edges.size(); j++)
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
if (aligntokens.size() == 0)
continue;
bool edgehassil = false;
foreach_index (i, aligntokens)
{
if (aligntokens[i].unit == silunitid)
edgehassil = true;
}
if (!edgehassil) // only process sil
continue;
const edgeinfowithscores &e = edges[j];
const size_t ts = nodes[e.S].t;
const size_t te = nodes[e.E].t;
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, true, thisedgealignments[j]);
}
}
// Phase 3: alignment on GPU
if (parallelstate.enabled())
parallelforwardbackwardalign(parallelstate, hset, logLLs, edgeacscores, thisedgealignments, thisbackpointers);
// zhaorui align to reference mlf
if (bounds.size() > 0)
{
size_t framenum = bounds.size();
msra::math::ssematrixbase *refabcs;
size_t ts, te, t;
ts = te = 0;
vector<aligninfo> refinfo(1);
vector<unsigned short> refalign(framenum);
array_ref<aligninfo> refunits(refinfo.data(), 1);
array_ref<unsigned short> refedgealignmentsj(refalign.data(), framenum);
while (te < framenum)
{
// found one phone's boundary (ts, te)
t = ts + 1;
while (t < framenum && bounds[t] == 0)
t++;
te = t;
// make one phone unit
size_t phoneid = bounds[ts] - 1;
refunits[0].unit = phoneid;
refunits[0].frames = te - ts;
size_t edgestates = hset.gethmm(phoneid).getnumstates();
littlematrixheap refmatrixheap(1); // for abcs
refabcs = &refmatrixheap.newmatrix(edgestates, te - ts + 2);
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
// do alignment
alignedge((const_array_ref<aligninfo>) refunits, hset, edgeLLs, *refabcs, 0, true, refedgealignmentsj);
for (t = ts; t < te; t++)
{
uids[t] = (size_t) refedgealignmentsj[t - ts];
}
ts = te;
}
}
// Phase 4: alignment or forwardbackward on CPU for non parallel mode or verification
if (!parallelstate.enabled() || cpuverification) // non parallel mode or verification
{
edgeacscores.resize(edges.size());
std::vector<float> edgeacscoresgpu;
edgealignments thisedgealignmentsgpu(thisedgealignments);
if (cpuverification)
{
parallelstate.getedgeacscores(edgeacscoresgpu);
parallelstate.copyalignments(thisedgealignmentsgpu);
}
foreach_index (j, edges)
{
const edgeinfowithscores &e = edges[j];
const size_t ts = nodes[e.S].t;
const size_t te = nodes[e.E].t;
if (ts == te) // dummy !NULL edge at end
edgeacscores[j] = 0.0f;
else
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
const auto edgeLLs = msra::math::ssematrixstriperef<msra::math::ssematrixbase>(const_cast<msra::math::ssematrixbase &>(logLLs), ts, te - ts);
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp)
edgeacscores[j] = LOGZERO; // will kill word level forwardbackward hypothesis
else if (softalignstates)
edgeacscores[j] = forwardbackwardedge(aligntokens, hset, edgeLLs, *abcs[j], j);
else
edgeacscores[j] = alignedge(aligntokens, hset, edgeLLs, *abcs[j], j, returnsenoneids, thisedgealignments[j]);
}
if (cpuverification)
{
const auto &aligntokens = getaligninfo(j); // get alignment tokens
bool edgehassil = false;
foreach_index (i, aligntokens)
{
if (aligntokens[i].unit == silunitid)
edgehassil = true;
}
if (fabs(edgeacscores[j] - edgeacscoresgpu[j]) > 1e-3)
{
fprintf(stderr, "edge %d, sil ? %d, edgeacscores / edgeacscoresgpu MISMATCH %f v.s. %f, diff %e\n",
j, edgehassil ? 1 : 0, (float) edgeacscores[j], (float) edgeacscoresgpu[j],
(float) (edgeacscores[j] - edgeacscoresgpu[j]));
fprintf(stderr, "aligntokens: ");
foreach_index (i, aligntokens)
fprintf(stderr, "%d %d; ", i, aligntokens[i].unit);
fprintf(stderr, "\n");
}
for (size_t t = ts; t < te; t++)
{
if (thisedgealignments[j][t - ts] != thisedgealignmentsgpu[j][t - ts])
fprintf(stderr, "edge %d, sil ? %d, time %d, alignment / alignmentgpu MISMATCH %d v.s. %d\n", j, edgehassil ? 1 : 0, (int) (t - ts), thisedgealignments[j][t - ts], thisedgealignmentsgpu[j][t - ts]);
}
}
}
}
}
// compute the error signal for sMBR mode
void lattice::sMBRerrorsignal(parallelstate ¶llelstate,
msra::math::ssematrixbase &errorsignal, msra::math::ssematrixbase &errorsignalneg, // output
const std::vector<double> &logpps, const float amf,
double minlogpp, const std::vector<double> &origlogpps, const std::vector<double> &logEframescorrect,
const double logEframescorrecttotal, const edgealignments &thisedgealignments) const
{
if (parallelstate.enabled()) // parallel version
{
/* time measurement for parallel sMBRerrorsignal
errorsignalcompute: 19.871935 ms (cuda) v.s. 448.711444 ms (emu) */
if (minlogpp > LOGZERO)
fprintf(stderr, "sMBRerrorsignal: pruning not supported (we won't need it!) :)\n");
parallelsMBRerrorsignal(parallelstate, thisedgealignments, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg);
return;
}
// linear mode
foreach_coord (i, j, errorsignal)
errorsignal(i, j) = 0.0f; // Note: we don't actually put anything into the numgammas
foreach_index (j, edges)
{
const auto &e = edges[j];
if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
continue;
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp) // this is pruned
continue;
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
const double diff = logEframescorrect[j] - logEframescorrecttotal;
// Note: the contribution of the states of an edge to their senones is the same for all states
// so we compute it once and add it to all; this will not be the case without hard alignments.
const double pp = exp(logpps[j]); // edge posterior
const float edgecorrect = (float) (pp * diff) / amf;
for (size_t t = ts; t < te; t++)
{
const size_t s = thisedgealignments[j][t - ts];
errorsignal(s, t) += edgecorrect;
}
}
}
// compute the error signal for MMI mode
void lattice::mmierrorsignal(parallelstate ¶llelstate, double minlogpp, const std::vector<double> &origlogpps,
std::vector<msra::math::ssematrixbase *> &abcs, const bool softalignstates,
const std::vector<double> &logpps, const msra::asr::simplesenonehmm &hset,
const edgealignments &thisedgealignments, msra::math::ssematrixbase &errorsignal) const
{
if (parallelstate.enabled())
{
if (minlogpp > LOGZERO)
fprintf(stderr, "mmierrorsignal: pruning not supported (we won't need it!) :)\n");
if (softalignstates)
LogicError("mmierrorsignal: parallel version for softalignstates mode is not supported yet");
parallelmmierrorsignal(parallelstate, thisedgealignments, logpps, errorsignal);
return;
}
for (size_t j = 0; j < (errorsignal).cols(); j++)
for (size_t i = 0; i < (errorsignal).rows(); i++)
errorsignal(i, j) = VIRGINLOGZERO; // set to zero --note: may be in-place with logLLs, which now get overwritten
// size_t warnings = 0; // [v-hansu] check code for mmi; search this comment to see all related codes
foreach_index (j, edges)
{
const auto &e = edges[j];
if (nodes[e.S].t == nodes[e.E].t) // this happens for dummy !NULL edge at end of file
continue;
if (minlogpp > LOGZERO && origlogpps[j] < minlogpp) // this is pruned
continue;
const auto &aligntokens = getaligninfo(j); // get alignment tokens
auto &loggammas = *abcs[j];
const float edgelogP = (float) logpps[j];
// if (islogzero (edgelogP)) // we had a 0 prob
// continue;
// accumulate this edge's gamma matrix into target posteriors
const size_t tedge = nodes[e.S].t;
size_t ts = 0; // time index into gamma matrix
size_t js = 0; // state index into gamma matrix
foreach_index (k, aligntokens) // we exploit that units have fixed boundaries
{
const auto &unit = aligntokens[k];
const size_t te = ts + unit.frames;
const auto &hmm = hset.gethmm(unit.unit); // TODO: inline these expressions
const size_t n = hmm.getnumstates();
const size_t je = js + n;
// P(s) = P(s|e) * P(e)
for (size_t t = ts; t < te; t++)
{
const size_t tutt = t + tedge; // time index w.r.t. utterance
// double logsum = LOGZERO; // [v-hansu] check code for mmi; search this comment to see all related codes
for (size_t i = 0; i < n; i++)
{
const size_t j2 = js + i; // state index for this unit in matrix
const size_t s = hmm.getsenoneid(i); // state class index
const float gammajt = loggammas(j2, t);
const float statelogP = edgelogP + gammajt;
logadd(errorsignal(s, tutt), statelogP);
}
}
ts = te;
js = je;
}
assert(ts + 2 == loggammas.cols() && js == loggammas.rows());
}
// check normalizedness (is that an actual English word?)
// also count non-zero probs
size_t nonzerostates = 0;
foreach_column (t, errorsignal)
{
double logsum = LOGZERO;
foreach_row (s, errorsignal)
{
if (islogzero(errorsignal(s, t)))
nonzerostates++;
else
logadd(logsum, (double) errorsignal(s, t));
// TODO: count VIRGINLOGZERO, print per frame
}
if (fabs(logsum) / errorsignal.rows() > 1e-6)
fprintf(stderr, "forwardbackward: WARNING: overall posterior column(%d) sum = exp (%.10f) != 1\n", (int) t, logsum);
}
fprintf(stderr, "forwardbackward: %.3f%% non-zero state posteriors\n", 100.0f - nonzerostates * 100.0f / errorsignal.rows() / errorsignal.cols());
// convert to non-log posterior --that's what we return
foreach_coord (i, j, errorsignal)
errorsignal(i, j) = expf(errorsignal(i, j));
}
// compute ground truth's score
// It is critical to get all details consistent with the lattice, to avoid skewing the weights.
// ... TODO: we don't need this to be a class member, actually; try to just make it a 'static' function.
/*static*/ double lattice::scoregroundtruth(const_array_ref<size_t> uids, const_array_ref<htkmlfwordsequence::word> transcript, const std::vector<float> &transcriptunigrams,
const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset, const float lmf, const float wp, const float amf)
{
if (transcript[0].firstframe != 0) // TODO: should we store the #frames instead? Then we can validate the total duration
LogicError("scoregroundtruth: first transcript[] token does not start at frame 0");
// get the silence models, since they are treated specially
const size_t numframes = logLLs.cols();
const auto &sil = hset.gethmm(hset.gethmmid("sil"));
const auto &sp = hset.gethmm(hset.gethmmid("sp"));
if (sp.numstates != 1 || sil.numstates != 3)
RuntimeError("scoregroundtruth: only supports 1-state /sp/ and 3-state /sil/ tied to /sp/");
const size_t silst = sp.senoneids[0];
// loop over words
double pathscore = 0.0;
foreach_index (i, transcript)
{
size_t ts = transcript[i].firstframe;
size_t te = ((size_t) i + 1 < transcript.size()) ? transcript[i + 1].firstframe : numframes;
if (ts >= te)
LogicError("scoregroundtruth: transcript[] tokens out of order");
// acoustic score: loop over frames
const msra::asr::simplesenonehmm::transP *prevtransP = NULL; // previous transP
int prevs = -1; // previous state index
for (size_t t = ts; t < te; t++)
{
size_t senoneid = uids[t];
// recover the transP and state index
int s;
const msra::asr::simplesenonehmm::transP *transP;
if (senoneid == silst)
{
if (prevtransP == &sil.gettransP()) // "silst" may be tied to /sp/, and thus the /sil/ center state may be ambiguous
{
transP = prevtransP; // remain in /sil/
s = 1; // note that this will fail if /sp/ can follow /sil/ and if /sil/ may end with "silst" (both not allowed currently)
}
else // "silst" -> we are in /sp/
{
transP = &sp.gettransP();
s = 0;
}
}
else // all others must be non-ambiguous
{
int transPindex = hset.senonetransP(senoneid);
int sindex = hset.senonestate(senoneid);
if (transPindex == -1 || sindex == -1)
RuntimeError("scoregroundtruth: failed to resolve ambiguous senone %s", hset.getsenonename(senoneid));
transP = &hset.transPs[transPindex];
s = sindex;
}
// if changing phoneme then add necessary enter/exit score
if (transP != prevtransP)
{
if (prevtransP) // previous exit transition
pathscore += (*prevtransP)(prevs, prevtransP->getnumstates());
prevs = -1; // enter transition
}
// add inter-state transP
pathscore += (*transP)(prevs, s);
// acoustic score
pathscore += logLLs(senoneid, t);
// remember state for next frame
prevtransP = transP;
prevs = s;
}
// add the last exit transition
if (prevtransP) // previous exit transition
pathscore += (*prevtransP)(prevs, prevtransP->getnumstates());
// need to add a /sp/ tee transition if we don't end in sp, since our transcript dictionary includes a /sp/ at the end of each entry
if (uids[te - 1] != sp.senoneids[0])
pathscore += sp.gettransP()(-1, sp.numstates);
// lm
pathscore += transcriptunigrams[i] * lmf + wp;
}
if (islogzero(pathscore))
fprintf(stderr, "scoregroundtruth: ground-truth path has zero probability; some model inconsistency, maybe?\n");
// account for amf
pathscore /= amf;
fprintf(stderr, "scoregroundtruth: ground-truth score %.6f (%d frames)\n", pathscore, (int) numframes);
return pathscore;
}
// ---------------------------------------------------------------------------
// sMBRdiagnostics() -- helper to print some diagnostics for analyzing sMBR results
// ---------------------------------------------------------------------------
// static // with 'static', compiler will complain if the function is not used (we only compile it in sometimes for diagnostics)
void sMBRdiagnostics(const msra::math::ssematrixbase &errorsignal, const_array_ref<size_t> uids,
const_array_ref<size_t> bestpath, const vector<bool> &refseen, const msra::asr::simplesenonehmm &hset)
{
// TODO:
// - print best positive runner-up
// - WARN tag if neg > pos
// - check the sum and warn if not 0
// - indicate whether best path state is correct or not
size_t numcor = 0;
size_t numnegbetter = 0; // # frames the neg competitor is better
size_t numposbetter = 0; // # frames the pos competitor is better
foreach_column (t, errorsignal)
{
const size_t sref = uids[t];
const char *srefname = hset.getsenonename(sref);
const size_t sbest = bestpath[t];
const char *sbestname = hset.getsenonename(sbest);
if (sref == sbest)
numcor++;
// for each frame, print error signal for ground truth and runner up (second largest abs value)
size_t sneg = SIZE_MAX; // competitor
float eneg = 0.0f;
size_t spos = SIZE_MAX; // best positive competitor
float epos = 0.0f;
foreach_row (s, errorsignal)
{
if (s == sref)
continue;
if (errorsignal(s, t) < eneg)
{
sneg = s;
eneg = errorsignal(s, t);
}
if (errorsignal(s, t) > epos)
{
spos = s;
epos = errorsignal(s, t);
}
}
if (fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < -eneg)
numnegbetter++;
if (fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < epos)
numposbetter++;
const char *snegname = sneg == SIZE_MAX ? "-" : hset.getsenonename(sneg);
const char *sposname = spos == SIZE_MAX ? "-" : hset.getsenonename(spos);
fprintf(stderr, "e(%d): ref %s: %.6f / %s: %.6f / %s: %.6f / top %s: %.6f%s%s%s%s%s\n",
(int) t, srefname, errorsignal(sref, t), snegname, eneg, sposname, epos, sbestname, errorsignal(sbest, t),
sbest == sref ? "" : " ERR",
fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < 0 ? " INV!!" : "",
fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < -eneg ? " WEAK" : "",
fabs(errorsignal(sref, t)) > 0.0001f && errorsignal(sref, t) < epos ? " 2ND" : "",
refseen[t] ? "" : " NOREF");
}
// print this to validate our bestpath computation
fprintf(stderr, "sMBRdiagnostics: %d frames correct out of %d, %.2f%%, neg better in %d, pos in %d\n",
(int) numcor, (int) errorsignal.cols(), 100.0f * numcor / errorsignal.cols(),
(int) numnegbetter, (int) numposbetter);
}
// static // with 'static', compiler will complain if the function is not used (we only compile it in sometimes for diagnostics)
void sMBRsuppressweirdstuff(msra::math::ssematrixbase &errorsignal, const_array_ref<size_t> uids)
{
size_t numweird = 0;
foreach_column (t, errorsignal)
{
const size_t sref = uids[t];
// for each frame, print error signal for ground truth and runner up (second largest abs value)
const float eref = errorsignal(sref, t);
bool isweird = eref < 0.0f; // negative for reference!?
for (size_t s = 0; s < errorsignal.rows() && !isweird; s++)
{
if (s == sref)
continue;
if (fabs(errorsignal(s, t)) > eref)
isweird = true;
}
if (isweird)
{
foreach_row (s, errorsignal)
errorsignal(s, t) = 0.0f;
numweird++;
}
}
// print this to validate our bestpath computation
fprintf(stderr, "sMBRsuppressweirdstuff: %d weird frames out of %d, %.2f%% were flattened\n",
(int) numweird, (int) errorsignal.cols(), 100.0f * numweird / errorsignal.cols());
}
// ---------------------------------------------------------------------------
// forwardbackward() -- main function for MMI/sMBR
//
// This computes the lattice state-level statistics for sequence training using MMI or sMBR.
//
// Outputs, MMI mode:
// - result = dengammas = denominator gammas (non-log form)
// - returns log of sum over all paths' likelihoods (the denominator of the MMI objective)
// - note: numgammas is not used/touched in MMI mode
//
// Outputs, sMBR mode:
// TODO: fix this comment
// - result = errorsignal = (abs value of) negative contributions to error signal
// - errorsignalbuf = for temporarily use to get errorsignal
// - returns expected frames-correct count (the sMBR objective)
// ---------------------------------------------------------------------------
double lattice::forwardbackward(parallelstate ¶llelstate, const msra::math::ssematrixbase &logLLs, const msra::asr::simplesenonehmm &hset,
msra::math::ssematrixbase &result, msra::math::ssematrixbase &errorsignalbuf,
const float lmf, const float wp, const float amf, const float boostingfactor,
const bool sMBRmode, array_ref<size_t> uids, const_array_ref<size_t> bounds,
const_array_ref<htkmlfwordsequence::word> transcript, const std::vector<float> &transcriptunigrams) const
{
bool softalign = true;
bool softalignstates = false; // true if soft alignment within edges, currently we only support soft within edge in cpu mode
bool softalignlattice = softalign; // w.r.t. whole lattice
edgealignments thisedgealignments(*this); // alignments memory allocate for this lattice
backpointers thisbackpointers(*this, hset); // memory for forwardbackward
if (info.numframes != logLLs.cols())
LogicError("forwardbackward: #frames mismatch between lattice (%d) and LLs (%d)", (int) info.numframes, (int) logLLs.cols());
// TODO: the following checks should throw, but I don't dare in case this will crash a critical job... if we never see this warning, then
if (info.numframes != uids.size())
fprintf(stderr, "forwardbackward: #frames mismatch between lattice (%d) and uids (%d)\n", (int) info.numframes, (int) uids.size());
if (info.numframes != result.cols())
fprintf(stderr, "forwardbackward: #frames mismatch between lattice (%d) and result (%d)\n", (int) info.numframes, (int) result.cols());
littlematrixheap matrixheap(info.numedges); // for abcs
// PHASE 0: fake word level forward backwards --only used when pruning enabled
const double minlogpp = LOGZERO; // pruning threshold --LOGZERO means disabled
std::vector<double> origlogpps; // word posterior from original lattice, for pruning decision
// PHASE 1: per-edge forward backwards (="time alignments")
// score the ground truth --only if a transcript is provided, which happens if the user provides a language model
// TODO: no longer used, remove this. 'transcript' parameter is no longer used in this function.
transcript;
transcriptunigrams;
// allocate alpha/beta/gamma matrices (all are sharing the same memory in-place)
std::vector<msra::math::ssematrixbase *> abcs;
std::vector<float> edgeacscores; // [edge index] acoustic scores
// funcation call for forwardbackward on edge level
forwardbackwardalign(parallelstate, hset, softalignstates, minlogpp, origlogpps, abcs, matrixheap, sMBRmode /*returnsenoneids*/, edgeacscores, logLLs, thisedgealignments, thisbackpointers, uids, bounds);
// PHASE 2: lattice-level forward backward
// we exploit that the lattice is sorted by (end node, start node) for in-place processing
// checklattice(); // comment out by v-hansu to save time
#ifdef PRINT_TIME_MEASUREMENT
auto_timer latlevelfwbw;
#endif
std::vector<double> logpps;
std::vector<double> Eframescorrectbuf; // this is used for compute the Eframescorrectdiff
std::vector<double> logEframescorrect; // this is the final output of PHASE 2
std::vector<double> logalphas;
std::vector<double> logbetas;
double totalfwscore = 0; // TODO: name no longer precise in sMBRmode
double logEframescorrecttotal = LOGZERO;
bool returnEframescorrect = sMBRmode;
if (softalignlattice)
{
totalfwscore = forwardbackwardlattice(edgeacscores, parallelstate, logpps, logalphas, logbetas, lmf, wp, amf, boostingfactor, returnEframescorrect, (const_array_ref<size_t> &) uids, thisedgealignments, logEframescorrect, Eframescorrectbuf, logEframescorrecttotal);
if (sMBRmode && !returnEframescorrect)
logEframescorrecttotal = forwardbackwardlatticesMBR(edgeacscores, hset, logalphas, logbetas, lmf, wp, amf, (const_array_ref<size_t> &) uids, thisedgealignments, Eframescorrectbuf);
// ^^ BUGBUG not tested
}
else
totalfwscore = bestpathlattice(edgeacscores, logpps, lmf, wp, amf);
#ifdef PRINT_TIME_MEASUREMENT
latlevelfwbw.show("latlevelfwbw"); // 68.395682 ms
#endif
if (islogzero(totalfwscore))
{
fprintf(stderr, "forwardbackward: WARNING: no path found in lattice (%d nodes/%d edges)\n", (int) nodes.size(), (int) edges.size());
return LOGZERO; // failed, do not use resulting matrix
}
// PHASE 3: compute final state-level posteriors (MMI mode)
// compute expected frames correct in sMBRmode
const size_t numframes = logLLs.cols();
assert(numframes == info.numframes);
// fprintf (stderr, "forwardbackward: total forward score %.6f (%d frames)\n", totalfwscore, (int) numframes); // for now--while we are debugging the GPU port
// MMI mode
if (!sMBRmode)
{
// we first take the sum in log domain to avoid numerical issues
auto &dengammas = result; // result is denominator gammas
mmierrorsignal(parallelstate, minlogpp, origlogpps, abcs, softalignstates, logpps, hset, thisedgealignments, dengammas);
return totalfwscore / numframes; // return value is av. posterior
}
// sMBR mode
else
{
auto &errorsignal = result;
sMBRerrorsignal(parallelstate, errorsignal, errorsignalbuf, logpps, amf, minlogpp, origlogpps, logEframescorrect, logEframescorrecttotal, thisedgealignments);
static bool dummyvariable = (fprintf(stderr, "note: new version with kappa adjustment, kappa = %.2f\n", 1 / amf), true); // we only print once
return exp(logEframescorrecttotal) / numframes; // return value is av. expected frame-correct count
}
}
};
};