https://github.com/Microsoft/CNTK
Raw File
Tip revision: 5c3f708097bdcdaf2c06d2aa8a9b3fdc772ae27a authored by Mark Hillebrand on 18 January 2016, 08:36:30 UTC
License change
Tip revision: 5c3f708
latticefunctionskernels.h
// latticefunctionskernels.cu(.h) -- kernels for lattice ops intended for use with CUDA, to be called from actual CUDA kernels.
//
// To make this compile for the CPU (emulation for testing), add this line before #including this:
// #define __device__
//
// F. Seide, V-hansu

#define FORBID_INVALID_SIL_PATHS    // [v-hansu] prune path that start from sil(sp) and go into sil, only used with addsil is adopted

#pragma once

#pragma push_macro ("__device__")
#pragma push_macro ("atomicAdd")
#pragma push_macro ("atomicCAS")

#include "latticestorage.h"
#include <limits>

namespace msra { namespace cuda { class passtextureref; }}

#ifdef CPUONLY
#define __kernel_emulation__
#endif

#ifdef __kernel_emulation__
#include "math.h"       // to get exp() and log() compiled correctly with c++
#include<stdexcept>
using namespace std;
#ifndef __device__
#define __device__
#endif
#define CUDART_MIN_DENORM_F numeric_limits<float>::denorm_min()
#define atomicAdd(address,value) (*(address)+=(value))  // don't forget to #undef (#praga pop_macro)! Otherwise CUDA might compile with this...
#define atomicCAS(address, compare, val) *address; *address = *address == compare ? val : *address;
#define __double_as_longlong(in) (*(unsigned long long int *) &in)
#define __longlong_as_double(in) (*(double *) &in)
#define __float_as_int(in) (*(int *) &in)
#define __int_as_float(in) (*(float *) &in)
#else    // TODO: remove this once we got this figured out
#include "math_constants.h"
#if __CUDA_ARCH__ < 200
//#warning Sequence training not supported on 1.x CUDA machines.
#define force_crash() (*((int*)-1)=0)               // TODO: this does not in fact seem to crash it...
#define atomicAdd(a,v) (force_crash(),*(a)=v)       // force a crash if used with 1.x devices 
#define atomicCAS(address, compare, val) (*(address) = compare + val, *((int*)-1)=0)
#define __double_as_longlong(in) (force_crash(), in)
#define __longlong_as_double(in) (force_crash(), in)
#define __float_as_int(in) (force_crash(), in)
#define __int_as_float(in) (force_crash(), in)
#endif
#endif

namespace msra { namespace lattices {

struct somedata         // example type to have a pattern to copy from
{
    size_t fortytwo;
};

struct empty {};

// Note that the code that uses this (edgealignmentj) will assume 3-state left-to-right
// except for /sil/ and /sp/ which are treated specially.
// TODO: either check when creating this whether this assumption is true, or control this through a flag in here.
struct lr3transP        // lr3 = 3-state left-to-right architecture
{
    
        static const size_t MAXSTATES = 3;
        size_t numstates;
        float loga[MAXSTATES+1][MAXSTATES+1];
        

    lr3transP ()
    {
#ifdef INITIAL_STRANGE  
        numstates = 3;
        for (size_t i = 0; i < NUMSTATES+1; i++)
            for (size_t j = 0; j < NUMSTATES+1; j++)
            {
                loga[i][j] = LOGZERO;            
            }
#endif
    }
};

struct lrhmmdef         // left-to-right HMM (no /sil/)
{
    static const size_t MAXSTATES = 3;              // we use a fixed memory allocation since it's almost always 3 anyway

    unsigned char transPindex;                      // index of monophone to find transP matrix
    unsigned char numstates;                        // number of states; code supports only either 1 or 3
    unsigned short senoneids[MAXSTATES];             // [0..numstates-1] senone indices

    size_t getsenoneid (size_t i) const { return (size_t) senoneids[i]; }
    size_t getnumstates() const { return (size_t) numstates; }
    const struct lr3transP & gettransP(const lr3transP * transPs) const { return transPs[transPindex]; }

