https://github.com/Microsoft/CNTK
Raw File
Tip revision: 05ef64e97f55927cc04173162c4c4d62de0a60de authored by Mudit Jain on 22 April 2018, 16:06:26 UTC
debug flag fix
Tip revision: 05ef64e
Utils.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include "CNTKLibrary.h"
#include "CommonMatrix.h"
#include "TensorShape.h"
#include <string>
#include "Config.h"
#include "ConvolutionEngine.h"
#include "ReshapingNodes.h"

namespace CNTK
{
    // Forward declarations
    class Dictionary;

    // Helper to get the size of an element of the specified DataType
    inline size_t ElementSize(DataType dataType)
    {
        if (dataType == DataType::Float)
            return sizeof(float);
        else if (dataType == DataType::Double)
            return sizeof(double);
        else
            NOT_IMPLEMENTED;
    }

    template <typename T>
    inline bool IsObjectExpired(std::weak_ptr<T> ptrToObject)
    {
        if ((ptrToObject.owner_before(std::weak_ptr<T>{}) || std::weak_ptr<T>{}.owner_before(ptrToObject)) && ptrToObject.expired())
            return true;
        else
            return false;
    }

    inline DEVICEID_TYPE AsCNTKImplDeviceId(const DeviceDescriptor& device)
    {
        if (device.Type() == DeviceKind::CPU)
            return CPUDEVICE;
        if (device.Type() == DeviceKind::GPU)
            return device.Id();

        LogicError("Invalid device type (%u).", (unsigned int)device.Type());
    }

    inline Microsoft::MSR::CNTK::MatrixFormat AsCNTKImplMatrixFormat(StorageFormat storageFormat)
    {
        if (storageFormat == StorageFormat::Dense)
            return Microsoft::MSR::CNTK::MatrixFormat::matrixFormatDense;
        else if (storageFormat == StorageFormat::SparseCSC)
            return Microsoft::MSR::CNTK::MatrixFormat::matrixFormatSparseCSC;
        else if (storageFormat == StorageFormat::SparseBlockCol)
            return Microsoft::MSR::CNTK::MatrixFormat::matrixFormatSparseBlockCol;
        else
            NOT_IMPLEMENTED;
    }

    inline StorageFormat AsStorageFormat(Microsoft::MSR::CNTK::MatrixFormat matrixFormat)
    {
        if (matrixFormat == Microsoft::MSR::CNTK::MatrixFormat::matrixFormatDense)
            return StorageFormat::Dense;
        else if (matrixFormat == Microsoft::MSR::CNTK::MatrixFormat::matrixFormatSparseCSC)
            return StorageFormat::SparseCSC;
        else if (matrixFormat == Microsoft::MSR::CNTK::MatrixFormat::matrixFormatSparseBlockCol)
            return StorageFormat::SparseBlockCol;
        else
            NOT_IMPLEMENTED;
    }

    inline DeviceDescriptor AsDeviceDescriptor(DEVICEID_TYPE deviceId)
    {
        if (deviceId == CPUDEVICE)
            return DeviceDescriptor::CPUDevice();
        else
            return DeviceDescriptor::GPUDevice(deviceId);
    }

    inline NDShape AsNDShape(const Microsoft::MSR::CNTK::TensorShape& tensorShape, bool allowNonFlattenableTensorShapes = false)
    {
        if (!allowNonFlattenableTensorShapes)
    {
        // The TensorShape should be flattenable to 1D
        for (size_t i = 1; i < tensorShape.GetRank(); ++i)
        {
            if (!tensorShape.CanFlatten(i))
                InvalidArgument("AsNDShape() can only be called for TensorShapes that can be flattened to 1D.");
        }
        }

        return std::vector<size_t>(tensorShape.GetDims().begin(), tensorShape.GetDims().end());
    }

