openjij
Framework for the Ising model and QUBO.
Loading...
Searching...
No Matches
cublas.hpp
Go to the documentation of this file.
1// Copyright 2023 Jij Inc.
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef OPENJIJ_UTILITY_GPU_CUBLAS_HPP__
16#define OPENJIJ_UTILITY_GPU_CUBLAS_HPP__
17
18#ifdef USE_CUDA
19
20#include <cublas_v2.h>
21#include <cuda_runtime.h>
22
25
26namespace openjij {
27namespace utility {
28namespace cuda {
29
30// cuda datatype
31template <typename FloatType> struct cudaDataType_impl;
32
33template <> struct cudaDataType_impl<float> {
34 constexpr static cudaDataType_t type = CUDA_R_32F;
35};
36
37template <> struct cudaDataType_impl<double> {
38 constexpr static cudaDataType_t type = CUDA_R_64F;
39};
40
41// cublas get maximal value
42template <typename FloatType>
43inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle, int n,
44 const FloatType *x, int incx,
45 int *result);
46
47template <>
48inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle, int n,
49 const float *x, int incx, int *result) {
50 return cublasIsamax(handle, n, x, incx, result);
51}
52
53template <>
54inline cublasStatus_t cublas_Iamax_impl(cublasHandle_t handle, int n,
55 const double *x, int incx,
56 int *result) {
57 return cublasIdamax(handle, n, x, incx, result);
58}
59
60// cublas dot product
61template <typename FloatType>
62inline cublasStatus_t
63cublas_dot_impl(cublasHandle_t handle, int n, const FloatType *x, int incx,
64 const FloatType *y, int incy, FloatType *result);
65
66template <>
67inline cublasStatus_t cublas_dot_impl(cublasHandle_t handle, int n,
68 const float *x, int incx, const float *y,
69 int incy, float *result) {
70 return cublasSdot(handle, n, x, incx, y, incy, result);
71}
72
73template <>
74inline cublasStatus_t
75cublas_dot_impl(cublasHandle_t handle, int n, const double *x, int incx,
76 const double *y, int incy, double *result) {
77 return cublasDdot(handle, n, x, incx, y, incy, result);
78}
79
83class CuBLASWrapper {
84public:
85 CuBLASWrapper() {
86 // generate cuBLAS instance
87 HANDLE_ERROR_CUBLAS(cublasCreate(&_handle));
88 // use tensor core if possible
89 HANDLE_ERROR_CUBLAS(cublasSetMathMode(_handle, CUBLAS_TENSOR_OP_MATH));
90 }
91
92 CuBLASWrapper(CuBLASWrapper &&obj) noexcept {
93 // move cuBLAS handler
94 this->_handle = obj._handle;
95 obj._handle = NULL;
96 }
97
98 ~CuBLASWrapper() {
99 // destroy generator
100 if (_handle != NULL)
101 HANDLE_ERROR_CUBLAS(cublasDestroy(_handle));
102 }
103
104 template <typename FloatType>
105 inline void SgemmEx(cublasOperation_t transa, cublasOperation_t transb, int m,
106 int n, int k, const float *alpha,
107 const utility::cuda::unique_dev_ptr<FloatType> &A,
108 int lda,
109 const utility::cuda::unique_dev_ptr<FloatType> &B,
110 int ldb, const float *beta,
111 utility::cuda::unique_dev_ptr<FloatType> &C, int ldc) {
112
113 cublasPointerMode_t mode;
114 HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
115 HANDLE_ERROR_CUBLAS(
116 cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_HOST));
117 HANDLE_ERROR_CUBLAS(cublasSgemmEx(
118 _handle, transa, transb, m, n, k, alpha, A.get(),
119 cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
120 lda, B.get(),
121 cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
122 ldb, beta, C.get(),
123 cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
124 ldc));
125 HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
126 }
127
140 template <typename FloatType>
141 inline void matmul(int m, int k, int n,
142 const utility::cuda::unique_dev_ptr<FloatType> &A,
143 const utility::cuda::unique_dev_ptr<FloatType> &B,
144 utility::cuda::unique_dev_ptr<FloatType> &C,
145 cublasOperation_t transa = CUBLAS_OP_N,
146 cublasOperation_t transb = CUBLAS_OP_N) {
147 typename std::remove_extent<FloatType>::type alpha = 1.0;
148 typename std::remove_extent<FloatType>::type beta = 0;
149
150 cublasPointerMode_t mode;
151 HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
152 HANDLE_ERROR_CUBLAS(
153 cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_HOST));
154 HANDLE_ERROR_CUBLAS(cublasSgemmEx(
155 _handle, transa, transb, m, n, k, &alpha, A.get(),
156 cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
157 m, B.get(),
158 cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
159 k, &beta, C.get(),
160 cudaDataType_impl<typename std::remove_extent<FloatType>::type>::type,
161 m));
162 HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
163 }
164
175 template <typename FloatType>
176 inline void Iamax(int n, const FloatType *x, int incx, int *result) {
177 cublasPointerMode_t mode;
178 HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
179 // set pointermode to device
180 HANDLE_ERROR_CUBLAS(
181 cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_DEVICE));
182 HANDLE_ERROR_CUBLAS(cublas_Iamax_impl(_handle, n, x, incx, result));
183 // reset pointermode
184 HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
185 }
186
196 template <typename FloatType>
197 inline void
198 absmax_val_index(int n, const utility::cuda::unique_dev_ptr<FloatType[]> &x,
199 utility::cuda::unique_dev_ptr<int[]> &result) {
200 Iamax(n, x.get(), 1, result.get());
201 }
202
214 template <typename FloatType>
215 inline void dot(int n, const FloatType *x, int incx, const FloatType *y,
216 int incy, FloatType *result) {
217 cublasPointerMode_t mode;
218 HANDLE_ERROR_CUBLAS(cublasGetPointerMode(_handle, &mode));
219 HANDLE_ERROR_CUBLAS(
220 cublasSetPointerMode(_handle, CUBLAS_POINTER_MODE_DEVICE));
221 // set pointermode to device
222 HANDLE_ERROR_CUBLAS(cublas_dot_impl(_handle, n, x, incx, y, incy, result));
223 // reset pointermode
224 HANDLE_ERROR_CUBLAS(cublasSetPointerMode(_handle, mode));
225 }
226
236 template <typename FloatType>
237 inline void dot(int n, const utility::cuda::unique_dev_ptr<FloatType[]> &x,
238 const utility::cuda::unique_dev_ptr<FloatType[]> &y,
239 utility::cuda::unique_dev_ptr<FloatType[]> &result) {
240 dot(n, x.get(), 1, y.get(), 1, result.get());
241 }
242
243private:
244 cublasHandle_t _handle;
245};
246
247} // namespace cuda
248} // namespace utility
249} // namespace openjij
250
251#endif
252#endif
Definition algorithm.hpp:24