2020#include < rmm/mr/device/cuda_memory_resource.hpp>
2121#include < rmm/mr/device/default_memory_resource.hpp>
2222#include < rmm/mr/device/managed_memory_resource.hpp>
23+ #include < rmm/mr/device/logging_resource_adaptor.hpp>
2324
2425namespace rmm {
26+
27+ using cuda_mr = rmm::mr::cuda_memory_resource;
28+ using pool_mr = rmm::mr::cnmem_memory_resource;
29+ using managed_mr = rmm::mr::cnmem_managed_memory_resource;
30+ using pool_managed_mr = rmm::mr::cnmem_managed_memory_resource;
31+ using logging_pool_mr = rmm::mr::logging_resource_adaptor<pool_mr>;
32+ using logging_pool_managed_mr = rmm::mr::logging_resource_adaptor<pool_managed_mr>;
33+
2534/* *
2635 * Record a memory manager event in the log.
2736 *
@@ -93,6 +102,25 @@ rmmError_t Manager::registerStream(cudaStream_t stream) {
93102 return RMM_SUCCESS;
94103}
95104
105+ // reset the initialized resource, optionally enabling logging via logging_resource_adaptor
106+ // note this is a template function so we avoid a vtable lookup on every logging allocation
107+ // That means it needs to be a free function since memory_manager.hpp cannot include
108+ // `logging_resource_adaptor.hpp` (since it depends on spdlog)
109+ template <typename MemoryResource>
110+ void reset_resource (std::unique_ptr<mr::device_memory_resource>& initialized_resource,
111+ std::unique_ptr<mr::device_memory_resource>& logging_resource,
112+ MemoryResource *mr,
113+ bool enable_logging) {
114+ initialized_resource.reset (mr);
115+ if (enable_logging) {
116+ auto lmr = new rmm::mr::logging_resource_adaptor<MemoryResource>(mr);
117+ logging_resource.reset (lmr);
118+ rmm::mr::set_default_resource (lmr);
119+ } else {
120+ rmm::mr::set_default_resource (mr);
121+ }
122+ }
123+
96124// Initialize the manager
97125void Manager::initialize (const rmmOptions_t* new_options) {
98126 std::lock_guard<std::mutex> guard (manager_mutex);
@@ -102,20 +130,23 @@ void Manager::initialize(const rmmOptions_t* new_options) {
102130
103131 if (nullptr != new_options) options = *new_options;
104132
133+ bool enable_logging = getOptions ().enable_logging ;
134+
105135 if (usePoolAllocator ()) {
106136 auto pool_size = getOptions ().initial_pool_size ;
107137 auto const & devices = getOptions ().devices ;
138+
108139 if (useManagedMemory ()) {
109- initialized_resource. reset (
110- new rmm::mr::cnmem_managed_memory_resource (pool_size, devices));
140+ reset_resource (initialized_resource, logging_adaptor,
141+ new pool_managed_mr (pool_size, devices), enable_logging );
111142 } else {
112- initialized_resource. reset (
113- new rmm::mr::cnmem_memory_resource (pool_size, devices));
143+ reset_resource (initialized_resource, logging_adaptor,
144+ new pool_mr (pool_size, devices), enable_logging );
114145 }
115146 } else if (rmm::Manager::useManagedMemory ()) {
116- initialized_resource. reset ( new rmm::mr::managed_memory_resource () );
147+ reset_resource (initialized_resource, logging_adaptor, new managed_mr (), enable_logging );
117148 } else {
118- initialized_resource. reset ( new rmm::mr::cuda_memory_resource () );
149+ reset_resource (initialized_resource, logging_adaptor, new cuda_mr (), enable_logging );
119150 }
120151
121152 rmm::mr::set_default_resource (initialized_resource.get ());
0 commit comments