    inline Microsoft::MSR::CNTK::TensorShape AsTensorShape(const NDShape& viewShape)
    {
        const size_t maxNumAxesSupportedByTensorView = 12;
        if (viewShape.Rank() > maxNumAxesSupportedByTensorView)
            LogicError("The number (%d) of requested axes exceeds the currently supported limit (%d)", (int)viewShape.Rank(), (int)maxNumAxesSupportedByTensorView);

        // TensorShape is required to be at least 1D
        size_t minRankSize = 0;
        Microsoft::MSR::CNTK::SmallVector<size_t> tensorViewShape(std::max<size_t>(minRankSize, viewShape.Rank()));
        for (size_t i = 0; i < tensorViewShape.size(); ++i)
            tensorViewShape[i] = (i < viewShape.Rank()) ? viewShape[i] : 1;

        return tensorViewShape;
    }

    inline Microsoft::MSR::CNTK::TensorShape AsTensorViewShape(const Microsoft::MSR::CNTK::TensorShape& viewShape)
    {
        // For TensorView shapes we pad the TensorShape to be at least rank 2
        return viewShape.PadRank(std::max<size_t>(2, viewShape.GetRank()));
    }

    inline Microsoft::MSR::CNTK::TensorShape AsTensorViewShape(const NDShape& viewShape)
    {
        return AsTensorViewShape(AsTensorShape(viewShape));
    }

    inline std::pair<size_t, size_t> GetMatrixDimensions(const NDShape& viewShape)
    {
        // Ensure none of the shape dimensions are unknown
        if (viewShape.HasUnboundDimension())
            InvalidArgument("Cannot create an NDArrayView using a view shape '%S' that has unknown dimensions for any of its axes.", viewShape.AsString().c_str());

        size_t matrixRowSize = (viewShape.Rank() > 0) ? viewShape[0] : 1;
        size_t matrixColSize = (viewShape.Rank() > 0) ? viewShape.SubShape(1).TotalSize() : 1;

        return{ matrixRowSize, matrixColSize };
    }

    inline bool IsSparseInput(const Variable& var)
    {
        return var.IsInput() && var.IsSparse();
    }


    inline void AddIndentation(std::wstringstream& s, size_t numIndentationSpaces)
    {
        for (size_t i = 0; i < numIndentationSpaces; ++i)
            s << L" ";
    }

    static const size_t perLevelIndentSize = 4;
    inline void AddConfigString(std::wstringstream& s, const std::wstring& key, const DictionaryValue& value, size_t numIndentationSpaces);
    inline void AddConfigString(std::wstringstream& s, const DictionaryValue& value, size_t numIndentationSpaces)
    {
        switch (value.ValueType())
        {
        case DictionaryValue::Type::Bool:
            s << value.Value<bool>();
            break;
        case DictionaryValue::Type::Float:
            s << value.Value<float>();
            break;
        case DictionaryValue::Type::Double:
            s << value.Value<double>();
            break;
        case DictionaryValue::Type::String:
            s << value.Value<std::wstring>();
            break;
        case DictionaryValue::Type::Int:
            s << value.Value<int>();
            break;
        case DictionaryValue::Type::SizeT:
            s << value.Value<size_t>();
            break;
        case DictionaryValue::Type::Vector:
        {
            const auto& valueVector = value.Value<std::vector<DictionaryValue>>();
            s << L"(" << std::endl;
            AddIndentation(s, numIndentationSpaces + perLevelIndentSize);
            bool isFirst = true;
            for (const auto& val : valueVector)
            {
                if (!isFirst)
                    s << L":";
                else
                    isFirst = false;

                AddConfigString(s, val, numIndentationSpaces + perLevelIndentSize);
            }
            AddIndentation(s, numIndentationSpaces);
            s << L")";
            break;
        }
        case DictionaryValue::Type::Dictionary:
        {
            const auto& valueDictionary = value.Value<Dictionary>();
            s << L"[" << std::endl;
            for (const auto& keyValuePair : *(valueDictionary.m_dictionaryData))
            {
                AddConfigString(s, keyValuePair.first, keyValuePair.second, numIndentationSpaces + perLevelIndentSize);
            }
            AddIndentation(s, numIndentationSpaces);
            s << L"]";
            break;
        }
        default:
            LogicError("Unsupported DictionaryValue type");
        }
    }

