https://github.com/Microsoft/CNTK
Raw File
Tip revision: 517947be3d5bc31a5cefa6472781c3b6cc905855 authored by Mark Hillebrand on 18 January 2016, 08:36:51 UTC
License change
Tip revision: 517947b
SequenceWriter.cpp
//
// <copyright file="LMSequenceWriter.cpp" company="Microsoft">
//     Copyright (c) Microsoft Corporation.  All rights reserved.
// </copyright>
//

//

#include "stdafx.h"
#include <objbase.h>
#include "Basics.h"
#include <fstream>
#include <algorithm>

#define DATAWRITER_EXPORTS  // creating the exports here
#include "DataWriter.h"
#include "SequenceReader.h"
#include "SequenceWriter.h"
#include "commandArgUtil.h"
#ifdef LEAKDETECT
#include <vld.h> // for memory leak detection
#endif

namespace Microsoft {
    namespace MSR {
        namespace CNTK {

            // Create a Data Writer
            //DATAWRITER_API IDataWriter* DataWriterFactory(void)


            // comparison, not case sensitive.
            template<class ElemType>
            bool LMSequenceWriter<ElemType>::compare_val(const ElemType& first, const ElemType& second)
            {
                return (first < second);
            }

            template<class ElemType>
            void LMSequenceWriter<ElemType>::Init(const ConfigParameters& writerConfig)
            {
                udims.clear();

                ConfigArray outputNames = writerConfig("outputNodeNames", "");
                if (outputNames.size()<1)
                    RuntimeError("writer needs at least one outputNodeName specified in config");

                foreach_index(i, outputNames) // inputNames should map to node names
                {
                    ConfigParameters thisOutput = writerConfig(outputNames[i]);
                    outputFiles[outputNames[i]] = thisOutput("file");
                    int iN = thisOutput("nbest", "1");
                    nBests[outputNames[i]] = iN;
                    wstring fname = thisOutput("token");
                    /// read unk sybol
                    mUnk[outputNames[i]] = writerConfig("unk", "<unk>");

                    SequenceReader<ElemType>::ReadClassInfo(fname, class_size,
                        word4idx[outputNames[i]],
                        idx4word[outputNames[i]],
                        idx4class[outputNames[i]],
                        idx4cnt[outputNames[i]],
                        0,
                        mUnk[outputNames[i]], 
                        m_noiseSampler,
                        false);
                    size_t dim = idx4word[outputNames[i]].size();
                    udims.push_back(dim);

                }

            }

            template<class ElemType>
            void LMSequenceWriter<ElemType>::ReadLabelInfo(const wstring & vocfile,
                map<string, int> & word4idx,
                map<int, string>& idx4word)
            {
                char strFileName[MAX_STRING];
                char stmp[MAX_STRING];
                string strtmp;
                size_t sz;
                int b;

                wcstombs_s(&sz, strFileName, 2048, vocfile.c_str(), vocfile.length());

                FILE * vin;
                vin = fopen(strFileName, "rt");

                if (vin == nullptr)
                {
                    RuntimeError("cannot open word class file");
                }
                b = 0;
                while (!feof(vin)){
                    fscanf_s(vin, "%s\n", stmp, _countof(stmp));
                    word4idx[stmp] = b;
                    idx4word[b++] = stmp;
                }
                fclose(vin);

            }

            template<class ElemType>
            void LMSequenceWriter<ElemType>::Destroy()
            {
                for (auto ptr = outputFileIds.begin(); ptr != outputFileIds.end(); ptr++)
                {
                    fclose(ptr->second);
                }
            }

            template<class ElemType>
            bool LMSequenceWriter<ElemType>::SaveData(size_t /*recordStart*/, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t /*numRecords*/, size_t /*datasetSize*/, size_t /*byteVariableSized*/)
            {

                for (auto iter = matrices.begin(); iter != matrices.end(); iter++)
                {
                    string outputName = ws2s(iter->first);
                    Matrix<ElemType>& outputData = *(static_cast<Matrix<ElemType>*>(iter->second));
                    wstring outFile = outputFiles[s2ws(outputName)];

                    SaveToFile(outFile, outputData, idx4word[iter->first], nBests[outputName]);
                }

                return true;
            }

            template<class ElemType>
            void LMSequenceWriter<ElemType>::SaveToFile(std::wstring& outputFile, const Matrix<ElemType>& outputData, const map<int, string>& idx2wrd, const int& nbest)
            {
                size_t nT = outputData.GetNumCols();
                size_t nD = min(idx2wrd.size(), outputData.GetNumRows());
                FILE *fp = nullptr;
                vector<pair<size_t, ElemType>> lv;

                auto NbestComparator = [](const pair<size_t, ElemType>& lv, const pair<size_t, ElemType>& rv){return lv.second > rv.second; };

                if (outputFileIds.find(outputFile) == outputFileIds.end())
                {
                    FILE* ofs;
                    msra::files::make_intermediate_dirs(outputFile);
                    string str(outputFile.begin(), outputFile.end());
                    ofs = fopen(str.c_str(), "wt");
                    if (ofs == nullptr)
                        RuntimeError("Cannot open %s for writing", str.c_str());
                    outputFileIds[outputFile] = ofs;
                    fp = ofs;
                }
                else
                    fp = outputFileIds[outputFile];

                for (int j = 0; j< nT; j++)
                {
                    int imax = 0;
                    ElemType fmax = outputData(imax, j);
                    lv.clear();
                    if (nbest > 1) lv.push_back(pair<size_t, ElemType>(0, fmax));
                    for (int i = 1; i<nD; i++)
                    {
                        if (nbest > 1) lv.push_back(pair<size_t, ElemType>(i, outputData(i, j)));
                        if (outputData(i, j) > fmax)
                        {
                            fmax = outputData(i, j);
                            imax = i;
                        }
                    }
                    if (nbest > 1) sort(lv.begin(), lv.end(), NbestComparator);
                    for (int i = 0; i < nbest; i++)
                    {
                        if (nbest > 1)
                        {
                            if (lv[i].second != 0)
                            {
                                int idx = (int)lv[i].first;
                                string sRes = idx2wrd.find(idx)->second;
                                fprintf(fp, "%s ", sRes.c_str());
                            }
                        }
                        else
                        {
                            string sRes = idx2wrd.find(imax)->second;
                            fprintf(fp, "%s ", sRes.c_str());
                            fprintf(stderr, "%s ", sRes.c_str());
                        }
                    }
                }
                fprintf(fp, "\n");
                fprintf(stderr, "\n");
            }


            template class LMSequenceWriter<float>;
            template class LMSequenceWriter<double>;

        }
    }
}
back to top