https://github.com/Microsoft/CNTK
Tip revision: 5c3f708097bdcdaf2c06d2aa8a9b3fdc772ae27a authored by Mark Hillebrand on 18 January 2016, 08:36:30 UTC
License change
License change
Tip revision: 5c3f708
latticefunctionskernels.h
// latticefunctionskernels.cu(.h) -- kernels for lattice ops intended for use with CUDA, to be called from actual CUDA kernels.
//
// To make this compile for the CPU (emulation for testing), add this line before #including this:
// #define __device__
//
// F. Seide, V-hansu
#define FORBID_INVALID_SIL_PATHS // [v-hansu] prune path that start from sil(sp) and go into sil, only used with addsil is adopted
#pragma once
#pragma push_macro ("__device__")
#pragma push_macro ("atomicAdd")
#pragma push_macro ("atomicCAS")
#include "latticestorage.h"
#include <limits>
namespace msra { namespace cuda { class passtextureref; }}
#ifdef CPUONLY
#define __kernel_emulation__
#endif
#ifdef __kernel_emulation__
#include "math.h" // to get exp() and log() compiled correctly with c++
#include<stdexcept>
using namespace std;
#ifndef __device__
#define __device__
#endif
#define CUDART_MIN_DENORM_F numeric_limits<float>::denorm_min()
#define atomicAdd(address,value) (*(address)+=(value)) // don't forget to #undef (#praga pop_macro)! Otherwise CUDA might compile with this...
#define atomicCAS(address, compare, val) *address; *address = *address == compare ? val : *address;
#define __double_as_longlong(in) (*(unsigned long long int *) &in)
#define __longlong_as_double(in) (*(double *) &in)
#define __float_as_int(in) (*(int *) &in)
#define __int_as_float(in) (*(float *) &in)
#else // TODO: remove this once we got this figured out
#include "math_constants.h"
#if __CUDA_ARCH__ < 200
//#warning Sequence training not supported on 1.x CUDA machines.
#define force_crash() (*((int*)-1)=0) // TODO: this does not in fact seem to crash it...
#define atomicAdd(a,v) (force_crash(),*(a)=v) // force a crash if used with 1.x devices
#define atomicCAS(address, compare, val) (*(address) = compare + val, *((int*)-1)=0)
#define __double_as_longlong(in) (force_crash(), in)
#define __longlong_as_double(in) (force_crash(), in)
#define __float_as_int(in) (force_crash(), in)
#define __int_as_float(in) (force_crash(), in)
#endif
#endif
namespace msra { namespace lattices {
struct somedata // example type to have a pattern to copy from
{
size_t fortytwo;
};
struct empty {};
// Note that the code that uses this (edgealignmentj) will assume 3-state left-to-right
// except for /sil/ and /sp/ which are treated specially.
// TODO: either check when creating this whether this assumption is true, or control this through a flag in here.
struct lr3transP // lr3 = 3-state left-to-right architecture
{
static const size_t MAXSTATES = 3;
size_t numstates;
float loga[MAXSTATES+1][MAXSTATES+1];
lr3transP ()
{
#ifdef INITIAL_STRANGE
numstates = 3;
for (size_t i = 0; i < NUMSTATES+1; i++)
for (size_t j = 0; j < NUMSTATES+1; j++)
{
loga[i][j] = LOGZERO;
}
#endif
}
};
struct lrhmmdef // left-to-right HMM (no /sil/)
{
static const size_t MAXSTATES = 3; // we use a fixed memory allocation since it's almost always 3 anyway
unsigned char transPindex; // index of monophone to find transP matrix
unsigned char numstates; // number of states; code supports only either 1 or 3
unsigned short senoneids[MAXSTATES]; // [0..numstates-1] senone indices
size_t getsenoneid (size_t i) const { return (size_t) senoneids[i]; }
size_t getnumstates() const { return (size_t) numstates; }
const struct lr3transP & gettransP(const lr3transP * transPs) const { return transPs[transPindex]; }
lrhmmdef ()
{
#ifdef INITIAL_STRANGE
transPindex = unsigned char (-1);
numstates = unsigned char (-1);
for (size_t i = 0; i < MAXSTATES; i++)
{
senoneids[i] = unsigned short (-1);
}
#endif
}
};
#if 1 // straight-forward version
#else // CUDA hacked version
hmm gethmm (hmms, i)
{
ushort4 u4 = *((ushort4) &hmms[i]);
lrhmmdef hmm;
hmm.transPindex = u4.x & 0xff;
hmm.numstates = u4.x >> 8;
hmm.senoneids[0] = u4.y;
hmm.senoneids[1] = u4.z;
hmm.senoneids[2] = u4.w;
}
#endif
#ifndef LOGZERO
#define LOGZERO -1e30f
#endif
class bpmatrixref
{
private:
unsigned short * p; // pointer in CUDA space of this device
size_t numrows; // rows()
size_t numcols; // cols()
size_t colstride; // height of column = rows() rounded to multiples of 4
__device__ void checkbounds (size_t i, size_t j) const
{
if (i >= numrows || j >= numcols)
#ifdef __kernel_emulation__
throw::logic_error ("out of boundary!!!");
#else
*((int*)-1)=0;
#endif
}
__device__ size_t locate (size_t i, size_t j) const
{
checkbounds (i,j);
return j * colstride + i;
} // matrix in column-wise storage
public:
__device__ bpmatrixref (unsigned short * address, size_t n, size_t m)
{
numrows = n;
numcols = m;
colstride = n;
p = address;
}
__device__ unsigned short & operator() (size_t i, size_t j) { return p[locate(i,j)]; }
__device__ const unsigned short & operator() (size_t i, size_t j) const { return p[locate(i,j)]; }
};
// this class contains all-static methods that are inner pieces of thread kernels for use with CUDA
struct latticefunctionskernels
{
// [v-hansu] mimic the float version of atomicCAS
static __device__ float atomicCASfloatdouble (float *address, float compare, float val)
{
int * intaddress = (int *) address;
int intcompare = __float_as_int(compare); // __double_as_longlong : read double as unsigned long long
int intval = __float_as_int(val);
int result = atomicCAS (intaddress, intcompare, intval);
return __int_as_float(result);
}
static __device__ double atomicCASfloatdouble (double *address, double compare, double val)
{
unsigned long long int * longlongintaddress = (unsigned long long int *) address;
unsigned long long int longlongintcompare = __double_as_longlong (compare); // __double_as_longlong : read double as unsigned long long
unsigned long long int longlongintval = __double_as_longlong (val);
unsigned long long int result = atomicCAS(longlongintaddress, longlongintcompare, longlongintval);
return __longlong_as_double(result);
}
static __device__ void logaddratio (double & loga, double diff)
{
if (diff < -37.0f) return; // log (2^-53), 52-bit mantissa -> cut of after 53th bit
loga += log (1.0 + exp (diff));
}
static __device__ void logaddratio (float & loga, float diff)
{
if (diff < -17.0f) return; // log (2^-24), 23-bit mantissa -> cut of after 24th bit
loga += logf (1.0f + expf (diff));
}
template<typename FLOAT> static __device__ void swap (FLOAT & left, FLOAT & right)
{ // exchange values stored at _Left and _Right
FLOAT tmp = left;
left = right;
right = tmp;
}
// overloads for exp() for float and double, so that we can use templates
static __device__ float expfd (float x) { return ::expf (x); } // TODO: ain't there an overload for this?
static __device__ double expfd (double x) { return ::exp (x); }
// Compute the difference of two numbers, which are represented as their logs.
// The return value is a non-log value. exp(loga) - exp(logb)
template<typename FLOAT>
static __device__ FLOAT expdiff (FLOAT loga, FLOAT logb)
{
if (logb < loga) // logb - loga < 0 => exp(logb-loga) < 1
return expfd (loga) * (1 - expfd (logb - loga));
else
return -expfd (logb) * (1 - expfd (loga - logb));
}
// loga <- log (exp (loga) + exp (logb)) = log (exp (loga) * (1.0 + exp (logb - loga)) = loga + log (1.0 + exp (logb - loga))
template<typename FLOAT> static __device__ void logadd (FLOAT & loga, FLOAT logb)
{
if (logb > loga) // we add smaller to bigger
swap (loga, logb); // loga is bigger
if (loga <= LOGZERO) // both are 0
return;
logaddratio (loga, logb - loga);
}
// does the same as above but if the bigger one is too small, we assign a small value to it
template<typename FLOAT> static __device__ void logaddseen (FLOAT & loga, FLOAT logb)
{
if (logb > loga) // we add smaller to bigger
swap (loga, logb); // loga is bigger
if (loga <= LOGZERO) // both are 0
{
loga = logf(CUDART_MIN_DENORM_F); // [v-hansu] we hope to separate LOGZERO (unseen states) and logf(CUDART_MIN_DENORM_F) (seen states with small prob)
return;
}
logaddratio (loga, logb - loga);
}
#if 1
static inline __device__ float bitsasfloat ( int b) { return __int_as_float (b); }
static inline __device__ double bitsasfloat (unsigned long long int b) { return __longlong_as_double (b); }
static inline __device__ int floatasbits (float f) { return __float_as_int (f); }
static inline __device__ unsigned long long int floatasbits (double f) { return __double_as_longlong (f); }
template<typename FLOAT> // adapted from [http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ixzz32EuzZjxV]
static __device__ FLOAT atomicLogAdd (FLOAT * address, FLOAT val) // direct adaptation from NVidia source code
{
typedef decltype (floatasbits (val)) bitstype;
bitstype * address_as_ull = (bitstype *) address;
bitstype old = *address_as_ull, assumed;
do {
assumed = old;
FLOAT sum = bitsasfloat (assumed);
logaddseen (sum, val);
old = atomicCAS (address_as_ull, assumed, floatasbits (sum));
} while (assumed != old);
// note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always)
return bitsasfloat(old);
}
#else // this code does not work because (assumed != old) will not compare correctly in case of NaNs
//same pattern as atomicAdd(), but performing the log-add operation instead
template<typename FLOAT> static __device__ FLOAT atomicLogAdd (FLOAT * address, FLOAT val)
{
FLOAT old = *address;
FLOAT assumed;
FLOAT logaddresult;
do {
assumed = old; // old is the assumed value at address
logaddresult = assumed; // for next step to compute logaddresult, assumed shall be the same as before
logaddseen (logaddresult, val);
old = atomicCASfloatdouble (address, assumed, logaddresult);
} while (assumed != old); // if old == assumed, the *address is not changed in this loop, so this is safe
return old;
}
#endif
// [v-hansu] shuffling accessing order for a item in cubic(Ni, Nj, Nk) with index i, j, k according to shufflemode
static inline __device__ size_t shuffle (size_t i, size_t Ni, size_t j, size_t Nj, size_t k, size_t Nk, size_t shufflemode)
{
if (shufflemode == 0)
return i + j * Ni + k * Nj * Ni;
else if (shufflemode == 1) // inverse
return k + j * Nk + i * Nj * Nk;
else if (shufflemode == 2) // flip i and j
return j + i * Nj + k * Ni * Nj;
else if (shufflemode == 3) // flip j and k
return i + k * Ni + j * Nk * Ni;
else if (shufflemode == 4)
return j + k * Nj + i * Nk * Nj;
else
*((int*)-1)=0; // shall not get here, WRONG
return 0;
}
template<typename doublevector>
static __device__ void setvaluej (size_t j, doublevector & thisvector, double value)
{
thisvector[j] = value;
}
//zhaorui
static inline __device__ float getlogtransp (lr3transP transP, int from, int to)
{
/*if (from < -1 || from >= transP.MAXSTATES || to > transP.MAXSTATES)
{
//printf("from: %d to: %d\n", from, to);
return LOGZERO;
}*/
return transP.loga[from+1][to];
}
template<typename lrhmmdefvector, typename lr3transPvector, typename matrix, typename nodeinfovector, typename edgeinfowithscoresvector, typename aligninfovector, typename ushortvector, typename uintvector, typename floatvector, typename sizetvector>
static inline __device__ void edgealignmentj (size_t j, const lrhmmdefvector & hmms, const lr3transPvector & transPs, const size_t spalignunitid,
const size_t silalignunitid, const matrix & logLLs, const nodeinfovector & nodes,
const edgeinfowithscoresvector & edges, const aligninfovector & aligns,
const uintvector & alignoffsets, ushortvector & backptrstorage, const sizetvector & backptroffsets,
ushortvector & alignresult, floatvector & edgeacscores)
{ // TODO: alignresult will change to (start,end)
// mostly finished
// some preparation
size_t as = edges[j].firstalign; // align start
size_t ae = (j+1) < edges.size() ? (size_t) edges[j+1].firstalign : aligns.size();
if (as == ae) // the last empty alignment
return;
size_t ts = nodes[edges[j].S].t;
float fwscore = 0.0f; // score passed across phone boundaries
size_t alignindex = alignoffsets[j]; // index to set (result)
#ifndef PARALLEL_SIL
const bool isSil = (aligns[as].unit == silalignunitid || aligns[ae-1].unit == silalignunitid);
if (isSil) return; // we do not support silence edge now, which is computed by cpu, may change when we support it
#endif
// Viterbi alignment
for (size_t k = as; k < ae; k++)
{
const aligninfo align = aligns[k];
const size_t numframes = align.frames;
const bool isSp = (align.unit == spalignunitid);
const bool isSil = (align.unit == silalignunitid);
const lrhmmdef hmm = hmms[align.unit];
const lr3transP transP = transPs[hmm.transPindex];
// pre-fetch senone ids into registers
size_t senoneid0 = hmm.senoneids[0];
size_t senoneid1 = 0;
size_t senoneid2 = 0;
if (!isSp) // fetch only if needed--may save some memory cycles
{
senoneid1 = hmm.senoneids[1];
senoneid2 = hmm.senoneids[2];
}
const size_t te = ts + numframes; // end time of current unit
size_t state1step0to1 = te; // inflection point from state 0 to 1, record in state 1
size_t state2step0to1 = te; // inflection point from state 0 to 1, record in state 2
size_t state2step1to2 = te; // inflection point from state 1 to 2, record in state 2
//now we only support transition from -1 to 0 or 2 for sil
float pathscore0 = fwscore ; // log pp in state 0
float pathscore1 = LOGZERO; // log pp in state 1
float pathscore2 = LOGZERO; // log pp in state 2
if(isSil)
pathscore2 = fwscore;
// first frame
if (ts != te) // for t = ts, initialization
{
if (isSil) //for sil, -1 to 2 and -1 to 0 is permitted
{
pathscore0 += getlogtransp(transP,-1,0) + logLLs(senoneid0,ts);
pathscore2 += getlogtransp(transP,-1,2) + logLLs(senoneid2,ts);
}
else //for others, only -1 to 0 is permitted
pathscore0 += logLLs(senoneid0,ts); // Note: no need to incorporate LLs for state [1] and [2] because the path log LLs are LOGZERO anyway
}
float pathscore2last = pathscore2; // allocate last for state 2 because the order of computation below is 2->1->0, last state 2 is needed because from 2 to 0 or 1 is permitedted for sil.
float pathscore1last = pathscore1; // allocate last for state 1 because the order of computation below is 2->1->0, last state 1 is needed because from 1 to 0 is permitedted for sil.
size_t backptroffset = backptroffsets[j]; // we make use of backptrstorage in backptroffsets[j] for viterbi of ergodic model (silence)
bpmatrixref backptrmatrix (&backptrstorage[backptroffset], hmm.MAXSTATES, numframes);
//subsequent frames
for (size_t t = ts + 1; t < te; t++)
{
if (!isSp)
{
// state [2]
pathscore2 += getlogtransp(transP,2,2); // log pp from state 2 to 2
if (isSil)
backptrmatrix (2, t-ts-1) = 2;
const float pathscore12 = pathscore1 + getlogtransp(transP,1,2); // log pp from state 1 to 2
if (pathscore12 >= pathscore2) // if state 1->2
{
pathscore2 = pathscore12;
state2step0to1 = state1step0to1; // record the inflection point
state2step1to2 = t; // record the inflection point
if (isSil)
backptrmatrix (2, t-ts-1) = 1;
}
if (isSil) // only silence have path from 0 to 2
{
const float pathscore02 = pathscore0 + getlogtransp(transP,0,2); // log pp from state 0 to 2
if (pathscore02 >= pathscore2) // if state 0->2
{
pathscore2 = pathscore02;
backptrmatrix (2, t-ts-1) = 0;
}
}
// state [1]
pathscore1 += getlogtransp(transP,1,1); // log pp from state 1 to 1
if (isSil)
backptrmatrix (1, t-ts-1) = 1;
const float pathscore01 = pathscore0 + getlogtransp(transP,0,1); // log pp from state 0 to 1
if (pathscore01 >= pathscore1) // if state 0 -> 1
{
pathscore1 = pathscore01;
state1step0to1 = t; // record the inflection point
if (isSil)
backptrmatrix (1, t-ts-1) = 0;
}
if (isSil) // only silence have path from 2 to 1
{
const float pathscore21 = pathscore2last + getlogtransp(transP,2,1);
if (pathscore21 >= pathscore1) // if state 2 -> 1
{
pathscore1 = pathscore21;
backptrmatrix (1, t-ts-1) = 2;
}
}
}
// state [0]
pathscore0 += getlogtransp(transP,0,0);
if(isSil) // only silence have path from 2 or 1 to 0
{
backptrmatrix (0, t-ts-1) = 0;
const float pathscore20 = pathscore2last + getlogtransp(transP,2,0); // log pp from state 2 to 0
if (pathscore20 >= pathscore0)
{
pathscore0 = pathscore20;
backptrmatrix (0, t-ts-1) = 2;
}
const float pathscore10 = pathscore1last + getlogtransp(transP,1,0); // log pp from state 1 to 0
if (pathscore10 >= pathscore0)
{
pathscore0 = pathscore10;
backptrmatrix (0, t-ts-1) = 1;
}
}
// add log LLs
pathscore0 += logLLs(senoneid0,t);
if (!isSp) // only fetch if needed, saves mem access
{
pathscore1 += logLLs(senoneid1,t);
pathscore2 += logLLs(senoneid2,t);
}
pathscore1last = pathscore1; // update pathscore1last
pathscore2last = pathscore2; // update pathscore2last
}
// final 'next' transition that exits from last frame
if (ts == te) // if sp tee model, will not in next loop
{
pathscore2 = pathscore0 + getlogtransp(transP,-1,1);
}
else if (isSp)
{
pathscore2 = pathscore0 + getlogtransp(transP,0,1) ; // sp model, from 0 to 1
//printf(" sp, %f\n", pathscore2);
}
else if(isSil) //for sil, the exit state can be 0 or 2.
{
const float pathscore03 = pathscore0 + getlogtransp(transP,0,3);
pathscore2 += getlogtransp(transP,2,3);
if(pathscore03 > pathscore2)
{
pathscore2 = pathscore03;
}
}
else
pathscore2 += getlogtransp(transP,2,3);
fwscore = pathscore2; // propagate across phone boundaries
// emit alignment
if (!isSil)
{
state2step0to1 += alignindex - ts; // convert to align measure
state2step1to2 += alignindex - ts;
for (size_t t = alignindex; t < alignindex + numframes; t++) // set the final alignment
{
size_t senoneid;
if (t < state2step0to1) // in state 0
senoneid = senoneid0;
else if(t < state2step1to2) // in state 1
senoneid = senoneid1;
else // in state 2
senoneid = senoneid2;
alignresult[t] = (unsigned short) senoneid;
}
}
else // for silence
{
size_t lastpointer = 2;
const float pathscore03 = pathscore0 + getlogtransp(transP,0,3);
if(pathscore03 >= pathscore2) //exit state is 0
{
alignresult[alignindex + numframes - 1] = (unsigned short) senoneid0;
lastpointer = 0;
}
else //exit state is 2
alignresult[alignindex + numframes - 1] = (unsigned short) senoneid2;
for (size_t t = alignindex + numframes - 2; (t + 1) > alignindex; t--) // set the final alignment
{
lastpointer = backptrmatrix (lastpointer, t-alignindex);
size_t senoneid = (size_t) (-1);
if (lastpointer == 0)
senoneid = senoneid0;
else if (lastpointer == 1)
senoneid = senoneid1;
else if (lastpointer == 2)
senoneid = senoneid2;
alignresult[t] = (unsigned short) senoneid;
}
}
ts = te;
alignindex += numframes;
}
edgeacscores[j] = fwscore;
}
// compute the final error signal from gammas and state-consolidated Eframescorrect
// in-place operation is supported (i.e. output = one of the inputs)
template<typename matrix>
static inline __device__ void computesMBRerrorsignals (const size_t s, const matrix & loggammas, const matrix & logEframescorrect,
const double logEframescorrecttotal, const float kappa, matrix & errorsignal)
{
const float Eframescorrecttotal = expf ((float)logEframescorrecttotal);
const size_t T = errorsignal.cols();
for (size_t t = 0; t < T; t++)
errorsignal(s,t) = expf (loggammas(s,t)) * (expf (logEframescorrect(s,t)) - Eframescorrecttotal) * kappa;
}
// test if a state is silence [v-hansu]
// WARNING, this function only support models with 9304 states
// TODO: change this later on
static inline __device__ bool issilencestate (size_t stateid, size_t numsenones)
{
if (numsenones == 9304 && (stateid == 7670 || stateid == 7671 || stateid == 7672))
return true;
else
return false;
}
// compare two states and check if they are of the same class [v-hansu]
template<typename ushortvector>
static inline __device__ bool isofsameclass (size_t statea, size_t stateb, ushortvector senone2classmap)
{
if (senone2classmap.size() == 0) // no map provided, we just do normal comparison
return (statea == stateb);
else
return senone2classmap[statea] == senone2classmap[stateb];
}
// Phase 1 of forwardbackward algorithm
// returnEframescorrect means sMBR mode
template<typename edgeinforvector, typename nodeinfovector, typename aligninfovector, typename ushortvector, typename uintvector, typename floatvector, typename doublevector>
static inline __device__ void forwardlatticej (const size_t j, const floatvector & edgeacscores,
const size_t /*spalignunitid --unused*/, const size_t silalignunitid,
const edgeinforvector & edges, const nodeinfovector & nodes, const aligninfovector & aligns,
const ushortvector & alignments, const uintvector & alignmentoffsets,
doublevector & logalphas, float lmf, float wp, float amf, const float boostingfactor,
const ushortvector & uids, const ushortvector senone2classmap, const bool returnEframescorrect,
doublevector & logframescorrectedge, doublevector & logaccalphas)
{
// edge info
const edgeinfowithscores & e = edges[j];
double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
//zhaorui to deal with the abnormal score for sent start.
if(e.l < -200.0f)
edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf;
const bool boostmmi = (boostingfactor != 0.0f);
// compute the frames-correct count for this edge
double logframescorrectedgej = LOGZERO;
const size_t numsenones = 9304; // WARNING: this is a hack, please fix this once smbr or bmmi is working! [v-hansu]
bool skipsilence = true; // currently we skip silence for BMMI and sMBR [v-hansu]
if (returnEframescorrect || boostmmi)
{
size_t ts = nodes[e.S].t;
size_t te = nodes[e.E].t;
size_t framescorrect = 0; // count raw number of correct frames
size_t startindex = alignmentoffsets[j];
size_t as = e.firstalign; // align start
size_t ae = (j+1) < edges.size() ? (size_t) edges[j+1].firstalign : aligns.size();
const bool isSil = (ae == as + 1 && aligns[as].unit == silalignunitid); // the order of this judgement shall be changed to save memory access
if (!(isSil && skipsilence)) // we don't count silence when 1. is silence; 2, skip them
{
for (size_t t = ts; t < te; t++)
{
if (!(skipsilence && issilencestate (alignments[t-ts+startindex], numsenones)))
framescorrect += isofsameclass (alignments[t-ts+startindex], uids[t], senone2classmap); // we only count correct && non-silence state
}
}
logframescorrectedgej = (framescorrect > 0) ? log ((double) framescorrect) : LOGZERO; // remember for backward pass
logframescorrectedge[j] = logframescorrectedgej;
}
if (boostmmi)
edgescore -= boostingfactor * exp (logframescorrectedge[j]);
#ifdef FORBID_INVALID_SIL_PATHS
const bool forbidinvalidsilpath = (logalphas.size() > nodes.size()); // we constrain sil to sil path if node space has been blown up
const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1); // HACK: 'unused' indicates artificially added sil/sp edge
// original mode
if (!isaddedsil)
#endif
{
const size_t S = e.S;
const size_t E = e.E;
const double inscore = logalphas[S];
const double pathscore = inscore + edgescore;
atomicLogAdd (&logalphas[E], pathscore);
if (returnEframescorrect)
{
#ifdef DIRECT_MODE
double loginaccs = logaccalphas[e.S] + edgescore;
double logpathaccs = logalphas[e.S] + edgescore + logframescorrectedgej;
logadd (logpathaccs, loginaccs);
atomicLogAdd (&logaccalphas[e.E], logpathaccs);
#else
double loginaccs = logaccalphas[S] - logalphas[S];
logadd (loginaccs, logframescorrectedgej);
double logpathacc = loginaccs + logalphas[S] + edgescore;
atomicLogAdd (&logaccalphas[E], logpathacc);
#endif
}
}
#ifdef FORBID_INVALID_SIL_PATHS
// silence edge or second speech edge
if ((isaddedsil && e.E != nodes.size() -1) || (forbidinvalidsilpath && e.S != 0))
{
const size_t S = (size_t) (!isaddedsil ? e.S + nodes.size() : e.S); // second speech edge comes from special 'silence state' node
const size_t E = (size_t) (isaddedsil ? e.E + nodes.size() : e.E); // silence edge goes into special 'silence state' node
// remaining lines here are 100% code dup from above, just operating on different (S, E)
const double inscore = logalphas[S];
const double pathscore = inscore + edgescore;
atomicLogAdd (&logalphas[E], pathscore);
if (returnEframescorrect)
{
double loginaccs = logaccalphas[S] - logalphas[S];
logadd (loginaccs, logframescorrectedgej);
double logpathacc = loginaccs + logalphas[S] + edgescore;
atomicLogAdd (&logaccalphas[E], logpathacc);
}
}
#endif
}
template<typename edgeinforvector, typename nodeinfovector, typename aligninfovector, typename floatvector, typename doublevector>
static inline __device__ void backwardlatticej (size_t j, const floatvector & edgeacscores,
const size_t /*spalignunitid --unused*/, const size_t /*silalignunitid --unused*/,
const edgeinforvector & edges, const nodeinfovector & nodes,
const aligninfovector & /*aligns -- unused*/, const double totalfwscore, doublevector & logpps,
doublevector & logalphas, doublevector & logbetas, float lmf, float wp,
float amf, const float boostingfactor, const bool returnEframescorrect,
doublevector & logframescorrectedge, doublevector & logaccalphas,
doublevector & logEframescorrect, doublevector & logaccbetas)
{
// output values
double logpp = LOGZERO;
double logEframescorrectj = LOGZERO;
const bool boostmmi = (boostingfactor != 0.0f);
// edge info
const edgeinfowithscores & e = edges[j];
double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
//zhaorui to deal with the abnormal score for sent start.
if (e.l < -200.0f)
edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf;
if (boostmmi)
edgescore -= boostingfactor * exp (logframescorrectedge[j]);
// get the frames-correct count for this edge that was computed during the forward pass
double logframescorrectedgej = (returnEframescorrect || boostmmi) ? logframescorrectedge[j] : LOGZERO;
#ifdef FORBID_INVALID_SIL_PATHS
// original mode
const bool forbidinvalidsilpath = (logalphas.size() > nodes.size()); // we prune sil to sil path if alphabetablowup != 1
const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1); // HACK: 'unused' indicates artificially added sil/sp edge
if (!isaddedsil) // original mode
#endif
{
const size_t S = e.S;
const size_t E = e.E;
// backward pass
const double inscore = logbetas[E];
const double pathscore = inscore + edgescore;
atomicLogAdd (&logbetas[S], pathscore);
// compute lattice posteriors on the fly since we are at it
logpp = logalphas[S] + edgescore + logbetas[E] - totalfwscore;
// similar logic for Eframescorrect
if (returnEframescorrect)
{
#ifdef DIRECT_MODE
double loginaccs = logaccbetas[e.E] + edgescore;
double logpathaccs = logbetas[e.E] + edgescore + logframescorrectedgej;
logadd (logpathaccs, loginaccs);
atomicLogAdd (&logaccbetas[e.S], logpathaccs);
double logecorrect = logaccalphas[e.S] + edgescore + logbetas[e.E];
logadd (logecorrect, logalphas[e.S] + edgescore + logframescorrectedgej + logbetas[e.E]);
logadd (logecorrect, logalphas[e.S] + edgescore + logaccbetas[e.E]);
logEframescorrectj = logecorrect - totalfwscore; // be careful, this includes the denominator
#else
// backward pass
double loginaccs = logaccbetas[E] - logbetas[E];
logadd (loginaccs, logframescorrectedgej);
double logpathacc = loginaccs + logbetas[E] + edgescore;
atomicLogAdd (&logaccbetas[S], logpathacc);
// sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
double logsum = logframescorrectedgej; // sum over this edge, left partial (alpha), right partial (beta)
double edgelogaccalpha = logaccalphas[S] - logalphas[S]; // incoming partial expected frames correct
logadd (logsum, edgelogaccalpha);
double edgelogaccbeta = logaccbetas[E] - logbetas[E]; // partial expected frames correct from the end
logadd (logsum, edgelogaccbeta);
logEframescorrectj = logsum; // that's it
#endif
}
}
#ifdef FORBID_INVALID_SIL_PATHS
double logpp2 = LOGZERO;
double logEframescorrectj2 = LOGZERO;
// silence edge or second speech edge
if ((isaddedsil && e.E != nodes.size() -1) || (forbidinvalidsilpath && e.S != 0))
{
const size_t S = (size_t) (!isaddedsil ? e.S + nodes.size() : e.S); // second speech edge comes from special 'silence state' node
const size_t E = (size_t) (isaddedsil ? e.E + nodes.size() : e.E); // silence edge goes into special 'silence state' node
// remaining lines here are code dup from above, with two changes: logadd2/logEframescorrectj2 instead of logadd/logEframescorrectj
// backward pass
const double inscore = logbetas[E];
const double pathscore = inscore + edgescore;
atomicLogAdd (&logbetas[S], pathscore);
// compute lattice posteriors on the fly since we are at it
logpp2 = logalphas[S] + edgescore + logbetas[E] - totalfwscore; // second edge (logpp2)
// similar logic for Eframescorrect
if (returnEframescorrect)
{
// backward pass
double loginaccs = logaccbetas[E] - logbetas[E];
logadd (loginaccs, logframescorrectedgej);
double logpathacc = loginaccs + logbetas[E] + edgescore;
atomicLogAdd (&logaccbetas[S], logpathacc);
// sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
double logsum = logframescorrectedgej; // sum over this edge, left partial (alpha), right partial (beta)
double edgelogaccalpha = logaccalphas[S] - logalphas[S]; // incoming partial expected frames correct
logadd (logsum, edgelogaccalpha);
double edgelogaccbeta = logaccbetas[E] - logbetas[E]; // partial expected frames correct from the end
logadd (logsum, edgelogaccbeta);
logEframescorrectj2 = logsum; // that's it for this second edge
}
// sum logpp2 and logEframescorrectj2
// Eframescorrect must be summed up in a weighted fashion, weighted by PP
double numer = logEframescorrectj + logpp;
logadd (numer, logEframescorrectj2 + logpp2); // weighted sum, weighted by respective (log)pp
logadd (logpp, logpp2); // (log)pp is just the sum of the two posteriors
double denom = logpp; // and that is also the denominator for the weighted sum
logEframescorrectj = numer - denom; // weighted sum
}
#else
nodes;
#endif
// write back return values
if (logpp > 0.0) // clip to log 1 (may be possible due to small numeric inaccuracies, although it really shouldn't happen)
logpp = 0.0;
logpps[j] = logpp;
if (returnEframescorrect)
logEframescorrect[j] = logEframescorrectj;
}
template<typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector, typename doublevector, typename matrix>
static inline __device__ void sMBRerrorsignalj (size_t j, const ushortvector & alignstateids, const uintvector & alignoffsets,
const edgeinfowithscoresvector & edges,
const nodeinfovector & nodes, const doublevector & logpps, const float amf,
const doublevector & logEframescorrect, const double logEframescorrecttotal,
matrix & errorsignal, matrix & errorsignalneg)
{
size_t ts = nodes[edges[j].S].t;
size_t te = nodes[edges[j].E].t;
if (ts != te)
{
#ifdef DIRECT_MODE
float logEframescorrectj = logEframescorrect[j];
size_t offset = alignoffsets[j];
for (size_t t = ts; t < te; t++)
{
const size_t s = (size_t) alignstateids[t - ts + offset];
atomicLogAdd (&errorsignal(s,t), logEframescorrectj);
}
#else
const double diff = expdiff (logEframescorrect[j], logEframescorrecttotal);
// Note: the contribution of the states of an edge to their senones is the same for all states
// so we compute it once and add it to all; this will not be the case without hard alignments.
double absdiff = fabs (diff);
if (absdiff == 0.0f)
return;
const float logedgecorrect = (float) (logpps[j] + log (absdiff));
size_t offset = alignoffsets[j];
for (size_t t = ts; t < te; t++)
{
const size_t s = (size_t) alignstateids[t - ts + offset];
if (diff > 0.0)
atomicLogAdd(&errorsignal(s,t), logedgecorrect);
else
atomicLogAdd(&errorsignalneg(s,t), logedgecorrect);
}
absdiff = amf; // just to reference it because we hope to support linear mode as well
#endif
}
}
// accumulate a per-edge quantity into the states that the edge is aligned with
// Use this for MMI passing the edge posteriors logpps[] as logq, or for sMBR passing logEframescorrect[].
// j=edge index, alignment in (alignstateids, alignoffsets)
template<typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector,typename doublevector, typename matrix>
static inline __device__ void stateposteriorsj (size_t j, const ushortvector & alignstateids, const uintvector & alignoffsets,
const edgeinfowithscoresvector & edges,
const nodeinfovector & nodes, const doublevector & logqs /*quantity to accumulate*/,
matrix & logacc /*accumulator to accumulate into*/)
{
size_t ts = nodes[edges[j].S].t;
size_t te = nodes[edges[j].E].t;
if (ts != te)
{
const float logq = (float) logqs[j]; // per-edge quantity to accumulate, e.g. edge posteriors -> state posteriors
size_t offset = alignoffsets[j];
for (size_t t = ts; t < te; t++)
{
const size_t s = (size_t) alignstateids[t - ts + offset]; // get state for this (j,t)
atomicLogAdd (&logacc(s,t), logq);
}
}
}
};
};};
#pragma pop_macro ("atomicCAS")
#pragma pop_macro ("atomicAdd")
#pragma pop_macro ("__device__")