https://github.com/Microsoft/CNTK
Raw File
Tip revision: bb18eb24c7ad17ebca1cfa1fc94d9cd2eff52774 authored by Mark Hillebrand on 18 January 2016, 08:37:59 UTC
License change
Tip revision: bb18eb2
cudalattice.cpp
// cudalattice.cpp -- lattice forward/backward functions for CUDA execution (glue code)
//
// F. Seide, V-hansu

#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms  --add this at the top of all CPP files that give "function or variable may be unsafe" warnings

#define DLLEXPORT
#define __kernel_emulation__    // allow the compilation of CUDA kernels on the CPU
#include "latticefunctionskernels.h"    // for the actual inner kernels and any argument types that are not yet defined in latticestorage.h
#undef __kernel_emulation__
#include "cudalattice.h"        // this exports the class
#include "cudalatticeops.h"     // brings in the actual lattice functions/kernels
#include "cudalib.h"            // generic CUDA helpers
#include "cudadevice.h"
#include <math.h>
#include <memory>               // for auto_ptr
#include <assert.h>
#include <float.h>

namespace msra { namespace cuda {

extern void operator|| (cudaError_t rc, const char * msg);      // TODO: imported from cudamatrix.cpp --better move to cudalib.h

// this implements the basic operations of exported interface vectorbase<>, from which all vectors derive
// TODO: This really should not be in cudalattice, since it is more general; we need a cudavector.cpp/h
template<typename VECTORTYPE,typename OPSTYPE> class vectorbaseimpl :
    public /*interface*/VECTORTYPE,                 // user-type interface; must derive from vectorbase<VECTORBASE::elemtype>
    public OPSTYPE,                              // type of class that implements the kernels; must derive from vectorref<VECTORBASE::elemtype>
    public objectondevice                           // setdevice()
{
    typedef typename VECTORTYPE::elemtype elemtype; // (for convenience)
    size_t capacity;                                // amount of allocated storage (like capacity() vs. vectorref::n = size())
    void release() { ondevice no (deviceid); free (this->reset (NULL, 0)); }
public:
    vectorbaseimpl(size_t deviceid) : capacity(0), objectondevice(deviceid) {}
    ~vectorbaseimpl() { release(); }
    void allocate (size_t sz)
    {
        if (sz > capacity)                                      // need to grow
        {
            ondevice no (deviceid);                             // switch to desired CUDA card
            cuda_ptr<elemtype> pnew = malloc<elemtype> (sz);    // allocate memory inside CUDA device (or throw)
            capacity = sz;                                      // if succeeded then: remember
            cuda_ptr<elemtype> p = this->reset (pnew, sz);            //  and swap the pointers and update n
            free (p);                                           //  then release the old one
        }
        else                                                    // not growing: keep same allocation
            this->reset (this->get(), sz);
    }
    size_t size() const throw() { return vectorref<elemtype>::size(); }
    void assign (const elemtype * p, size_t nelem, bool synchronize)
    {
        allocate (nelem);           // assign will resize the target appropriately
        ondevice no (deviceid);     // switch to desired CUDA card
        if (nelem > 0)
            memcpy (this->get(), 0, p, nelem);
        if (synchronize)
            join();
    }
    void fetch (elemtype * p, size_t nelem, bool synchronize) const
    {
        if (nelem != size())        // fetch() cannot resize the target; caller must do that
            LogicError("fetch: vector size mismatch");
        ondevice no (deviceid);     // switch to desired CUDA card
        if (nelem > 0)
            memcpy (p, this->get(), 0, nelem);
        if (synchronize)
            join();
    };
};

// ---------------------------------------------------------------------------
// glue code for lattice-related classes
// The XXXvectorimpl classes must derive from vectorbaseimpl<XXXvector,XXXvectorops>.
// For classes without kernels that operate on the vector, XXXvectorimpl is not
// needed, use vectorbaseimpl<XXXvector,vectorref<XXX>> instead, where
// XXXvector is an alias for vectorbase<XXX> (but better keep that alias in cudalattice.h
// to document which vectors are implemented).
// ---------------------------------------------------------------------------

matrixref<float> tomatrixref(const Microsoft::MSR::CNTK::Matrix<float>& m)
{
    return matrixref<float>(m.BufferPointer(), m.GetNumRows(), m.GetNumCols(), m.GetNumRows());
}

class latticefunctionsimpl : public vectorbaseimpl<latticefunctions,latticefunctionsops>
{
public:
    latticefunctionsimpl(size_t deviceid) : vectorbaseimpl(deviceid) {}

private:
    void edgealignment (const lrhmmdefvector & hmms, const lr3transPvector & transPs, const size_t spalignunitid, 
                        const size_t silalignunitid, const Microsoft::MSR::CNTK::Matrix<float>& logLLs, const nodeinfovector & nodes, 
                        const edgeinfowithscoresvector & edges, const aligninfovector & aligns,
                        const uintvector & alignoffsets, ushortvector & backptrstorage, const sizetvector & backptroffsets,
                        ushortvector & alignresult, floatvector & edgeacscores)         // output
    {
        ondevice no (deviceid); 

        matrixref<float> logLLsMatrixRef = tomatrixref(logLLs);
        latticefunctionsops::edgealignment (dynamic_cast<const vectorbaseimpl<lrhmmdefvector, vectorref<lrhmmdef>> &> (hmms), 
                                            dynamic_cast<const vectorbaseimpl<lr3transPvector, vectorref<lr3transP>> &> (transPs),
                                            spalignunitid, silalignunitid, logLLsMatrixRef,
                                            dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
                                            dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges), 
                                            dynamic_cast<const vectorbaseimpl<aligninfovector, vectorref<msra::lattices::aligninfo>> &> (aligns),
                                            dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
                                            dynamic_cast<vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (backptrstorage), 
                                            dynamic_cast<const vectorbaseimpl<sizetvector, vectorref<size_t>> &> (backptroffsets),
                                            dynamic_cast<vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignresult), 
                                            dynamic_cast<vectorbaseimpl<floatvector, vectorref<float>> &> (edgeacscores));
    }
    
