https://github.com/Microsoft/CNTK
Raw File
Tip revision: 52331f59b6891484d452800d723f800ca0c4b9f4 authored by Mark Hillebrand on 18 January 2016, 08:33:07 UTC
License change
Tip revision: 52331f5
cudalatticeops.cu.h
// cudamatrix.cu(.h) -- CUDA kernels for lattice ops. Consider this a .cu/.cpp file.
//
// F. Seide, V-hansu

#if 0   // [v-hansu] set aside codes with log
#endif

#undef DIRECT_MODE          // [v-hansu] use the direct formula for smbr mode, proven makes no difference

#include <cuda_runtime_api.h>
#include <cuda.h>
#include "cudalib.h"
#include "cudabasetypes.h"
#include "latticestorage.h"
#include "latticefunctionskernels.h"
#include "cudalatticeops.h"
#include "math.h"
#include <assert.h>
#include <stdexcept>

#ifdef _WIN32
#include <windows.h>    // for timer
#endif

#if __unix__
#include <sys/time.h>
#endif

namespace msra { namespace cuda {

    cudaStream_t GetCurrentStream();

    // auto_timer timer; run(); double seconds = timer; // now can abandon the object    
    #ifdef __unix__
    typedef timeval LARGE_INTEGER;
    #endif
    class auto_timer
    {
        LARGE_INTEGER freq, start;
        auto_timer (const auto_timer &); void operator= (const auto_timer &);
    public:
        auto_timer()
        {
    #ifdef _WIN32
            if (!QueryPerformanceFrequency (&freq)) // count ticks per second
                RuntimeError("auto_timer: QueryPerformanceFrequency failure");
            QueryPerformanceCounter (&start);
    #endif
    #ifdef __unix__
            gettimeofday (&start, NULL);
    #endif
        }
        operator double() const     // each read gives time elapsed since start, in seconds
        {
            LARGE_INTEGER end;
    #ifdef _WIN32
            QueryPerformanceCounter (&end);
            return (end.QuadPart - start.QuadPart) / (double) freq.QuadPart;
    #endif
    #ifdef __unix__
            gettimeofday (&end,NULL);
            return (end.tv_sec - start.tv_sec) + (end.tv_usec - start.tv_usec)/(1000*1000);
    #endif
        }
        void show (const std::string & msg) const
        {
            double elapsed = *this;
            fprintf (stderr, "%s: %.6f ms\n", msg.c_str(), elapsed * 1000.0/*to ms*/);
        }
    };
    
    // -----------------------------------------------------------------------
    // edgealignment --do alignment on a per edge level, only support normal left to right hmms and ergodic silence hmm
    // output alignresult
    // -----------------------------------------------------------------------

    __global__ void edgealignmentj (const vectorref<lrhmmdef> hmms, const vectorref<lr3transP> transPs, const size_t spalignunitid, const size_t silalignunitid,
                                    const matrixref<float> logLLs, const vectorref<msra::lattices::nodeinfo> nodes, const vectorref<msra::lattices::edgeinfowithscores> edges,
                                    const vectorref<msra::lattices::aligninfo> aligns, const vectorref<unsigned int> alignoffsets, 
                                    vectorref<unsigned short> backptrstorage, const vectorref<size_t> backptroffsets, 
                                    vectorref<unsigned short> alignresult, vectorref<float> edgeacscores)       // output
    {
        const size_t tpb = blockDim.x * blockDim.y;       // total #threads in a block
        const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x;
        const size_t j = jinblock + blockIdx.x * tpb;
        if (j < edges.size())       // note: will cause issues if we ever use __synctreads()
        {
            msra::lattices::latticefunctionskernels::edgealignmentj (j, hmms, transPs, spalignunitid, silalignunitid, logLLs, nodes, edges, aligns, alignoffsets, backptrstorage, backptroffsets, alignresult, edgeacscores);
        }
    }

