Skip to content

Commit 57c9406

Browse files
committed
updated enums in controls and tests
1 parent d591f71 commit 57c9406

3 files changed

Lines changed: 122 additions & 118 deletions

File tree

RAT/controls.py

Lines changed: 70 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self,
1313
parallel: str = ParallelOptions.Single.value,
1414
calcSldDuringFit: bool = False,
1515
resamPars: list[Union[int, float]] = [0.9, 50],
16-
display: str = DisplayOptions.Iter.value) -> None:
16+
display: str = DisplayOptions.Iter) -> None:
1717

1818
self._parallel = parallel
1919
self._calcSldDuringFit = calcSldDuringFit
@@ -293,10 +293,10 @@ class Calculate(BaseProcedure):
293293
"""Defines the class for the calculate procedure"""
294294

295295
def __init__(self,
296-
parallel: str = ParallelOptions.Single.value,
296+
parallel: str = ParallelOptions.Single,
297297
calcSldDuringFit: bool = False,
298298
resamPars: list[Union[int, float]] = [0.9, 50],
299-
display: str = DisplayOptions.Iter.value) -> None:
299+
display: str = DisplayOptions.Iter) -> None:
300300

301301
# call the constructor of the parent class
302302
super().__init__(parallel = parallel,
@@ -314,24 +314,24 @@ def procedure(self) -> str:
314314
procedure : str
315315
The value of the procedure property.
316316
"""
317-
return Procedures.Calculate.value
317+
return Procedures.Calculate
318318

319319
def __repr__(self):
320320
"""
321321
Defines the display method for Calculate class
322322
"""
323-
table = super().__repr__(Procedures.Calculate.value)
323+
table = super().__repr__(Procedures.Calculate)
324324
return table
325325

326326

327327
class Simplex(BaseProcedure):
328328
"""Defines the class for the simplex procedure"""
329329

330330
def __init__(self,
331-
parallel: str = ParallelOptions.Single.value,
331+
parallel: str = ParallelOptions.Single,
332332
calcSldDuringFit: bool = False,
333333
resamPars: list[Union[int, float]] = [0.9, 50],
334-
display: str = DisplayOptions.Iter.value,
334+
display: str = DisplayOptions.Iter,
335335
tolX: float = 1e-6,
336336
tolFun: float = 1e-6,
337337
maxFunEvals: int = 10000,
@@ -362,7 +362,7 @@ def procedure(self) -> str:
362362
procedure : str
363363
The value of the procedure property.
364364
"""
365-
return Procedures.Simplex.value
365+
return Procedures.Simplex
366366

367367
@property
368368
def tolX(self) -> float:
@@ -554,22 +554,22 @@ def __repr__(self):
554554
"""
555555
Defines the display method for Simplex class
556556
"""
557-
table = super().__repr__(Procedures.Simplex.value)
557+
table = super().__repr__(Procedures.Simplex)
558558
return table
559559

560560

561561
class DE(BaseProcedure):
562562
"""Defines the class for the Differential Evolution procedure"""
563563

564564
def __init__(self,
565-
parallel: str = ParallelOptions.Single.value,
565+
parallel: str = ParallelOptions.Single,
566566
calcSldDuringFit: bool = False,
567567
resamPars: list[Union[int, float]] = [0.9, 50],
568-
display: str = DisplayOptions.Iter.value,
568+
display: str = DisplayOptions.Iter,
569569
populationSize: int = 20,
570570
fWeight: float = 0.5,
571571
crossoverProbability: float = 0.8,
572-
strategy: int = StrategyOptions.RandomWithPerVectorDither.value,
572+
strategy: int = StrategyOptions.RandomWithPerVectorDither,
573573
targetValue: Union[int, float] = 1,
574574
numGenerations: int = 500) -> None:
575575

@@ -596,7 +596,7 @@ def procedure(self) -> str:
596596
procedure : str
597597
The value of the procedure property.
598598
"""
599-
return Procedures.DE.value
599+
return Procedures.DE
600600

601601
@property
602602
def populationSize(self) -> int:
@@ -808,18 +808,18 @@ def __repr__(self):
808808
"""
809809
Defines the display method for DE class
810810
"""
811-
table = super().__repr__(Procedures.DE.value)
811+
table = super().__repr__(Procedures.DE)
812812
return table
813813

814814

815815
class NS(BaseProcedure):
816816
"""Defines the class for the Nested Sampler procedure"""
817817

818818
def __init__(self,
819-
parallel: str = ParallelOptions.Single.value,
819+
parallel: str = ParallelOptions.Single,
820820
calcSldDuringFit: bool = False,
821821
resamPars: list[Union[int, float]] = [0.9, 50],
822-
display: str = DisplayOptions.Iter.value,
822+
display: str = DisplayOptions.Iter,
823823
Nlive: int = 150,
824824
Nmcmc: Union[float, int] = 0,
825825
propScale: float = 0.1,
@@ -846,7 +846,7 @@ def procedure(self) -> str:
846846
procedure : str
847847
The value of the procedure property.
848848
"""
849-
return Procedures.NS.value
849+
return Procedures.NS
850850

851851
@property
852852
def Nlive(self) -> int:
@@ -992,23 +992,23 @@ def __repr__(self):
992992
"""
993993
Defines the display method for NS class
994994
"""
995-
table = super().__repr__(Procedures.NS.value)
995+
table = super().__repr__(Procedures.NS)
996996
return table
997997

998998

999999
class Dream(BaseProcedure):
10001000
"""Defines the class for the Dream procedure"""
10011001

10021002
def __init__(self,
1003-
parallel: str = ParallelOptions.Single.value,
1003+
parallel: str = ParallelOptions.Single,
10041004
calcSldDuringFit: bool = False,
10051005
resamPars: list[Union[int, float]] = [0.9, 50],
1006-
display: str = DisplayOptions.Iter.value,
1006+
display: str = DisplayOptions.Iter,
10071007
nSamples: int = 50000,
10081008
nChains: int = 10,
10091009
jumpProb: float = 0.5,
10101010
pUnitGamma:float = 0.2,
1011-
boundHandling: str = BoundHandlingOptions.Fold.value) -> None:
1011+
boundHandling: str = BoundHandlingOptions.Fold) -> None:
10121012

