// cudamatrix.cu(.h) -- CUDA kernels for lattice ops. Consider this a .cu/.cpp file. // // F. Seide, V-hansu #undef DIRECT_MODE // [v-hansu] use the direct formula for smbr mode, proven makes no difference #include #include #include "cudalib.h" #include "cudabasetypes.h" #include "latticestorage.h" #include "latticefunctionskernels.h" #include "cudalatticeops.h" #include "math.h" #include #include #ifdef _WIN32 #define NOMINMAX #include "Windows.h" // for timer #endif #if __unix__ #include #endif namespace msra { namespace cuda { cudaStream_t GetCurrentStream(); // auto_timer timer; run(); double seconds = timer; // now can abandon the object #ifdef __unix__ typedef timeval LARGE_INTEGER; #endif class auto_timer { LARGE_INTEGER freq, start; auto_timer(const auto_timer &); void operator=(const auto_timer &); public: auto_timer() { #ifdef _WIN32 if (!QueryPerformanceFrequency(&freq)) // count ticks per second RuntimeError("auto_timer: QueryPerformanceFrequency failure"); QueryPerformanceCounter(&start); #endif #ifdef __unix__ gettimeofday(&start, NULL); #endif } operator double() const // each read gives time elapsed since start, in seconds { LARGE_INTEGER end; #ifdef _WIN32 QueryPerformanceCounter(&end); return (end.QuadPart - start.QuadPart) / (double) freq.QuadPart; #endif #ifdef __unix__ gettimeofday(&end, NULL); return (end.tv_sec - start.tv_sec) + (end.tv_usec - start.tv_usec) / (1000 * 1000); #endif } void show(const std::string &msg) const { double elapsed = *this; fprintf(stderr, "%s: %.6f ms\n", msg.c_str(), elapsed * 1000.0 /*to ms*/); } }; // ----------------------------------------------------------------------- // edgealignment --do alignment on a per edge level, only support normal left to right hmms and ergodic silence hmm // output alignresult // ----------------------------------------------------------------------- __global__ void edgealignmentj(const vectorref hmms, const vectorref transPs, const size_t spalignunitid, const size_t silalignunitid, const matrixref logLLs, const vectorref nodes, const vectorref edges, const vectorref aligns, const vectorref alignoffsets, vectorref backptrstorage, const vectorref backptroffsets, vectorref alignresult, vectorref edgeacscores) // output { const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x; const size_t j = jinblock + blockIdx.x * tpb; if (j < edges.size()) // note: will cause issues if we ever use __synctreads() { msra::lattices::latticefunctionskernels::edgealignmentj(j, hmms, transPs, spalignunitid, silalignunitid, logLLs, nodes, edges, aligns, alignoffsets, backptrstorage, backptroffsets, alignresult, edgeacscores); } } void latticefunctionsops::edgealignment(const vectorref &hmms, const vectorref &transPs, const size_t spalignunitid, const size_t silalignunitid, const matrixref &logLLs, const vectorref &nodes, const vectorref &edges, const vectorref &aligns, const vectorref &alignoffsets, vectorref &backptrstorage, const vectorref &backptroffsets, vectorref &alignresult, vectorref &edgeacscores) const // output { // Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks. // This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well. const size_t numedges = edges.size(); dim3 t(32, 8); const size_t tpb = t.x * t.y; dim3 b((unsigned int) ((numedges + tpb - 1) / tpb)); // cudaarrayref logLLsarray; // TODO: pass this in, of course // passtextureref texref (logLLstex, logLLsarray); // use the same name as that global texref one, so it will match the name inside the kernel edgealignmentj<<>>(hmms, transPs, spalignunitid, silalignunitid, logLLs, nodes, edges, aligns, alignoffsets, backptrstorage, backptroffsets, alignresult, edgeacscores); checklaunch("edgealignment"); } // setvalue --helper to initialize an array to a constant value, e.g. LOGZERO __global__ void setvaluej(vectorref arraytoset, double value, size_t nelem) { const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x; const size_t j = jinblock + blockIdx.x * tpb; if (j < nelem) { msra::lattices::latticefunctionskernels::setvaluej(j, arraytoset, value); } } __global__ void expfi(matrixref mata) { const size_t i = threadIdx.x + (blockIdx.x * blockDim.x); if (i < mata.rows()) { const size_t m = mata.cols(); for (size_t j = 0; j < m; j++) mata(i, j) = expf(mata(i, j)); } } __global__ void dotprodi(matrixref mata, matrixref matb) { const size_t i = threadIdx.x + (blockIdx.x * blockDim.x); if (i < mata.rows()) { const size_t m = mata.cols(); for (size_t j = 0; j < m; j++) mata(i, j) = mata(i, j) * matb(i, j); } } __global__ void setunseeni(matrixref errorsignal, matrixref errorsignalauxbuf) { const size_t i = threadIdx.x + (blockIdx.x * blockDim.x); if (i < errorsignal.rows()) { const size_t m = errorsignal.cols(); for (size_t j = 0; j < m; j++) if (errorsignal(i, j) == logf(CUDART_MIN_DENORM_F) && errorsignalauxbuf(i, j) == logf(CUDART_MIN_DENORM_F)) errorsignalauxbuf(i, j) = LOGZERO; } } // errorsignal(i,j) = (exp(errorsignal(i,j)) - exp(errorsignal(i,j))) / amf __global__ void errorcomputationi(matrixref errorsignal, matrixref errorsignalauxbuf, float amf) { const size_t i = threadIdx.x + (blockIdx.x * blockDim.x); if (i < errorsignal.rows()) { const size_t m = errorsignal.cols(); for (size_t j = 0; j < m; j++) errorsignal(i, j) = msra::lattices::latticefunctionskernels::expdiff(errorsignal(i, j), errorsignalauxbuf(i, j)) / amf; } } // exp(errorsignal(i,j)) - exp(logEframescorrecttotal+errorsignalauxbuf(i,j))/amf __global__ void directerrorcomputationi(matrixref errorsignal, matrixref errorsignalauxbuf, float logEframescorrecttotal, float amf) { const size_t i = threadIdx.x + (blockIdx.x * blockDim.x); if (i < errorsignal.rows()) { const size_t m = errorsignal.cols(); for (size_t j = 0; j < m; j++) errorsignal(i, j) = msra::lattices::latticefunctionskernels::expdiff(errorsignal(i, j), logEframescorrecttotal + errorsignalauxbuf(i, j)) / amf; } } // compute the final error signal from gammas and state-consolidated Eframescorrect // in-place operation is supported (i.e. output = one of the inputs) __global__ void computesMBRerrorsignals(const matrixref loggammas, const matrixref logEframescorrect, const double logEframecorrecttotal, const float kappa, matrixref errorsignal) { const size_t s = threadIdx.x + (blockIdx.x * blockDim.x); if (s < loggammas.rows()) msra::lattices::latticefunctionskernels::computesMBRerrorsignals(s, loggammas, logEframescorrect, logEframecorrecttotal, kappa, errorsignal); } __global__ void forwardlatticej(const size_t batchsize, const size_t startindex, const vectorref edgeacscores, const size_t spalignunitid, const size_t silalignunitid, vectorref edges, vectorref nodes, const vectorref aligns, vectorref alignments, vectorref alignmentoffsets, vectorref logalphas, float lmf, float wp, float amf, const float boostingfactor, const vectorref uids, const vectorref senone2classmap, const bool returnEframescorrect, vectorref logframescorrectedge, vectorref logaccalphas) { const size_t shufflemode = 1; // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle) const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode); if (j < batchsize) // note: will cause issues if we ever use __synctreads() { msra::lattices::latticefunctionskernels::forwardlatticej(j + startindex, edgeacscores, spalignunitid, silalignunitid, edges, nodes, aligns, alignments, alignmentoffsets, logalphas, lmf, wp, amf, boostingfactor, uids, senone2classmap, returnEframescorrect, logframescorrectedge, logaccalphas); } } __global__ void backwardlatticej(const size_t batchsize, const size_t startindex, const vectorref edgeacscores, const size_t spalignunitid, const size_t silalignunitid, vectorref edges, vectorref nodes, vectorref aligns, const double totalfwscore, vectorref logpps, vectorref logalphas, vectorref logbetas, float lmf, float wp, float amf, const float boostingfactor, const bool returnEframescorrect, vectorref logframescorrectedge, vectorref logaccalphas, vectorref logEframescorrect, vectorref logaccbetas) { const size_t tpb = blockDim.x * blockDim.y; // total #threads in a block const size_t jinblock = threadIdx.x + threadIdx.y * blockDim.x; size_t j = jinblock + blockIdx.x * tpb; if (j < batchsize) // note: will cause issues if we ever use __synctreads() { msra::lattices::latticefunctionskernels::backwardlatticej(j + startindex, edgeacscores, spalignunitid, silalignunitid, edges, nodes, aligns, totalfwscore, logpps, logalphas, logbetas, lmf, wp, amf, boostingfactor, returnEframescorrect, logframescorrectedge, logaccalphas, logEframescorrect, logaccbetas); } } void latticefunctionsops::forwardbackwardlattice(const size_t *batchsizeforward, const size_t *batchsizebackward, const size_t numlaunchforward, const size_t numlaunchbackward, const size_t spalignunitid, const size_t silalignunitid, const vectorref &edgeacscores, const vectorref &edges, const vectorref &nodes, const vectorref &aligns, const vectorref &alignments, const vectorref &aligmentoffsets, vectorref &logpps, vectorref &logalphas, vectorref &logbetas, const float lmf, const float wp, const float amf, const float boostingfactor, const bool returnEframescorrect, const vectorref &uids, const vectorref &senone2classmap, vectorref &logaccalphas, vectorref &logaccbetas, vectorref &logframescorrectedge, vectorref &logEframescorrect, vectorref & /*Eframescorrectbuf*/, double &logEframescorrecttotal, double &totalfwscore) const { // initialize log{,acc}(alhas/betas) dim3 t(32, 8); const size_t tpb = t.x * t.y; dim3 b((unsigned int) ((logalphas.size() + tpb - 1) / tpb)); // TODO: is this really efficient? One thread per value? setvaluej<<>>(logalphas, LOGZERO, logalphas.size()); checklaunch("setvaluej"); setvaluej<<>>(logbetas, LOGZERO, logalphas.size()); checklaunch("setvaluej"); if (returnEframescorrect) { setvaluej<<>>(logaccalphas, LOGZERO, logalphas.size()); checklaunch("setvaluej"); setvaluej<<>>(logaccbetas, LOGZERO, logalphas.size()); checklaunch("setvaluej"); } // set initial tokens to probability 1 (0 in log) double log1 = 0.0; memcpy(logalphas.get(), 0, &log1, 1); memcpy(logbetas.get(), nodes.size() - 1, &log1, 1); // forward pass size_t startindex = 0; for (size_t i = 0; i < numlaunchforward; i++) { dim3 b2((unsigned int) ((batchsizeforward[i] + tpb - 1) / tpb)); forwardlatticej<<>>(batchsizeforward[i], startindex, edgeacscores, spalignunitid, silalignunitid, edges, nodes, aligns, alignments, aligmentoffsets, logalphas, lmf, wp, amf, boostingfactor, uids, senone2classmap, returnEframescorrect, logframescorrectedge, logaccalphas); checklaunch("edgealignment"); startindex += batchsizeforward[i]; } memcpy(&totalfwscore, logalphas.get(), nodes.size() - 1, 1); double totalfwacc = 0; if (returnEframescorrect) { memcpy(&totalfwacc, logaccalphas.get(), nodes.size() - 1, 1); totalfwacc -= totalfwscore; } // backward pass startindex = edges.size(); for (size_t i = 0; i < numlaunchbackward; i++) { dim3 b2((unsigned int) ((batchsizebackward[i] + tpb - 1) / tpb)); backwardlatticej<<>>(batchsizebackward[i], startindex - batchsizebackward[i], edgeacscores, spalignunitid, silalignunitid, edges, nodes, aligns, totalfwscore, logpps, logalphas, logbetas, lmf, wp, amf, boostingfactor, returnEframescorrect, logframescorrectedge, logaccalphas, logEframescorrect, logaccbetas); checklaunch("edgealignment"); startindex -= batchsizebackward[i]; } double totalbwscore = 0; memcpy(&totalbwscore, logbetas.get(), 0, 1); double totalbwacc = 0; if (returnEframescorrect) { memcpy(&totalbwacc, logaccbetas.get(), 0, 1); totalbwacc -= totalbwscore; logEframescorrecttotal = totalbwacc; } double difffwbwscore = totalfwscore - totalbwscore; double absdifffwbwscore = difffwbwscore > 0 ? difffwbwscore : 0 - difffwbwscore; if (absdifffwbwscore / nodes.size() > 1e-4) fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw scores %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwscore, (float) totalbwscore, (int) nodes.size(), (int) edges.size()); if (returnEframescorrect) { double difffwbwacc = totalfwacc - totalbwacc; double absdifffwbwacc = difffwbwacc > 0 ? difffwbwacc : 0 - difffwbwacc; if (absdifffwbwacc / nodes.size() > 1e-4) fprintf(stderr, "forwardbackward: WARNING: lattice fw and bw acc %.10f vs. %.10f (%d nodes/%d edges)\n", (float) totalfwacc, (float) totalbwacc, (int) nodes.size(), (int) edges.size()); } } // ----------------------------------------------------------------------- // sMBRerrorsignal -- accumulate difference of logEframescorrect and logEframescorrecttotal into errorsignal // ----------------------------------------------------------------------- __global__ void sMBRerrorsignalj(const vectorref alignstateids, const vectorref alignoffsets, const vectorref edges, const vectorref nodes, vectorref logpps, const float amf, const vectorref logEframescorrect, const double logEframescorrecttotal, matrixref errorsignal, matrixref errorsignalneg) { const size_t shufflemode = 1; // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle) const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode); if (j < edges.size()) // note: will cause issues if we ever use __synctreads() { msra::lattices::latticefunctionskernels::sMBRerrorsignalj(j, alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalneg); } } // ----------------------------------------------------------------------- // stateposteriors --accumulate a per-edge quantity into the states that the edge is aligned with // ----------------------------------------------------------------------- __global__ void stateposteriorsj(const vectorref alignstateids, const vectorref alignoffsets, const vectorref edges, const vectorref nodes, const vectorref logqs, matrixref logacc) { const size_t shufflemode = 1; // [v-hansu] this gives us about 100% speed up than shufflemode = 0 (no shuffle) const size_t j = msra::lattices::latticefunctionskernels::shuffle(threadIdx.x, blockDim.x, threadIdx.y, blockDim.y, blockIdx.x, gridDim.x, shufflemode); if (j < edges.size()) // note: will cause issues if we ever use __synctreads() { msra::lattices::latticefunctionskernels::stateposteriorsj(j, alignstateids, alignoffsets, edges, nodes, logqs, logacc); } } __global__ void setvaluei(matrixref us, float value) { const size_t i = threadIdx.x + (blockIdx.x * blockDim.x); if (i >= us.rows()) return; // set all columns const size_t m = us.cols(); for (size_t j = 0; j < m; j++) us(i, j) = value; } void latticefunctionsops::stateposteriors(const vectorref &alignstateids, const vectorref &alignoffsets, const vectorref &edges, const vectorref &nodes, const vectorref &logqs, matrixref &logacc) const { // Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks. // This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well. const size_t numedges = edges.size(); dim3 t(32, 8); const size_t tpb = t.x * t.y; dim3 b((unsigned int) ((numedges + tpb - 1) / tpb)); setvaluei<<>>(logacc, LOGZERO); checklaunch("setvaluei"); stateposteriorsj<<>>(alignstateids, alignoffsets, edges, nodes, logqs, logacc); checklaunch("stateposteriors"); } void latticefunctionsops::sMBRerrorsignal(const vectorref &alignstateids, const vectorref &alignoffsets, const vectorref &edges, const vectorref &nodes, const vectorref &logpps, const float amf, const vectorref &logEframescorrect, const double logEframescorrecttotal, matrixref &errorsignal, matrixref &errorsignalauxbuf) const { // Layout: each thread block takes 1024 threads; and we have #edges/1024 blocks. // This limits us to 16 million edges. If you need more, please adjust to either use wider thread blocks or a second dimension for the grid. Don't forget to adjust the kernel as well. const size_t numedges = edges.size(); dim3 t(32, 8); const size_t tpb = t.x * t.y; dim3 b((unsigned int) ((numedges + tpb - 1) / tpb)); #ifdef DIRECT_MODE // compute Eframescorrect in a more direct way, proven to get same result as below setvaluei<<>>(errorsignal, LOGZERO); checklaunch("setvaluei"); sMBRerrorsignalj<<>>(alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalauxbuf); checklaunch("sMBRerrorsignal"); // now we get state based logEframescorrect matrixref &loggammas = errorsignalauxbuf; setvaluei<<>>(errorsignalauxbuf, LOGZERO); checklaunch("setvaluei"); stateposteriorsj<<>>(alignstateids, alignoffsets, edges, nodes, logpps, loggammas); checklaunch("stateposteriorsj"); // now we get state based loggammas directerrorcomputationi<<>>(errorsignal, errorsignalauxbuf, logEframescorrecttotal, amf); checklaunch("errorcomputationj"); #else // this saves some computation compared with DIRECT_MODE setvaluei<<>>(errorsignal, LOGZERO); checklaunch("setvaluei"); setvaluei<<>>(errorsignalauxbuf, LOGZERO); checklaunch("setvaluei"); sMBRerrorsignalj<<>>(alignstateids, alignoffsets, edges, nodes, logpps, amf, logEframescorrect, logEframescorrecttotal, errorsignal, errorsignalauxbuf); checklaunch("sMBRerrorsignal"); setunseeni<<>>(errorsignal, errorsignalauxbuf); checklaunch("setunseenj"); errorcomputationi<<>>(errorsignal, errorsignalauxbuf, amf); checklaunch("errorcomputationj"); #endif } void latticefunctionsops::mmierrorsignal(const vectorref &alignstateids, const vectorref &alignoffsets, const vectorref &edges, const vectorref &nodes, const vectorref &logpps, matrixref &errorsignal) const { const size_t numedges = edges.size(); dim3 t(32, 8); const size_t tpb = t.x * t.y; dim3 b((unsigned int) ((numedges + tpb - 1) / tpb)); matrixref &loggammas = errorsignal; // remember--this is an alias to 'errorsignal' setvaluei<<>>(loggammas, LOGZERO); checklaunch("setvaluei"); stateposteriorsj<<>>(alignstateids, alignoffsets, edges, nodes, logpps, loggammas); checklaunch("stateposteriorsj"); expfi<<>>(errorsignal); checklaunch("expfi"); } }; };