https://github.com/Microsoft/CNTK
Raw File
Tip revision: 36f457ca562c618b1ebe0e4f1eed0f96b9c775bc authored by Mark Hillebrand on 01 February 2017, 17:11:45 UTC
Fix link
Tip revision: 36f457c
MatrixPool.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include <string>
#include <stdexcept>
#include <vector>
#include <algorithm>
#include <stdlib.h>

#include "Basics.h"
#include "Matrix.h"
#include "ComputationNode.h"

namespace Microsoft { namespace MSR { namespace CNTK {

// MatrixPool -- class to support memory sharing
// Despite the gather general name of this class, it is specifically designed to support the memory sharing of ComputationNodes.
// Note: see #define SUPRESS_MEMSHARING below as for how to temporarily disable memory sharing altogether, for debugging
class MatrixPool
{
    vector<shared_ptr<Matrix<float>>>  m_releasedFloatMatrices;
    vector<shared_ptr<Matrix<double>>> m_releasedDoubleMatrices;

    template <class ElemType>
    vector<shared_ptr<Matrix<ElemType>>>& GetReleasedMatrices();

public:
    // release here means the matrix can be put back and shared by others
    template <class ElemType>
    void Release(shared_ptr<Matrix<ElemType>> freeMatrix)
    {
        if (freeMatrix == nullptr || freeMatrix->GetMatrixType() == SPARSE)
            LogicError("MatrixPool::Release: freeMatrix should not be null or sparse.");
//#define SUPRESS_MEMSHARING // #define this to disable memory sharing through this structure
        // TODO: Make this a runtime option.
#ifndef SUPRESS_MEMSHARING
        vector<shared_ptr<Matrix<ElemType>>>& releasedMatrices = GetReleasedMatrices<ElemType>();
#ifdef _DEBUG
        for (int i = 0; i < releasedMatrices.size(); i++)
        {
            if (releasedMatrices[i] == freeMatrix)
                RuntimeError("MatrixPool::Release: freeMatrix is already in the released pool.");
        }

#endif
        releasedMatrices.push_back(freeMatrix);
#endif
    }

    template <class ElemType>
    shared_ptr<Matrix<ElemType>> Request(DEVICEID_TYPE deviceId)
    {
        vector<shared_ptr<Matrix<ElemType>>>& releasedMatrices = GetReleasedMatrices<ElemType>();
        shared_ptr<Matrix<ElemType>> matrixPtr;
        if (releasedMatrices.empty())
        {
            matrixPtr = make_shared<Matrix<ElemType>>(deviceId);
        }
        else
        {
            matrixPtr = releasedMatrices.back();
            releasedMatrices.pop_back();
        }

        if (!matrixPtr) // this can't really happen
            LogicError("MatrixPool::Request: failed to get a valid matrix.");

        return matrixPtr;
    }
};

}}}
back to top