openjij
Framework for the Ising model and QUBO.
Loading...
Searching...
No Matches
gpu/memory.hpp
Go to the documentation of this file.
1// Copyright 2023 Jij Inc.
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#ifndef OPENJIJ_UTILITY_GPU_MEMORY_HPP__
16#define OPENJIJ_UTILITY_GPU_MEMORY_HPP__
17
18#ifdef USE_CUDA
19
20#include <memory>
21#include <type_traits>
22
23#include <cuda_runtime.h>
24
26
27namespace openjij {
28namespace utility {
29namespace cuda {
30
34struct deleter_dev {
35 void operator()(void *ptr) const { HANDLE_ERROR_CUDA(cudaFree(ptr)); }
36};
37
41struct deleter_host {
42 void operator()(void *ptr) const { HANDLE_ERROR_CUDA(cudaFreeHost(ptr)); }
43};
44
50template <typename T> using unique_dev_ptr = std::unique_ptr<T, deleter_dev>;
51
57template <typename T> using unique_host_ptr = std::unique_ptr<T, deleter_host>;
58
67template <typename T> cuda::unique_dev_ptr<T> make_dev_unique(std::size_t n) {
68 static_assert(std::is_array<T>::value, "T must be an array.");
69 using U = typename std::remove_extent<T>::type;
70 U *ptr;
71 HANDLE_ERROR_CUDA(cudaMalloc(reinterpret_cast<void **>(&ptr), sizeof(U) * n));
72 return cuda::unique_dev_ptr<T>{ptr};
73}
74
83template <typename T> cuda::unique_host_ptr<T> make_host_unique(std::size_t n) {
84 static_assert(std::is_array<T>::value, "T must be an array.");
85 using U = typename std::remove_extent<T>::type;
86 U *ptr;
87 HANDLE_ERROR_CUDA(
88 cudaMallocHost(reinterpret_cast<void **>(&ptr), sizeof(U) * n));
89 return cuda::unique_host_ptr<T>{ptr};
90}
91
92} // namespace cuda
93} // namespace utility
94} // namespace openjij
95
96#endif
97#endif
Definition algorithm.hpp:24