    lrhmmdef ()
    {
#ifdef INITIAL_STRANGE
        transPindex = unsigned char (-1);
        numstates = unsigned char (-1);
        for (size_t i = 0; i < MAXSTATES; i++)
        {
            senoneids[i] = unsigned short (-1);
        }
#endif
    }
};

#if 1   // straight-forward version
#else   // CUDA hacked version
hmm gethmm (hmms, i)
{
    ushort4 u4 = *((ushort4) &hmms[i]);
    lrhmmdef hmm;
    hmm.transPindex = u4.x & 0xff;
    hmm.numstates = u4.x >> 8;
    hmm.senoneids[0] = u4.y;
    hmm.senoneids[1] = u4.z;
    hmm.senoneids[2] = u4.w;
}
#endif

#ifndef LOGZERO
#define LOGZERO -1e30f
#endif

class bpmatrixref
{
private:
    unsigned short * p;      // pointer in CUDA space of this device
    size_t numrows;     // rows()
    size_t numcols;     // cols()
    size_t colstride;   // height of column = rows() rounded to multiples of 4
    __device__ void checkbounds (size_t i, size_t j) const 
    {
        if (i >= numrows || j >= numcols)
#ifdef __kernel_emulation__
            throw::logic_error ("out of boundary!!!");
#else
            *((int*)-1)=0;
#endif
    }
    __device__ size_t locate (size_t i, size_t j) const 
    {
        checkbounds (i,j);
        return j * colstride + i; 
    }   // matrix in column-wise storage
public:
    __device__ bpmatrixref (unsigned short * address, size_t n, size_t m)
    {
        numrows = n;
        numcols = m;
        colstride = n;
        p = address;
    }

    __device__ unsigned short &       operator() (size_t i, size_t j)       { return p[locate(i,j)]; }
    __device__ const unsigned short & operator() (size_t i, size_t j) const { return p[locate(i,j)]; }
};

// this class contains all-static methods that are inner pieces of thread kernels for use with CUDA
struct latticefunctionskernels
{
    // [v-hansu]  mimic the float version of atomicCAS
    static __device__ float atomicCASfloatdouble (float *address, float compare, float val)
    {
        int * intaddress = (int *) address;
        int intcompare = __float_as_int(compare);             // __double_as_longlong : read double as unsigned long long
        int intval = __float_as_int(val);
        int result = atomicCAS (intaddress, intcompare, intval);
        return __int_as_float(result);
    }

    static __device__ double atomicCASfloatdouble (double *address, double compare, double val)
    {
        unsigned long long int * longlongintaddress = (unsigned long long int *) address;
        unsigned long long int longlongintcompare = __double_as_longlong (compare);             // __double_as_longlong : read double as unsigned long long
        unsigned long long int longlongintval = __double_as_longlong (val);
        unsigned long long int result = atomicCAS(longlongintaddress, longlongintcompare, longlongintval);
        return __longlong_as_double(result);
    }

    static __device__ void logaddratio (double & loga, double diff)
    {
        if (diff < -37.0f) return;      // log (2^-53), 52-bit mantissa -> cut of after 53th bit
        loga += log (1.0 + exp (diff));
    }
    static __device__ void logaddratio (float & loga, float diff) 
    { 
        if (diff < -17.0f) return; // log (2^-24), 23-bit mantissa -> cut of after 24th bit 
        loga += logf (1.0f + expf (diff)); 
    }

    template<typename FLOAT> static __device__ void swap (FLOAT & left, FLOAT & right)
    {    // exchange values stored at _Left and _Right
        FLOAT tmp = left;
        left = right;
        right = tmp;
    }

    // overloads for exp() for float and double, so that we can use templates
    static __device__ float expfd (float x) { return ::expf (x); }  // TODO: ain't there an overload for this?
    static __device__ double expfd (double x) { return ::exp (x); }

    // Compute the difference of two numbers, which are represented as their logs.
    // The return value is a non-log value. exp(loga) - exp(logb)
    template<typename FLOAT>
    static __device__ FLOAT expdiff (FLOAT loga, FLOAT logb)
    {
        if (logb < loga)  // logb - loga < 0 => exp(logb-loga) < 1
            return  expfd (loga) * (1 - expfd (logb - loga));
        else
            return -expfd (logb) * (1 - expfd (loga - logb));
    }

    // loga <- log (exp (loga) + exp (logb)) = log (exp (loga) * (1.0 + exp (logb - loga)) = loga + log (1.0 + exp (logb - loga))
    template<typename FLOAT> static __device__ void logadd (FLOAT & loga, FLOAT logb)
    {
        if (logb > loga)            // we add smaller to bigger
            swap (loga, logb);      // loga is bigger
        if (loga <= LOGZERO)         // both are 0
            return;
        logaddratio (loga, logb - loga);
    }

    // does the same as above but if the bigger one is too small, we assign a small value to it
    template<typename FLOAT> static __device__ void logaddseen (FLOAT & loga, FLOAT logb)
    {
        if (logb > loga)            // we add smaller to bigger
            swap (loga, logb);      // loga is bigger
        if (loga <= LOGZERO)         // both are 0
        {
            loga = logf(CUDART_MIN_DENORM_F);   // [v-hansu] we hope to separate LOGZERO (unseen states) and logf(CUDART_MIN_DENORM_F) (seen states with small prob)
            return;
        }
        logaddratio (loga, logb - loga);
    }

#if 1
    static inline __device__ float  bitsasfloat (                   int b) { return       __int_as_float (b); }
    static inline __device__ double bitsasfloat (unsigned long long int b) { return __longlong_as_double (b); }
    static inline __device__                    int floatasbits (float  f) { return       __float_as_int (f); }
    static inline __device__ unsigned long long int floatasbits (double f) { return __double_as_longlong (f); }

