https://github.com/Microsoft/CNTK
Raw File
Tip revision: f1930b09f010e793a590ab00fd44975df324aa2e authored by Wayne Xiong on 26 May 2017, 01:22:51 UTC
Merge remote-tracking branch 'origin/master' into weixi/conttrain
Tip revision: f1930b0
KaldiSequenceTrainingDerivative.h
#pragma once

#include "kaldi.h"
#include "Matrix.h"
#include "basetypes.h"
#include "UtteranceDerivativeComputationInterface.h"

namespace Microsoft { namespace MSR { namespace CNTK {

// This class deals with the interaction with Kaldi in order to do sequence
// in CNTK.
template <class ElemType>
class KaldiSequenceTrainingDerivative : public UtteranceDerivativeComputationInterface<ElemType>
{
private:
    bool m_oneSilenceClass;
    wstring m_trainCriterion;
    ElemType m_oldAcousticScale;
    ElemType m_acousticScale;
    ElemType m_lmScale;
    std::vector<kaldi::int32> m_silencePhones;
    kaldi::TransitionModel m_transModel;
    kaldi::RandomAccessCompactLatticeReader* m_denlatReader;
    kaldi::RandomAccessInt32VectorReader* m_aliReader;

    // Rescores the lattice with the lastest posteriors from the neural network.
    void LatticeAcousticRescore(const wstring& uttID,
                                const Matrix<ElemType>& outputs,
                                kaldi::Lattice* lat) const;

    void ConvertPosteriorToDerivative(const kaldi::Posterior& post,
                                      Matrix<ElemType>* derivative);

public:
    // Constructor.
    KaldiSequenceTrainingDerivative(const wstring& denlatRspecifier,
                                    const wstring& aliRspecifier,
                                    const wstring& transModelFilename,
                                    const wstring& silencePhoneStr,
                                    const wstring& trainCriterion,
                                    ElemType oldAcousticScale,
                                    ElemType acousticScale,
                                    ElemType lmScale,
                                    bool oneSilenceClass);

    // Destructor.
    ~KaldiSequenceTrainingDerivative();

    bool ComputeDerivative(const wstring& uttID,
                           const Matrix<ElemType>& logLikelihood,
                           Matrix<ElemType>* derivative,
                           ElemType* objective);

    bool HasResourceForDerivative(const wstring& uttID) const;
};
} } }
back to top