https://github.com/Microsoft/CNTK
Tip revision: 16a41cef30894ca92667bd93079cd6fa11b3e92d authored by Sayan Pathak on 02 November 2017, 16:10:10 UTC
Added super resolution tutorial contributed by Borna with added code to minimize test downloads, fix tests, added documentation and small editorial changes to LSGAN tutorial
Added super resolution tutorial contributed by Borna with added code to minimize test downloads, fix tests, added documentation and small editorial changes to LSGAN tutorial
Tip revision: 16a41ce
ComputeInputStatistics.cpp
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//
#include "stdafx.h"
#include "CNTKLibrary.h"
#include "Utils.h"
#include "CompositeFunction.h"
#include <tuple>
#include "ComputationNetworkBuilder.h"
using namespace Microsoft::MSR::CNTK;
namespace CNTK
{
void ComputeInputPerDimMeansAndInvStdDevs(const MinibatchSourcePtr& minibatchSource,
std::unordered_map<StreamInformation, std::pair<NDArrayViewPtr, NDArrayViewPtr>>& computedMeanAndInvStdDevs,
const DeviceDescriptor& device /*= DeviceDescriptor::CPUDevice()*/)
{
typedef std::shared_ptr<ComputationNode<float>> ComputationNodePtr;
const auto& minibatchSourceStreams = minibatchSource->StreamInfos();
auto computationNetwork = std::make_shared<ComputationNetwork>(AsCNTKImplDeviceId(device));
ComputationNetworkBuilder<float> builder(*computationNetwork);
std::vector<ComputationNodeBasePtr> allInputNodes;
std::unordered_map<StreamInformation, ComputationNodeBasePtr> streamToInputNodeMap;
std::unordered_map<StreamInformation, Variable> streamToDummyInputVariableMap;
std::unordered_map<StreamInformation, Variable> streamToDummyOutputVariableMap;
std::unordered_map<StreamInformation, ComputationNodeBasePtr> streamToMeanNodeMap;
std::unordered_map<StreamInformation, ComputationNodeBasePtr> streamToInvStdDevNodeMap;
size_t totalSizePerSample = 0;
for (auto& currentStreamKV : computedMeanAndInvStdDevs)
{
auto currentStreamInfo = currentStreamKV.first;
if (minibatchSourceStreams.find(currentStreamInfo) == minibatchSourceStreams.end())
InvalidArgument("Stream '%S' for which mean and variance are to be computed, is not supported by the specified minibatchSource.", currentStreamKV.first.AsString().c_str());
if (currentStreamInfo.m_elementType != DataType::Float)
LogicError("ComputeInputPerDimMeansAndInvStdDevs: Stream '%S' has unsupported DataType; only DataType::Float is currently supported by the CNTK built-in composite MinibatchSource.",
currentStreamInfo.AsString().c_str());
auto inputVariableShape = currentStreamInfo.m_sampleLayout;
auto inputTensorShape = AsTensorShape(inputVariableShape);
totalSizePerSample += (inputVariableShape.TotalSize() * sizeof(float));
ComputationNodePtr inputNode;
Variable inputVariable;
if (currentStreamInfo.m_storageFormat != StorageFormat::Dense)
{
inputNode = builder.CreateSparseInputNode(currentStreamInfo.m_name, inputTensorShape);
inputVariable = InputVariable(inputVariableShape, /*isSparse*/ true, DataType::Float, currentStreamInfo.m_name);
}
else
{
inputNode = builder.CreateInputNode(currentStreamInfo.m_name, inputTensorShape);
inputVariable = InputVariable(inputVariableShape, DataType::Float, currentStreamInfo.m_name);
}
allInputNodes.push_back(inputNode);
streamToInputNodeMap[currentStreamInfo] = inputNode;
streamToDummyInputVariableMap[currentStreamInfo] = inputVariable;
streamToDummyOutputVariableMap[currentStreamInfo] = OutputVariable(inputVariableShape, DataType::Float, {}, /*needsGradient =*/ false, currentStreamInfo.m_name);
streamToMeanNodeMap[currentStreamInfo] = builder.Mean(inputNode);
streamToInvStdDevNodeMap[currentStreamInfo] = builder.InvStdDev(inputNode);
}
computationNetwork->CompileNetwork();
computationNetwork->AllocateAllMatrices(computationNetwork->RootNodes(), {}, nullptr);
ScopedNetworkOperationMode modeGuard(computationNetwork, NetworkOperationMode::preComputing);
// initialize
auto preComputeNodes = computationNetwork->GetNodesRequiringPreComputation();
for (auto & preComputeNode : preComputeNodes)
dynamic_pointer_cast<IPreComputeNode>(preComputeNode)->MarkComputed(false /*begin accumulating*/);
std::unordered_map<MBLayoutPtr, Variable> layoutsPopulated;
const size_t maxMinibatchDataSize = (1 << 27); // 128 MB
const size_t minibatchSize = maxMinibatchDataSize / totalSizePerSample;
for (;;)
{
auto minibatchData = minibatchSource->GetNextMinibatch(minibatchSize, device);
if (minibatchData.empty())
break;
for (auto& currentStreamKV : computedMeanAndInvStdDevs)
CompositeFunction::PopulateComputationNodeValue<float>({ streamToDummyInputVariableMap[currentStreamKV.first], minibatchData[currentStreamKV.first].data }, streamToInputNodeMap[currentStreamKV.first], layoutsPopulated);
ComputationNetwork::BumpEvalTimeStamp(allInputNodes);
computationNetwork->ForwardProp(preComputeNodes);
}
// finalize
for (auto & preComputeNode : preComputeNodes)
dynamic_pointer_cast<IPreComputeNode>(preComputeNode)->MarkComputed(true /*done accumulating*/);
// Copy out the results
for (auto& currentStreamKV : computedMeanAndInvStdDevs)
{
ValuePtr mean, invStdDev;
if (computedMeanAndInvStdDevs[currentStreamKV.first].first != nullptr)
mean = MakeSharedObject<Value>(computedMeanAndInvStdDevs[currentStreamKV.first].first);
if (computedMeanAndInvStdDevs[currentStreamKV.first].second != nullptr)
invStdDev = MakeSharedObject<Value>(computedMeanAndInvStdDevs[currentStreamKV.first].second);
CompositeFunction::GetNodeOutputOrGradient(streamToDummyOutputVariableMap[currentStreamKV.first], mean, streamToMeanNodeMap[currentStreamKV.first], false /*getGradient*/);
CompositeFunction::GetNodeOutputOrGradient(streamToDummyOutputVariableMap[currentStreamKV.first], invStdDev, streamToInvStdDevNodeMap[currentStreamKV.first], false /*getGradient*/);
if (computedMeanAndInvStdDevs[currentStreamKV.first].first == nullptr)
computedMeanAndInvStdDevs[currentStreamKV.first].first = mean->Data();
if (computedMeanAndInvStdDevs[currentStreamKV.first].second == nullptr)
computedMeanAndInvStdDevs[currentStreamKV.first].second = invStdDev->Data();
}
}
}