// cudalatticeops.h -- contains all actual CUDA-side lattice ops // // F. Seide, V-hansu #pragma once #include "cudabasetypes.h" // for vectorref<> #include "latticestorage.h" // for the lattice types #include "latticefunctionskernels.h" // for the actual inner kernels and any argument types that are not yet defined in latticestorage.h using namespace msra::lattices; // Forward declarations namespace Microsoft { namespace MSR { namespace CNTK { template class Matrix; } } } namespace msra { namespace math { class ssematrixbase; } } namespace msra { namespace cuda { // The XXXvectorops classes must derive from vectorref. class latticefunctionsops : protected vectorref { protected: void edgealignment(const vectorref& hmms, const vectorref& transPs, const size_t spalignunitid, const size_t silalignunitid, const matrixref& logLLs, const vectorref& nodes, const vectorref& edges, const vectorref& aligns, const vectorref& alignoffsets, vectorref& backptrstorage, const vectorref& backptroffsets, vectorref& alignresult, vectorref& edgeacscores) const; // output void forwardbackwardlattice(const size_t* batchsizeforward, const size_t* batchsizebackward, const size_t numlaunchforward, const size_t numlaunchbackward, const size_t spalignunitid, const size_t silalignunitid, const vectorref& edgeacscores, const vectorref& edges, const vectorref& nodes, const vectorref& aligns, const vectorref& aligments, const vectorref& aligmentoffsets, vectorref& logpps, vectorref& logalphas, vectorref& logbetas, const float lmf, const float wp, const float amf, const float boostingfactor, const bool returnEframescorrect, const vectorref& uids, const vectorref& senone2classmap, vectorref& logaccalphas, vectorref& logaccbetas, vectorref& logframescorrectedge, vectorref& logEframescorrect, vectorref& Eframescorrectbuf, double& logEframescorrecttotal, double& totalfwscore) const; void sMBRerrorsignal(const vectorref& alignstateids, const vectorref& alignoffsets, const vectorref& edges, const vectorref& nodes, const vectorref& logpps, const float amf, const vectorref& logEframescorrect, const double logEframescorrecttotal, matrixref& errorsignal, matrixref& errorsignalneg) const; void mmierrorsignal(const vectorref& alignstateids, const vectorref& alignoffsets, const vectorref& edges, const vectorref& nodes, const vectorref& logpps, matrixref& errorsignal) const; void stateposteriors(const vectorref& alignstateids, const vectorref& alignoffsets, const vectorref& edges, const vectorref& nodes, const vectorref& logqs, matrixref& logacc) const; }; }; };