// // Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE.md file in the project root for full license information. // // TODO: // - remove empty-matrix checks: if an op is well-defined with empty matrices, then do it // - Resize() must be cheap if it does nothing (I already did that for CPU; already done for GPU?) // #pragma once #include "Basics.h" #include "File.h" #include "CommonMatrix.h" #include "TensorShape.h" // only for SmallVector; I was hoping to keep this out #include "RNGHandle.h" #include "DataTransferer.h" #include #include // for shared_ptr #include #include #include "QuantizedOperations.h" // Forward declarations namespace CNTK { class Value; } // This class is exported from the Math.dll namespace Microsoft { namespace MSR { namespace CNTK { enum CurrentDataLocation { NONE, CPU, GPU, BOTH }; enum MatrixType { UNDETERMINED, DENSE, SPARSE }; // avoid pulling in these header files for consumers of this class template class GPUMatrix; template class CPUMatrix; template class GPUSparseMatrix; template class CPUSparseMatrix; template class DeviceBoundNumber; // -agnostic base class struct /*interface*/ MATH_API MatrixBase { virtual int GetDeviceId() const = 0; virtual MatrixType GetMatrixType() const = 0; virtual MatrixFormat GetFormat() const = 0; virtual void CastAssignValuesOf(const MatrixBase& other) = 0; // allows for mixed assignment with conversion // TODO: Move more generic functions such as getting dims, resizing, and getting/setting as scalars in here. virtual ~MatrixBase(); }; typedef std::shared_ptr MatrixBasePtr; // Note: To comply with BLAS libraries, matrices are stored in ColMajor. However, by default C/C++/C# use RowMajor convertion. // !!!WARNING!!! This class is NOT THREAD SAFE. Test and add necessary modifications if using in multi-threaded environment template class MATH_API Matrix : public MatrixBase { friend class ::CNTK::Value; typedef MatrixBase Base; private: mutable BaseMatrix* m_baseMatrix; mutable shared_ptr> m_GPUMatrix; mutable shared_ptr> m_CPUMatrix; mutable shared_ptr> m_GPUSparseMatrix; mutable shared_ptr> m_CPUSparseMatrix; mutable MatrixType m_matrixType; mutable CurrentDataLocation m_currentDataLocation; // Indicates which matrix is current mutable DEVICEID_TYPE m_preferredDeviceId; mutable size_t m_numTimesDeviceChanged; mutable size_t m_numTimesMatrixTypeChanged; mutable int m_devicesTransferedTo[2]; // TODO: what is this for? Seems only diagnostics // Moves matrix from device id_from to device with id_to. This method doesn't change preferred device Id void _transferFromDeviceToDevice(int id_from, int id_to, bool isBeingMoved = true, bool emptyTransfer = false) const; // Moves matrix from current device to device with id_to. This method doesn't change preferred device Id void _transferToDevice(int id_to, bool isBeingMoved = true, bool emptyTransfer = false) const; template static void DecideAndMoveToRightDevice(const Matrix& a, const Matrix& b); static void DecideAndMoveToRightDevice(const Matrix& a, const Matrix& b, const Matrix& c); static void DecideAndMoveToRightDevice(const Matrix& a, const Matrix& b, const Matrix& c, const Matrix& d); static void CopyElementsFromDenseToSparse(CPUMatrix& from, CPUSparseMatrix& dest); public: // Constructors, destructors and other static matrix builders // Each constructor can take deviceId as parameter. // If deviceId<0 then the matrix will be based in RAM (CPUMatrix) // Elseif deviceId>=0 then the matrix will be based on GPU with specified deviceId explicit Matrix(DEVICEID_TYPE deviceId); // This constructor is not used, but it makes the ownership of baseMatrix ambiguous. If it's to be used, ensure that the semantics with external buffer are clear. #if 0 Matrix(shared_ptr> baseMatrix, ElemType* pArray, DEVICEID_TYPE deviceId); // constructor for setting Matrix from a base matrix (externally managed butter pArray) #endif Matrix(const size_t numRows, const size_t numCols, DEVICEID_TYPE deviceId, const MatrixType matrixType = DENSE, const MatrixFormat matrixFormat = matrixFormatDense); // TODO: Rewrite this constructor to eliminate the external buffers flag. Make a separate construction mechanism for Matrix objects that don't own their storage. Matrix(const size_t numRows, const size_t numCols, ElemType* pArray, DEVICEID_TYPE deviceId, const size_t matrixFlags = matrixFlagNormal, const size_t nnz = 0); Matrix(const Matrix& deepCopyFrom, DEVICEID_TYPE deviceId); Matrix(Matrix&& moveFrom); // move constructor, shallow copy Matrix& operator=(Matrix&& moveFrom); // move assignment operator, shallow copy Matrix DeepClone() const; // Disallow deep copy construction and assignment to avoid // inadvertent silent deep copying Matrix(const Matrix& deepCopyFrom) = delete; Matrix& operator=(const Matrix& deepCopyFrom) = delete; static Matrix Ones(const size_t rows, const size_t cols, DEVICEID_TYPE deviceId); static Matrix Zeros(const size_t rows, const size_t cols, DEVICEID_TYPE deviceId); static Matrix Eye(const size_t rows, DEVICEID_TYPE deviceId); #define USE_TIME_BASED_SEED ULONG_MAX static Matrix RandomUniform(const size_t rows, const size_t cols, DEVICEID_TYPE deviceId, const ElemType low, const ElemType high, unsigned long seed = USE_TIME_BASED_SEED); static Matrix RandomGaussian(const size_t rows, const size_t cols, DEVICEID_TYPE deviceId, const ElemType mean, const ElemType sigma, unsigned long seed = USE_TIME_BASED_SEED); static void SetDevice(DEVICEID_TYPE deviceId); // TODO: unify with PrepareDevice() void ReleaseMemory(); ~Matrix(); // workaround to bugs in BOTH implementation: force to collapse to home location void CollapseDataLocation() const { SetDataLocation(GetDeviceId() < 0 ? CurrentDataLocation::CPU : CurrentDataLocation::GPU, GetMatrixType()); } private: Matrix(const MatrixFlags matrixFlags, const MatrixType matrixType, const MatrixFormat matrixFormat, DEVICEID_TYPE deviceID); // only used internally to initialize a blank matrix Matrix(const MatrixFlags matrixFlags, const MatrixType matrixType, DEVICEID_TYPE deviceID); // only used internally to initialize a blank matrix Matrix(const MatrixFlags matrixFlags, DEVICEID_TYPE deviceID); // only used internally to initialize a blank matrix void Init(DEVICEID_TYPE deviceID); void SetDataLocation(CurrentDataLocation location, MatrixType type = UNDETERMINED) const; void ShallowCopyFrom(const Matrix& other); public: // down-cast to make life easier template static shared_ptr DownCast(shared_ptr> inode) { shared_ptr node = dynamic_pointer_cast(inode); if (!node) LogicError("A Matrix of mismatching type was passed."); return node; } MatrixType GetMatrixType() const override; MatrixFormat GetFormat() const override; bool OwnBuffer() const { return m_baseMatrix->OwnBuffer(); } int GetDeviceId() const; // -1 if CPU, otherwise GPU CUDA device id DEVICEID_TYPE GetPreferredDeviceId() const { return m_preferredDeviceId; }; // -1 if CPU, otherwise GPU CUDA device id void SetPreferredDeviceId(DEVICEID_TYPE preferredDeviceId) { m_preferredDeviceId = preferredDeviceId; } // Moves matrix from device id_from to device with id_to. // If emptyTransfer=true, then no data is ever moved, just corresponding GPU/CPU matrices are deleted and then created using empty constructor void TransferFromDeviceToDevice(int id_from, int id_to, bool isBeingMoved = false, /*if false then keep source and set location to BOTH*/ bool emptyTransfer = false, bool updatePreferredDevice = true) const; // Same as TransferFromDeviceToDevice() but moves only if it is currently not on the target device void TransferToDeviceIfNotThere(int id_to, bool isBeingMoved = false, bool emptyTransfer = false, bool updatePreferredDevice = true) const; CurrentDataLocation GetCurrentMatrixLocation() const { return m_currentDataLocation; }; void SwitchToMatrixType(MatrixType newMatrixType, MatrixFormat newMatrixFormat, bool keepValues); // sets matrix type between dense and sparse size_t GetNumRows() const; size_t GetNumCols() const; size_t GetNumElements() const; bool HasNoElements() const { return GetNumElements() == 0; } bool IsEmpty() const; size_t BufferSize() const; ElemType* Data() const; ElemType* CopyToArray() const; // allocated by the callee but need to be deleted by the caller size_t CopyToArray(ElemType*& arrayCopyTo, size_t& currentArraySize) const; // allocated by the callee but need to be deleted by the caller // colStride specifies leading dimension of dst. // REVIEW alexeyk: GPU version copies from device to host only, implement all versions (device <-> host). void CopySection(size_t numRows, size_t numCols, ElemType* dst, size_t colStride) const; Matrix ColumnSlice(size_t startColumn, size_t numCols) const; // note: 'const' is misleading here, as the returned matrix is a mutable reference // difference between AssignColumnSlice and SetColumnSlice // AssignColumnSlice : this(:, startColumn:startColumn+numCols-1) = fromMatrix(:, startColumn: startColumn+numCols-1) // SetColumnSlice : this(:, startColumn:startColumn+numCols-1) = fromMatrix(:, 0: startColumn+numCols-1) // AssignColumnSlice do not transfer data, it uses external data // SetColumnSlice copies data Matrix& AssignColumnSlice(const Matrix& fromMatrix, size_t startColumn, size_t numCols); Matrix& SetColumnSlice(const Matrix& fromMatrix, size_t startColumn, size_t numCols); void CopyColumnsStrided(const Matrix& fromMatrix, size_t numCols, size_t srcNumColsStride, size_t destNumColsStride); Matrix Diagonal() const; void AssignDiagonalValuesTo(Matrix& diag) const; void SGDUpdate(Matrix& gradients, ElemType learnRatePerSample); void MomentumSGDUpdate(Matrix& gradients, Matrix& smoothedGradients, ElemType learnRatePerSample, ElemType momentum, bool unitGainMomentum = true); void NesterovAcceleratedMomentumSGDUpdate(Matrix& gradients, Matrix& smoothedGradients, ElemType learnRatePerSample, ElemType momentum, bool unitGainMomentum = true); ElemType Adagrad(Matrix& gradients, const bool needAveMultiplier); void FSAdagradUpdate(size_t mbSize, Matrix& gradients, Matrix& functionValues, double& smoothedCount, const double learnRatePerSample, const double targetAdagradAvDenom, const double meanMomentum, const double varMomentum, bool unitGainMomentum = true); void AdamUpdate(Matrix& gradients, Matrix& functionValues, double& smoothedCount, const double learnRatePerSample, const double meanMomentum, const double varMomentum, bool unitGainMomentum = true); ElemType RmsProp(Matrix& gradients, ElemType RMS_GAMMA, ElemType RMS_WGT_INC, ElemType RMS_WGT_MAX, ElemType RMS_WGT_DEC, ElemType RMS_WGT_MIN, const bool needAveMultiplier); void Resize(const size_t numRows, const size_t numCols, const size_t numNZElemToReserve = 10000, bool growOnly = true); // by default we only reallocate if need to grow void Resize(const Matrix& other) // TODO: Should this carry over numNZElemToReserve for sparse matrices? { Resize(other.GetNumRows(), other.GetNumCols()); } void VerifySize(size_t rows, size_t cols) { m_baseMatrix->VerifySize(rows, cols); } // TODO: Call this ShallowClone instead? Matrix AsReference() const { return ColumnSlice(0, GetNumCols()); } // get a reference (e.g. this is not resizable but can be reshaped) void Reshape(const size_t numRows, const size_t numCols); // note: reshapes in place. To get a reshaped reference, use Reshaped() Matrix Reshaped(const size_t numRows, const size_t numCols) const // get a reshaped reference { Matrix result = AsReference(); result.Reshape(numRows, numCols); return result; } // update number of columns // TODO: a future version may want to enforce retaining the content, to allow dynamically growing layouts column by column (when size is not known upfront) void ResizeColumns(const size_t numCols) { Resize(GetNumRows(), numCols); } // similarl to the repmat operation in matlab or octave static Matrix RepMat(const Matrix& frmMat, const size_t rows, const size_t cols); size_t GetAllocatedSize() const; void Reset(); // reset for sparse matrix const ElemType operator()(const size_t row, const size_t col) const; ElemType& operator()(const size_t row, const size_t col); ElemType GetValue(const size_t row, const size_t col) const { return operator()(row, col); } // use this for reading on non-const objects to avoid inefficiency ElemType Get00Element() const; void SetValue(const ElemType v); void SetValue(const DeviceBoundNumber& db_number); //void SetValue (const Matrix& deepCopyFrom, const MatrixFormat format = matrixFormatSparseCSR); // BUGBUG: default for 'format' is unexpected // SetValue respects the source matrix's information. It moves the target's location (if necessary), and then copies the sources values. void SetValue (const Matrix& deepCopyFrom); // AssignValuesOf respects the target matrix's information. It copies the values from the target into the memory of the source. void AssignValuesOf(const Matrix& deepCopyFrom); void SetValue(const size_t numRows, const size_t numCols, int deviceId, ElemType* pArray, const size_t matrixFlags = matrixFlagNormal, DataTransferer* transferer = nullptr); void SetValue(const size_t rIdx, const size_t cIdx, ElemType val); // set matrix sparsely void SetValue(const size_t numRows, const size_t numCols, std::initializer_list l) // SetValue(2,3, {1,2,3, 4,5,6}); { std::vector vals(l); assert(vals.size() == numRows * numCols); SetValue(numRows, numCols, GetDeviceId(), vals.data(), matrixFormatRowMajor); } void CastAssignValuesOf(const MatrixBase& other) override; // allows for mixed assignment with conversion static ElemType MakeNan(size_t payload); void Invalidate() { SetValue(MakeNan(__LINE__)); } void SetMatrixFromCSCFormat(const CPUSPARSE_INDEX_TYPE* h_CSCCol, const CPUSPARSE_INDEX_TYPE* h_Row, const ElemType* h_Val, const size_t nz, const size_t numRows, const size_t numCols, DataTransferer* transferer = nullptr); void MaskColumnsValue(const Matrix& columnsMask, ElemType val, size_t numColsPerMaskEntry); void SetColumn(const ElemType* colPointer, size_t colInd); void SetColumn(const ElemType val, size_t colInd); void SetColumn(const Matrix& valMat, size_t colInd); void SetDiagonalValue(const ElemType v); void SetDiagonalValue(const Matrix& vector); void SetUniformRandomValue(const ElemType low, const ElemType high, unsigned long seed = USE_TIME_BASED_SEED); void SetGaussianRandomValue(const ElemType mean, const ElemType sigma, unsigned long seed = USE_TIME_BASED_SEED); void SetUniformRandomMask(const ElemType maskRate, const ElemType scaleValue, RNGHandle& rngHandle); void AddGaussianRandomValue(const ElemType mean, const ElemType sigma, unsigned long seed = USE_TIME_BASED_SEED); Matrix& AssignNoiseContrastiveEstimation(const Matrix& a, const Matrix& b, const Matrix& c, const Matrix& bias, Matrix& tmp); Matrix& AssignNCEDerivative(const Matrix& tmp, const Matrix& a, const Matrix& b, const Matrix& c, size_t inputIndex); Matrix& AssignSoftmaxSum(const Matrix& a, const Matrix& softmax); Matrix& AssignNceUnnormalizedEval(const Matrix& a, const Matrix& b, const Matrix& c, const Matrix& bias); Matrix Transpose(); // This method doesn't change state of Matrix. It should be a const function Matrix& AssignTransposeOf(const Matrix& a); Matrix& DoGatherColumnsOf (ElemType beta, const Matrix& idx, const Matrix& a, ElemType alpha); Matrix& DoScatterColumnsOf(ElemType beta, const Matrix& idx, const Matrix& a, ElemType alpha); Matrix& operator+=(const ElemType alpha); Matrix operator+(const ElemType alpha) const; Matrix& AssignSumOf(const ElemType alpha, const Matrix& a); Matrix& operator+=(const Matrix& a); Matrix operator+(const Matrix& a) const; Matrix& AssignSumOf(const Matrix& a, const Matrix& b); Matrix& operator-=(const ElemType alpha); Matrix operator-(const ElemType alpha) const; Matrix& AssignDifferenceOf(const ElemType alpha, const Matrix& a); Matrix& AssignDifferenceOf(const Matrix& a, const ElemType alpha); Matrix& operator-=(const Matrix& a); Matrix operator-(const Matrix& a) const; Matrix& AssignDifferenceOf(const Matrix& a, const Matrix& b); Matrix& operator*=(const ElemType alpha); Matrix operator*(const ElemType alpha) const; Matrix& AssignProductOf(const ElemType alpha, const Matrix& a); Matrix operator*(const Matrix& a) const; Matrix& AssignProductOf(const Matrix& a, const bool transposeA, const Matrix& b, const bool transposeB); // this = a * b Matrix& Assign1x1ProductOf(const Matrix& a1x1, const Matrix& b); // this = a * b, where a is 1x1 Matrix& operator/=(ElemType alpha); Matrix operator/(ElemType alpha) const; Matrix& operator^=(ElemType alpha); // element-wise power Matrix operator^(ElemType alpha) const; // element-wise power Matrix& AssignElementPowerOf(const Matrix& a, const ElemType power); // TODO: There are several functions below that perform an in-place operation // We should prepend the names of these functions with InPlace for clearly indicating // the semantics for callers. Matrix& ElementMultiplyWith(const Matrix& a); Matrix& AssignElementProductOf(const Matrix& a, const Matrix& b); Matrix& AddElementProductOf(const Matrix& a, const Matrix& b); Matrix& AssignElementDivisionOf(const Matrix& a, const Matrix& b); Matrix& ElementDivideBy(const Matrix& a); Matrix& ColumnElementMultiplyWith(const Matrix& a); Matrix& RowElementMultiplyWith(const Matrix& a); Matrix& ColumnElementDivideBy(const Matrix& a); Matrix& RowElementDivideBy(const Matrix& a); Matrix& ElementInverse(); Matrix& AssignElementInverseOf(const Matrix& a); Matrix& InplaceLinearRectifierDerivative(); Matrix& AssignLinearRectifierDerivativeOf(const Matrix& a); Matrix& InplaceSigmoidDerivative(); Matrix& AssignSigmoidDerivativeOf(const Matrix& a); Matrix& InplaceSigmoid(); Matrix& AssignSigmoidOf(const Matrix& a); Matrix& InplaceTanh(); Matrix& AssignTanhOf(const Matrix& a); Matrix& InplaceLogSoftmax(const bool isColWise); Matrix& AssignLogSoftmaxOf(const Matrix& a, const bool isColWise); Matrix& InplaceHardmax(const bool isColWise); Matrix& AssignHardmaxOf(const Matrix& a, const bool isColWise); // sequence training Matrix& DropFrame(const Matrix& label, const Matrix& gamma, const ElemType& threshhold); Matrix& AssignSequenceError(const ElemType hsmoothingWeight, const Matrix& label, const Matrix& dnnoutput, const Matrix& gamma, ElemType alpha); Matrix& AssignCTCScore(const Matrix& prob, Matrix& alpha, Matrix& beta, const Matrix& phoneSeq, const Matrix& phoneBound, ElemType &totalScore, const vector & extraUttMap, const vector & uttBeginFrame, const vector & uttFrameNum, const vector & uttPhoneNum, const size_t samplesInRecurrentStep, const size_t mbSize, const size_t blankTokenId, const int delayConstraint, const bool isColWise); Matrix& InplaceSqrt(); Matrix& AssignSqrtOf(const Matrix& a); Matrix& InplaceExp(); Matrix& AssignExpOf(const Matrix& a); Matrix& InplaceLog(); Matrix& AssignLogOf(const Matrix& a); Matrix& InplaceCosine(); Matrix& AssignCosineOf(const Matrix& a); Matrix& InplaceNegativeSine(); Matrix& AssignNegativeSineOf(const Matrix& a); Matrix& InplaceLog10(); Matrix& AssignLog10Of(const Matrix& a); Matrix& InplaceAbs(); Matrix& AssignAbsOf(const Matrix& a); // TODO: rename these to InPlaceFloor() and -Ceil() (I never know what it means to truncate a bottom) // And also document and implement that sparse matrices can only truncate towards 0. Matrix& InplaceTruncateBottom(const ElemType threshold); Matrix& AssignTruncateBottomOf(const Matrix& a, const ElemType threshold); Matrix& InplaceTruncateTop(const ElemType threshold); Matrix& AssignTruncateTopOf(const Matrix& a, const ElemType threshold); Matrix& InplaceTruncate(const ElemType threshold); Matrix& InplaceSoftThreshold(const ElemType threshold); void InplaceTranspose(); Matrix& SetToZeroIfAbsLessThan(const ElemType threshold); DeviceBoundNumber Sum_AsDeviceBoundNum() const; ElemType SumOfAbsElements() const; // sum of all abs(elements) ElemType SumOfElements() const; // sum of all elements Matrix& AssignSumOfElements(const Matrix& a); ElemType LogSumOfElements() const; Matrix& AssignToRowSliceValuesOf(const Matrix& a, const size_t startIndex, const size_t numRows); Matrix& AssignRowSliceValuesOf(const Matrix& a, const size_t startIndex, const size_t numRows); Matrix& AddToRowSliceValuesOf(const Matrix& a, const size_t startIndex, const size_t numRows); Matrix& AddWithRowSliceValuesOf(const Matrix& a, const size_t startIndex, const size_t numRows); // Matrix& AssignRowStackValuesOf(const std::vector*>& inputMatrices, const size_t sliceStartCol, const size_t sliceNumCols); Matrix& AssignRepeatOf(const Matrix& a, const size_t numRowRepeats, const size_t numColRepeats); Matrix& AddToRowRepeatValuesOf(const Matrix& a, const size_t numRepeats); Matrix& AssignPositiveAndShiftedNegSample(const Matrix& a, const size_t posNumber, const size_t negNumber, const size_t shiftNumber); Matrix& AddFoldedPositiveAndShiftedNegSample(const Matrix& a, const size_t posNumber, const size_t negNumber, const size_t shiftNumber); bool IsValid() const; bool IsEqualTo(const Matrix& a, const ElemType threshold = 1e-8) const; static void VectorSum(const Matrix& a, Matrix& c, const bool isColWise); void VectorNorm1(Matrix& c, const bool isColWise) const; Matrix& AssignVectorNorm1Of(Matrix& a, const bool isColWise); // TODO: arg should be const void VectorNorm2(Matrix& c, const bool isColWise) const; Matrix& AssignVectorNorm2Of(Matrix& a, const bool isColWise); // TODO: arg should be const void VectorNormInf(Matrix& c, const bool isColWise) const; Matrix& AssignVectorNormInfOf(Matrix& a, const bool isColWise); Matrix& AssignInnerProductOf(const Matrix& a, const Matrix& b, const bool isColWise); Matrix& AssignKhatriRaoProductOf(const Matrix& a, const Matrix& b); Matrix& AddColumnReshapeProductOf(const Matrix& a, const Matrix& b, const bool transposeAColumn); Matrix& AddWithScaleOf(ElemType alpha, const Matrix& a); // this += alpha * a ElemType FrobeniusNorm() const; Matrix& AssignFrobeniusNormOf(const Matrix& a); ElemType MatrixNormInf() const; ElemType MatrixNorm1() const; ElemType MatrixNorm0() const; // number of non-zero elemets Matrix& AssignSignOf(const Matrix& a); Matrix& AddSignOf(const Matrix& a); void VectorMax(Matrix& maxIndexes, Matrix& maxValues, const bool isColWise) const; void VectorMax(Matrix& maxIndexes, Matrix& maxValues, const bool isColWise, int topK) const; void VectorMin(Matrix& minIndexes, Matrix& minValues, const bool isColWise) const; Matrix& AssignNumOfDiff(const Matrix& a, const Matrix& b, bool searchInCol = false); Matrix& AssignInnerProductOfMatrices(const Matrix& a, const Matrix& b); // this method will resize(1,1) first bool HasNan(const char* name) const; size_t CountNanInf() const; void Print(const char* matrixName, ptrdiff_t rowFirst, ptrdiff_t rowLast, ptrdiff_t colFirst, ptrdiff_t colLast) const; void Print(const char* matrixName = nullptr) const; // print whole matrix. can be expensive Matrix& AssignPackedConvolutionInput(const Matrix& inputSubBatch, const size_t inputWidth, const size_t inputHeight, const size_t inputChannels, const size_t outputWidth, const size_t outputHeight, const size_t outputChannels, const size_t kernelWidth, const size_t kernelHeight, const size_t horizontalSubsample, const size_t verticalSubsample, const bool zeroPadding = false); Matrix& UnpackConvolutionInput(Matrix& inputSubBatch, const size_t inputWidth, const size_t inputHeight, const size_t inputChannels, const size_t outputWidth, const size_t outputHeight, const size_t outputChannels, const size_t kernelWidth, const size_t kernelHeight, const size_t horizontalSubsample, const size_t verticalSubsample, const bool zeroPadding = false) const; Matrix& AssignMaxPoolingResult(const Matrix& inputBatch, const size_t channels, const size_t inputWidth, const size_t inputHeight, const size_t inputSizePerSample, const size_t outputWidth, const size_t outputHeight, const size_t outputSizePerSample, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample); Matrix& AddMaxPoolingGradient(const Matrix& outputGradientBatch, const Matrix& inputBatch, const Matrix& outputBatch, const size_t channels, const size_t inputWidth, const size_t inputHeight, const size_t inputSizePerSample, const size_t outputWidth, const size_t outputHeight, const size_t outputSizePerSample, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample); Matrix& AssignAveragePoolingResult(const Matrix& inputBatch, const size_t channels, const size_t inputWidth, const size_t inputHeight, const size_t inputSizePerSample, const size_t outputWidth, const size_t outputHeight, const size_t outputSizePerSample, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample); Matrix& AddAveragePoolingGradient(const Matrix& outputGradientBatch, const size_t channels, const size_t inputWidth, const size_t inputHeight, const size_t inputSizePerSample, const size_t outputWidth, const size_t outputHeight, const size_t outputSizePerSample, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample); void ConvolutionForward(const Matrix& kernel, const Matrix& mpRowCol, const Matrix& mpRowIwht, const Matrix& mpRowRun, const Matrix& runs, Matrix& output) const; void ConvolutionBackwardData(const Matrix& kernel, const Matrix& mpRowCol, const Matrix& mpRowIwht, const Matrix& mpRowRun, const Matrix& runs, Matrix& grad) const; void ConvolutionBackwardKernel(const Matrix& in, const Matrix& mpRowCol, const Matrix& mpRowIwht, const Matrix& mpRowRun, const Matrix& runs, Matrix& kernelGrad) const; void UnrollConvolutionInput(size_t unrollCols, size_t mapOutSize, const Matrix& mpRowCol, const Matrix& mpRowRun, const Matrix& runs, Matrix& output) const; void UnrollConvolutionOutput(size_t unrollCols, size_t mapInCount, size_t mapOutCount, const Matrix& mpRowCol, const Matrix& mpRowRun, const Matrix& runs, Matrix& output) const; void UnrollConvolutionInputForKernelBackprop(size_t mapOutSize, const Matrix& mpRowCol, const Matrix& mpRowRun, const Matrix& runs, Matrix& output) const; void MaxPoolingForward(const Matrix& mpRowCol, const Matrix& mpRowIndices, const Matrix& indices, Matrix& output) const; void MaxPoolingBackward(const Matrix& out, const Matrix& in, const Matrix& mpRowCol, const Matrix& mpRowIndices, const Matrix& indices, Matrix& grad) const; void ROIPoolingForward(const size_t numRois, const size_t numImg, const size_t channels, const size_t width, const size_t height, const size_t pooledWidth, const size_t pooledHeight, const Matrix& roiData, Matrix& output, Matrix& argmax) const; void ROIPoolingBackward(const size_t numRois, const size_t numImg, const size_t channels, const size_t width, const size_t height, const size_t pooledWidth, const size_t pooledHeight, const Matrix& roiData, Matrix& grad, Matrix& argmax) const; void MaxUnpooling(const Matrix& mpRowCol, const Matrix& mpRowIndices, const Matrix& indices, const Matrix& poolInput, Matrix& input) const; void AveragePoolingForward(const Matrix& mpRowCol, const Matrix& mpRowIndices, const Matrix& indices, Matrix& output) const; void AveragePoolingBackward(const Matrix& mpRowCol, const Matrix& mpRowIndices, const Matrix& indices, Matrix& grad) const; void BatchNormalizationForward(const Matrix& scale, const Matrix& bias, bool inferenceOnly, double expAvgFactor, double blendFactor, Matrix& runMean, Matrix& runVariance, Matrix& out, double epsilon, Matrix& saveMean, Matrix& saveInvStdDev) const; void BatchNormalizationBackward(const Matrix& in, Matrix& grad, const Matrix& scale, double blendFactor, const Matrix& saveMean, const Matrix& saveInvStdDev, Matrix& scaleGrad, Matrix& biasGrad) const; void RNNForward(const Matrix& inputX, const Matrix& paramW, size_t xDim, size_t yDim, const vector& numSequencesForFrame, const struct RnnAttributes& rnnAttributes, Matrix& reserve, Matrix& workspace); void RNNBackwardData(const Matrix& outputDY, const Matrix& paramW, Matrix& outputDX, const struct RnnAttributes& rnnAttributes, Matrix& reserve, Matrix& workspace); void RNNBackwardWeights(const Matrix& inputX, const Matrix& outputY, Matrix& dw, const struct RnnAttributes& rnnAttributes, Matrix& reserve, Matrix& workspace); public: // TODO: why are these not static? And why are they here? ElemType Exp10(ElemType num); ElemType Mod(ElemType x, ElemType y); ElemType LogAdd(ElemType x, ElemType y); public: // static BLAS functions // singular value decomposition of A as A = U*SIGMA*VT static void SVD(const Matrix& A, Matrix& SIGMA, Matrix& U, Matrix& VT, Matrix& W); static void MultiplyAndWeightedAdd(ElemType alpha, const Matrix& a, const bool transposeA, const Matrix& b, const bool transposeB, ElemType beta, Matrix& c, shared_ptr> pQuantizedMultiplier=nullptr); // SGEMM static void MultiplyAndAdd(const Matrix& a, const bool transposeA, const Matrix& b, const bool transposeB, Matrix& c); static void Multiply(const Matrix& a, const bool transposeA, const Matrix& b, const bool transposeB, Matrix& c); static void Multiply(const Matrix& a, const Matrix& b, Matrix& c); static void Multiply1x1AndWeightedAdd(ElemType alpha, const Matrix& a, const Matrix& b, ElemType beta, Matrix& c); static void ConvolveAndWeightedAdd(ElemType alpha, const Matrix& a, const bool transposeA, const Matrix& b, const bool transposeB, ElemType beta, Matrix& c, size_t numChannels, size_t horizontalSubsample, bool padding, bool channelwise); static void ScaleAndAdd(ElemType alpha, const Matrix& a, Matrix& c); static void ScaleAndAdd(ElemType alpha, const Matrix& a, ElemType beta, Matrix& c); static void AddScaledDifference(const ElemType alpha, const Matrix& a, const Matrix& b, Matrix& c); static void AssignScaledDifference(const ElemType alpha, const Matrix& a, const Matrix& b, Matrix& c); static void AddScaledDifference(const Matrix& alpha, const Matrix& a, const Matrix& b, Matrix& c); // c += alpha * (a - b) static void AssignScaledDifference(const Matrix& alpha, const Matrix& a, const Matrix& b, Matrix& c); static void AddElementToElement(const Matrix& a, const size_t ai, const size_t aj, Matrix& c, const size_t ci, const size_t cj); // static void AddLogElementToElement(const Matrix& a, const size_t ai, const size_t aj, Matrix& c, const size_t ci, const size_t cj); static void AssignElementToElement(const Matrix& a, const size_t ai, const size_t aj, Matrix& c, const size_t ci, const size_t cj); static void MinusOneAt(Matrix& c, const size_t position); static void Scale(ElemType alpha, Matrix& a); static void Scale(const Matrix& alpha, Matrix& a); // In this case Matrix alpha must be 1x1 static void Scale(ElemType alpha, const Matrix& a, Matrix& c); static void InnerProduct(const Matrix& a, const Matrix& b, Matrix& c, const bool isColWise); static ElemType InnerProductOfMatrices(const Matrix& a, const Matrix& b); static void ElementWisePower(ElemType alpha, const Matrix& a, Matrix& c); static bool AreEqual(const Matrix& a, const Matrix& b, const ElemType threshold = 1e-8); static bool HasElement(const Matrix& a, const ElemType value = 0.0); static void TensorShuffleScaleAndAdd(ElemType keepWeight, const Matrix& a, size_t D, size_t S, size_t M, size_t K, size_t T, ElemType scaleFactor, const Matrix& b, Matrix& c); void TensorOp(ElemType beta, const Matrix& a, ElemType alpha, ElementWiseOperator op, ElementWiseOperator reductionOp, const std::array& offsets, const SmallVector& regularOpDims, const std::array, 2>& regularStrides, const SmallVector& reducingOpDims, const std::array, 2>& reducingStrides); void TensorOp(ElemType beta, const Matrix& a, const Matrix& b, ElemType alpha, ElementWiseOperator op, ElementWiseOperator reductionOp, const std::array& offsets, const SmallVector& regularOpDims, const std::array, 3>& regularStrides, const SmallVector& reducingOpDims, const std::array, 3>& reducingStrides); void TensorOp(ElemType beta, const Matrix& a, const Matrix& b, const Matrix& c, ElemType alpha, ElementWiseOperator op, ElementWiseOperator reductionOp, const std::array& offsets, const SmallVector& regularOpDims, const std::array, 4>& regularStrides, const SmallVector& reducingOpDims, const std::array, 4>& reducingStrides); void TensorArgOp(const Matrix& a, ElementWiseOperator reductionOp, const std::array& offsets, const SmallVector& regularOpDims, const std::array, 2>& regularStrides, const SmallVector& reducingOpDims, const std::array, 2>& reducingStrides); public: void Read(File& stream); void Write(File& stream) const; Matrix& Shift(const Matrix& a, int shift); Matrix& AssignElementProductOfWithShiftNeg(const Matrix& a, const Matrix& b, size_t shift, size_t negnumber); Matrix& AssignInnerProductOfWithShiftNeg(const Matrix& a, const Matrix& b, const bool isColWise, size_t shift, size_t negnumber); static void InnerProductWithShiftNeg(const Matrix& a, const Matrix& b, Matrix& c, const bool isColWise, size_t shift, size_t negnumber); Matrix& GetARowByIndex(const Matrix& a, size_t index); static void ConductRowElementMultiplyWithShift(const Matrix& a, const Matrix& b, Matrix& c, size_t shift, bool bFirstmatrixfixed); Matrix& AssignElementProductOfWithShift(const Matrix& a, const Matrix& b, size_t shift); public: static void RCRFBackwardCompute(const Matrix& alpha, Matrix& beta, Matrix& functionValues, const Matrix& lbls, const Matrix& pos_scores, const Matrix& pair_scores, const int shift); static void RCRFTransGrdCompute(const Matrix& lbls, const Matrix& alpha, const Matrix& beta, const Matrix& pair_scores, Matrix& grd, const int startLbl, // the time 0 start symbol in the output layer const int shift); template friend class MatrixQuantizer; template friend class QuantizedMatrix; template friend class Matrix; }; // overload I/O operators template File& operator>>(File& stream, Matrix& M) { M.Read(stream); return stream; } template File& operator<<(File& stream, const Matrix& M) { M.Write(stream); return stream; } typedef Matrix SingleMatrix; typedef Matrix DoubleMatrix; }}}