https://github.com/Microsoft/CNTK
Raw File
Tip revision: 10a8ffcf50d7b9225f3236ffcfdc422b2014fb92 authored by microsoft-github-policy-service[bot] on 23 September 2022, 14:06:50 UTC
Microsoft mandatory file (#3870)
Tip revision: 10a8ffc
cudalattice.cpp
// cudalattice.cpp -- lattice forward/backward functions for CUDA execution (glue code)
//
// F. Seide, V-hansu

#define _CRT_SECURE_NO_WARNINGS // "secure" CRT not available on all platforms  --add this at the top of all CPP files that give "function or variable may be unsafe" warnings

#define DLLEXPORT
#define __kernel_emulation__         // allow the compilation of CUDA kernels on the CPU
#include "latticefunctionskernels.h" // for the actual inner kernels and any argument types that are not yet defined in latticestorage.h
#undef __kernel_emulation__
#include "cudalattice.h"    // this exports the class
#include "cudalatticeops.h" // brings in the actual lattice functions/kernels
#include "cudalib.h"        // generic CUDA helpers
#include "cudadevice.h"
#include <math.h>
#include <memory> // for auto_ptr
#include <assert.h>
#include <float.h>

namespace msra { namespace cuda {

extern void operator||(cudaError_t rc, const char *msg); // TODO: imported from cudamatrix.cpp --better move to cudalib.h

// this implements the basic operations of exported interface vectorbase<>, from which all vectors derive
// TODO: This really should not be in cudalattice, since it is more general; we need a cudavector.cpp/h
template <typename VECTORTYPE, typename OPSTYPE>
class vectorbaseimpl : public /*interface*/ VECTORTYPE, // user-type interface; must derive from vectorbase<VECTORBASE::elemtype>
                       public OPSTYPE,                  // type of class that implements the kernels; must derive from vectorref<VECTORBASE::elemtype>
                       public objectondevice            // setdevice()
{
    typedef typename VECTORTYPE::elemtype elemtype; // (for convenience)
    size_t capacity;                                // amount of allocated storage (like capacity() vs. vectorref::n = size())
    void release()
    {
        ondevice no(deviceid);
        free(this->reset(NULL, 0));
    }

public:
    vectorbaseimpl(size_t deviceid)
        : capacity(0), objectondevice(deviceid)
    {
    }
    ~vectorbaseimpl()
    {
        release();
    }
    void allocate(size_t sz)
    {
        if (sz > capacity) // need to grow
        {
            ondevice no(deviceid);                          // switch to desired CUDA card
            cuda_ptr<elemtype> pnew = malloc<elemtype>(sz); // allocate memory inside CUDA device (or throw)
            capacity = sz;                                  // if succeeded then: remember
            cuda_ptr<elemtype> p2 = this->reset(pnew, sz);  //  and swap the pointers and update n
            free(p2);                                       //  then release the old one
        }
        else // not growing: keep same allocation
            this->reset(this->get(), sz);
    }
    size_t size() const throw()
    {
        return vectorref<elemtype>::size();
    }
    void assign(const elemtype *p2, size_t nelem, bool synchronize)
    {
        allocate(nelem);       // assign will resize the target appropriately
        ondevice no(deviceid); // switch to desired CUDA card
        if (nelem > 0)
            memcpy(this->get(), 0, p2, nelem);
        if (synchronize)
            join();
    }
    void fetch(elemtype *p2, size_t nelem, bool synchronize) const
    {
        if (nelem != size()) // fetch() cannot resize the target; caller must do that
            LogicError("fetch: vector size mismatch");
        ondevice no(deviceid); // switch to desired CUDA card
        if (nelem > 0)
            memcpy(p2, this->get(), 0, nelem);
        if (synchronize)
            join();
    };
};

// ---------------------------------------------------------------------------
// glue code for lattice-related classes
// The XXXvectorimpl classes must derive from vectorbaseimpl<XXXvector,XXXvectorops>.
// For classes without kernels that operate on the vector, XXXvectorimpl is not
// needed, use vectorbaseimpl<XXXvector,vectorref<XXX>> instead, where
// XXXvector is an alias for vectorbase<XXX> (but better keep that alias in cudalattice.h
// to document which vectors are implemented).
// ---------------------------------------------------------------------------

matrixref<float> tomatrixref(const Microsoft::MSR::CNTK::Matrix<float> &m)
{
    return matrixref<float>(m.Data(), m.GetNumRows(), m.GetNumCols(), m.GetNumRows());
}

class latticefunctionsimpl : public vectorbaseimpl<latticefunctions, latticefunctionsops>
{
public:
    latticefunctionsimpl(size_t deviceid)
        : vectorbaseimpl(deviceid)
    {
    }

private:
    void edgealignment(const lrhmmdefvector &hmms, const lr3transPvector &transPs, const size_t spalignunitid,
                       const size_t silalignunitid, const Microsoft::MSR::CNTK::Matrix<float> &logLLs, const nodeinfovector &nodes,
                       const edgeinfowithscoresvector &edges, const aligninfovector &aligns,
                       const uintvector &alignoffsets, ushortvector &backptrstorage, const sizetvector &backptroffsets,
                       ushortvector &alignresult, floatvector &edgeacscores) // output
    {
        ondevice no(deviceid);

        matrixref<float> logLLsMatrixRef = tomatrixref(logLLs);
        latticefunctionsops::edgealignment(dynamic_cast<const vectorbaseimpl<lrhmmdefvector, vectorref<lrhmmdef>> &>(hmms),
                                           dynamic_cast<const vectorbaseimpl<lr3transPvector, vectorref<lr3transP>> &>(transPs),
                                           spalignunitid, silalignunitid, logLLsMatrixRef,
                                           dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
                                           dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
                                           dynamic_cast<const vectorbaseimpl<aligninfovector, vectorref<msra::lattices::aligninfo>> &>(aligns),
                                           dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
                                           dynamic_cast<vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(backptrstorage),
                                           dynamic_cast<const vectorbaseimpl<sizetvector, vectorref<size_t>> &>(backptroffsets),
                                           dynamic_cast<vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignresult),
                                           dynamic_cast<vectorbaseimpl<floatvector, vectorref<float>> &>(edgeacscores));
    }

    void forwardbackwardlattice(const size_t *batchsizeforward, const size_t *batchsizebackward,
                                const size_t numlaunchforward, const size_t numlaunchbackward,
                                const size_t spalignunitid, const size_t silalignunitid,
                                const floatvector &edgeacscores, const edgeinfowithscoresvector &edges,
                                const nodeinfovector &nodes, const aligninfovector &aligns,
                                const ushortvector &alignments, const uintvector &alignoffsets,
                                doublevector &logpps, doublevector &logalphas, doublevector &logbetas,
                                const float lmf, const float wp, const float amf, const float boostingfactor, const bool returnEframescorrect,
                                const ushortvector &uids, const ushortvector &senone2classmap, doublevector &logaccalphas,
                                doublevector &logaccbetas, doublevector &logframescorrectedge,
                                doublevector &logEframescorrect, doublevector &Eframescorrectbuf, double &logEframescorrecttotal, double &totalfwscore)
    {
        ondevice no(deviceid);
        latticefunctionsops::forwardbackwardlattice(batchsizeforward, batchsizebackward, numlaunchforward, numlaunchbackward,
                                                    spalignunitid, silalignunitid,
                                                    dynamic_cast<const vectorbaseimpl<floatvector, vectorref<float>> &>(edgeacscores),
                                                    dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
                                                    dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
                                                    dynamic_cast<const vectorbaseimpl<aligninfovector, vectorref<msra::lattices::aligninfo>> &>(aligns),
                                                    dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignments),
                                                    dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
                                                    dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logpps),
                                                    dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logalphas),
                                                    dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logbetas),
                                                    lmf, wp, amf, boostingfactor, returnEframescorrect,
                                                    dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(uids),
                                                    dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(senone2classmap),
                                                    dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logaccalphas),
                                                    dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logaccbetas),
                                                    dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logframescorrectedge),
                                                    dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(logEframescorrect),
                                                    dynamic_cast<vectorbaseimpl<doublevector, vectorref<double>> &>(Eframescorrectbuf),
                                                    logEframescorrecttotal, totalfwscore);
    }

    void sMBRerrorsignal(const ushortvector &alignstateids,
                         const uintvector &alignoffsets,
                         const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
                         const doublevector &logpps, const float amf, const doublevector &logEframescorrect,
                         const double logEframescorrecttotal, Microsoft::MSR::CNTK::Matrix<float> &dengammas, Microsoft::MSR::CNTK::Matrix<float> &dengammasbuf)
    {
        ondevice no(deviceid);

        matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
        matrixref<float> dengammasbufMatrixRef = tomatrixref(dengammasbuf);
        latticefunctionsops::sMBRerrorsignal(dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignstateids),
                                             dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
                                             dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
                                             dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
                                             dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(logpps),
                                             amf,
                                             dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(logEframescorrect),
                                             logEframescorrecttotal, dengammasMatrixRef, dengammasbufMatrixRef);
    }

    void mmierrorsignal(const ushortvector &alignstateids, const uintvector &alignoffsets,
                        const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
                        const doublevector &logpps, Microsoft::MSR::CNTK::Matrix<float> &dengammas)
    {
        ondevice no(deviceid);

        matrixref<float> dengammasMatrixRef = tomatrixref(dengammas);
        latticefunctionsops::mmierrorsignal(dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignstateids),
                                            dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
                                            dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
                                            dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
                                            dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(logpps),
                                            dengammasMatrixRef);
    }

    void stateposteriors(const ushortvector &alignstateids, const uintvector &alignoffsets,
                         const edgeinfowithscoresvector &edges, const nodeinfovector &nodes,
                         const doublevector &logqs, Microsoft::MSR::CNTK::Matrix<float> &logacc)
    {
        ondevice no(deviceid);

        matrixref<float> logaccMatrixRef = tomatrixref(logacc);
        latticefunctionsops::stateposteriors(dynamic_cast<const vectorbaseimpl<ushortvector, vectorref<unsigned short>> &>(alignstateids),
                                             dynamic_cast<const vectorbaseimpl<uintvector, vectorref<unsigned int>> &>(alignoffsets),
                                             dynamic_cast<const vectorbaseimpl<edgeinfowithscoresvector, vectorref<msra::lattices::edgeinfowithscores>> &>(edges),
                                             dynamic_cast<const vectorbaseimpl<nodeinfovector, vectorref<msra::lattices::nodeinfo>> &>(nodes),
                                             dynamic_cast<const vectorbaseimpl<doublevector, vectorref<double>> &>(logqs),
                                             logaccMatrixRef);
    }
};