    template<typename FLOAT>    // adapted from [http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ixzz32EuzZjxV]
    static __device__ FLOAT atomicLogAdd (FLOAT * address, FLOAT val) // direct adaptation from NVidia source code
    {
        typedef decltype (floatasbits (val)) bitstype;
        bitstype * address_as_ull = (bitstype *) address;
        bitstype old = *address_as_ull, assumed;
        do {
            assumed = old;
            FLOAT sum = bitsasfloat (assumed);
            logaddseen (sum, val);
            old = atomicCAS (address_as_ull, assumed, floatasbits (sum));
        } while (assumed != old);
        // note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always)
        return bitsasfloat(old);
    }
#else   // this code does not work because (assumed != old) will not compare correctly in case of NaNs
    //same pattern as atomicAdd(), but performing the log-add operation instead 
    template<typename FLOAT> static __device__ FLOAT atomicLogAdd (FLOAT * address, FLOAT val)
    {
        FLOAT old = *address;
        FLOAT assumed;
        FLOAT logaddresult;
        do {
            assumed = old;                                  // old is the assumed value at address
            logaddresult = assumed;                         // for next step to compute logaddresult, assumed shall be the same as before
            logaddseen (logaddresult, val);
            old = atomicCASfloatdouble (address, assumed, logaddresult);
        } while (assumed != old);                           // if old == assumed, the *address is not changed in this loop, so this is safe
        return old;
    }
#endif

    // [v-hansu] shuffling accessing order for a item in cubic(Ni, Nj, Nk) with index i, j, k according to shufflemode
    static inline __device__ size_t shuffle (size_t i, size_t Ni, size_t j, size_t Nj, size_t k, size_t Nk, size_t shufflemode)
    {
        if (shufflemode == 0)
            return i + j * Ni + k * Nj * Ni;
        else if (shufflemode == 1)  // inverse
            return k + j * Nk + i * Nj * Nk;
        else if (shufflemode == 2)  // flip i and j
            return j + i * Nj + k * Ni * Nj;
        else if (shufflemode == 3)  // flip j and k
            return i + k * Ni + j * Nk * Ni;
        else if (shufflemode == 4)
            return j + k * Nj + i * Nk * Nj;
        else
            *((int*)-1)=0;            // shall not get here, WRONG
        return 0;
    }

    template<typename doublevector>
    static __device__ void setvaluej (size_t j, doublevector & thisvector, double value)
    {
        thisvector[j] = value;
    }