    inline void AddConfigString(std::wstringstream& s, const std::wstring& key, const DictionaryValue& value, size_t numIndentationSpaces)
    {
        AddIndentation(s, numIndentationSpaces);
        s << key << L" = ";
        AddConfigString(s, value, numIndentationSpaces);
        s << std::endl;
    }

    template <typename T>
    inline std::vector<DictionaryValue> AsDictionaryValueVector(const std::vector<T>& elementVector)
    {
        static_assert(std::is_same<T, bool>::value ||
                      std::is_same<T, int>::value ||
                      std::is_same<T, size_t>::value ||
                      std::is_same<T, float>::value ||
                      std::is_same<T, double>::value ||
                      std::is_same<T, Axis>::value ||
                      std::is_same<T, std::wstring>::value,
                      "Unsupported ValueType");

        std::vector<DictionaryValue> dictionaryValueVector;
        for (auto value : elementVector)
            dictionaryValueVector.push_back(value);

        return dictionaryValueVector;
    }

    template <typename T>
    inline std::vector<T> AsVector(const std::vector<DictionaryValue>& dictionaryValueVector)
    {
        static_assert(std::is_same<T, bool>::value ||
                      std::is_same<T, int>::value || 
                      std::is_same<T, size_t>::value ||
                      std::is_same<T, float>::value ||
                      std::is_same<T, double>::value ||
                      std::is_same<T, Axis>::value ||
                      std::is_same<T, std::wstring>::value,
                      "Unsupported ValueType");

        std::vector<T> elementVector;
        for (auto value : dictionaryValueVector)
            elementVector.push_back(value.Value<T>());

        return elementVector;
    }

    inline PoolingType AsPoolingType(Microsoft::MSR::CNTK::PoolKind cntkPoolingKind)
    {
        switch (cntkPoolingKind)
        {
        case Microsoft::MSR::CNTK::PoolKind::Average:
            return PoolingType::Average;
        case Microsoft::MSR::CNTK::PoolKind::Max:
            return PoolingType::Max;
        default:
            LogicError("Unknown pooling type");
        }
    }

    inline Microsoft::MSR::CNTK::PoolKind AsCNTKPoolKind(PoolingType poolingType)
    {
        switch (poolingType)
        {
        case PoolingType::Average:
            return Microsoft::MSR::CNTK::PoolKind::Average;
        case PoolingType::Max:
            return Microsoft::MSR::CNTK::PoolKind::Max;
        default:
            LogicError("Unknown pooling type");
        }
    }

    static int const CNTKInternalIdxValueForAllStaticAxes = Microsoft::MSR::CNTK::ReduceElementsNode<float>::CNTKInternalIdxValueForAllStaticAxes;
    static int const CNTKInternalIdxValueForAllAxes = Microsoft::MSR::CNTK::ReduceElementsNode<float>::CNTKInternalIdxValueForAllAxes;
    static int const CNTKInternalIdxValueForSequenceAxis = Microsoft::MSR::CNTK::ReduceElementsNode<float>::CNTKInternalIdxValueForSequenceAxis;
    static int const CNTKInternalIdxValueForBatchAxis = Microsoft::MSR::CNTK::ReduceElementsNode<float>::CNTKInternalIdxValueForBatchAxis;

    inline Axis AsAxis(int CNTKInternalAxisIdx)
    {
        if (CNTKInternalAxisIdx == CNTKInternalIdxValueForAllStaticAxes)
            return Axis::AllStaticAxes();

        if (CNTKInternalAxisIdx == CNTKInternalIdxValueForAllAxes)
            return Axis::AllAxes();

        if (CNTKInternalAxisIdx == CNTKInternalIdxValueForSequenceAxis)
            return Axis::OperandSequenceAxis();

        if (CNTKInternalAxisIdx == CNTKInternalIdxValueForBatchAxis)
            return Axis::DefaultBatchAxis();

        return Axis(CNTKInternalAxisIdx - 1);
    }

