@@ -59,39 +59,48 @@ const char * rmmGetErrorString(rmmError_t errcode) {
5959// Initialize memory manager state and storage.
6060rmmError_t rmmInitialize (rmmOptions_t *options)
6161{
62+ rmm::Manager::getInstance ().initialize (options);
63+
64+ rmm::mr::device_memory_resource * memory_resource = nullptr ;
65+
66+ if (rmm::Manager::usePoolAllocator ()) {
67+ if (rmm::Manager::useManagedMemory ()) {
68+ memory_resource = new rmm::mr::cnmem_managed_memory_resource{
69+ rmm::Manager::getOptions ().initial_pool_size };
70+ } else {
71+ memory_resource = new rmm::mr::cnmem_memory_resource{
72+ rmm::Manager::getOptions ().initial_pool_size };
73+ }
74+ } else if (rmm::Manager::useManagedMemory ()) {
75+ memory_resource = new rmm::mr::managed_memory_resource ();
76+ }
6277
63- rmm::Manager::getInstance ().initialize (options);
64- rmm::mr::device_memory_resource * memory_resource = rmm::mr::detail::initial_resource ();
65-
66- if (rmm::Manager::usePoolAllocator ())
67- {
68- memory_resource = rmm::mr::detail::pool_resource (rmm::Manager::getOptions ().initial_pool_size );
69-
70- }else if (rmm::Manager::useManagedMemory ()){
71- memory_resource = rmm::mr::detail::managed_resource ();
72-
73- }else if (rmm::Manager::useManagedMemory () && rmm::Manager::usePoolAllocator ()){
74- memory_resource = rmm::mr::detail::managed_pool_resource (rmm::Manager::getOptions ().initial_pool_size );
75- }
76- rmm::mr::set_default_resource (memory_resource);
78+ rmm::mr::set_default_resource (memory_resource);
7779
78- return RMM_SUCCESS;
80+ return RMM_SUCCESS;
7981}
8082
8183// Shutdown memory manager.
8284rmmError_t rmmFinalize ()
8385{
84- rmm::Manager::getInstance ().finalize ();
85- return RMM_SUCCESS;
86+ // delete the current default resource to reset it, and replace with the
87+ // initial resource, so that subsequent calls to initialize will act
88+ // just like the first call to initialize
89+ rmm::mr::device_memory_resource *mr = rmm::mr::get_default_resource ();
90+ rmm::mr::set_default_resource (nullptr );
91+ delete mr;
92+
93+ rmm::Manager::getInstance ().finalize ();
94+ return RMM_SUCCESS;
8695}
8796
8897// Query the initialization state of RMM.
8998bool rmmIsInitialized (rmmOptions_t *options)
9099{
91- if (nullptr != options) {
92- *options = rmm::Manager::getOptions ();
93- }
94- return rmm::Manager::getInstance ().isInitialized ();
100+ if (nullptr != options) {
101+ *options = rmm::Manager::getOptions ();
102+ }
103+ return rmm::Manager::getInstance ().isInitialized ();
95104}
96105
97106// Allocate memory and return a pointer to device memory.
@@ -109,72 +118,70 @@ rmmError_t rmmAlloc(void **ptr, size_t size, cudaStream_t stream, const char* fi
109118// Release device memory and recycle the associated memory.
110119rmmError_t rmmFree (void *ptr, cudaStream_t stream, const char * file, unsigned int line)
111120{
112- return rmm::free (ptr, stream, file, line);
121+ return rmm::free (ptr, stream, file, line);
113122}
114123
115124// Get the offset of ptr from its base allocation
116125rmmError_t rmmGetAllocationOffset (ptrdiff_t *offset,
117126 void *ptr,
118127 cudaStream_t stream)
119128{
120- void *base = (void *)0xffffffff ;
121- CUresult res = cuMemGetAddressRange ((CUdeviceptr*)&base, nullptr ,
122- (CUdeviceptr)ptr);
123- if (res != CUDA_SUCCESS)
124- return RMM_ERROR_INVALID_ARGUMENT;
125- *offset = reinterpret_cast <ptrdiff_t >(ptr) -
126- reinterpret_cast <ptrdiff_t >(base);
127- return RMM_SUCCESS;
129+ void *base = (void *)0xffffffff ;
130+ CUresult res = cuMemGetAddressRange ((CUdeviceptr*)&base, nullptr ,
131+ (CUdeviceptr)ptr);
132+ if (res != CUDA_SUCCESS)
133+ return RMM_ERROR_INVALID_ARGUMENT;
134+ *offset = reinterpret_cast <ptrdiff_t >(ptr) -
135+ reinterpret_cast <ptrdiff_t >(base);
136+ return RMM_SUCCESS;
128137}
129138
130139// Get amounts of free and total memory managed by a manager associated
131140// with the stream.
132141rmmError_t rmmGetInfo (size_t *freeSize, size_t *totalSize, cudaStream_t stream)
133142{
134- try {
135- std::pair<size_t ,size_t > memInfo = rmm::mr::get_default_resource ()->get_mem_info ( stream);
136- *freeSize = memInfo.first ;
137- *totalSize = memInfo.second ;
138- } catch (std::runtime_error){
139- return RMM_ERROR_CUDA_ERROR;
140- }
141- return RMM_SUCCESS;
143+ try {
144+ std::pair<size_t ,size_t > memInfo = rmm::mr::get_default_resource ()->get_mem_info ( stream);
145+ *freeSize = memInfo.first ;
146+ *totalSize = memInfo.second ;
147+ } catch (std::runtime_error){
148+ return RMM_ERROR_CUDA_ERROR;
149+ }
150+ return RMM_SUCCESS;
142151}
143152
144153// Write the memory event stats log to specified path/filename
145154rmmError_t rmmWriteLog (const char * filename)
146155{
147- try
148- {
149- std::ofstream csv;
150- csv.open (filename);
151- rmm::Manager::getLogger ().to_csv (csv);
152- }
153- catch (const std::ofstream::failure& e) {
154- return RMM_ERROR_IO;
155- }
156- return RMM_SUCCESS;
156+ try {
157+ std::ofstream csv;
158+ csv.open (filename);
159+ rmm::Manager::getLogger ().to_csv (csv);
160+ }
161+ catch (const std::ofstream::failure& e) {
162+ return RMM_ERROR_IO;
163+ }
164+ return RMM_SUCCESS;
157165}
158166
159167// Get the size opf the CSV log
160168size_t rmmLogSize ()
161169{
162- std::ostringstream csv;
163- rmm::Manager::getLogger ().to_csv (csv);
164- return csv.str ().size ();
170+ std::ostringstream csv;
171+ rmm::Manager::getLogger ().to_csv (csv);
172+ return csv.str ().size ();
165173}
166174
167175// Get the CSV log as a string
168176rmmError_t rmmGetLog (char *buffer, size_t buffer_size)
169177{
170- try
171- {
172- std::ostringstream csv;
173- rmm::Manager::getLogger ().to_csv (csv);
174- csv.str ().copy (buffer, std::min (buffer_size, csv.str ().size ()));
175- }
176- catch (const std::ofstream::failure& e) {
177- return RMM_ERROR_IO;
178- }
179- return RMM_SUCCESS;
178+ try {
179+ std::ostringstream csv;
180+ rmm::Manager::getLogger ().to_csv (csv);
181+ csv.str ().copy (buffer, std::min (buffer_size, csv.str ().size ()));
182+ }
183+ catch (const std::ofstream::failure& e) {
184+ return RMM_ERROR_IO;
185+ }
186+ return RMM_SUCCESS;
180187}
0 commit comments