https://github.com/Microsoft/CNTK
Tip revision: 16a41cef30894ca92667bd93079cd6fa11b3e92d authored by Sayan Pathak on 02 November 2017, 16:10:10 UTC
Added super resolution tutorial contributed by Borna with added code to minimize test downloads, fix tests, added documentation and small editorial changes to LSGAN tutorial
Added super resolution tutorial contributed by Borna with added code to minimize test downloads, fix tests, added documentation and small editorial changes to LSGAN tutorial
Tip revision: 16a41ce
KaldiSequenceTrainingDerivative.cpp
#include "basetypes.h"
#include "htkfeatio_utils.h"
#include "KaldiSequenceTrainingDerivative.h"
namespace Microsoft { namespace MSR { namespace CNTK {
// Constructor.
template <class ElemType>
KaldiSequenceTrainingDerivative<ElemType>::KaldiSequenceTrainingDerivative(
const wstring& denlatRspecifier, const wstring& aliRspecifier,
const wstring& transModelFilename, const wstring& silencePhoneStr,
const wstring& trainCriterion,
ElemType oldAcousticScale, ElemType acousticScale,
ElemType lmScale, bool oneSilenceClass)
{
using namespace msra::asr;
assert(denlatRspecifier != L"");
assert(aliRspecifier != L"");
m_denlatReader = new kaldi::RandomAccessCompactLatticeReader(
trimmed(fileToStr(toStr(denlatRspecifier))));
m_aliReader = new kaldi::RandomAccessInt32VectorReader(
trimmed(fileToStr(toStr(aliRspecifier))));
ReadKaldiObject(toStr(transModelFilename), &m_transModel);
m_oldAcousticScale = oldAcousticScale;
m_acousticScale = acousticScale;
m_lmScale = lmScale;
m_trainCriterion = trainCriterion;
m_oneSilenceClass = oneSilenceClass;
if (!kaldi::SplitStringToIntegers(toStr(silencePhoneStr),
":", false, &m_silencePhones))
{
LogicError("Invalid silence phone sequence.\n");
}
if (m_trainCriterion != L"mpfe" && m_trainCriterion != L"smbr")
{
LogicError("Supported sequence training criterion: mpfe, smbr.\n");
}
}
// Destructor.
template <class ElemType>
KaldiSequenceTrainingDerivative<ElemType>::~KaldiSequenceTrainingDerivative()
{
delete m_denlatReader;
delete m_aliReader;
}
template <class ElemType>
bool KaldiSequenceTrainingDerivative<ElemType>::ComputeDerivative(
const wstring& uttID,
const Matrix<ElemType>& logLikelihood,
Matrix<ElemType>* derivative,
ElemType* objective)
{
std::string uttIDStr = msra::asr::toStr(uttID);
// Sanity check.
if (m_transModel.NumPdfs() != logLikelihood.GetNumRows())
{
RuntimeError("Number of labels in logLikelihood does not match that"
" in the Kaldi model for utterance %S: %d v.s. %d\n",
uttID.c_str(), (int) logLikelihood.GetNumRows(),
(int) m_transModel.NumPdfs());
}
// Reads alignment.
if (!m_aliReader->HasKey(uttIDStr))
{
RuntimeError("Alignment not found for utterance %s\n",
uttIDStr.c_str());
}
const std::vector<int32> ali = m_aliReader->Value(uttIDStr);
if (ali.size() != logLikelihood.GetNumCols())
{
RuntimeError("Number of frames in logLikelihood does not match that"
" in the alignment for utterance %S: %d v.s. %d\n",
uttID.c_str(), (int) logLikelihood.GetNumCols(), (int) ali.size());
}
// Reads denominator lattice.
if (!m_denlatReader->HasKey(uttIDStr))
{
RuntimeError("Denominator lattice not found for utterance %S\n",
uttID.c_str());
}
kaldi::CompactLattice clat = m_denlatReader->Value(uttIDStr);
fst::CreateSuperFinal(&clat); /* One final state with weight One() */
kaldi::Lattice lat;
fst::ConvertLattice(clat, &lat);
// Does a first path of acoustic scaling. Typically this sets the old
// acoustic scale to 0.
if (m_oldAcousticScale != 1.0)
{
fst::ScaleLattice(fst::AcousticLatticeScale(m_oldAcousticScale),
&lat);
}
// Topsort lattice.
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted))
{
if (fst::TopSort(&lat) == false)
{
RuntimeError("Cycles detected in lattice: %S\n", uttID.c_str());
}
}
// Does lattice acoustic rescoring with the new posteriors from the
// neural network.
LatticeAcousticRescore(uttID, logLikelihood, &lat);
// Second pass acoustic and language model scale.
if (m_acousticScale != 1.0 || m_lmScale != 1.0)
{
fst::ScaleLattice(fst::LatticeScale(m_lmScale, m_acousticScale),
&lat);
}
// Forward-backward on the lattice.
kaldi::Posterior post, pdfPost;
if (m_trainCriterion == L"smbr")
{
*objective = kaldi::LatticeForwardBackwardMpeVariants(
m_transModel, m_silencePhones, lat,
ali, "smbr", m_oneSilenceClass, &post);
}
else if (m_trainCriterion == L"mpfe")
{
*objective = kaldi::LatticeForwardBackwardMpeVariants(
m_transModel, m_silencePhones, lat,
ali, "mpfe", m_oneSilenceClass, &post);
}
ConvertPosteriorToDerivative(post, derivative);
assert(derivative->GetNumCols() == logLikelihood.GetNumCols());
// Uses "expected error rate" instead of "expected accuracy".
*objective = logLikelihood.GetNumCols() - *objective;
return true;
}
template <class ElemType>
void KaldiSequenceTrainingDerivative<ElemType>::ConvertPosteriorToDerivative(
const kaldi::Posterior& post,
Matrix<ElemType>* derivative)
{
kaldi::Posterior pdfPost;
kaldi::ConvertPosteriorToPdfs(m_transModel, post, &pdfPost);
derivative->Resize(m_transModel.NumPdfs(), pdfPost.size());
derivative->SetValue(0);
for (size_t t = 0; t < pdfPost.size(); ++t)
{
for (size_t i = 0; i < pdfPost[t].size(); ++i)
{
size_t pdf_id = pdfPost[t][i].first;
assert(pdf_id < m_transModel.NumPdfs());
// Flips the sign below.
(*derivative)(pdf_id, t) -= pdfPost[t][i].second;
}
}
}
template <class ElemType>
void KaldiSequenceTrainingDerivative<ElemType>::LatticeAcousticRescore(
const wstring& uttID,
const Matrix<ElemType>& logLikelihood,
kaldi::Lattice* lat) const
{
// Gets time information for the lattice.
std::vector<kaldi::int32> stateTimes;
kaldi::int32 maxTime = kaldi::LatticeStateTimes(*lat, &stateTimes);
if (maxTime != logLikelihood.GetNumCols())
{
RuntimeError("Number of frames in the logLikelihood does not match"
" that in the denominator lattice for utterance %S: %d vs. %d\n",
uttID.c_str(), (int) logLikelihood.GetNumRows(), (int) maxTime);
}
std::vector<std::vector<kaldi::int32>> timeStateMap(
logLikelihood.GetNumCols());
size_t num_states = lat->NumStates();
for (size_t s = 0; s < num_states; s++)
{
assert(stateTimes[s] >= 0 && stateTimes[s] <= logLikelihood.GetNumCols());
if (stateTimes[s] < logLikelihood.GetNumCols())
{
timeStateMap[stateTimes[s]].push_back(s);
}
}
for (size_t t = 0; t < logLikelihood.GetNumCols(); ++t)
{
for (size_t i = 0; i < timeStateMap[t].size(); ++i)
{
kaldi::int32 state = timeStateMap[t][i];
for (fst::MutableArcIterator<kaldi::Lattice> aiter(lat, state);
!aiter.Done(); aiter.Next())
{
kaldi::LatticeArc arc = aiter.Value();
kaldi::int32 trans_id = arc.ilabel;
if (trans_id != 0)
{
kaldi::int32 pdf_id =
m_transModel.TransitionIdToPdf(trans_id);
arc.weight.SetValue2(-logLikelihood(pdf_id, t) + arc.weight.Value2());
aiter.SetValue(arc);
}
}
// Checks final state.
kaldi::LatticeWeight final = lat->Final(state);
if (final != kaldi::LatticeWeight::Zero())
{
final.SetValue2(0.0);
lat->SetFinal(state, final);
}
}
}
}
template <class ElemType>
bool KaldiSequenceTrainingDerivative<ElemType>::HasResourceForDerivative(
const wstring& uttID) const
{
if (m_aliReader == false || m_denlatReader == false)
{
fprintf(stderr, "WARNING: lattice or alignemnt reader has not been"
" set up yet.\n");
return false;
}
std::string uttIDStr = msra::asr::toStr(uttID);
if (!m_aliReader->HasKey(uttIDStr) || !m_denlatReader->HasKey(uttIDStr))
{
return false;
}
return true;
}
template class KaldiSequenceTrainingDerivative<float>;
template class KaldiSequenceTrainingDerivative<double>;
} } }