Raw File
CudaMemoryProvider.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include <memory>
#include <CUDAPageLockedMemAllocator.h>

#include "MemoryProvider.h"

namespace CNTK {

/// TODO: Memory provider should reside on the matrix. It is responsibility of the network
/// to decide what memory to use per stream. This class will be moved in the near future.
class CudaMemoryProvider : public MemoryProvider
{
    std::unique_ptr<Microsoft::MSR::CNTK::CUDAPageLockedMemAllocator> m_allocator;

public:
    CudaMemoryProvider(int deviceId)
    {
        m_allocator = std::make_unique<Microsoft::MSR::CNTK::CUDAPageLockedMemAllocator>(deviceId);
    }

    virtual void* Alloc(size_t elementSize, size_t numberOfElements) override
    {
        size_t totalSize = elementSize * numberOfElements;
        return m_allocator->Malloc(totalSize);
    }

    virtual void Free(void* p) override
    {
        if (!p)
        {
            return;
        }

        m_allocator->Free(reinterpret_cast<char*>(p));
    }
};
}
back to top