Skip to content

Commit e521e95

Browse files
committed
Remove all RMM singletons except for the default resource, fix rmmInitialize/rmmFinalize
1 parent 6fcfb45 commit e521e95

File tree

3 files changed

+73
-159
lines changed

3 files changed

+73
-159
lines changed

include/rmm/mr/default_memory_resource.hpp

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -54,69 +54,5 @@ device_memory_resource* get_default_resource();
5454
device_memory_resource* set_default_resource(
5555
device_memory_resource* new_resource);
5656

57-
namespace detail{
58-
59-
60-
/**---------------------------------------------------------------------------*
61-
* @brief gets the default memory_resource when none is set
62-
*
63-
* A static function which will return a singleton cuda_memory_resource
64-
*
65-
*
66-
* @return device_memory_resource* a pointer to the static
67-
* cuda_memory_resource
68-
*---------------------------------------------------------------------------**/
69-
device_memory_resource* initial_resource();
70-
71-
/**---------------------------------------------------------------------------*
72-
* @brief gets a cuda_memory_resource
73-
*
74-
* A static function which will return a singleton cuda_memory_resource
75-
*
76-
*
77-
*
78-
* @return device_memory_resource* a pointer to the static
79-
* cuda_memory_resource
80-
*---------------------------------------------------------------------------**/
81-
device_memory_resource* cuda_resource();
82-
83-
/**---------------------------------------------------------------------------*
84-
* @brief gets a cnmem_memory_resource
85-
*
86-
* A static function which will return a singleton cnmem_memory_resource which
87-
* manages a pool
88-
*
89-
*
90-
* @param pool_size The initial size of the pool
91-
* @return device_memory_resource* a pointer to the static
92-
* cnmem_memory_resource
93-
*---------------------------------------------------------------------------**/
94-
device_memory_resource* pool_resource(std::size_t pool_size = 0);
95-
96-
/**---------------------------------------------------------------------------*
97-
* @brief gets a cnmem_managed_memory_resource
98-
*
99-
* A static function which will return a singleton cnmem_memory_resource which
100-
* manages a pool of UVM memory
101-
*
102-
* @param pool_size The initial size of the pool
103-
* @return device_memory_resource* a pointer to the static
104-
* cnmem_managed_memory_resource
105-
*---------------------------------------------------------------------------**/
106-
device_memory_resource* managed_pool_resource(std::size_t pool_size = 0);
107-
108-
109-
/**---------------------------------------------------------------------------*
110-
* @brief gets a managed_memory_resource
111-
*
112-
* A static function that returns a singleton managed_memory_resource.
113-
*
114-
*
115-
*
116-
* @return device_memory_resource* a pointer to the static
117-
* managed_memory_resource
118-
*---------------------------------------------------------------------------**/
119-
device_memory_resource* managed_resource();
120-
} // namespace detail
12157
} // namespace mr
12258
} // namespace rmm

src/mr/default_memory_resource.cpp

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,49 +25,20 @@
2525
namespace rmm {
2626
namespace mr {
2727

28-
namespace detail{
29-
device_memory_resource* initial_resource() {
30-
return cuda_resource();
31-
}
32-
33-
device_memory_resource* cuda_resource() {
34-
static cuda_memory_resource resource{};
35-
return &resource;
36-
}
37-
38-
device_memory_resource* pool_resource(std::size_t pool_size){
39-
static cnmem_memory_resource mr{pool_size};
40-
return &mr;
41-
}
42-
43-
device_memory_resource* managed_pool_resource(std::size_t pool_size){
44-
static cnmem_managed_memory_resource mr{pool_size};
45-
return &mr;
46-
}
47-
48-
device_memory_resource* managed_resource(){
49-
static managed_memory_resource mr{};
50-
return &mr;
51-
}
52-
} //namespace detail
5328
namespace {
54-
55-
56-
57-
58-
5929
// Use an atomic to guarantee thread safety
6030
std::atomic<device_memory_resource*>& get_default() {
61-
static std::atomic<device_memory_resource*> res{detail::initial_resource()};
31+
static std::atomic<device_memory_resource*> res{new cuda_memory_resource{}};
6232
return res;
6333
}
64-
} // namespace
34+
} // namespace anonymous
6535

6636
device_memory_resource* get_default_resource() { return get_default().load(); }
6737

6838
device_memory_resource* set_default_resource(
69-
device_memory_resource* new_resource) {
70-
new_resource = (new_resource == nullptr) ? detail::initial_resource() : new_resource;
39+
device_memory_resource* new_resource) {
40+
new_resource = (new_resource == nullptr) ?
41+
new cuda_memory_resource() : new_resource;
7142
return get_default().exchange(new_resource);
7243
}
7344
} // namespace mr

