Skip to content

Commit 01c2d30

Browse files
authored
Merge pull request #48 from mirecl/47-api-add-support-getsupportedevaluatortypes
Add support get supported evaluator types
2 parents 3ca68e5 + 6b8ab29 commit 01c2d30

File tree

6 files changed

+85
-3
lines changed

6 files changed

+85
-3
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,4 @@ jobs:
123123
go run example/metadata/metadata.go
124124
go run example/ranker/ranker.go
125125
go run example/survival/survival.go
126+
go run example/device/device.go

catboost/catboost.go

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
)
2020

2121
type PredictionType string
22+
type EvaluatorType uint64
2223

2324
const (
2425
RawFormulaVal PredictionType = "RawFormulaVal"
@@ -28,6 +29,12 @@ const (
2829
Exponent PredictionType = "Exponent"
2930
)
3031

32+
const (
33+
CPU EvaluatorType = iota
34+
GPU
35+
Unknown
36+
)
37+
3138
const formatErrorMessage = "%w: %v"
3239

3340
// https://catboost.ai/en/docs/concepts/python-reference_catboost_metadata
@@ -49,6 +56,7 @@ var (
4956
ErrLoadLibrary = errors.New("failed loading CatBoost shared library")
5057
ErrSetPredictionType = errors.New("failed set prediction type")
5158
ErrGetIndices = errors.New("failed get indices")
59+
ErrGetDevices = errors.New("failed get devices")
5260
)
5361

5462
var catboostSharedLibraryPath = ""
@@ -73,7 +81,6 @@ func initialization() error {
7381

7482
handle := C.dlopen(cName, C.RTLD_LAZY)
7583
if handle == nil {
76-
C.dlclose(handle)
7784
msg := C.GoString(C.dlerror())
7885
return fmt.Errorf("%w `%s`: %s", ErrLoadLibrary, catboostSharedLibraryPath, msg)
7986
}
@@ -94,6 +101,7 @@ func initialization() error {
94101
l.RegisterFn("GetModelInfoValue")
95102
l.RegisterFn("GetCatFeatureIndices")
96103
l.RegisterFn("GetFloatFeatureIndices")
104+
l.RegisterFn("GetSupportedEvaluatorTypes")
97105

98106
return nil
99107
}
@@ -132,6 +140,8 @@ func (l *library) RegisterFn(fnName string) {
132140
C.SetGetCatFeatureIndicesFn(fnC)
133141
case "GetFloatFeatureIndices":
134142
C.SetGetFloatFeatureIndicesFn(fnC)
143+
case "GetSupportedEvaluatorTypes":
144+
C.SetGetSupportedEvaluatorTypesFn(fnC)
135145
default:
136146
panic(fmt.Sprintf("not supported function from catboost library: %s", fnName))
137147
}
@@ -220,6 +230,27 @@ func (m *Model) SetPredictionType(p PredictionType) error {
220230
return nil
221231
}
222232

233+
// GetSupportedEvaluatorTypes returns supported formula evaluator types.
234+
func (m *Model) GetSupportedEvaluatorTypes() ([]EvaluatorType, error) {
235+
devicesNum := uint64(2)
236+
237+
devicesTmp := make([]*uint64, devicesNum)
238+
devicesC := (*C.size_t)(devicesTmp[0])
239+
defer C.free(unsafe.Pointer(devicesC))
240+
241+
if !C.WrapGetSupportedEvaluatorTypes(m.handler, &devicesC, (*C.size_t)(&devicesNum)) {
242+
return nil, fmt.Errorf(formatErrorMessage, ErrGetDevices, GetError())
243+
}
244+
245+
devicesCTmp := (*[1 << 28]C.int)(unsafe.Pointer(devicesC))[:devicesNum:devicesNum]
246+
247+
devices := make([]EvaluatorType, 0, len(devicesCTmp))
248+
for _, d := range devicesCTmp {
249+
devices = append(devices, EvaluatorType(d))
250+
}
251+
return devices, nil
252+
}
253+
223254
// GetModelUsedFeaturesNames returns names of features used in the model.
224255
func (m *Model) GetModelUsedFeaturesNames() ([]string, error) {
225256
featuresCount := m.GetFeaturesCount()
@@ -368,7 +399,7 @@ func (m *Model) GetCatFeatureIndices() ([]uint64, error) {
368399
defer C.free(unsafe.Pointer(catsFeatureIndicesC))
369400

370401
if !C.WrapGetCatFeatureIndices(m.handler, &catsFeatureIndicesC, (*C.size_t)(&catsFeatureNum)) {
371-
return []uint64{}, ErrGetIndices
402+
return nil, fmt.Errorf(formatErrorMessage, ErrGetIndices, GetError())
372403
}
373404

374405
indices := (*[1 << 28]uint64)(unsafe.Pointer(catsFeatureIndicesC))[:catsFeatureNum:catsFeatureNum]
@@ -387,7 +418,7 @@ func (m *Model) GetFloatFeatureIndices() ([]uint64, error) {
387418
defer C.free(unsafe.Pointer(floatsFeatureIndicesC))
388419

389420
if !C.WrapGetFloatFeatureIndices(m.handler, &floatsFeatureIndicesC, (*C.size_t)(&floatsFeatureNum)) {
390-
return []uint64{}, ErrGetIndices
421+
return nil, fmt.Errorf(formatErrorMessage, ErrGetIndices, GetError())
391422
}
392423

393424
indices := (*[1 << 28]uint64)(unsafe.Pointer(floatsFeatureIndicesC))[:floatsFeatureNum:floatsFeatureNum]

catboost/catboost_wrapper.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ static TypeGetModelUsedFeaturesNames GetModelUsedFeaturesNamesFn = NULL;
1313
static TypeGetModelInfoValue GetModelInfoValueFn = NULL;
1414
static TypeGetCatFeatureIndices GetCatFeatureIndicesFn = NULL;
1515
static TypeGetFloatFeatureIndices GetFloatFeatureIndicesFn = NULL;
16+
static TypeGetSupportedEvaluatorTypes GetSupportedEvaluatorTypesFn = NULL;
1617

1718
const char* WrapGetErrorString() {
1819
return GetErrorStringFn();
@@ -77,6 +78,10 @@ const char* WrapGetModelInfoValue(ModelCalcerHandle* modelHandle, const char* ke
7778
return GetModelInfoValueFn(modelHandle, keyPtr, keySize);
7879
}
7980

81+
bool WrapGetSupportedEvaluatorTypes(ModelCalcerHandle* modelHandle, size_t** formulaEvaluatorTypes, size_t* count) {
82+
return GetSupportedEvaluatorTypesFn(modelHandle, formulaEvaluatorTypes, count);
83+
}
84+
8085
void SetCalcModelPredictionSingleFn(void *fn) {
8186
CalcModelPredictionSingleFn = ((TypeCalcModelPredictionSingle) fn);
8287
}
@@ -129,6 +134,10 @@ void SetGetModelInfoValueFn(void *fn) {
129134
GetModelInfoValueFn = ((TypeGetModelInfoValue) fn);
130135
}
131136

137+
void SetGetSupportedEvaluatorTypesFn(void *fn) {
138+
GetSupportedEvaluatorTypesFn = ((TypeGetSupportedEvaluatorTypes) fn);
139+
}
140+
132141
char*** makeCharArray2D(int size) {
133142
return calloc(sizeof(char**), size);
134143
}

catboost/catboost_wrapper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ typedef bool (*TypeGetModelUsedFeaturesNames) (ModelCalcerHandle* modelHandle, c
2424
typedef const char* (*TypeGetModelInfoValue) (ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize);
2525
typedef bool (*TypeGetCatFeatureIndices) (ModelCalcerHandle* modelHandle, size_t** indices, size_t* count);
2626
typedef bool (*TypeGetFloatFeatureIndices) (ModelCalcerHandle* modelHandle, size_t** indices, size_t* count);
27+
typedef bool (*TypeGetSupportedEvaluatorTypes) (ModelCalcerHandle* modelHandle, size_t** formulaEvaluatorTypes, size_t* count);
2728

2829
void SetGetErrorStringFn(void *fn);
2930
void SetCalcModelPredictionSingleFn(void *fn);
@@ -38,6 +39,7 @@ void SetGetModelUsedFeaturesNamesFn(void *fn);
3839
void SetGetModelInfoValueFn(void *fn);
3940
void SetGetCatFeatureIndicesFn(void *fn);
4041
void SetGetFloatFeatureIndicesFn(void *fn);
42+
void SetGetSupportedEvaluatorTypesFn(void *fn);
4143

4244
const char* WrapGetErrorString();
4345
ModelCalcerHandle* WrapModelCalcerCreate();
@@ -61,6 +63,7 @@ bool WrapGetModelUsedFeaturesNames(ModelCalcerHandle* modelHandle, char*** featu
6163
const char* WrapGetModelInfoValue(ModelCalcerHandle* modelHandle, const char* keyPtr, size_t keySize);
6264
bool WrapGetCatFeatureIndices(ModelCalcerHandle* modelHandle, size_t** indices, size_t* count);
6365
bool WrapGetFloatFeatureIndices(ModelCalcerHandle* modelHandle, size_t** indices, size_t* count);
66+
bool WrapGetSupportedEvaluatorTypes(ModelCalcerHandle* modelHandle, size_t** formulaEvaluatorTypes, size_t* count);
6467

6568
void freeCharArray1D(char **a, int size);
6669
void freeCharArray2D(char ***a, int sizeX, int sizeY);

example/device/device.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"log"
6+
"path"
7+
"path/filepath"
8+
"runtime"
9+
10+
cb "github.com/mirecl/catboost-cgo/catboost"
11+
)
12+
13+
func main() {
14+
_, fileName, _, _ := runtime.Caller(0)
15+
modelPath := path.Join(filepath.Dir(fileName), "regressor.cbm")
16+
17+
// Initialize CatBoostRegressor
18+
model, err := cb.LoadFullModelFromFile(modelPath)
19+
if err != nil {
20+
log.Fatalln(err)
21+
}
22+
23+
devices, err := model.GetSupportedEvaluatorTypes()
24+
if err != nil {
25+
log.Fatalln(err)
26+
}
27+
28+
for _, device := range devices {
29+
switch device {
30+
case cb.CPU:
31+
fmt.Println("Supported CPU")
32+
case cb.GPU:
33+
fmt.Println("Supported GPU")
34+
default:
35+
fmt.Println("Unknown device")
36+
}
37+
}
38+
}

example/device/regressor.cbm

6.38 KB
Binary file not shown.

0 commit comments

Comments
 (0)