    void latticefunctionsops::edgealignment (const vectorref<lrhmmdef> & hmms, const vectorref<lr3transP> & transPs, const size_t spalignunitid,
                                             const size_t silalignunitid, const matrixref<float> & logLLs, 
                                             const vectorref<msra::lattices::nodeinfo> & nodes, const vectorref<msra::lattices::edgeinfowithscores> & edges, 
                                             const vectorref<msra::lattices::aligninfo> & aligns, const vectorref<unsigned int> & alignoffsets, 
                                             vectorref<unsigned short> & backptrstorage, const vectorref<size_t> & backptroffsets,
                                             vectorref<unsigned short> & alignresult, vectorref<float> & edgeacscores) const        // output
    {
        // Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks.
        // This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well.
        const size_t numedges = edges.size();
        dim3 t (32,8);
        const size_t tpb = t.x * t.y;
        dim3 b ((unsigned int) ((numedges + tpb - 1) / tpb));
        //cudaarrayref<float> logLLsarray;        // TODO: pass this in, of course
        //passtextureref texref (logLLstex, logLLsarray);    // use the same name as that global texref one, so it will match the name inside the kernel
        edgealignmentj << <b, t, 0, /*GetCurrentStream()*/ cudaStreamDefault >> > (hmms, transPs, spalignunitid, silalignunitid, logLLs, nodes, edges, aligns, alignoffsets, backptrstorage, backptroffsets, alignresult, edgeacscores);
        checklaunch ("edgealignment");
    }

    // setvalue --helper to initialize an array to a constant value, e.g. LOGZERO
    __global__ void setvaluej (vectorref<double> arraytoset, double value, size_t nelem)
    {
        const size_t tpb = blockDim.x * blockDim.y;       // total #threads in a block
        const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x;
        const size_t j = jinblock + blockIdx.x * tpb;
        if (j < nelem)
        {
            msra::lattices::latticefunctionskernels::setvaluej(j, arraytoset, value);
        }
    }

    __global__ void expfi (matrixref<float> mata)
    {
        const size_t i = threadIdx.x + (blockIdx.x * blockDim.x);
        if (i < mata.rows())
        {
            const size_t m = mata.cols();
            for (size_t j = 0; j < m; j++)
                mata(i,j) = expf (mata(i,j));
        }
    }

    __global__ void dotprodi (matrixref<float> mata, matrixref<float> matb)
    {
        const size_t i = threadIdx.x + (blockIdx.x * blockDim.x);
        if (i < mata.rows())
        {
            const size_t m = mata.cols();
            for (size_t j = 0; j < m; j++)
                mata(i,j) = mata(i,j) * matb(i,j);
        }
    }

    __global__ void setunseeni (matrixref<float> errorsignal, matrixref<float> errorsignalauxbuf)
    {
        const size_t i = threadIdx.x + (blockIdx.x * blockDim.x);
        if (i < errorsignal.rows())
        {
            const size_t m = errorsignal.cols();
            for (size_t j = 0; j < m; j++)
                if (errorsignal(i,j) == logf(CUDART_MIN_DENORM_F) && errorsignalauxbuf(i,j) == logf(CUDART_MIN_DENORM_F))
                    errorsignalauxbuf(i,j) = LOGZERO;
        }
    }

    // errorsignal(i,j) = (exp(errorsignal(i,j)) - exp(errorsignal(i,j))) / amf
    __global__ void errorcomputationi (matrixref<float> errorsignal, matrixref<float> errorsignalauxbuf, float amf)
    {
        const size_t i = threadIdx.x + (blockIdx.x * blockDim.x);
        if (i < errorsignal.rows())
        {
            const size_t m = errorsignal.cols();
            for (size_t j = 0; j < m; j++)
                errorsignal(i,j) = msra::lattices::latticefunctionskernels::expdiff (errorsignal(i,j), errorsignalauxbuf(i,j)) / amf;
        }
    }

    // exp(errorsignal(i,j)) - exp(logEframescorrecttotal+errorsignalauxbuf(i,j))/amf
    __global__ void directerrorcomputationi (matrixref<float> errorsignal, matrixref<float> errorsignalauxbuf, float logEframescorrecttotal,float amf)
    {
        const size_t i = threadIdx.x + (blockIdx.x * blockDim.x);
        if (i < errorsignal.rows())
        {
            const size_t m = errorsignal.cols();
            for (size_t j = 0; j < m; j++)
                errorsignal(i,j) = msra::lattices::latticefunctionskernels::expdiff (errorsignal(i,j), logEframescorrecttotal + errorsignalauxbuf(i,j)) / amf;
        }
    }

