// // // Copyright (c) Microsoft Corporation. All rights reserved. // // // #include "stdafx.h" #include #include "Basics.h" #include #include #define DATAWRITER_EXPORTS // creating the exports here #include "DataWriter.h" #include "SequenceReader.h" #include "SequenceWriter.h" #include "commandArgUtil.h" #ifdef LEAKDETECT #include // for memory leak detection #endif namespace Microsoft { namespace MSR { namespace CNTK { // Create a Data Writer //DATAWRITER_API IDataWriter* DataWriterFactory(void) // comparison, not case sensitive. template bool LMSequenceWriter::compare_val(const ElemType& first, const ElemType& second) { return (first < second); } template void LMSequenceWriter::Init(const ConfigParameters& writerConfig) { udims.clear(); ConfigArray outputNames = writerConfig("outputNodeNames", ""); if (outputNames.size()<1) RuntimeError("writer needs at least one outputNodeName specified in config"); foreach_index(i, outputNames) // inputNames should map to node names { ConfigParameters thisOutput = writerConfig(outputNames[i]); outputFiles[outputNames[i]] = thisOutput("file"); int iN = thisOutput("nbest", "1"); nBests[outputNames[i]] = iN; wstring fname = thisOutput("token"); /// read unk sybol mUnk[outputNames[i]] = writerConfig("unk", ""); SequenceReader::ReadClassInfo(fname, class_size, word4idx[outputNames[i]], idx4word[outputNames[i]], idx4class[outputNames[i]], idx4cnt[outputNames[i]], 0, mUnk[outputNames[i]], m_noiseSampler, false); size_t dim = idx4word[outputNames[i]].size(); udims.push_back(dim); } } template void LMSequenceWriter::ReadLabelInfo(const wstring & vocfile, map & word4idx, map& idx4word) { char strFileName[MAX_STRING]; char stmp[MAX_STRING]; string strtmp; size_t sz; int b; wcstombs_s(&sz, strFileName, 2048, vocfile.c_str(), vocfile.length()); FILE * vin; vin = fopen(strFileName, "rt"); if (vin == nullptr) { RuntimeError("cannot open word class file"); } b = 0; while (!feof(vin)){ fscanf_s(vin, "%s\n", stmp, _countof(stmp)); word4idx[stmp] = b; idx4word[b++] = stmp; } fclose(vin); } template void LMSequenceWriter::Destroy() { for (auto ptr = outputFileIds.begin(); ptr != outputFileIds.end(); ptr++) { fclose(ptr->second); } } template bool LMSequenceWriter::SaveData(size_t /*recordStart*/, const std::map& matrices, size_t /*numRecords*/, size_t /*datasetSize*/, size_t /*byteVariableSized*/) { for (auto iter = matrices.begin(); iter != matrices.end(); iter++) { string outputName = ws2s(iter->first); Matrix& outputData = *(static_cast*>(iter->second)); wstring outFile = outputFiles[s2ws(outputName)]; SaveToFile(outFile, outputData, idx4word[iter->first], nBests[outputName]); } return true; } template void LMSequenceWriter::SaveToFile(std::wstring& outputFile, const Matrix& outputData, const map& idx2wrd, const int& nbest) { size_t nT = outputData.GetNumCols(); size_t nD = min(idx2wrd.size(), outputData.GetNumRows()); FILE *fp = nullptr; vector> lv; auto NbestComparator = [](const pair& lv, const pair& rv){return lv.second > rv.second; }; if (outputFileIds.find(outputFile) == outputFileIds.end()) { FILE* ofs; msra::files::make_intermediate_dirs(outputFile); string str(outputFile.begin(), outputFile.end()); ofs = fopen(str.c_str(), "wt"); if (ofs == nullptr) RuntimeError("Cannot open %s for writing", str.c_str()); outputFileIds[outputFile] = ofs; fp = ofs; } else fp = outputFileIds[outputFile]; for (int j = 0; j< nT; j++) { int imax = 0; ElemType fmax = outputData(imax, j); lv.clear(); if (nbest > 1) lv.push_back(pair(0, fmax)); for (int i = 1; i 1) lv.push_back(pair(i, outputData(i, j))); if (outputData(i, j) > fmax) { fmax = outputData(i, j); imax = i; } } if (nbest > 1) sort(lv.begin(), lv.end(), NbestComparator); for (int i = 0; i < nbest; i++) { if (nbest > 1) { if (lv[i].second != 0) { int idx = (int)lv[i].first; string sRes = idx2wrd.find(idx)->second; fprintf(fp, "%s ", sRes.c_str()); } } else { string sRes = idx2wrd.find(imax)->second; fprintf(fp, "%s ", sRes.c_str()); fprintf(stderr, "%s ", sRes.c_str()); } } } fprintf(fp, "\n"); fprintf(stderr, "\n"); } template class LMSequenceWriter; template class LMSequenceWriter; } } }