Raw File
ImageConfigHelper.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#include "stdafx.h"
#include "ImageConfigHelper.h"
#include "StringUtil.h"
#include "ConfigUtil.h"

namespace Microsoft { namespace MSR { namespace CNTK {

ImageConfigHelper::ImageConfigHelper(const ConfigParameters& config)
    : m_dataFormat(CHW)
{
    std::vector<std::string> featureNames = GetSectionsWithParameter("ImageReader", config, "width");
    std::vector<std::string> labelNames = GetSectionsWithParameter("ImageReader", config, "labelDim");

    // REVIEW alexeyk: currently support only one feature and label section.
    if (featureNames.size() != 1 || labelNames.size() != 1)
    {
        RuntimeError(
            "ImageReader currently supports a single feature and label stream. '%d' features , '%d' labels found.",
            static_cast<int>(featureNames.size()),
            static_cast<int>(labelNames.size()));
    }

    ConfigParameters featureSection = config(featureNames[0]);
    size_t w = featureSection("width");
    size_t h = featureSection("height");
    size_t c = featureSection("channels");

    std::string mbFmt = featureSection("mbFormat", "nchw");
    if (AreEqualIgnoreCase(mbFmt, "nhwc") || AreEqualIgnoreCase(mbFmt, "legacy"))
    {
        m_dataFormat = HWC;
    }
    else if (!AreEqualIgnoreCase(mbFmt, "nchw") || AreEqualIgnoreCase(mbFmt, "cudnn"))
    {
        RuntimeError("ImageReader does not support the sample format '%s', only 'nchw' and 'nhwc' are supported.", mbFmt.c_str());
    }

    auto features = std::make_shared<StreamDescription>();
    features->m_id = 0;
    features->m_name = msra::strfun::utf16(featureSection.ConfigName());
    features->m_sampleLayout = std::make_shared<TensorShape>(ImageDimensions(w, h, c).AsTensorShape(m_dataFormat));
    features->m_storageType = StorageType::dense;
    m_streams.push_back(features);

    ConfigParameters label = config(labelNames[0]);
    size_t labelDimension = label("labelDim");

    auto labelSection = std::make_shared<StreamDescription>();
    labelSection->m_id = 1;
    labelSection->m_name = msra::strfun::utf16(label.ConfigName());
    labelSection->m_sampleLayout = std::make_shared<TensorShape>(labelDimension);
    labelSection->m_storageType = StorageType::dense;
    m_streams.push_back(labelSection);

    m_mapPath = config(L"file");

    m_grayscale = config(L"grayscale", c == 1);
    std::string rand = config(L"randomize", "auto");

    if (AreEqualIgnoreCase(rand, "auto"))
    {
        m_randomize = true;
    }
    else if (AreEqualIgnoreCase(rand, "none"))
    {
        m_randomize = false;
    }
    else
    {
        RuntimeError("'randomize' parameter must be set to 'auto' or 'none'");
    }

    // Identify precision
    string precision = config.Find("precision", "float");
    if (AreEqualIgnoreCase(precision, "float"))
    {
        features->m_elementType = ElementType::tfloat;
        labelSection->m_elementType = ElementType::tfloat;
    }
    else if (AreEqualIgnoreCase(precision, "double"))
    {
        features->m_elementType = ElementType::tdouble;
        labelSection->m_elementType = ElementType::tdouble;
    }
    else
    {
        RuntimeError("Not supported precision '%s'. Expected 'double' or 'float'.", precision.c_str());
    }

    m_cpuThreadCount = config(L"numCPUThreads", 0);

    m_cropType = ParseCropType(featureSection(L"cropType", ""));
}

std::vector<StreamDescriptionPtr> ImageConfigHelper::GetStreams() const
{
    return m_streams;
}

size_t ImageConfigHelper::GetFeatureStreamId() const
{
    // Currently we only support a single feature/label stream, so the index is hard-wired.
    return 0;
}

size_t ImageConfigHelper::GetLabelStreamId() const
{
    // Currently we only support a single feature/label stream, so the index is hard-wired.
    return 1;
}

std::string ImageConfigHelper::GetMapPath() const
{
    return m_mapPath;
}

CropType ImageConfigHelper::ParseCropType(const std::string &src)
{
    if (src.empty() || AreEqualIgnoreCase(src, "center"))
    {
        return CropType::Center;
    }

    if (AreEqualIgnoreCase(src, "random"))
    {
        return CropType::Random;
    }

    if (AreEqualIgnoreCase(src, "multiview10"))
    {
        return CropType::MultiView10;
    }

    RuntimeError("Invalid crop type: %s.", src.c_str());
}

}}}
back to top