    // compute the final error signal from gammas and state-consolidated Eframescorrect
    // in-place operation is supported (i.e. output = one of the inputs) 
    __global__ void computesMBRerrorsignals (const matrixref<float> loggammas, const matrixref<float> logEframescorrect, 
                                             const double logEframecorrecttotal, const float kappa, matrixref<float> errorsignal)
    {
        const size_t s = threadIdx.x + (blockIdx.x * blockDim.x);
        if (s < loggammas.rows())
            msra::lattices::latticefunctionskernels::computesMBRerrorsignals (s, loggammas, logEframescorrect, logEframecorrecttotal, kappa, errorsignal);
    }

    __global__ void forwardlatticej (const size_t batchsize, const size_t startindex, const vectorref<float> edgeacscores, 
                                     const size_t spalignunitid, const size_t silalignunitid,
                                     vectorref<msra::lattices::edgeinfowithscores> edges, vectorref<msra::lattices::nodeinfo> nodes, 
                                     const vectorref<msra::lattices::aligninfo> aligns, vectorref<unsigned short> alignments, 
                                     vectorref<unsigned int> alignmentoffsets, vectorref<double> logalphas, float lmf, float wp, float amf, 
                                     const float boostingfactor, const vectorref<unsigned short> uids, const vectorref<unsigned short> senone2classmap, 
                                     const bool returnEframescorrect, vectorref<double> logframescorrectedge, vectorref<double> logaccalphas)
    {
        const size_t shufflemode = 1;       // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle)
        const size_t j = msra::lattices::latticefunctionskernels::shuffle (threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode);
        if (j < batchsize)       // note: will cause issues if we ever use __synctreads()
        {
            msra::lattices::latticefunctionskernels::forwardlatticej (j + startindex, edgeacscores, spalignunitid, silalignunitid, edges, nodes, aligns, alignments, alignmentoffsets, 
                                                                      logalphas, lmf, wp, amf, boostingfactor, uids, senone2classmap, returnEframescorrect, logframescorrectedge, logaccalphas);
        }
    }

    __global__ void backwardlatticej (const size_t batchsize, const size_t startindex, const vectorref<float> edgeacscores, 
                                      const size_t spalignunitid, const size_t silalignunitid,                              
                                      vectorref<msra::lattices::edgeinfowithscores> edges, vectorref<msra::lattices::nodeinfo> nodes,
                                      vectorref<msra::lattices::aligninfo> aligns, const double totalfwscore,
                                      vectorref<double> logpps, vectorref<double> logalphas, vectorref<double> logbetas, 
                                      float lmf, float wp, float amf, const float boostingfactor, const bool returnEframescorrect, 
                                      vectorref<double> logframescorrectedge, vectorref<double> logaccalphas, 
                                      vectorref<double> logEframescorrect, vectorref<double> logaccbetas)
    {
        const size_t tpb = blockDim.x * blockDim.y;       // total #threads in a block
        const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x;
        size_t j = jinblock + blockIdx.x * tpb;
        if (j < batchsize)       // note: will cause issues if we ever use __synctreads()
        {
            msra::lattices::latticefunctionskernels::backwardlatticej (j + startindex, edgeacscores, spalignunitid, silalignunitid, 
                                                                       edges, nodes, aligns, totalfwscore, logpps, logalphas, logbetas, 
                                                                       lmf, wp, amf, boostingfactor, returnEframescorrect,
                                                                       logframescorrectedge, logaccalphas, 
                                                                       logEframescorrect, logaccbetas);
        }
    }

