@@ -19,6 +19,7 @@ import (
1919)
2020
2121type PredictionType string
22+ type EvaluatorType uint64
2223
2324const (
2425 RawFormulaVal PredictionType = "RawFormulaVal"
@@ -28,6 +29,12 @@ const (
2829 Exponent PredictionType = "Exponent"
2930)
3031
32+ const (
33+ CPU EvaluatorType = iota
34+ GPU
35+ Unknown
36+ )
37+
3138const formatErrorMessage = "%w: %v"
3239
3340// https://catboost.ai/en/docs/concepts/python-reference_catboost_metadata
4956 ErrLoadLibrary = errors .New ("failed loading CatBoost shared library" )
5057 ErrSetPredictionType = errors .New ("failed set prediction type" )
5158 ErrGetIndices = errors .New ("failed get indices" )
59+ ErrGetDevices = errors .New ("failed get devices" )
5260)
5361
5462var catboostSharedLibraryPath = ""
@@ -73,7 +81,6 @@ func initialization() error {
7381
7482 handle := C .dlopen (cName , C .RTLD_LAZY )
7583 if handle == nil {
76- C .dlclose (handle )
7784 msg := C .GoString (C .dlerror ())
7885 return fmt .Errorf ("%w `%s`: %s" , ErrLoadLibrary , catboostSharedLibraryPath , msg )
7986 }
@@ -94,6 +101,7 @@ func initialization() error {
94101 l .RegisterFn ("GetModelInfoValue" )
95102 l .RegisterFn ("GetCatFeatureIndices" )
96103 l .RegisterFn ("GetFloatFeatureIndices" )
104+ l .RegisterFn ("GetSupportedEvaluatorTypes" )
97105
98106 return nil
99107}
@@ -132,6 +140,8 @@ func (l *library) RegisterFn(fnName string) {
132140 C .SetGetCatFeatureIndicesFn (fnC )
133141 case "GetFloatFeatureIndices" :
134142 C .SetGetFloatFeatureIndicesFn (fnC )
143+ case "GetSupportedEvaluatorTypes" :
144+ C .SetGetSupportedEvaluatorTypesFn (fnC )
135145 default :
136146 panic (fmt .Sprintf ("not supported function from catboost library: %s" , fnName ))
137147 }
@@ -220,6 +230,27 @@ func (m *Model) SetPredictionType(p PredictionType) error {
220230 return nil
221231}
222232
233+ // GetSupportedEvaluatorTypes returns supported formula evaluator types.
234+ func (m * Model ) GetSupportedEvaluatorTypes () ([]EvaluatorType , error ) {
235+ devicesNum := uint64 (2 )
236+
237+ devicesTmp := make ([]* uint64 , devicesNum )
238+ devicesC := (* C .size_t )(devicesTmp [0 ])
239+ defer C .free (unsafe .Pointer (devicesC ))
240+
241+ if ! C .WrapGetSupportedEvaluatorTypes (m .handler , & devicesC , (* C .size_t )(& devicesNum )) {
242+ return nil , fmt .Errorf (formatErrorMessage , ErrGetDevices , GetError ())
243+ }
244+
245+ devicesCTmp := (* [1 << 28 ]C.int )(unsafe .Pointer (devicesC ))[:devicesNum :devicesNum ]
246+
247+ devices := make ([]EvaluatorType , 0 , len (devicesCTmp ))
248+ for _ , d := range devicesCTmp {
249+ devices = append (devices , EvaluatorType (d ))
250+ }
251+ return devices , nil
252+ }
253+
223254// GetModelUsedFeaturesNames returns names of features used in the model.
224255func (m * Model ) GetModelUsedFeaturesNames () ([]string , error ) {
225256 featuresCount := m .GetFeaturesCount ()
@@ -368,7 +399,7 @@ func (m *Model) GetCatFeatureIndices() ([]uint64, error) {
368399 defer C .free (unsafe .Pointer (catsFeatureIndicesC ))
369400
370401 if ! C .WrapGetCatFeatureIndices (m .handler , & catsFeatureIndicesC , (* C .size_t )(& catsFeatureNum )) {
371- return [] uint64 {}, ErrGetIndices
402+ return nil , fmt . Errorf ( formatErrorMessage , ErrGetIndices , GetError ())
372403 }
373404
374405 indices := (* [1 << 28 ]uint64 )(unsafe .Pointer (catsFeatureIndicesC ))[:catsFeatureNum :catsFeatureNum ]
@@ -387,7 +418,7 @@ func (m *Model) GetFloatFeatureIndices() ([]uint64, error) {
387418 defer C .free (unsafe .Pointer (floatsFeatureIndicesC ))
388419
389420 if ! C .WrapGetFloatFeatureIndices (m .handler , & floatsFeatureIndicesC , (* C .size_t )(& floatsFeatureNum )) {
390- return [] uint64 {}, ErrGetIndices
421+ return nil , fmt . Errorf ( formatErrorMessage , ErrGetIndices , GetError ())
391422 }
392423
393424 indices := (* [1 << 28 ]uint64 )(unsafe .Pointer (floatsFeatureIndicesC ))[:floatsFeatureNum :floatsFeatureNum ]
0 commit comments