Revision 6cb945e453d8e5b19221a5ad820b8c10a4864368 authored by Tianjun Xiao on 21 October 2016, 14:38:14 UTC, committed by Tianjun Xiao on 21 October 2016, 14:38:14 UTC
1 parent 28ddc7e
Raw File
Utils.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#include "stdafx.h"
#include "CNTKLibrary.h"
#include "Utils.h"
#include <istream>
#include <ostream>

using namespace std;

namespace CNTK
{
    // This wrapper redefines operator<< in terms of unformatted (binary) write operation.
    struct BinaryOStreamWrapper
    {
        BinaryOStreamWrapper(ostream& s) : m_stream(s) {}

        template<typename T>
        typename std::enable_if<std::is_pod<T>::value, BinaryOStreamWrapper&>::type
        operator<<(const T& value)
        { 
            m_stream.write(reinterpret_cast<const char*>(&value), sizeof(T)); 
            return *this ; 
        }

        BinaryOStreamWrapper& operator<<(const wstring& str)
        { 
            *this << str.length();
            m_stream.write(reinterpret_cast<const char*>(str.c_str()), str.length() * sizeof(wchar_t)); 
            return *this; 
        }

        operator ostream& () { return m_stream; }

        ostream& m_stream;
        BinaryOStreamWrapper(const BinaryOStreamWrapper&) = delete; BinaryOStreamWrapper(BinaryOStreamWrapper&&) = delete; BinaryOStreamWrapper& operator=(const BinaryOStreamWrapper&) = delete; BinaryOStreamWrapper& operator=(BinaryOStreamWrapper&&) = delete;
    };

    // This wrapper redefines operator>> in terms of unformatted (binary) read operation.
    struct BinaryIStreamWrapper
    {
        BinaryIStreamWrapper(istream& s) : m_stream(s) {}

        template<typename T>
        typename std::enable_if<std::is_pod<T>::value, BinaryIStreamWrapper&>::type
        operator>>(T& value)
        { 
            static_assert(sizeof(T) <= sizeof(size_t), "size_t is the largest supported type.");
            m_stream.read(buf, sizeof(T)); 
            value = *(reinterpret_cast<T*>(buf));
            return *this ; 
        }

        BinaryIStreamWrapper& operator>>(wstring& str)
        { 
            size_t length;
            *this >> length;
            str.resize(length);
            for (size_t i = 0; i < length; ++i)
            {
                m_stream.read(buf, sizeof(wchar_t)); 
                str[i] = *(reinterpret_cast<wchar_t*>(buf));
            }

            return *this; 
        }

        operator istream& () const { return m_stream ;}

        istream& m_stream;
        char buf[sizeof(size_t)];
        BinaryIStreamWrapper(const BinaryIStreamWrapper&) = delete; BinaryIStreamWrapper(BinaryIStreamWrapper&&) = delete; BinaryIStreamWrapper& operator=(const BinaryIStreamWrapper&) = delete; BinaryIStreamWrapper& operator=(BinaryIStreamWrapper&&) = delete;
    };

    template <typename T>
    T* CreateDataPtr(const T& value)
    {
        return new T(value);
    }

    template <>
    NDArrayView* CreateDataPtr<NDArrayView>(const NDArrayView& value)
    {
        // TODO: replace this copy with an alias to value.
        NDArrayView* viewPtr = new NDArrayView(value.GetDataType(), value.Shape(), DeviceDescriptor::CPUDevice());
        viewPtr->CopyFrom(value);
        return viewPtr;
    }

    template <typename T>
    void DictionaryValue::AllocateDataPtr(const T& value)
    {
        static_assert(is_same<T, NDShape>::value ||
                      is_same<T, Axis>::value ||
                      is_same<T, wstring>::value ||
                      is_same<T, vector<DictionaryValue>>::value ||
                      is_same<T, Dictionary>::value ||
                      is_same<T, NDArrayView>::value,
                      "AllocateDataPtr called with invalid type");
        m_data.m_ptr = CreateDataPtr<T>(value);
    }