    //zhaorui
    static inline __device__ float getlogtransp (lr3transP transP, int from, int to)
    {
        /*if (from < -1 || from >= transP.MAXSTATES || to > transP.MAXSTATES) 
        {
            //printf("from: %d to: %d\n", from, to);
            return LOGZERO;
        }*/
        return transP.loga[from+1][to];
    }
    template<typename lrhmmdefvector, typename lr3transPvector, typename matrix, typename nodeinfovector, typename edgeinfowithscoresvector, typename aligninfovector, typename ushortvector, typename uintvector, typename floatvector, typename sizetvector>
    static inline __device__ void edgealignmentj (size_t j, const lrhmmdefvector & hmms, const lr3transPvector & transPs, const size_t spalignunitid,
                                                  const size_t silalignunitid, const matrix & logLLs, const nodeinfovector & nodes,
                                                  const edgeinfowithscoresvector & edges, const aligninfovector & aligns, 
                                                  const uintvector & alignoffsets, ushortvector & backptrstorage, const sizetvector & backptroffsets,
                                                  ushortvector & alignresult, floatvector & edgeacscores)
    {                                            // TODO: alignresult will change to (start,end) 
        // mostly finished
        // some preparation
        size_t as = edges[j].firstalign;        // align start
        size_t ae = (j+1) < edges.size() ? (size_t) edges[j+1].firstalign : aligns.size();
        if (as == ae)       // the last empty alignment
            return;
        size_t ts = nodes[edges[j].S].t;

        float fwscore = 0.0f;       // score passed across phone boundaries
        size_t alignindex = alignoffsets[j];        // index to set (result)
#ifndef PARALLEL_SIL
        const bool isSil = (aligns[as].unit == silalignunitid || aligns[ae-1].unit == silalignunitid);
        if (isSil)  return;     // we do not support silence edge now, which is computed by cpu, may change when we support it
#endif
        // Viterbi alignment
        for (size_t k = as; k < ae; k++)
        {
            const aligninfo align = aligns[k];
            const size_t numframes = align.frames;
            const bool isSp = (align.unit == spalignunitid);
            const bool isSil = (align.unit == silalignunitid);

            const lrhmmdef hmm = hmms[align.unit];
            const lr3transP transP = transPs[hmm.transPindex];
          

            // pre-fetch senone ids into registers
            size_t senoneid0 = hmm.senoneids[0];
            size_t senoneid1 = 0;
            size_t senoneid2 = 0;
            if (!isSp)      // fetch only if needed--may save some memory cycles
            {
                senoneid1 = hmm.senoneids[1];
                senoneid2 = hmm.senoneids[2];
            }

            const size_t te = ts + numframes;               // end time of current unit

            size_t state1step0to1 = te;                     // inflection point from state 0 to 1, record in state 1
            size_t state2step0to1 = te;                     // inflection point from state 0 to 1, record in state 2
            size_t state2step1to2 = te;                     // inflection point from state 1 to 2, record in state 2

            //now we only support transition from -1 to 0 or 2 for sil
            float pathscore0 = fwscore ;                     // log pp in state 0
            float pathscore1 = LOGZERO;                     // log pp in state 1
            float pathscore2 = LOGZERO;                     // log pp in state 2
            if(isSil)
                pathscore2 = fwscore;                    
                
            // first frame
            if (ts != te)                                                              // for t = ts, initialization
            {                           
                if (isSil)                                                              //for sil, -1 to 2 and -1 to 0 is permitted
                {
                    pathscore0 += getlogtransp(transP,-1,0) + logLLs(senoneid0,ts); 
                    pathscore2 += getlogtransp(transP,-1,2) + logLLs(senoneid2,ts);      
                }
                else                                                                    //for others, only -1 to 0 is permitted
                    pathscore0 +=  logLLs(senoneid0,ts);                                // Note: no need to incorporate LLs for state [1] and [2] because the path log LLs are LOGZERO anyway
            }
            
            
            float pathscore2last = pathscore2;              // allocate last for state 2 because the order of computation below is 2->1->0, last state 2 is needed because from 2 to 0 or 1 is permitedted for sil.
            float pathscore1last = pathscore1;              // allocate last for state 1 because the order of computation below is 2->1->0, last state 1 is needed because from 1 to 0 is permitedted for sil.
            
            size_t backptroffset = backptroffsets[j];               // we make use of backptrstorage in backptroffsets[j] for viterbi of ergodic model (silence)

            bpmatrixref backptrmatrix (&backptrstorage[backptroffset], hmm.MAXSTATES, numframes);
            //subsequent frames
            for (size_t t = ts + 1; t < te; t++)
            {
                if (!isSp)
                {
                    // state [2]                    
                    pathscore2 += getlogtransp(transP,2,2);                                           // log pp from state 2 to 2
                    if (isSil)
                        backptrmatrix (2, t-ts-1) = 2;
                    const float pathscore12 = pathscore1 + getlogtransp(transP,1,2);                  // log pp from state 1 to 2
                    if (pathscore12 >= pathscore2)                                              // if state 1->2
                    {
                        pathscore2 = pathscore12;
                        state2step0to1 = state1step0to1;                                        // record the inflection point
                        state2step1to2 = t;                                                     // record the inflection point
                        if (isSil)
                            backptrmatrix (2, t-ts-1) = 1;
                    }
                    if (isSil)                                                                  // only silence have path from 0 to 2
                    {
                        const float pathscore02 = pathscore0 + getlogtransp(transP,0,2);          // log pp from state 0 to 2
                        if (pathscore02 >= pathscore2)                                          // if state 0->2
                        {
                            pathscore2 = pathscore02;
                            backptrmatrix (2, t-ts-1) = 0;
                        }
                    }

                    // state [1]
                    pathscore1 += getlogtransp(transP,1,1);                                           // log pp from state 1 to 1
                    if (isSil)
                        backptrmatrix (1, t-ts-1) = 1;
                    const float pathscore01 = pathscore0 + getlogtransp(transP,0,1);                  // log pp from state 0 to 1
                    if (pathscore01 >= pathscore1)                                              // if state 0 -> 1
                    {
                        pathscore1 = pathscore01;
                        state1step0to1 = t;                                                     // record the inflection point
                        if (isSil)
                            backptrmatrix (1, t-ts-1) = 0;
                    }
                    if (isSil)                                                                  // only silence have path from 2 to 1
                    {
                        const float pathscore21 = pathscore2last + getlogtransp(transP,2,1); 
                        if (pathscore21 >= pathscore1)                                              // if state 2 -> 1
                        {
                            pathscore1 = pathscore21;
                            backptrmatrix (1, t-ts-1) = 2;
                        }
                    }
                }
                // state [0]
                pathscore0 += getlogtransp(transP,0,0);
                if(isSil)                                                                       // only silence have path from 2 or 1 to 0
                {
                    backptrmatrix (0, t-ts-1) = 0;
                    const float pathscore20 = pathscore2last + getlogtransp(transP,2,0);      // log pp from state 2 to 0
                    if (pathscore20 >= pathscore0)
                    {
                        pathscore0 = pathscore20;
                        backptrmatrix (0, t-ts-1) = 2;
                    }
                    const float pathscore10 = pathscore1last + getlogtransp(transP,1,0);      // log pp from state 1 to 0
                    if (pathscore10 >= pathscore0)
                    {
                        pathscore0 = pathscore10;
                        backptrmatrix (0, t-ts-1) = 1;
                    }
                }

                // add log LLs
                pathscore0 += logLLs(senoneid0,t);
                if (!isSp)      // only fetch if needed, saves mem access
                {
                    pathscore1 += logLLs(senoneid1,t);
                    pathscore2 += logLLs(senoneid2,t);
                }
                pathscore1last = pathscore1;                                                    // update pathscore1last
                pathscore2last = pathscore2;                                                   // update pathscore2last 
            }

            // final 'next' transition that exits from last frame

            if (ts == te)                                                                   // if sp tee model, will not in next loop
            {                
                pathscore2 = pathscore0 + getlogtransp(transP,-1,1);                                               
            }
            else if (isSp)
            {
                pathscore2 = pathscore0 + getlogtransp(transP,0,1) ;          // sp model, from 0 to 1
                //printf(" sp, %f\n", pathscore2);
            }
            
            else if(isSil)                                                                   //for sil, the exit state can be 0 or 2.
            {
                const float pathscore03 =  pathscore0 + getlogtransp(transP,0,3);
                pathscore2 += getlogtransp(transP,2,3);
                if(pathscore03 > pathscore2)
                {
                    pathscore2 = pathscore03;
                }
            }
            else
                pathscore2 += getlogtransp(transP,2,3);

            fwscore = pathscore2;                                                           // propagate across phone boundaries
            
            // emit alignment

            if (!isSil)
            {
                state2step0to1 += alignindex - ts;                              // convert to align measure
                state2step1to2 += alignindex - ts;
                for (size_t t = alignindex; t < alignindex + numframes; t++)    // set the final alignment
                {
                    size_t senoneid;
                    if (t < state2step0to1)                                     // in state 0
                        senoneid = senoneid0;
                    else if(t < state2step1to2)                                 // in state 1
                        senoneid = senoneid1;
                    else                                                        // in state 2
                        senoneid = senoneid2;
                    alignresult[t] = (unsigned short) senoneid;
                }
            }
            else                                                                        // for silence
            {
                size_t lastpointer = 2;
                const float pathscore03 =  pathscore0 + getlogtransp(transP,0,3);
                if(pathscore03 >= pathscore2)                                                       //exit state is 0
                {
                    alignresult[alignindex + numframes - 1] = (unsigned short) senoneid0;                  
                    lastpointer = 0;
                }
                else                                                                                  //exit state is 2
                    alignresult[alignindex + numframes - 1] = (unsigned short) senoneid2;             
                
                for (size_t t = alignindex + numframes - 2; (t + 1) > alignindex; t--)  // set the final alignment
                {
                    lastpointer = backptrmatrix (lastpointer, t-alignindex);
                    size_t senoneid = (size_t) (-1);
                    if (lastpointer == 0)
                        senoneid = senoneid0;
                    else if (lastpointer == 1)
                        senoneid = senoneid1;
                    else if (lastpointer == 2)
                        senoneid = senoneid2;
                    alignresult[t] = (unsigned short) senoneid;
                }
            }
            ts = te;
            alignindex += numframes;
        }
        edgeacscores[j] = fwscore;
    }

