Skip to content

Commit bbc3e97

Browse files
authored
Remove C++ Wrappers in memory_resource_adaptors.hpp Needed by Cython (#662)
Depends on #661 While working on PR #661, it looked like it was possible to remove the "owning" C++ memory resource adaptors in `memory_resource_adaptors.hpp`. This PR is a quick implementation of what that would look like to run through CI and get feedback. The main driving factor of this PR is to eliminate the need for 2 layers of wrappers around every memory resource in the library. When adding new memory resources, C++ wrappers must be created in `memory_resource_adaptors.hpp` and Cython wrappers must be created in `memory_resource.pyx`, for any property/function that needs to be exposed at the python level. This removes the C++ wrappers in favor of using pythons reference counting for lifetime management. A few notes: 1. `MemoryResource` was renamed `DeviceMemoryResource` to more closely match the C++ class names. Easily can be changed back 1. Upstream MR are kept alive by a base class `UpstreamResourceAdaptor` that stores a single property `upstream_mr`. Any MR that has an upstream, needs to inherit from this class. 1. Once the `UpstreamResourceAdaptor` was created, most of the work/changes were updating the Cython imports to use the C++ classes instead of the C++ wrappers. 1. This should make it easier to expose more methods/properties at the python layer in the future. Would appreciate any feedback. Authors: - Michael Demoret (@mdemoret-nv) Approvers: - Christopher Harris (@cwharris) - Keith Kraus (@kkraus14) - Ashwin Srinath (@shwina) URL: #662
1 parent 869d374 commit bbc3e97

File tree

6 files changed

+280
-387
lines changed

6 files changed

+280
-387
lines changed

python/rmm/_lib/device_buffer.pxd

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from libcpp.memory cimport unique_ptr
1615
from libc.stdint cimport uintptr_t
16+
from libcpp.memory cimport unique_ptr
1717

18-
from rmm._lib.cuda_stream_view cimport cuda_stream_view
1918
from rmm._cuda.stream cimport Stream
20-
from rmm._lib.memory_resource cimport MemoryResource
19+
from rmm._lib.cuda_stream_view cimport cuda_stream_view
20+
from rmm._lib.memory_resource cimport DeviceMemoryResource
2121

2222

2323
cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
@@ -39,10 +39,10 @@ cdef extern from "rmm/device_buffer.hpp" namespace "rmm" nogil:
3939
cdef class DeviceBuffer:
4040
cdef unique_ptr[device_buffer] c_obj
4141

42-
# Holds a reference to the MemoryResource used for allocation. Ensures the
43-
# MR does not get destroyed before this DeviceBuffer. `mr` is needed for
44-
# deallocation
45-
cdef MemoryResource mr
42+
# Holds a reference to the DeviceMemoryResource used for allocation.
43+
# Ensures the MR does not get destroyed before this DeviceBuffer. `mr` is
44+
# needed for deallocation
45+
cdef DeviceMemoryResource mr
4646

4747
# Holds a reference to the stream used by the underlying `device_buffer`.
4848
# Ensures the stream does not get destroyed before this DeviceBuffer
Lines changed: 31 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,50 @@
11
# Copyright (c) 2020, NVIDIA CORPORATION.
22

33
from libc.stdint cimport int8_t
4-
from libcpp.vector cimport vector
5-
from libcpp.string cimport string
64
from libcpp.memory cimport shared_ptr
5+
from libcpp.string cimport string
6+
from libcpp.vector cimport vector
7+
8+
9+
cdef extern from "rmm/mr/device/device_memory_resource.hpp" \
10+
namespace "rmm::mr" nogil:
11+
cdef cppclass device_memory_resource:
12+
pass
713

14+
cdef class DeviceMemoryResource:
15+
cdef shared_ptr[device_memory_resource] c_obj
816

9-
cdef extern from "memory_resource_wrappers.hpp" nogil:
10-
cdef cppclass device_memory_resource_wrapper:
11-
shared_ptr[device_memory_resource_wrapper] get_mr() except +
12-
13-
cdef cppclass default_memory_resource_wrapper(
14-
device_memory_resource_wrapper
15-
):
16-
default_memory_resource_wrapper(int device) except +
17-
18-
cdef cppclass cuda_memory_resource_wrapper(device_memory_resource_wrapper):
19-
cuda_memory_resource_wrapper() except +
20-
21-
cdef cppclass managed_memory_resource_wrapper(
22-
device_memory_resource_wrapper
23-
):
24-
managed_memory_resource_wrapper() except +
25-
26-
cdef cppclass pool_memory_resource_wrapper(device_memory_resource_wrapper):
27-
pool_memory_resource_wrapper(
28-
shared_ptr[device_memory_resource_wrapper] upstream_mr,
29-
size_t initial_pool_size,
30-
size_t maximum_pool_size
31-
) except +
32-
33-
cdef cppclass fixed_size_memory_resource_wrapper(
34-
device_memory_resource_wrapper
35-
):
36-
fixed_size_memory_resource_wrapper(
37-
shared_ptr[device_memory_resource_wrapper] upstream_mr,
38-
size_t block_size,
39-
size_t blocks_to_preallocate
40-
) except +
41-
42-
cdef cppclass binning_memory_resource_wrapper(
43-
device_memory_resource_wrapper
44-
):
45-
binning_memory_resource_wrapper(
46-
shared_ptr[device_memory_resource_wrapper] upstream_mr
47-
) except +
48-
binning_memory_resource_wrapper(
49-
shared_ptr[device_memory_resource_wrapper] upstream_mr,
50-
int8_t min_size_exponent,
51-
int8_t max_size_exponent
52-
) except +
53-
void add_bin(
54-
size_t allocation_size,
55-
shared_ptr[device_memory_resource_wrapper] bin_mr
56-
) except +
57-
void add_bin(
58-
size_t allocation_size
59-
) except +
60-
61-
cdef cppclass logging_resource_adaptor_wrapper(
62-
device_memory_resource_wrapper
63-
):
64-
logging_resource_adaptor_wrapper(
65-
shared_ptr[device_memory_resource_wrapper] upstream_mr,
66-
string filename
67-
) except +
68-
void flush() except +
69-
70-
cdef cppclass thread_safe_resource_adaptor_wrapper(
71-
device_memory_resource_wrapper
72-
):
73-
thread_safe_resource_adaptor_wrapper(
74-
shared_ptr[device_memory_resource_wrapper] upstream_mr,
75-
) except +
76-
77-
void set_per_device_resource(
78-
int device,
79-
shared_ptr[device_memory_resource_wrapper] new_resource
80-
) except +
81-
82-
83-
cdef class MemoryResource:
84-
cdef shared_ptr[device_memory_resource_wrapper] c_obj
85-
86-
cdef class CudaMemoryResource(MemoryResource):
17+
cdef device_memory_resource* get_mr(self)
18+
19+
cdef class UpstreamResourceAdaptor(DeviceMemoryResource):
20+
cdef readonly DeviceMemoryResource upstream_mr
21+
22+
cpdef DeviceMemoryResource get_upstream(self)
23+
24+
cdef class CudaMemoryResource(DeviceMemoryResource):
8725
pass
8826

89-
cdef class ManagedMemoryResource(MemoryResource):
27+
cdef class ManagedMemoryResource(DeviceMemoryResource):
9028
pass
9129

92-
cdef class PoolMemoryResource(MemoryResource):
30+
cdef class PoolMemoryResource(UpstreamResourceAdaptor):
9331
pass
9432

95-
cdef class FixedSizeMemoryResource(MemoryResource):
33+
cdef class FixedSizeMemoryResource(UpstreamResourceAdaptor):
9634
pass
9735

98-
cdef class BinningMemoryResource(MemoryResource):
99-
cpdef add_bin(self, size_t allocation_size, object bin_resource=*)
36+
cdef class BinningMemoryResource(UpstreamResourceAdaptor):
37+
38+
cdef readonly list bin_mrs
39+
40+
cpdef add_bin(
41+
self,
42+
size_t allocation_size,
43+
DeviceMemoryResource bin_resource=*)
10044

101-
cdef class LoggingResourceAdaptor(MemoryResource):
102-
cdef MemoryResource upstream
45+
cdef class LoggingResourceAdaptor(UpstreamResourceAdaptor):
10346
cdef object _log_file_name
104-
cpdef MemoryResource get_upstream(self)
10547
cpdef get_file_name(self)
10648
cpdef flush(self)
10749

108-
cpdef MemoryResource get_current_device_resource()
50+
cpdef DeviceMemoryResource get_current_device_resource()

0 commit comments

Comments
 (0)