swh:1:snp:f50ab94432af916b5fb8b4ad831e8dddded77084
Tip revision: 1c2bc5a7783570f55330bbfdeec43d94477f4d2d authored by Marko Radmilac on 28 April 2016, 17:21:13 UTC
Add support for new CU
Add support for new CU
Tip revision: 1c2bc5a
ImageDataDeserializer.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#define __STDC_FORMAT_MACROS
#include <inttypes.h>
#include <opencv2/opencv.hpp>
#include <numeric>
#include <limits>
#include "ImageDataDeserializer.h"
#include "ImageConfigHelper.h"
namespace Microsoft { namespace MSR { namespace CNTK {
class ImageDataDeserializer::LabelGenerator
{
public:
virtual void CreateLabelFor(size_t classId, SparseSequenceData& data) = 0;
virtual ~LabelGenerator() { }
};
// A helper class to generate a typed label in a sparse format.
// A label is just a category/class the image belongs to.
// It is represented as a array indexed by the category with zero values for all categories the image does not belong to,
// and a single one for a category it belongs to: [ 0, .. 0.. 1 .. 0 ]
// The class is parameterized because the representation of 1 is type specific.
template <class TElement>
class TypedLabelGenerator : public ImageDataDeserializer::LabelGenerator
{
public:
TypedLabelGenerator(size_t labelDimension) : m_value(1), m_indices(labelDimension)
{
if (labelDimension > numeric_limits<IndexType>::max())
{
RuntimeError("Label dimension (%" PRIu64 ") exceeds the maximum allowed "
"value (%" PRIu64 ")\n", labelDimension, (size_t)numeric_limits<IndexType>::max());
}
iota(m_indices.begin(), m_indices.end(), 0);
}
virtual void CreateLabelFor(size_t classId, SparseSequenceData& data) override
{
data.m_nnzCounts.resize(1);
data.m_nnzCounts[0] = 1;
data.m_totalNnzCount = 1;
data.m_data = &m_value;
data.m_indices = &(m_indices[classId]);
}
private:
TElement m_value;
vector<IndexType> m_indices;
};
// Used to keep track of the image. Accessed only using DenseSequenceData interface.
struct DeserializedImage : DenseSequenceData
{
cv::Mat m_image;
};
// For image, chunks correspond to a single image.
class ImageDataDeserializer::ImageChunk : public Chunk, public std::enable_shared_from_this<ImageChunk>
{
ImageSequenceDescription m_description;
ImageDataDeserializer& m_parent;
public:
ImageChunk(ImageSequenceDescription& description, ImageDataDeserializer& parent)
: m_description(description), m_parent(parent)
{
}
virtual void GetSequence(size_t sequenceId, std::vector<SequenceDataPtr>& result) override
{
assert(sequenceId == m_description.m_id);
const auto& imageSequence = m_description;
auto image = std::make_shared<DeserializedImage>();
image->m_image = std::move(m_parent.ReadImage(m_description.m_id, imageSequence.m_path, m_parent.m_grayscale));
auto& cvImage = image->m_image;
if (!cvImage.data)
{
RuntimeError("Cannot open file '%s'", imageSequence.m_path.c_str());
}
// Convert element type.
int dataType = m_parent.m_featureElementType == ElementType::tfloat ? CV_32F : CV_64F;
if (cvImage.type() != CV_MAKETYPE(dataType, cvImage.channels()))
{
cvImage.convertTo(cvImage, dataType);
}
if (!cvImage.isContinuous())
{
cvImage = cvImage.clone();
}
assert(cvImage.isContinuous());
image->m_data = image->m_image.data;
ImageDimensions dimensions(cvImage.cols, cvImage.rows, cvImage.channels());
image->m_sampleLayout = std::make_shared<TensorShape>(dimensions.AsTensorShape(HWC));
image->m_id = sequenceId;
image->m_numberOfSamples = 1;
image->m_chunk = shared_from_this();
result.push_back(image);
SparseSequenceDataPtr label = std::make_shared<SparseSequenceData>();
label->m_chunk = shared_from_this();
m_parent.m_labelGenerator->CreateLabelFor(imageSequence.m_classId, *label);
label->m_numberOfSamples = 1;
result.push_back(label);
}
};
ImageDataDeserializer::ImageDataDeserializer(const ConfigParameters& config)
{
ImageConfigHelper configHelper(config);
m_streams = configHelper.GetStreams();
assert(m_streams.size() == 2);
m_grayscale = configHelper.UseGrayscale();
const auto& label = m_streams[configHelper.GetLabelStreamId()];
const auto& feature = m_streams[configHelper.GetFeatureStreamId()];
// Expect data in HWC.
ImageDimensions dimensions(*feature->m_sampleLayout, configHelper.GetDataFormat());
feature->m_sampleLayout = std::make_shared<TensorShape>(dimensions.AsTensorShape(HWC));
label->m_storageType = StorageType::sparse_csc;
feature->m_storageType = StorageType::dense;
m_featureElementType = feature->m_elementType;
size_t labelDimension = label->m_sampleLayout->GetDim(0);
if (label->m_elementType == ElementType::tfloat)
{
m_labelGenerator = std::make_shared<TypedLabelGenerator<float>>(labelDimension);
}
else if (label->m_elementType == ElementType::tdouble)
{
m_labelGenerator = std::make_shared<TypedLabelGenerator<double>>(labelDimension);
}
else
{
RuntimeError("Unsupported label element type '%d'.", (int)label->m_elementType);
}
CreateSequenceDescriptions(configHelper.GetMapPath(), labelDimension, configHelper);
}
// Descriptions of chunks exposed by the image reader.
ChunkDescriptions ImageDataDeserializer::GetChunkDescriptions()
{
ChunkDescriptions result;
result.reserve(m_imageSequences.size());
for (auto const& s : m_imageSequences)
{
auto chunk = std::make_shared<ChunkDescription>();
chunk->m_id = s.m_chunkId;
chunk->m_numberOfSamples = 1;
chunk->m_numberOfSequences = 1;
result.push_back(chunk);
}
return result;
}
void ImageDataDeserializer::GetSequencesForChunk(size_t chunkId, std::vector<SequenceDescription>& result)
{
// Currently a single sequence per chunk.
result.push_back(m_imageSequences[chunkId]);
}
void ImageDataDeserializer::CreateSequenceDescriptions(std::string mapPath, size_t labelDimension, const ImageConfigHelper& config)
{
std::ifstream mapFile(mapPath);
if (!mapFile)
{
RuntimeError("Could not open %s for reading.", mapPath.c_str());
}
size_t itemsPerLine = config.IsMultiViewCrop() ? 10 : 1;
size_t curId = 0;
std::string line;
PathReaderMap knownReaders;
ImageSequenceDescription description;
description.m_numberOfSamples = 1;
description.m_isValid = true;
for (size_t lineIndex = 0; std::getline(mapFile, line); ++lineIndex)
{
std::stringstream ss(line);
std::string imagePath;
std::string classId;
if (!std::getline(ss, imagePath, '\t') || !std::getline(ss, classId, '\t'))
RuntimeError("Invalid map file format, must contain 2 tab-delimited columns, line %" PRIu64 " in file %s.", lineIndex, mapPath.c_str());
char* eptr;
errno = 0;
size_t cid = strtoull(classId.c_str(), &eptr, 10);
if (classId.c_str() == eptr || errno == ERANGE)
RuntimeError("Cannot parse label value on line %" PRIu64 ", second column, in file %s.", lineIndex, mapPath.c_str());
if (cid >= labelDimension)
{
RuntimeError(
"Image '%s' has invalid class id '%" PRIu64 "'. Expected label dimension is '%" PRIu64 "'. Line %" PRIu64 " in file %s.",
imagePath.c_str(), cid, labelDimension, lineIndex, mapPath.c_str());
}
for (size_t start = curId; curId < start + itemsPerLine; curId++)
{
description.m_id = curId;
description.m_chunkId = curId;
description.m_path = imagePath;
description.m_classId = cid;
description.m_key.m_sequence = description.m_id;
description.m_key.m_sample = 0;
m_imageSequences.push_back(description);
RegisterByteReader(description.m_id, description.m_path, knownReaders);
}
}
}
ChunkPtr ImageDataDeserializer::GetChunk(size_t chunkId)
{
auto sequenceDescription = m_imageSequences[chunkId];
return std::make_shared<ImageChunk>(sequenceDescription, *this);
}
void ImageDataDeserializer::RegisterByteReader(size_t seqId, const std::string& path, PathReaderMap& knownReaders)
{
assert(!path.empty());
auto atPos = path.find_first_of('@');
// Is it container or plain image file?
if (atPos == std::string::npos)
return;
// REVIEW alexeyk: only .zip container support for now.
#ifdef USE_ZIP
assert(atPos > 0);
assert(atPos + 1 < path.length());
auto containerPath = path.substr(0, atPos);
// skip @ symbol and path separator (/ or \)
auto itemPath = path.substr(atPos + 2);
// zlib only supports / as path separator.
std::replace(begin(itemPath), end(itemPath), '\\', '/');
std::shared_ptr<ByteReader> reader;
auto r = knownReaders.find(containerPath);
if (r == knownReaders.end())
{
reader = std::make_shared<ZipByteReader>(containerPath);
knownReaders[containerPath] = reader;
}
else
{
reader = (*r).second;
}
reader->Register(seqId, itemPath);
m_readers[seqId] = reader;
#else
UNUSED(seqId);
UNUSED(knownReaders);
RuntimeError("The code is built without zip container support. Only plain image files are supported.");
#endif
}
cv::Mat ImageDataDeserializer::ReadImage(size_t seqId, const std::string& path, bool grayscale)
{
assert(!path.empty());
ImageDataDeserializer::SeqReaderMap::const_iterator r;
if (m_readers.empty() || (r = m_readers.find(seqId)) == m_readers.end())
return m_defaultReader.Read(seqId, path, grayscale);
return (*r).second->Read(seqId, path, grayscale);
}
cv::Mat FileByteReader::Read(size_t, const std::string& path, bool grayscale)
{
assert(!path.empty());
return cv::imread(path, grayscale ? cv::IMREAD_GRAYSCALE : cv::IMREAD_COLOR);
}
}}}