    // compute the final error signal from gammas and state-consolidated Eframescorrect
    // in-place operation is supported (i.e. output = one of the inputs) 
    template<typename matrix>
    static inline __device__  void computesMBRerrorsignals (const size_t s, const matrix & loggammas, const matrix & logEframescorrect, 
                                                            const double logEframescorrecttotal, const float kappa, matrix & errorsignal)
    {
        const float Eframescorrecttotal = expf ((float)logEframescorrecttotal);
        const size_t T = errorsignal.cols();
        for (size_t t = 0; t < T; t++)
            errorsignal(s,t) = expf (loggammas(s,t)) * (expf (logEframescorrect(s,t)) - Eframescorrecttotal) * kappa; 
    }

    // test if a state is silence [v-hansu]
    // WARNING, this function only support models with 9304 states
    // TODO: change this later on
    static inline __device__ bool issilencestate (size_t stateid, size_t numsenones)
    {
        if (numsenones == 9304 && (stateid == 7670 || stateid == 7671 || stateid == 7672))
            return true;
        else
            return false;
    }

    // compare two states and check if they are of the same class [v-hansu]
    template<typename ushortvector>
    static inline __device__ bool isofsameclass (size_t statea, size_t stateb, ushortvector senone2classmap)
    {
        if (senone2classmap.size() == 0)        // no map provided, we just do normal comparison
            return (statea == stateb);
        else
            return senone2classmap[statea] == senone2classmap[stateb];
    }