    void forwardbackwardlattice (const size_t * batchsizeforward, const size_t * batchsizebackward, 
                                 const size_t numlaunchforward, const size_t numlaunchbackward,
                                 const size_t spalignunitid, const size_t silalignunitid,
                                 const floatvector & edgeacscores, const edgeinfowithscoresvector & edges, 
                                 const nodeinfovector & nodes, const aligninfovector & aligns, 
                                 const ushortvector & alignments, const uintvector & alignoffsets,
                                 doublevector & logpps, doublevector & logalphas, doublevector & logbetas,
                                 const float lmf, const float wp, const float amf, const float boostingfactor, const bool returnEframescorrect, 
                                 const ushortvector & uids, const ushortvector & senone2classmap, doublevector & logaccalphas, 
                                 doublevector & logaccbetas, doublevector & logframescorrectedge,
                                 doublevector & logEframescorrect, doublevector & Eframescorrectbuf, double & logEframescorrecttotal, double & totalfwscore)
    {
        ondevice no (deviceid);
        latticefunctionsops::forwardbackwardlattice (batchsizeforward, batchsizebackward, numlaunchforward, numlaunchbackward,
                                                     spalignunitid, silalignunitid,
                                                     dynamic_cast<const vectorbaseimpl<floatvector, vectorref<float>> &> (edgeacscores),
                                                     dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges), 
                                                     dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
                                                     dynamic_cast<const vectorbaseimpl<aligninfovector, vectorref<msra::lattices::aligninfo>> &> (aligns),
                                                     dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignments), 
                                                     dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets), 
                                                     dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logpps), 
                                                     dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logalphas), 
                                                     dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logbetas),
                                                     lmf, wp, amf, boostingfactor, returnEframescorrect,
                                                     dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (uids),
                                                     dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (senone2classmap),
                                                     dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logaccalphas),
                                                     dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logaccbetas),
                                                     dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logframescorrectedge),
                                                     dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (logEframescorrect),
                                                     dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &> (Eframescorrectbuf),
                                                     logEframescorrecttotal, totalfwscore);
    }

    void sMBRerrorsignal (const ushortvector & alignstateids, 
                          const uintvector & alignoffsets,
                          const edgeinfowithscoresvector & edges, const nodeinfovector & nodes, 
                          const doublevector & logpps, const float amf, const doublevector & logEframescorrect,
                          const double logEframescorrecttotal, Microsoft::MSR::CNTK::Matrix<float>& dengammas, Microsoft::MSR::CNTK::Matrix<float>& dengammasbuf)
    {
        ondevice no (deviceid);

        matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
        matrixref<float> dengammasbufMatrixRef = tomatrixref(dengammasbuf);
        latticefunctionsops::sMBRerrorsignal (dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignstateids),
                                              dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
                                              dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges),
                                              dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
                                              dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &> (logpps),
                                              amf,
                                              dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &> (logEframescorrect),
                                              logEframescorrecttotal, dengammasMatrixRef, dengammasbufMatrixRef);
    }

    void mmierrorsignal (const ushortvector & alignstateids, const uintvector & alignoffsets,
                         const edgeinfowithscoresvector & edges, const nodeinfovector & nodes, 
                         const doublevector & logpps, Microsoft::MSR::CNTK::Matrix<float>& dengammas)
    {
        ondevice no (deviceid);

        matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
        latticefunctionsops::mmierrorsignal (dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignstateids),
                                             dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
                                             dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges),
                                             dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
                                             dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &> (logpps),
                                             dengammasMatrixRef);
    }

    void stateposteriors (const ushortvector & alignstateids, const uintvector & alignoffsets,
                          const edgeinfowithscoresvector & edges, const nodeinfovector & nodes, 
                          const doublevector & logqs, Microsoft::MSR::CNTK::Matrix<float>& logacc)
    {
        ondevice no (deviceid);

        matrixref<float> logaccMatrixRef = tomatrixref(logacc);
        latticefunctionsops::stateposteriors (dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &> (alignstateids),
                                              dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &> (alignoffsets),
                                              dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &> (edges),
                                              dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &> (nodes),
                                              dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &> (logqs),
                                              logaccMatrixRef);
    }
};