    void latticefunctionsops::forwardbackwardlattice (const size_t * batchsizeforward, const size_t * batchsizebackward, 
                                                      const size_t numlaunchforward, const size_t numlaunchbackward,
                                                      const size_t spalignunitid, const size_t silalignunitid,
                                                      const vectorref<float> & edgeacscores,
                                                      const vectorref<msra::lattices::edgeinfowithscores> & edges, 
                                                      const vectorref<msra::lattices::nodeinfo> & nodes, 
                                                      const vectorref<msra::lattices::aligninfo> & aligns, 
                                                      const vectorref<unsigned short> & alignments,
                                                      const vectorref<unsigned int> & aligmentoffsets,
                                                      vectorref<double> & logpps, vectorref<double> & logalphas, vectorref<double> & logbetas,
                                                      const float lmf, const float wp, const float amf, const float boostingfactor, 
                                                      const bool returnEframescorrect, const vectorref<unsigned short> & uids, 
                                                      const vectorref<unsigned short> & senone2classmap, vectorref<double> & logaccalphas, 
                                                      vectorref<double> & logaccbetas, vectorref<double> & logframescorrectedge,
                                                      vectorref<double> & logEframescorrect, vectorref<double> & /*Eframescorrectbuf*/, 
                                                      double & logEframescorrecttotal, double & totalfwscore) const
    {
        // initialize log{,acc}(alhas/betas)
        dim3 t (32, 8);
        const size_t tpb = t.x * t.y;
        dim3 b ((unsigned int) ((logalphas.size() + tpb - 1) / tpb));

        // TODO: is this really efficient? One thread per value?
        setvaluej<<<b, t, 0, GetCurrentStream()>>> (logalphas, LOGZERO, logalphas.size());
        checklaunch ("setvaluej");
        setvaluej<<<b, t, 0, GetCurrentStream()>>> (logbetas, LOGZERO, logalphas.size());
        checklaunch ("setvaluej");
        if (returnEframescorrect)
        {
            setvaluej<<<b, t, 0, GetCurrentStream()>>> (logaccalphas, LOGZERO, logalphas.size());
            checklaunch ("setvaluej");
            setvaluej<<<b, t, 0, GetCurrentStream()>>> (logaccbetas, LOGZERO, logalphas.size());
            checklaunch ("setvaluej");
        }
        // set initial tokens to probability 1 (0 in log)
        double log1 = 0.0;
        memcpy (logalphas.get(), 0, &log1, 1);
        memcpy (logbetas.get(), nodes.size()-1, &log1, 1);

        // forward pass
        size_t startindex = 0;
        for (size_t i = 0; i < numlaunchforward; i++)
        {
            dim3 b ((unsigned int) ((batchsizeforward[i] + tpb - 1) / tpb));
            forwardlatticej <<<b, t, 0, GetCurrentStream()>>> (batchsizeforward[i], startindex, edgeacscores, 
                                                               spalignunitid, silalignunitid, edges, nodes, aligns,
                                                               alignments, aligmentoffsets, logalphas, lmf, wp, amf, 
                                                               boostingfactor, uids, senone2classmap, returnEframescorrect, 
                                                               logframescorrectedge, logaccalphas);
            checklaunch ("edgealignment");
            startindex += batchsizeforward[i];
        }
        memcpy<double>(&totalfwscore, logalphas.get(), nodes.size() - 1, 1);
        double totalfwacc = 0;
        if(returnEframescorrect)
        {
            memcpy<double>(&totalfwacc, logaccalphas.get(), nodes.size() - 1, 1);
            totalfwacc -= totalfwscore;
        }

        // backward pass
        startindex = edges.size();
        for (size_t i = 0; i < numlaunchbackward; i++)
        {
            dim3 b ((unsigned int) ((batchsizebackward[i] + tpb - 1) / tpb));
            backwardlatticej <<<b, t, 0, GetCurrentStream()>>> (batchsizebackward[i], startindex - batchsizebackward[i], 
                                                                edgeacscores, spalignunitid, silalignunitid, edges, nodes, aligns,
                                                                totalfwscore, logpps, logalphas, logbetas, 
                                                                lmf, wp, amf, boostingfactor, returnEframescorrect, logframescorrectedge, 
                                                                logaccalphas, logEframescorrect, logaccbetas);
            checklaunch ("edgealignment");
            startindex -= batchsizebackward[i];
        }
        double totalbwscore = 0;
        memcpy<double>(&totalbwscore, logbetas.get(), 0, 1);
        double totalbwacc = 0;
        if (returnEframescorrect)
        {
            memcpy<double>(&totalbwacc, logaccbetas.get(), 0, 1);
            totalbwacc -= totalbwscore;
            logEframescorrecttotal = totalbwacc;
        }

        double difffwbwscore = totalfwscore - totalbwscore;
        double absdifffwbwscore = difffwbwscore > 0 ? difffwbwscore : 0 - difffwbwscore;

        if (absdifffwbwscore / nodes.size() > 1e-4)
            fprintf (stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size());

        if (returnEframescorrect)
        {
            double difffwbwacc = totalfwacc- totalbwacc;
            double absdifffwbwacc = difffwbwacc > 0 ? difffwbwacc : 0 - difffwbwacc;

            if (absdifffwbwacc / nodes.size() > 1e-4)
                fprintf (stderr, "forwardbackward: WARNING: lattice fw and bw acc %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size());
        }
    }

    // -----------------------------------------------------------------------
    // sMBRerrorsignal -- accumulate difference of logEframescorrect and logEframescorrecttotal into errorsignal
    // -----------------------------------------------------------------------

    __global__ void sMBRerrorsignalj (const vectorref<unsigned short> alignstateids, const vectorref<unsigned int> alignoffsets,
                                      const vectorref<msra::lattices::edgeinfowithscores> edges, const vectorref<msra::lattices::nodeinfo> nodes, 
                                      vectorref<double> logpps, const float amf, const vectorref<double> logEframescorrect, const double logEframescorrecttotal, 
                                      matrixref<float> errorsignal, matrixref<float> errorsignalneg)
    {
        const size_t shufflemode = 1;  // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle)
        const size_t j = msra::lattices::latticefunctionskernels::shuffle (threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode);
        if (j < edges.size())       // note: will cause issues if we ever use __synctreads()
        {
            msra::lattices::latticefunctionskernels::sMBRerrorsignalj (j, alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg);
        }
    }

    // -----------------------------------------------------------------------
    // stateposteriors --accumulate a per-edge quantity into the states that the edge is aligned with
    // -----------------------------------------------------------------------

    __global__ void stateposteriorsj (const vectorref<unsigned short> alignstateids, const vectorref<unsigned int> alignoffsets,
                                      const vectorref<msra::lattices::edgeinfowithscores> edges, const vectorref<msra::lattices::nodeinfo> nodes, 
                                      const vectorref<double> logqs, matrixref<float> logacc)
    {
        const size_t shufflemode = 1;  // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle)
        const size_t j = msra::lattices::latticefunctionskernels::shuffle (threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode);
        if (j < edges.size())       // note: will cause issues if we ever use __synctreads()
        {
            msra::lattices::latticefunctionskernels::stateposteriorsj (j, alignstateids, alignoffsets, edges, nodes, logqs, logacc);
        }
    }

    __global__ void setvaluei(matrixref<float> us, float value)
    {
        const size_t i = threadIdx.x + (blockIdx.x * blockDim.x);
        if (i >= us.rows())
            return;
        // set all columns
        const size_t m = us.cols();
        for (size_t j = 0; j < m; j++)
            us(i, j) = value;
    }

    void latticefunctionsops::stateposteriors (const vectorref<unsigned short> & alignstateids, const vectorref<unsigned int> & alignoffsets,
                                               const vectorref<msra::lattices::edgeinfowithscores> & edges, const vectorref<msra::lattices::nodeinfo> & nodes, 
                                               const vectorref<double> & logqs, matrixref<float> & logacc) const
    {
        // Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks.
        // This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well.
        const size_t numedges = edges.size();
        dim3 t (32,8);
        const size_t tpb = t.x * t.y;
        dim3 b ((unsigned int) ((numedges + tpb - 1) / tpb));

        setvaluei <<<dim3((((unsigned int) logacc.rows())+31)/32), 32, 0, GetCurrentStream()>>> (logacc, LOGZERO);
        checklaunch ("setvaluei");

        stateposteriorsj <<<b, t, 0, GetCurrentStream()>>> (alignstateids, alignoffsets, edges, nodes, logqs, logacc);
        checklaunch ("stateposteriors");
    }

    void latticefunctionsops::sMBRerrorsignal (const vectorref<unsigned short> & alignstateids, const vectorref<unsigned int> & alignoffsets,
                                               const vectorref<msra::lattices::edgeinfowithscores> & edges, const vectorref<msra::lattices::nodeinfo> & nodes, 
                                               const vectorref<double> & logpps, const float amf, const vectorref<double> & logEframescorrect, const double logEframescorrecttotal, 
                                               matrixref<float> & errorsignal, matrixref<float> & errorsignalauxbuf) const
    {
        // Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks.
        // This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well.
        const size_t numedges = edges.size();
        dim3 t (32,8);
        const size_t tpb = t.x * t.y;
        dim3 b ((unsigned int) ((numedges + tpb - 1) / tpb));

#ifdef DIRECT_MODE  // compute Eframescorrect in a more direct way, proven to get same result as below
        setvaluei <<<dim3((((unsigned int) errorsignal.rows())+31)/32), 32, 0, GetCurrentStream()>>> (errorsignal, LOGZERO);
        checklaunch ("setvaluei");
        
        sMBRerrorsignalj <<<b, t, 0, GetCurrentStream()>>> (alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalauxbuf);
        checklaunch ("sMBRerrorsignal");        // now we get state based logEframescorrect

        matrixref<float> & loggammas = errorsignalauxbuf;
        setvaluei <<<dim3((((unsigned int) errorsignalauxbuf.rows())+31)/32), 32, 0, GetCurrentStream()>>> (errorsignalauxbuf, LOGZERO);
        checklaunch ("setvaluei");

        stateposteriorsj <<<b, t, 0, GetCurrentStream()>>> (alignstateids, alignoffsets, edges, nodes, logpps, loggammas);
        checklaunch ("stateposteriorsj");       // now we get state based loggammas

        directerrorcomputationi <<<dim3((((unsigned int) errorsignal.rows())+31)/32), 32, 0, GetCurrentStream()>>> (errorsignal, errorsignalauxbuf, logEframescorrecttotal, amf);
        checklaunch ("errorcomputationj");
#else   // this saves some computation compared with DIRECT_MODE
        setvaluei <<<dim3((((unsigned int) errorsignal.rows())+31)/32), 32, 0, GetCurrentStream()>>> (errorsignal, LOGZERO);
        checklaunch ("setvaluei");
        setvaluei <<<dim3((((unsigned int) errorsignalauxbuf.rows())+31)/32), 32, 0, GetCurrentStream()>>> (errorsignalauxbuf, LOGZERO);
        checklaunch ("setvaluei");
        sMBRerrorsignalj <<<b, t, 0, GetCurrentStream()>>> (alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalauxbuf);
        checklaunch ("sMBRerrorsignal");

        setunseeni <<<dim3((((unsigned int) errorsignal.rows())+31)/32), 32, 0, GetCurrentStream()>>> (errorsignal, errorsignalauxbuf);
        checklaunch ("setunseenj");

        errorcomputationi <<<dim3((((unsigned int) errorsignal.rows())+31)/32), 32, 0, GetCurrentStream()>>> (errorsignal, errorsignalauxbuf, amf);
        checklaunch ("errorcomputationj");

#endif
    }

    void latticefunctionsops::mmierrorsignal (const vectorref<unsigned short> & alignstateids, const vectorref<unsigned int> & alignoffsets,
                                              const vectorref<msra::lattices::edgeinfowithscores> & edges, const vectorref<msra::lattices::nodeinfo> & nodes, 
                                              const vectorref<double> & logpps, matrixref<float> & errorsignal) const
    {
        const size_t numedges = edges.size();
        dim3 t (32,8);
        const size_t tpb = t.x * t.y;
        dim3 b ((unsigned int) ((numedges + tpb - 1) / tpb));

        matrixref<float> & loggammas = errorsignal;                     // remember--this is an alias to 'errorsignal'
        setvaluei <<<dim3((((unsigned int) loggammas.rows())+31)/32), 32, 0, GetCurrentStream()>>> (loggammas, LOGZERO);
        checklaunch ("setvaluei");
        stateposteriorsj <<<b, t, 0, GetCurrentStream()>>> (alignstateids, alignoffsets, edges, nodes, logpps, loggammas);
        checklaunch ("stateposteriorsj");

        expfi <<<dim3((((unsigned int) errorsignal.rows())+31)/32), 32, 0, GetCurrentStream()>>> (errorsignal);
        checklaunch ("expfi");
    }
};};
back to top