    // Phase 1 of forwardbackward algorithm
    // returnEframescorrect means sMBR mode
    template<typename edgeinforvector, typename nodeinfovector, typename aligninfovector, typename ushortvector, typename uintvector, typename floatvector, typename doublevector>
    static inline __device__ void forwardlatticej (const size_t j, const floatvector & edgeacscores, 
                                                   const size_t /*spalignunitid --unused*/, const size_t silalignunitid, 
                                                   const edgeinforvector & edges, const nodeinfovector & nodes, const aligninfovector & aligns, 
                                                   const ushortvector & alignments, const uintvector & alignmentoffsets,
                                                   doublevector & logalphas, float lmf, float wp, float amf, const float boostingfactor, 
                                                   const ushortvector & uids, const ushortvector senone2classmap, const bool returnEframescorrect, 
                                                   doublevector & logframescorrectedge, doublevector & logaccalphas)
    {
        // edge info
        const edgeinfowithscores & e = edges[j];
        double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;      // note: edgeacscores[j] == LOGZERO if edge was pruned
        //zhaorui to deal with the abnormal score for sent start. 
        if(e.l < -200.0f)
            edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf;
        const bool boostmmi = (boostingfactor != 0.0f);
        // compute the frames-correct count for this edge
        double logframescorrectedgej = LOGZERO;
        const size_t numsenones = 9304;             // WARNING: this is a hack, please fix this once smbr or bmmi is working! [v-hansu]
        bool skipsilence = true;                    // currently we skip silence for BMMI and sMBR [v-hansu]
        if (returnEframescorrect || boostmmi)
        {
            size_t ts = nodes[e.S].t;
            size_t te = nodes[e.E].t;
            size_t framescorrect = 0;                                                   // count raw number of correct frames
            size_t startindex = alignmentoffsets[j];

            size_t as = e.firstalign;                                                   // align start
            size_t ae = (j+1) < edges.size() ? (size_t) edges[j+1].firstalign : aligns.size();
            const bool isSil = (ae == as + 1 && aligns[as].unit == silalignunitid);     // the order of this judgement shall be changed to save memory access
            if (!(isSil && skipsilence))         // we don't count silence when 1. is silence; 2, skip them
            {
                for (size_t t = ts; t < te; t++)
                {
                    if (!(skipsilence && issilencestate (alignments[t-ts+startindex], numsenones)))
                        framescorrect += isofsameclass (alignments[t-ts+startindex], uids[t], senone2classmap);  // we only count correct && non-silence state
                }
            }
            logframescorrectedgej = (framescorrect > 0) ? log ((double) framescorrect) : LOGZERO;          // remember for backward pass
            logframescorrectedge[j] = logframescorrectedgej;
        }
        if (boostmmi)
            edgescore -= boostingfactor * exp (logframescorrectedge[j]);
        
#ifdef FORBID_INVALID_SIL_PATHS
        const bool forbidinvalidsilpath = (logalphas.size() > nodes.size());    // we constrain sil to sil path if node space has been blown up
        const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1);        // HACK: 'unused' indicates artificially added sil/sp edge

        // original mode
        if (!isaddedsil)
#endif
        {
            const size_t S = e.S;
            const size_t E = e.E;

            const double inscore = logalphas[S];
            const double pathscore = inscore + edgescore;

            atomicLogAdd (&logalphas[E], pathscore);

            if (returnEframescorrect)
            {
#ifdef DIRECT_MODE
                double loginaccs = logaccalphas[e.S] + edgescore;
                double logpathaccs = logalphas[e.S] + edgescore + logframescorrectedgej;
                logadd (logpathaccs, loginaccs);
                atomicLogAdd (&logaccalphas[e.E], logpathaccs);
#else
                double loginaccs = logaccalphas[S] - logalphas[S];
                logadd (loginaccs, logframescorrectedgej);
                double logpathacc = loginaccs + logalphas[S] + edgescore;
                atomicLogAdd (&logaccalphas[E], logpathacc);
#endif
            }
        }

#ifdef FORBID_INVALID_SIL_PATHS
        // silence edge or second speech edge
        if ((isaddedsil && e.E != nodes.size() -1) || (forbidinvalidsilpath && e.S != 0))
        {
            const size_t S = (size_t) (!isaddedsil ? e.S + nodes.size() : e.S);          // second speech edge comes from special 'silence state' node
            const size_t E = (size_t) (isaddedsil  ? e.E + nodes.size() : e.E);          // silence edge goes into special 'silence state' node
            // remaining lines here are 100% code dup from above, just operating on different (S, E)
            const double inscore = logalphas[S];
            const double pathscore = inscore + edgescore;
            atomicLogAdd (&logalphas[E], pathscore);

            if (returnEframescorrect)
            {
                double loginaccs = logaccalphas[S] - logalphas[S];
                logadd (loginaccs, logframescorrectedgej);
                double logpathacc = loginaccs + logalphas[S] + edgescore;
                atomicLogAdd (&logaccalphas[E], logpathacc);
            }
        }
#endif
    }