    inline std::vector<Axis> AsAxis(std::vector<int> CNTKInternalAxis)
    {
        std::vector<Axis> retAxisVec; 
        for (auto& axisIdx : CNTKInternalAxis)
            retAxisVec.push_back(AsAxis(axisIdx));
        return retAxisVec;
    }

    inline int AsCNTKInternalAxisIdx(const Axis& axis)
    {
        if (axis == Axis::AllStaticAxes())
            return CNTKInternalIdxValueForAllStaticAxes;

        if (axis == Axis::AllAxes())
            return CNTKInternalIdxValueForAllAxes;

        if (axis.IsDynamicAxis() && axis.IsOrdered())
            return CNTKInternalIdxValueForSequenceAxis;

        if (axis == Axis::DefaultBatchAxis())
            return CNTKInternalIdxValueForBatchAxis;

        if (!axis.IsStaticAxis())
            LogicError("Only Static Axes can be converted to a CNTK internal axis index");

        return (int)(axis.StaticAxisIndex() + 1);
    }

    inline std::vector<int> AsCNTKInternalAxisIdx(const std::vector<Axis>& axisVec)
    {
        std::vector<int> retAxisVec; 
        for (auto& axis : axisVec)
            retAxisVec.push_back(AsCNTKInternalAxisIdx(axis)); 
        return retAxisVec; 
    }

    inline std::pair<NDShape, NDShape> GetConvolutionOutputMapCountAndKernelShape(const NDShape& convolutionMapShape, const NDShape& operandShape, bool transpose)
    {
        NDShape kernelShape = convolutionMapShape.SubShape(0, operandShape.Rank());
        auto outputMapCount = convolutionMapShape.SubShape(kernelShape.Rank());
        auto shapeRank = operandShape.Rank(); 
        NDShape paddedOutputMapCount;
        if (shapeRank > outputMapCount.Rank())
            paddedOutputMapCount = NDShape(shapeRank - outputMapCount.Rank(), 1);

        paddedOutputMapCount = paddedOutputMapCount.AppendShape(outputMapCount);

        if (transpose && (shapeRank > 0) && (paddedOutputMapCount[shapeRank - 1] == NDShape::InferredDimension))  // convolution transpose, the mapCount in depth is derived from operandShape 
        {
            if (operandShape[shapeRank - 1] == NDShape::FreeDimension)
                InvalidArgument("Deconvolution: Output map count cannot be inferred from operand shape '%S' free dimension.", operandShape.AsString().c_str());

            paddedOutputMapCount[shapeRank - 1] = operandShape[shapeRank - 1];
        }

        return{ paddedOutputMapCount, kernelShape };
    }



    template <typename SourceElementType, typename TargetElementType>
    inline TargetElementType* Copy(const SourceElementType* src, size_t srcSize)
    {
        // Cast to double
        TargetElementType* castValue = new TargetElementType[srcSize];
        for (size_t i = 0; i < srcSize; ++i)
            castValue[i] = (TargetElementType)src[i];

        return castValue;
    }

    inline NDArrayViewPtr CloneAsDataType(const NDArrayViewPtr& source, DataType targetDataType, bool readOnly)
    {
        if (source->Device() != DeviceDescriptor::CPUDevice())
            LogicError("CloneAsDataType currently does not support non-CPU source NDArrayView objects");

        auto sourceDataType = source->GetDataType();
        if (sourceDataType == targetDataType)
            LogicError("CloneAsDataType: Source and target DataTypes are same");

        if (targetDataType != DataType::Double)
            LogicError("CloneAsDataType: Only Double target DataType is supported");

        auto sourceShape = source->Shape();
        auto sourceSize = sourceShape.TotalSize();

        // Cast to double
        double* castValue = Copy<float, double>(source->DataBuffer<float>(), sourceSize);
        return MakeSharedObject<NDArrayView>(sourceShape, castValue, sourceSize, DeviceDescriptor::CPUDevice(), readOnly);
    }

