https://github.com/Microsoft/CNTK
Raw File
Tip revision: 10a8ffcf50d7b9225f3236ffcfdc422b2014fb92 authored by microsoft-github-policy-service[bot] on 23 September 2022, 14:06:50 UTC
Microsoft mandatory file (#3870)
Tip revision: 10a8ffc
latticefunctionskernels.h
// latticefunctionskernels.cu(.h) -- kernels for lattice ops intended for use with CUDA, to be called from actual CUDA kernels.
//
// To make this compile for the CPU (emulation for testing), add this line before #including this:
// #define __device__
//
// F. Seide, V-hansu

#define FORBID_INVALID_SIL_PATHS // [v-hansu] prune path that start from sil(sp) and go into sil, only used with addsil is adopted

#pragma once

#pragma push_macro("__device__")
#pragma push_macro("atomicAdd")
#pragma push_macro("atomicCAS")

#include "latticestorage.h"
#include <limits>

namespace msra { namespace cuda { class passtextureref; } }

#ifdef CPUONLY
#define __kernel_emulation__
#endif

#ifdef __kernel_emulation__
#include "math.h" // to get exp() and log() compiled correctly with c++
#include <stdexcept>
using namespace std;
#ifndef __device__
#define __device__
#endif
#define CUDART_MIN_DENORM_F numeric_limits<float>::denorm_min()
// renamed to x- so we make sure to not accidentally use these; rename back if ever needed again
#define xatomicAdd(address, value) (*(address) += (value)) // don't forget to #undef (#praga pop_macro)! Otherwise CUDA might compile with this...
#define atomicCAS(address, compare, val) \
    *address;                            \
    *address = *address == compare ? val : *address;
#define __double_as_longlong(in) (*(unsigned long long int *) &in)
#define __longlong_as_double(in) (*(double *) &in)
#define __float_as_int(in) (*(int *) &in)
#define __int_as_float(in) (*(float *) &in)
#else // TODO: remove this once we got this figured out
#include "math_constants.h"
#if __CUDA_ARCH__ < 200
//#warning Sequence training not supported on 1.x CUDA machines.
#define force_crash() (*((int *) -1) = 0)         // TODO: this does not in fact seem to crash it...
#define xatomicAdd(a, v) (force_crash(), *(a) = v) // force a crash if used with 1.x devices
#define xatomicCAS(address, compare, val) (*(address) = compare + val, *((int *) -1) = 0)
#define __double_as_longlong(in) (force_crash(), in)
#define __longlong_as_double(in) (force_crash(), in)
#define __float_as_int(in) (force_crash(), in)
#define __int_as_float(in) (force_crash(), in)
#endif
#endif

namespace msra { namespace lattices {

struct somedata // example type to have a pattern to copy from
{
    size_t fortytwo;
};

struct empty
{
};

// Note that the code that uses this (edgealignmentj) will assume 3-state left-to-right
// except for /sil/ and /sp/ which are treated specially.
// TODO: either check when creating this whether this assumption is true, or control this through a flag in here.
struct lr3transP // lr3 = 3-state left-to-right architecture
{

    static const size_t MAXSTATES = 3;
    size_t numstates;
    float loga[MAXSTATES + 1][MAXSTATES + 1];

    lr3transP()
    {
#ifdef INITIAL_STRANGE
        numstates = 3;
        for (size_t i = 0; i < NUMSTATES + 1; i++)
            for (size_t j = 0; j < NUMSTATES + 1; j++)
            {
                loga[i][j] = LOGZERO;
            }
#endif
    }
};

struct lrhmmdef // left-to-right HMM (no /sil/)
{
    static const size_t MAXSTATES = 3; // we use a fixed memory allocation since it's almost always 3 anyway

    unsigned char transPindex;           // index of monophone to find transP matrix
    unsigned char numstates;             // number of states; code supports only either 1 or 3
    unsigned short senoneids[MAXSTATES]; // [0..numstates-1] senone indices

    size_t getsenoneid(size_t i) const
    {
        return (size_t) senoneids[i];
    }
    size_t getnumstates() const
    {
        return (size_t) numstates;
    }
    const struct lr3transP &gettransP(const lr3transP *transPs) const
    {
        return transPs[transPindex];
    }

    lrhmmdef()
    {
#ifdef INITIAL_STRANGE
        transPindex = unsigned char(-1);
        numstates = unsigned char(-1);
        for (size_t i = 0; i < MAXSTATES; i++)
        {
            senoneids[i] = unsigned short(-1);
        }
#endif
    }
};

#if 1 // straight-forward version
#else // CUDA hacked version
hmm gethmm(hmms, i)
{
    ushort4 u4 = *((ushort4) &hmms[i]);
    lrhmmdef hmm;
    hmm.transPindex = u4.x & 0xff;
    hmm.numstates = u4.x >> 8;
    hmm.senoneids[0] = u4.y;
    hmm.senoneids[1] = u4.z;
    hmm.senoneids[2] = u4.w;
}
#endif

#ifndef LOGZERO
#define LOGZERO -1e30f
#endif

class bpmatrixref
{
private:
    unsigned short *p; // pointer in CUDA space of this device
    size_t numrows;    // rows()
    size_t numcols;    // cols()
    size_t colstride;  // height of column = rows() rounded to multiples of 4
    __device__ void checkbounds(size_t i, size_t j) const
    {
        if (i >= numrows || j >= numcols)
#ifdef __kernel_emulation__
            throw ::logic_error("out of boundary!!!");
#else
            *((int *) -1) = 0;
#endif
    }
    __device__ size_t locate(size_t i, size_t j) const
    {
        checkbounds(i, j);
        return j * colstride + i;
    } // matrix in column-wise storage
public:
    __device__ bpmatrixref(unsigned short *address, size_t n, size_t m)
    {
        numrows = n;
        numcols = m;
        colstride = n;
        p = address;
    }

    __device__ unsigned short &operator()(size_t i, size_t j)
    {
        return p[locate(i, j)];
    }
    __device__ const unsigned short &operator()(size_t i, size_t j) const
    {
        return p[locate(i, j)];
    }
};

// this class contains all-static methods that are inner pieces of thread kernels for use with CUDA
struct latticefunctionskernels
{
    // [v-hansu]  mimic the float version of atomicCAS
    static __device__ float atomicCASfloatdouble(float *address, float compare, float val)
    {
        int *intaddress = (int *) address;
        int intcompare = __float_as_int(compare); // __double_as_longlong : read double as unsigned long long
        int intval = __float_as_int(val);
        int result = atomicCAS(intaddress, intcompare, intval);
        return __int_as_float(result);
    }

    static __device__ double atomicCASfloatdouble(double *address, double compare, double val)
    {
        unsigned long long int *longlongintaddress = (unsigned long long int *) address;
        unsigned long long int longlongintcompare = __double_as_longlong(compare); // __double_as_longlong : read double as unsigned long long
        unsigned long long int longlongintval = __double_as_longlong(val);
        unsigned long long int result = atomicCAS(longlongintaddress, longlongintcompare, longlongintval);
        return __longlong_as_double(result);
    }

    static __device__ void logaddratio(double &loga, double diff)
    {
        if (diff < -37.0f)
            return; // log (2^-53), 52-bit mantissa -> cut of after 53th bit
        loga += log(1.0 + exp(diff));
    }
    static __device__ void logaddratio(float &loga, float diff)
    {
        if (diff < -17.0f)
            return; // log (2^-24), 23-bit mantissa -> cut of after 24th bit
        loga += logf(1.0f + expf(diff));
    }

    template <typename FLOAT>
    static __device__ void swap(FLOAT &left, FLOAT &right)
    { // exchange values stored at _Left and _Right
        FLOAT tmp = left;
        left = right;
        right = tmp;
    }

    // overloads for exp() for float and double, so that we can use templates
    static __device__ float expfd(float x)
    {
        return ::expf(x);
    } // TODO: ain't there an overload for this?
    static __device__ double expfd(double x)
    {
        return ::exp(x);
    }

    // Compute the difference of two numbers, which are represented as their logs.
    // The return value is a non-log value. exp(loga) - exp(logb)
    template <typename FLOAT>
    static __device__ FLOAT expdiff(FLOAT loga, FLOAT logb)
    {
        if (logb < loga) // logb - loga < 0 => exp(logb-loga) < 1
            return expfd(loga) * (1 - expfd(logb - loga));
        else
            return -expfd(logb) * (1 - expfd(loga - logb));
    }

    // loga <- log (exp (loga) + exp (logb)) = log (exp (loga) * (1.0 + exp (logb - loga)) = loga + log (1.0 + exp (logb - loga))
    template <typename FLOAT>
    static __device__ void logadd(FLOAT &loga, FLOAT logb)
    {
        if (logb > loga)      // we add smaller to bigger
            swap(loga, logb); // loga is bigger
        if (loga <= LOGZERO)  // both are 0
            return;
        logaddratio(loga, logb - loga);
    }

    // does the same as above but if the bigger one is too small, we assign a small value to it
    template <typename FLOAT>
    static __device__ void logaddseen(FLOAT &loga, FLOAT logb)
    {
        if (logb > loga)      // we add smaller to bigger
            swap(loga, logb); // loga is bigger
        if (loga <= LOGZERO)  // both are 0
        {
            loga = logf(CUDART_MIN_DENORM_F); // [v-hansu] we hope to separate LOGZERO (unseen states) and logf(CUDART_MIN_DENORM_F) (seen states with small prob)
            return;
        }
        logaddratio(loga, logb - loga);
    }

#if 1
    static inline __device__ float bitsasfloat(int b)
    {
        return __int_as_float(b);
    }
    static inline __device__ double bitsasfloat(unsigned long long int b)
    {
        return __longlong_as_double(b);
    }
    static inline __device__ int floatasbits(float f)
    {
        return __float_as_int(f);
    }
    static inline __device__ unsigned long long int floatasbits(double f)
    {
        return __double_as_longlong(f);
    }

    template <typename FLOAT>                                       // adapted from [http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ixzz32EuzZjxV]
    static __device__ FLOAT atomicLogAdd(FLOAT *address, FLOAT val) // direct adaptation from NVidia source code
    {
        typedef decltype(floatasbits(val)) bitstype;
        bitstype *address_as_ull = (bitstype *) address;
        bitstype old = *address_as_ull, assumed;
        do
        {
            assumed = old;
            FLOAT sum = bitsasfloat(assumed);
            logaddseen(sum, val);
            old = atomicCAS(address_as_ull, assumed, floatasbits(sum));
        } while (assumed != old);
        // note: critically, ^^ this comparison must copare the bits ('int') instead of the converted float values, since this will fail for NaNs (NaN != NaN is true always)
        return bitsasfloat(old);
    }
#else // this code does not work because (assumed != old) will not compare correctly in case of NaNs
    // same pattern as atomicAdd(), but performing the log-add operation instead
    template <typename FLOAT>
    static __device__ FLOAT atomicLogAdd(FLOAT *address, FLOAT val)
    {
        FLOAT old = *address;
        FLOAT assumed;
        FLOAT logaddresult;
        do
        {
            assumed = old;          // old is the assumed value at address
            logaddresult = assumed; // for next step to compute logaddresult, assumed shall be the same as before
            logaddseen(logaddresult, val);
            old = atomicCASfloatdouble(address, assumed, logaddresult);
        } while (assumed != old); // if old == assumed, the *address is not changed in this loop, so this is safe
        return old;
    }
#endif

    // [v-hansu] shuffling accessing order for a item in cubic(Ni, Nj, Nk) with index i, j, k according to shufflemode
    static inline __device__ size_t shuffle(size_t i, size_t Ni, size_t j, size_t Nj, size_t k, size_t Nk, size_t shufflemode)
    {
        if (shufflemode == 0)
            return i + j * Ni + k * Nj * Ni;
        else if (shufflemode == 1) // inverse
            return k + j * Nk + i * Nj * Nk;
        else if (shufflemode == 2) // flip i and j
            return j + i * Nj + k * Ni * Nj;
        else if (shufflemode == 3) // flip j and k
            return i + k * Ni + j * Nk * Ni;
        else if (shufflemode == 4)
            return j + k * Nj + i * Nk * Nj;
        else
            *((int *) -1) = 0; // shall not get here, WRONG
        return 0;
    }

    template <typename doublevector>
    static __device__ void setvaluej(size_t j, doublevector &thisvector, double value)
    {
        thisvector[j] = value;
    }

    // zhaorui
    static inline __device__ float getlogtransp(lr3transP transP, int from, int to)
    {
        /*if (from < -1 || from >= transP.MAXSTATES || to > transP.MAXSTATES)
        {
            // printf("from: %d to: %d\n", from, to);
            return LOGZERO;
        }*/
        return transP.loga[from + 1][to];
    }
    template <typename lrhmmdefvector, typename lr3transPvector, typename matrix, typename nodeinfovector, typename edgeinfowithscoresvector, typename aligninfovector, typename ushortvector, typename uintvector, typename floatvector, typename sizetvector>
    static inline __device__ void edgealignmentj(size_t j, const lrhmmdefvector &hmms, const lr3transPvector &transPs, const size_t spalignunitid,
                                                 const size_t silalignunitid, const matrix &logLLs, const nodeinfovector &nodes,
                                                 const edgeinfowithscoresvector &edges, const aligninfovector &aligns,
                                                 const uintvector &alignoffsets, ushortvector &backptrstorage, const sizetvector &backptroffsets,
                                                 ushortvector &alignresult, floatvector &edgeacscores)
    { // TODO: alignresult will change to (start,end)
        // mostly finished
        // some preparation
        size_t as = edges[j].firstalign; // align start
        size_t ae = (j + 1) < edges.size() ? (size_t) edges[j + 1].firstalign : aligns.size();
        if (as == ae) // the last empty alignment
            return;
        size_t ts = nodes[edges[j].S].t;

        float fwscore = 0.0f;                // score passed across phone boundaries
        size_t alignindex = alignoffsets[j]; // index to set (result)
#ifndef PARALLEL_SIL
        const bool isSil = (aligns[as].unit == silalignunitid || aligns[ae - 1].unit == silalignunitid);
        if (isSil)
            return; // we do not support silence edge now, which is computed by cpu, may change when we support it
#endif
        // Viterbi alignment
        for (size_t k = as; k < ae; k++)
        {
            const aligninfo align = aligns[k];
            const size_t numframes = align.frames;
            const bool isSp = (align.unit == spalignunitid);
            const bool isSil = (align.unit == silalignunitid);

            const lrhmmdef hmm = hmms[align.unit];
            const lr3transP transP = transPs[hmm.transPindex];

            // pre-fetch senone ids into registers
            size_t senoneid0 = hmm.senoneids[0];
            size_t senoneid1 = 0;
            size_t senoneid2 = 0;
            if (!isSp) // fetch only if needed--may save some memory cycles
            {
                senoneid1 = hmm.senoneids[1];
                senoneid2 = hmm.senoneids[2];
            }

            const size_t te = ts + numframes; // end time of current unit

            size_t state1step0to1 = te; // inflection point from state 0 to 1, record in state 1
                                        // size_t state1stepm1to1 = te;
            size_t state2step0to1 = te; // inflection point from state 0 to 1, record in state 2
            // size_t state2stepm1to1 = te;                    // inflection point from state 0 to 1, record in state 2
            size_t state2step1to2 = te; // inflection point from state 1 to 2, record in state 2
            size_t state2step0to2 = te;

            // now we only support transition from -1 to 0 or 2 for sil
            float pathscore0 = fwscore; // log pp in state 0
            float pathscore1 = fwscore; // log pp in state 1
            float pathscore2 = fwscore; // log pp in state 2

            // first frame
            if (ts != te) // for t = ts, initialization
            {
                /*    if (isSil)                                                              // for sil, -1 to 2 and -1 to 0 is permitted
                {
                    pathscore0 += getlogtransp(transP,-1,0) + logLLs(senoneid0,ts);
                    pathscore2 += getlogtransp(transP,-1,2) + logLLs(senoneid2,ts);
                }
                else                                                                    // for others, only -1 to 0 is permitted
                {
                    pathscore0 += getlogtransp(transP, -1, 0) + logLLs(senoneid0, ts);
                    pathscore1 += getlogtransp(transP, -1, 1) + logLLs(senoneid1, ts);

                }*/
                pathscore2 += getlogtransp(transP, -1, 2) + logLLs(senoneid2, ts);
                pathscore1 += getlogtransp(transP, -1, 1) + logLLs(senoneid1, ts);
                // state1stepm1to1 = ts;
                pathscore0 += getlogtransp(transP, -1, 0) + logLLs(senoneid0, ts);
            }

            float pathscore2last = pathscore2; // allocate last for state 2 because the order of computation below is 2->1->0, last state 2 is needed because from 2 to 0 or 1 is permitedted for sil.
            float pathscore1last = pathscore1; // allocate last for state 1 because the order of computation below is 2->1->0, last state 1 is needed because from 1 to 0 is permitedted for sil.

            size_t backptroffset = backptroffsets[j]; // we make use of backptrstorage in backptroffsets[j] for viterbi of ergodic model (silence)

            bpmatrixref backptrmatrix(&backptrstorage[backptroffset], hmm.MAXSTATES, numframes);
            // subsequent frames
            for (size_t t = ts + 1; t < te; t++)
            {
                if (!isSp)
                {
                    // state [2]
                    pathscore2 += getlogtransp(transP, 2, 2); // log pp from state 2 to 2
                    if (isSil)
                        backptrmatrix(2, t - ts - 1) = 2;
                    const float pathscore12 = pathscore1 + getlogtransp(transP, 1, 2); // log pp from state 1 to 2
                    if (pathscore12 >= pathscore2)                                     // if state 1->2
                    {
                        pathscore2 = pathscore12;
                        state2step0to1 = state1step0to1; // record the inflection point
                                                         // state2stepm1to1 = state1stepm1to1;
                        state2step1to2 = t;              // record the inflection point
                        state2step0to2 = te;
                        if (isSil)
                            backptrmatrix(2, t - ts - 1) = 1;
                    }
                    // if (isSil)                                                                  // only silence have path from 0 to 2
                    {
                        const float pathscore02 = pathscore0 + getlogtransp(transP, 0, 2); // log pp from state 0 to 2
                        if (pathscore02 >= pathscore2)                                     // if state 0->2
                        {
                            pathscore2 = pathscore02;
                            if (isSil)
                                backptrmatrix(2, t - ts - 1) = 0;
                            state2step0to2 = t;
                            state2step1to2 = te;
                        }
                    }

                    // state [1]
                    pathscore1 += getlogtransp(transP, 1, 1); // log pp from state 1 to 1
                    if (isSil)
                        backptrmatrix(1, t - ts - 1) = 1;
                    const float pathscore01 = pathscore0 + getlogtransp(transP, 0, 1); // log pp from state 0 to 1
                    if (pathscore01 >= pathscore1)                                     // if state 0 -> 1
                    {
                        pathscore1 = pathscore01;
                        state1step0to1 = t; // record the inflection point
                                            // state1stepm1to1 = te;
                        if (isSil)
                            backptrmatrix(1, t - ts - 1) = 0;
                    }

                    if (isSil) // only silence have path from 2 to 1
                    {
                        const float pathscore21 = pathscore2last + getlogtransp(transP, 2, 1);
                        if (pathscore21 >= pathscore1) // if state 2 -> 1
                        {
                            pathscore1 = pathscore21;
                            backptrmatrix(1, t - ts - 1) = 2;
                        }
                    }
                }
                // state [0]
                pathscore0 += getlogtransp(transP, 0, 0);
                if (isSil) // only silence have path from 2 or 1 to 0
                {
                    backptrmatrix(0, t - ts - 1) = 0;
                    const float pathscore20 = pathscore2last + getlogtransp(transP, 2, 0); // log pp from state 2 to 0
                    if (pathscore20 >= pathscore0)
                    {
                        pathscore0 = pathscore20;
                        backptrmatrix(0, t - ts - 1) = 2;
                    }
                    const float pathscore10 = pathscore1last + getlogtransp(transP, 1, 0); // log pp from state 1 to 0
                    if (pathscore10 >= pathscore0)
                    {
                        pathscore0 = pathscore10;
                        backptrmatrix(0, t - ts - 1) = 1;
                    }
                }

                // add log LLs
                pathscore0 += logLLs(senoneid0, t);
                if (!isSp) // only fetch if needed, saves mem access
                {
                    pathscore1 += logLLs(senoneid1, t);
                    pathscore2 += logLLs(senoneid2, t);
                }
                pathscore1last = pathscore1; // update pathscore1last
                pathscore2last = pathscore2; // update pathscore2last
            }

            // final 'next' transition that exits from last frame

            if (ts == te) // if sp tee model, will not in next loop
            {
                pathscore2 = pathscore0 + getlogtransp(transP, -1, 1);
            }
            else if (isSp)
            {
                pathscore2 = pathscore0 + getlogtransp(transP, 0, 1); // sp model, from 0 to 1
                // printf(" sp, %f\n", pathscore2);
            }

            else if (isSil) // for sil, the exit state can be 0 or 2.
            {
                const float pathscore03 = pathscore0 + getlogtransp(transP, 0, 3);
                pathscore2 += getlogtransp(transP, 2, 3);
                if (pathscore03 > pathscore2)
                {
                    pathscore2 = pathscore03;
                }
            }
            else
                pathscore2 += getlogtransp(transP, 2, 3);

            fwscore = pathscore2; // propagate across phone boundaries

            // emit alignment

            if (!isSil)
            {
                if (state2step0to2 < te) // from 0 to 2
                {
                    state2step0to2 += alignindex - ts;
                    for (size_t t = alignindex; t < alignindex + numframes; t++) // set the final alignment
                    {
                        size_t senoneid;
                        if (t < state2step0to2) // in state 0
                            senoneid = senoneid0;
                        else // in state 2
                            senoneid = senoneid2;
                        alignresult[t] = (unsigned short) senoneid;
                    }
                }
                else // from 1 to 2
                {
                    state2step0to1 += alignindex - ts; // convert to align measure
                    state2step1to2 += alignindex - ts;
                    for (size_t t = alignindex; t < alignindex + numframes; t++) // set the final alignment
                    {
                        size_t senoneid;
                        if (state2step0to1 < alignindex - ts + te && t < state2step0to1)
                            senoneid = senoneid0;
                        else if (t < state2step1to2) // in state 1
                            senoneid = senoneid1;
                        else // in state 2
                            senoneid = senoneid2;
                        alignresult[t] = (unsigned short) senoneid;
                    }
                }
            }
            else // for silence
            {
                size_t lastpointer = 2;
                const float pathscore03 = pathscore0 + getlogtransp(transP, 0, 3);
                if (pathscore03 >= pathscore2) // exit state is 0
                {
                    alignresult[alignindex + numframes - 1] = (unsigned short) senoneid0;
                    lastpointer = 0;
                }
                else // exit state is 2
                    alignresult[alignindex + numframes - 1] = (unsigned short) senoneid2;

                for (size_t t = alignindex + numframes - 2; (t + 1) > alignindex; t--) // set the final alignment
                {
                    lastpointer = backptrmatrix(lastpointer, t - alignindex);
                    size_t senoneid = (size_t)(-1);
                    if (lastpointer == 0)
                        senoneid = senoneid0;
                    else if (lastpointer == 1)
                        senoneid = senoneid1;
                    else if (lastpointer == 2)
                        senoneid = senoneid2;
                    alignresult[t] = (unsigned short) senoneid;
                }
            }
            ts = te;
            alignindex += numframes;
        }
        edgeacscores[j] = fwscore;
    }

    // compute the final error signal from gammas and state-consolidated Eframescorrect
    // in-place operation is supported (i.e. output = one of the inputs)
    template <typename matrix>
    static inline __device__ void computesMBRerrorsignals(const size_t s, const matrix &loggammas, const matrix &logEframescorrect,
                                                          const double logEframescorrecttotal, const float kappa, matrix &errorsignal)
    {
        const float Eframescorrecttotal = expf((float) logEframescorrecttotal);
        const size_t T = errorsignal.cols();
        for (size_t t = 0; t < T; t++)
            errorsignal(s, t) = expf(loggammas(s, t)) * (expf(logEframescorrect(s, t)) - Eframescorrecttotal) * kappa;
    }

    // test if a state is silence [v-hansu]
    // WARNING, this function only support models with 9304 states
    // TODO: change this later on
    static inline __device__ bool issilencestate(size_t stateid, size_t numsenones)
    {
        if (numsenones == 9304 && (stateid == 7670 || stateid == 7671 || stateid == 7672))
            return true;
        else
            return false;
    }

    // compare two states and check if they are of the same class [v-hansu]
    template <typename ushortvector>
    static inline __device__ bool isofsameclass(size_t statea, size_t stateb, ushortvector senone2classmap)
    {
        if (senone2classmap.size() == 0) // no map provided, we just do normal comparison
            return (statea == stateb);
        else
            return senone2classmap[statea] == senone2classmap[stateb];
    }

    // Phase 1 of forwardbackward algorithm
    // returnEframescorrect means sMBR mode
    template <typename edgeinforvector, typename nodeinfovector, typename aligninfovector, typename ushortvector, typename uintvector, typename floatvector, typename doublevector>
    static inline __device__ void forwardlatticej(const size_t j, const floatvector &edgeacscores,
                                                  const size_t /*spalignunitid --unused*/, const size_t silalignunitid,
                                                  const edgeinforvector &edges, const nodeinfovector &nodes, const aligninfovector &aligns,
                                                  const ushortvector &alignments, const uintvector &alignmentoffsets,
                                                  doublevector &logalphas, float lmf, float wp, float amf, const float boostingfactor,
                                                  const ushortvector &uids, const ushortvector senone2classmap, const bool returnEframescorrect,
                                                  doublevector &logframescorrectedge, doublevector &logaccalphas)
    {
        // edge info
        const edgeinfowithscores &e = edges[j];
        double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf; // note: edgeacscores[j] == LOGZERO if edge was pruned
        // zhaorui to deal with the abnormal score for sent start.
        if (e.l < -200.0f)
            edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf;
        const bool boostmmi = (boostingfactor != 0.0f);
        // compute the frames-correct count for this edge
        double logframescorrectedgej = LOGZERO;
        const size_t numsenones = 9304; // WARNING: this is a hack, please fix this once smbr or bmmi is working! [v-hansu]
        bool skipsilence = true;        // currently we skip silence for BMMI and sMBR [v-hansu]
        if (returnEframescorrect || boostmmi)
        {
            size_t ts = nodes[e.S].t;
            size_t te = nodes[e.E].t;
            size_t framescorrect = 0; // count raw number of correct frames
            size_t startindex = alignmentoffsets[j];

            size_t as = e.firstalign; // align start
            size_t ae = (j + 1) < edges.size() ? (size_t) edges[j + 1].firstalign : aligns.size();
            const bool isSil = (ae == as + 1 && aligns[as].unit == silalignunitid); // the order of this judgement shall be changed to save memory access
            if (!(isSil && skipsilence))                                            // we don't count silence when 1. is silence; 2, skip them
            {
                for (size_t t = ts; t < te; t++)
                {
                    if (!(skipsilence && issilencestate(alignments[t - ts + startindex], numsenones)))
                        framescorrect += isofsameclass(alignments[t - ts + startindex], uids[t], senone2classmap); // we only count correct && non-silence state
                }
            }
            logframescorrectedgej = (framescorrect > 0) ? log((double) framescorrect) : LOGZERO; // remember for backward pass
            logframescorrectedge[j] = logframescorrectedgej;
        }
        if (boostmmi)
            edgescore -= boostingfactor * exp(logframescorrectedge[j]);

#ifdef FORBID_INVALID_SIL_PATHS
        const bool forbidinvalidsilpath = (logalphas.size() > nodes.size()); // we constrain sil to sil path if node space has been blown up
        const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1);     // HACK: 'unused' indicates artificially added sil/sp edge

        // original mode
        if (!isaddedsil)
#endif
        {
            const size_t S = e.S;
            const size_t E = e.E;

            const double inscore = logalphas[S];
            const double pathscore = inscore + edgescore;

            atomicLogAdd(&logalphas[E], pathscore);

            if (returnEframescorrect)
            {
#ifdef DIRECT_MODE
                double loginaccs = logaccalphas[e.S] + edgescore;
                double logpathaccs = logalphas[e.S] + edgescore + logframescorrectedgej;
                logadd(logpathaccs, loginaccs);
                atomicLogAdd(&logaccalphas[e.E], logpathaccs);
#else
                double loginaccs = logaccalphas[S] - logalphas[S];
                logadd(loginaccs, logframescorrectedgej);
                double logpathacc = loginaccs + logalphas[S] + edgescore;
                atomicLogAdd(&logaccalphas[E], logpathacc);
#endif
            }
        }

#ifdef FORBID_INVALID_SIL_PATHS
        // silence edge or second speech edge
        if ((isaddedsil && e.E != nodes.size() - 1) || (forbidinvalidsilpath && e.S != 0))
        {
            const size_t S = (size_t)(!isaddedsil ? e.S + nodes.size() : e.S); // second speech edge comes from special 'silence state' node
            const size_t E = (size_t)(isaddedsil ? e.E + nodes.size() : e.E);  // silence edge goes into special 'silence state' node
            // remaining lines here are 100% code dup from above, just operating on different (S, E)
            const double inscore = logalphas[S];
            const double pathscore = inscore + edgescore;
            atomicLogAdd(&logalphas[E], pathscore);

            if (returnEframescorrect)
            {
                double loginaccs = logaccalphas[S] - logalphas[S];
                logadd(loginaccs, logframescorrectedgej);
                double logpathacc = loginaccs + logalphas[S] + edgescore;
                atomicLogAdd(&logaccalphas[E], logpathacc);
            }
        }
#endif
    }

    template <typename edgeinforvector, typename nodeinfovector, typename aligninfovector, typename floatvector, typename doublevector>
    static inline __device__ void backwardlatticej(size_t j, const floatvector &edgeacscores,
                                                   const size_t /*spalignunitid --unused*/, const size_t /*silalignunitid --unused*/,
                                                   const edgeinforvector &edges, const nodeinfovector &nodes,
                                                   const aligninfovector & /*aligns -- unused*/, const double totalfwscore, doublevector &logpps,
                                                   doublevector &logalphas, doublevector &logbetas, float lmf, float wp,
                                                   float amf, const float boostingfactor, const bool returnEframescorrect,
                                                   doublevector &logframescorrectedge, doublevector &logaccalphas,
                                                   doublevector &logEframescorrect, doublevector &logaccbetas)
    {
        // output values
        double logpp = LOGZERO;
        double logEframescorrectj = LOGZERO;
        const bool boostmmi = (boostingfactor != 0.0f);

        // edge info
        const edgeinfowithscores &e = edges[j];
        double edgescore = (e.l * lmf + wp + edgeacscores[j]) / amf;
        // zhaorui to deal with the abnormal score for sent start.
        if (e.l < -200.0f)
            edgescore = (0.0 * lmf + wp + edgeacscores[j]) / amf;
        if (boostmmi)
            edgescore -= boostingfactor * exp(logframescorrectedge[j]);

        // get the frames-correct count for this edge that was computed during the forward pass
        double logframescorrectedgej = (returnEframescorrect || boostmmi) ? logframescorrectedge[j] : LOGZERO;

#ifdef FORBID_INVALID_SIL_PATHS
        // original mode
        const bool forbidinvalidsilpath = (logalphas.size() > nodes.size()); // we prune sil to sil path if alphabetablowup != 1
        const bool isaddedsil = forbidinvalidsilpath && (e.unused == 1);     // HACK: 'unused' indicates artificially added sil/sp edge

        if (!isaddedsil) // original mode
#endif
        {
            const size_t S = e.S;
            const size_t E = e.E;

            // backward pass
            const double inscore = logbetas[E];
            const double pathscore = inscore + edgescore;
            atomicLogAdd(&logbetas[S], pathscore);

            // compute lattice posteriors on the fly since we are at it
            logpp = logalphas[S] + edgescore + logbetas[E] - totalfwscore;

            // similar logic for Eframescorrect
            if (returnEframescorrect)
            {
#ifdef DIRECT_MODE
                double loginaccs = logaccbetas[e.E] + edgescore;
                double logpathaccs = logbetas[e.E] + edgescore + logframescorrectedgej;
                logadd(logpathaccs, loginaccs);
                atomicLogAdd(&logaccbetas[e.S], logpathaccs);

                double logecorrect = logaccalphas[e.S] + edgescore + logbetas[e.E];
                logadd(logecorrect, logalphas[e.S] + edgescore + logframescorrectedgej + logbetas[e.E]);
                logadd(logecorrect, logalphas[e.S] + edgescore + logaccbetas[e.E]);
                logEframescorrectj = logecorrect - totalfwscore; // be careful, this includes the denominator
#else
                // backward pass
                double loginaccs = logaccbetas[E] - logbetas[E];
                logadd(loginaccs, logframescorrectedgej);
                double logpathacc = loginaccs + logbetas[E] + edgescore;
                atomicLogAdd(&logaccbetas[S], logpathacc);

                // sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
                double logsum = logframescorrectedgej; // sum over this edge, left partial (alpha), right partial (beta)

                double edgelogaccalpha = logaccalphas[S] - logalphas[S]; // incoming partial expected frames correct
                logadd(logsum, edgelogaccalpha);

                double edgelogaccbeta = logaccbetas[E] - logbetas[E]; // partial expected frames correct from the end
                logadd(logsum, edgelogaccbeta);

                logEframescorrectj = logsum; // that's it
#endif
            }
        }

#ifdef FORBID_INVALID_SIL_PATHS
        double logpp2 = LOGZERO;
        double logEframescorrectj2 = LOGZERO;

        // silence edge or second speech edge
        if ((isaddedsil && e.E != nodes.size() - 1) || (forbidinvalidsilpath && e.S != 0))
        {
            const size_t S = (size_t)(!isaddedsil ? e.S + nodes.size() : e.S); // second speech edge comes from special 'silence state' node
            const size_t E = (size_t)(isaddedsil ? e.E + nodes.size() : e.E);  // silence edge goes into special 'silence state' node
            // remaining lines here are code dup from above, with two changes: logadd2/logEframescorrectj2 instead of logadd/logEframescorrectj

            // backward pass
            const double inscore = logbetas[E];
            const double pathscore = inscore + edgescore;
            atomicLogAdd(&logbetas[S], pathscore);

            // compute lattice posteriors on the fly since we are at it
            logpp2 = logalphas[S] + edgescore + logbetas[E] - totalfwscore; // second edge (logpp2)

            // similar logic for Eframescorrect
            if (returnEframescorrect)
            {
                // backward pass
                double loginaccs = logaccbetas[E] - logbetas[E];
                logadd(loginaccs, logframescorrectedgej);
                double logpathacc = loginaccs + logbetas[E] + edgescore;
                atomicLogAdd(&logaccbetas[S], logpathacc);

                // sum up to get final expected frames-correct count per state == per edge (since we assume hard state alignment)
                double logsum = logframescorrectedgej; // sum over this edge, left partial (alpha), right partial (beta)

                double edgelogaccalpha = logaccalphas[S] - logalphas[S]; // incoming partial expected frames correct
                logadd(logsum, edgelogaccalpha);

                double edgelogaccbeta = logaccbetas[E] - logbetas[E]; // partial expected frames correct from the end
                logadd(logsum, edgelogaccbeta);

                logEframescorrectj2 = logsum; // that's it for this second edge
            }

            // sum logpp2 and logEframescorrectj2
            // Eframescorrect must be summed up in a weighted fashion, weighted by PP
            double numer = logEframescorrectj + logpp;
            logadd(numer, logEframescorrectj2 + logpp2); // weighted sum, weighted by respective (log)pp
            logadd(logpp, logpp2);                       // (log)pp is just the sum of the two posteriors
            double denom = logpp;                        // and that is also the denominator for the weighted sum
            logEframescorrectj = numer - denom;          // weighted sum
        }
#else
        nodes;
#endif

        // write back return values
        if (logpp > 0.0) // clip to log 1 (may be possible due to small numeric inaccuracies, although it really shouldn't happen)
            logpp = 0.0;
        logpps[j] = logpp;
        if (returnEframescorrect)
            logEframescorrect[j] = logEframescorrectj;
    }

    template <typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector, typename doublevector, typename matrix>
    static inline __device__ void sMBRerrorsignalj(size_t j, const ushortvector &alignstateids, const uintvector &alignoffsets,
                                                   const edgeinfowithscoresvector &edges,
                                                   const nodeinfovector &nodes, const doublevector &logpps, const float amf,
                                                   const doublevector &logEframescorrect, const double logEframescorrecttotal,
                                                   matrix &errorsignal, matrix &errorsignalneg)
    {
        size_t ts = nodes[edges[j].S].t;
        size_t te = nodes[edges[j].E].t;
        if (ts != te)
        {
#ifdef DIRECT_MODE
            float logEframescorrectj = logEframescorrect[j];
            size_t offset = alignoffsets[j];
            for (size_t t = ts; t < te; t++)
            {
                const size_t s = (size_t) alignstateids[t - ts + offset];
                atomicLogAdd(&errorsignal(s, t), logEframescorrectj);
            }
#else
            const double diff = expdiff(logEframescorrect[j], logEframescorrecttotal);
            // Note: the contribution of the states of an edge to their senones is the same for all states
            // so we compute it once and add it to all; this will not be the case without hard alignments.
            double absdiff = fabs(diff);
            if (absdiff == 0.0f)
                return;
            const float logedgecorrect = (float) (logpps[j] + log(absdiff));
            size_t offset = alignoffsets[j];
            for (size_t t = ts; t < te; t++)
            {
                const size_t s = (size_t) alignstateids[t - ts + offset];
                if (diff > 0.0)
                    atomicLogAdd(&errorsignal(s, t), logedgecorrect);
                else
                    atomicLogAdd(&errorsignalneg(s, t), logedgecorrect);
            }
            absdiff = amf; // just to reference it because we hope to support linear mode as well
#endif
        }
    }

    // accumulate a per-edge quantity into the states that the edge is aligned with
    // Use this for MMI passing the edge posteriors logpps[] as logq, or for sMBR passing logEframescorrect[].
    // j=edge index, alignment in (alignstateids, alignoffsets)
    template <typename ushortvector, typename uintvector, typename edgeinfowithscoresvector, typename nodeinfovector, typename doublevector, typename matrix>
    static inline __device__ void stateposteriorsj(size_t j, const ushortvector &alignstateids, const uintvector &alignoffsets,
                                                   const edgeinfowithscoresvector &edges,
                                                   const nodeinfovector &nodes, const doublevector &logqs /*quantity to accumulate*/,
                                                   matrix &logacc /*accumulator to accumulate into*/)
    {
        size_t ts = nodes[edges[j].S].t;
        size_t te = nodes[edges[j].E].t;
        if (ts != te)
        {
            const float logq = (float) logqs[j]; // per-edge quantity to accumulate, e.g. edge posteriors -> state posteriors
            size_t offset = alignoffsets[j];
            for (size_t t = ts; t < te; t++)
            {
                const size_t s = (size_t) alignstateids[t - ts + offset]; // get state for this (j,t)
                atomicLogAdd(&logacc(s, t), logq);
            }
        }
    }
};

}};

#pragma pop_macro("atomicCAS")
#pragma pop_macro("atomicAdd")
#pragma pop_macro("__device__")
back to top