15#ifndef OPENJIJ_UTILITY_GPU_CUBLAS_HPP__
16#define OPENJIJ_UTILITY_GPU_CUBLAS_HPP__
21#include <cuda_runtime.h>
31template <
typename FloatType>
struct cudaDataType_impl;
33template <>
struct cudaDataType_impl<float> {
34 constexpr static cudaDataType_t type = CUDA_R_32F;
37template <>
struct cudaDataType_impl<double> {
38 constexpr static cudaDataType_t type = CUDA_R_64F;
42template <
typename FloatType>
43inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle,
int n,
44 const FloatType *x,
int incx,
48inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle,
int n,
49 const float *x,
int incx,
int *result) {
50 return cublasIsamax(handle, n, x, incx, result);
54inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle,
int n,
55 const double *x,
int incx,
57 return cublasIdamax(handle, n, x, incx, result);
61template <
typename FloatType>
63cublas_dot_impl(cublasHandle_t handle,
int n,
const FloatType *x,
int incx,
64 const FloatType *y,
int incy, FloatType *result);
67inline cublasStatus_t cublas_dot_impl(cublasHandle_t handle,
int n,
68 const float *x,
int incx,
const float *y,
69 int incy,
float *result) {
70 return cublasSdot(handle, n, x, incx, y, incy, result);
75cublas_dot_impl(cublasHandle_t handle,
int n,
const double *x,
int incx,
76 const double *y,
int incy,
double *result) {
77 return cublasDdot(handle, n, x, incx, y, incy, result);
87 HANDLE_ERROR_CUBLAS(cublasCreate(&_handle));
89 HANDLE_ERROR_CUBLAS(cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH));
92 CuBLASWrapper(CuBLASWrapper &&obj)
noexcept {
94 this->_handle = obj._handle;
101 HANDLE_ERROR_CUBLAS(cublasDestroy(_handle));
104 template <
typename FloatType>
105 inline void SgemmEx(cublasOperation_t transa, cublasOperation_t transb,
int m,
106 int n,
int k,
const float *alpha,
107 const utility::cuda::unique_dev_ptr<FloatType> &A,
109 const utility::cuda::unique_dev_ptr<FloatType> &B,
110 int ldb,
const float *beta,
111 utility::cuda::unique_dev_ptr<FloatType> &C,
int ldc) {
113 cublasPointerMode_t mode;
114 HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
116 cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_HOST));
117 HANDLE_ERROR_CUBLAS(cublasSgemmEx(
118 _handle, transa, transb, m, n, k, alpha, A.get(),
119 cudaDataType_impl<
typename std::remove_extent<FloatType>::type>::type,
121 cudaDataType_impl<
typename std::remove_extent<FloatType>::type>::type,
123 cudaDataType_impl<
typename std::remove_extent<FloatType>::type>::type,
125 HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
140 template <
typename FloatType>
141 inline void matmul(
int m,
int k,
int n,
142 const utility::cuda::unique_dev_ptr<FloatType> &A,
143 const utility::cuda::unique_dev_ptr<FloatType> &B,
144 utility::cuda::unique_dev_ptr<FloatType> &C,
145 cublasOperation_t transa = CUBLAS_OP_N,
146 cublasOperation_t transb = CUBLAS_OP_N) {
147 typename std::remove_extent<FloatType>::type alpha = 1.0;
148 typename std::remove_extent<FloatType>::type beta = 0;
150 cublasPointerMode_t mode;
151 HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
153 cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_HOST));
154 HANDLE_ERROR_CUBLAS(cublasSgemmEx(
155 _handle, transa, transb, m, n, k, &alpha, A.get(),
156 cudaDataType_impl<
typename std::remove_extent<FloatType>::type>::type,
158 cudaDataType_impl<
typename std::remove_extent<FloatType>::type>::type,
160 cudaDataType_impl<
typename std::remove_extent<FloatType>::type>::type,
162 HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
175 template <
typename FloatType>
176 inline void Iamax(
int n,
const FloatType *x,
int incx,
int *result) {
177 cublasPointerMode_t mode;
178 HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
181 cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_DEVICE));
182 HANDLE_ERROR_CUBLAS(cublas_Iamax_impl(_handle, n, x, incx, result));
184 HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
196 template <
typename FloatType>
198 absmax_val_index(
int n,
const utility::cuda::unique_dev_ptr<FloatType[]> &x,
199 utility::cuda::unique_dev_ptr<
int[]> &result) {
200 Iamax(n, x.get(), 1, result.get());
214 template <
typename FloatType>
215 inline void dot(
int n,
const FloatType *x,
int incx,
const FloatType *y,
216 int incy, FloatType *result) {
217 cublasPointerMode_t mode;
218 HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
220 cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_DEVICE));
222 HANDLE_ERROR_CUBLAS(cublas_dot_impl(_handle, n, x, incx, y, incy, result));
224 HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
236 template <
typename FloatType>
237 inline void dot(
int n,
const utility::cuda::unique_dev_ptr<FloatType[]> &x,
238 const utility::cuda::unique_dev_ptr<FloatType[]> &y,
239 utility::cuda::unique_dev_ptr<FloatType[]> &result) {
240 dot(n, x.get(), 1, y.get(), 1, result.get());
244 cublasHandle_t _handle;
Definition algorithm.hpp:24