    template <typename T>
    void DictionaryValue::FreePtrAsType()
    {
        T* typedPtr = reinterpret_cast<T*>(m_data.m_ptr);
        delete typedPtr;

        m_data.m_ptr = nullptr;
    }

    template <typename ElementType> 
    bool AreEqual(NDArrayView& view1, NDArrayView& view2)
    {
        if (view1.GetDataType() != view2.GetDataType() ||
            view1.Shape() != view2.Shape())
        {
            return false;
        }

        ElementType* data1 = nullptr;
        ElementType* data2 = nullptr;
        if (view1.Device().Type() == DeviceKind::CPU)
        {
            data1 = view1.WritableDataBuffer<ElementType>();
            data2 = view2.WritableDataBuffer<ElementType>();
        }
        else
        {
            NDArrayViewPtr temp1CpuDataView = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), view1.Shape(), DeviceDescriptor::CPUDevice());
            temp1CpuDataView->CopyFrom(view1);
            data1 = temp1CpuDataView->WritableDataBuffer<ElementType>();

            NDArrayViewPtr temp2CpuDataView = MakeSharedObject<NDArrayView>(AsDataType<ElementType>(), view2.Shape(), DeviceDescriptor::CPUDevice());
            temp2CpuDataView->CopyFrom(view2);
            data2 = temp2CpuDataView->WritableDataBuffer<ElementType>();
        }

        size_t numElements = view1.Shape().TotalSize();

        for (size_t i = 0; i < numElements; ++i)
        {
            if (data1[i] != data2[i])
            {
                return false;
            }
        }
        return true;
    }

    bool DictionaryValue::operator==(const DictionaryValue& other) const
    {
        if (this == &other)
        {
            return true;
        }

        if (m_valueType != other.m_valueType)
        {
            return false;
        }
        
        switch (m_valueType)
        {
        case DictionaryValue::Type::Bool:
            return (m_data.m_boolean == other.m_data.m_boolean);
        case DictionaryValue::Type::SizeT:
            return (m_data.m_sizeT == other.m_data.m_sizeT);
        case DictionaryValue::Type::Float:
            return (m_data.m_float == other.m_data.m_float);
        case DictionaryValue::Type::Double:
            return (m_data.m_double == other.m_data.m_double);
        case DictionaryValue::Type::String:
        {
            wstring* strPtr1 = reinterpret_cast<wstring*>(m_data.m_ptr);
            wstring* strPtr2 = reinterpret_cast<wstring*>(other.m_data.m_ptr);
            return (*strPtr1 == *strPtr2);
        }
        case DictionaryValue::Type::NDShape:
        {
            NDShape* shapePtr1 = reinterpret_cast<NDShape*>(m_data.m_ptr);
            NDShape* shapePtr2 = reinterpret_cast<NDShape*>(other.m_data.m_ptr);
            return (*shapePtr1 == *shapePtr2);
        }
        case DictionaryValue::Type::Axis:
        {
            Axis* axisPtr1 = reinterpret_cast<Axis*>(m_data.m_ptr);
            Axis* axisPtr2 = reinterpret_cast<Axis*>(other.m_data.m_ptr);
            return (*axisPtr1 == *axisPtr2);
        }
        case DictionaryValue::Type::Vector:
        {   
            vector<DictionaryValue>* vectorPtr1 = reinterpret_cast<vector<DictionaryValue>*>(m_data.m_ptr);
            vector<DictionaryValue>* vectorPtr2 = reinterpret_cast<vector<DictionaryValue>*>(other.m_data.m_ptr);
            return (*vectorPtr1 == *vectorPtr2);
        }
        case DictionaryValue::Type::Dictionary:
        {
            Dictionary* dictPtr1 = reinterpret_cast<Dictionary*>(m_data.m_ptr);
            Dictionary* dictPtr2 = reinterpret_cast<Dictionary*>(other.m_data.m_ptr);
            return (*dictPtr1 == *dictPtr2);
        }
        case DictionaryValue::Type::NDArrayView:
        {
            NDArrayView* viewPtr1 = reinterpret_cast<NDArrayView*>(m_data.m_ptr);
            NDArrayView* viewPtr2 = reinterpret_cast<NDArrayView*>(other.m_data.m_ptr);

            switch (viewPtr1->GetDataType())
            {
            case DataType::Float:
                return AreEqual<float>(*viewPtr1, *viewPtr2);
            case DataType::Double:
                return AreEqual<double>(*viewPtr1, *viewPtr2);
            default:
                NOT_IMPLEMENTED;
            }
        }
        default:
            NOT_IMPLEMENTED;
        }
    }
    
    bool DictionaryValue::operator!=(const DictionaryValue& other) const
    {
        return !(*this == other);    
    }

    
    BinaryOStreamWrapper& operator<<(BinaryOStreamWrapper& stream, const NDShape& us)
    {
        auto size = us.Rank();
        stream << size;
        for (auto i = 0; i < size; i++)
        {
            stream << us[i];
        }
        return stream;
    }

    BinaryOStreamWrapper& operator<<(BinaryOStreamWrapper& stream, const Axis& us)
    {
        stream << us.StaticAxisIndex(false);
        stream << us.Name();
        stream << us.IsOrdered();

        return stream;
    }

    template <typename T>
    void Write(BinaryOStreamWrapper& stream, const NDArrayView& view)
    {
        assert(view.Device().Type() == DeviceKind::CPU);

        auto numElements = view.Shape().TotalSize();
        const T* buffer = view.DataBuffer<T>();
        for (auto i = 0; i < numElements; ++i)
        {
            stream << buffer[i];
        }
    }

    template <typename T>
    void Read(BinaryIStreamWrapper& stream, NDArrayView& view)
    {
        assert(view.Device().Type() == DeviceKind::CPU);
        
        auto numElements = view.Shape().TotalSize();
        T* buffer = view.WritableDataBuffer<T>();
        for (auto i = 0; i < numElements; ++i)
        {
            stream >> buffer[i];
        }
    }

    istream& operator>>(istream& stdStream, DictionaryValue& us)
    {
        BinaryIStreamWrapper stream(stdStream);
        size_t version;
        stream >> version;
        
        unsigned int type;
        stream >> type;
        us.m_valueType = static_cast<DictionaryValue::Type>(type);

        switch (us.ValueType())
        {
        case DictionaryValue::Type::Bool:
            stream >> us.m_data.m_boolean;
            break;
        case DictionaryValue::Type::SizeT:
            stream >> us.m_data.m_sizeT;
            break;
        case DictionaryValue::Type::Float:
            stream >> us.m_data.m_float;
            break;
        case DictionaryValue::Type::Double:
            stream >> us.m_data.m_double;
            break;
        case DictionaryValue::Type::String:
        {
            wstring* strPtr = new wstring();
            stream >> *strPtr;
            us.m_data.m_ptr = strPtr;
            break;
        }
        case DictionaryValue::Type::NDShape:
        {
            size_t size;
            stream >> size;
            NDShape* shapePtr = new NDShape(size);
            for (auto i = 0; i < size; i++)
            {
                stream >> shapePtr->operator[](i);
            }
            us.m_data.m_ptr = shapePtr;
            break;
        }
        case DictionaryValue::Type::Axis:
        {
            size_t staticAxisIdx;
            stream >> staticAxisIdx;

            std::wstring axisName;
            stream >> axisName;

            bool isOrderedDynamicAxis;
            stream >> isOrderedDynamicAxis;

            Axis* axisPtr = nullptr;
            if (Axis(staticAxisIdx).IsStaticAxis())
                axisPtr = new Axis(staticAxisIdx);
            else
                axisPtr = new Axis(axisName, isOrderedDynamicAxis);

            us.m_data.m_ptr = axisPtr;
            break;
        }
        case DictionaryValue::Type::Vector:
        {   
            size_t size;
            stream >> size;
            vector<DictionaryValue>* vectorPtr = new vector<DictionaryValue>(size);
            for (auto i = 0; i < size; i++)
            {
                stream >> vectorPtr->at(i);
            }
            us.m_data.m_ptr = vectorPtr;
            break;
        }
        case DictionaryValue::Type::Dictionary:
        {
            Dictionary* dictPtr = new Dictionary();
            stream >> *dictPtr;
            us.m_data.m_ptr = dictPtr;
            break;
        }
        case DictionaryValue::Type::NDArrayView:
        {
            unsigned int type;
            stream >> type;
            DataType dtype = static_cast<DataType>(type);

            size_t size;
            stream >> size;
            NDShape shape(size);
            for (auto i = 0; i < size; i++)
            {
                stream >> shape[i];
            }

            NDArrayView* viewPtr = new NDArrayView(dtype, shape, DeviceDescriptor::CPUDevice());
            switch (dtype)
            {
            case DataType::Float:
                Read<float>(stream, *viewPtr);
                break;
            case DataType::Double:
                Read<double>(stream, *viewPtr);
                break;
            default:
                LogicError("Unsupported DataType %s", DataTypeName(dtype));
            }

            us.m_data.m_ptr = viewPtr;
            break;
        }
        default:
            NOT_IMPLEMENTED;
        }
        return stream;
    }

    ostream& operator<<(ostream& stdStream, const DictionaryValue& us)
    {
        BinaryOStreamWrapper stream(stdStream);

        stream << us.version;

        stream << static_cast<unsigned int>(us.ValueType());

        switch (us.ValueType())
        {
        case DictionaryValue::Type::Bool:
            stream << us.m_data.m_boolean;
            break;
        case DictionaryValue::Type::SizeT:
            stream << us.m_data.m_sizeT;
            break;
        case DictionaryValue::Type::Float:
            stream << us.m_data.m_float;
            break;
        case DictionaryValue::Type::Double:
            stream << us.m_data.m_double;
            break;
        case DictionaryValue::Type::String:
        {
            wstring* stringPtr = reinterpret_cast<wstring*>(us.m_data.m_ptr);
            stream << *stringPtr;
            break;
        }
        case DictionaryValue::Type::NDShape:
        {
            NDShape* shapePtr = reinterpret_cast<NDShape*>(us.m_data.m_ptr);
            stream << *shapePtr;
            break;
        }
        case DictionaryValue::Type::Axis:
        {
            Axis* axisPtr = reinterpret_cast<Axis*>(us.m_data.m_ptr);
            stream << *axisPtr;
            break;
        }
        case DictionaryValue::Type::Vector:
        {
            vector<DictionaryValue>* vectorPtr =
                reinterpret_cast<vector<DictionaryValue>*>(us.m_data.m_ptr);
            auto size = vectorPtr->size();
            stream << size;
            for (auto i = 0; i < size; i++)
            {
                stream << vectorPtr->at(i);
            }
            break;
        }
        case DictionaryValue::Type::Dictionary:
        {
            Dictionary* dictPtr = reinterpret_cast<Dictionary*>(us.m_data.m_ptr);
            stream << *dictPtr;
            break;
        }
        case DictionaryValue::Type::NDArrayView:
        {
            NDArrayView* viewPtr = reinterpret_cast<NDArrayView*>(us.m_data.m_ptr);
            stream << static_cast<unsigned int>(viewPtr->GetDataType());
            stream << viewPtr->Shape();
            switch (viewPtr->GetDataType())
            {
            case DataType::Float:
                Write<float>(stream, *viewPtr);
                break;
            case DataType::Double:
                Write<double>(stream, *viewPtr);
                break;
            default:
                LogicError("Unsupported DataType %s", DataTypeName(viewPtr->GetDataType()));
            }
            break;
        }
        default:
            NOT_IMPLEMENTED;
        }
        return stream;
    }

    Dictionary::Dictionary()
        : m_dictionaryData(new unordered_map <wstring, DictionaryValue>)
    {
    }

    Dictionary::~Dictionary()
    {
    }

    Dictionary::Dictionary(const Dictionary& other)
    {
        *this = other;
    }

    Dictionary& Dictionary::operator=(const Dictionary& other)
    {
        assert(this != &other);
        m_dictionaryData.reset(new unordered_map<wstring, DictionaryValue>(*(other.m_dictionaryData)));
        return *this;
    }

    Dictionary::Dictionary(Dictionary&& other)
        : m_dictionaryData(nullptr)
    {
        *this = move(other);
    }

    Dictionary& Dictionary::operator=(Dictionary&& other)
    {
        assert(this != &other);

        m_dictionaryData = other.m_dictionaryData;
        other.m_dictionaryData = nullptr;

        return *this;
    }

    DictionaryValue& Dictionary::operator[](const wchar_t* key)
    {
        return (*m_dictionaryData)[key];
    }

    DictionaryValue Dictionary::operator[](const wchar_t* key) const
    {
        return m_dictionaryData->at(key);
    }

    bool Dictionary::Contains(const wchar_t* key) const
    {
        return (m_dictionaryData->find(key) != m_dictionaryData->end());
    }

    bool Dictionary::operator==(const Dictionary& other) const
    {
        if (this == &other)
        {
            return true;
        }

        if (m_dictionaryData->size() != other.m_dictionaryData->size())
        {
            return false;
        }
        
        for (auto& kv : *m_dictionaryData)
        {
            auto result = other.m_dictionaryData->find(kv.first);
            if (result == other.m_dictionaryData->end() || kv.second != result->second)
            {
                return false;
            }
        }

        return true;
    }
    
    bool Dictionary::operator!=(const Dictionary& other) const
    {
        return !(*this == other);    
    }

    ostream& operator<<(ostream& stdStream, const Dictionary& us)
    {
        BinaryOStreamWrapper stream(stdStream);
        stream << us.version;
        stream << us.m_dictionaryData->size();
        for (auto& kv : *(us.m_dictionaryData))
        {
            stream << kv.first;
            stream << kv.second;
        }
        return stream;
    }

    istream& operator>>(istream& stdStream, Dictionary& us)
    {
        BinaryIStreamWrapper stream(stdStream);
        size_t version;
        stream >> version;
        size_t size;
        stream >> size;
        us.m_dictionaryData->reserve(size);
        for (auto i = 0; i < size; i++)
        {
            wstring key;
            stream >> key;
            stream >> us[key];
        }
        return stream;
    }

    // Returns the element whose key is greater than the required sample count 
    // or the last element if no such key exists.
    template <typename T>
    const T& TrainingParameterSchedule<T>::operator[](size_t sampleCount) const
    {
        assert(m_schedule.size() > 0);
        auto it = m_schedule.upper_bound(sampleCount);
        if (it == m_schedule.end())
        {
            --it;
        }
        return it->second;
    }

    template void DictionaryValue::AllocateDataPtr<NDShape>(const NDShape& value);
    template void DictionaryValue::AllocateDataPtr<Axis>(const Axis& value);
    template void DictionaryValue::AllocateDataPtr<vector<DictionaryValue>>(const vector<DictionaryValue>& value);
    template void DictionaryValue::AllocateDataPtr<wstring>(const wstring& value);
    template void DictionaryValue::AllocateDataPtr<Dictionary>(const Dictionary& value);
    template void DictionaryValue::AllocateDataPtr<NDArrayView>(const NDArrayView& value);

    template void DictionaryValue::FreePtrAsType<NDShape>();
    template void DictionaryValue::FreePtrAsType<Axis>();
    template void DictionaryValue::FreePtrAsType<vector<DictionaryValue>>();
    template void DictionaryValue::FreePtrAsType<wstring>();
    template void DictionaryValue::FreePtrAsType<Dictionary>();
    template void DictionaryValue::FreePtrAsType<NDArrayView>();

    template const double& TrainingParameterSchedule<double>::operator[](size_t key) const;
}
back to top