10131013
# call the constructor of the parent class
10141014
super().__init__(parallel=parallel,
@@ -1032,7 +1032,7 @@ def procedure(self) -> str:
10321032
procedure : str
10331033
The value of the procedure property.
10341034
"""
1035-
return Procedures.Dream.value
1035+
return Procedures.Dream
10361036

10371037
@property
10381038
def nSamples(self) -> int:
@@ -1210,32 +1210,32 @@ def __repr__(self):
12101210
"""
12111211
Defines the display method for Dream class
12121212
"""
1213-
table = super().__repr__(Procedures.Dream.value)
1213+
table = super().__repr__(Procedures.Dream)
12141214
return table
12151215

12161216

12171217
class ControlsClass:
12181218

12191219
def __init__(self,
1220-
procedure: str = Procedures.Calculate.value,
1220+
procedure: str = Procedures.Calculate,
12211221
**properties) -> None:
12221222

12231223
self._procedure = procedure
12241224
self._validate_properties(**properties)
12251225

1226-
if self._procedure == Procedures.Calculate.value:
1226+
if self._procedure == Procedures.Calculate:
12271227
self._controls = Calculate(**properties)
12281228

1229-
elif self._procedure == Procedures.Simplex.value:
1229+
elif self._procedure == Procedures.Simplex:
12301230
self._controls = Simplex(**properties)
12311231

1232-
elif self._procedure == Procedures.DE.value:
1232+
elif self._procedure == Procedures.DE:
12331233
self._controls = DE(**properties)
12341234

1235-
elif self._procedure == Procedures.NS.value:
1235+
elif self._procedure == Procedures.NS:
12361236
self._controls = NS(**properties)
12371237

1238-
elif self._procedure == Procedures.Dream.value:
1238+
elif self._procedure == Procedures.Dream:
12391239
self._controls = Dream(**properties)
12401240

12411241
@property
@@ -1273,47 +1273,47 @@ def _validate_properties(self, **properties) -> bool:
12731273
ValueError
12741274
Raised if properties are not validated.
12751275
"""
1276-
property_names = {Procedures.Calculate.value: {'parallel',
1277-
'calcSLdDuringFit',
1278-
'resamPars',
1279-
'display'},
1280-
Procedures.Simplex.value: {'parallel',
1281-
'calcSLdDuringFit',
1282-
'resamPars',
1283-
'display',
1284-
'tolX',
1285-
'tolFun',
1286-
'maxFunEvals',
1287-
'maxIter',
1288-
'updateFreq',
1289-
'updatePlotFreq'},
1290-
Procedures.DE.value: {'parallel',
1291-
'calcSLdDuringFit',
1292-
'resamPars',
1293-
'display',
1294-
'populationSize',
1295-
'fWeight',
1296-
'crossoverProbability',
1297-
'strategy',
1298-
'targetValue',
1299-
'numGenerations'},
1300-
Procedures.NS.value: {'parallel',
1301-
'calcSLdDuringFit',
1302-
'resamPars',
1303-
'display',
1304-
'Nlive',
1305-
'Nmcmc',
1306-
'propScale',
1307-
'nsTolerance'},
1308-
Procedures.Dream.value: {'parallel',
1309-
'calcSLdDuringFit',
1310-
'resamPars',
1311-
'display',
1312-
'nSamples',
1313-
'nChains',
1314-
'jumpProb',
1315-
'pUnitGamma',
1316-
'boundHandling'}}
1276+
property_names = {Procedures.Calculate: {'parallel',
1277+
'calcSLdDuringFit',
1278+
'resamPars',
1279+
'display'},
1280+
Procedures.Simplex: {'parallel',
1281+
'calcSLdDuringFit',
1282+
'resamPars',
1283+
'display',
1284+
'tolX',
1285+
'tolFun',
1286+
'maxFunEvals',
1287+
'maxIter',
1288+
'updateFreq',
1289+
'updatePlotFreq'},
1290+
Procedures.DE: {'parallel',
1291+
'calcSLdDuringFit',
1292+
'resamPars',
1293+
'display',
1294+
'populationSize',
1295+
'fWeight',
1296+
'crossoverProbability',
1297+
'strategy',
1298+
'targetValue',
1299+
'numGenerations'},
1300+
Procedures.NS: {'parallel',
1301+
'calcSLdDuringFit',
1302+
'resamPars',
1303+
'display',
1304+
'Nlive',
1305+
'Nmcmc',
1306+
'propScale',
1307+
'nsTolerance'},
1308+
Procedures.Dream: {'parallel',
1309+
'calcSLdDuringFit',
1310+
'resamPars',
1311+
'display',
1312+
'nSamples',
1313+
'nChains',
1314+
'jumpProb',
1315+
'pUnitGamma',
1316+
'boundHandling'}}
13171317
expected_properties = property_names[self._procedure]
13181318
input_properties = properties.keys()
13191319
if not (expected_properties | input_properties == expected_properties):

RAT/utils/enums.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
from enum import Enum
2+
try:
3+
from enum import StrEnum
4+
except ImportError:
5+
from strenum import StrEnum
26

37

4-
class ParallelOptions(Enum):
8+
class ParallelOptions(StrEnum):
59
"""Defines the avaliable options for parallelization"""
610
Single = 'single'
711
Points = 'points'
812
Contrasts = 'contrasts'
913
All = 'all'
1014

1115

12-
class Procedures(Enum):
16+
class Procedures(StrEnum):
1317
"""Defines the avaliable options for procedures"""
1418
Calculate = 'calculate'
1519
Simplex = 'simplex'
@@ -18,15 +22,15 @@ class Procedures(Enum):
1822
Dream = 'dream'
1923

2024

21-
class DisplayOptions(Enum):
25+
class DisplayOptions(StrEnum):
2226
"""Defines the avaliable options for display"""
2327
Off = 'off'
2428
Iter = 'iter'
2529
Notify = 'notify'
2630
Final = 'final'
2731

2832

29-
class BoundHandlingOptions(Enum):
33+
class BoundHandlingOptions(StrEnum):
3034
"""Defines the avaliable options for bound handling"""
3135
No = 'no'
3236
Reflect = 'reflect'

0 commit comments

Comments
 (0)