    template<typename edgeinforvector, typename nodeinfovector, typename aligninfovector, typename floatvector, typename doublevector>
    static inline __device__ void backwardlatticej (size_t j, const floatvector & edgeacscores, 
                                                    const size_t /*spalignunitid --unused*/, const size_t /*silalignunitid --unused*/, 
                                                    const edgeinforvector & edges, const nodeinfovector & nodes, 
                                                    const aligninfovector & /*aligns -- unused*/, const double totalfwscore, doublevector & logpps, 
                                                    doublevector & logalphas, doublevector & logbetas, float lmf, float wp, 
                                                    float amf, const float boostingfactor, const bool returnEframescorrect,
                                                    doublevector & logframescorrectedge, doublevector & logaccalphas, 
                                                    doublevector & logEframescorrect, doublevector & logaccbetas)
    {
        // output values
        double logpp = LOGZERO;
        double logEframescorrectj = LOGZERO;
        const bool boostmmi = (boostingfactor != 0.0f);

        // edge info
        const edgeinfowithscores & e = edges[j];
        double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        //zhaorui to deal with the abnormal score for sent start. 
        if (e.l < -200.0f)
            edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf;
        if (boostmmi)
            edgescore -= boostingfactor * exp (logframescorrectedge[j]);

        // get the frames-correct count for this edge that was computed during the forward pass
        double logframescorrectedgej = (returnEframescorrect || boostmmi) ? logframescorrectedge[j] : LOGZERO;

#ifdef FORBID_INVALID_SIL_PATHS
        // original mode
        const bool forbidinvalidsilpath = (logalphas.size() > nodes.size());    // we prune sil to sil path if alphabetablowup != 1
        const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1);        // HACK: 'unused' indicates artificially added sil/sp edge

        if (!isaddedsil) // original mode
#endif
        {
            const size_t S = e.S;
            const size_t E = e.E;

            // backward pass
            const double inscore = logbetas[E];
            const double pathscore = inscore + edgescore;
            atomicLogAdd (&logbetas[S], pathscore);

            // compute lattice posteriors on the fly since we are at it
            logpp = logalphas[S] + edgescore + logbetas[E] - totalfwscore;

            // similar logic for Eframescorrect
            if (returnEframescorrect)
            {
#ifdef DIRECT_MODE
                double loginaccs = logaccbetas[e.E] + edgescore;
                double logpathaccs = logbetas[e.E] + edgescore + logframescorrectedgej;
                logadd (logpathaccs, loginaccs);
                atomicLogAdd (&logaccbetas[e.S], logpathaccs);

                double logecorrect = logaccalphas[e.S] + edgescore + logbetas[e.E];
                logadd (logecorrect, logalphas[e.S] + edgescore + logframescorrectedgej + logbetas[e.E]);
                logadd (logecorrect, logalphas[e.S] + edgescore + logaccbetas[e.E]);
                logEframescorrectj = logecorrect - totalfwscore;            // be careful, this includes the denominator
#else
                // backward pass
                double loginaccs = logaccbetas[E] - logbetas[E];
                logadd (loginaccs, logframescorrectedgej);
                double logpathacc = loginaccs + logbetas[E] + edgescore;
                atomicLogAdd (&logaccbetas[S], logpathacc);

                // sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
                double logsum = logframescorrectedgej;                      // sum over this edge, left partial (alpha), right partial (beta)

                double edgelogaccalpha = logaccalphas[S] - logalphas[S];    // incoming partial expected frames correct
                logadd (logsum, edgelogaccalpha);

                double edgelogaccbeta = logaccbetas[E] - logbetas[E];       // partial expected frames correct from the end
                logadd (logsum, edgelogaccbeta);

                logEframescorrectj = logsum;                                // that's it
#endif
            }
        }

#ifdef FORBID_INVALID_SIL_PATHS
        double logpp2 = LOGZERO;
        double logEframescorrectj2 = LOGZERO;

