https://github.com/Microsoft/CNTK
Raw File
Tip revision: e2377424a28e9b54bcc0f92373c0d715ca2ed8a3 authored by Bowen Bao on 05 July 2018, 19:49:49 UTC
Moving sequential convolution in python to a new high level api, to maintain compatibility with previous implementation (special case 1d sequential convolution).
Tip revision: e237742
KaldiSequenceTrainingDerivative.h
#pragma once

#include "kaldi.h"
#include "Matrix.h"
#include "basetypes.h"
#include "UtteranceDerivativeComputationInterface.h"

namespace Microsoft { namespace MSR { namespace CNTK {

// This class deals with the interaction with Kaldi in order to do sequence
// in CNTK.
template <class ElemType>
class KaldiSequenceTrainingDerivative : public UtteranceDerivativeComputationInterface<ElemType>
{
private:
    bool m_oneSilenceClass;
    wstring m_trainCriterion;
    ElemType m_oldAcousticScale;
    ElemType m_acousticScale;
    ElemType m_lmScale;
    std::vector<kaldi::int32> m_silencePhones;
    kaldi::TransitionModel m_transModel;
    kaldi::RandomAccessCompactLatticeReader* m_denlatReader;
    kaldi::RandomAccessInt32VectorReader* m_aliReader;

    // Rescores the lattice with the lastest posteriors from the neural network.
    void LatticeAcousticRescore(const wstring& uttID,
                                const Matrix<ElemType>& outputs,
                                kaldi::Lattice* lat) const;

    void ConvertPosteriorToDerivative(const kaldi::Posterior& post,
                                      Matrix<ElemType>* derivative);

public:
    // Constructor.
    KaldiSequenceTrainingDerivative(const wstring& denlatRspecifier,
                                    const wstring& aliRspecifier,
                                    const wstring& transModelFilename,
                                    const wstring& silencePhoneStr,
                                    const wstring& trainCriterion,
                                    ElemType oldAcousticScale,
                                    ElemType acousticScale,
                                    ElemType lmScale,
                                    bool oneSilenceClass);

    // Destructor.
    ~KaldiSequenceTrainingDerivative();

    bool ComputeDerivative(const wstring& uttID,
                           const Matrix<ElemType>& logLikelihood,
                           Matrix<ElemType>* derivative,
                           ElemType* objective);

    bool HasResourceForDerivative(const wstring& uttID) const;
};
} } }
back to top