    template <typename T>
    inline std::string Typename(const T* = nullptr)
    {
        auto name = typeid(T).name(); 
        if (strncmp(name, "class ", 6) == 0)
        {
            // On Windows, the type name contains "class" prefix. 
            // Return the actual name, omitting the prefix.
            return &name[6];
        }
        return name;
    }

    inline std::wstring ParanthesizedName(const std::wstring& name)
    {
        if (name.empty())
            return name;

        return L"(" + name + L")";
    }

    static const std::wstring UidPrefix = L"__v2libuid__";
    static const std::wstring NamePrefix = L"__v2libname__";

    // 'generateMangledNames' = true is used if we want to emit mangled names for the internal CNTK v1 nodes so that when
    // saving the model in V1 format and loading it back, we can retrieve the original V2 Variable/Function UID and Name.
    inline std::wstring CNTKInternalNodeNameFromUidAndName(const std::wstring& uid, const std::wstring& name, bool generateMangledNames = false)
    {
        if (generateMangledNames)
            return UidPrefix + uid + NamePrefix + name;
        else
            return uid;
    }

    inline std::pair<std::wstring, std::wstring> UidAndNameFromCNTKInternalNodeName(const std::wstring& CNTKInternalNodeName)
    {
        std::wstring uid, name;
        auto uidPrefixBeginPos = CNTKInternalNodeName.find(UidPrefix);
        if (uidPrefixBeginPos != std::wstring::npos)
        {
            auto uidBeginPos = uidPrefixBeginPos + UidPrefix.length();
            auto namePrefixBeginPos = CNTKInternalNodeName.find(NamePrefix, uidBeginPos);
            if (namePrefixBeginPos == std::wstring::npos)
                LogicError("CNTK internal node name found to contain uid but not name!");

            auto nameBeginPos = namePrefixBeginPos + NamePrefix.length();
            uid = CNTKInternalNodeName.substr(uidBeginPos, namePrefixBeginPos - uidBeginPos);
            name = CNTKInternalNodeName.substr(nameBeginPos);
        }

        return{ uid, name };
    }

    inline std::pair<std::wstring, std::wstring> UidAndNameFromCNTKInternalNodeName(const std::wstring& CNTKInternalNodeName, VariableKind varKind)
    {
        std::wstring uid, name;
        std::tie(uid, name) = UidAndNameFromCNTKInternalNodeName(CNTKInternalNodeName);
        if (uid == L"")
        {
            name = CNTKInternalNodeName;
            uid = Internal::GenerateUid(varKind);
        }

        return{ uid, name };
    }

    std::pair<std::wstring, std::wstring> UidAndNameFromCNTKInternalNodeName(const std::wstring& CNTKInternalNodeName, const PrimitiveOpType& opType);

    inline std::vector<Axis> GetDerivedDynamicAxes(const Axis& sourceAxis, size_t multiplicativeFactor, int additiveFactor)
    {
        if (!sourceAxis.IsDynamicAxis())
            LogicError("Only dynamic axes can be derived from, to create new dynamic axes!");

        if ((multiplicativeFactor == 0) && (additiveFactor == 0))
            LogicError("Zero size dynamic axes are not allowed!");

        // If we slice off exactly one frame off of the source axis, then we effectively delete this axis
        if ((multiplicativeFactor == 0) && (additiveFactor == 1))
            return {};

        if ((multiplicativeFactor == 1) && (additiveFactor == 0))
            return {sourceAxis};

        std::wstring derivedDynamicAxisName = sourceAxis.Name();
        if (multiplicativeFactor > 0)
        {
            derivedDynamicAxisName += L"_times_" + std::to_wstring(multiplicativeFactor);
            if (additiveFactor > 0)
                derivedDynamicAxisName += L"_plus_" + std::to_wstring(additiveFactor);
            else
                derivedDynamicAxisName += L"_minus_" + std::to_wstring(-additiveFactor);
        }
        else
        {
            assert(additiveFactor > 0);
            derivedDynamicAxisName += L"_fixedSliceOf_" + std::to_wstring(additiveFactor);
        }

        return{ Axis(derivedDynamicAxisName, sourceAxis.IsOrdered()) };
    }

