// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include "CNTKLibrary.h" #include "Utils.h" #include #include #include #include #include #ifdef _MSC_VER #include #endif #pragma warning(push) #pragma warning(disable : 4800 4267 4610 4512 4100 4510) #include "CNTK.pb.h" #include #include #pragma warning(pop) namespace CNTK { using namespace ::google::protobuf; static const uint32 MAGIC_NUMBER = 0x636e746bU; static const uint32 BLOCK_SIZE = 8 << 10; // 8Kb; static void SetUTF8Locale() { #ifndef _MSC_VER if (std::setlocale(LC_ALL, "C.UTF-8") == nullptr) { std::setlocale(LC_ALL, "en_US.UTF-8"); } #endif } static void UnsetUTF8Locale() { #ifndef _MSC_VER std::setlocale(LC_ALL, ""); #endif } struct UsingUTF8 { UsingUTF8() { SetUTF8Locale(); } ~UsingUTF8() { UnsetUTF8Locale(); } }; // Variation on EncodeFloat/EncodeDouble/DecodeFloat/DecodeDouble from // ::google::protobuf::internal::WireFormatLite. template inline DST Encode(SRC value) { union { SRC src; DST dst; }; src = value; return dst; } class RenewableCodedStream { public: RenewableCodedStream(io::ZeroCopyInputStream& input) : m_input(input) { Renew(); } template inline bool Read(T* value) { auto size = sizeof(T); if (m_codedInputPtr->CurrentPosition() > INT_MAX - size) Renew(); if (size == sizeof(uint32)) { uint32 tmp; if (!m_codedInputPtr->ReadLittleEndian32(&tmp)) return false; *value = Encode(tmp); return true; } if (size == sizeof(uint64)) { uint64 tmp; if (!m_codedInputPtr->ReadLittleEndian64(&tmp)) return false; *value = Encode(tmp); return true; } return false; } private: void Renew() { delete m_codedInputPtr.release(); m_codedInputPtr = make_unique(&m_input); m_codedInputPtr->SetTotalBytesLimit(INT_MAX, INT_MAX); } io::ZeroCopyInputStream& m_input; std::unique_ptr m_codedInputPtr; }; class Serializer { friend std::ostream& operator<<(std::ostream&, const Dictionary&); friend std::ostream& operator<<(std::ostream&, const DictionaryValue&); friend std::istream& operator>>(std::istream&, Dictionary&); friend std::istream& operator>>(std::istream&, DictionaryValue&); friend class Dictionary; friend class DictionaryValue; Serializer(const Dictionary& dict); Serializer(const DictionaryValue& dict); Serializer() = default; private: proto::DictionaryValue* CreateProto(const DictionaryValue& src, Arena* arena = nullptr); proto::Dictionary* CreateProto(const Dictionary& src, Arena* arena = nullptr); proto::Vector* CreateProto(const std::vector& src, Arena* arena = nullptr); proto::NDArrayView* CreateProto(const NDArrayView& src, Arena* arena = nullptr); proto::Axis* CreateProto(const Axis& src, Arena* arena = nullptr); proto::NDShape* CreateProto(const NDShape& src, Arena* arena = nullptr); void Copy(const DictionaryValue& src, proto::DictionaryValue& dst, Arena* arena = nullptr); void CopyNDArrayViewDataToProtos(); void WriteNDArrayViewData(io::CodedOutputStream& output); std::ostream& Write(std::ostream& stream); void Write(const std::wstring& filename); void Write(io::ZeroCopyOutputStream& stream); bool Read(std::istream& stream, Dictionary& dict); bool Read(std::istream& stream, DictionaryValue& value); bool Read(const std::wstring& filename, Dictionary& dict); bool Read(const std::wstring& filename, DictionaryValue& value); bool Read(std::wstring filename, const std::function& callback); bool Read(std::istream& stream, const std::function& callback); bool ReadNDArrayViewData(io::ZeroCopyInputStream& input); size_t GetTotalByteSize() { return m_byteSize + m_proto->ByteSizeLong(); } bool FitsIntoProtobuf() { return GetTotalByteSize() < static_cast(INT_MAX); } Dictionary* CreateFromProto(const proto::Dictionary& src); std::vector* CreateFromProto(const proto::Vector& src); NDArrayView* CreateFromProto(const proto::NDArrayView& src); Axis* CreateFromProto(const proto::Axis& src); NDShape* CreateFromProto(const proto::NDShape& src); void Copy(const proto::Dictionary& src, Dictionary& dst); void Copy(const proto::DictionaryValue& src, DictionaryValue& dst); static proto::NDArrayView::DataType ToProtoType(DataType type) { if (!proto::NDArrayView::DataType_IsValid((int)type)) { InvalidArgument("NDArrayView::DataType is invalid."); } return proto::NDArrayView_DataType(type); } static DataType FromProtoType(proto::NDArrayView::DataType type) { if (!proto::NDArrayView::DataType_IsValid(type)) { InvalidArgument("NDArrayView::DataType is invalid."); } return DataType(type); } static proto::NDArrayView::StorageFormat ToProtoType(StorageFormat type) { if (!proto::NDArrayView::StorageFormat_IsValid((int)type)) { InvalidArgument("NDArrayView::StorageFormat is invalid."); } return proto::NDArrayView_StorageFormat(type); } static StorageFormat FromProtoType(proto::NDArrayView::StorageFormat type) { if (!proto::NDArrayView::StorageFormat_IsValid((int)type)) { InvalidArgument("NDArrayView::StorageFormat is invalid."); } return StorageFormat(type); } static proto::DictionaryValue::Type ToProtoType(DictionaryValue::Type type) { if (!proto::DictionaryValue::Type_IsValid((int)type)) { InvalidArgument("DictionaryValue::Type is invalid."); } return proto::DictionaryValue_Type(type); } static DictionaryValue::Type FromProtoType(proto::DictionaryValue::Type type) { if (!proto::DictionaryValue::Type_IsValid((int)type)) { InvalidArgument("DictionaryValue::Type is invalid."); } return DictionaryValue::Type(type); } template static void CopyData(const NDArrayView& src, RepeatedField* dst) { auto size = src.Shape().TotalSize(); dst->Resize((int)size, T()); const T* buffer = src.DataBuffer(); memcpy(dst->mutable_data(), buffer, (int)size * sizeof(T)); } template static void WriteData(const NDArrayView& src, io::CodedOutputStream& output) { auto size = src.Shape().TotalSize(); const T* buffer = src.DataBuffer(); auto tSize = sizeof(T); for (auto i = 0; i < size; i++) { auto value = buffer[i]; if (tSize <= sizeof(uint32)) output.WriteLittleEndian32(Encode(value)); else output.WriteLittleEndian64(Encode(value)); } } template static bool ReadData(RenewableCodedStream& input, NDArrayView& dst) { auto size = dst.Shape().TotalSize(); T* buffer = dst.WritableDataBuffer(); for (auto i = 0; i < size; i++) { if (!input.Read(buffer+i)) return false; } return true; } template static void CopyData(const RepeatedField& src, NDArrayView* dst) { auto size = src.size(); assert(size == dst->Shape().TotalSize());; T* buffer = dst->WritableDataBuffer(); memcpy(buffer, src.data(), size * sizeof(T)); } UsingUTF8 m_locale; Arena m_arena; Message* m_proto; std::vector> m_arrayViews; size_t m_byteSize {0}; }; Serializer::Serializer(const Dictionary& dict) { m_proto = CreateProto(dict, &m_arena); } Serializer::Serializer(const DictionaryValue& value) { m_proto = CreateProto(value, &m_arena); } void Serializer::CopyNDArrayViewDataToProtos() { for (auto& pair : m_arrayViews) { const auto& src = *(pair.first); auto dst = pair.second; if (src.GetDataType() == DataType::Float) { CopyData(src, dst->mutable_float_values()->mutable_value()); } else if (src.GetDataType() == DataType::Double) { CopyData(src, dst->mutable_double_values()->mutable_value()); } } } void Serializer::WriteNDArrayViewData(io::CodedOutputStream& output) { for (auto& pair : m_arrayViews) { const auto& src = *(pair.first); if (src.GetDataType() == DataType::Float) { WriteData(src, output); } else if (src.GetDataType() == DataType::Double) { WriteData(src, output); } } } bool Serializer::ReadNDArrayViewData(io::ZeroCopyInputStream& input) { if (m_arrayViews.size() == 0) return true; RenewableCodedStream wrapper(input); for (auto& pair : m_arrayViews) { auto& dst = *(pair.first); if (dst.GetDataType() == DataType::Float) { if (!ReadData(wrapper, dst)) return false; } else if (dst.GetDataType() == DataType::Double) { if (!ReadData(wrapper, dst)) return false; } } return true; } proto::NDShape* Serializer::CreateProto(const NDShape& src, Arena* arena) { proto::NDShape* dst = (arena != nullptr) ? Arena::CreateMessage(arena) : new proto::NDShape(); auto size = src.Rank(); dst->mutable_shape_dim()->Reserve((int)size); for (auto i = 0; i < size; i++) { dst->add_shape_dim(src[i]); } return dst; } NDShape* Serializer::CreateFromProto(const proto::NDShape& src) { auto size = src.shape_dim_size(); NDShape* dst = new NDShape(size); for (auto i = 0; i < size; i++) { dst->operator[](i) = size_t(src.shape_dim()[i]); } return dst; } proto::Axis* Serializer::CreateProto(const Axis& src, Arena* arena) { proto::Axis* dst = (arena != nullptr) ? Arena::CreateMessage(arena) : new proto::Axis(); dst->set_static_axis_idx(src.StaticAxisIndex(false)); dst->set_name(ToString(src.Name())); dst->set_is_ordered_dynamic_axis(src.IsOrdered()); return dst; } Axis* Serializer::CreateFromProto(const proto::Axis& src) { if (!Axis(src.static_axis_idx()).IsDynamicAxis()) { return new Axis(src.static_axis_idx()); } else { return new Axis(ToWString(src.name()), src.is_ordered_dynamic_axis()); } } proto::NDArrayView* Serializer::CreateProto(const NDArrayView& src, Arena* arena) { proto::NDArrayView* dst = (arena != nullptr) ? Arena::CreateMessage(arena) : new proto::NDArrayView(); dst->set_data_type(ToProtoType(src.GetDataType())); dst->set_allocated_shape(CreateProto(src.Shape(), arena)); dst->set_storage_format(ToProtoType(src.GetStorageFormat())); m_arrayViews.push_back({const_cast(&src), dst }); auto numElements = src.Shape().TotalSize(); auto dataSize = DataTypeSize(src.GetDataType()); if (numElements > SIZE_MAX / dataSize) RuntimeError("Bytes size of NDArrayView exceeds %zu.", SIZE_MAX); m_byteSize += numElements * dataSize; return dst; } NDArrayView* Serializer::CreateFromProto(const proto::NDArrayView& src) { if (!proto::NDArrayView::DataType_IsValid(src.data_type()) || !proto::NDArrayView::StorageFormat_IsValid(src.storage_format())) { return nullptr; } std::unique_ptr shape(CreateFromProto(src.shape())); auto dataType = FromProtoType(src.data_type()); auto storageFormat = FromProtoType(src.storage_format()); NDArrayView* dst = new NDArrayView(dataType, storageFormat, *shape, DeviceDescriptor::CPUDevice()); if (dataType == DataType::Float) { if (src.float_values().value().size() == shape->TotalSize()) CopyData(src.float_values().value(), dst); else m_arrayViews.push_back({ dst, nullptr }); } else if (dataType == DataType::Double) { if (src.double_values().value().size() == shape->TotalSize()) CopyData(src.double_values().value(), dst); else m_arrayViews.push_back({ dst, nullptr }); } return dst; } proto::Vector* Serializer::CreateProto(const std::vector& src, Arena* arena) { proto::Vector* dst = (arena != nullptr) ? Arena::CreateMessage(arena) : new proto::Vector(); dst->mutable_value()->Reserve((int)src.size()); for (const auto& value : src) { dst->mutable_value()->AddAllocated(CreateProto(value, arena)); } return dst; } std::vector* Serializer::CreateFromProto(const proto::Vector& src) { std::vector* dst = new std::vector(src.value_size()); for (auto i = 0; i < src.value_size(); ++i) { Copy(src.value()[i], dst->at(i)); } return dst; } proto::Dictionary* Serializer::CreateProto(const Dictionary& src, Arena* arena) { proto::Dictionary* dst = (arena != nullptr) ? Arena::CreateMessage(arena) : new proto::Dictionary(); dst->set_version(src.s_version); for (const auto& kv : src) { Copy(kv.second, dst->mutable_data()->operator[](ToString(kv.first)), arena); } return dst; } Dictionary* Serializer::CreateFromProto(const proto::Dictionary& src) { Dictionary* dst = new Dictionary(); for (const auto& kv : src.data()) { Copy(kv.second, dst->operator[](ToWString(kv.first))); } return dst; } proto::DictionaryValue* Serializer::CreateProto(const DictionaryValue& src, Arena* arena) { proto::DictionaryValue* dst = (arena != nullptr) ? Arena::CreateMessage(arena) : new proto::DictionaryValue(); dst->set_version(src.s_version); Copy(src, *dst, arena); return dst; } void Serializer::Copy(const DictionaryValue& src, proto::DictionaryValue& dst, Arena* arena) { auto valueType = src.ValueType(); dst.set_value_type(ToProtoType(valueType)); dst.set_version(src.s_version); switch (valueType) { case DictionaryValue::Type::None: break; case DictionaryValue::Type::Bool: dst.set_bool_value(src.Value()); break; case DictionaryValue::Type::Int: dst.set_int_value(src.Value()); break; case DictionaryValue::Type::SizeT: dst.set_size_t_value(src.Value()); break; case DictionaryValue::Type::Float: dst.set_float_value(src.Value()); break; case DictionaryValue::Type::Double: dst.set_double_value(src.Value()); break; case DictionaryValue::Type::String: dst.set_string_value(ToString(src.Value())); break; case DictionaryValue::Type::NDShape: dst.set_allocated_nd_shape_value(CreateProto(src.Value(), arena)); break; case DictionaryValue::Type::Axis: dst.set_allocated_axis_value(CreateProto(src.Value(), arena)); break; case DictionaryValue::Type::Vector: dst.set_allocated_vector_value(CreateProto(src.Value>(), arena)); break; case DictionaryValue::Type::Dictionary: dst.set_allocated_dictionary_value(CreateProto(src.Value(), arena)); break; case DictionaryValue::Type::NDArrayView: dst.set_allocated_nd_array_view_value(CreateProto(src.Value(), arena)); break; default: NOT_IMPLEMENTED } } void Serializer::Copy(const proto::Dictionary& src, Dictionary& dst) { dst.m_dictionaryData->reserve(src.data_size()); for (const auto& kv : src.data()) { Serializer::Copy(kv.second, dst[ToWString(kv.first)]); } } void Serializer::Copy(const proto::DictionaryValue& src, DictionaryValue& dst) { auto valueType = src.value_type(); if (!proto::DictionaryValue::Type_IsValid(valueType)) { return; } dst.m_valueType = FromProtoType(valueType); switch (valueType) { case proto::DictionaryValue::None: break; case proto::DictionaryValue::Bool: dst.m_data.m_boolean = src.bool_value(); break; case proto::DictionaryValue::Int: dst.m_data.m_int = src.int_value(); break; case proto::DictionaryValue::SizeT: dst.m_data.m_sizeT = src.size_t_value(); break; case proto::DictionaryValue::Float: dst.m_data.m_float = src.float_value(); break; case proto::DictionaryValue::Double: dst.m_data.m_double = src.double_value(); break; case proto::DictionaryValue::String: dst.m_data.m_ptr = new std::wstring(ToWString(src.string_value())); break; case proto::DictionaryValue::NDShape: dst.m_data.m_ptr = CreateFromProto(src.nd_shape_value()); break; case proto::DictionaryValue::Axis: dst.m_data.m_ptr = CreateFromProto(src.axis_value()); break; case proto::DictionaryValue::Vector: dst.m_data.m_ptr = CreateFromProto(src.vector_value()); break; case proto::DictionaryValue::Dictionary: dst.m_data.m_ptr = CreateFromProto(src.dictionary_value()); break; case proto::DictionaryValue::NDArrayView: dst.m_data.m_ptr = CreateFromProto(src.nd_array_view_value()); break; } } void Serializer::Write(io::ZeroCopyOutputStream& stream) { io::CodedOutputStream output(&stream); // Protobufs have a hard limit on the maximum message size(INT_MAX = 2GBs). // Check if we fit into a single protobuf message. if (FitsIntoProtobuf()) { CopyNDArrayViewDataToProtos(); m_proto->SerializeToCodedStream(&output); } else { // If we don't, pull the metadata apart from the actual payload (NDArrayView content) // and store the payload separately, outside of the protobuf. // Prefix the metadata protobuf with a magic number and its bytes size. output.WriteLittleEndian32(MAGIC_NUMBER); output.WriteLittleEndian32(m_proto->ByteSize()); m_proto->SerializeToCodedStream(&output); WriteNDArrayViewData(output); } } std::ostream& Serializer::Write(std::ostream& stream) { io::OstreamOutputStream output(&stream); Write(output); return stream; } void Serializer::Write(const std::wstring& filename) { auto fd = GetFileDescriptor(filename, false); { io::FileOutputStream output(fd); Write(output); } #ifdef _MSC_VER _close(fd); #else close(fd); #endif } bool ParseMessage(io::ZeroCopyInputStream& input, Message& msg) { uint32 prefix = 0, limit = INT_MAX;; const void* temp; int size; bool success; do { success = input.Next(&temp, &size); } while (success && size == 0); if (!success) return false; if (size >= (sizeof(prefix) + sizeof(limit))) { io::CodedInputStream::ReadLittleEndian32FromArray(reinterpret_cast(temp), &prefix); } // the message is only prefixed with a magic number + message length, // if its size exceeds 2GBs. if (prefix == MAGIC_NUMBER) { io::CodedInputStream::ReadLittleEndian32FromArray( reinterpret_cast(temp) + sizeof(prefix), &limit); input.BackUp(size - sizeof(prefix) - sizeof(limit)); } else input.BackUp(size); io::CodedInputStream codedInput(&input); codedInput.SetTotalBytesLimit(limit, limit); return msg.ParseFromCodedStream(&codedInput) && codedInput.ConsumedEntireMessage(); } bool Serializer::Read(std::istream& stream, Dictionary& dict) { m_proto = Arena::CreateMessage(&m_arena); return Read(stream, [this, &dict](io::ZeroCopyInputStream& input) { Copy(*dynamic_cast(m_proto), dict); return ReadNDArrayViewData(input); }); } bool Serializer::Read(std::istream& stream, DictionaryValue& value) { m_proto = Arena::CreateMessage(&m_arena); return Read(stream, [this, &value](io::ZeroCopyInputStream& input) { Copy(*dynamic_cast(m_proto), value); return ReadNDArrayViewData(input); }); } bool Serializer::Read(const std::wstring& filename, Dictionary& dict) { m_proto = Arena::CreateMessage(&m_arena); return Read(filename, [this, &dict](io::ZeroCopyInputStream& input) { Copy(*dynamic_cast(m_proto), dict); return ReadNDArrayViewData(input); }); } bool Serializer::Read(const std::wstring& filename, DictionaryValue& value) { m_proto = Arena::CreateMessage(&m_arena); return Read(filename, [this, &value](io::ZeroCopyInputStream& input) { Copy(*dynamic_cast(m_proto), value); return ReadNDArrayViewData(input); }); } bool Serializer::Read(std::wstring filename, const std::function& callback) { bool result; auto fd = GetFileDescriptor(filename, true); { io::FileInputStream input(fd, BLOCK_SIZE); result = ParseMessage(input, *m_proto); result = result && callback(input); } #ifdef _MSC_VER _close(fd); #else close(fd); #endif return result; } bool Serializer::Read(std::istream& stream, const std::function& callback) { io::IstreamInputStream input(&stream, BLOCK_SIZE); if (ParseMessage(input, *m_proto)) { return callback(input); } return false; } std::ostream& operator<<(std::ostream& stream, const Dictionary& dictionary) { return Serializer(dictionary).Write(stream); } std::ostream& operator<<(std::ostream& stream, const DictionaryValue& value) { return Serializer(value).Write(stream); } void Dictionary::Save(const std::wstring& filename) { Serializer(*this).Write(filename); } void DictionaryValue::Save(const std::wstring& filename) { Serializer(*this).Write(filename); } std::istream& operator>>(std::istream& stream, Dictionary& dictionary) { if (!Serializer(dictionary).Read(stream, dictionary)) RuntimeError("Failed to parse Dictionary from the input stream."); return stream; } std::istream& operator>>(std::istream& stream, DictionaryValue& value) { if (!Serializer(value).Read(stream, value)) RuntimeError("Failed to parse DictionaryValue from the input stream."); return stream; } /*static*/ Dictionary Dictionary::Load(const std::wstring& filename) { Dictionary dictionary; if (!Serializer().Read(filename, dictionary)) RuntimeError("Failed to parse Dictionary from file (%ls).", filename.c_str()); return dictionary; } /*static*/ DictionaryValue DictionaryValue::Load(const std::wstring& filename) { DictionaryValue dictionaryValue; if (!Serializer().Read(filename, dictionaryValue)) RuntimeError("Failed to parse DictionaryValue from file (%ls).", filename.c_str()); return dictionaryValue; } }