https://github.com/Microsoft/CNTK
Tip revision: bb18eb24c7ad17ebca1cfa1fc94d9cd2eff52774 authored by Mark Hillebrand on 18 January 2016, 08:37:59 UTC
License change
License change
Tip revision: bb18eb2
cudalattice.cpp
// cudalattice.cpp -- lattice forward/backward functions for CUDA execution (glue code)
//
// F. Seide, V-hansu
#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms --add this at the top of all CPP files that give "function or variable may be unsafe" warnings
#define DLLEXPORT
#define __kernel_emulation__ // allow the compilation of CUDA kernels on the CPU
#include "latticefunctionskernels.h" // for the actual inner kernels and any argument types that are not yet defined in latticestorage.h
#undef __kernel_emulation__
#include "cudalattice.h" // this exports the class
#include "cudalatticeops.h" // brings in the actual lattice functions/kernels
#include "cudalib.h" // generic CUDA helpers
#include "cudadevice.h"
#include <math.h>
#include <memory> // for auto_ptr
#include <assert.h>
#include <float.h>
namespace msra { namespace cuda {
extern void operator|| (cudaError_t rc, const char * msg); // TODO: imported from cudamatrix.cpp --better move to cudalib.h
// this implements the basic operations of exported interface vectorbase<>, from which all vectors derive
// TODO: This really should not be in cudalattice, since it is more general; we need a cudavector.cpp/h
template<typename VECTORTYPE,typename OPSTYPE> class vectorbaseimpl :
public /*interface*/VECTORTYPE, // user-type interface; must derive from vectorbase<VECTORBASE::elemtype>
public OPSTYPE, // type of class that implements the kernels; must derive from vectorref<VECTORBASE::elemtype>
public objectondevice // setdevice()
{
typedef typename VECTORTYPE::elemtype elemtype; // (for convenience)
size_t capacity; // amount of allocated storage (like capacity() vs. vectorref::n = size())
void release() { ondevice no (deviceid); free (this->reset (NULL, 0)); }
public:
vectorbaseimpl(size_t deviceid) : capacity(0), objectondevice(deviceid) {}
~vectorbaseimpl() { release(); }
void allocate (size_t sz)
{
if (sz > capacity) // need to grow
{
ondevice no (deviceid); // switch to desired CUDA card
cuda_ptr<elemtype> pnew = malloc<elemtype> (sz); // allocate memory inside CUDA device (or throw)
capacity = sz; // if succeeded then: remember
cuda_ptr<elemtype> p = this->reset (pnew, sz); // and swap the pointers and update n
free (p); // then release the old one
}
else // not growing: keep same allocation
this->reset (this->get(), sz);
}
size_t size() const throw() { return vectorref<elemtype>::size(); }
void assign (const elemtype * p, size_t nelem, bool synchronize)
{
allocate (nelem); // assign will resize the target appropriately
ondevice no (deviceid); // switch to desired CUDA card
if (nelem > 0)
memcpy (this->get(), 0, p, nelem);
if (synchronize)
join();
}
void fetch (elemtype * p, size_t nelem, bool synchronize) const
{
if (nelem != size()) // fetch() cannot resize the target; caller must do that
LogicError("fetch: vector size mismatch");
ondevice no (deviceid); // switch to desired CUDA card
if (nelem > 0)
memcpy (p, this->get(), 0, nelem);
if (synchronize)
join();
};
};
// ---------------------------------------------------------------------------
// glue code for lattice-related classes
// The XXXvectorimpl classes must derive from vectorbaseimpl<XXXvector,XXXvectorops>.
// For classes without kernels that operate on the vector, XXXvectorimpl is not
// needed, use vectorbaseimpl<XXXvector,vectorref<XXX>> instead, where
// XXXvector is an alias for vectorbase<XXX> (but better keep that alias in cudalattice.h
// to document which vectors are implemented).
// ---------------------------------------------------------------------------
matrixref<float> tomatrixref(const Microsoft::MSR::CNTK::Matrix<float>& m)
{
return matrixref<float>(m.BufferPointer(), m.GetNumRows(), m.GetNumCols(), m.GetNumRows());
}
class latticefunctionsimpl : public vectorbaseimpl<latticefunctions,latticefunctionsops>
{
public:
latticefunctionsimpl(size_t deviceid) : vectorbaseimpl(deviceid) {}
private:
void edgealignment (const lrhmmdefvector & hmms, const lr3transPvector & transPs, const size_t spalignunitid,
const size_t silalignunitid, const Microsoft::MSR::CNTK::Matrix<float>& logLLs, const nodeinfovector & nodes,
const edgeinfowithscoresvector & edges, const aligninfovector & aligns,
const uintvector & alignoffsets, ushortvector & backptrstorage, const sizetvector & backptroffsets,
ushortvector & alignresult, floatvector & edgeacscores) // output
{
ondevice no (deviceid);
matrixref<float> logLLsMatrixRef = tomatrixref(logLLs);
latticefunctionsops::edgealignment (dynamic_cast<const vectorbaseimpl<lrhmmdefvector, vectorref<lrhmmdef>> &> (hmms),
dynamic_cast<const vectorbaseimpl<lr3transPvector, vectorref<lr3transP>> &> (transPs),
spalignunitid, silalignunitid, logLLsMatrixRef,
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges),
dynamic_cast<const vectorbaseimpl<aligninfovector, vectorref<msra::lattices::aligninfo>> &> (aligns),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
dynamic_cast<vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (backptrstorage),
dynamic_cast<const vectorbaseimpl<sizetvector, vectorref<size_t>> &> (backptroffsets),
dynamic_cast<vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignresult),
dynamic_cast<vectorbaseimpl<floatvector, vectorref<float>> &> (edgeacscores));
}
void forwardbackwardlattice (const size_t * batchsizeforward, const size_t * batchsizebackward,
const size_t numlaunchforward, const size_t numlaunchbackward,
const size_t spalignunitid, const size_t silalignunitid,
const floatvector & edgeacscores, const edgeinfowithscoresvector & edges,
const nodeinfovector & nodes, const aligninfovector & aligns,
const ushortvector & alignments, const uintvector & alignoffsets,
doublevector & logpps, doublevector & logalphas, doublevector & logbetas,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool returnEframescorrect,
const ushortvector & uids, const ushortvector & senone2classmap, doublevector & logaccalphas,
doublevector & logaccbetas, doublevector & logframescorrectedge,
doublevector & logEframescorrect, doublevector & Eframescorrectbuf, double & logEframescorrecttotal, double & totalfwscore)
{
ondevice no (deviceid);
latticefunctionsops::forwardbackwardlattice (batchsizeforward, batchsizebackward, numlaunchforward, numlaunchbackward,
spalignunitid, silalignunitid,
dynamic_cast<const vectorbaseimpl<floatvector, vectorref<float>> &> (edgeacscores),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
dynamic_cast<const vectorbaseimpl<aligninfovector, vectorref<msra::lattices::aligninfo>> &> (aligns),
dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignments),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logpps),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logalphas),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logbetas),
lmf, wp, amf, boostingfactor, returnEframescorrect,
dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (uids),
dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (senone2classmap),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logaccalphas),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logaccbetas),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logframescorrectedge),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logEframescorrect),
dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (Eframescorrectbuf),
logEframescorrecttotal, totalfwscore);
}
void sMBRerrorsignal (const ushortvector & alignstateids,
const uintvector & alignoffsets,
const edgeinfowithscoresvector & edges, const nodeinfovector & nodes,
const doublevector & logpps, const float amf, const doublevector & logEframescorrect,
const double logEframescorrecttotal, Microsoft::MSR::CNTK::Matrix<float>& dengammas, Microsoft::MSR::CNTK::Matrix<float>& dengammasbuf)
{
ondevice no (deviceid);
matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
matrixref<float> dengammasbufMatrixRef = tomatrixref(dengammasbuf);
latticefunctionsops::sMBRerrorsignal (dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignstateids),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &> (logpps),
amf,
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &> (logEframescorrect),
logEframescorrecttotal, dengammasMatrixRef, dengammasbufMatrixRef);
}
void mmierrorsignal (const ushortvector & alignstateids, const uintvector & alignoffsets,
const edgeinfowithscoresvector & edges, const nodeinfovector & nodes,
const doublevector & logpps, Microsoft::MSR::CNTK::Matrix<float>& dengammas)
{
ondevice no (deviceid);
matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
latticefunctionsops::mmierrorsignal (dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignstateids),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &> (logpps),
dengammasMatrixRef);
}
void stateposteriors (const ushortvector & alignstateids, const uintvector & alignoffsets,
const edgeinfowithscoresvector & edges, const nodeinfovector & nodes,
const doublevector & logqs, Microsoft::MSR::CNTK::Matrix<float>& logacc)
{
ondevice no (deviceid);
matrixref<float> logaccMatrixRef = tomatrixref(logacc);
latticefunctionsops::stateposteriors (dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignstateids),
dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges),
dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &> (logqs),
logaccMatrixRef);
}
};
latticefunctions * newlatticefunctions(size_t deviceid) { return new latticefunctionsimpl(deviceid); }
// implementation of lrhmmdefvector
// Class has no vector-level member functions, so no need for an extra type
lrhmmdefvector * newlrhmmdefvector(size_t deviceid) { return new vectorbaseimpl<lrhmmdefvector, vectorref<lrhmmdef>>(deviceid); }
lr3transPvector * newlr3transPvector(size_t deviceid) { return new vectorbaseimpl<lr3transPvector, vectorref<lr3transP>>(deviceid); }
ushortvector * newushortvector(size_t deviceid) { return new vectorbaseimpl<ushortvector, vectorref<unsigned short>>(deviceid); }
uintvector * newuintvector(size_t deviceid) { return new vectorbaseimpl<uintvector, vectorref<unsigned int>>(deviceid); }
floatvector * newfloatvector(size_t deviceid) { return new vectorbaseimpl<floatvector, vectorref<float>>(deviceid); }
doublevector * newdoublevector(size_t deviceid) { return new vectorbaseimpl<doublevector, vectorref<double>>(deviceid); }
sizetvector * newsizetvector(size_t deviceid) { return new vectorbaseimpl<sizetvector, vectorref<size_t>>(deviceid); }
nodeinfovector * newnodeinfovector(size_t deviceid) { return new vectorbaseimpl<nodeinfovector, vectorref<nodeinfo>>(deviceid); }
edgeinfowithscoresvector * newedgeinfovector(size_t deviceid) { return new vectorbaseimpl<edgeinfowithscoresvector, vectorref<edgeinfowithscores>>(deviceid); }
aligninfovector * newaligninfovector(size_t deviceid) { return new vectorbaseimpl<aligninfovector, vectorref<aligninfo>>(deviceid); }
};};