    inline Axis& NormalizeStaticAxis(Axis& axis, size_t rank)
    {
        if (axis == Axis::EndStaticAxis())
            axis = Axis((int)rank);
        else if (axis.StaticAxisIndex() < 0)
        {
            auto normalizedAxis = Axis((int)rank + axis.StaticAxisIndex());
            if (normalizedAxis.StaticAxisIndex() < 0)
                InvalidArgument("Axis '%S' is out of bounds for the rank '%zd' it applies to.", axis.AsString().c_str(), rank);
            else
                axis = normalizedAxis;
        }
        return axis;
    }

    inline Axis& NormalizeStaticAxis(Axis& axis, const NDShape& operandShape)
    {
        if (axis != Axis::AllStaticAxes() && axis != Axis::AllAxes())
        {
            assert(axis.IsStaticAxis());
            assert(!operandShape.IsUnknown());
            axis = NormalizeStaticAxis(axis, operandShape.Rank());
        }
        return axis;
    }

    inline Axis& NormalizeAxis(Axis& axis, const Variable& operand)
    {
        if (axis.IsDynamicAxis())
        {
            auto operandDynamicAxes = operand.DynamicAxes();
            if (axis == Axis::OperandSequenceAxis() && (operandDynamicAxes != Axis::UnknownDynamicAxes()))
            {
                auto numOrderedDynamicAxes = std::count_if(operandDynamicAxes.begin(), operandDynamicAxes.end(), [](const Axis& axis) { return axis.IsOrdered(); });
                if (numOrderedDynamicAxes != 1)
                    InvalidArgument("Axis::OperandSequenceAxis() sentinel cannot be resolved if the operand '%S' has no sequence axis or > 1 ordered dynamic axes.", operand.AsString().c_str());

                axis = *std::find_if(operandDynamicAxes.begin(), operandDynamicAxes.end(), [](const Axis& axis) { return axis.IsOrdered(); });
            }

            return axis;
        }
        else
            return NormalizeStaticAxis(axis, operand.Shape());
    }

    inline void VerifyStaticAxis(const Axis& axis, const NDShape& operandShape)
    {
        assert(axis.IsStaticAxis());
        assert(axis.StaticAxisIndex() >= 0);

        if (axis.StaticAxisIndex() >= (int)operandShape.Rank())
            InvalidArgument("The specified axis index (%d) exceeds the #static axes (%d) of the corresponding operand (shape='%S)",
                            (int)axis.StaticAxisIndex(), (int)operandShape.Rank(), operandShape.AsString().c_str());
    }

    bool IsFirstOutputOfMultiOutputFunction(const Variable& var);
    inline  bool IsConstantScalar(const Variable& var)
    {
        return var.IsConstant() && (var.Shape().TotalSize() == 1);
    }

    inline Variable PlaceholderLike(const Variable& var)
    {
        return PlaceholderVariable(var.Shape(), var.GetDataType(), var.Name(), var.DynamicAxes());
    }

    std::vector<Axis> DynamicAxesFromInternalDynamicAxisName(const std::wstring& internalDynamicAxisName);

    // Construct the dynamic axis name to be used internally for the CNTK InputNodes
    std::wstring InternalDynamicAxisNameFromDynamicAxes(const std::vector<Axis>& dynamicAxes);

    std::shared_ptr<std::fstream> GetFstream(const std::wstring& filePath, bool readOnly);
    int GetFileDescriptor(const std::wstring& filePath, bool readOnly);

    std::string ToString(const std::wstring& wstring);
    std::wstring ToWString(const std::string& string);


    std::pair<size_t, size_t> GetNumTimeStepsAndSequences(const NDShape& maskShape, size_t numDynamicAxes);

