Skip to content

Commit 6b8ab29

Browse files
committed
Fix device
1 parent bad4160 commit 6b8ab29

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,4 @@ jobs:
123123
go run example/metadata/metadata.go
124124
go run example/ranker/ranker.go
125125
go run example/survival/survival.go
126+
go run example/device/device.go

catboost/catboost.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ func initialization() error {
8181

8282
handle := C.dlopen(cName, C.RTLD_LAZY)
8383
if handle == nil {
84-
C.dlclose(handle)
8584
msg := C.GoString(C.dlerror())
8685
return fmt.Errorf("%w `%s`: %s", ErrLoadLibrary, catboostSharedLibraryPath, msg)
8786
}
@@ -235,16 +234,21 @@ func (m *Model) SetPredictionType(p PredictionType) error {
235234
func (m *Model) GetSupportedEvaluatorTypes() ([]EvaluatorType, error) {
236235
devicesNum := uint64(2)
237236

238-
devices := make([]*EvaluatorType, devicesNum)
239-
devicesC := (*C.size_t)(devices[0])
237+
devicesTmp := make([]*uint64, devicesNum)
238+
devicesC := (*C.size_t)(devicesTmp[0])
240239
defer C.free(unsafe.Pointer(devicesC))
241240

242241
if !C.WrapGetSupportedEvaluatorTypes(m.handler, &devicesC, (*C.size_t)(&devicesNum)) {
243-
return nil, ErrGetDevices
242+
return nil, fmt.Errorf(formatErrorMessage, ErrGetDevices, GetError())
244243
}
245244

246-
supportDevices := (*[1 << 28]EvaluatorType)(unsafe.Pointer(devicesC))[:devicesNum:devicesNum]
247-
return supportDevices, nil
245+
devicesCTmp := (*[1 << 28]C.int)(unsafe.Pointer(devicesC))[:devicesNum:devicesNum]
246+
247+
devices := make([]EvaluatorType, 0, len(devicesCTmp))
248+
for _, d := range devicesCTmp {
249+
devices = append(devices, EvaluatorType(d))
250+
}
251+
return devices, nil
248252
}
249253

250254
// GetModelUsedFeaturesNames returns names of features used in the model.
@@ -395,7 +399,7 @@ func (m *Model) GetCatFeatureIndices() ([]uint64, error) {
395399
defer C.free(unsafe.Pointer(catsFeatureIndicesC))
396400

397401
if !C.WrapGetCatFeatureIndices(m.handler, &catsFeatureIndicesC, (*C.size_t)(&catsFeatureNum)) {
398-
return []uint64{}, ErrGetIndices
402+
return nil, fmt.Errorf(formatErrorMessage, ErrGetIndices, GetError())
399403
}
400404

401405
indices := (*[1 << 28]uint64)(unsafe.Pointer(catsFeatureIndicesC))[:catsFeatureNum:catsFeatureNum]
@@ -414,7 +418,7 @@ func (m *Model) GetFloatFeatureIndices() ([]uint64, error) {
414418
defer C.free(unsafe.Pointer(floatsFeatureIndicesC))
415419

416420
if !C.WrapGetFloatFeatureIndices(m.handler, &floatsFeatureIndicesC, (*C.size_t)(&floatsFeatureNum)) {
417-
return []uint64{}, ErrGetIndices
421+
return nil, fmt.Errorf(formatErrorMessage, ErrGetIndices, GetError())
418422
}
419423

420424
indices := (*[1 << 28]uint64)(unsafe.Pointer(floatsFeatureIndicesC))[:floatsFeatureNum:floatsFeatureNum]

0 commit comments

Comments
 (0)