@@ -106,47 +106,46 @@ void* Caffe::RNG::generator() {
106106
107107Caffe::Caffe ()
108108 // TODO: HIP Equivalent
109- : hipblas_handle_(NULL ), random_generator_(),
109+ : hipblas_handle_(NULL ), hiprng_generator_( NULL ), random_generator_(),
110110 mode_(Caffe::CPU), solver_count_(1 ), root_solver_(true ) {
111111 // Try to create a hipblas handler, and report an error if failed (but we will
112112 // keep the program running as one might just want to run CPU code).
113113 if (hipblasCreate (&hipblas_handle_) != HIPBLAS_STATUS_SUCCESS) {
114114 LOG (ERROR) << " Cannot create Cublas handle. Cublas won't be available." ;
115115 }
116- // Try to create a curand handler.
117- /* if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT )
118- != CURAND_STATUS_SUCCESS ||
119- curandSetPseudoRandomGeneratorSeed(curand_generator_ , cluster_seedgen())
120- != CURAND_STATUS_SUCCESS ) {
116+ // Try to create a hiprng handler.
117+ if (hiprngCreateGenerator (&hiprng_generator_, HIPRNG_RNG_PSEUDO_MRG32K3A )
118+ != HIPRNG_STATUS_SUCCESS ||
119+ hiprngSetPseudoRandomGeneratorSeed (hiprng_generator_ , cluster_seedgen ())
120+ != HIPRNG_STATUS_SUCCESS ) {
121121 LOG (ERROR) << " Cannot create Curand generator. Curand won't be available." ;
122- }*/
122+ }
123123}
124124
125125Caffe::~Caffe () {
126- // TODO: HIP Equivalent
127126 if (hipblas_handle_) HIPBLAS_CHECK (hipblasDestroy (hipblas_handle_));
128- /* if (curand_generator_ ) {
129- CURAND_CHECK(curandDestroyGenerator(curand_generator_ ));
130- }*/
127+ if (hiprng_generator_ ) {
128+ HIPRNG_CHECK ( hiprngDestroyGenerator (hiprng_generator_ ));
129+ }
131130}
132131
133132void Caffe::set_random_seed (const unsigned int seed) {
134133 // Curand seed
135- // TODO HIP Equivalent
136- /* static bool g_curand_availability_logged = false;
137- if (Get().curand_generator_) {
138- CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(curand_generator(),
139- seed));
140- CURAND_CHECK(curandSetGeneratorOffset(curand_generator (), 0));
134+ static bool g_hiprng_availability_logged = false ;
135+ if ( Get (). hiprng_generator_ ) {
136+ HIPRNG_CHECK ( hiprngSetPseudoRandomGeneratorSeed ( hiprng_generator (),
137+ seed));
138+ // TODO: support in HIP equivalent
139+ // HIPRNG_CHECK(hiprngSetGeneratorOffset(hiprng_generator (), 0));
141140 } else {
142- if (!g_curand_availability_logged ) {
141+ if (!g_hiprng_availability_logged ) {
143142 LOG (ERROR) <<
144- "Curand not available. Skipping setting the curand seed.";
145- g_curand_availability_logged = true;
143+ " Curand not available. Skipping setting the hiprng seed." ;
144+ g_hiprng_availability_logged = true ;
146145 }
147146 }
148147 // RNG seed
149- Get().random_generator_.reset(new RNG(seed));*/
148+ Get ().random_generator_ .reset (new RNG (seed));
150149}
151150
152151void Caffe::SetDevice (const int device_id) {
@@ -158,16 +157,15 @@ void Caffe::SetDevice(const int device_id) {
158157 // The call to hipSetDevice must come before any calls to Get, which
159158 // may perform initialization using the GPU.
160159 HIP_CHECK (hipSetDevice (device_id));
161- // TODO HIP equivalent
162160 if (Get ().hipblas_handle_ ) HIPBLAS_CHECK (hipblasDestroy (Get ().hipblas_handle_ ));
163- /* if (Get().curand_generator_ ) {
164- CURAND_CHECK(curandDestroyGenerator (Get().curand_generator_ ));
165- }*/
161+ if (Get ().hiprng_generator_ ) {
162+ HIPRNG_CHECK ( hiprngDestroyGenerator (Get ().hiprng_generator_ ));
163+ }
166164 HIPBLAS_CHECK (hipblasCreate (&Get ().hipblas_handle_ ));
167- /* CURAND_CHECK(curandCreateGenerator (&Get().curand_generator_ ,
168- CURAND_RNG_PSEUDO_DEFAULT ));
169- CURAND_CHECK(curandSetPseudoRandomGeneratorSeed (Get().curand_generator_ ,
170- cluster_seedgen()));*/
165+ HIPRNG_CHECK ( hiprngCreateGenerator (&Get ().hiprng_generator_ ,
166+ HIPRNG_RNG_PSEUDO_MRG32K3A ));
167+ HIPRNG_CHECK ( hiprngSetPseudoRandomGeneratorSeed (Get ().hiprng_generator_ ,
168+ cluster_seedgen ()));
171169}
172170
173171void Caffe::DeviceQuery () {
@@ -263,7 +261,7 @@ void* Caffe::RNG::generator() {
263261
264262
265263const char * hipblasGetErrorString (hipblasStatus_t error) {
266- /* switch (error) {
264+ switch (error) {
267265 case HIPBLAS_STATUS_SUCCESS:
268266 return " HIPBLAS_STATUS_SUCCESS" ;
269267 case HIPBLAS_STATUS_NOT_INITIALIZED:
@@ -278,46 +276,49 @@ const char* hipblasGetErrorString(hipblasStatus_t error) {
278276 return " HIPBLAS_STATUS_EXECUTION_FAILED" ;
279277 case HIPBLAS_STATUS_INTERNAL_ERROR:
280278 return " HIPBLAS_STATUS_INTERNAL_ERROR" ;
279+ case HIPBLAS_STATUS_NOT_SUPPORTED:
280+ return " HIPBLAS_STATUS_NOT_SUPPORTED" ;
281281#if HIP_VERSION >= 6000
282282 case HIPBLAS_STATUS_INTERNAL_ERROR:
283283 return " HIPBLAS_STATUS_INTERNAL_ERROR" ;
284284#endif
285- }*/
285+ }
286286 return " Unknown hipblas status" ;
287287}
288288
289- // TODO HIP Equivalent
290- /* const char* curandGetErrorString(curandStatus_t error) {
289+ const char * hiprngGetErrorString (hiprngStatus_t error) {
291290 switch (error) {
292- case CURAND_STATUS_SUCCESS:
293- return "CURAND_STATUS_SUCCESS";
294- case CURAND_STATUS_VERSION_MISMATCH:
295- return "CURAND_STATUS_VERSION_MISMATCH";
296- case CURAND_STATUS_NOT_INITIALIZED:
297- return "CURAND_STATUS_NOT_INITIALIZED";
298- case CURAND_STATUS_ALLOCATION_FAILED:
299- return "CURAND_STATUS_ALLOCATION_FAILED";
300- case CURAND_STATUS_TYPE_ERROR:
301- return "CURAND_STATUS_TYPE_ERROR";
302- case CURAND_STATUS_OUT_OF_RANGE:
303- return "CURAND_STATUS_OUT_OF_RANGE";
304- case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
305- return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
306- case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
307- return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
308- case CURAND_STATUS_LAUNCH_FAILURE:
309- return "CURAND_STATUS_LAUNCH_FAILURE";
310- case CURAND_STATUS_PREEXISTING_FAILURE:
311- return "CURAND_STATUS_PREEXISTING_FAILURE";
312- case CURAND_STATUS_INITIALIZATION_FAILED:
313- return "CURAND_STATUS_INITIALIZATION_FAILED";
314- case CURAND_STATUS_ARCH_MISMATCH:
315- return "CURAND_STATUS_ARCH_MISMATCH";
316- case CURAND_STATUS_INTERNAL_ERROR:
317- return "CURAND_STATUS_INTERNAL_ERROR";
291+ case HIPRNG_STATUS_INVALID_STREAM_CREATOR:
292+ return " HIPRNG_STATUS_INVALID_STREAM_CREATOR" ;
293+ case HIPRNG_STATUS_SUCCESS:
294+ return " HIPRNG_STATUS_SUCCESS" ;
295+ case HIPRNG_STATUS_VERSION_MISMATCH:
296+ return " HIPRNG_STATUS_VERSION_MISMATCH" ;
297+ // case HIPRNG_STATUS_NOT_INITIALIZED:
298+ // return "HIPRNG_STATUS_NOT_INITIALIZED";
299+ case HIPRNG_STATUS_ALLOCATION_FAILED:
300+ return " HIPRNG_STATUS_ALLOCATION_FAILED" ;
301+ case HIPRNG_STATUS_TYPE_ERROR:
302+ return " HIPRNG_STATUS_TYPE_ERROR" ;
303+ // case HIPRNG_STATUS_OUT_OF_RANGE:
304+ // return "HIPRNG_STATUS_OUT_OF_RANGE";
305+ // case HIPRNG_STATUS_LENGTH_NOT_MULTIPLE:
306+ // return "HIPRNG_STATUS_LENGTH_NOT_MULTIPLE";
307+ // case HIPRNG_STATUS_DOUBLE_PRECISION_REQUIRED:
308+ // return "HIPRNG_STATUS_DOUBLE_PRECISION_REQUIRED";
309+ // case HIPRNG_STATUS_LAUNCH_FAILURE:
310+ // return "HIPRNG_STATUS_LAUNCH_FAILURE";
311+ // case HIPRNG_STATUS_PREEXISTING_FAILURE:
312+ // return "HIPRNG_STATUS_PREEXISTING_FAILURE";
313+ case HIPRNG_STATUS_INITIALIZATION_FAILED:
314+ return " HIPRNG_STATUS_INITIALIZATION_FAILED" ;
315+ // case HIPRNG_STATUS_ARCH_MISMATCH:
316+ // return "HIPRNG_STATUS_ARCH_MISMATCH";
317+ // case HIPRNG_STATUS_INTERNAL_ERROR:
318+ // return "HIPRNG_STATUS_INTERNAL_ERROR";
318319 }
319- return "Unknown curand status";
320- }*/
320+ return " Unknown hiprng status" ;
321+ }
321322
322323#endif // CPU_ONLY
323324
0 commit comments