    inline size_t ShapeRowColSplitPoint(const NDShape& varShape, bool isSparse, bool noDynamicAxes)
    {
        if (isSparse || noDynamicAxes)
            return std::min<size_t>(varShape.Rank(), 1);
        else
            return varShape.Rank();
    }

    inline size_t VariableRowColSplitPoint(const Variable& var)
    {
        return ShapeRowColSplitPoint(var.Shape(), var.IsSparse(), var.DynamicAxes().empty());
    }

    bool IsPackedValue(const ValuePtr& value);

    inline NDShape GetVariableShape(const NDShape& varShape, const Microsoft::MSR::CNTK::TensorShape& computationNodeShape)
    {
        auto fullyDefinedVarShape = varShape;
        if (computationNodeShape.GetRank() < fullyDefinedVarShape.Rank())
            LogicError("Computation node tensor shape '%s' must not be of lower rank than variable shape '%S'.", ((std::string)computationNodeShape).c_str(), fullyDefinedVarShape.AsString().c_str());

        for (size_t i = 0; i < fullyDefinedVarShape.Rank(); ++i)
        {
            if ((fullyDefinedVarShape[i] == NDShape::FreeDimension) || (fullyDefinedVarShape[i] == NDShape::InferredDimension))
                fullyDefinedVarShape[i] = computationNodeShape.GetDim(i);
            else if (fullyDefinedVarShape[i] != computationNodeShape.GetDim(i))
                LogicError("Computation node tensor shape '%s' does not match variable shape '%S'.", ((std::string)computationNodeShape).c_str(), fullyDefinedVarShape.AsString().c_str());
        }

        for (size_t i = fullyDefinedVarShape.Rank(); i < computationNodeShape.GetRank(); ++i)
        {
            if (computationNodeShape.GetDim(i) != 1)
                LogicError("Computation node tensor shape '%s' does not match variable shape '%S'.", ((std::string)computationNodeShape).c_str(), fullyDefinedVarShape.AsString().c_str());
        }

        return fullyDefinedVarShape;
    }

    std::vector<Axis> GetSqueezableAxes(const NDShape& inputShape);

    NDShape GetSqueezedShape(const NDShape& inputShape, const std::vector<Axis>& axes);

    NDShape GetSqueezedShape(const NDShape& inputShape);

    NDShape GetSqueezedShape(const NDShape& inputShape, const Dictionary& squeezeConfig);

    NDMaskPtr CreateMask(const std::vector<size_t>& sequenceLengths, const std::vector<bool>& sequenceStartFlags = {}, const DeviceDescriptor& device = DeviceDescriptor::CPUDevice());

    double ReductionIdentityValue(const std::wstring& reductionOpName);

    // Helper class to manage a collection of learners.
    class Learners
    {
    public:
        explicit Learners(const std::vector<LearnerPtr>& learners);

        bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, size_t trainingSampleCount, bool sweepEnd);
        bool Update(std::unordered_map<Parameter, NDArrayViewPtr>& gradientValues, MinibatchInfo& minibatchInfo);

        std::vector<DictionaryValue> CreateCheckpoint();

        void RestoreFromCheckpoint(const std::vector<DictionaryValue>&);

        const std::vector<LearnerPtr>& ParameterLearners() const
        {
            return m_learners;
        }

        const LearnerPtr& GetMetricAggregatingLearner() const;

        std::unordered_set<Parameter> GetParameters() const
        {
            std::unordered_set<Parameter> result;
            for (auto l : m_learners)
            {
                const auto& p = l->Parameters();
                result.insert(p.begin(), p.end());
            }
            return result;
        }

        bool IsDistributed() const
        {
            return m_isDistributed;
        }

        std::function<void(NDArrayViewPtr&, NDArrayViewPtr&)> DoAggregateMetricsIfNeededLambda;
        
    private:
        void GetLearnerGradients(LearnerPtr learner, const std::unordered_map<Parameter, NDArrayViewPtr>& allGradients, std::unordered_map<Parameter, NDArrayViewPtr>& learnerGradients);
        void CheckDistributedLearners();