latticefunctions * newlatticefunctions(size_t deviceid) { return new latticefunctionsimpl(deviceid); }

// implementation of lrhmmdefvector
// Class has no vector-level member functions, so no need for an extra type
lrhmmdefvector * newlrhmmdefvector(size_t deviceid) { return new vectorbaseimpl<lrhmmdefvector, vectorref<lrhmmdef>>(deviceid); }
lr3transPvector * newlr3transPvector(size_t deviceid) { return new vectorbaseimpl<lr3transPvector, vectorref<lr3transP>>(deviceid); }
ushortvector * newushortvector(size_t deviceid) { return new vectorbaseimpl<ushortvector, vectorref<unsigned short>>(deviceid); }
uintvector * newuintvector(size_t deviceid) { return new vectorbaseimpl<uintvector, vectorref<unsigned int>>(deviceid); }
floatvector * newfloatvector(size_t deviceid) { return new vectorbaseimpl<floatvector, vectorref<float>>(deviceid); }
doublevector * newdoublevector(size_t deviceid) { return new vectorbaseimpl<doublevector, vectorref<double>>(deviceid); }
sizetvector * newsizetvector(size_t deviceid) { return new vectorbaseimpl<sizetvector, vectorref<size_t>>(deviceid); }
nodeinfovector * newnodeinfovector(size_t deviceid) { return new vectorbaseimpl<nodeinfovector, vectorref<nodeinfo>>(deviceid); }
edgeinfowithscoresvector * newedgeinfovector(size_t deviceid) { return new vectorbaseimpl<edgeinfowithscoresvector, vectorref<edgeinfowithscores>>(deviceid); }
aligninfovector * newaligninfovector(size_t deviceid) { return new vectorbaseimpl<aligninfovector, vectorref<aligninfo>>(deviceid); }

};};
back to top