https://github.com/Microsoft/CNTK
Raw File
Tip revision: c256b2de4d3ba6da64f1bf0ef0e48234c311ef8b authored by Jacob DeWitt on 31 October 2017, 23:09:23 UTC
Placeholder header file for plain C API that will come in a later commit.
Tip revision: c256b2d
CuDnnFactories.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include "ConvolutionEngine.h"
#include "BatchNormalizationEngine.h"

namespace Microsoft { namespace MSR { namespace CNTK {

template <class ElemType>
class CuDnnConvolutionEngineFactory
{
public:
    static std::unique_ptr<ConvolutionEngine<ElemType>> Create(ConvolveGeometryPtr geometry, DEVICEID_TYPE deviceId,
                                                               ImageLayoutKind imageLayout, size_t maxTempMemSizeInSamples,
                                                               PoolKind poolKind, bool forceDeterministicAlgorithms, 
                                                               bool poolIncludePad, bool inputHasFreeDimension);
    static bool IsSupported(DEVICEID_TYPE deviceId, ConvolveGeometryPtr geometry, PoolKind poolKind);
};

template <class ElemType>
class CuDnnBatchNormEngineFactory
{
public:
    static std::unique_ptr<BatchNormEngine<ElemType>> Create(DEVICEID_TYPE deviceId, const TensorShape& inOutT,
                                                             bool spatial, ImageLayoutKind imageLayout);
};

// REVIEW alexeyk: wrong place? It is currently used only in unit tests but I can't add it there because of the build issues.
// Timer that can be used to measure CUDA calls. 
// Uses CUDA event and will synchronize(!) the stream when Stop is called.
class MATH_API CudaTimer
{
public:
    CudaTimer(): m_start(nullptr), m_stop(nullptr)
    {
    }
    ~CudaTimer();
    void Start();
    void Stop();
    float Elapsed();

    DISABLE_COPY_AND_MOVE(CudaTimer);
private:
    void* m_start;
    void* m_stop;
};

} } }
back to top