        std::vector<LearnerPtr> m_learners;
        bool m_isDistributed;
        LearnerPtr m_metricAggregatingLearner;
    };

    class Utils
    {
    public:
        static Axis NewDynamicAxisDerivedFromOperand(const std::wstring& axisNamePrefix, const Variable& operand)
        {
            std::function<Variable(const Variable&)> GetActualSourceVariable;
            GetActualSourceVariable = [&GetActualSourceVariable](const Variable& var) -> Variable {
                if (var.BlockFunctionVariableMapping() == Variable())
                    return var;
                else
                    return GetActualSourceVariable(var.BlockFunctionVariableMapping());
            };

            auto whereNodeConditionSourceVar = GetActualSourceVariable(operand);
            return Axis(axisNamePrefix + whereNodeConditionSourceVar.Uid());
        }
        static void VerifyVariableValueCompatibility(const Variable& var, const ValuePtr& value, NDShape* inferredVarShape = nullptr);

        template <typename ElementType>
        static std::pair<std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>>, Microsoft::MSR::CNTK::MBLayoutPtr>
        GetCNTKImplMatrixAndMBLayoutFromValueObject(const Variable& var, const ValuePtr& value, NDShape* inferredVarShape,
                                                    const std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>>& outputMatrixStorage,
                                                    const std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>>& tempIndicesStorage);

        template <typename ElementType>
        static std::pair<std::shared_ptr<const Microsoft::MSR::CNTK::Matrix<ElementType>>, Microsoft::MSR::CNTK::MBLayoutPtr>
        GetCNTKImplMatrixAndMBLayoutFromValueObject(const Variable& var, const ValuePtr& value, NDShape* inferredVarShape = nullptr)
        {
            auto nullSharedPtr = std::shared_ptr<Microsoft::MSR::CNTK::Matrix<ElementType>>(nullptr);
            return GetCNTKImplMatrixAndMBLayoutFromValueObject(var, value, inferredVarShape, nullSharedPtr, nullSharedPtr);
        }

        template <typename ElementType>
        static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(const NDShape& sampleShape, const std::vector<Axis>& sampleDynamicAxes, const Microsoft::MSR::CNTK::Matrix<ElementType>& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true);

        template <typename ElementType>
        static ValuePtr GetValueObjectFromCNTKImplMatrixAndMBLayout(const Variable& var, const Microsoft::MSR::CNTK::ComputationNodeBasePtr& computationNode, const Microsoft::MSR::CNTK::Matrix<ElementType>& matrix, const Microsoft::MSR::CNTK::MBLayoutPtr& layout, bool readOnly = true);
    };

    template <typename Container>
    inline std::wstring NamedListString(const Container& namedList)
    {
        std::wstringstream wss;
        bool first = true;
        for (auto namedObject : namedList)
        {
            if (!first)
                wss << L", ";

            wss << namedObject.AsString();
            first = false;
        }

        return wss.str();
    }

    class Accumulator : public Value
    {
    public:
        Accumulator() : Value(nullptr), m_numUpdates(0), m_isInitialized(false) {}

        void Update(const ValuePtr& delta, const DeviceDescriptor& device);
        void Reset();
        bool IsInitialized() { return m_isInitialized; }
    private:
        void ResetToZero();

        bool m_isInitialized;
        size_t   m_numUpdates;
    };

    std::wstring DynamicAxesAsString(const std::vector<Axis>& da, bool rowMajor = false);

    template <typename T> //T can be Variable or StreamInfo
    static bool IsAtSweepEnd(const std::unordered_map<T, MinibatchData>& arguments)
    {
        if (arguments.empty()) return true;

        return std::any_of(arguments.begin(), arguments.end(), [](const std::pair<const T, MinibatchData>& kv)
        {
            return kv.second.sweepEnd;
        });
    }

    // half is V1 ElemType, so specialize here instead of in CNTKLibrary.h
    template<>
    inline DataType AsDataType<half>()
    {
        return DataType::Float16;
    }
}
back to top