Revision 34b3a7f106baa9099802f246fda6ad3476f32d26 authored by Qiao Zhang on 16 February 2022, 19:26:23 UTC, committed by jax authors on 16 February 2022, 19:41:41 UTC
1 parent e1fd630
Raw File
cusolver_kernels.h
/* Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef JAXLIB_CUSOLVER_KERNELS_H_
#define JAXLIB_CUSOLVER_KERNELS_H_

#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "jaxlib/handle_pool.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"

namespace jax {

using SolverHandlePool = HandlePool<cusolverDnHandle_t, cudaStream_t>;

template <>
absl::StatusOr<SolverHandlePool::Handle> SolverHandlePool::Borrow(
    cudaStream_t stream);

// Set of types known to Cusolver.
enum class CusolverType {
  F32,
  F64,
  C64,
  C128,
};

// potrf: Cholesky decomposition

struct PotrfDescriptor {
  CusolverType type;
  cublasFillMode_t uplo;
  std::int64_t batch, n;
  int lwork;
};

void Potrf(cudaStream_t stream, void** buffers, const char* opaque,
           size_t opaque_len, XlaCustomCallStatus* status);
// getrf: LU decomposition

struct GetrfDescriptor {
  CusolverType type;
  int batch, m, n;
};

void Getrf(cudaStream_t stream, void** buffers, const char* opaque,
           size_t opaque_len, XlaCustomCallStatus* status);

// geqrf: QR decomposition

struct GeqrfDescriptor {
  CusolverType type;
  int batch, m, n, lwork;
};

void Geqrf(cudaStream_t stream, void** buffers, const char* opaque,
           size_t opaque_len, XlaCustomCallStatus* status);

// orgqr/ungqr: apply elementary Householder transformations

struct OrgqrDescriptor {
  CusolverType type;
  int batch, m, n, k, lwork;
};

void Orgqr(cudaStream_t stream, void** buffers, const char* opaque,
           size_t opaque_len, XlaCustomCallStatus* status);

// Symmetric (Hermitian) eigendecomposition, QR algorithm: syevd/heevd

struct SyevdDescriptor {
  CusolverType type;
  cublasFillMode_t uplo;
  int batch, n;
  int lwork;
};

void Syevd(cudaStream_t stream, void** buffers, const char* opaque,
           size_t opaque_len, XlaCustomCallStatus* status);

// Symmetric (Hermitian) eigendecomposition, Jacobi algorithm: syevj/heevj
// Supports batches of matrices up to size 32.

struct SyevjDescriptor {
  CusolverType type;
  cublasFillMode_t uplo;
  int batch, n;
  int lwork;
};

void Syevj(cudaStream_t stream, void** buffers, const char* opaque,
           size_t opaque_len, XlaCustomCallStatus* status);

// Singular value decomposition using QR algorithm: gesvd

struct GesvdDescriptor {
  CusolverType type;
  int batch, m, n;
  int lwork;
  signed char jobu, jobvt;
};

void Gesvd(cudaStream_t stream, void** buffers, const char* opaque,
           size_t opaque_len, XlaCustomCallStatus* status);

// Singular value decomposition using Jacobi algorithm: gesvdj

struct GesvdjDescriptor {
  CusolverType type;
  int batch, m, n;
  int lwork;
  cusolverEigMode_t jobz;
};

void Gesvdj(cudaStream_t stream, void** buffers, const char* opaque,
            size_t opaque_len, XlaCustomCallStatus* status);

}  // namespace jax

#endif  // JAXLIB_CUSOLVER_KERNELS_H_

back to top