src/rmm.cpp

Lines changed: 68 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -59,39 +59,48 @@ const char * rmmGetErrorString(rmmError_t errcode) {
5959
// Initialize memory manager state and storage.
6060
rmmError_t rmmInitialize(rmmOptions_t *options)
6161
{
62+
rmm::Manager::getInstance().initialize(options);
63+
64+
rmm::mr::device_memory_resource * memory_resource = nullptr;
65+
66+
if (rmm::Manager::usePoolAllocator()) {
67+
if (rmm::Manager::useManagedMemory()) {
68+
memory_resource = new rmm::mr::cnmem_managed_memory_resource{
69+
rmm::Manager::getOptions().initial_pool_size};
70+
} else {
71+
memory_resource = new rmm::mr::cnmem_memory_resource{
72+
rmm::Manager::getOptions().initial_pool_size};
73+
}
74+
} else if (rmm::Manager::useManagedMemory()) {
75+
memory_resource = new rmm::mr::managed_memory_resource();
76+
}
6277

63-
rmm::Manager::getInstance().initialize(options);
64-
rmm::mr::device_memory_resource * memory_resource = rmm::mr::detail::initial_resource();
65-
66-
if (rmm::Manager::usePoolAllocator())
67-
{
68-
memory_resource = rmm::mr::detail::pool_resource(rmm::Manager::getOptions().initial_pool_size);
69-
70-
}else if(rmm::Manager::useManagedMemory()){
71-
memory_resource = rmm::mr::detail::managed_resource();
72-
73-
}else if(rmm::Manager::useManagedMemory() && rmm::Manager::usePoolAllocator()){
74-
memory_resource = rmm::mr::detail::managed_pool_resource(rmm::Manager::getOptions().initial_pool_size);
75-
}
76-
rmm::mr::set_default_resource(memory_resource);
78+
rmm::mr::set_default_resource(memory_resource);
7779

78-
return RMM_SUCCESS;
80+
return RMM_SUCCESS;
7981
}
8082

8183
// Shutdown memory manager.
8284
rmmError_t rmmFinalize()
8385
{
84-
rmm::Manager::getInstance().finalize();
85-
return RMM_SUCCESS;
86+
// delete the current default resource to reset it, and replace with the
87+
// initial resource, so that subsequent calls to initialize will act
88+
// just like the first call to initialize
89+
rmm::mr::device_memory_resource *mr = rmm::mr::get_default_resource();
90+
rmm::mr::set_default_resource(nullptr);
91+
delete mr;
92+
93+
rmm::Manager::getInstance().finalize();
94+
return RMM_SUCCESS;
8695
}
8796

8897
// Query the initialization state of RMM.
8998
bool rmmIsInitialized(rmmOptions_t *options)
9099
{
91-
if (nullptr != options) {
92-
*options = rmm::Manager::getOptions();
93-
}
94-
return rmm::Manager::getInstance().isInitialized();
100+
if (nullptr != options) {
101+
*options = rmm::Manager::getOptions();
102+
}
103+
return rmm::Manager::getInstance().isInitialized();
95104
}
96105

97106
// Allocate memory and return a pointer to device memory.
@@ -109,72 +118,70 @@ rmmError_t rmmAlloc(void **ptr, size_t size, cudaStream_t stream, const char* fi
109118
// Release device memory and recycle the associated memory.
110119
rmmError_t rmmFree(void *ptr, cudaStream_t stream, const char* file, unsigned int line)
111120
{
112-
return rmm::free(ptr, stream, file, line);
121+
return rmm::free(ptr, stream, file, line);
113122
}
114123

115124
// Get the offset of ptr from its base allocation
116125
rmmError_t rmmGetAllocationOffset(ptrdiff_t *offset,
117126
void *ptr,
118127
cudaStream_t stream)
119128
{
120-
void *base = (void*)0xffffffff;
121-
CUresult res = cuMemGetAddressRange((CUdeviceptr*)&base, nullptr,
122-
(CUdeviceptr)ptr);
123-
if (res != CUDA_SUCCESS)
124-
return RMM_ERROR_INVALID_ARGUMENT;
125-
*offset = reinterpret_cast<ptrdiff_t>(ptr) -
126-
reinterpret_cast<ptrdiff_t>(base);
127-
return RMM_SUCCESS;
129+
void *base = (void*)0xffffffff;
130+
CUresult res = cuMemGetAddressRange((CUdeviceptr*)&base, nullptr,
131+
(CUdeviceptr)ptr);
132+
if (res != CUDA_SUCCESS)
133+
return RMM_ERROR_INVALID_ARGUMENT;
134+
*offset = reinterpret_cast<ptrdiff_t>(ptr) -
135+
reinterpret_cast<ptrdiff_t>(base);
136+
return RMM_SUCCESS;
128137
}
129138

130139
// Get amounts of free and total memory managed by a manager associated
131140
// with the stream.
132141
rmmError_t rmmGetInfo(size_t *freeSize, size_t *totalSize, cudaStream_t stream)
133142
{
134-
try{
135-
std::pair<size_t,size_t> memInfo = rmm::mr::get_default_resource()->get_mem_info( stream);
136-
*freeSize = memInfo.first;
137-
*totalSize = memInfo.second;
138-
}catch(std::runtime_error){
139-
return RMM_ERROR_CUDA_ERROR;
140-
}
141-
return RMM_SUCCESS;
143+
try {
144+
std::pair<size_t,size_t> memInfo = rmm::mr::get_default_resource()->get_mem_info( stream);
145+
*freeSize = memInfo.first;
146+
*totalSize = memInfo.second;
147+
} catch(std::runtime_error){
148+
return RMM_ERROR_CUDA_ERROR;
149+
}
150+
return RMM_SUCCESS;
142151
}
143152

144153
// Write the memory event stats log to specified path/filename
145154
rmmError_t rmmWriteLog(const char* filename)
146155
{
147-
try
148-
{
149-
std::ofstream csv;
150-
csv.open(filename);
151-
rmm::Manager::getLogger().to_csv(csv);
152-
}
153-
catch (const std::ofstream::failure& e) {
154-
return RMM_ERROR_IO;
155-
}
156-
return RMM_SUCCESS;
156+
try {
157+
std::ofstream csv;
158+
csv.open(filename);
159+
rmm::Manager::getLogger().to_csv(csv);
160+
}
161+
catch (const std::ofstream::failure& e) {
162+
return RMM_ERROR_IO;
163+
}
164+
return RMM_SUCCESS;
157165
}
158166

159167
// Get the size opf the CSV log
160168
size_t rmmLogSize()
161169
{
162-
std::ostringstream csv;
163-
rmm::Manager::getLogger().to_csv(csv);
164-
return csv.str().size();
170+
std::ostringstream csv;
171+
rmm::Manager::getLogger().to_csv(csv);
172+
return csv.str().size();
165173
}
166174

167175
// Get the CSV log as a string
168176
rmmError_t rmmGetLog(char *buffer, size_t buffer_size)
169177
{
170-
try
171-
{
172-
std::ostringstream csv;
173-
rmm::Manager::getLogger().to_csv(csv);
174-
csv.str().copy(buffer, std::min(buffer_size, csv.str().size()));
175-
}
176-
catch (const std::ofstream::failure& e) {
177-
return RMM_ERROR_IO;
178-
}
179-
return RMM_SUCCESS;
178+
try {
179+
std::ostringstream csv;
180+
rmm::Manager::getLogger().to_csv(csv);
181+
csv.str().copy(buffer, std::min(buffer_size, csv.str().size()));
182+
}
183+
catch (const std::ofstream::failure& e) {
184+
return RMM_ERROR_IO;
185+
}
186+
return RMM_SUCCESS;
180187
}

0 commit comments

Comments
 (0)