latticefunctions *newlatticefunctions(size_t deviceid)
{
    return new latticefunctionsimpl(deviceid);
}

// implementation of lrhmmdefvector
// Class has no vector-level member functions, so no need for an extra type
lrhmmdefvector *newlrhmmdefvector(size_t deviceid)
{
    return new vectorbaseimpl<lrhmmdefvector, vectorref<lrhmmdef>>(deviceid);
}
lr3transPvector *newlr3transPvector(size_t deviceid)
{
    return new vectorbaseimpl<lr3transPvector, vectorref<lr3transP>>(deviceid);
}
ushortvector *newushortvector(size_t deviceid)
{
    return new vectorbaseimpl<ushortvector, vectorref<unsigned short>>(deviceid);
}
uintvector *newuintvector(size_t deviceid)
{
    return new vectorbaseimpl<uintvector, vectorref<unsigned int>>(deviceid);
}
floatvector *newfloatvector(size_t deviceid)
{
    return new vectorbaseimpl<floatvector, vectorref<float>>(deviceid);
}
doublevector *newdoublevector(size_t deviceid)
{
    return new vectorbaseimpl<doublevector, vectorref<double>>(deviceid);
}
sizetvector *newsizetvector(size_t deviceid)
{
    return new vectorbaseimpl<sizetvector, vectorref<size_t>>(deviceid);
}
nodeinfovector *newnodeinfovector(size_t deviceid)
{
    return new vectorbaseimpl<nodeinfovector, vectorref<nodeinfo>>(deviceid);
}
edgeinfowithscoresvector *newedgeinfovector(size_t deviceid)
{
    return new vectorbaseimpl<edgeinfowithscoresvector, vectorref<edgeinfowithscores>>(deviceid);
}
aligninfovector *newaligninfovector(size_t deviceid)
{
    return new vectorbaseimpl<aligninfovector, vectorref<aligninfo>>(deviceid);
}
};
};
back to top