        // silence edge or second speech edge
        if ((isaddedsil && e.E != nodes.size() -1) || (forbidinvalidsilpath && e.S != 0))
        {
            const size_t S = (size_t) (!isaddedsil ? e.S + nodes.size() : e.S);          // second speech edge comes from special 'silence state' node
            const size_t E = (size_t) (isaddedsil  ? e.E + nodes.size() : e.E);          // silence edge goes into special 'silence state' node
            // remaining lines here are code dup from above, with two changes: logadd2/logEframescorrectj2 instead of logadd/logEframescorrectj

            // backward pass
            const double inscore = logbetas[E];
            const double pathscore = inscore + edgescore;
            atomicLogAdd (&logbetas[S], pathscore);

            // compute lattice posteriors on the fly since we are at it
            logpp2 = logalphas[S] + edgescore + logbetas[E] - totalfwscore; // second edge (logpp2)

            // similar logic for Eframescorrect
            if (returnEframescorrect)
            {
                // backward pass
                double loginaccs = logaccbetas[E] - logbetas[E];
                logadd (loginaccs, logframescorrectedgej);
                double logpathacc = loginaccs + logbetas[E] + edgescore;
                atomicLogAdd (&logaccbetas[S], logpathacc);

                // sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
                double logsum = logframescorrectedgej;                      // sum over this edge, left partial (alpha), right partial (beta)

                double edgelogaccalpha = logaccalphas[S] - logalphas[S];    // incoming partial expected frames correct
                logadd (logsum, edgelogaccalpha);

                double edgelogaccbeta = logaccbetas[E] - logbetas[E];       // partial expected frames correct from the end
                logadd (logsum, edgelogaccbeta);

                logEframescorrectj2 = logsum;                               // that's it for this second edge
            }

            // sum logpp2 and logEframescorrectj2
            // Eframescorrect must be summed up in a weighted fashion, weighted by PP
            double numer = logEframescorrectj + logpp;
            logadd (numer, logEframescorrectj2 + logpp2);       // weighted sum, weighted by respective (log)pp
            logadd (logpp, logpp2);                             // (log)pp is just the sum of the two posteriors
            double denom = logpp;                               // and that is also the denominator for the weighted sum
            logEframescorrectj = numer - denom;                 // weighted sum
        }
#else
        nodes;
#endif

        // write back return values
        if (logpp > 0.0)            // clip to log 1 (may be possible due to small numeric inaccuracies, although it really shouldn't happen)
            logpp = 0.0;
        logpps[j] = logpp;
        if (returnEframescorrect)
            logEframescorrect[j] = logEframescorrectj;
    }

    template<typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector, typename doublevector, typename matrix>
    static inline __device__ void sMBRerrorsignalj (size_t j, const ushortvector & alignstateids, const uintvector & alignoffsets,
                                                    const edgeinfowithscoresvector & edges, 
                                                    const nodeinfovector & nodes, const doublevector & logpps, const float amf,
                                                    const doublevector & logEframescorrect, const double logEframescorrecttotal,
                                                    matrix & errorsignal, matrix & errorsignalneg)
    {
        size_t ts = nodes[edges[j].S].t;
        size_t te = nodes[edges[j].E].t;
        if (ts != te)
        {
#ifdef DIRECT_MODE
            float logEframescorrectj = logEframescorrect[j];
            size_t offset = alignoffsets[j];
            for (size_t t = ts; t < te; t++)
            {
                const size_t s = (size_t) alignstateids[t - ts + offset];
                atomicLogAdd (&errorsignal(s,t), logEframescorrectj);
            }
#else
            const double diff = expdiff (logEframescorrect[j], logEframescorrecttotal);
            // Note: the contribution of the states of an edge to their senones is the same for all states
            // so we compute it once and add it to all; this will not be the case without hard alignments.
            double absdiff = fabs (diff);
            if (absdiff == 0.0f)
                return;
            const float logedgecorrect = (float) (logpps[j] + log (absdiff));
            size_t offset = alignoffsets[j];
            for (size_t t = ts; t < te; t++)
            {
                const size_t s = (size_t) alignstateids[t - ts + offset];
                if (diff > 0.0)
                    atomicLogAdd(&errorsignal(s,t), logedgecorrect);
                else
                    atomicLogAdd(&errorsignalneg(s,t), logedgecorrect);
            }
            absdiff = amf;        // just to reference it because we hope to support linear mode as well
#endif
        }
    }

    // accumulate a per-edge quantity into the states that the edge is aligned with
    // Use this for MMI passing the edge posteriors logpps[] as logq, or for sMBR passing logEframescorrect[].
    // j=edge index, alignment in (alignstateids, alignoffsets)
    template<typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector,typename doublevector, typename matrix>
    static inline __device__ void stateposteriorsj (size_t j, const ushortvector & alignstateids, const uintvector & alignoffsets,
                                                    const edgeinfowithscoresvector & edges,
                                                    const nodeinfovector & nodes, const doublevector & logqs /*quantity to accumulate*/,
                                                    matrix & logacc /*accumulator to accumulate into*/)
    {
        size_t ts = nodes[edges[j].S].t;
        size_t te = nodes[edges[j].E].t;
        if (ts != te)
        {
            const float logq = (float) logqs[j];                // per-edge quantity to accumulate, e.g. edge posteriors -> state posteriors
            size_t offset = alignoffsets[j];
            for (size_t t = ts; t < te; t++)
            {
                const size_t s = (size_t) alignstateids[t - ts + offset];   // get state for this (j,t)
                atomicLogAdd (&logacc(s,t), logq);
            }
       }
    }
};

};};


#pragma pop_macro ("atomicCAS")
#pragma pop_macro ("atomicAdd")
#pragma pop_macro ("__device__")

back to top