swh:1:snp:f50ab94432af916b5fb8b4ad831e8dddded77084
Tip revision: 6b2bcac054816cd6e8679e60738f369ea60efcc4 authored by Mark Hillebrand on 09 June 2017, 11:45:25 UTC
configure: bump openjdk
configure: bump openjdk
Tip revision: 6b2bcac
BinaryDataDeserializer.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#pragma once
#include "DataDeserializerBase.h"
#include "BinaryConfigHelper.h"
#include "CorpusDescriptor.h"
#include "BinaryDataChunk.h"
#include "FileHelper.h"
namespace Microsoft { namespace MSR { namespace CNTK {
class BinaryDataDeserialzer
{
public:
BinaryDataDeserialzer(FILE* file, ElementType precision = ElementType::tfloat)
{
ReadName(file);
ReadDataType(file);
ReadSampleSize(file);
if (precision != ElementType::tfloat && precision != ElementType::tdouble)
LogicError("Unsupported precision type %u.", (unsigned int)precision);
if ((m_dataType == DataType::tfloat && precision != ElementType::tfloat) ||
(m_dataType == DataType::tdouble && precision != ElementType::tdouble))
LogicError("Unsupported combination of the input data type %u and precision %u. "
"At the moment, both have to match.", (unsigned int)m_dataType, (unsigned int)precision);
m_precision = precision;
}
virtual size_t GetSequenceDataForChunk(size_t numSequences, void* data, std::vector<SequenceDataPtr>& result) = 0;
virtual StorageType GetStorageType() = 0;
StreamDescriptionPtr GetStreamDescription()
{
auto streamDescription = std::make_shared<StreamDescription>();
streamDescription->m_elementType = m_precision;
streamDescription->m_storageType = GetStorageType();
streamDescription->m_sampleLayout = GetSampleLayout();
streamDescription->m_name = m_name;
return streamDescription;
}
TensorShapePtr GetSampleLayout()
{
return make_shared<TensorShape>(m_sampleDimension);
}
size_t SizeOfDataType()
{
if (m_dataType == DataType::tfloat)
return sizeof(float);
if (m_dataType == DataType::tdouble)
return sizeof(double);
LogicError("Unsupported input data type %u.", (unsigned int)m_dataType);
}
protected:
enum class DataType : unsigned char
{
tfloat = 0,
tdouble = 1,
// TODO:
// tbool = 2, 1 bit per value (one-hot data)
// tbyte = 3, 1 byte per value
};
virtual ~BinaryDataDeserialzer() = default;
void ReadName(FILE* file)
{
uint32_t len;
// read the name
CNTKBinaryFileHelper::ReadOrDie(&len, sizeof(len), 1, file);
vector<char> temp(len + 1 , '\0');
CNTKBinaryFileHelper::ReadOrDie(temp.data(), sizeof(char), len, file);
m_name = msra::strfun::utf16(temp.data());
}
void ReadDataType(FILE* file)
{
CNTKBinaryFileHelper::ReadOrDie(&m_dataType, sizeof(m_dataType), 1, file);
if (m_dataType> DataType::tdouble)
RuntimeError("Unsupported input data type %u.", (unsigned int)m_dataType);
}
void ReadSampleSize(FILE* file)
{
CNTKBinaryFileHelper::ReadOrDie(&m_sampleDimension, sizeof(m_sampleDimension), 1, file);
}
struct DenseInputStreamBuffer : DenseSequenceData
{
const void* GetDataBuffer() override
{
return m_data;
}
void* m_data;
DataType m_dataType;
};
struct SparseInputStreamBuffer : SparseSequenceData
{
SparseInputStreamBuffer()
{
m_totalNnzCount = 0;
}
const void* GetDataBuffer() override
{
return m_data;
}
void* m_data;
};
ElementType m_precision;
DataType m_dataType;
uint32_t m_sampleDimension;
wstring m_name;
};
typedef shared_ptr<BinaryDataDeserialzer> BinaryDataDeserializerPtr;
class DenseBinaryDataDeserializer : public BinaryDataDeserialzer
{
public:
using BinaryDataDeserialzer::BinaryDataDeserialzer;
virtual StorageType GetStorageType() override { return StorageType::dense; }
size_t GetSequenceDataForChunk(size_t numSequences, void* data, std::vector<SequenceDataPtr>& result)
{
size_t valueSize = SizeOfDataType();
result.resize(numSequences);
size_t offset = 0;
for (size_t i = 0; i < numSequences; i++)
{
shared_ptr<DenseInputStreamBuffer> sequenceDataPtr = make_shared<DenseInputStreamBuffer>();
sequenceDataPtr->m_numberOfSamples = *(uint32_t*)((char*)data + offset);
offset += sizeof(uint32_t);
sequenceDataPtr->m_data = (char*)data + offset;
sequenceDataPtr->m_sampleLayout = GetSampleLayout();
sequenceDataPtr->m_elementType = m_precision;
result[i] = sequenceDataPtr;
offset += m_sampleDimension * valueSize * sequenceDataPtr->m_numberOfSamples;
}
return offset;
}
};
class SparseBinaryDataDeserializer : public BinaryDataDeserialzer
{
public:
SparseBinaryDataDeserializer(FILE* file, ElementType precision = ElementType::tfloat)
:BinaryDataDeserialzer(file, precision)
{
if (IndexType(m_sampleDimension) < 0)
{
RuntimeError("Sample dimension is too large for an IndexType value.");
}
}
virtual StorageType GetStorageType() override { return StorageType::sparse_csc; }
// The format of data is:
// sequence[numSequences], where each sequence consists of:
// uint32_t: numSamples
// uint32_t: nnz for the sequence
// ElemType[nnz]: the values for the sparse sequences
// int32_t[nnz]: the row offsets for the sparse sequences
// int32_t[numSamples]: sizes (nnz counts) for each sample in the sequence
size_t GetSequenceDataForChunk(size_t numSequences, void* data, std::vector<SequenceDataPtr>& result)
{
size_t offset = 0;
result.resize(numSequences);
for (size_t i = 0; i < numSequences; i++)
{
shared_ptr<SparseInputStreamBuffer> sequenceDataPtr = make_shared<SparseInputStreamBuffer>();
offset += GetSequenceData((char*)data + offset, sequenceDataPtr);
sequenceDataPtr->m_sampleLayout = GetSampleLayout();
sequenceDataPtr->m_elementType = m_precision;
result[i] = sequenceDataPtr;
}
return offset;
}
size_t GetSequenceData(void* data, shared_ptr<SparseInputStreamBuffer>& sequence)
{
size_t valueSize = SizeOfDataType();
size_t offset = 0;
// The very first value in the buffer is the number of samples in this sequence.
sequence->m_numberOfSamples = *(uint32_t*)data;
offset += sizeof(uint32_t);
// Next is the total number of elements in all of the samples.
uint32_t nnz = *(uint32_t*)((char*)data + offset);
if (IndexType(nnz) < 0)
{
RuntimeError("NNZ count is too large for an IndexType value.");
}
sequence->m_totalNnzCount = nnz;
offset += sizeof(uint32_t);
// the rest of this sequence
// Since we're not templating on ElemType, we use void for the values. Note that this is the only place
// this deserializer uses ElemType, the rest are int32_t for this deserializer.
// The data is already properly packed, so just use it.
sequence->m_data = (char*)data + offset;
offset += valueSize * sequence->m_totalNnzCount;
// The indices are supposed to be correctly packed (i.e., in increasing order)
sequence->m_indices = (int32_t*)((char*)data + offset);
offset += sizeof(int32_t) * sequence->m_totalNnzCount;
int32_t* begin = (int32_t*)((char*)data + offset);
offset += sizeof(int32_t) * sequence->m_numberOfSamples;
int32_t* end = (int32_t*)((char*)data + offset);
sequence->m_nnzCounts.reserve(sequence->m_numberOfSamples);
sequence->m_nnzCounts.assign(begin, end);
return offset;
}
};
}}}