Revision e9396480025b9ca457d26b6f33dd07c474c6aa04 authored by liqunfu on 31 March 2020, 15:55:14 UTC, committed by GitHub on 31 March 2020, 15:55:14 UTC
1 parent e1467a7
Raw File
EvalWriter.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once

#define DATAWRITER_LOCAL
#include "DataWriter.h"

namespace Microsoft { namespace MSR { namespace CNTK {

// Evaluation Writer class
// interface to pass to evaluation DLL
template <class ElemType>
class EvalWriter : public IDataWriter
{
    std::map<std::wstring, std::vector<ElemType>*>* m_outputs; // our output data
    std::map<std::wstring, size_t>* m_dimensions;              // the number of rows for the output data
    size_t m_recordCount;                                      // count of records in this data
    size_t m_currentRecord;                                    // next record number to read

public:
    // Method to setup the data for the reader
    void SetData(std::map<std::wstring, std::vector<ElemType>*>* outputs, std::map<std::wstring, size_t>* dimensions)
    {
        m_outputs = outputs;
        m_dimensions = dimensions;
        m_currentRecord = 0;
        m_recordCount = 0;
        for (auto iter = outputs->begin(); iter != outputs->end(); ++iter)
        {
            // figure out the dimension of the data
            const std::wstring& val = iter->first;
            size_t count = (*outputs)[val]->size();

            if (dimensions->find(val) == dimensions->end())
            {
                RuntimeError("Output %ls not found in CNTK model.", val.c_str());
            }

            size_t rows = (*dimensions)[val];
            size_t recordCount = count / rows;

            if (m_recordCount != 0)
            {
                // record count must be the same for all the data
                if (recordCount != m_recordCount)
                    RuntimeError("Record Count of %ls (%lux%lu) does not match the record count of previous entries (%lu).", val.c_str(), rows, recordCount, m_recordCount);
            }
            else
            {
                m_recordCount = recordCount;
            }
        }
    }

    virtual void Init(const ConfigParameters& /*config*/) override
    {
    }
    virtual void Init(const ScriptableObjects::IConfigRecord& /*config*/) override
    {
    }

    // Destroy - cleanup and remove this class
    // NOTE: this destroys the object, and it can't be used past this point
    virtual void Destroy()
    {
        delete this;
    }

    // EvalWriter Constructor
    // config - [in] configuration parameters for the datareader
    template <class ConfigRecordType>
    EvalWriter(const ConfigRecordType& config)
    {
        m_recordCount = m_currentRecord = 0;
        Init(config);
    }

    // Destructor - free up the matrix values we allocated
    virtual ~EvalWriter()
    {
    }

    virtual void GetSections(std::map<std::wstring, SectionType, nocase_compare>& /*sections*/)
    {
        assert(false);
        NOT_IMPLEMENTED;
    }
    virtual bool SaveData(size_t /*recordStart*/, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t /*datasetSize*/, size_t /*byteVariableSized*/)
    {
        // loop through all the output vectors to copy the data over
        for (auto iter = m_outputs->begin(); iter != m_outputs->end(); ++iter)
        {
            // figure out the dimension of the data
            std::wstring val = iter->first;
            size_t rows = (*m_dimensions)[val];
            // size_t count = rows*numRecords;

            // find the output matrix we want to fill
            const std::map<std::wstring, void*, nocase_compare>::const_iterator iterIn = matrices.find(val);

            // allocate the matrix if we don't have one yet
            if (iterIn == matrices.end())
            {
                RuntimeError("No matrix data found for key '%ls', cannot continue", val.c_str());
            }

            Matrix<ElemType>* matrix = (Matrix<ElemType>*) iterIn->second;

            // copy over the data
            std::vector<ElemType>* data = iter->second;
            size_t index = m_currentRecord * rows;
            size_t numberToCopy = rows * numRecords;
            data->resize(index + numberToCopy);
            ElemType* dataPtr = ((ElemType*)data->data()) + index;
            if (matrix->GetNumElements() > numberToCopy)
                RuntimeError("The output matrix being saved has more data than the numRecords (%d) requested to be saved", (int)numRecords);

            matrix->CopyToArray(dataPtr, numberToCopy);
        }

        // increment our record pointer
        m_currentRecord += numRecords;

        // return the "done with all records" value
        return (m_currentRecord >= m_recordCount);
    }
    virtual void SaveMapping(std::wstring saveId, const std::map<typename EvalWriter<ElemType>::LabelIdType, typename EvalWriter<ElemType>::LabelType>& /*labelMapping*/){};
    virtual bool SupportMultiUtterances() const
    {
        return false;
    };
};
} } }
back to top