// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // #include "stdafx.h" #include "Basics.h" #include "BestGpu.h" #ifndef CPUONLY #include "GPUMatrix.h" #include "GPUMatrixCUDAKernels.cuh" //#include "GPUSparseMatrix.h" #include "GPUTensor.h" #include "CommonMatrix.h" #define TENSOR_OPS_DECL __device__ __host__ #include "TensorOps.h" #include "device_launch_parameters.h" #include #include #include #include #include "cublas_v2.h" #include #include #include "CntkBatchNormalization.cuh" #include "Convolution.cuh" #include "CuDnnRNN.h" #pragma comment(lib, "cudart.lib") // instruct linker to reference these libs #pragma comment(lib, "cublas.lib") #pragma comment(lib, "cusparse.lib") #pragma comment(lib, "curand.lib") #pragma warning(disable : 4267) // conversion from 'size_t' to 'unsigned int'; happens in CUDA <<>> syntax if a and b are size_t #pragma warning(disable : 4127) // conditional expression is constant; "if (sizeof(ElemType)==sizeof(float))" triggers this #pragma warning(disable : 4702) // unreachable code; triggered for unknown reasons #define DEFAULT_THREAD_PER_DIM 16 #define UNCONST(t, c, uc) GPUMatrix& uc = const_cast&>(c); #ifdef _WIN32 // thread local storage to access the current stream, initialize to default stream __declspec(thread) #endif cudaStream_t t_stream = cudaStreamDefault; #define DEFAULT_THREAD_PER_DIM 16 extern int _ConvertSMVer2Cores(int major, int minor); // forward declaration // SetStream - set the stream that will be used by the GPU routines void MATH_API SetStream(cudaStream_t stream) { t_stream = stream; } // GetStream - get the stream that will be used by the GPU routines cudaStream_t MATH_API GetStream() { return t_stream; } // Helper macro patterns for elementwise methods #define DEF_ELEMWISE_INPLACE_FUNC(f) \ template \ GPUMatrix& GPUMatrix::Inplace##f() \ { \ performElementWiseFunction(ElementWiseOperator::op##f, Data()); \ return *this; \ } #define DEF_ELEMWISE_ASSIGN_FUNC(f) \ template \ GPUMatrix& GPUMatrix::Assign##f##Of(const GPUMatrix& a) \ { \ if (a.IsEmpty()) \ LogicError("Assign##f##Of: Matrix a is empty."); \ if (this != &a) \ RequireSize(a.GetNumRows(), a.GetNumCols()); \ performElementWiseFunction(ElementWiseOperator::op##f, a.Data()); \ return *this; \ } template <> const char* CudaErrString(cudaError_t x) { cudaDeviceSynchronize(); return cudaGetErrorString(x); } template <> const char* CudaErrString(cublasStatus_t e) { cudaDeviceSynchronize(); switch (e) { case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; default: return "(look for CUBLAS_STATUS_xxx in cublas_api.h)"; } } template <> const char* CudaErrString(curandStatus) { cudaDeviceSynchronize(); return "(see curand.h & look for curandStatus or CURAND_STATUS_xxx)"; } namespace Microsoft { namespace MSR { namespace CNTK { /*static*/ std::vector GridDim::s_cachedDeviceProps; /*static*/ std::once_flag GridDim::s_cachedDevicePropsInitFlag; /*static*/ bool SyncGuard::s_isSyncEnabled = false; /*static*/ void SyncGuard::EnableSync() { s_isSyncEnabled = true; } /*static*/ bool SyncGuard::IsSyncEnabled() { return s_isSyncEnabled; } SyncGuard::SyncGuard(bool forceSync /*= false*/) : m_forceSync(forceSync) { m_done = nullptr; if (m_forceSync || s_isSyncEnabled) { CUDA_CALL(cudaGetLastError()); CUDA_CALL(cudaEventCreate(&m_done)); } } SyncGuard::~SyncGuard() { if (m_forceSync || s_isSyncEnabled) { // The regular use of this destructor is to synchronize the GPU, but also // to check for errors. So this destructor is where CUDA errors would be thrown. // If this destructor runs during stack unwinding, then a different error has // already happened that should be reported; so we only clean up the resource. if (std::uncaught_exception()) cudaEventDestroy(m_done); else { // failures in a prior launch might be reported here CUDA_CALL(cudaEventRecord(m_done)); CUDA_CALL(cudaEventSynchronize(m_done)); CUDA_CALL(cudaEventDestroy(m_done)); } } } template AllocatedElemType* TracingGPUMemoryAllocator::Allocate(int deviceId, size_t numRows, size_t numCols) { if (IsTraceEnabled()) { auto freeAndTotalMemory = GetFreeAndTotalMemoryInMBs(deviceId); fprintf(stderr, "Allocating Matrix<%s> (Rows = %d, Cols = %d) buffer on DeviceId = %d; GPU Memory Free = %d MB of %d MB\n", typeid(AllocatedElemType).name(), (int)numRows, (int)numCols, (int)deviceId, (int)freeAndTotalMemory.first, (int)freeAndTotalMemory.second); Microsoft::MSR::CNTK::DebugUtil::PrintCallStack(); } AllocatedElemType* deviceBufferPtr = AllocateNoTrace(deviceId, numRows * numCols); if (IsTraceEnabled()) { fprintf(stderr, "Allocated DeviceData = %p\n", (void*) deviceBufferPtr); } return deviceBufferPtr; } template AllocatedElemType* TracingGPUMemoryAllocator::Allocate(int deviceId, size_t numElements) { if (IsTraceEnabled()) { auto freeAndTotalMemory = GetFreeAndTotalMemoryInMBs(deviceId); fprintf(stderr, "Allocating array<%s> (NumElements = %d) on DeviceId = %d; GPU Memory Free = %d MB of %d MB\n", typeid(AllocatedElemType).name(), (int)numElements, (int)deviceId, (int)freeAndTotalMemory.first, (int)freeAndTotalMemory.second); Microsoft::MSR::CNTK::DebugUtil::PrintCallStack(); } AllocatedElemType* deviceBufferPtr = AllocateNoTrace(deviceId, numElements); if (IsTraceEnabled()) { fprintf(stderr, "Allocated DeviceData = %p\n", (void*)deviceBufferPtr); } return deviceBufferPtr; } template void TracingGPUMemoryAllocator::Free(int deviceId, AllocatedElemType* bufferPtr, bool ignoreCUDARetCode /*= false*/) { PrepareDevice(deviceId); if (ignoreCUDARetCode) cudaFree((void*) bufferPtr); else CUDA_CALL(cudaFree((void*) bufferPtr)); if (IsTraceEnabled()) { auto freeAndTotalMemory = GetFreeAndTotalMemoryInMBs(deviceId); fprintf(stderr, "Freed buffer<%s> DeviceData = %p on DeviceId = %d; GPU Memory Free = %d MB of %d MB\n", typeid(AllocatedElemType).name(), (void*) bufferPtr, (int) deviceId, (int) freeAndTotalMemory.first, (int) freeAndTotalMemory.second); Microsoft::MSR::CNTK::DebugUtil::PrintCallStack(); } } template AllocatedElemType* TracingGPUMemoryAllocator::AllocateNoTrace(int deviceId, size_t numElements) { AllocatedElemType* deviceBufferPtr; PrepareDevice(deviceId); // In case numElements is odd we allocate a buffer with one more element. The reason is // we might call curandGenerateNormal (e.g. for Gaussian noise injection) which would fail // if the number of elements it needs to generate is odd. CUDA_CALL(cudaMalloc((void**) &deviceBufferPtr, sizeof(AllocatedElemType) * AsMultipleOf(numElements, 2))); return deviceBufferPtr; } std::pair TracingGPUMemoryAllocator::GetFreeAndTotalMemoryInMBs(int deviceId) { PrepareDevice(deviceId); size_t free, total; CUDA_CALL(cudaMemGetInfo(&free, &total)); size_t numBytesPerMB = 1 << 20; return {free / numBytesPerMB, total / numBytesPerMB}; } // PrepareDevice - Setup the correct cuda context for an operation // deviceId - the device on which the operation will take place void PrepareDevice(DEVICEID_TYPE deviceId) { THREAD_LOCAL static DEVICEID_TYPE currentDevice = DEVICEID_NOTYETDETERMINED; // and if we last set the device to be this device we are good if (deviceId == currentDevice) return; CUDA_CALL(cudaSetDevice(deviceId)); currentDevice = deviceId; } #pragma region DeviceBoundNumber class template DeviceBoundNumber::DeviceBoundNumber(const DeviceBoundNumber& /*deepCopy*/) { NOT_IMPLEMENTED; } template DeviceBoundNumber::DeviceBoundNumber(DeviceBoundNumber&& shallowCopy) { ShallowCopyFrom(shallowCopy.m_data, shallowCopy.m_computeDevice); shallowCopy.m_data = NULL; } template void DeviceBoundNumber::ShallowCopyFrom(ElemType* newVal, int newValsDevceId) { m_computeDevice = newValsDevceId; m_data = newVal; } template DeviceBoundNumber::~DeviceBoundNumber() { if (m_data != NULL) { if (m_computeDevice < 0) { delete m_data; m_data = NULL; } else { TracingGPUMemoryAllocator::Free(m_computeDevice, m_data); } } } #pragma endregion DeviceBoundNumber class #pragma region Helper functions template cublasHandle_t _initCUBLAS(int devId) { PrepareDevice((DEVICEID_TYPE) devId); cublasHandle_t cuHandle; CUBLAS_CALL(cublasCreate(&cuHandle)); return cuHandle; } template void GPUMatrix::SetDevice(DEVICEID_TYPE deviceId) { assert(deviceId >= 0); CUDA_CALL(cudaSetDevice(deviceId)); } // PrepareDevice - Setup the correct cuda context for an operation // deviceId - the device on which the operation will take place // defaults to -1, which means use matrices current device template DEVICEID_TYPE GPUMatrix::PrepareDevice(DEVICEID_TYPE deviceId /*=-1*/) const { // if default value use current compute device DEVICEID_TYPE newId = deviceId >= 0 ? deviceId : GetComputeDeviceId(); Microsoft::MSR::CNTK::PrepareDevice(newId); return newId; } template ElemType* GPUMatrix::CopyToArray() const { size_t numElements = GetNumElements(); if (numElements != 0) { PrepareDevice(); ElemType* pArray = new ElemType[numElements]; CUDA_CALL(cudaMemcpy(pArray, Data(), sizeof(ElemType) * m_numRows * m_numCols, cudaMemcpyDeviceToHost)); return pArray; } else { return NULL; } } //memory will be allocated by the callee if not enough but need to be deleted by the caller after it's done //return number of elements copied template size_t GPUMatrix::CopyToArray(ElemType*& arrayCopyTo, size_t& currentArraySize) const { size_t numElements = GetNumElements(); if (numElements > currentArraySize) { delete arrayCopyTo; arrayCopyTo = new ElemType[numElements]; currentArraySize = numElements; } if (numElements != 0) { PrepareDevice(); CUDA_CALL(cudaMemcpy(arrayCopyTo, Data(), sizeof(ElemType) * numElements, cudaMemcpyDeviceToHost)); } return numElements; } template void GPUMatrix::CopySection(size_t numRows, size_t numCols, ElemType* dst, size_t colStride) const { CUBLAS_CALL(cublasGetMatrix((int) numRows, (int) numCols, sizeof(ElemType), Data(), (int) GetNumRows(), dst, (int) colStride)); } template void GPUMatrix::ChangeDeviceTo(DEVICEID_TYPE to_id) { if (to_id == CPUDEVICE) LogicError("to_id must be valid GPU"); if (GetComputeDeviceId() == to_id) return; ElemType* d_dst = TracingGPUMemoryAllocator::Allocate(to_id, m_numRows, m_numCols); SetSizeAllocated(m_numRows * m_numCols); // check to make sure we have something to copy (on init we often have zero sized allocations) if (GetSizeAllocated() > 0) { #if 0 // see the backlog item # 1220 // IOMMU DMAR needs to be disabled for CUDA P2P, otherwise it will silently hang. // Unfortunately, cudaDeviceCanAccessPeer returns true irrespective of the IOMMU settings. // More details: https://bugzilla.kernel.org/show_bug.cgi?id=188271 // http://docs.nvidia.com/cuda/gpudirect-rdma/#supported-systems // TODO: enable UVA p2p access once this is fixed. // first try peer access int canAccessPeer = false; CUDA_CALL(cudaDeviceCanAccessPeer(&canAccessPeer, to_id, GetComputeDeviceId())); if (canAccessPeer) { cudaError_t cudaStatus = cudaDeviceEnablePeerAccess(GetComputeDeviceId(), 0); if (cudaStatus != cudaErrorPeerAccessAlreadyEnabled) { CUDA_CALL(cudaStatus); } CUDA_CALL(cudaMemcpyPeer(d_dst, to_id, Data(), GetComputeDeviceId(), sizeof(ElemType) * m_numRows * m_numCols)); } else #endif { // peer access didn't work, just copy normal // make this more efficient by keeping some buffers available for each copy ElemType* h_dst = NULL; PrepareDevice(); CUDA_CALL(cudaMallocHost((void**) &h_dst, sizeof(ElemType) * m_numRows * m_numCols)); CUDA_CALL(cudaMemcpy(h_dst, Data(), sizeof(ElemType) * m_numRows * m_numCols, cudaMemcpyDeviceToHost)); PrepareDevice((DEVICEID_TYPE) to_id); CUDA_CALL(cudaMemcpy(d_dst, h_dst, sizeof(ElemType) * m_numRows * m_numCols, cudaMemcpyHostToDevice)); CUDA_CALL(cudaFreeHost(h_dst)); } } TracingGPUMemoryAllocator::Free(GetComputeDeviceId(), Buffer()); SetBuffer(d_dst, m_numRows * m_numCols * sizeof(ElemType)); PrepareDevice((DEVICEID_TYPE) to_id); SetComputeDeviceId(to_id); } template void GPUMatrix::performElementWiseFunction(ElementWiseOperator kind, const ElemType* src) { PrepareDevice(); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; switch (kind) { case ElementWiseOperator::opSigmoid: return _elementWiseSigmoidOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opTanh: return _elementWiseTanhOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opSqrt: return _elementWiseSqrtOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opExp: return _elementWiseExpOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opLog: return _elementWiseLogOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opAbs: return _elementWiseAbsOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opLinearRectifierDerivative: return _elementWiseLinRectDerivativeOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opCosine: return _elementWiseCosineOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opNegativeSine: return _elementWiseNegativeSineOnCuda<<>>(src, Data(), N); case ElementWiseOperator::opSigmoidDerivative: return _elementWiseSigmoidDerivativeOnCuda<<>>(src, Data(), N); default: LogicError("performElementWiseFunction: unexpected op code %d", (int)kind); } } #pragma endregion Helper functions #pragma region Constructors and Destructor // should only be used by constructors template void GPUMatrix::ZeroInit(int deviceId) { BaseMatrix::ZeroInit(); SetComputeDeviceId(deviceId); } template GPUMatrix::GPUMatrix(int deviceId) { ZeroInit(deviceId); }; template GPUMatrix::GPUMatrix(const size_t numRows, const size_t numCols, int deviceId) { ZeroInit(deviceId); m_numRows = numRows; m_numCols = numCols; SetSizeAllocated(GetNumElements()); if (GetNumElements() != 0) { SetBuffer(TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), m_numRows, m_numCols), GetNumElements() * sizeof(ElemType)); CUDA_CALL(cudaMemset(Buffer(), 0, sizeof(ElemType) * GetSizeAllocated())); } }; template GPUMatrix::GPUMatrix(const size_t numRows, const size_t numCols, int deviceId, ElemType* pArray, const size_t matrixFlags) { ZeroInit(deviceId); SetValue(numRows, numCols, deviceId, pArray, matrixFlags); }; template GPUMatrix::GPUMatrix(const GPUMatrix& deepCopyFrom) { ZeroInit(); SetValue(deepCopyFrom); } template GPUMatrix::GPUMatrix(GPUMatrix&& moveFrom) { ShallowCopyFrom(moveFrom); moveFrom.ZeroValues(); } //assignment operator, deep copy template GPUMatrix& GPUMatrix::operator=(const GPUMatrix& deepCopyFrom) { if (this != &deepCopyFrom) { SetValue(deepCopyFrom); } return *this; } //move assignment operator, shallow copy template GPUMatrix& GPUMatrix::operator=(GPUMatrix&& moveFrom) { if (this != &moveFrom) { ShallowCopyFrom(moveFrom); moveFrom.ZeroValues(); } return *this; } template GPUMatrix::~GPUMatrix(void) { } // TODO: This should be in the storage object. // Clear will clear your storage, zeroinit just drops it on the ground. template void GPUMatrix::Clear() { VerifyWritable(__FUNCTION__); //if (OwnBuffer() && m_pArray != NULL) if (m_sob != nullptr) { if (GetComputeDeviceId()>= 0) { // BUG: We do not check the CUDA return code for cudaFree here since this may get called // during processExit when cudaFree will fail. The destruction of CUDA objects during // process exit must be avoided ReleaseStorageMemory(); } } ZeroInit(GetComputeDeviceId()); } #pragma endregion Constructors and Destructor template std::unique_ptr> GPUMatrix::GetOrCreateWorkspace() const { // REVIEW alexeyk: not thread-safe, fine for now. if (m_workspace == nullptr) m_workspace = std::make_unique>>>(); assert(m_workspace != nullptr); auto deviceId = GetComputeDeviceId(); return m_workspace->pop_or_create([deviceId]() { return std::make_unique>(deviceId); }); } template void GPUMatrix::ReleaseWorkspace(std::unique_ptr> src) const { assert(m_workspace != nullptr); m_workspace->push(std::move(src)); } #pragma region Basic Operators template GPUMatrix GPUMatrix::ColumnSlice(size_t startColumn, size_t numCols) const { if (startColumn + numCols > GetNumCols()) InvalidArgument("The slice (%d+%d) is out of range of the source matrix (%d).", (int) startColumn, (int) numCols, (int) GetNumCols()); GPUMatrix slice(GetComputeDeviceId()); slice.ShallowCopyFrom(*this); slice.m_numCols = numCols; slice.m_sliceViewOffset = m_sliceViewOffset + startColumn * GetNumRows(); return slice; } template GPUMatrix& GPUMatrix::AssignColumnSlice(const GPUMatrix& fromMatrix, size_t startColumn, size_t numCols) { if (numCols == 0) LogicError("The slice cannot have 0 columns."); if (startColumn + numCols > fromMatrix.GetNumCols()) InvalidArgument("The slice (%d+%d) is out of range of the source matrix (%d).", (int) startColumn, (int) numCols, (int) fromMatrix.GetNumCols()); Clear(); ShallowCopyFrom(fromMatrix); m_numCols = numCols; m_sliceViewOffset = fromMatrix.m_sliceViewOffset + startColumn * GetNumRows(); return *this; } template GPUMatrix& GPUMatrix::SetColumnSlice(const GPUMatrix& fromMatrix, size_t startColumn, size_t numCols) { if (startColumn + numCols > GetNumCols()) LogicError("The slice is out of range of the destination matrix."); if (numCols > fromMatrix.GetNumCols()) InvalidArgument("The slice (%d) is out of range of the source matrix (%d).", (int) numCols, (int) fromMatrix.GetNumCols()); if (m_numRows != fromMatrix.m_numRows) LogicError("The number of rows in source and destination matrices do not match"); if (m_numRows * numCols > 0) // TODO: remove if unnecessary CUDA_CALL(cudaMemcpy(Data() + LocateColumn(startColumn), fromMatrix.Data(), sizeof(ElemType) * m_numRows * numCols, cudaMemcpyDeviceToDevice)); return *this; } template void GPUMatrix::CopyColumnsStrided(const GPUMatrix& fromMatrix, size_t numCols, size_t srcNumColsStride, size_t destNumColsStride) { if ((((numCols - 1) * srcNumColsStride) + 1) > fromMatrix.m_numCols) LogicError("The numCols to copy and srcNumColsStride specified is out of range of the source matrix."); if ((((numCols - 1) * destNumColsStride) + 1) > m_numCols) LogicError("The numCols to copy and srcNumColsStride specified is out of range of the destination matrix."); if (m_numRows != fromMatrix.m_numRows) LogicError("The number of rows in source and destination matrices do not match"); if ((m_numRows * numCols) > 0) { // Launch a kernel to do the strided copy CUDA_LONG N = (CUDA_LONG)(m_numRows * numCols); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _copyColumnsStrided<<>>(Data(), fromMatrix.Data(), N, (CUDA_LONG) m_numRows, (CUDA_LONG) destNumColsStride, (CUDA_LONG) srcNumColsStride); } } //for each column of a, we assign all rows of a to this starting from startIndex template GPUMatrix& GPUMatrix::AssignToRowSliceValuesOf(const GPUMatrix& a, const size_t startIndex, const size_t numRows) { if (a.IsEmpty()) LogicError("AddToRowSliceValuesOf: input matrix a is empty."); if (a.GetNumRows() != numRows) LogicError("AddToRowSliceValuesOf: a.GetNumRows() != numRows."); if (startIndex + numRows > GetNumRows()) LogicError("AddToRowSliceValuesOf: startIndex + numRows exceeds GetNumRows()."); if (a.GetNumCols() != GetNumCols()) LogicError("AddToRowSliceValuesOf: columns does not match."); CUDA_LONG N = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _assignToRowSliceValuesOf<<>>(Data(), a.Data(), N, (CUDA_LONG) startIndex, (CUDA_LONG) GetNumRows(), (CUDA_LONG) a.GetNumRows()); return *this; } //for each column of a, we assign numRows starting from startIndex to this template GPUMatrix& GPUMatrix::AssignRowSliceValuesOf(const GPUMatrix& a, const size_t startIndex, const size_t numRows) { if (a.IsEmpty()) LogicError("AssignRowSliceValuesOf: input matrix a is empty."); if (startIndex + numRows > a.GetNumRows()) LogicError("AssignRowSliceValuesOf: startIndex + numRows exceeds a.GetNumRows()."); RequireSize(numRows, a.GetNumCols()); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _assignRowSliceValuesOf<<>>(Data(), a.Data(), N, (CUDA_LONG) startIndex, (CUDA_LONG) numRows, (CUDA_LONG) a.GetNumRows()); return *this; } //for the row slice of this starting from startIndex we add a to it. template GPUMatrix& GPUMatrix::AddToRowSliceValuesOf(const GPUMatrix& a, const size_t startIndex, const size_t numRows) { if (a.IsEmpty()) LogicError("AddToRowSliceValuesOf: input matrix a is empty."); if (a.GetNumRows() != numRows) LogicError("AddToRowSliceValuesOf: a.GetNumRows() != numRows."); if (startIndex + numRows > GetNumRows()) LogicError("AddToRowSliceValuesOf: startIndex + numRows exceeds GetNumRows()."); if (a.GetNumCols() != GetNumCols()) LogicError("AddToRowSliceValuesOf: columns does not match."); CUDA_LONG N = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _addToRowSliceValuesOf<<>>(Data(), a.Data(), N, (CUDA_LONG) startIndex, (CUDA_LONG) GetNumRows(), (CUDA_LONG) a.GetNumRows()); return *this; } //for each column of this, we add row slice of a starting from startIndex template GPUMatrix& GPUMatrix::AddWithRowSliceValuesOf(const GPUMatrix& a, const size_t startIndex, const size_t numRows) { if (a.IsEmpty()) LogicError("AddWithRowSliceValuesOf: input matrix a is empty."); if (GetNumRows() != numRows) LogicError("AddWithRowSliceValuesOf: GetNumRows() != numRows."); if (startIndex + numRows > a.GetNumRows()) LogicError("AddWithRowSliceValuesOf: startIndex + numRows exceeds a.GetNumRows()."); if (a.GetNumCols() != GetNumCols()) LogicError("AddWithRowSliceValuesOf: columns does not match."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _addWithRowSliceValuesOf<<>>(Data(), a.Data(), N, (CUDA_LONG) startIndex, (CUDA_LONG) GetNumRows(), (CUDA_LONG) a.GetNumRows()); return *this; } template GPUMatrix GPUMatrix::Diagonal() const { size_t m = GetNumRows(); size_t n = GetNumCols(); if (m != n) LogicError("Diagonal can be called only for square matrix. (rows=%d, cols=%d)", (int) m, (int) n); GPUMatrix diag(1, n, GetComputeDeviceId()); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _assignToDiagonalValuesOf<<>>(diag.Data(), Data(), N, (CUDA_LONG) n); return diag; } // c = c - 1.0 for a specific position template void GPUMatrix::MinusOneAt(GPUMatrix& c, const size_t position) { assert(position < c.GetNumElements()); CUDA_LONG n = (CUDA_LONG) c.GetNumElements(); CUDA_LONG p = (CUDA_LONG) position; int blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); // BUGBUG: PrepareDevice() missing? SyncGuard syncGuard; _minusOneAt<<>>(c.Data(), p, n); } template GPUMatrix& GPUMatrix::AssignRepeatOf(const GPUMatrix& a, const size_t numRowRepeats, const size_t numColRepeats) { if (this == &a) LogicError("AssignRepeatOf: a is the same as [this]. Does not support inplace repeat."); if (a.IsEmpty()) LogicError("AssignRepeatOf: Matrix a is empty."); RequireSize(a.GetNumRows() * numRowRepeats, a.GetNumCols() * numColRepeats); CUDA_LONG N = (CUDA_LONG) GetNumElements(); CUDA_LONG n = (CUDA_LONG) a.GetNumCols(), m = (CUDA_LONG) a.GetNumRows(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _assignRepeatOf<<>>(Data(), a.Data(), N, m, n, (CUDA_LONG) GetNumRows()); return *this; } template GPUMatrix& GPUMatrix::AddToRowRepeatValuesOf(const GPUMatrix& a, const size_t numRepeats) { if (a.IsEmpty()) LogicError("AddToRowRepeatValuesOf: input matrix a is empty."); if (a.GetNumRows() != GetNumRows() * numRepeats) LogicError("AddToRowSliceValuesOf: a.GetNumRows() != GetNumRows() * numRepeats."); RequireSize(a.GetNumRows() / numRepeats, a.GetNumCols()); CUDA_LONG N = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _addToRowRepeatValuesOf<<>>(Data(), a.Data(), N, (CUDA_LONG) a.GetNumRows(), (CUDA_LONG) a.GetNumCols(), (CUDA_LONG) GetNumRows()); return *this; } template GPUMatrix& GPUMatrix::AssignPositiveAndShiftedNegSample(const GPUMatrix& a, const size_t posNumber, const size_t negNumber, const size_t shiftNumber) { if (this == &a) LogicError("AssignPositiveAndShiftedNegSample: a is the same as [this]. Does not support inplace assignment."); if (a.IsEmpty()) LogicError("AssignPositiveAndShiftedNegSample: Matrix a is empty."); RequireSize(a.GetNumRows() * (posNumber + negNumber), a.GetNumCols()); CUDA_LONG N = (CUDA_LONG) GetNumElements(); CUDA_LONG n = (CUDA_LONG) a.GetNumCols(), m = (CUDA_LONG) a.GetNumRows(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _assignPositiveAndShiftedNegSample<<>>(Data(), a.Data(), N, m, n, (CUDA_LONG) GetNumRows(), posNumber, shiftNumber); return *this; } template GPUMatrix& GPUMatrix::AddFoldedPositiveAndShiftedNegSample(const GPUMatrix& a, const size_t posNumber, const size_t negNumber, const size_t shiftNumber) { if (this == &a) LogicError("AddFoldedPositiveAndShiftedNegSample: a is the same as [this]. Does not support inplace assignment."); if (a.IsEmpty()) LogicError("AddFoldedPositiveAndShiftedNegSample: Matrix a is empty."); if (a.GetNumRows() != GetNumRows() * (posNumber + negNumber) || a.GetNumCols() != GetNumCols()) LogicError("AddFoldedPositiveAndShiftedNegSample: dimensions mismatch."); CUDA_LONG N = (CUDA_LONG) a.GetNumElements(); CUDA_LONG n = (CUDA_LONG) a.GetNumCols(), m = (CUDA_LONG) a.GetNumRows(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _addFoldedPositiveAndShiftedNegSample<<>>(Data(), a.Data(), N, m, n, (CUDA_LONG) GetNumRows(), posNumber, shiftNumber); return *this; } template GPUMatrix GPUMatrix::Transpose() const { if (IsEmpty()) LogicError("Transpose: Matrix is empty."); GPUMatrix c(GetComputeDeviceId()); c.AssignTransposeOf(*this); return c; } // GetCublasHandle - get a cublas handle for the given GPU, should only need one per GPU // computeDevice - The compute device for which the cublas handle is desired // returns: cublas handle // NOTE: we currently don't bother to ever free the CUBLAS handle, it will be freed automatically by CUDA when the process ends template cublasHandle_t GPUMatrix::GetCublasHandle(int computeDevice /*=-1*/) { // if the compute device is not passed, get the current device from CUDA if (computeDevice < 0) cudaGetDevice(&computeDevice); if (computeDevice < 0 || computeDevice >= MaxGpus) LogicError("GetCublasHandle: Maximum GPU exceeded"); cublasHandle_t cuHandle = s_cuHandle[computeDevice]; if (cuHandle == NULL) { s_cuHandle[computeDevice] = cuHandle = _initCUBLAS(computeDevice); } CUBLAS_CALL(cublasSetStream(cuHandle, t_stream)); return cuHandle; } template GPUMatrix& GPUMatrix::AssignTransposeOf(const GPUMatrix& a) { if (this == &a) LogicError("AssignTransposeOf: a is the same as [this]. Does not support inplace transpose."); if (a.IsEmpty()) LogicError("AssignTransposeOf: Matrix a is empty."); if (GetNumRows() != a.GetNumCols() || GetNumCols() != a.GetNumRows()) RequireSize(a.GetNumCols(), a.GetNumRows()); cublasHandle_t cuHandle = GetCublasHandle(a.GetComputeDeviceId()); cublasOperation_t transA = CUBLAS_OP_T; cublasOperation_t transB = CUBLAS_OP_T; int m = (int) a.m_numCols; int n = (int) a.m_numRows; ElemType alpha = 1; ElemType beta = 0; cublasStatus_t st; if (sizeof(ElemType) == sizeof(float)) st = cublasSgeam(cuHandle, transA, transB, m, n, reinterpret_cast(&alpha), reinterpret_cast(a.Data()), (int) a.m_numRows, reinterpret_cast(&beta), reinterpret_cast(a.Data()), (int) a.m_numRows, reinterpret_cast(Data()), (int) m_numRows); else if (sizeof(ElemType) == sizeof(double)) st = cublasDgeam(cuHandle, transA, transB, m, n, reinterpret_cast(&alpha), reinterpret_cast(a.Data()), (int) a.m_numRows, reinterpret_cast(&beta), reinterpret_cast(a.Data()), (int) a.m_numRows, reinterpret_cast(Data()), (int) m_numRows); else RuntimeError("Unsupported template argument in GPUMatrix"); if (st != CUBLAS_STATUS_SUCCESS) RuntimeError("AssignTransposeOf failed"); m_numRows = a.m_numCols; m_numCols = a.m_numRows; return *this; } template __global__ void _doGatherColumnsOf(ElemType* us, size_t usStride, const ElemType beta, const ElemType* idx, size_t idxStride, const ElemType* a, size_t aStride, size_t aCols, const ElemType alpha, CUDA_LONG numElements) { CUDA_LONG id = GridDim::GetLinearThreadId(); if (id >= numElements) // note: there are no __syncthread() calls inside return; // id = i + jOut * usStride; // Each thread processes one element of the output matrix. CUDA_LONG i = id % usStride; // row index into 'us' and 'a' CUDA_LONG jOut = id / usStride; // col index into 'us' and 'idx' auto jInF = idx[jOut * idxStride]; // this is the column we need to get if (::isnan(jInF) || jInF < 0) // negative index means gap return; size_t jIn = (size_t)jInF; //if (jIn >= aCols) // return; // actually a failure const ElemType& ra = a[ i + jIn * aStride ]; ElemType& rus = us[id/*i + jOut * usStride*/]; ElemType res = ra * alpha; if (beta != 0) res += rus * beta; rus = res; } // *this[:,j] = a[:,idx[j]] * alpha + *this[:,j] * beta template GPUMatrix& GPUMatrix::DoGatherColumnsOf(ElemType beta, const GPUMatrix& idx, const GPUMatrix& a, ElemType alpha) { if (idx.GetNumRows() != 1) // index is 1-dimensional only InvalidArgument("DoGatherColumnsOf: Map must be a row vector."); if (beta == 0) RequireSize(a.GetNumRows(), idx.GetNumCols()); // output has same column format as a, but number of columns comes from idx else VerifySize(a.GetNumRows(), idx.GetNumCols()); if (idx.GetComputeDeviceId() != a.GetComputeDeviceId() || GetComputeDeviceId() != a.GetComputeDeviceId()) InvalidArgument("All matrices must be on the same GPU"); a.PrepareDevice(); // launch the kernel CUDA_LONG NN = (CUDA_LONG)GetNumElements(); // linear space identifying each individual input element SyncGuard syncGuard; GridDim grid(NN); _doGatherColumnsOf<<>>(Data(), GetNumRows(), beta, idx.Data(), idx.GetNumRows(), a.Data(), a.GetNumRows(), a.GetNumCols(), alpha, grid.m_N); // Note: The following fails silently (no error, immediate or delayed) for numcols = 10000 under CUDA 7.0. //_doGatherColumnsOf<<>>(Data(), GetNumRows(), beta, idx.Data(), idx.GetNumRows(), a.Data(), a.GetNumRows(), a.GetNumCols(), alpha); return *this; } // little helper for debugging template static void Peek(const GPUMatrix& m, const char* which) { size_t rows = m.GetNumRows(); size_t cols = m.GetNumCols(); ElemType buf[10000] = { 0 }; size_t n = min(rows * cols, _countof(buf)); CUDA_CALL(cudaMemcpy(buf, m.Data(), sizeof(ElemType) * n, cudaMemcpyDeviceToHost)); UNUSED(which); UNUSED(rows); UNUSED(cols); sin(1.0f); // set breakpoint here //CUDA_CALL(cudaMemcpy(const_cast(m.Data()), buf, sizeof(ElemType) * n, cudaMemcpyHostToDevice)); } #define ALLOW_ATOMIC_SCATTER // allow to disable this, until we know atomicAdd() works properly here template __global__ void _doScatterColumnsOf(ElemType* us, size_t usStride, size_t usCols, const ElemType* idx, size_t idxStride, const ElemType* a, size_t aStride, const ElemType alpha, CUDA_LONG numElements) { CUDA_LONG id = GridDim::GetLinearThreadId(); if (id >= numElements) // note: there are no __syncthread() calls inside return; // id = i + jIn * aStride // Each thread processes one element of a CUDA_LONG i = id % aStride; // row index into 'a' and 'us' CUDA_LONG jIn = id / aStride; // col index into 'a' and 'idx' auto jOutF = idx[jIn * idxStride]; // this is the column we copy/add into if (::isnan(jOutF) || jOutF < 0) // negative index means gap return; size_t jOut = (size_t)jOutF; //if (jOut >= usCols) // return; // actually a failure --TODO: This should not be necessary. Why is it? const ElemType& ra = a[id/*i + jIn * aStride*/]; ElemType& rus = us[ i + jOut * usStride ]; ElemType res = ra * alpha; if (res != 0) // avoid memory conflict if e.g. an entire column has no gradient #ifdef ALLOW_ATOMIC_SCATTER atomicAdd(&rus, res); // rus += res; #else rus += res; #endif // Note: atomicAdd() is supposed to be fast in case of no conflict (the simple case of Scatter()) } // *this[:,idx[j]] = a[:,j] * alpha + *this[:,idx[j]] * beta template GPUMatrix& GPUMatrix::DoScatterColumnsOf(ElemType beta, const GPUMatrix& idx, const GPUMatrix& a, ElemType alpha) { if (idx.GetNumRows() != 1) // index is 1-dimensional only InvalidArgument("DoScatterColumnsOf: Map must be a row vector."); if (idx.GetNumCols() != a.GetNumCols()) InvalidArgument("DoScatterColumnsOf: Map must have width of input vector."); if (a.GetNumRows() != GetNumRows()) InvalidArgument("DoScatterColumnsOf: Output must have same height as input vector."); if (idx.GetComputeDeviceId() != a.GetComputeDeviceId() || GetComputeDeviceId() != a.GetComputeDeviceId()) InvalidArgument("All matrices must be on the same GPU"); a.PrepareDevice(); auto& us = *this; #ifndef ALLOW_ATOMIC_SCATTER // verify that atomicAdd is not needed --this is not efficient { vector buf(idx.GetNumRows() * idx.GetNumCols()); // idx(,)are the column(s) we copy/add into CUDA_CALL(cudaMemcpy(buf.data(), idx.Data(), sizeof(ElemType) * buf.size(), cudaMemcpyDeviceToHost)); vector writtenTo(GetNumCols(), false); // remember whether an output column is in fact a target for (size_t i = 0; i < buf.size(); i++) { auto colF = buf[i]; if (std::isnan(colF) || colF < 0) continue; size_t col = (size_t)colF; if (col >= GetNumCols()) LogicError("DoScatterColumnsOf: Index value out of bounds."); if (writtenTo[col]) LogicError("DoScatterColumnsOf: #ifndef ALLOW_ATOMIC_SCATTER then columns must be unique. Column idx(%d,%d)=%d is used twice.", (int)(i % idx.GetNumCols()), (int)(i / idx.GetNumCols()), (int)col); else writtenTo[col] = true; } } #endif // pre-scale with beta upfront // Scatter may add more than one source column to the same target, so we must pre-scale with beta, and then just keep adding. Scale(beta, us); // if beta is 0, then this will be a memset() // launch the kernel CUDA_LONG NN = (CUDA_LONG)(a.GetNumElements()); // linear space identifying each individual input element SyncGuard syncGuard; GridDim grid(NN); _doScatterColumnsOf<<>>(Data(), GetNumRows(), GetNumCols(), idx.Data(), idx.GetNumRows(), a.Data(), a.GetNumRows(), alpha, NN); //SyncGuard syncGuard; //_doScatterColumnsOf<<>>(Data(), GetNumRows(), GetNumCols(), idx.Data(), idx.GetNumRows(), a.Data(), a.GetNumRows(), alpha, NN); return *this; } template void GPUMatrix::SetValue(const ElemType v) { if (IsEmpty()) return; CUDA_LONG N = (CUDA_LONG) GetNumElements(); // Check if value is zero, which can be set using cudaMemset bool isZero = true; const char* valArray = reinterpret_cast(&v); for (int i = 0; i < sizeof(ElemType); i++) { if (valArray[i] != 0) { isZero = false; break; } } if (isZero) { CUDA_CALL(cudaMemset(Data(), 0, N * sizeof(ElemType))); } else { int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _setValue<<>>(Data(), v, N); } } template void GPUMatrix::SetValue(const ElemType* d_v) // d_v is pointer to the the value in GPU memory { if (IsEmpty()) LogicError("SetValue: Matrix is empty."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _setValue<<>>(Data(), d_v, N); } template void GPUMatrix::MaskColumnsValue(const GPUMatrix& columnsMask, ElemType val, size_t numColsPerMaskEntry) { if (GetNumCols() != (columnsMask.GetNumCols() * numColsPerMaskEntry)) RuntimeError("Matrix number of columns must equal 'number of columns in column mask * numColsPerMaskEntry'."); if (GetComputeDeviceId() != columnsMask.GetComputeDeviceId()) RuntimeError("Matrix and column mask must be on the same device"); int blocksPerGrid = (int)columnsMask.GetNumCols(); PrepareDevice(); SyncGuard syncGuard; _maskColumnsValue<<>>(Data(), columnsMask.Data(), (CUDA_LONG) GetNumCols(), (CUDA_LONG) GetNumRows(), val, numColsPerMaskEntry); } template void GPUMatrix::SetColumn(const ElemType* colPointer, size_t colInd) { if (IsEmpty()) LogicError("SetValue: Matrix is empty."); if (colPointer == NULL) return; CUDA_CALL(cudaMemcpy(Data() + LocateColumn(colInd), colPointer, sizeof(ElemType) * m_numRows, cudaMemcpyHostToDevice)); } template void GPUMatrix::SetColumn(const GPUMatrix& valMat, size_t colInd) { if (IsEmpty()) LogicError("SetColumn: Matrix is empty."); if (valMat.GetNumCols() != 1) LogicError("SetColumn: only support one column matrix now."); CUDA_CALL(cudaMemcpy(Data() + LocateColumn(colInd), valMat.Data(), sizeof(ElemType) * m_numRows, cudaMemcpyDeviceToDevice)); } template void GPUMatrix::SetValue(const GPUMatrix& deepCopyFrom) { if (this == &deepCopyFrom) return; SetValue(deepCopyFrom.GetNumRows(), deepCopyFrom.GetNumCols(), deepCopyFrom.GetComputeDeviceId(), deepCopyFrom.Data(), matrixFlagSetValueOnDevice); } #if 0 template void GPUMatrix::SetValue(const CPUMatrix& /*deepCopyFrom*/) { NOT_IMPLEMENTED; } template void GPUMatrix::SetValue(const CPUSparseMatrix& /*deepCopyFrom*/) { NOT_IMPLEMENTED; } template void GPUMatrix::SetValue(const GPUSparseMatrix& deepCopyFrom) { deepCopyFrom.CopyToDenseMatrix(*this); } #endif template void GPUMatrix::SetValue(const size_t numRows, const size_t numCols, int deviceId, ElemType* pArray, size_t matrixFlags, DataTransferer* transferer) { // handle externally managed case // BUGBUG: This is super super ugly, and needs to be fixed, but if matrixFlags has the right value, then we can't free anything, // and everything gets wonky. This should be fixed, and would go away if it is made a shared_ptr. if (matrixFlags & matrixFlagDontOwnBuffer) { // free the existing array if it used to be an owned array if ( Buffer() != NULL) { TracingGPUMemoryAllocator::Free(GetComputeDeviceId(), Buffer()); } m_numRows = numRows; m_numCols = numCols; SetBuffer(pArray, GetNumElements() * sizeof(ElemType), true); SetSizeAllocated(GetNumElements()); SetFormat(matrixFormatDense); SetComputeDeviceId(deviceId); } else { if (transferer && (matrixFlags & matrixFlagSetValueOnDevice)) RuntimeError("Asynchronous data copy from device to device is currently not supported."); // if the devices are different move it now if (GetComputeDeviceId() != deviceId && deviceId >= 0) { Clear(); ZeroInit(deviceId); } // now RequireSize/allocate as necessary RequireSize(numRows, numCols); // copy over the content to the buffer PrepareDevice(); if (pArray != NULL) { if (!(matrixFlags & matrixFormatRowMajor)) { if (transferer) transferer->CopyCPUToGPUAsync(pArray, GetNumElements(), sizeof(ElemType), Data()); else CUDA_CALL(cudaMemcpy(Data(), pArray, sizeof(ElemType) * GetNumElements(), (matrixFlags & matrixFlagSetValueOnDevice) ? cudaMemcpyDeviceToDevice : cudaMemcpyHostToDevice)); } else // row major: must transpose (this is not meant to be efficient, but very useful for defining inline matrices for test code) { vector transposed(GetNumElements()); for (size_t i = 0; i < numRows; i++) for (size_t j = 0; j < numCols; j++) transposed[i + numRows * j] = pArray[j + numCols * i]; if (transferer) transferer->CopyCPUToGPUAsync(transposed.data(), GetNumElements(), sizeof(ElemType), Data()); else CUDA_CALL(cudaMemcpy(Data(), transposed.data(), sizeof(ElemType) * GetNumElements(), (matrixFlags & matrixFlagSetValueOnDevice) ? cudaMemcpyDeviceToDevice : cudaMemcpyHostToDevice)); } } } SetFormat(matrixFormatDense); } template void GPUMatrix::SetDiagonalValue(const ElemType v) { CUDA_LONG N = (CUDA_LONG) GetNumRows(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _setDiagonalValue<<>>(Data(), v, N, (CUDA_LONG) GetNumRows()); } template void GPUMatrix::SetDiagonalValue(const GPUMatrix& vector) { if (IsEmpty() || vector.IsEmpty()) LogicError("SetDiagonalValue: Matrix is empty."); if (GetNumRows() != GetNumCols()) LogicError("SetDiagonalValue: NumRows and NumCols do not agree."); if (vector.GetNumRows() != 1 && vector.GetNumCols() != 1) LogicError("SetDiagonalValue: input vector must be a vector."); if (vector.GetNumElements() == 1) // reduce to simple form SetDiagonalValue(vector.Data()[0]); else if (vector.GetNumRows() != GetNumRows() && vector.GetNumCols() != GetNumRows()) LogicError("SetDiagonalValue: input vector's dimension does not agree with [this]."); else { CUDA_LONG N = (CUDA_LONG) GetNumRows(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _setDiagonalValueFromVector<<>>(Data(), vector.Data(), N); } } template void RescaleToRange(const GPUMatrix& matrix, const ElemType low, const ElemType high) { size_t N = matrix.GetNumElements(); size_t blocksPerGrid = (size_t)ceil(N / (double)GridDim::maxThreadsPerBlock); //Nobody is ever calling SetStream so all work is done one the same stream //Therefore we don't need to sync //SyncGuard syncGuard; _rescaleToRange << > > (matrix.Data(), N, low, high); } template void GPUMatrix::SetUniformRandomValue(const ElemType low, const ElemType high, unsigned long seed) { PrepareDevice(); CreateCurandObject(seed, __FUNCTION__); // TODO call ResetCurandObject() instead? { //Nobody is ever calling SetStream so all work is done one the same stream //Therefore we don't need to sync //SyncGuard syncGuard; if (sizeof(ElemType) == sizeof(float)) CURAND_CALL(curandGenerateUniform(((curandGenerator_t*) s_curandGenerator)[0], reinterpret_cast(Data()), GetNumElements())); else CURAND_CALL(curandGenerateUniformDouble(((curandGenerator_t*) s_curandGenerator)[0], reinterpret_cast(Data()), GetNumElements())); } RescaleToRange(*this, low, high); } template void GPUMatrix::SetUniformRandomValue(RNGHandle& rngHandle, const ElemType low, const ElemType high) { PrepareDevice(); GPURNGHandle* gpuRNGHandle = dynamic_cast(&rngHandle); assert(gpuRNGHandle != nullptr); { //Nobody is ever calling SetStream so all work is done one the same stream //Therefore we don't need to sync //SyncGuard syncGuard; if (sizeof(ElemType) == sizeof(float)) CURAND_CALL(curandGenerateUniform(gpuRNGHandle->Generator(), reinterpret_cast(Data()), GetNumElements())); else CURAND_CALL(curandGenerateUniformDouble(gpuRNGHandle->Generator(), reinterpret_cast(Data()), GetNumElements())); } RescaleToRange(*this, low, high); } template void SetNormalRandomValue(const GPUMatrix& matrix, const curandGenerator_t& generator, const ElemType mean, const ElemType stdev) { //Nobody is ever calling SetStream so all work is done one the same stream //Therefore we don't need to sync //SyncGuard syncGuard; // curandGenerateNormal can return the error CURAND_STATUS_LENGTH_NOT_MULTIPLE if GetNumElements() is odd. // To avoid this we always allocate a buffer of even size and potentially generate one more random element. auto n = AsMultipleOf(matrix.GetNumElements(), 2); if (sizeof(ElemType) == sizeof(float)) CURAND_CALL(curandGenerateNormal(generator, reinterpret_cast(matrix.Data()), n, (float)mean, (float)stdev)); else CURAND_CALL(curandGenerateNormalDouble(generator, reinterpret_cast(matrix.Data()), n, (double)mean, (double)stdev)); } template void GPUMatrix::SetGaussianRandomValue(RNGHandle& rngHandle, const ElemType mean, const ElemType stdev) { PrepareDevice(); GPURNGHandle* gpuRNGHandle = dynamic_cast(&rngHandle); assert(gpuRNGHandle != nullptr); SetNormalRandomValue(*this, gpuRNGHandle->Generator(), mean, stdev); } template void GPUMatrix::SetGumbelRandomValue(RNGHandle& rngHandle, const ElemType loc, const ElemType scale) { PrepareDevice(); GPURNGHandle* gpuRNGHandle = dynamic_cast(&rngHandle); assert(gpuRNGHandle != nullptr); { //Nobody is ever calling SetStream so all work is done one the same stream //Therefore we don't need to sync //SyncGuard syncGuard; if (sizeof(ElemType) == sizeof(float)) CURAND_CALL(curandGenerateUniform(gpuRNGHandle->Generator(), reinterpret_cast(Data()), GetNumElements())); else CURAND_CALL(curandGenerateUniformDouble(gpuRNGHandle->Generator(), reinterpret_cast(Data()), GetNumElements())); } size_t N = GetNumElements(); size_t blocksPerGrid = (size_t)ceil(N / (double)GridDim::maxThreadsPerBlock); { //Nobody is ever calling SetStream so all work is done one the same stream //Therefore we don't need to sync //SyncGuard syncGuard; _gumbelFromUniform << > > (Data(), N, loc, scale); } } template void GPUMatrix::SetGaussianRandomValue(const ElemType mean, const ElemType sigma, unsigned long seed) { PrepareDevice(); CreateCurandObject(seed, __FUNCTION__); // TODO call ResetCurandObject() instead? SetNormalRandomValue(*this, ((curandGenerator_t*)s_curandGenerator)[0], mean, sigma); } template void GPUMatrix::SetTruncatedNormalRandomValue(const ElemType mean, const ElemType sigma, unsigned long seed) { // We use the method described in https://en.wikipedia.org/wiki/Truncated_normal_distribution // i.e. generate uniform, scale it to the right range, pass it through the inverse cdf, scale by sigma, and add the mean PrepareDevice(); CreateCurandObject(seed, __FUNCTION__); // TODO call ResetCurandObject() instead? { //Nobody is ever calling SetStream so all work is done one the same stream //Therefore we don't need to sync //SyncGuard syncGuard; if (sizeof(ElemType) == sizeof(float)) CURAND_CALL(curandGenerateUniform(((curandGenerator_t*)s_curandGenerator)[0], reinterpret_cast(Data()), GetNumElements())); else CURAND_CALL(curandGenerateUniformDouble(((curandGenerator_t*)s_curandGenerator)[0], reinterpret_cast(Data()), GetNumElements())); } size_t N = GetNumElements(); size_t blocksPerGrid = (size_t)ceil(N / (double)GridDim::maxThreadsPerBlock); { //Nobody is ever calling SetStream so all work is done one the same stream //Therefore we don't need to sync //SyncGuard syncGuard; _truncated_normal_transform << > > (Data(), N, mean, sigma); } } //maskRate: percentage of values masked out (similar to dropout rate) //scaleValue: which scale value to set to the left ones (unmasked items). template void GPUMatrix::SetUniformRandomMask(const ElemType maskRate, const ElemType scaleValue, RNGHandle& rngHandle) { PrepareDevice(); GPURNGHandle* gpuRNGHandle = dynamic_cast(&rngHandle); assert(gpuRNGHandle != nullptr); cudaEvent_t done = nullptr; CUDA_CALL(cudaEventCreate(&done)); // TODO: why not condition on do_sync, so that we can use SyncGuard? if (sizeof(ElemType) == sizeof(float)) CURAND_CALL(curandGenerateUniform(gpuRNGHandle->Generator(), reinterpret_cast(Data()), GetNumElements())); else CURAND_CALL(curandGenerateUniformDouble(gpuRNGHandle->Generator(), reinterpret_cast(Data()), GetNumElements())); CUDA_CALL(cudaEventRecord(done)); CUDA_CALL(cudaEventSynchronize(done)); CUDA_CALL(cudaEventDestroy(done)); size_t N = GetNumElements(); size_t blocksPerGrid = (size_t) ceil(N / (double) GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _setMaskAndScale<<>>(Data(), N, maskRate, scaleValue); } template ElemType GPUMatrix::Adagrad(GPUMatrix& gradients, const bool needAveMultiplier) { size_t numColsNeeded = gradients.GetNumCols(); if (needAveMultiplier) numColsNeeded += gradients.GetNumCols(); if (IsEmpty() || GetNumCols() < numColsNeeded) { RequireSize(gradients.GetNumRows(), numColsNeeded); SetValue(0.0); } assert(GetNumRows() == gradients.GetNumRows() && GetNumCols() == numColsNeeded); size_t n = gradients.GetNumElements(); ElemType* multipliers = nullptr; if (needAveMultiplier) multipliers = Data() + n; // temp memory used to store multipliers, int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; _adagrad<<>>(Data(), gradients.Data(), n, multipliers); if (!needAveMultiplier) return 1; cublasHandle_t cuHandle = GetCublasHandle(GetComputeDeviceId()); if (sizeof(ElemType) == sizeof(float)) { float aveMultiplier = 0; CUBLAS_CALL(cublasSasum(cuHandle, (CUDA_LONG) n, reinterpret_cast(multipliers), 1, &aveMultiplier)); return (ElemType) aveMultiplier / n; } else { double aveMultiplier = 0; CUBLAS_CALL(cublasDasum(cuHandle, (CUDA_LONG) n, reinterpret_cast(multipliers), 1, &aveMultiplier)); return (ElemType) aveMultiplier / n; } } template void GPUMatrix::FSAdagrad(GPUMatrix& gradients, GPUMatrix& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, bool unitGainMomentum) { size_t numColsNeeded = 2 * gradients.GetNumCols(); if (IsEmpty() || (GetNumCols() < numColsNeeded)) { RequireSize(gradients.GetNumRows(), numColsNeeded); SetValue(0.0); } assert((GetNumRows() == gradients.GetNumRows()) && (GetNumCols() == numColsNeeded)); size_t n = gradients.GetNumElements(); int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; _fsadagrad<<>>(n, gradients.Data(), Data(), Data()+ n, functionValues.Data(), learnRatePerSample, momentum, adaWeight, adaMul, unitGainMomentum); } template void GPUMatrix::Adam(GPUMatrix& gradients, GPUMatrix& functionValues, ElemType learnRatePerSample, ElemType momentum, ElemType adaWeight, ElemType adaMul, ElemType epsilon, bool unitGainMomentum, bool adamax) { size_t numColsNeeded = 2 * gradients.GetNumCols(); if (IsEmpty() || (GetNumCols() < numColsNeeded)) { RequireSize(gradients.GetNumRows(), numColsNeeded); SetValue(0.0); } assert((GetNumRows() == gradients.GetNumRows()) && (GetNumCols() == numColsNeeded)); size_t n = gradients.GetNumElements(); int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; _adam << > >(n, gradients.Data(), Data(), Data() + n, functionValues.Data(), learnRatePerSample, momentum, adaWeight, adaMul, epsilon, unitGainMomentum, adamax); } template ElemType GPUMatrix::RmsProp(GPUMatrix& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier, const bool initialized) { const ElemType floor = 1e-6f; static ElemType* upd_gpu = (ElemType*) 0; size_t n = gradients.GetNumElements(); int blocksPerGrid = (GetNumElements() + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; size_t numColsNeeded = gradients.GetNumCols() * 3; if (needAveMultiplier) numColsNeeded += gradients.GetNumCols(); if (IsEmpty() || GetNumCols() < numColsNeeded || !initialized) { RequireSize(gradients.GetNumRows(), numColsNeeded); SetValue(0.0); ElemType* avars = Data(); // accumulated variances for RMS scaling ElemType* signs = Data() + n; // sign of previous gradient ElemType* steps = Data() + 2 * n; // current step size // Data()+3*n is temp memory used to store multipliers, no need to initialize _rmsprop_init<<>>(avars, signs, steps, gradients.Data(), n); } assert(GetNumRows() == gradients.GetNumRows() && GetNumCols() == numColsNeeded); ElemType* avars = Data(); // accumulated variances for RMS scaling ElemType* signs = Data() + n; // sign of previous gradient ElemType* steps = Data() + 2 * n; // current step size ElemType* multipliers = nullptr; if (needAveMultiplier) multipliers = Data() + 3 * n; // temp memory used to store multipliers, if (!upd_gpu) { const ElemType upd[] = { 2, 2, 0, 2, 2, 0, 1, 1, 1, 2, 2, 0, 1, 2, 1, 0, 2, 2, 1, 1, 1, 0, 2, 2, 0, 2, 2, }; upd_gpu = TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), 27); CUDA_CALL(cudaMemcpy(upd_gpu, upd, sizeof(ElemType) * _countof(upd), cudaMemcpyHostToDevice)); } _rmsprop<<>>(avars, signs, steps, gradients.Data(), n, RMS_GAMMA, RMS_WGT_INC, RMS_WGT_MAX, RMS_WGT_DEC, RMS_WGT_MIN, floor, upd_gpu, multipliers); if (!needAveMultiplier) return 1; cublasHandle_t cuHandle = GetCublasHandle(GetComputeDeviceId()); if (sizeof(ElemType) == sizeof(float)) { float aveMultiplier = 0; CUBLAS_CALL(cublasSasum(cuHandle, (CUDA_LONG) n, reinterpret_cast(multipliers), 1, &aveMultiplier)); return aveMultiplier / n; } else { double aveMultiplier = 0; CUBLAS_CALL(cublasDasum(cuHandle, (CUDA_LONG) n, reinterpret_cast(multipliers), 1, &aveMultiplier)); return (ElemType) aveMultiplier / n; } } template void GPUMatrix::AdaDelta(GPUMatrix& gradients, GPUMatrix& functionValues, ElemType learningRate, ElemType rho, ElemType epsilon) { size_t numColsNeeded = 2 * gradients.GetNumCols(); if (IsEmpty() || (GetNumCols() < numColsNeeded)) { RequireSize(gradients.GetNumRows(), numColsNeeded); SetValue(0.0); } assert((GetNumRows() == gradients.GetNumRows()) && (GetNumCols() == numColsNeeded)); size_t n = gradients.GetNumElements(); int blocksPerGrid = (n + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; _adadelta << > >(n, gradients.Data(), Data(), Data() + n, functionValues.Data(), learningRate, rho, epsilon); } template void GPUMatrix::Reshape(const size_t numRows, const size_t numCols) { assert(numRows * numCols == GetNumElements()); if (numRows * numCols != GetNumElements()) InvalidArgument("Reshape: total number of elements does not match."); m_numRows = numRows; m_numCols = numCols; } template void GPUMatrix::RequireSize(const size_t numRows, const size_t numCols, bool growOnly) { if (GetNumRows() != numRows || GetNumCols() != numCols) Resize(numRows, numCols, growOnly); } template void GPUMatrix::Resize(const size_t numRows, const size_t numCols, bool growOnly) { if (GetNumRows() == numRows && GetNumCols() == numCols) return; VerifyResizable(__FUNCTION__); size_t numElements = numRows * numCols; if (numElements > GetSizeAllocated() || // grow allocation (!growOnly && numElements != GetSizeAllocated())) // shrink allocation if not growOnly { // If the buffer exists, free it before allocate if (Buffer()) { TracingGPUMemoryAllocator::Free(GetComputeDeviceId(), Buffer()); } // reallocate buffer if numElements > 0 ElemType* pArray = nullptr; if (numElements > 0) { pArray = TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), numRows, numCols); } SetBuffer(pArray, numElements * sizeof(ElemType)); SetSizeAllocated(numElements); } // success m_sliceViewOffset = 0; m_numRows = numRows; m_numCols = numCols; } template size_t GPUMatrix::LocateElement(const size_t row, const size_t col) const { assert(row < m_numRows && col < m_numCols); return LocateColumn(col) + row; // matrix in column-wise storage } template size_t GPUMatrix::LocateColumn(const size_t col) const { assert(col < GetNumCols()); return col * m_numRows; // matrix in column-wise storage } template ElemType GPUMatrix::Get00Element() const { ElemType res = 0; CUDA_CALL(cudaMemcpy(&res, Data(), sizeof(ElemType), cudaMemcpyDeviceToHost)); return res; } #pragma endregion Basic Operators #pragma region Member BLAS Functions template GPUMatrix& GPUMatrix::operator+=(ElemType alpha) { if (IsEmpty()) LogicError("operator+=: Matrix is empty."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _addValue<<>>(Data(), alpha, N); return *this; } template GPUMatrix GPUMatrix::operator+(ElemType alpha) const { if (IsEmpty()) LogicError("operator+: Matrix is empty."); GPUMatrix c(*this); c += alpha; return c; } template GPUMatrix& GPUMatrix::AssignSumOf(const ElemType alpha, const GPUMatrix& a) { SetValue(a); (*this) += alpha; return (*this); } template GPUMatrix& GPUMatrix::operator+=(const GPUMatrix& a) { ScaleAndAdd(1, a, *this); return *this; } template GPUMatrix GPUMatrix::operator+(const GPUMatrix& a) const { if (GetNumElements() == 1) { GPUMatrix c(a); c += Get00Element(); return c; } else if (a.GetNumElements() == 1) { GPUMatrix c(*this); c += a.Get00Element(); return c; } else { GPUMatrix c(*this); // this implementation will introduce a copy overhead. but make resue of the code c += a; return c; } } template GPUMatrix& GPUMatrix::AssignSumOf(const GPUMatrix& a, const GPUMatrix& b) { SetValue(a); (*this) += b; return (*this); } template GPUMatrix& GPUMatrix::operator-=(ElemType alpha) { if (IsEmpty()) LogicError("operato-=: Matrix is empty."); return operator+=(-1 * alpha); } template GPUMatrix GPUMatrix::operator-(ElemType alpha) const { if (IsEmpty()) LogicError("operator-: Matrix is empty."); return operator+(-1 * alpha); } template GPUMatrix& GPUMatrix::AssignDifferenceOf(const ElemType alpha, const GPUMatrix& a) { RequireSize(a.m_numRows, a.m_numCols); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _assignDifferenceOf1<<>>(Data(), alpha, a.Data(), N); return *this; } template GPUMatrix& GPUMatrix::AssignDifferenceOf(const GPUMatrix& a, const ElemType alpha) { RequireSize(a.m_numRows, a.m_numCols); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _assignDifferenceOf2<<>>(Data(), alpha, a.Data(), N); return *this; } template GPUMatrix& GPUMatrix::operator-=(const GPUMatrix& a) { ScaleAndAdd(-1, a, *this); return *this; } template GPUMatrix GPUMatrix::operator-(const GPUMatrix& a) const { GPUMatrix c(*this); // this implementation will introduce a copy overhead. but make resue of the code c -= a; return c; } template GPUMatrix& GPUMatrix::AssignDifferenceOf(const GPUMatrix& a, const GPUMatrix& b) { if (this != &a) { RequireSize(a.GetNumRows(), a.GetNumCols()); SetValue(a); } (*this) -= b; return *this; } template GPUMatrix& GPUMatrix::operator*=(ElemType alpha) { Scale(alpha, *this); return *this; } template GPUMatrix GPUMatrix::operator*(ElemType alpha) const { GPUMatrix c(GetNumRows(), GetNumCols(), GetComputeDeviceId()); Scale(alpha, *this, c); return c; } template GPUMatrix& GPUMatrix::AssignProductOf(const ElemType alpha, const GPUMatrix& a) { Scale(alpha, a, *this); return *this; } template GPUMatrix& GPUMatrix::AssignProductOf(const GPUMatrix& a, const bool transposeA, const GPUMatrix& b, const bool transposeB) { if (a.GetNumElements() == 1) { if (transposeB) AssignTransposeOf(b); (*this) *= a.Get00Element(); } else if (b.GetNumElements() == 1) { if (transposeA) AssignTransposeOf(a); (*this) *= b.Get00Element(); } else Multiply(a, transposeA, b, transposeB, *this); return *this; } template GPUMatrix GPUMatrix::operator*(const GPUMatrix& a) const { const GPUMatrix& us = *this; if (GetNumElements() == 1) { GPUMatrix c(GetComputeDeviceId()); c.AssignProductOf(Get00Element(), a); return c; } else if (a.GetNumElements() == 1) { GPUMatrix c(GetComputeDeviceId()); c.AssignProductOf(a.Get00Element(), us); return c; } else { GPUMatrix c(GetNumRows(), a.GetNumCols(), GetComputeDeviceId()); Multiply(*this, a, c); return c; } } template GPUMatrix& GPUMatrix::operator/=(ElemType alpha) { (*this) *= 1 / alpha; return (*this); } template GPUMatrix GPUMatrix::operator/(ElemType alpha) const { return ((*this) * (1 / alpha)); } //element-wise power template GPUMatrix& GPUMatrix::operator^=(ElemType alpha) { GPUMatrix& us = *this; ElementWisePower(alpha, us, us); return us; } template GPUMatrix GPUMatrix::operator^(ElemType alpha) const { GPUMatrix c(GetNumRows(), GetNumCols(), GetComputeDeviceId()); ElementWisePower(alpha, *this, c); return c; } template GPUMatrix& GPUMatrix::AssignElementPowerOf(const GPUMatrix& a, const ElemType power) { ElementWisePower(power, a, *this); return *this; } template GPUMatrix& GPUMatrix::AddElementProductOf(const GPUMatrix& a, const GPUMatrix& b) { if (a.IsEmpty() || b.IsEmpty()) LogicError("AddElementProductOf: Matrix is empty."); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols())) InvalidArgument("The input matrix dimensions do not match."); if (!(a.GetNumRows() == GetNumRows() && a.GetNumCols() == GetNumCols())) InvalidArgument("The input matrix dimensions do not match [this]."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _addElementProductOf<<>>(Data(), a.Data(), b.Data(), N); return *this; } template GPUMatrix& GPUMatrix::ColumnElementMultiplyWith(const GPUMatrix& a) { if (a.IsEmpty() || IsEmpty()) LogicError("ColumnElementMultiplyWith: Matrix is empty."); if (!(a.GetNumRows() == GetNumRows() && a.GetNumCols() == 1)) InvalidArgument("ColumnElementMultiplyWith: The input matrix should be a col vector and match [this]'s rows."); CUDA_LONG N = (CUDA_LONG) a.GetNumRows(); CUDA_LONG M = (CUDA_LONG) GetNumCols(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _columnElementMultiplyWith<<>>(Data(), a.Data(), N, M); return *this; } template GPUMatrix& GPUMatrix::RowElementMultiplyWith(const GPUMatrix& a) { if (a.IsEmpty() || IsEmpty()) LogicError("RowElementMultiplyWith: Matrix is empty."); if (!(a.GetNumRows() == 1 && a.GetNumCols() == GetNumCols())) InvalidArgument("RowElementMultiplyWith: The input matrix should be a row vector and match [this]'s columns."); CUDA_LONG N = (CUDA_LONG) GetNumRows(); CUDA_LONG M = (CUDA_LONG) a.GetNumCols(); int blocksPerGrid = (int) ceil(1.0 * M / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _rowElementMultiplyWith<<>>(Data(), a.Data(), N, M); return *this; } template GPUMatrix& GPUMatrix::RowElementDivideBy(const GPUMatrix& a) { if (a.IsEmpty() || IsEmpty()) LogicError("RowElementDivideBy: Matrix is empty."); if (!(a.GetNumRows() == 1 && a.GetNumCols() == GetNumCols())) InvalidArgument("RowElementDivideBy: The input matrix should be a row vector and match [this]'s columns."); CUDA_LONG N = (CUDA_LONG) GetNumRows(); CUDA_LONG M = (CUDA_LONG) a.GetNumCols(); int blocksPerGrid = (int) ceil(1.0 * M / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _rowElementDivideBy<<>>(Data(), a.Data(), N, M); return *this; } template GPUMatrix& GPUMatrix::ColumnElementDivideBy(const GPUMatrix& a) { if (a.IsEmpty() || IsEmpty()) LogicError("ColumnElementDivideBy: Matrix is empty."); if (!(a.GetNumRows() == GetNumRows() && a.GetNumCols() == 1)) InvalidArgument("ColumnElementDivideBy: The input matrix should be a col vector and match [this]'s rows."); CUDA_LONG N = (CUDA_LONG) a.GetNumRows(); CUDA_LONG M = (CUDA_LONG) GetNumCols(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _ColumnElementDivideBy<<>>(Data(), a.Data(), N, M); return *this; } template GPUMatrix& GPUMatrix::ElementInverse() { if (IsEmpty()) LogicError("ElementInverse: Matrix is empty."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _elemInverse<<>>(Data(), N); return *this; } template GPUMatrix& GPUMatrix::AssignElementInverseOf(const GPUMatrix& a) { SetValue(a); return ElementInverse(); } DEF_ELEMWISE_INPLACE_FUNC(Sigmoid) template GPUMatrix& GPUMatrix::AssignSigmoidOf(const GPUMatrix& a) { RequireSize(a.GetNumRows(), a.GetNumCols()); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; // _elementWIseSigmoidOnCuda has an implementation that avoids possible overflow errors, but has a slight accuracy regression. #if 0 _elementWiseSigmoidOnCuda<<>>(a.Data(), Data(), N); #else _assignSigmoidOf<<>>(a.Data(), Data(), N); #endif return *this; } DEF_ELEMWISE_INPLACE_FUNC(SigmoidDerivative) DEF_ELEMWISE_ASSIGN_FUNC(SigmoidDerivative) template void GPUMatrix::AssignNoiseContrastiveEstimation(const GPUMatrix& a, const GPUMatrix& b, const GPUMatrix& bias, size_t sampleCount, GPUMatrix& tmp, GPUMatrix& c) //this: samples+probs // a : hidden // b : embedding // tmp: softmax // c : loglikelihood { UNCONST(ElemType, a, my_a); UNCONST(ElemType, b, my_b); UNCONST(ElemType, bias, my_bias); SyncGuard syncGuard; // a: dim * minibatch // b: dim * |vocab| int p = 512; int width = a.GetNumRows(); // dimension of hidden vector while (p / 2 > width) p = p / 2; // note: kernel has hard-coded dimension of 512 _computeNceOutputMax512Threads << > >( Data(), sampleCount, m_numRows / 2, my_a.Data(), // a a.GetNumRows(), my_b.Data(), // b my_bias.Data(), tmp.Data()); // tmp p = 512; while (p / 2 > GetNumElements() / 2) p = p / 2; // summing up objective must be done in one block // note: kernel has hard-coded dimension of 512 _assignNoiseContrastiveEstimationMax512Threads << <1, p >> >( Data(), sampleCount, m_numRows / 2, my_a.Data(), a.GetNumCols(), my_b.Data(), tmp.Data(), c.Data()); } template void GPUMatrix::AssignNCEDerivative(GPUMatrix& tmp, const GPUMatrix& a, const GPUMatrix& b, size_t inputIndex, GPUMatrix& c) { UNCONST(ElemType, a, my_a); UNCONST(ElemType, b, my_b); SyncGuard syncGuard; int p = 512; int width = a.GetNumRows(); while (p / 2 > width) p = p / 2; _assignNceDerivativeNew<<<(tmp.GetNumElements() + p - 1) / p, p>>>( Data(), tmp.GetNumCols(), m_numRows / 2, my_a.Data(), a.GetNumRows(), my_b.Data(), tmp.Data(), c.Data(), inputIndex); } template void GPUMatrix::AssignSoftmaxSum(const GPUMatrix& a, GPUMatrix& c) { UNCONST(ElemType, a, my_a); SyncGuard syncGuard; int p = 512; int width = a.GetNumRows(); while (p / 2 > width) p = p / 2; // note: kernel has hard-coded dimension of 512 _assignSoftmaxSumMax512Threads << <1, p >> >( my_a.Data(), width, Data(), c.Data()); } template void GPUMatrix::AssignNCEUnnormalizedEval(const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c) { assert(a.GetComputeDeviceId() == b.GetComputeDeviceId()); assert(GetNumRows() == a.GetNumRows()); assert(GetNumCols() == b.GetNumRows()); assert(a.GetNumCols() == b.GetNumRows()); UNUSED(a); UNUSED(b); UNUSED(c); // TODO: this function seems like a stub /* EnsureAuxMemory(); int p = 512; int width = a.GetNumCols(); while (p / 2 > width) p = p / 2; // this kernel need be launched in nnz blocks _sparseInnerProductDenseTimesDense << > >( m_dVal, m_buf, m_dCol, m_nz, GetNumRows(), a.Buffer(), b.Buffer(), b.GetNumRows(), m_res); // sum up the results _reductionSum32 << <1, 32 >> >(m_res, c.Buffer(), m_nz);*/ } DEF_ELEMWISE_INPLACE_FUNC(Tanh) DEF_ELEMWISE_ASSIGN_FUNC(Tanh) template GPUMatrix& GPUMatrix::InplaceLogSoftmax(const bool isColWise) { if (IsEmpty()) LogicError("InplaceLogSoftmax: Matrix is empty."); PrepareDevice(); if (isColWise) { CUDA_LONG N = (CUDA_LONG) GetNumCols(); // one kernel per column int blocksPerGrid = (int) ceil(N * 1.0 / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _logSoftMaxColWise<<>>(Data(), (CUDA_LONG) m_numCols, (CUDA_LONG) m_numRows); } else { CUDA_LONG N = (CUDA_LONG) GetNumRows(); // one kernel per column int blocksPerGrid = (int) ceil(N * 1.0 / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _logSoftMaxRowWise<<>>(Data(), (CUDA_LONG) m_numCols, (CUDA_LONG) m_numRows); } return *this; } template GPUMatrix& GPUMatrix::AssignLogSoftmaxOf(const GPUMatrix& a, const bool isColWise) { RequireSize(a.GetNumRows(), a.GetNumCols()); if (isColWise) { PrepareDevice(); CUDA_LONG N = (CUDA_LONG) GetNumCols(); CUDA_LONG M = (CUDA_LONG) GetNumRows(); SyncGuard syncGuard; // note: kernel uses hard-coded thread dimension _assignColumnwiseLogSoftmaxOf512Threads<<>>(a.Data(), Data(), N, M); } else { NOT_IMPLEMENTED; } return *this; } template GPUMatrix& GPUMatrix::InplaceHardmax(const bool isColWise) { return AssignHardmaxOf(*this, isColWise); } template GPUMatrix& GPUMatrix::AssignHardmaxOf(const GPUMatrix& a, const bool isColWise) { RequireSize(a.GetNumRows(), a.GetNumCols()); if (isColWise) { PrepareDevice(); CUDA_LONG N = (CUDA_LONG) GetNumCols(); CUDA_LONG M = (CUDA_LONG) GetNumRows(); SyncGuard syncGuard; // note: kernel uses hard-coded thread dimension _assignColumnwiseHardmaxOf512Threads << > >(a.Data(), Data(), N, M); } else { NOT_IMPLEMENTED; } return *this; } DEF_ELEMWISE_INPLACE_FUNC(Sqrt) DEF_ELEMWISE_ASSIGN_FUNC(Sqrt) DEF_ELEMWISE_INPLACE_FUNC(Exp) DEF_ELEMWISE_ASSIGN_FUNC(Exp) DEF_ELEMWISE_INPLACE_FUNC(Log) DEF_ELEMWISE_ASSIGN_FUNC(Log) DEF_ELEMWISE_INPLACE_FUNC(Abs) DEF_ELEMWISE_ASSIGN_FUNC(Abs) DEF_ELEMWISE_INPLACE_FUNC(LinearRectifierDerivative) DEF_ELEMWISE_ASSIGN_FUNC(LinearRectifierDerivative) DEF_ELEMWISE_INPLACE_FUNC(Cosine) DEF_ELEMWISE_ASSIGN_FUNC(Cosine) DEF_ELEMWISE_INPLACE_FUNC(NegativeSine) DEF_ELEMWISE_ASSIGN_FUNC(NegativeSine) template GPUMatrix& GPUMatrix::InplaceTruncateBottom(const ElemType threshold) { return AssignTruncateBottomOf(*this, threshold); } template GPUMatrix& GPUMatrix::AssignTruncateBottomOf(const GPUMatrix& a, const ElemType threshold) { if (a.IsEmpty()) LogicError("AssignTruncateBottomOf: Matrix a is empty."); if (this != &a) { RequireSize(a.GetNumRows(), a.GetNumCols()); } CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(N * 1.0 / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _assignTruncateBottom<<>>(Data(), a.Data(), threshold, N); return *this; } template GPUMatrix& GPUMatrix::InplaceTruncateTop(const ElemType threshold) { return AssignTruncateTopOf(*this, threshold); } template GPUMatrix& GPUMatrix::AssignTruncateTopOf(const GPUMatrix& a, const ElemType threshold) { if (a.IsEmpty()) LogicError("AssignTruncateTopOf: Matrix a is empty."); if (this != &a) { RequireSize(a.GetNumRows(), a.GetNumCols()); } CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(N * 1.0 / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _assignTruncateTop<<>>(Data(), a.Data(), threshold, N); return *this; } template GPUMatrix& GPUMatrix::InplaceTruncate(const ElemType threshold) { if (IsEmpty()) LogicError("InplaceTruncate: Matrix is empty."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(N * 1.0 / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _inplaceTruncate<<>>(Data(), threshold, N); return *this; } template GPUMatrix& GPUMatrix::InplaceSoftThreshold(const ElemType threshold) { if (IsEmpty()) LogicError("InplaceSoftThreshold: Matrix is empty."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(N * 1.0 / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _inplaceSoftThreshold<<>>(Data(), threshold, N); return *this; } template GPUMatrix& GPUMatrix::SetToZeroIfAbsLessThan(const ElemType threshold) { if (IsEmpty()) LogicError("SetToZeroIfAbsLessThan: Matrix is empty."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(N * 1.0 / GridDim::maxThreadsPerBlock); PrepareDevice(); SyncGuard syncGuard; _setToZeroIfAbsLessThan<<>>(Data(), threshold, N); return *this; } template ElemType GPUMatrix::SumOfAbsElements() const { if (IsEmpty()) LogicError("SumOfAbsElements: Matrix is empty"); cublasHandle_t cuHandle = GetCublasHandle(GetComputeDeviceId()); if (sizeof(ElemType) == sizeof(float)) { float res = 0; CUBLAS_CALL(cublasSasum(cuHandle, (CUDA_LONG) GetNumElements(), reinterpret_cast(Data()), 1, &res)); return res; } else { double res = 0; CUBLAS_CALL(cublasDasum(cuHandle, (CUDA_LONG) GetNumElements(), reinterpret_cast(Data()), 1, &res)); return ElemType(res); } } template ElemType GPUMatrix::SumOfElements() const { if (IsEmpty()) LogicError("SumOfElements: Matrix is empty"); ElemType* d_sum = TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), 1); ElemType h_sum; // WARNING: THIS kernel is not the most efficient way! // note: kernel has hard-coded dimension of 1024 _reductionSum1024Threads << <1, 1024, 0, t_stream >> >(Data(), d_sum, (CUDA_LONG)GetNumElements()); CUDA_CALL(cudaMemcpy(&h_sum, d_sum, sizeof(ElemType), cudaMemcpyDeviceToHost)); TracingGPUMemoryAllocator::Free(GetComputeDeviceId(), d_sum); return h_sum; } template GPUMatrix& GPUMatrix::AssignSumOfElements(const GPUMatrix& a) { if (a.IsEmpty()) LogicError("AssignSumOfElements: Matrix a is empty"); RequireSize(1, 1); PrepareDevice(); SyncGuard syncGuard; // WARNING: THIS kernel is not the most efficient way! // note: kernel has hard-coded dimension of 1024 _reductionSumAndAssign1024Threads << <1, 1024 >> >(Data(), a.Data(), (CUDA_LONG)a.GetNumElements(), (CUDA_LONG)GetNumElements()); return (*this); } template DeviceBoundNumber GPUMatrix::Sum_AsDeviceBoundNum() const { if (IsEmpty()) LogicError("Matrix is empty"); ElemType* d_sum = TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), 1); // WARNING: THIS kernel is not the most efficient way! // note: kernel has hard-coded dimension of 1024 _reductionSum1024Threads << <1, 1024, 0, t_stream >> >(Data(), d_sum, (CUDA_LONG)GetNumElements()); DeviceBoundNumber result; result.ShallowCopyFrom(d_sum, GetComputeDeviceId()); return result; } template ElemType GPUMatrix::AbsoluteMax() const { cublasHandle_t cuHandle = GetCublasHandle(GetComputeDeviceId()); ElemType res; if (sizeof(ElemType) == sizeof(float)) { int resInd = 0; cublasIsamax(cuHandle, (CUDA_LONG)GetNumElements(), reinterpret_cast(Data()), 1, &resInd); resInd--; CUDA_CALL(cudaMemcpy(reinterpret_cast(&res), reinterpret_cast(Data() + resInd), sizeof(float), cudaMemcpyDeviceToHost)); return res; } else { int resInd = 0; cublasIdamax(cuHandle, (CUDA_LONG)GetNumElements(), reinterpret_cast(Data()), 1, &resInd); resInd--; CUDA_CALL(cudaMemcpy(reinterpret_cast(&res), Data() + resInd, sizeof(double), cudaMemcpyDeviceToHost)); return res; } } template GPUMatrix& GPUMatrix::ElementMultiplyWith(const GPUMatrix& a) { if (IsEmpty() || a.IsEmpty()) LogicError("ElementMultiplyWith: Matrix is empty."); GPUMatrix& us = *this; assert(us.GetNumRows() == a.GetNumRows() && us.GetNumCols() == a.GetNumCols()); if (us.GetNumRows() != a.GetNumRows() || us.GetNumCols() != a.GetNumCols()) InvalidArgument("The matrix dimensions do not match."); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _elemMul<<>>(Data(), a.Data(), N); return *this; } template GPUMatrix& GPUMatrix::AssignElementProductOf(const GPUMatrix& a, const GPUMatrix& b) { if (a.IsEmpty() || b.IsEmpty()) LogicError("AssignElementProductOf: Matrix is empty."); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols())) InvalidArgument("The input matrix dimensions do not match."); RequireSize(a.GetNumRows(), a.GetNumCols()); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _assignElementProductOf<<>>(Data(), a.Data(), b.Data(), N); return *this; } template GPUMatrix& GPUMatrix::ElementDivideBy(const GPUMatrix& a) { return AssignElementDivisionOf(*this, a); } template GPUMatrix& GPUMatrix::AssignElementDivisionOf(const GPUMatrix& a, const GPUMatrix& b) { if (a.IsEmpty() || b.IsEmpty()) LogicError("AssignElementDivisionOf: Matrix is empty."); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols())) InvalidArgument("The input matrix dimensions do not match."); RequireSize(a.GetNumRows(), a.GetNumCols()); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _assignElementDivisionOf<<>>(Data(), a.Data(), b.Data(), N); return *this; } template bool GPUMatrix::IsEqualTo(const GPUMatrix& a, const ElemType threshold /*= 1e-8*/) const { return AreEqual(*this, a, threshold); } template void GPUMatrix::VectorSum(const GPUMatrix& a, GPUMatrix& c, const bool isColWise) { if (a.GetComputeDeviceId() != c.GetComputeDeviceId()) { InvalidArgument("All matrices must be on the same GPU"); } a.PrepareDevice(); if (a.IsEmpty()) LogicError("VectorSum: Input matrix is empty."); const CUDA_LONG n = (CUDA_LONG) a.GetNumRows(); const CUDA_LONG m = (CUDA_LONG) a.GetNumCols(); assert(m > 0 && n > 0); // converting from size_t to int may cause overflow int blocksPerGrid = 0; if (isColWise) // col-wise { c.RequireSize(1, m); blocksPerGrid = (int) ceil(1.0 * m / GridDim::maxThreadsPerBlock); } else { c.RequireSize(n, 1); blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); } SyncGuard syncGuard; _vectorSum<<>>(c.Data(), a.Data(), n, m, isColWise); } template void GPUMatrix::VectorNorm1(GPUMatrix& c, const bool isColWise) const { if (IsEmpty()) LogicError("VectorNorm1: Matrix is empty."); const CUDA_LONG n = (CUDA_LONG) GetNumRows(); const CUDA_LONG m = (CUDA_LONG) GetNumCols(); assert(m > 0 && n > 0); // converting from size_t to int may cause overflow PrepareDevice(); c.ChangeDeviceTo(GetComputeDeviceId()); int blocksPerGrid = 0; if (isColWise) // col-wise { c.RequireSize(1, m); blocksPerGrid = (int) ceil(1.0 * m / GridDim::maxThreadsPerBlock); } else { c.RequireSize(n, 1); blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); } SyncGuard syncGuard; _vectorNorm1<<>>(c.Data(), Data(), n, m, isColWise); } template GPUMatrix& GPUMatrix::AssignVectorNorm1Of(GPUMatrix& a, const bool isColWise) { a.VectorNorm1(*this, isColWise); return *this; } template void GPUMatrix::VectorNorm2(GPUMatrix& c, const bool isColWise) const { if (IsEmpty()) LogicError("VectorNorm2: Matrix is empty."); const CUDA_LONG n = (CUDA_LONG) GetNumRows(); const CUDA_LONG m = (CUDA_LONG) GetNumCols(); assert(m > 0 && n > 0); // converting from size_t to int may cause overflow PrepareDevice(); c.ChangeDeviceTo(GetComputeDeviceId()); int blocksPerGrid = 0; if (isColWise) // col-wise { c.RequireSize(1, m); blocksPerGrid = (int) ceil(1.0 * m / GridDim::maxThreadsPerBlock); } else { c.RequireSize(n, 1); c.ChangeDeviceTo(GetComputeDeviceId()); blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); } SyncGuard syncGuard; _vectorNorm2<<>>(c.Data(), Data(), n, m, isColWise); } template GPUMatrix& GPUMatrix::AssignVectorNorm2Of(GPUMatrix& a, const bool isColWise) { a.VectorNorm2(*this, isColWise); return *this; } template void GPUMatrix::VectorNormInf(GPUMatrix& c, const bool isColWise) const { if (IsEmpty()) LogicError("VectorMax: Matrix is empty."); // this implementation is not efficient GPUMatrix tmp(GetComputeDeviceId()); GPUMatrix tmp1(GetComputeDeviceId()); tmp.AssignAbsOf((*this)); tmp.VectorMax(tmp1, c, isColWise); } template GPUMatrix& GPUMatrix::AssignVectorNormInfOf(GPUMatrix& a, const bool isColWise) { a.VectorNormInf(*this, isColWise); return *this; } template GPUMatrix& GPUMatrix::AssignInnerProductOf(const GPUMatrix& a, const GPUMatrix& b, const bool isColWise) { InnerProduct(a, b, *this, isColWise); return *this; } template GPUMatrix& GPUMatrix::AssignKhatriRaoProductOf(const GPUMatrix& a, const GPUMatrix& b) { if (a.IsEmpty() || b.IsEmpty()) LogicError("AssignKhatriRaoProductOf: Matrix is empty."); CUDA_LONG cols = a.GetNumCols(); assert(cols == b.GetNumCols()); if (!(cols == b.GetNumCols())) InvalidArgument("AssignKhatriRaoProductOf: The input matrix dimensions do not match."); CUDA_LONG rowsA = (CUDA_LONG) a.GetNumRows(); CUDA_LONG rowsB = (CUDA_LONG) b.GetNumRows(); RequireSize(rowsA * rowsB, cols); float N = (float) GetNumElements(); int blocksPerGrid = (int) ceil(N / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _assignKhatriRaoProductOf<<>>(Data(), a.Data(), b.Data(), rowsA, rowsB, cols); return *this; } //column-wise reshaped product. Used to compute KhatriRaoProduct Gradient // this = reshape each column of a from (K1xK2,1) to (K1, K2) // if each column of a is not transposed, each (K1, K2) times each column of b (K2, frames). // the output is a (K1, frames) matrix // if each column of a is tranposed, each (K1, K2)^T times each column of b(K1, frames) and output is (K2, frames) template GPUMatrix& GPUMatrix::AddColumnReshapeProductOf(const GPUMatrix& a, const GPUMatrix& b, const bool transposeAColumn) { if (a.IsEmpty() || b.IsEmpty()) LogicError("AddColumnReshapeProductOf: Matrix is empty."); CUDA_LONG cols = a.GetNumCols(); assert(cols == b.GetNumCols()); if (!(cols == b.GetNumCols())) InvalidArgument("AddColumnReshapeProductOf: The input matrix dimensions do not match."); CUDA_LONG rowsA = (CUDA_LONG) a.GetNumRows(); CUDA_LONG rowsB = (CUDA_LONG) b.GetNumRows(); if (rowsA % rowsB != 0) InvalidArgument("AddColumnReshapeProductOf: number of rows in a should be multiples of that in b."); CUDA_LONG rowsC = rowsA / rowsB; if (rowsC != GetNumRows() || cols != GetNumCols()) InvalidArgument("AddColumnReshapeProductOf: This matrix does not have the right size."); float N = (float) GetNumElements(); int blocksPerGrid = (int) ceil(N / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _addColumnReshapeProductOf<<>>(Data(), a.Data(), b.Data(), rowsB, rowsC, cols, transposeAColumn); return *this; } template GPUMatrix& GPUMatrix::AddWithScaleOf(ElemType alpha, const GPUMatrix& a) { ScaleAndAdd(alpha, a, *this); return *this; } template ElemType GPUMatrix::FrobeniusNorm() const { if (IsEmpty()) LogicError("FrobeniusNorm: Matrix is empty."); ElemType* d_sum = TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), 1); ElemType h_sum = 0; // WARNING: THIS kernel is not the most efficient way! // note: kernel has hard-coded dimension of 1024 _reductionSum21024Threads << <1, 1024, 0, t_stream >> >(Data(), d_sum, (CUDA_LONG)GetNumElements(), true); CUDA_CALL(cudaMemcpy(&h_sum, d_sum, sizeof(ElemType), cudaMemcpyDeviceToHost)); TracingGPUMemoryAllocator::Free(GetComputeDeviceId(), d_sum); return (h_sum); } template GPUMatrix& GPUMatrix::AssignFrobeniusNormOf(const GPUMatrix& a) { if (a.IsEmpty()) LogicError("AssignFrobeniusNormOf: Matrix a is empty."); RequireSize(1, 1); PrepareDevice(); // WARNING: THIS kernel is not the most efficient way! // note: kernel has hard-coded dimension of 1024 _reductionSum21024Threads << <1, 1024, 0, t_stream >> >(a.Data(), Data(), (CUDA_LONG)a.GetNumElements(), true); return *this; } template ElemType GPUMatrix::MatrixNormInf() const { if (IsEmpty()) LogicError("MatrixNormInf: Matrix is empty."); ElemType* d_maxAbs = TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), 1); ElemType h_maxAbs = 0; // WARNING: THIS kernel is not the most efficient way! // note: kernel has hard-coded dimension of 1024 _reductionMatrixNormInf1024Threads << <1, 1024, 0, t_stream >> >(Data(), d_maxAbs, (CUDA_LONG)GetNumElements()); CUDA_CALL(cudaMemcpy(&h_maxAbs, d_maxAbs, sizeof(ElemType), cudaMemcpyDeviceToHost)); TracingGPUMemoryAllocator::Free(GetComputeDeviceId(), d_maxAbs); return h_maxAbs; } template ElemType GPUMatrix::MatrixNorm1() const { if (IsEmpty()) LogicError("MatrixNorm1: Matrix is empty."); return SumOfAbsElements(); } template ElemType GPUMatrix::MatrixNorm0() const { if (IsEmpty()) LogicError("MatrixNorm0: Matrix is empty."); ElemType* d_nz = TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), 1); ElemType h_nz = 0; // WARNING: THIS kernel is not the most efficient way! // note: kernel has hard-coded dimension of 1024 _reductionMatrixNorm01024Threads << <1, 1024, 0, t_stream >> >(Data(), d_nz, (CUDA_LONG)GetNumElements()); CUDA_CALL(cudaMemcpy(&h_nz, d_nz, sizeof(ElemType), cudaMemcpyDeviceToHost)); TracingGPUMemoryAllocator::Free(GetComputeDeviceId(), d_nz); return h_nz; } template GPUMatrix& GPUMatrix::AssignSignOf(const GPUMatrix& a) { if (a.IsEmpty()) LogicError("AssignSignOf: Matrix a is empty."); if (this != &a) RequireSize(a.GetNumRows(), a.GetNumCols()); PrepareDevice(); int blocksPerGrid = (int) ceil(1.0 * GetNumElements() / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _assignSignOf<<>>(Data(), a.Data(), (CUDA_LONG) GetNumElements()); return *this; } template GPUMatrix& GPUMatrix::AddSignOf(const GPUMatrix& a) { if (a.IsEmpty()) LogicError("AddSignOf: Matrix a is empty."); if (this != &a) RequireSize(a.GetNumRows(), a.GetNumCols()); PrepareDevice(); int blocksPerGrid = (int) ceil(1.0 * GetNumElements() / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _addSignOf<<>>(Data(), a.Data(), (CUDA_LONG) GetNumElements()); return *this; } template void GPUMatrix::VectorMax(GPUMatrix& maxIndexes, GPUMatrix& maxValues, const bool isColWise) const { if (IsEmpty()) LogicError("VectorMax: Matrix is empty."); const GPUMatrix& us = *this; const CUDA_LONG m = (CUDA_LONG) GetNumRows(); const CUDA_LONG n = (CUDA_LONG) GetNumCols(); assert(m > 0 && n > 0); // converting from size_t to int may cause overflow PrepareDevice(); SyncGuard syncGuard; if (isColWise) { maxValues.RequireSize(1, n); maxIndexes.RequireSize(1, n); int blocksPerGrid = n; // we'll have 1 block processing 1 column // note: kernel has hard-coded dimension of 512 _vectorMaxMinReduce512Threads<<>>(us.Data(), maxIndexes.Data(), maxValues.Data(), m, n); /*int blocksPerGrid=(int)ceil(1.0*n/GridDim::maxThreadsPerBlock); _vectorMax<<>>(us.Data(),maxIndexes.Data(),maxValues.Data(),m,n,isColWise);*/ } else { maxValues.RequireSize(m, 1); maxIndexes.RequireSize(m, 1); int blocksPerGrid = (int) ceil(1.0 * m / GridDim::maxThreadsPerBlock); _vectorMax<<>>(us.Data(), maxIndexes.Data(), maxValues.Data(), m, n, isColWise); } } __global__ void _initIndicesForSort(uint64_t* indexes, CUDA_LONG crow, CUDA_LONG ccol) { CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; if (id >= crow * ccol) return; uint32_t irow = id % crow; uint32_t icol = id / crow; indexes[id] = (static_cast(irow) << 32) | icol; } template void GPUMatrix::VectorMax(GPUMatrix& maxIndexes, GPUMatrix& maxValues, const bool isColWise, int topK) const { if (IsEmpty()) LogicError("VectorMax: Matrix is empty."); if (topK == 1) { VectorMax(maxIndexes, maxValues, isColWise); return; } if (!isColWise) RuntimeError("Row-wise TopK max is not supported."); const GPUMatrix& us = *this; const CUDA_LONG m = (CUDA_LONG) GetNumRows(); const CUDA_LONG n = (CUDA_LONG) GetNumCols(); assert(topK <= m); assert(m > 0 && n > 0); // converting from size_t to int may cause overflow PrepareDevice(); SyncGuard syncGuard; maxValues.RequireSize(topK, n); maxIndexes.RequireSize(topK, n); // To sort matrix columns we use 2-pass _stable_ sort algorithm: // 1. Sort by values (descending) with corresponding row/col indexes. // 2. Sort by col indices (ascending) with corresponding values/row indices. // Indices are stored as 64-bit ints where low 32 bits represent column and high 32 bits - row index. // On the second pass only first 32 bits of the index are used in sorting, so SortPairs has // begin_bit and end_bit set accordingly. CUDA_LONG celt = static_cast(GetNumElements()); ElemType* inVal = us.Data(); ElemType* outVal1 = nullptr; ElemType* outVal2 = nullptr; uint64_t* inIdx = nullptr; uint64_t* outIdx = nullptr; // Determine temp buffer size needed for SortPairsDescending to sort values on the first pass. size_t cbtemp = 0; // If first param is nullptr then no actual work is done except writing result to cbtemp. CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(nullptr, cbtemp, inVal, outVal1, inIdx, outIdx, celt, 0, sizeof(ElemType) * 8, t_stream)); size_t ctemp1 = (cbtemp + sizeof(ElemType) - 1) / sizeof(ElemType); // Determine temp buffer size needed for SortPairs to sort indices on the second pass. cbtemp = 0; CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, cbtemp, outIdx, inIdx, outVal1, outVal2, celt, 0, 32, t_stream)); size_t ctemp2 = (cbtemp + sizeof(ElemType) - 1) / sizeof(ElemType); size_t ctemp = std::max(ctemp1, ctemp2); cbtemp = ctemp * sizeof(ElemType); // ElemType count needed to store indices, accounting for natural alignment for uint64_t type. size_t cidx = ((celt + 1) * sizeof(uint64_t) - 1 + sizeof(ElemType) - 1) / sizeof(ElemType); // Get temp workspace. auto workspace = GetOrCreateWorkspace(); // RequireSize to store: output values for the 1st and 2nd passes, input indices, output indices, and temp storage. workspace->RequireSize(m, 2 * n + (2 * cidx + ctemp + m - 1) / m); outVal1 = workspace->Data(); outVal2 = outVal1 + celt; inIdx = reinterpret_cast(outVal2 + celt); // Align indices pointer if needed. size_t cbAlign = reinterpret_cast(inIdx) % sizeof(uint64_t); if (cbAlign != 0) reinterpret_cast(inIdx) += sizeof(uint64_t) - cbAlign; outIdx = inIdx + celt; void* ptmp = outIdx + celt; assert(reinterpret_cast(reinterpret_cast(ptmp) + cbtemp) <= workspace->Data() + workspace->GetNumElements()); // Initialize indices. const int ThreadsPerBlock = 128; int cblock = (celt + ThreadsPerBlock - 1) / ThreadsPerBlock; _initIndicesForSort<<>>(inIdx, m, n); // Sort by values. CUDA_CALL(cub::DeviceRadixSort::SortPairsDescending(ptmp, cbtemp, inVal, outVal1, inIdx, outIdx, celt, 0, sizeof(ElemType) * 8, t_stream)); // Sort by column indices. outIdx contains indices after the first pass so it's used as an input. CUDA_CALL(cub::DeviceRadixSort::SortPairs(ptmp, cbtemp, outIdx, inIdx, outVal1, outVal2, celt, 0, 32, t_stream)); // Copy results. cblock = (topK * n + ThreadsPerBlock - 1) / ThreadsPerBlock; _copyTopKResults<<>>(inIdx, outVal2, maxIndexes.Data(), maxValues.Data(), m, n, topK); ReleaseWorkspace(std::move(workspace)); } template void GPUMatrix::VectorMin(GPUMatrix& minIndexes, GPUMatrix& minValues, const bool isColWise) const { if (IsEmpty()) LogicError("VectorMax: Matrix is empty."); const GPUMatrix& us = *this; const int m = (int) GetNumRows(); const int n = (int) GetNumCols(); assert(m > 0 && n > 0); // converting from size_t to int may cause overflow PrepareDevice(); SyncGuard syncGuard; if (isColWise) { minValues.RequireSize(1, n); minIndexes.RequireSize(1, n); int blocksPerGrid = n; // we'll have 1 block processing 1 column // note: kernel has hard-coded dimension of 512 _vectorMaxMinReduce512Threads << > >(us.Data(), minIndexes.Data(), minValues.Data(), m, n); /* int blocksPerGrid=(int)ceil(1.0*n/GridDim::maxThreadsPerBlock); _vectorMin<<>>(us.Data(),minIndexes.Data(),minValues.Data(),m,n,isColWise);*/ } else { minValues.RequireSize(m, 1); minIndexes.RequireSize(m, 1); int blocksPerGrid = (int) ceil(1.0 * m / GridDim::maxThreadsPerBlock); _vectorMin<<>>(us.Data(), minIndexes.Data(), minValues.Data(), m, n, isColWise); } } template GPUMatrix& GPUMatrix::AssignNumOfDiff(const GPUMatrix& a, const GPUMatrix& b, bool searchInCol) { if (a.GetNumCols() != b.GetNumCols()) InvalidArgument("AssignNumOfDiff: a and b must have the same number of columns."); if (!searchInCol && a.GetNumRows() != b.GetNumRows()) InvalidArgument("AssignNumOfDiff: a and b must have the same number of rows."); RequireSize(1, 1); // result should be one element PrepareDevice(); SyncGuard syncGuard; if (!searchInCol) { // int blocksPerGrid=(int)ceil(1.0*a.GetNumElements()/GridDim::maxThreadsPerBlock); // _assignNumOfDiff1024Threads<<>>(a.Data(), b.Data(), Data(), a.GetNumElements()); // note: kernel has hard-coded dimension of 1024 _assignNumOfDiff1024Threads << <1, 1024, 0, t_stream >> >(a.Data(), b.Data(), Data(), (CUDA_LONG)a.GetNumElements()); } else { const int blockSize = 1024; _assignNumOfDiffCol<<<1, blockSize, 0, t_stream>>>(a.Data(), b.Data(), Data(), static_cast(b.GetNumRows()), static_cast(a.GetNumCols())); } return *this; } #pragma endregion Member BLAS Functions #pragma region Other helper functions template void GPUMatrix::Print(const char* /*matrixName*/, size_t /*rowStart*/, size_t /*rowEnd*/, size_t /*colStart*/, size_t /*colEnd*/) const { NOT_IMPLEMENTED; } template void GPUMatrix::Print(const char* matrixName /*=nullptr*/) const { size_t elemCount = GetNumRows() * GetNumCols(); vector localCopy(elemCount); cudaMemcpy(localCopy.data(), Data(), elemCount * sizeof(ElemType), cudaMemcpyDeviceToHost); fprintf(stderr, "\n###### "); if (matrixName != nullptr) fprintf(stderr, "%s ", matrixName); fprintf(stderr, "(%lu, %lu) ######\n\n", (unsigned long)GetNumRows(), (unsigned long)GetNumCols()); if (IsEmpty()) { fprintf(stderr, "(empty)\n"); return; } // CNTK is using column-major storage for (size_t i = 0; i < GetNumRows(); i++) { for (size_t j = 0; j < GetNumCols(); j++) { fprintf(stderr, "%.10f\t", localCopy[i + j * GetNumRows()]); } fprintf(stderr, "\n"); } } //helpfer function used for convolution neural network template GPUMatrix& GPUMatrix::AssignPackedConvolutionInput(const GPUMatrix& inputSubBatch, const size_t inputWidth, const size_t inputHeight, const size_t inputChannels, const size_t outputWidth, const size_t outputHeight, const size_t outputChannels, const size_t kernelWidth, const size_t kernelHeight, const size_t horizontalSubsample, const size_t verticalSubsample, const bool zeroPadding) { assert(verticalSubsample <= kernelHeight && horizontalSubsample <= kernelWidth); size_t packedInputRows = kernelWidth * kernelHeight * inputChannels; size_t packedInputColsPerSample = outputWidth * outputHeight; size_t smallBatchSize = inputSubBatch.GetNumCols(); RequireSize(packedInputRows, packedInputColsPerSample * smallBatchSize); if (zeroPadding) SetValue((ElemType) 0); PrepareDevice(); int numThreadPerBlock = GridDim::maxThreadsPerBlock; #if 1 int blocksPerGrid = (smallBatchSize * inputWidth * inputHeight * inputChannels + numThreadPerBlock - 1) / numThreadPerBlock; #else dim3 blocksPerGrid((inputWidth * inputHeight * inputChannels + numThreadPerBlock - 1) / numThreadPerBlock, smallBatchSize); #endif SyncGuard syncGuard; _assignPackedConvolutionInput<<>>(Data(), inputSubBatch.Data(), smallBatchSize, inputWidth, inputHeight, inputChannels, outputWidth, outputHeight, outputChannels, kernelWidth, kernelHeight, horizontalSubsample, verticalSubsample, zeroPadding); return *this; } //helpfer function used for convolution neural network template GPUMatrix& GPUMatrix::UnpackConvolutionInput(GPUMatrix& inputSubBatch, const size_t inputWidth, const size_t inputHeight, const size_t inputChannels, const size_t outputWidth, const size_t outputHeight, const size_t outputChannels, const size_t kernelWidth, const size_t kernelHeight, const size_t horizontalSubsample, const size_t verticalSubsample, const bool zeroPadding) const { assert(verticalSubsample <= kernelHeight && horizontalSubsample <= kernelWidth); size_t smallBatchSize = inputSubBatch.GetNumCols(); PrepareDevice(); int numThreadPerBlock = GridDim::maxThreadsPerBlock; #if 1 int blocksPerGrid = (smallBatchSize * inputWidth * inputHeight * inputChannels + numThreadPerBlock - 1) / numThreadPerBlock; #else dim3 blocksPerGrid((inputWidth * inputHeight * inputChannels + numThreadPerBlock - 1) / numThreadPerBlock, smallBatchSize); #endif SyncGuard syncGuard; _unpackConvolutionInput<<>>(Data(), inputSubBatch.Data(), smallBatchSize, inputWidth, inputHeight, inputChannels, outputWidth, outputHeight, outputChannels, kernelWidth, kernelHeight, horizontalSubsample, verticalSubsample, zeroPadding); return inputSubBatch; } template GPUMatrix& GPUMatrix::AssignMaxPoolingResult(const GPUMatrix& inputBatch, const size_t channels, const size_t inputWidth, const size_t inputHeight, const size_t inputSizePerSample, const size_t outputWidth, const size_t outputHeight, const size_t outputSizePerSample, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample) { assert(verticalSubsample <= windowHeight && horizontalSubsample <= windowWidth); unsigned int batchSize = inputBatch.GetNumCols(); RequireSize(outputSizePerSample, batchSize); int numThreadPerBlock = GridDim::maxThreadsPerBlock; int blocksPerGrid = (batchSize * outputSizePerSample + numThreadPerBlock - 1) / numThreadPerBlock; PrepareDevice(); SyncGuard syncGuard; _assignMaxPoolingResult<<>>(Data(), inputBatch.Data(), batchSize, channels, inputWidth, inputHeight, inputSizePerSample, outputWidth, outputHeight, outputSizePerSample, windowWidth, windowHeight, horizontalSubsample, verticalSubsample); return *this; } template GPUMatrix& GPUMatrix::AddMaxPoolingGradient(const GPUMatrix& outputGradientBatch, const GPUMatrix& inputBatch, const GPUMatrix& outputBatch, const size_t channels, const size_t inputWidth, const size_t inputHeight, const size_t inputSizePerSample, const size_t outputWidth, const size_t outputHeight, const size_t outputSizePerSample, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample) { assert(verticalSubsample <= windowHeight && horizontalSubsample <= windowWidth); unsigned int batchSize = outputGradientBatch.GetNumCols(); int numThreadPerBlock = GridDim::maxThreadsPerBlock; PrepareDevice(); SyncGuard syncGuard; int blocksPerGrid = (batchSize * inputSizePerSample + numThreadPerBlock - 1) / numThreadPerBlock; _addMaxPoolingGradient<<>>(Data(), outputGradientBatch.Data(), inputBatch.Data(), outputBatch.Data(), batchSize, channels, inputWidth, inputHeight, inputSizePerSample, outputWidth, outputHeight, outputSizePerSample, windowWidth, windowHeight, horizontalSubsample, verticalSubsample); return *this; } template GPUMatrix& GPUMatrix::AssignAveragePoolingResult(const GPUMatrix& inputBatch, const size_t channels, const size_t inputWidth, const size_t inputHeight, const size_t inputSizePerSample, const size_t outputWidth, const size_t outputHeight, const size_t outputSizePerSample, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample) { assert(verticalSubsample <= windowHeight && horizontalSubsample <= windowWidth); unsigned int batchSize = inputBatch.GetNumCols(); RequireSize(outputSizePerSample, batchSize); int numThreadPerBlock = GridDim::maxThreadsPerBlock; int blocksPerGrid = (batchSize * outputSizePerSample + numThreadPerBlock - 1) / numThreadPerBlock; PrepareDevice(); SyncGuard syncGuard; _assignAveragePoolingResult<<>>(Data(), inputBatch.Data(), batchSize, channels, inputWidth, inputHeight, inputSizePerSample, outputWidth, outputHeight, outputSizePerSample, windowWidth, windowHeight, horizontalSubsample, verticalSubsample); return *this; } template GPUMatrix& GPUMatrix::AddAveragePoolingGradient(const GPUMatrix& outputGradientBatch, const size_t channels, const size_t inputWidth, const size_t inputHeight, const size_t inputSizePerSample, const size_t outputWidth, const size_t outputHeight, const size_t outputSizePerSample, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample) { assert(verticalSubsample <= windowHeight && horizontalSubsample <= windowWidth); size_t batchSize = outputGradientBatch.GetNumCols(); int numThreadPerBlock = GridDim::maxThreadsPerBlock; PrepareDevice(); SyncGuard syncGuard; size_t blocksPerGrid = (batchSize * inputSizePerSample + numThreadPerBlock - 1) / numThreadPerBlock; _addAveragePoolingGradient<<>>(Data(), outputGradientBatch.Data(), (CUDA_LONG) batchSize, channels, inputWidth, inputHeight, inputSizePerSample, outputWidth, outputHeight, outputSizePerSample, windowWidth, windowHeight, horizontalSubsample, verticalSubsample); return *this; } #pragma endregion Other helper functions template void GPUMatrix::ConvolutionForward(const GPUMatrix& kernel, const GPUMatrix& mpRowCol, const GPUMatrix& mpRowIwht, const GPUMatrix& mpRowRun, const GPUMatrix& runs, GPUMatrix& output) const { const int BlockSize = 128; auto gdim = dim3((output.GetNumRows() + BlockSize - 1)/ BlockSize, std::min((int)GetNumCols(), 65535)); PrepareDevice(); SyncGuard syncGuard; kConvolutionForward<<>>((int)GetNumCols(), kernel.Data(), mpRowCol.Data(), mpRowIwht.Data(), mpRowRun.Data(), runs.Data(), Data(), (int)GetNumRows(), output.Data(), (int)output.GetNumRows()); } template void GPUMatrix::ConvolutionBackwardData(const GPUMatrix& kernel, const GPUMatrix& mpRowCol, const GPUMatrix& mpRowIwht, const GPUMatrix& mpRowRun, const GPUMatrix& runs, GPUMatrix& grad) const { const int BlockSize = 128; auto gdim = dim3((GetNumRows() + BlockSize - 1)/ BlockSize, std::min((int)GetNumCols(), 65535)); PrepareDevice(); SyncGuard syncGuard; kConvolutionBackwardData<<>>((int)GetNumCols(), kernel.Data(), mpRowCol.Data(), mpRowIwht.Data(), mpRowRun.Data(), runs.Data(), Data(), (int)GetNumRows(), grad.Data(), (int)grad.GetNumRows()); } template void GPUMatrix::ConvolutionBackwardKernel(const GPUMatrix& in, const GPUMatrix& mpRowCol, const GPUMatrix& mpRowIwht, const GPUMatrix& mpRowRun, const GPUMatrix& runs, GPUMatrix& kernelGrad) const { const int BlockSize = 128; auto gdim = dim3((GetNumRows() + BlockSize - 1)/ BlockSize, std::min((int)GetNumCols(), 65535)); PrepareDevice(); SyncGuard syncGuard; kConvolutionBackwardKernel<<>>((int)GetNumCols(), (int)in.GetNumRows(), (int)GetNumRows(), in.Data(), mpRowCol.Data(), mpRowIwht.Data(), mpRowRun.Data(), runs.Data(), Data(), kernelGrad.Data()); } template void GPUMatrix::MaxPoolingForward(const GPUMatrix& mpRowCol, const GPUMatrix& mpRowIndices, const GPUMatrix& indices, GPUMatrix& output) const { const int BlockSize = 128; auto gdim = dim3((output.GetNumRows() + BlockSize - 1)/ BlockSize, std::min((int)GetNumCols(), 65535)); PrepareDevice(); SyncGuard syncGuard; kMaxPoolingForward<<>>((int)GetNumCols(), mpRowCol.Data(), mpRowIndices.Data(), indices.Data(), Data(), (int)GetNumRows(), output.Data(), (int)output.GetNumRows()); } template void GPUMatrix::MaxPoolingBackward(const GPUMatrix& out, const GPUMatrix& in, const GPUMatrix& mpRowCol, const GPUMatrix& mpRowIndices, const GPUMatrix& indices, GPUMatrix& grad) const { const int BlockSize = 128; auto gdim = dim3((GetNumRows() + BlockSize - 1)/ BlockSize, std::min((int)GetNumCols(), 65535)); PrepareDevice(); SyncGuard syncGuard; kMaxPoolingBackward<<>>((int)GetNumCols(), out.Data(), in.Data(), mpRowCol.Data(), mpRowIndices.Data(), indices.Data(), Data(), (int)GetNumRows(), grad.Data(), (int)grad.GetNumRows()); } template void GPUMatrix::MaxROIPoolingForward(const size_t numRois, const size_t numImg, const size_t channels, const size_t width, const size_t height, const size_t pooledWidth, const size_t pooledHeight, const GPUMatrix& roiData, GPUMatrix& output, GPUMatrix& argmax, double spatialScale) const { PrepareDevice(); SyncGuard syncGuard; int count = numRois * numImg * channels * pooledHeight * pooledWidth; const int blockSize = GridDim::maxThreadsPerBlock; auto numThreads = dim3((int)floor((double)(count + blockSize - 1) / blockSize)); kMaxROIPoolingForward<<>>(count, numRois, numImg, channels, width, height, pooledWidth, pooledHeight, Data(), roiData.Data(), output.Data(), argmax.Data(), spatialScale); } template void GPUMatrix::MaxROIPoolingBackward(const size_t numRois, const size_t numImg, const size_t channels, const size_t width, const size_t height, const size_t pooledWidth, const size_t pooledHeight, const GPUMatrix& roiData, GPUMatrix& grad, GPUMatrix& argmax, double spatialScale) const { PrepareDevice(); SyncGuard syncGuard; int count = numImg * channels * height * width; const int blockSize = GridDim::maxThreadsPerBlock; auto numThreads = dim3((int)floor((double)(count + blockSize - 1) / blockSize)); kMaxROIPoolingBackward<<>>(count, numRois, numImg, channels, width, height, pooledWidth, pooledHeight, Data(), roiData.Data(), grad.Data(), argmax.Data(), spatialScale); } template void GPUMatrix::MaxUnpooling(const GPUMatrix& mpRowCol, const GPUMatrix& mpRowIndices, const GPUMatrix& indices, const GPUMatrix& poolInput, GPUMatrix& input) const { const int BlockSize = 128; auto gdim = dim3((GetNumRows() + BlockSize - 1)/ BlockSize, std::min((int)GetNumCols(), 65535)); PrepareDevice(); SyncGuard syncGuard; kMaxUnpooling<<>>((int)GetNumCols(), mpRowCol.Data(), mpRowIndices.Data(), indices.Data(), Data(), poolInput.Data(), (int)GetNumRows(), input.Data(), (int)input.GetNumRows()); } template void GPUMatrix::AveragePoolingForward(const GPUMatrix& mpRowCol, const GPUMatrix& mpRowIndices, const GPUMatrix& indices, GPUMatrix& output) const { const int BlockSize = 128; auto gdim = dim3((output.GetNumRows() + BlockSize - 1)/ BlockSize, std::min((int)GetNumCols(), 65535)); PrepareDevice(); SyncGuard syncGuard; kAveragePoolingForward<<>>((int)GetNumCols(), mpRowCol.Data(), mpRowIndices.Data(), indices.Data(), Data(), (int)GetNumRows(), output.Data(), (int)output.GetNumRows()); } template void GPUMatrix::AveragePoolingBackward(const GPUMatrix& mpRowCol, const GPUMatrix& mpRowIndices, const GPUMatrix& indices, GPUMatrix& grad) const { const int BlockSize = 128; auto gdim = dim3((GetNumRows() + BlockSize - 1)/ BlockSize, std::min((int)GetNumCols(), 65535)); PrepareDevice(); SyncGuard syncGuard; kAveragePoolingBackward<<>>((int)GetNumCols(), mpRowCol.Data(), mpRowIndices.Data(), indices.Data(), Data(), (int)GetNumRows(), grad.Data(), (int)grad.GetNumRows()); } // returns savedMean/savedInvStdDev which are the actual values used to perform the normalization, except for blendFactor 1, in which case they are unused and set to empty template void GPUMatrix::BatchNormalizationForward(const GPUMatrix& scale, const GPUMatrix& bias, bool inferenceOnly, double expAvgFactor, double blendFactor, GPUMatrix& runMean, GPUMatrix& runVariance, GPUMatrix& out, double epsilon, GPUMatrix& savedMean, GPUMatrix& savedInvStdDev) const { assert((GetNumRows() % scale.GetNumRows()) == 0); bool spatial = GetNumRows() != scale.GetNumRows(); size_t vectorSize = GetNumRows(); size_t spatialSize = spatial ? (GetNumRows() / scale.GetNumRows()) : 1; size_t batchSize = GetNumCols(); bool normalizeRunningStats; assert(0 < vectorSize && vectorSize <= std::numeric_limits::max()); assert(0 < batchSize && batchSize <= std::numeric_limits::max()); SyncGuard syncGuard; if (inferenceOnly) { // Pick running statistics for normalizing. No update reuqired, and // saved statistics do not need to be produced. assert(expAvgFactor == 0 && blendFactor == 1); normalizeRunningStats = true; savedMean.RequireSize(0, 0); savedInvStdDev.RequireSize(0, 0); } else { // Compute data mean and inverse standard deviation (into savedMean and // savedInvStdDev), and update running mean and variance. // TODO expAvgFactor == 0 && blendFactor == 1 can be optimized (no need for update). normalizeRunningStats = false; savedMean.RequireSize(runMean); savedInvStdDev.RequireSize(runMean); if (spatial) { Call(spatialSize, vectorSize, spatialSize, batchSize, Data(), expAvgFactor, blendFactor, runMean.Data(), runVariance.Data(), epsilon, savedMean.Data(), savedInvStdDev.Data(), GetStream()); } else { Call(vectorSize, vectorSize, batchSize, Data(), expAvgFactor, blendFactor, runMean.Data(), runVariance.Data(), epsilon, savedMean.Data(), savedInvStdDev.Data(), GetStream()); } } Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, spatial, normalizeRunningStats, epsilon, Data(), out.Data(), scale.Data(), bias.Data(), runMean.Data(), runVariance.Data(), savedMean.Data(), savedInvStdDev.Data(), GetStream()); } // savedMean/savedInvStdDev are the interpolated mean/inverse standard deviation as used in ForwardProp(). // For blendFactor=1, they are not used and can be uninitialized or empty. template void GPUMatrix::BatchNormalizationBackward(const GPUMatrix& in, GPUMatrix& grad, const GPUMatrix& scale, double blendFactor, const GPUMatrix& savedMean, const GPUMatrix& savedInvStdDev, GPUMatrix& scaleGrad, GPUMatrix& biasGrad) const { assert((GetNumRows() % scale.GetNumRows()) == 0); bool spatial = GetNumRows() != scale.GetNumRows(); size_t vectorSize = GetNumRows(); size_t spatialSize = spatial ? (GetNumRows() / scale.GetNumRows()) : 1; size_t batchSize = GetNumCols(); assert(0 < vectorSize && vectorSize <= std::numeric_limits::max()); assert(0 < batchSize && batchSize <= std::numeric_limits::max()); SyncGuard syncGuard; if (spatial) { Call(spatialSize, vectorSize, spatialSize, batchSize, in.Data(), Data(), scaleGrad.Data(), biasGrad.Data(), savedMean.Data(), savedInvStdDev.Data(), GetStream()); } else { Call(vectorSize, vectorSize, batchSize, in.Data(), Data(), scaleGrad.Data(), biasGrad.Data(), savedMean.Data(), savedInvStdDev.Data(), GetStream()); } ElemType mbStatsWeight = (ElemType)(1 - blendFactor); // weight for contribution from actual MB stats (0 if none, e.g. locked BN node) Call(spatial ? spatialSize : vectorSize, vectorSize, spatialSize, batchSize, spatial, in.Data(), Data(), grad.Data(), scale.Data(), mbStatsWeight, scaleGrad.Data(), biasGrad.Data(), savedMean.Data(), savedInvStdDev.Data(), GetStream()); } #pragma region RNN Functions template void GPUMatrix::RNNForward(const GPUMatrix &inputX, const GPUMatrix ¶mW, size_t xDim, size_t yDim, const vector& numSequencesForFrame, const RnnAttributes& rnnAttributes, GPUMatrix& reserve, GPUMatrix& workspace) { // numLayers, hiddenSize are input parameters if (!m_rnnExecutor) m_rnnExecutor = std::make_unique>(xDim, yDim, rnnAttributes); m_rnnExecutor->ForwardCore(paramW, inputX, *this, numSequencesForFrame, rnnAttributes, reserve, workspace); } template void GPUMatrix::RNNBackwardData(const GPUMatrix& outputDY, const GPUMatrix& paramW, GPUMatrix& outputDX, const RnnAttributes& rnnAttributes, GPUMatrix& reserve, GPUMatrix& workspace) { if (!m_rnnExecutor) LogicError("RNNBackwardData called, but RNNWrapper object is not yet initialized"); m_rnnExecutor->BackwardDataCore(*this, outputDY, paramW, outputDX, rnnAttributes, reserve, workspace); } template void GPUMatrix::RNNBackwardWeights(const GPUMatrix& inputX, const GPUMatrix& outputY, GPUMatrix& dw, const RnnAttributes& rnnAttributes, GPUMatrix& reserve, GPUMatrix& workspace) { if (!m_rnnExecutor) LogicError("RNNBackwardWeights called, but RNNWrapper object is not yet initialized"); m_rnnExecutor->BackwardWeightsCore(inputX, outputY, dw, rnnAttributes, reserve, workspace); } #pragma region Static BLAS Functions // float/double overloads of cublasSgemm()/cublasDgemm() static cublasStatus_t cublas_gemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* A, int lda, const float* B, int ldb, const float* beta, float* C, int ldc) { return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } static cublasStatus_t cublas_gemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const double* alpha, const double* A, int lda, const double* B, int ldb, const double* beta, double* C, int ldc) { return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } static cublasStatus_t cublas_axpy(cublasHandle_t handle, int n, const float* alpha, const float* x, int incx, float* y, int incy) { return cublasSaxpy(handle, n, alpha, x, incx, y, incy); } static cublasStatus_t cublas_axpy(cublasHandle_t handle, int n, const double* alpha, const double* x, int incx, double* y, int incy) { return cublasDaxpy(handle, n, alpha, x, incx, y, incy); } template void GPUMatrix::MultiplyAndWeightedAdd(ElemType alpha, const GPUMatrix& a, const bool transposeA, const GPUMatrix& b, const bool transposeB, ElemType beta, GPUMatrix& c) { a.PrepareDevice(); if ((a.GetComputeDeviceId() != b.GetComputeDeviceId()) || (b.GetComputeDeviceId() != c.GetComputeDeviceId())) // different GPUs InvalidArgument("All matrices must be on the same GPU"); cublasHandle_t cuHandle = GetCublasHandle(b.GetComputeDeviceId()); cublasOperation_t transA = transposeA ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t transB = transposeB ? CUBLAS_OP_T : CUBLAS_OP_N; int m = int(transposeA ? a.m_numCols : a.m_numRows); int n = int(transposeB ? b.m_numRows : b.m_numCols); int k = int(transposeA ? a.m_numRows : a.m_numCols); int l = int(transposeB ? b.m_numCols : b.m_numRows); if (beta == 0) c.RequireSize(m, n); else c.VerifySize(m, n); // Can't resize if beta != 0 if (!(m > 0 && k > 0 && l > 0 && n > 0)) RuntimeError("!(m>0 && k>0 && l>0 && n>0)"); // converting from size_t to int may cause overflow if (k != l) RuntimeError("matrix dim mismatch in MultiplyAndWeightedAdd"); CUBLAS_CALL(cublas_gemm(cuHandle, transA, transB, m, n, k, &alpha, a.Data(), (int) a.m_numRows, b.Data(), (int) b.m_numRows, &beta, c.Data(), (int) c.m_numRows)); } template void GPUMatrix::Multiply1x1AndWeightedAdd(ElemType alpha, const GPUMatrix& a, const GPUMatrix& b, ElemType beta, GPUMatrix& c) { a.PrepareDevice(); if ((a.GetComputeDeviceId() != b.GetComputeDeviceId()) || (b.GetComputeDeviceId() != c.GetComputeDeviceId())) // different GPUs InvalidArgument("All matrices must be on the same GPU"); CUDA_LONG N = (CUDA_LONG) c.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _multiply1x1AndWeightedAdd<<>>(alpha, a.Data(), b.Data(), beta, c.Data(), N); } template void GPUMatrix::MultiplyAndAdd(const GPUMatrix& a, const bool transposeA, const GPUMatrix& b, const bool transposeB, GPUMatrix& c) { return GPUMatrix::MultiplyAndWeightedAdd(1, a, transposeA, b, transposeB, 1, c); } template void GPUMatrix::Multiply(const GPUMatrix& a, const bool transposeA, const GPUMatrix& b, const bool transposeB, GPUMatrix& c) { return GPUMatrix::MultiplyAndWeightedAdd(1, a, transposeA, b, transposeB, 0, c); } template void GPUMatrix::Multiply(const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c) { return GPUMatrix::MultiplyAndWeightedAdd(1, a, false, b, false, 0, c); } template void GPUMatrix::ColumnwiseScaleAndWeightedAdd(ElemType alpha, const GPUMatrix& a, const GPUMatrix& v, ElemType beta, GPUMatrix& c) { if (v.GetNumRows() != 1 && v.GetNumCols() != 1) InvalidArgument("the argument v must be a vector"); // v is a vector if (beta == 0) c.RequireSize(a.GetNumRows(), a.GetNumCols()); else c.VerifySize(a.GetNumRows(), a.GetNumCols()); // Can't resize if beta != 0 int blocksPerGrid = (int)ceil(1.0 * c.GetNumElements() / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _columnwiseScaleAndWeightedAdd<<>>(alpha, a.Data(), v.Data(), beta, c.Data(), a.GetNumRows(), a.GetNumCols()); } /// Matrix-scalar multiply with col-major matrices: c = alpha * a + c /// if a is a column vector, add to all columns of c /// if a is a row vector, add to all rows of c /// if a is a scalar, add to all elements of c /// Scalar /// Input matrix /// Resulting matrix, user is responsible for allocating this template /*static*/ void GPUMatrix::ScaleAndAdd(ElemType alpha, const GPUMatrix& a, GPUMatrix& c) { if (a.GetComputeDeviceId() != c.GetComputeDeviceId()) { InvalidArgument("All matrices must be on the same GPU"); } else { if (a.IsEmpty() && c.IsEmpty()) return; a.PrepareDevice(); if (a.IsEmpty() || c.IsEmpty()) LogicError("ScaleAndAdd: one of the input matrices is empty."); // if (a.GetNumRows() != 1 && a.GetNumCols() != 1) // a is not a col or row vector if (a.GetNumRows() == c.GetNumRows() && a.GetNumCols() == c.GetNumCols()) // dimensions match { const int m = (int) a.GetNumRows(); const int n = (int) a.GetNumCols(); const int len = m * n; const int incx = 1; const int incy = 1; assert(m > 0 && n > 0 && len > 0); // converting from size_t to int may cause overflow assert((int) c.GetNumRows() == m && (int) c.GetNumCols() == n); if ((int) c.GetNumRows() != m || (int) c.GetNumCols() != n) InvalidArgument("dimension of matrix c does not match dimension of matrix a."); cublasHandle_t cuHandle = GetCublasHandle(a.GetComputeDeviceId()); // TODO: Overload the call to cublas_axpy to remove these ugly if/else statements. if (sizeof(ElemType) == sizeof(float)) { CUBLAS_CALL(cublasSaxpy(cuHandle, len, reinterpret_cast(&alpha), reinterpret_cast(a.Data()), incx, reinterpret_cast(c.Data()), incy)); } else if (sizeof(ElemType) == sizeof(double)) { CUBLAS_CALL(cublasDaxpy(cuHandle, len, reinterpret_cast(&alpha), reinterpret_cast(a.Data()), incx, reinterpret_cast(c.Data()), incy)); } else { RuntimeError("Unsupported template argument in GPUMatrix"); } } else if (a.GetNumElements() == 1) { CUDA_LONG N = (CUDA_LONG) c.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); c.PrepareDevice(); SyncGuard syncGuard; _scaleAndAddScalar<<>>(c.Data(), N, alpha, a.Data(), c.Data()); } else if (a.GetNumCols() == 1) // col vector, add it to all columns { CUDA_LONG m = (CUDA_LONG) c.GetNumRows(); CUDA_LONG n = (CUDA_LONG) c.GetNumCols(); if (m != (CUDA_LONG) a.GetNumRows()) InvalidArgument("To add column vector, rows should match."); int blocksPerGrid = (int) (ceil(1.0 * m * n / GridDim::maxThreadsPerBlock)); SyncGuard syncGuard; #ifdef VALIDATION printf(">>>> CUDA compute device is %d\n", a.GetComputeDeviceId()); printf(">>>> a.Data()= %p, c.Data()= %p, alpha = %f, m = %ld, n = %ld\n", a.Data(), c.Data(), alpha, m, n); for (int i = 0; i < 2; i++) { ElemType buffer[10] = {-1.234f}; cudaError_t error = cudaMemcpy(buffer, !i ? a.Data(): c.Data(), sizeof(buffer), cudaMemcpyKind::cudaMemcpyDeviceToHost); if (error == cudaError::cudaSuccess) printf("buffer valid\n"); } #endif _matrixVectorColumnWiseAddWithThreadPerElem<<>>(a.Data(), c.Data(), c.Data(), alpha, m, n); } else if (a.GetNumRows() == 1) // row vector, add it to all rows { cublasHandle_t cuHandle = GetCublasHandle(a.GetComputeDeviceId()); int m = (int) c.GetNumRows(); int n = (int) c.GetNumCols(); assert(n == (int) a.GetNumCols()); if (n != (int) a.GetNumCols()) InvalidArgument("To add row vector, cols should match."); // TODO: Overload the call to cublas_axpy to remove these ugly if/else statements. if (sizeof(ElemType) == sizeof(double)) { foreach_row (i, c) { CUBLAS_CALL(cublasDaxpy(cuHandle, n, reinterpret_cast(&alpha), reinterpret_cast(a.Data()), 1, reinterpret_cast(c.Data()+ i), m)); } } else { foreach_row (i, c) { CUBLAS_CALL(cublasSaxpy(cuHandle, n, reinterpret_cast(&alpha), reinterpret_cast(a.Data()), 1, reinterpret_cast(c.Data()+ i), m)); } } } else InvalidArgument("dimension of matrix c does not match dimension of matrix a."); } } /// Matrix-scalar multiply with col-major matrices: c = alpha * a + b /// if a is a column vector, add to all columns of b /// if a is a row vector, add to all rows of b /// if a is a scalar, add to all elements of b /// Scalar /// Input matrix /// Input matrix /// Resulting matrix, user is responsible for allocating this template /*static*/ void GPUMatrix::ScaleAndAdd(ElemType alpha, const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c) { if (a.GetComputeDeviceId() != c.GetComputeDeviceId() || a.GetComputeDeviceId() != b.GetComputeDeviceId()) { InvalidArgument("All matrices must be on the same GPU"); } else { if (a.IsEmpty() && b.IsEmpty()) return; a.PrepareDevice(); if (a.IsEmpty() || b.IsEmpty()) LogicError("ScaleAndAdd: One of the input matrices is empty."); c.RequireSize(b.GetNumRows(), b.GetNumCols()); // if (a.GetNumRows() != 1 && a.GetNumCols() != 1) // a is not a col or row vector if (a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols()) // dimensions match { /* const int m = (int)a.GetNumRows(); const int n = (int)a.GetNumCols(); const int len = m * n; const int incx = 1; const int incy = 1; assert (m>0 && n>0 && len>0); // converting from size_t to int may cause overflow */ CUDA_LONG N = (CUDA_LONG) c.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); c.PrepareDevice(); SyncGuard syncGuard; _matrixMatrixAddOnCuda<<>>(alpha, a.Data(), b.Data(), c.Data(), N); } else if (a.GetNumElements() == 1) { CUDA_LONG N = (CUDA_LONG) c.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); c.PrepareDevice(); SyncGuard syncGuard; _scaleAndAddScalar<<>>(c.Data(), N, alpha, a.Data(), b.Data()); } else if (a.GetNumCols() == 1) // col vector, add it to all columns { CUDA_LONG m = (CUDA_LONG) c.GetNumRows(); CUDA_LONG n = (CUDA_LONG) c.GetNumCols(); if (m != (CUDA_LONG) a.GetNumRows()) InvalidArgument("To add column vector, rows should match."); int blocksPerGrid = (int) (ceil(1.0 * m * n / GridDim::maxThreadsPerBlock)); SyncGuard syncGuard; _matrixVectorColumnWiseAddWithThreadPerElem<<>>(a.Data(), b.Data(), c.Data(), alpha, m, n); } else if (a.GetNumRows() == 1) // row vector, add it to all rows { CUDA_LONG m = (CUDA_LONG) c.GetNumRows(); CUDA_LONG n = (CUDA_LONG) c.GetNumCols(); if (m != (CUDA_LONG) a.GetNumRows()) InvalidArgument("To add column vector, rows should match."); int blocksPerGrid = (int) (ceil(1.0 * m * n / GridDim::maxThreadsPerBlock)); SyncGuard syncGuard; _matrixVectorRowWiseAddWithThreadPerElem<<>>(a.Data(), b.Data(), c.Data(), alpha, m, n); } else InvalidArgument("Dimension of matrix c does not match dimension of matrix a."); } } /// c += alpha * (a-b) /// if a, b, c must have same dim /// Scalar /// Input matrix /// Input matrix /// Resulting matrix, user is responsible for allocating this template void GPUMatrix::AddScaledDifference(const ElemType alpha, const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c) { if (a.GetComputeDeviceId() != c.GetComputeDeviceId()) { InvalidArgument("All matrices must be on the same GPU"); } else { a.PrepareDevice(); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumRows() == c.GetNumRows() && a.GetNumCols() == b.GetNumCols() && a.GetNumCols() == c.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumRows() == c.GetNumRows() && a.GetNumCols() == b.GetNumCols() && a.GetNumCols() == c.GetNumCols())) { InvalidArgument("AddScaledDifference: a, b, and c must have same dimension."); } if (a.IsEmpty()) LogicError("AddScaledDifference: Input matrix a is empty."); CUDA_LONG n = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _addScaledDifference<<>>(alpha, a.Data(), b.Data(), c.Data(), n); } } /// c = alpha * (a-b) /// if a, b, c must have same dim /// Scalar /// Input matrix /// Input matrix /// Resulting matrix, user is responsible for allocating this template void GPUMatrix::AssignScaledDifference(const ElemType alpha, const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c) { if (a.GetComputeDeviceId() != c.GetComputeDeviceId()) { InvalidArgument("All matrices must be on the same GPU"); } else { a.PrepareDevice(); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols())) InvalidArgument("AssignScaledDifference: a, b must have same dimension."); if (a.IsEmpty()) LogicError("AssignScaledDifference: Input matrix a is empty."); if (&c != &a && &c != &b) c.RequireSize(a.GetNumRows(), a.GetNumCols()); CUDA_LONG n = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _assignScaledDifference<<>>(alpha, a.Data(), b.Data(), c.Data(), n); } } /// c += alpha * (a-b) /// if a, b, c must have same dim /// 1X1 matrix /// Input matrix /// Input matrix /// Resulting matrix, user is responsible for allocating this template void GPUMatrix::AddScaledDifference(const GPUMatrix& alpha, const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c) { assert(alpha.GetNumElements() == 1); if (!(alpha.GetNumElements() == 1)) InvalidArgument("AddScaledDifference: alpha must be a 1X1 matrix."); if (a.GetComputeDeviceId() != c.GetComputeDeviceId()) { InvalidArgument("All matrices must be on the same GPU"); } else { a.PrepareDevice(); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumRows() == c.GetNumRows() && a.GetNumCols() == b.GetNumCols() && a.GetNumCols() == c.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumRows() == c.GetNumRows() && a.GetNumCols() == b.GetNumCols() && a.GetNumCols() == c.GetNumCols())) { InvalidArgument("AddScaledDifference: a, b, and c must have same dimension."); } if (a.IsEmpty()) LogicError("AddScaledDifference: Input matrix a is empty."); CUDA_LONG n = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _addScaledDifference<<>>(alpha.Data(), a.Data(), b.Data(), c.Data(), n); } } /// c = alpha * (a-b) /// if a, b, c must have same dim /// Scalar /// Input matrix /// Input matrix /// Resulting matrix, user is responsible for allocating this template void GPUMatrix::AssignScaledDifference(const GPUMatrix& alpha, const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c) { assert(alpha.GetNumElements() == 1); if (!(alpha.GetNumElements() == 1)) InvalidArgument("AddScaledDifference: alpha must be a 1X1 matrix."); if (a.GetComputeDeviceId() != c.GetComputeDeviceId()) { InvalidArgument("All matrices must be on the same GPU"); } else { a.PrepareDevice(); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols())) { InvalidArgument("AssignScaledDifference: a, b must have same dimension."); } if (a.IsEmpty()) LogicError("AssignScaledDifference: Input matrix a is empty."); c.RequireSize(a.GetNumRows(), a.GetNumCols()); CUDA_LONG n = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _assignScaledDifference<<>>(alpha.Data(), a.Data(), b.Data(), c.Data(), n); } } //c[ci,cj] += a[ai,aj] template void GPUMatrix::AddElementToElement(ElemType beta, const GPUMatrix& a, const size_t ai, const size_t aj, GPUMatrix& c, const size_t ci, const size_t cj) { if (ai >= a.GetNumRows() || aj >= a.GetNumCols() || ci >= c.GetNumRows() || cj >= c.GetNumCols()) InvalidArgument("AddElementToElement: Index out of range."); a.PrepareDevice(); SyncGuard syncGuard; _addElementToElement<<<1, 1, 0, t_stream>>>(beta, a.Data(), (CUDA_LONG) a.LocateElement(ai, aj), c.Data(), (CUDA_LONG) c.LocateElement(ci, cj)); } template /*static*/ void GPUMatrix::Scale(ElemType alpha, GPUMatrix& a) { if (alpha == 0) // if 0 then do not access the value, so that we can use this to multiply uninitialized matrices with beta=0 { CUDA_CALL(cudaMemset(a.Data(), 0, a.m_numRows * a.m_numCols * sizeof(ElemType))); return; } cublasHandle_t cuHandle = GetCublasHandle(a.GetComputeDeviceId()); // TODO: Overload the call to cublas_axpy to remove these ugly if/else statements. if (sizeof(ElemType) == sizeof(float)) { float alph = (float) alpha; CUBLAS_CALL(cublasSscal(cuHandle, int(a.m_numRows * a.m_numCols), &alph, (float*) a.Data(), 1)); } else if (sizeof(ElemType) == sizeof(double)) { double alph = alpha; CUBLAS_CALL(cublasDscal(cuHandle, int(a.m_numRows * a.m_numCols), &alph, (double*) a.Data(), 1)); } else { RuntimeError("Unsupported template argument in GPUMatrix"); } } template /*static*/ void GPUMatrix::Scale(GPUMatrix& alpha, GPUMatrix& a) { if (alpha.GetNumElements() != 1) { RuntimeError("Matrix alpha must be 1x1"); } cublasHandle_t cuHandle = GetCublasHandle(a.GetComputeDeviceId()); cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_DEVICE); if (sizeof(ElemType) == sizeof(float)) { CUBLAS_CALL(cublasSscal(cuHandle, int(a.m_numRows * a.m_numCols), (float*) alpha.Data(), (float*) a.Data(), 1)); } else if (sizeof(ElemType) == sizeof(double)) { CUBLAS_CALL(cublasDscal(cuHandle, int(a.m_numRows * a.m_numCols), (double*) alpha.Data(), (double*) a.Data(), 1)); } else { cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_HOST); RuntimeError("Unsupported template argument in GPUMatrix"); } cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_HOST); } template // c = alpha * a /*static*/ void GPUMatrix::Scale(ElemType alpha, const GPUMatrix& a, GPUMatrix& c) { c = a; Scale(alpha, c); } template void GPUMatrix::InnerProduct(const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c, const bool isColWise) { if (a.GetComputeDeviceId() != b.GetComputeDeviceId() || b.GetComputeDeviceId() != c.GetComputeDeviceId()) // different GPUs InvalidArgument("All matrices must be on the same GPU"); if (a.IsEmpty() || b.IsEmpty()) LogicError("Scale: one of the input matrices is empty."); const int m = (int) a.GetNumRows(); const int n = (int) a.GetNumCols(); const int k = (int) b.GetNumRows(); const int l = (int) b.GetNumCols(); assert(m > 0 && n > 0 && k > 0 && l > 0); // converting from size_t to int may cause overflow assert(m == k && n == l); // converting from size_t to int may cause overflow if (m != k || n != l) InvalidArgument("Matrices a and b should have same dimension."); if (isColWise) c.RequireSize(1, n); else c.RequireSize(m, 1); if ((isColWise && m == 1) || (!isColWise && n == 1)) // in this case it's equivalent to element-wise product { c.AssignElementProductOf(a, b); } else { c.PrepareDevice(); int blocksPerGrid = 0; if (isColWise) // col-wise { c.RequireSize(1, n); blocksPerGrid = (int) ceil(1.0 * n / GridDim::maxThreadsPerBlock); } else { c.RequireSize(m, 1); blocksPerGrid = (int) ceil(1.0 * m / GridDim::maxThreadsPerBlock); } SyncGuard syncGuard; _innerProduct<<>>(c.Data(), a.Data(), b.Data(), m, n, isColWise); } } template ElemType GPUMatrix::InnerProductOfMatrices(const GPUMatrix& a, const GPUMatrix& b) { if (a.IsEmpty() || b.IsEmpty()) LogicError("InnerProductOfMatrices: one of the input matrices is empty."); const int m = (int) a.GetNumRows(); const int n = (int) a.GetNumCols(); const int k = (int) b.GetNumRows(); const int l = (int) b.GetNumCols(); assert(m > 0 && n > 0 && k > 0 && l > 0); // converting from size_t to int may cause overflow assert(m == k && n == l); // converting from size_t to int may cause overflow if (m != k || n != l) InvalidArgument("InnerProductOfMatrices: Matrices a and b should have same dimension."); cublasHandle_t cuHandle = GetCublasHandle(a.GetComputeDeviceId()); if (sizeof(ElemType) == sizeof(double)) { double tmp = 0; CUBLAS_CALL(cublasDdot(cuHandle, m * n, reinterpret_cast(a.Data()), 1, reinterpret_cast(b.Data()), 1, &tmp)); return ElemType(tmp); // return (ElemType)ddot((int)a.GetNumElements(), reinterpret_cast (a.Data()), 1, reinterpret_cast (b.Data()), 1); } else { float tmp = 0; CUBLAS_CALL(cublasSdot(cuHandle, m * n, reinterpret_cast(a.Data()), 1, reinterpret_cast(b.Data()), 1, &tmp)); return tmp; // return (ElemType)sdot((int)a.GetNumElements(), reinterpret_cast (a.Data()), 1, reinterpret_cast (b.Data()), 1); } } template GPUMatrix& GPUMatrix::AssignInnerProductOfMatrices(const GPUMatrix& a, const GPUMatrix& b) { if (a.IsEmpty() || b.IsEmpty()) LogicError("InnerProductOfMatrices: one of the input matrices is empty."); RequireSize(1, 1); const int m = (int) a.GetNumRows(); const int n = (int) a.GetNumCols(); const int k = (int) b.GetNumRows(); const int l = (int) b.GetNumCols(); assert(m > 0 && n > 0 && k > 0 && l > 0); // converting from size_t to int may cause overflow assert(m == k && n == l); // converting from size_t to int may cause overflow if (m != k || n != l) InvalidArgument("InnerProductOfMatrices: Matrices a and b should have same dimension."); cublasHandle_t cuHandle = GetCublasHandle(a.GetComputeDeviceId()); cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_DEVICE); if (sizeof(ElemType) == sizeof(double)) { CUBLAS_CALL(cublasDdot(cuHandle, m * n, reinterpret_cast(a.Data()), 1, reinterpret_cast(b.Data()), 1, reinterpret_cast(Data()))); } else { CUBLAS_CALL(cublasSdot(cuHandle, m * n, reinterpret_cast(a.Data()), 1, reinterpret_cast(b.Data()), 1, reinterpret_cast(Data()))); } cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_HOST); return *this; } template void GPUMatrix::ElementWisePower(ElemType alpha, const GPUMatrix& a, GPUMatrix& c) { if (a.GetComputeDeviceId() != c.GetComputeDeviceId()) { InvalidArgument("All matrices must be on the same GPU"); } else { if (a.IsEmpty()) LogicError("ElementWisePower: The input matrix a is empty."); c.RequireSize(a.GetNumRows(), a.GetNumCols()); a.PrepareDevice(); SyncGuard syncGuard; CUDA_LONG N = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); _elementWisePowerOnCuda<<>>(alpha, a.Data(), c.Data(), N); } } template bool GPUMatrix::AreEqual(const GPUMatrix& a, const GPUMatrix& b, const ElemType threshold /*= 1e-8*/) { if (a.IsEmpty() || b.IsEmpty()) LogicError("AreEqual: one of the input matrices is empty."); if (a.GetNumRows() != b.GetNumRows() || a.GetNumCols() != b.GetNumCols()) return false; bool bResult = false; long* res = new long[1]; res[0] = 1; long* d_res = TracingGPUMemoryAllocator::Allocate(a.GetComputeDeviceId(), 1); CUDA_CALL(cudaMemcpy(d_res, res, sizeof(long) * 1, cudaMemcpyHostToDevice)); CUDA_LONG N = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); _areEqual<<>>(a.Data(), b.Data(), N, threshold, d_res); CUDA_CALL(cudaMemcpy(res, d_res, sizeof(long) * 1, cudaMemcpyDeviceToHost)); TracingGPUMemoryAllocator::Free(a.GetComputeDeviceId(), d_res); if (res[0] != 0) bResult = true; delete[] res; return bResult; } // see Matrix::TensorShuffleScaleAndAdd() for comments template void GPUMatrix::TensorShuffleScaleAndAdd(ElemType keepWeight, const GPUMatrix& a, size_t D, size_t S, size_t M, size_t K, size_t T, ElemType scaleFactor, const GPUMatrix& b, GPUMatrix& c) { CUDA_LONG N = (CUDA_LONG) c.GetNumElements(); assert(N == (CUDA_LONG) a.GetNumElements() && N == (CUDA_LONG) b.GetNumElements()); assert(a.GetComputeDeviceId() == c.GetComputeDeviceId() && b.GetComputeDeviceId() == c.GetComputeDeviceId()); a.PrepareDevice(); SyncGuard syncGuard; int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); _tensorShuffleScaleAndAdd<<>>(keepWeight, a.Data(), D, S, M, K, T, scaleFactor, b.Data(), c.Data()); } template bool GPUMatrix::HasElement(const GPUMatrix& a, const ElemType v) { if (a.IsEmpty()) LogicError("HasElement: the input matrix is empty."); bool bResult = false; ElemType* res = new ElemType[2]; res[0] = v; res[1] = 0; ElemType* d_res = TracingGPUMemoryAllocator::Allocate(a.GetComputeDeviceId(), 2); CUDA_CALL(cudaMemcpy(d_res, res, sizeof(ElemType) * 2, cudaMemcpyHostToDevice)); CUDA_LONG N = (CUDA_LONG) a.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); _hasElement<<>>(a.Data(), N, d_res); CUDA_CALL(cudaMemcpy(res, d_res, sizeof(ElemType) * 2, cudaMemcpyDeviceToHost)); TracingGPUMemoryAllocator::Free(a.GetComputeDeviceId(), d_res); if (res[1] != 0) bResult = true; else bResult = false; delete[] res; return bResult; } template void GPUMatrix::CreateCurandObject(unsigned long seed, const char* caller) { assert(caller != nullptr); if (s_curandGenerator == NULL) { unsigned long long cudaSeed = (seed == USE_TIME_BASED_SEED) ? time(NULL) : seed; if (GetMathLibTraceLevel() > 0) { fprintf(stderr, "%s (GPU): creating curand object with seed %llu, sizeof(ElemType)==%lu\n", caller, cudaSeed, (unsigned long)sizeof(ElemType)); } s_curandGenerator = new curandGenerator_t; // Create pseudo-random number generator CURAND_CALL(curandCreateGenerator(&(((curandGenerator_t*) s_curandGenerator)[0]), CURAND_RNG_PSEUDO_XORWOW)); CURAND_CALL(curandSetPseudoRandomGeneratorSeed(((curandGenerator_t*) s_curandGenerator)[0], cudaSeed)); CURAND_CALL(curandSetGeneratorOrdering(((curandGenerator_t*) s_curandGenerator)[0], CURAND_ORDERING_PSEUDO_SEEDED)); } } template void GPUMatrix::ResetCurandObject(unsigned long seed, const char* caller) { assert(caller != nullptr); if (s_curandGenerator && (seed != USE_TIME_BASED_SEED)) { // Note: this might be slow. CURAND_CALL(curandSetPseudoRandomGeneratorSeed(((curandGenerator_t*) s_curandGenerator)[0], seed)); CURAND_CALL(curandSetGeneratorOffset(((curandGenerator_t*) s_curandGenerator)[0], 0)); } else { CreateCurandObject(seed, caller); } } template GPUMatrix GPUMatrix::Ones(const size_t rows, const size_t cols, int deviceId) { GPUMatrix c(rows, cols, deviceId); // will initialize to 0 c.SetValue(1); return c; } template GPUMatrix GPUMatrix::Zeros(const size_t rows, const size_t cols, int deviceId) { GPUMatrix c(rows, cols, deviceId); // will initialize to 0 // c.SetValue(0); return c; } template GPUMatrix GPUMatrix::Eye(const size_t rows, int deviceId) { GPUMatrix c(rows, rows, deviceId); // will initialize to 0 c.SetDiagonalValue(1); return c; } template GPUMatrix GPUMatrix::RandomUniform(const size_t rows, const size_t cols, int deviceId, const ElemType low, const ElemType high, unsigned long seed) { GPUMatrix c(rows, cols, deviceId); // will initialize to 0 c.SetUniformRandomValue(low, high, seed); return c; } template GPUMatrix GPUMatrix::RandomGaussian(const size_t rows, const size_t cols, int deviceId, const ElemType mean, const ElemType sigma, unsigned long seed) { GPUMatrix c(rows, cols, deviceId); // will initialize to 0 c.SetGaussianRandomValue(mean, sigma, seed); return c; } template ElemType GPUMatrix::GetLearnRateForBlock_Helper(const GPUMatrix& Gradients, const GPUMatrix& SmoothedGradients) { ElemType* d_res = TracingGPUMemoryAllocator::Allocate(Gradients.GetComputeDeviceId(), 1); // Compute inner product of matrices and keep it on device const int m = (int) Gradients.GetNumRows(); const int n = (int) Gradients.GetNumCols(); const int k = (int) SmoothedGradients.GetNumRows(); const int l = (int) SmoothedGradients.GetNumCols(); assert(m > 0 && n > 0 && k > 0 && l > 0); // converting from size_t to int may cause overflow assert(m == k && n == l); // converting from size_t to int may cause overflow if (m != k || n != l) InvalidArgument("InnerProductOfMatrices: Matrices a and b should have same dimension."); if (sizeof(ElemType) == sizeof(double)) { cublasHandle_t cuHandle = GetCublasHandle(Gradients.GetComputeDeviceId()); cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_DEVICE); CUBLAS_CALL(cublasDdot(cuHandle, m * n, reinterpret_cast(Gradients.Data()), 1, reinterpret_cast(SmoothedGradients.Data()), 1, reinterpret_cast(d_res))); cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_HOST); } else { cublasHandle_t cuHandle = GetCublasHandle(Gradients.GetComputeDeviceId()); cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_DEVICE); CUBLAS_CALL(cublasSdot(cuHandle, m * n, reinterpret_cast(Gradients.Data()), 1, reinterpret_cast(SmoothedGradients.Data()), 1, reinterpret_cast(d_res))); cublasSetPointerMode(cuHandle, CUBLAS_POINTER_MODE_HOST); } // d_res[0] should now contain inner product of matrices // Compute squared Frobenius norms (squared sums of elements) // note: kernel has hard-coded dimension of 512 _lrHelper512Threads << <1, 512, 0, t_stream >> >(Gradients.Data(), SmoothedGradients.Data(), (CUDA_LONG)Gradients.GetNumElements(), d_res); ElemType res; CUDA_CALL(cudaMemcpy(&res, d_res, sizeof(ElemType), cudaMemcpyDeviceToHost)); TracingGPUMemoryAllocator::Free(Gradients.GetComputeDeviceId(), d_res); return res; } // The inputs are two row vectors [a1 a2 a3 a4] [b1 b2 b3 b4] // The outputs are one matrix of size (nt+1)*4 // The first row is just element multiplication // The rest rows will be with shift template GPUMatrix& GPUMatrix::AssignElementProductOfWithShiftNeg(const GPUMatrix& a, const GPUMatrix& b, const size_t shift, const size_t nt) { if (a.IsEmpty() || b.IsEmpty()) LogicError("AssignElementProductOf: Matrix is empty."); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols())) InvalidArgument("The input matrix dimensions do not match."); if (!(a.GetNumRows() == 1)) InvalidArgument("The input matrix must be a row vector."); RequireSize(nt + 1, a.GetNumCols()); int BS = a.GetNumCols(); // the output matrix is of size (nt+1, BS) dim3 thread_tail(DEFAULT_THREAD_PER_DIM, DEFAULT_THREAD_PER_DIM); dim3 block_tail((nt + 1 + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM, (BS + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM); a.PrepareDevice(); SyncGuard syncGuard; _assignElementProductOfWithShiftNeg<<>>(Data(), a.Data(), b.Data(), shift, nt + 1, BS); // _assignElementProductOf << > >(Data(), a.Data(), b.Data(), nt); return *this; } template GPUMatrix& GPUMatrix::AssignOneHot(const GPUMatrix& a, vector& shape, size_t axis) { if (a.IsEmpty()) LogicError("AssignOneHot: Matrix a is empty."); if (axis >= shape.size()) LogicError("AssignOneHot: axis is not correct"); size_t item_size = 1; for (size_t i = 0; i < shape.size() && i < axis; i++) item_size *= shape[i]; size_t num_class = shape[axis]; auto nCols = a.GetNumCols(); auto nRows = num_class * a.GetNumRows(); this->RequireSize(nRows, nCols); this->PrepareDevice(); CUDA_CALL(cudaMemset(Data(), 0, nCols * nRows * sizeof(ElemType))); CUDA_LONG N = (CUDA_LONG)a.GetNumElements(); int blocksPerGrid = (int)ceil(((double)N) / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _assignOneHot << > > (a.Data(), Data(), num_class, item_size, N); return *this; } template GPUMatrix& GPUMatrix::GatherFromTarget(const GPUMatrix& indices, const GPUMatrix& target, size_t row_elements) { if (indices.IsEmpty() || target.IsEmpty()) LogicError("GatherFromTarget: input matrix is empty."); if (row_elements == 0) LogicError("GatherFromTarget: target matrix at least need 1 dim."); auto nCols = indices.GetNumCols(); auto nRows = indices.GetNumRows() * row_elements; this->RequireSize(nRows, nCols); this->PrepareDevice(); ElemType* indicesBufPtr = indices.Data(); ElemType* targetBufPtr = target.Data(); ElemType* buffer = Data(); size_t num_indices = indices.GetNumElements(); CUDA_LONG N = (CUDA_LONG)num_indices * row_elements; int blocksPerGrid = (int)ceil(((double)N) / GridDim::maxThreadsPerBlock); _gatherFromTarget <<>> (indicesBufPtr, targetBufPtr, buffer, row_elements, num_indices, N); return *this; } template GPUMatrix& GPUMatrix::ScatterToIndices(const GPUMatrix& values, const GPUMatrix& indices, size_t row_elements) { if (indices.IsEmpty() || values.IsEmpty()) LogicError("ScatterToIndices: input matrix is empty."); ElemType* indicesBufPtr = indices.Data(); ElemType* valueBufPtr = values.Data(); ElemType* buffer = Data(); size_t num_indices = indices.GetNumElements(); CUDA_LONG N = (CUDA_LONG)num_indices * row_elements; int blocksPerGrid = (int)ceil(((double)N) / GridDim::maxThreadsPerBlock); _scatterToIndices << > > (indicesBufPtr, valueBufPtr, buffer, row_elements, num_indices, N); return *this; } template void GPUMatrix::InnerProductWithShiftNeg(const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c, const size_t shift, const size_t nt) { if (a.GetComputeDeviceId() != b.GetComputeDeviceId() || b.GetComputeDeviceId() != c.GetComputeDeviceId()) // different GPUs InvalidArgument("All matrices must be on the same GPU"); if (a.IsEmpty() || b.IsEmpty()) LogicError("Scale: one of the input matrices is empty."); const int m = (int) a.GetNumRows(); const int n = (int) a.GetNumCols(); const int k = (int) b.GetNumRows(); const int l = (int) b.GetNumCols(); assert(m > 0 && n > 0 && k > 0 && l > 0); // converting from size_t to int may cause overflow assert(m == k && n == l); // converting from size_t to int may cause overflow if (m != k || n != l) InvalidArgument("Matrices a and b should have same dimension."); c.RequireSize(nt + 1, n); if (true) { c.PrepareDevice(); dim3 thread_tail(DEFAULT_THREAD_PER_DIM, DEFAULT_THREAD_PER_DIM); dim3 block_tail((nt + 1 + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM, (n + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM); SyncGuard syncGuard; _innerProductWithShiftNeg<<>>(c.Data(), a.Data(), b.Data(), m, n, shift, nt + 1); } } template GPUMatrix& GPUMatrix::GetARowByIndex(const GPUMatrix& a, const size_t m) { if (a.IsEmpty()) LogicError("GetARowByIndex: Matrix is empty."); RequireSize(1, a.GetNumCols()); int n = a.GetNumRows(); int P = a.GetNumCols(); if (m >= n) LogicError("GetARowByIndex: m is out of range."); int blocksPerGrid = (int) ceil(((double) P) / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _getARowByIndex<<>>(Data(), a.Data(), n, P, m); // _assignElementProductOf << > >(Data(), a.Data(), b.Data(), nt); return *this; } // Calculate CTC score // prob (input): the posterior output from the network // alpha, beta (output): alpha and beta for forward-backward calculation. // phoneSeq (input): phone ID sequence for each utterance in this minibatch, each col is one utterance // phoneBoundary (input): phone boundary (frame index) of each phone for each utterance in this minibatch, each col is one utterance // totalScore (output): total CTC score // uttToChanInd (input): map from utterance ID to minibatch channel ID. We need this because each channel may contain more than one utterance. // uttBeginFrame(input): the position of the first frame of each utterance in the minibatch channel. We need this because each channel may contain more than one utterance. // uttFrameNum (input): the frame number of each utterance. The size of this vector = the number of all utterances in this minibatch // uttPhoneNum (input): the phone number of each utterance. The size of this vector = the number of all utterances in this minibatch // numParallelSequences (input): channel number in this minibatch // maxFrameNum (input): the maximum channel frame number // delayConstraint -- label output delay constraint introduced during training that allows to have shorter delay during inference. // Alpha and Beta scores outside of the delay boundary are set to zero. // Setting this parameter smaller will result in shorted delay between label output during decoding, yet may hurt accuracy // delayConstraint=-1 means no constraint template GPUMatrix& GPUMatrix::AssignCTCScore(const GPUMatrix& prob, GPUMatrix& alpha, GPUMatrix& beta, const GPUMatrix phoneSeq, const GPUMatrix phoneBoundary, GPUMatrix &totalScore, const std::vector& uttToChanInd, const std::vector & uttBeginFrame, const std::vector & uttFrameNum, const std::vector & uttPhoneNum, const size_t numParallelSequences, const size_t maxFrameNum, const size_t blankTokenId, const int delayConstraint, const bool isColWise) { if (isColWise) { PrepareDevice(); // Total number of phones long totalPhoneNum = prob.GetNumRows(); size_t uttNum = uttFrameNum.size(); // Max number of phones in utterances in this minibatch size_t maxPhoneNum = phoneSeq.GetNumRows(); size_t *gpuFrameNum; CUDA_CALL(cudaMalloc((void **)&gpuFrameNum, uttNum * sizeof(size_t))); CUDA_CALL(cudaMemcpy(gpuFrameNum, uttFrameNum.data(), uttNum * sizeof(size_t), cudaMemcpyHostToDevice)); size_t *gpuPhoneNum; CUDA_CALL(cudaMalloc((void **)&gpuPhoneNum, uttNum * sizeof(size_t))); CUDA_CALL(cudaMemcpy(gpuPhoneNum, uttPhoneNum.data(), uttNum * sizeof(size_t), cudaMemcpyHostToDevice)); size_t *gpuBeginFrame; CUDA_CALL(cudaMalloc((void **)&gpuBeginFrame, uttNum * sizeof(size_t))); CUDA_CALL(cudaMemcpy(gpuBeginFrame, uttBeginFrame.data(), uttNum * sizeof(size_t), cudaMemcpyHostToDevice)); size_t *gpuUttToChanInd; CUDA_CALL(cudaMalloc((void **)&gpuUttToChanInd, uttNum * sizeof(size_t))); CUDA_CALL(cudaMemcpy(gpuUttToChanInd, uttToChanInd.data(), uttNum * sizeof(size_t), cudaMemcpyHostToDevice)); cudaEvent_t done = nullptr; CUDA_CALL(cudaEventCreate(&done)); dim3 thread_tail(DEFAULT_THREAD_PER_DIM, DEFAULT_THREAD_PER_DIM); // x dimension is for utterances // y dimention is for phone sequence in each utterance // Ensure that we allocate correct number of blocks for given number of utterances and max number of phones in those utterances dim3 block_tail((uttNum + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM, (maxPhoneNum + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM); for (long t = 0; t < maxFrameNum; t++) { _assignAlphaScore << > >(prob.Data(), alpha.Data(), phoneSeq.Data(), phoneBoundary.Data(), gpuUttToChanInd, gpuFrameNum, gpuBeginFrame, gpuPhoneNum, numParallelSequences, uttNum, t, maxPhoneNum, totalPhoneNum, blankTokenId, delayConstraint); } for (long t = maxFrameNum - 1; t >= 0; t--) { _assignBetaScore << > >(prob.Data(), beta.Data(), phoneSeq.Data(), phoneBoundary.Data(), gpuUttToChanInd, gpuFrameNum, gpuBeginFrame, gpuPhoneNum, numParallelSequences, uttNum, t, maxPhoneNum, totalPhoneNum, blankTokenId, delayConstraint); } ElemType zerVar = 0.0; totalScore.SetColumn(&zerVar, 0); _assignTotalScore << > > (beta.Data(), totalScore.Data(), uttNum, gpuUttToChanInd, gpuBeginFrame, numParallelSequences, maxPhoneNum); dim3 block_tail_2((uttNum + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM, (maxFrameNum + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM); _assignCTCScore << < block_tail_2, thread_tail, 0, t_stream >> >(Data(), prob.Data(), alpha.Data(), beta.Data(), phoneSeq.Data(), uttNum, gpuUttToChanInd, gpuBeginFrame, gpuPhoneNum, gpuFrameNum, numParallelSequences, maxPhoneNum, totalPhoneNum); CUDA_CALL(cudaFree(gpuFrameNum)); CUDA_CALL(cudaFree(gpuPhoneNum)); CUDA_CALL(cudaFree(gpuBeginFrame)); CUDA_CALL(cudaFree(gpuUttToChanInd)); CUDA_CALL(cudaEventRecord(done)); CUDA_CALL(cudaEventSynchronize(done)); CUDA_CALL(cudaEventDestroy(done)); } else { NOT_IMPLEMENTED; } return *this; } template void GPUMatrix::ConductRowElementMultiplyWithShift(const GPUMatrix& a, const GPUMatrix& b, GPUMatrix& c, const size_t shift, const bool isafixed) { if (a.GetComputeDeviceId() != b.GetComputeDeviceId() || b.GetComputeDeviceId() != c.GetComputeDeviceId()) // different GPUs InvalidArgument("All matrices must be on the same GPU"); if (a.IsEmpty() || b.IsEmpty()) LogicError("Scale: one of the input matrices is empty."); const int m = (int) a.GetNumRows(); const int n = (int) a.GetNumCols(); const int O = (int) b.GetNumRows(); const int P = (int) b.GetNumCols(); assert(m > 0 && n > 0 && O > 0 && P > 0); // converting from size_t to int may cause overflow if (m != 1 || n != P) InvalidArgument("Matrices a and b should have same dimension."); c.RequireSize(O, P); if (true) { c.PrepareDevice(); dim3 thread_tail(DEFAULT_THREAD_PER_DIM, DEFAULT_THREAD_PER_DIM); dim3 block_tail((O + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM, (P + DEFAULT_THREAD_PER_DIM - 1) / DEFAULT_THREAD_PER_DIM); SyncGuard syncGuard; _conductRowElementMultiplyWithShift<<>>(c.Data(), a.Data(), b.Data(), O, P, shift, isafixed); } } template GPUMatrix& GPUMatrix::AssignElementProductOfWithShift(const GPUMatrix& a, const GPUMatrix& b, const size_t shift) { if (a.IsEmpty() || b.IsEmpty()) LogicError("AssignElementProductOfWithShift: Matrix is empty."); assert(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols()); if (!(a.GetNumRows() == b.GetNumRows() && a.GetNumCols() == b.GetNumCols())) InvalidArgument("The input matrix dimensions do not match."); // int O = a.GetNumRows(); int P = a.GetNumCols(); RequireSize(1, P); CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock); a.PrepareDevice(); SyncGuard syncGuard; _assignElementProductOfWithShift<<>>(Data(), a.Data(), b.Data(), shift, N); return *this; } //sequence training template GPUMatrix& GPUMatrix::DropFrame(const GPUMatrix& label, const GPUMatrix& gamma, const ElemType& threshhold) { if (IsEmpty()) LogicError("DropFrame: Matrix is empty."); PrepareDevice(); long N = (long) GetNumCols(); // one kernel per column int blocksPerGrid = (int) ceil(N * 1.0 / GridDim::maxThreadsPerBlock); SyncGuard syncGuard; _DropFrame<<>>(Data(), label.Data(), gamma.Data(), threshhold, (long) m_numCols, (long) m_numRows); return *this; } template GPUMatrix& GPUMatrix::AssignSequenceError(const ElemType hsmoothingWeight, const GPUMatrix& label, const GPUMatrix& dnnoutput, const GPUMatrix& gamma, ElemType alpha) { if (IsEmpty()) LogicError("AssignSequenceError: Matrix is empty."); PrepareDevice(); SyncGuard syncGuard; long N = (LONG64) label.GetNumElements(); int blocksPerGrid = (int) ceil(1.0 * N / GridDim::maxThreadsPerBlock); _AssignSequenceError<<>>(hsmoothingWeight, Data(), label.Data(), dnnoutput.Data(), gamma.Data(), alpha, N); return *this; } #pragma endregion Static BLAS Functions /// f = logadd(f, vec) to get the logadd sum of vector elments template ElemType GPUMatrix::LogSumOfElements() const { if (IsEmpty()) LogicError("SumOfElements: Matrix is empty"); ElemType* d_sum = TracingGPUMemoryAllocator::Allocate(GetComputeDeviceId(), 1); ElemType h_sum; CUDA_LONG N = (CUDA_LONG) GetNumElements(); int blocksPerGrid = (int) ceil(((double) N) / GridDim::maxThreadsPerBlock); _reductionLogAddSum<<>>(Data(), d_sum, 1, N); CUDA_CALL(cudaMemcpy(&h_sum, d_sum, sizeof(ElemType), cudaMemcpyDeviceToHost)); TracingGPUMemoryAllocator::Free(GetComputeDeviceId(), d_sum); return h_sum; } template void GPUMatrix::RCRFBackwardCompute( const GPUMatrix& alpha, GPUMatrix& beta, const GPUMatrix& /*lbls*/, const GPUMatrix& pos_scores, const GPUMatrix& pair_scores, const int shift) { if (alpha.IsEmpty() || pos_scores.IsEmpty() || pair_scores.IsEmpty()) LogicError("RCRFBackwardCompute: one of the input matrices is empty."); if (alpha.GetNumRows() != pos_scores.GetNumRows() || alpha.GetNumCols() != pos_scores.GetNumCols()) LogicError("RCRFBackwardCompute: matrix dimensions mismatched."); size_t iNumLab = alpha.GetNumRows(); size_t iNumPos = alpha.GetNumCols(); alpha.PrepareDevice(); beta.RequireSize(iNumLab, iNumPos); ElemType* d_zeta = TracingGPUMemoryAllocator::Allocate(alpha.GetComputeDeviceId(), iNumLab); CUDA_LONG N = iNumLab; // TODO: change all three '512' to 'GridDim::maxThreadsPerBlock' (not doing this now since I cannot test it) int blocksPerGrid = (int) ceil(1.0 * N / 512); size_t szMemSize; for (int t = iNumPos - 1; t >= 0; t--) { szMemSize = sizeof(ElemType) * iNumLab; // This function assumes iNumLab <= 1024 and that shared memory == total (!) number of threads == iNumLab. assert(iNumLab <= 1024); _rcrfBackwardComputeZetaMax1024Labels << > >(t, iNumPos, alpha.Data(), d_zeta, pair_scores.Data(), iNumLab, shift); szMemSize = iNumLab * 3; szMemSize *= sizeof(ElemType); // This function assumes iNumLab <= 1024 and that shared memory == total (!) number of threads == 3 * iNumLab. assert(iNumLab <= 1024); _rcrfBackwardComputeMax1024Labels << > >(t, iNumPos, alpha.Data(), beta.Data(), d_zeta, pair_scores.Data(), iNumLab, shift); } /* error = cudaGetErrorString(cudaPeekAtLastError()); printf("%s\n", error); error = cudaGetErrorString(cudaThreadSynchronize()); printf("%s\n", error); */ TracingGPUMemoryAllocator::Free(alpha.GetComputeDeviceId(), d_zeta); } /** Compute the gradient for the first order Markov transition probabilities It uses equations derived in R. Collobert's paper "Natural language processing (almost) from scratch" */ template void GPUMatrix::RCRFTransGrdCompute(const GPUMatrix& lbls, const GPUMatrix& alpha, const GPUMatrix& beta, const GPUMatrix& pair_scores, GPUMatrix& grd, const int startLbl, const int shift) { assert(shift == 1); int iNumPos = alpha.GetNumCols(); int iNumLab = alpha.GetNumRows(); ElemType* d_zeta = TracingGPUMemoryAllocator::Allocate(alpha.GetComputeDeviceId(), iNumLab); CUDA_LONG N = iNumLab; // TODO: change all three '512' to 'GridDim::maxThreadsPerBlock' (not doing this now since I cannot test it) int blocksPerGrid = (int)ceil(1.0 * N / 512); size_t szMemSize; for (int t = 0; t < iNumPos; t++) { szMemSize = sizeof(ElemType) * iNumLab; // This function assumes iNumLab <= 1024 and that shared memory == total (!) number of threads == iNumLab. assert(iNumLab <= 1024); // BUGBUG: This is launched with 512 threads per block, but allocates shared mem as if there is only one block. Likewise for all 4 of these functions. _rcrfTransGrdComputeZetaMax1024Labels << > >(t - 1, iNumPos, alpha.Data(), d_zeta, pair_scores.Data(), iNumLab, startLbl, shift); szMemSize = iNumLab * 3; szMemSize *= sizeof(ElemType); // This function assumes iNumLab <= 1024 and that shared memory == total (!) number of threads == iNumLab. assert(iNumLab <= 1024); _rcrfTransGrdComputeMax1024Labels << > >(t, startLbl, alpha.Data(), beta.Data(), d_zeta, pair_scores.Data(), lbls.Data(), grd.Data(), iNumPos, iNumLab, shift); } TracingGPUMemoryAllocator::Free(alpha.GetComputeDeviceId(), d_zeta); }; // ----------------------------------------------------------------------- // TensorView entry points from Matrix.cpp // ----------------------------------------------------------------------- // helper to provide a vector of ones of at least the given number of elements // TODO: Use this to implement ComputationNode::ConstOnes? Or do we even need that anymore? template static shared_ptr> GetOnesVector(size_t N, DEVICEID_TYPE deviceId) { // using a dynamically allocated array so this will never get freed, avoiding free-after-DLL-unload issues. // and using shared_ptrs since we don't want to leak more than CacheSize elements // when using a plain array we would have to control lifetime of the object and destructor would be called for every element in the array at the end const int CacheSize = 32; static shared_ptr> * onesCache = new shared_ptr>[CacheSize]; // cache of objects if (deviceId >= CacheSize){ LogicError("GetOnesVector: onesCache[] too small (%d entries), increase (you need %d) and recompile.", CacheSize, (int)deviceId + 1); } auto p = onesCache[deviceId]; if (!p || p->GetNumRows() < N) // must (re-)allocate { p = make_shared>(GPUMatrix::Ones(N, 1, deviceId)); onesCache[deviceId] = p; // this will replace the pointer thread-safely (although weird race conditions may happen where a larger entry is overwritten by a smaller one; will still run correctly) } return p; } // perform unary operation 'op' on a giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides // This binds the N-ariness to a template parameter N, and gets the data pointers out from the matrix objects. template void GPUMatrix::TensorOp(ElemType beta, const GPUMatrix& a, ElemType alpha, ElementWiseOperator op, ElementWiseOperator reductionOp, const array& offsets, const SmallVector& regularOpDims, const array, 2>& regularStrides, const SmallVector& reducingOpDims, const array, 2>& reducingStrides) { if (reductionOp != ElementWiseOperator::opSum && reductionOp != ElementWiseOperator::opLogSum && reductionOp != ElementWiseOperator::opMin && reductionOp != ElementWiseOperator::opMax && reductionOp != ElementWiseOperator::opElementwiseProduct) InvalidArgument("TensorOp: Unary reduction operations other than opMax, opMin, opSum, and opLogSum are not implemented."); a.PrepareDevice(); if (a.GetComputeDeviceId() != GetComputeDeviceId()) InvalidArgument("All matrices must be on the same GPU"); // special case: linear processing // The case statement has measurable impact for unary ops (but not for binary ops it seems, due to double mem access). // Linear gap-free unary ops happen so regularly that we will eliminate the case statement from the CUDA kernel, and instead expand all. if (regularOpDims.size() == 1 && regularStrides[0][0] == 1 && regularStrides[1][0] == 1 && reducingOpDims.size() == 0) { // special case: for copy, use cudaMemcpy() instead, or cublas_axpy() // TODO: We should observe if these actually make a speed difference, and if not, remove these special cases. if (op == ElementWiseOperator::opCopy && beta == 0 && alpha == 1) return CUDA_CALL(cudaMemcpy(Data()+ offsets[1], a.Data()+ offsets[0], sizeof(ElemType) * regularOpDims[0], cudaMemcpyDeviceToDevice)); else if (op == ElementWiseOperator::opCopy && beta == 1) return CUBLAS_CALL(cublas_axpy(GetCublasHandle(GetComputeDeviceId()), (int) regularOpDims[0], &alpha, a.Data()+ offsets[0], 1, Data()+ offsets[1], 1)); else return LaunchUnaryTensorOp(beta, a.Data()+ offsets[0], Data()+ offsets[1], alpha, op, regularOpDims[0]); } // special case: sum-reducing a matrix onto a column vector; can be done with SGEMM // Note: A minor risk is that with this, our own reduction function will rarely be used. // That function was tested to give the same results with 'double', and nearly the same with 'float' (different summation order matters). else if (op == ElementWiseOperator::opCopy && // we are just adding to target without any further operation reductionOp == ElementWiseOperator::opSum && #ifdef _DEBUG sizeof(ElemType) == sizeof(float) && // in debug don't shortcut 'double' so we have some test of our own codepath #endif regularOpDims.size() == 1 && regularStrides[0][0] == 1 && regularStrides[1][0] == 1 && // we are processing a column reducingOpDims.size() == 1 && reducingStrides[0][0] >= (ptrdiff_t) regularOpDims[0]) // reducing across columns and no overlap { assert(reducingStrides[1][0] == 0); auto ARows = regularOpDims[0]; // vertical steps auto ACols = reducingOpDims[0]; // horizontal steps (reduction) auto ALd = reducingStrides[0][0]; // horizontal step width through matrix cublasHandle_t cuHandle = GetCublasHandle(a.GetComputeDeviceId()); CUBLAS_CALL(cublas_gemm(cuHandle, CUBLAS_OP_N, CUBLAS_OP_N, (int) /*CRows=*/ARows, /*CCols=*/1, (int) ACols, &alpha, /*A00=*/a.Data()+ offsets[0], (int) ALd, /*B00=*/GetOnesVector(ACols, a.GetComputeDeviceId())->Data(), (int) /*BRows=*/ACols, &beta, /*C00=*/Data()+ offsets[1], (int) /*CRows=*/ARows)); return; } // TODO: Add a special case for tensor bias reduction. cudnn is ~7% faster on Image/QuickE2E. // regular case else return TensorOpN(beta, array{a.Data(), Data()}, alpha, op, reductionOp, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides); } // perform binary operation 'op' on a and b giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides template void GPUMatrix::TensorOp(ElemType beta, const GPUMatrix& a, const GPUMatrix& b, ElemType alpha, ElementWiseOperator op, ElementWiseOperator reductionOp, const array& offsets, const SmallVector& regularOpDims, const array, 3>& regularStrides, const SmallVector& reducingOpDims, const array, 3>& reducingStrides) { if (reductionOp != ElementWiseOperator::opSum) InvalidArgument("TensorOp: The only permitted binary reduction operation is opSum."); a.PrepareDevice(); if (a.GetComputeDeviceId() != GetComputeDeviceId() || b.GetComputeDeviceId() != GetComputeDeviceId()) InvalidArgument("All matrices must be on the same GPU"); return TensorOpN(beta, array{a.Data(), b.Data(), Data()}, alpha, op, reductionOp, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides); } // perform ternary operation 'op' on a, and c giving 'this', reinterpreting the matrices as tensors as specified by the dims and strides template void GPUMatrix::TensorOp(ElemType beta, const GPUMatrix& a, const GPUMatrix& b, const GPUMatrix& c, ElemType alpha, ElementWiseOperator op, ElementWiseOperator reductionOp, const array& offsets, const SmallVector& regularOpDims, const array, 4>& regularStrides, const SmallVector& reducingOpDims, const array, 4>& reducingStrides) { if (reductionOp != ElementWiseOperator::opSum) InvalidArgument("TensorOp: The only permitted ternary reduction operation is opSum."); a.PrepareDevice(); if (a.GetComputeDeviceId() != GetComputeDeviceId() || b.GetComputeDeviceId() != GetComputeDeviceId() || c.GetComputeDeviceId() != GetComputeDeviceId()) InvalidArgument("All matrices must be on the same GPU"); return TensorOpN(beta, array{a.Data(), b.Data(), c.Data(), Data()}, alpha, op, reductionOp, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides); } template void GPUMatrix::TensorArgOp(const GPUMatrix& a, ElementWiseOperator reductionOp, const array& offsets, const SmallVector& regularOpDims, const array, 2>& regularStrides, const SmallVector& reducingOpDims, const array, 2>& reducingStrides) { if (reductionOp != ElementWiseOperator::opArgmin && reductionOp != ElementWiseOperator::opArgmax) InvalidArgument("TensorOp: Arg reduction operations other than opArgmax, and opArgmin are not implemented."); a.PrepareDevice(); if (a.GetComputeDeviceId() != GetComputeDeviceId()) InvalidArgument("All matrices must be on the same GPU"); return TensorOpN((ElemType) 0, array{a.Data(), Data()}, (ElemType) 1, ElementWiseOperator::opCopy, reductionOp, offsets, regularOpDims, regularStrides, reducingOpDims, reducingStrides); } // ======================================================================= // explicit instantiations business // ======================================================================= template class GPUMatrix; template class GPUMatrix; template class DeviceBoundNumber; template class DeviceBoundNumber; template cublasHandle_t GPUMatrix::s_cuHandle[GPUMatrix::MaxGpus] = {0}; template void* GPUMatrix::s_curandGenerator = NULL; // We use Matrix as the backing store for QuantizedMatrix // Let's explicitly instantiate the methods we need for that purpose template GPUMatrix::GPUMatrix(const size_t numRows, const size_t numCols, int deviceId); template GPUMatrix::GPUMatrix(const size_t numRows, const size_t numCols, int deviceId, char* pArray, const size_t matrixFlags); template GPUMatrix::GPUMatrix(const GPUMatrix&); template GPUMatrix::GPUMatrix(GPUMatrix&&); template char* GPUMatrix::CopyToArray() const; template void GPUMatrix::ChangeDeviceTo(int); template void GPUMatrix::Resize(size_t, size_t, bool); template void GPUMatrix::RequireSize(size_t, size_t, bool); template GPUMatrix::~GPUMatrix(); template GPUMatrix GPUMatrix::ColumnSlice(size_t startColumn, size_t numCols) const; template GPUMatrix& GPUMatrix::operator=(GPUMatrix&&); template GPUMatrix::GPUMatrix(int); template void GPUMatrix::SetValue(const char); template void GPUMatrix::SetValue(const size_t numRows, const size_t numCols, int deviceId, char* pArray, size_t matrixFlags, DataTransferer* transferer); //template void GPUMatrix::SetValue(CPUMatrix const&); template void GPUMatrix::SetValue(GPUMatrix const&); //template void GPUMatrix::SetValue(CPUSparseMatrix const&); //template void GPUMatrix::SetValue(GPUSparseMatrix const&); template void GPUMatrix::CopySection(size_t numRows, size_t numCols, char* dst, size_t colStride) const; template void GPUMatrix::Reshape(const size_t, const size_t); template GPUMatrix& GPUMatrix::operator*=(char); template DEVICEID_TYPE GPUMatrix::PrepareDevice(DEVICEID_TYPE deviceId) const; // Support template GPUMatrix::GPUMatrix(const size_t numRows, const size_t numCols, int deviceId); template GPUMatrix::GPUMatrix(const size_t numRows, const size_t numCols, int deviceId, short* pArray, const size_t matrixFlags); template GPUMatrix::GPUMatrix(const GPUMatrix&); template GPUMatrix::GPUMatrix(GPUMatrix&&); template short* GPUMatrix::CopyToArray() const; template void GPUMatrix::ChangeDeviceTo(int); template void GPUMatrix::Resize(size_t, size_t, bool); template void GPUMatrix::RequireSize(size_t, size_t, bool); template GPUMatrix::~GPUMatrix(); template GPUMatrix GPUMatrix::ColumnSlice(size_t startColumn, size_t numCols) const; template GPUMatrix& GPUMatrix::operator=(GPUMatrix&&); template GPUMatrix::GPUMatrix(int); template void GPUMatrix::SetValue(const short); template void GPUMatrix::SetValue(const size_t numRows, const size_t numCols, int deviceId, short* pArray, size_t matrixFlags, DataTransferer* transferer); //template void GPUMatrix::SetValue(CPUMatrix const&); template void GPUMatrix::SetValue(GPUMatrix const&); //template void GPUMatrix::SetValue(CPUSparseMatrix const&); //template void GPUMatrix::SetValue(GPUSparseMatrix const&); template void GPUMatrix::CopySection(size_t numRows, size_t numCols, short* dst, size_t colStride) const; template void GPUMatrix::Reshape(const size_t, const size_t); template GPUMatrix& GPUMatrix::operator*=(short); template DEVICEID_TYPE GPUMatrix::PrepareDevice(DEVICEID_TYPE deviceId) const; template GPUMatrix::GPUMatrix(const size_t, const size_t, int, int*, const size_t); template GPUMatrix::~GPUMatrix(); template int* TracingGPUMemoryAllocator::Allocate(int, size_t); template size_t* TracingGPUMemoryAllocator::Allocate(int, size_t); template long* TracingGPUMemoryAllocator::Allocate(int, size_t); template short* TracingGPUMemoryAllocator::Allocate(int, size_t); template char* TracingGPUMemoryAllocator::Allocate(int, size_t); template float* TracingGPUMemoryAllocator::Allocate(int, size_t); template double* TracingGPUMemoryAllocator::Allocate(int, size_t); template void TracingGPUMemoryAllocator::Free(int, int*, bool); template void TracingGPUMemoryAllocator::Free(int, size_t*, bool); template void TracingGPUMemoryAllocator::Free(int, short*, bool); template void TracingGPUMemoryAllocator::Free(int, char*, bool); template void TracingGPUMemoryAllocator::Free(int, float*, bool); template void TracingGPUMemoryAllocator::Free(int, double*, bool); }}} // !!!!This is from helper_cuda.h which comes with CUDA samples!!!! Consider if it is beneficial to just include all helper_cuda.h // TODO: This is duplicated in BestGpu.cpp // Beginning of GPU Architecture definitions int _ConvertSMVer2Cores(int major, int minor) { // Defines for GPU Architecture types (using the SM version to determine the # of cores per SM typedef struct { int SM; // 0xMm (hexidecimal notation), M = SM Major version, and m = SM minor version int Cores; } sSMtoCores; sSMtoCores nGpuArchCoresPerSM[] = { {0x10, 8}, // Tesla Generation (SM 1.0) G80 class {0x11, 8}, // Tesla Generation (SM 1.1) G8x class {0x12, 8}, // Tesla Generation (SM 1.2) G9x class {0x13, 8}, // Tesla Generation (SM 1.3) GT200 class {0x20, 32}, // Fermi Generation (SM 2.0) GF100 class {0x21, 48}, // Fermi Generation (SM 2.1) GF10x class {0x30, 192}, // Kepler Generation (SM 3.0) GK10x class {0x35, 192}, // Kepler Generation (SM 3.5) GK11x class {-1, -1}}; int index = 0; while (nGpuArchCoresPerSM[index].SM != -1) { if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { return nGpuArchCoresPerSM[index].Cores; } index++; } return nGpuArchCoresPerSM[7].Cores; }; // end of GPU Architecture definitions //inline CUDA_LONG _GetFreeMemoryOnCUDADevice(int devId) //{ // CUdevice cudaDevice; // CUresult result = cuDeviceGet(&cudaDevice, devId); // if(result!= CUDA_SUCCESS) // { // return 0; // } // // // create cuda context // CUcontext cudaContext; // result = cuCtxCreate(&cudaContext, CU_CTX_SCHED_AUTO, cudaDevice); // if(result != CUDA_SUCCESS) // { // return 0; // } // // // get the amount of free memory on the graphics card // size_t free; // size_t total; // result = cuMemGetInfo(&free, &total); // if (result!=CUDA_SUCCESS) // { // return 0; // } // else // return (CUDA_LONG)free; //} #endif // CPUONLY