https://github.com/Microsoft/CNTK
Raw File
Tip revision: c48e1a4ab7b4ed309492623fd936467a9662b890 authored by Frank Seide on 29 January 2016, 20:27:10 UTC
merged with pkranen/doc
Tip revision: c48e1a4
CuDnnConvolutionEngine.h
//
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
//

#pragma once

#include "ConvolutionEngine.h"

namespace Microsoft { namespace MSR { namespace CNTK {

template <class ElemType>
class CuDnnConvolutionEngineFactory : public ConvolutionEngineFactory<ElemType>
{
public:
    using Base = ConvolutionEngineFactory<ElemType>;
    using typename Base::Tensor4D;
    using typename Base::Tensor4DPtr;
    using typename Base::Filter;
    using typename Base::FilterPtr;
    using typename Base::ConvDesc;
    using typename Base::ConvDescPtr;
    using typename Base::PoolDesc;
    using typename Base::PoolDescPtr;

    using typename Base::ConvEnginePtr;
    using typename Base::PoolEnginePtr;

public:
    Tensor4DPtr CreateTensor(size_t w, size_t h, size_t c, size_t n) override;
    FilterPtr CreateFilter(size_t w, size_t h, size_t c, size_t k) override;
    ConvDescPtr CreateConvDescriptor(const Tensor4D& inT, const Filter& filterT,
                                     size_t wStride, size_t hStride, bool padding) override;
    PoolDescPtr CreatePoolDescriptor(typename PoolDesc::PoolKind kind, size_t w, size_t h, size_t wStride, size_t hStride, size_t wPad, size_t hPad) override;

    ConvEnginePtr CreateConvEngine(DEVICEID_TYPE deviceId, size_t maxTempMemSizeInSamples) override;
    PoolEnginePtr CreatePoolEngine(DEVICEID_TYPE deviceId) override;

    static bool IsSupported(DEVICEID_TYPE deviceId);
};
} } }
back to top