// cudalattice.cpp -- lattice forward/backward functions for CUDA execution (glue code) // // F. Seide, V-hansu #define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings #define DLLEXPORT #define __kernel_emulation__ // allow the compilation of CUDA kernels on the CPU #include "latticefunctionskernels.h" // for the actual inner kernels and any argument types that are not yet defined in latticestorage.h #undef __kernel_emulation__ #include "cudalattice.h" // this exports the class #include "cudalatticeops.h" // brings in the actual lattice functions/kernels #include "cudalib.h" // generic CUDA helpers #include "cudadevice.h" #include #include // for auto_ptr #include #include namespace msra { namespace cuda { extern void operator||(cudaError_t rc, const char *msg); // TODO: imported from cudamatrix.cpp --better move to cudalib.h // this implements the basic operations of exported interface vectorbase<>, from which all vectors derive // TODO: This really should not be in cudalattice, since it is more general; we need a cudavector.cpp/h template class vectorbaseimpl : public /*interface*/ VECTORTYPE, // user-type interface; must derive from vectorbase public OPSTYPE, // type of class that implements the kernels; must derive from vectorref public objectondevice // setdevice() { typedef typename VECTORTYPE::elemtype elemtype; // (for convenience) size_t capacity; // amount of allocated storage (like capacity() vs. vectorref::n = size()) void release() { ondevice no(deviceid); free(this->reset(NULL, 0)); } public: vectorbaseimpl(size_t deviceid) : capacity(0), objectondevice(deviceid) { } ~vectorbaseimpl() { release(); } void allocate(size_t sz) { if (sz > capacity) // need to grow { ondevice no(deviceid); // switch to desired CUDA card cuda_ptr pnew = malloc(sz); // allocate memory inside CUDA device (or throw) capacity = sz; // if succeeded then: remember cuda_ptr p = this->reset(pnew, sz); // and swap the pointers and update n free(p); // then release the old one } else // not growing: keep same allocation this->reset(this->get(), sz); } size_t size() const throw() { return vectorref::size(); } void assign(const elemtype *p, size_t nelem, bool synchronize) { allocate(nelem); // assign will resize the target appropriately ondevice no(deviceid); // switch to desired CUDA card if (nelem > 0) memcpy(this->get(), 0, p, nelem); if (synchronize) join(); } void fetch(elemtype *p, size_t nelem, bool synchronize) const { if (nelem != size()) // fetch() cannot resize the target; caller must do that LogicError("fetch: vector size mismatch"); ondevice no(deviceid); // switch to desired CUDA card if (nelem > 0) memcpy(p, this->get(), 0, nelem); if (synchronize) join(); }; }; // --------------------------------------------------------------------------- // glue code for lattice-related classes // The XXXvectorimpl classes must derive from vectorbaseimpl. // For classes without kernels that operate on the vector, XXXvectorimpl is not // needed, use vectorbaseimpl> instead, where // XXXvector is an alias for vectorbase (but better keep that alias in cudalattice.h // to document which vectors are implemented). // --------------------------------------------------------------------------- matrixref tomatrixref(const Microsoft::MSR::CNTK::Matrix &m) { return matrixref(m.BufferPointer(), m.GetNumRows(), m.GetNumCols(), m.GetNumRows()); } class latticefunctionsimpl : public vectorbaseimpl { public: latticefunctionsimpl(size_t deviceid) : vectorbaseimpl(deviceid) { } private: void edgealignment(const lrhmmdefvector &hmms, const lr3transPvector &transPs, const size_t spalignunitid, const size_t silalignunitid, const Microsoft::MSR::CNTK::Matrix &logLLs, const nodeinfovector &nodes, const edgeinfowithscoresvector &edges, const aligninfovector &aligns, const uintvector &alignoffsets, ushortvector &backptrstorage, const sizetvector &backptroffsets, ushortvector &alignresult, floatvector &edgeacscores) // output { ondevice no(deviceid); matrixref logLLsMatrixRef = tomatrixref(logLLs); latticefunctionsops::edgealignment(dynamic_cast> &>(hmms), dynamic_cast> &>(transPs), spalignunitid, silalignunitid, logLLsMatrixRef, dynamic_cast> &>(nodes), dynamic_cast> &>(edges), dynamic_cast> &>(aligns), dynamic_cast> &>(alignoffsets), dynamic_cast> &>(backptrstorage), dynamic_cast> &>(backptroffsets), dynamic_cast> &>(alignresult), dynamic_cast> &>(edgeacscores)); } void forwardbackwardlattice(const size_t *batchsizeforward, const size_t *batchsizebackward, const size_t numlaunchforward, const size_t numlaunchbackward, const size_t spalignunitid, const size_t silalignunitid, const floatvector &edgeacscores, const edgeinfowithscoresvector &edges, const nodeinfovector &nodes, const aligninfovector &aligns, const ushortvector &alignments, const uintvector &alignoffsets, doublevector &logpps, doublevector &logalphas, doublevector &logbetas, const float lmf, const float wp, const float amf, const float boostingfactor, const bool returnEframescorrect, const ushortvector &uids, const ushortvector &senone2classmap, doublevector &logaccalphas, doublevector &logaccbetas, doublevector &logframescorrectedge, doublevector &logEframescorrect, doublevector &Eframescorrectbuf, double &logEframescorrecttotal, double &totalfwscore) { ondevice no(deviceid); latticefunctionsops::forwardbackwardlattice(batchsizeforward, batchsizebackward, numlaunchforward, numlaunchbackward, spalignunitid, silalignunitid, dynamic_cast> &>(edgeacscores), dynamic_cast> &>(edges), dynamic_cast> &>(nodes), dynamic_cast> &>(aligns), dynamic_cast> &>(alignments), dynamic_cast> &>(alignoffsets), dynamic_cast> &>(logpps), dynamic_cast> &>(logalphas), dynamic_cast> &>(logbetas), lmf, wp, amf, boostingfactor, returnEframescorrect, dynamic_cast> &>(uids), dynamic_cast> &>(senone2classmap), dynamic_cast> &>(logaccalphas), dynamic_cast> &>(logaccbetas), dynamic_cast> &>(logframescorrectedge), dynamic_cast> &>(logEframescorrect), dynamic_cast> &>(Eframescorrectbuf), logEframescorrecttotal, totalfwscore); } void sMBRerrorsignal(const ushortvector &alignstateids, const uintvector &alignoffsets, const edgeinfowithscoresvector &edges, const nodeinfovector &nodes, const doublevector &logpps, const float amf, const doublevector &logEframescorrect, const double logEframescorrecttotal, Microsoft::MSR::CNTK::Matrix &dengammas, Microsoft::MSR::CNTK::Matrix &dengammasbuf) { ondevice no(deviceid); matrixref dengammasMatrixRef = tomatrixref(dengammas); matrixref dengammasbufMatrixRef = tomatrixref(dengammasbuf); latticefunctionsops::sMBRerrorsignal(dynamic_cast> &>(alignstateids), dynamic_cast> &>(alignoffsets), dynamic_cast> &>(edges), dynamic_cast> &>(nodes), dynamic_cast> &>(logpps), amf, dynamic_cast> &>(logEframescorrect), logEframescorrecttotal, dengammasMatrixRef, dengammasbufMatrixRef); } void mmierrorsignal(const ushortvector &alignstateids, const uintvector &alignoffsets, const edgeinfowithscoresvector &edges, const nodeinfovector &nodes, const doublevector &logpps, Microsoft::MSR::CNTK::Matrix &dengammas) { ondevice no(deviceid); matrixref dengammasMatrixRef = tomatrixref(dengammas); latticefunctionsops::mmierrorsignal(dynamic_cast> &>(alignstateids), dynamic_cast> &>(alignoffsets), dynamic_cast> &>(edges), dynamic_cast> &>(nodes), dynamic_cast> &>(logpps), dengammasMatrixRef); } void stateposteriors(const ushortvector &alignstateids, const uintvector &alignoffsets, const edgeinfowithscoresvector &edges, const nodeinfovector &nodes, const doublevector &logqs, Microsoft::MSR::CNTK::Matrix &logacc) { ondevice no(deviceid); matrixref logaccMatrixRef = tomatrixref(logacc); latticefunctionsops::stateposteriors(dynamic_cast> &>(alignstateids), dynamic_cast> &>(alignoffsets), dynamic_cast> &>(edges), dynamic_cast> &>(nodes), dynamic_cast> &>(logqs), logaccMatrixRef); } }; latticefunctions *newlatticefunctions(size_t deviceid) { return new latticefunctionsimpl(deviceid); } // implementation of lrhmmdefvector // Class has no vector-level member functions, so no need for an extra type lrhmmdefvector *newlrhmmdefvector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } lr3transPvector *newlr3transPvector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } ushortvector *newushortvector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } uintvector *newuintvector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } floatvector *newfloatvector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } doublevector *newdoublevector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } sizetvector *newsizetvector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } nodeinfovector *newnodeinfovector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } edgeinfowithscoresvector *newedgeinfovector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } aligninfovector *newaligninfovector(size_t deviceid) { return new vectorbaseimpl>(deviceid); } }; };