Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RAT/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from RAT.classlist import ClassList
from RAT.controls import Controls
from RAT.project import Project
import RAT.controls
import RAT.models
93 changes: 43 additions & 50 deletions RAT/controls.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import prettytable
from pydantic import BaseModel, Field, field_validator
from typing import Union
from pydantic import BaseModel, Field, field_validator, ValidationError
from typing import Literal, Union

from RAT.utils.enums import ParallelOptions, Procedures, DisplayOptions, BoundHandlingOptions, StrategyOptions


class BaseProcedure(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the base class with properties used in all five procedures."""
class Calculate(BaseModel, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure, which includes the properties used in all five procedures."""
procedure: Literal[Procedures.Calculate] = Procedures.Calculate
parallel: ParallelOptions = ParallelOptions.Single
calcSldDuringFit: bool = False
resamPars: list[float] = Field([0.9, 50], min_length=2, max_length=2)
Expand All @@ -21,15 +22,16 @@ def check_resamPars(cls, resamPars):
raise ValueError('resamPars[1] must be greater than or equal to 0')
return resamPars


class Calculate(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the calculate procedure."""
procedure: Procedures = Field(Procedures.Calculate, frozen=True)
def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self.__dict__.items()])
return table.get_string()


class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the simplex procedure."""
procedure: Procedures = Field(Procedures.Simplex, frozen=True)
class Simplex(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the simplex procedure."""
procedure: Literal[Procedures.Simplex] = Procedures.Simplex
tolX: float = Field(1.0e-6, gt=0.0)
tolFun: float = Field(1.0e-6, gt=0.0)
maxFunEvals: int = Field(10000, gt=0)
Expand All @@ -38,9 +40,9 @@ class Simplex(BaseProcedure, validate_assignment=True, extra='forbid'):
updatePlotFreq: int = -1


class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Differential Evolution procedure."""
procedure: Procedures = Field(Procedures.DE, frozen=True)
class DE(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Differential Evolution procedure."""
procedure: Literal[Procedures.DE] = Procedures.DE
populationSize: int = Field(20, ge=1)
fWeight: float = 0.5
crossoverProbability: float = Field(0.8, gt=0.0, lt=1.0)
Expand All @@ -49,52 +51,43 @@ class DE(BaseProcedure, validate_assignment=True, extra='forbid'):
numGenerations: int = Field(500, ge=1)


class NS(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Nested Sampler procedure."""
procedure: Procedures = Field(Procedures.NS, frozen=True)
class NS(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Nested Sampler procedure."""
procedure: Literal[Procedures.NS] = Procedures.NS
Nlive: int = Field(150, ge=1)
Nmcmc: float = Field(0.0, ge=0.0)
propScale: float = Field(0.1, gt=0.0, lt=1.0)
nsTolerance: float = Field(0.1, ge=0.0)


class Dream(BaseProcedure, validate_assignment=True, extra='forbid'):
"""Defines the class for the Dream procedure."""
procedure: Procedures = Field(Procedures.Dream, frozen=True)
class Dream(Calculate, validate_assignment=True, extra='forbid'):
"""Defines the additional fields for the Dream procedure."""
procedure: Literal[Procedures.Dream] = Procedures.Dream
nSamples: int = Field(50000, ge=0)
nChains: int = Field(10, gt=0)
jumpProb: float = Field(0.5, gt=0.0, lt=1.0)
pUnitGamma: float = Field(0.2, gt=0.0, lt=1.0)
boundHandling: BoundHandlingOptions = BoundHandlingOptions.Fold


class Controls:

def __init__(self,
procedure: Procedures = Procedures.Calculate,
**properties) -> None:

if procedure == Procedures.Calculate:
self.controls = Calculate(**properties)
elif procedure == Procedures.Simplex:
self.controls = Simplex(**properties)
elif procedure == Procedures.DE:
self.controls = DE(**properties)
elif procedure == Procedures.NS:
self.controls = NS(**properties)
elif procedure == Procedures.Dream:
self.controls = Dream(**properties)

@property
def controls(self) -> Union[Calculate, Simplex, DE, NS, Dream]:
return self._controls

@controls.setter
def controls(self, value: Union[Calculate, Simplex, DE, NS, Dream]) -> None:
self._controls = value

def __repr__(self) -> str:
table = prettytable.PrettyTable()
table.field_names = ['Property', 'Value']
table.add_rows([[k, v] for k, v in self._controls.__dict__.items()])
return table.get_string()
def set_controls(procedure: Procedures = Procedures.Calculate, **properties)\
-> Union[Calculate, Simplex, DE, NS, Dream]:
"""Returns the appropriate controls model given the specified procedure."""
controls = {
Procedures.Calculate: Calculate,
Procedures.Simplex: Simplex,
Procedures.DE: DE,
Procedures.NS: NS,
Procedures.Dream: Dream
}

try:
model = controls[procedure](**properties)
except KeyError:
members = list(Procedures.__members__.values())
allowed_values = f'{", ".join([repr(member.value) for member in members[:-1]])} or {members[-1].value!r}'
raise ValueError(f'The controls procedure must be one of: {allowed_values}') from None
except ValidationError:
raise

return model
10 changes: 3 additions & 7 deletions RAT/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import collections
import copy
import functools
import logging
import numpy as np
import os
from pydantic import BaseModel, ValidationInfo, field_validator, model_validator, ValidationError
from typing import Any, Callable

from RAT.classlist import ClassList
import RAT.models
from RAT.utils.custom_errors import formatted_pydantic_error
from RAT.utils.custom_errors import formatted_pydantic_error, formatted_traceback

try:
from enum import StrEnum
Expand Down Expand Up @@ -524,12 +525,7 @@ def wrapped_func(*args, **kwargs):
try:
return_value = func(*args, **kwargs)
Project.model_validate(self)
except ValidationError as e:
setattr(class_list, 'data', previous_state)
error_string = formatted_pydantic_error(e)
# Use ANSI escape sequences to print error text in red
print('\033[31m' + error_string + '\033[0m')
except (TypeError, ValueError):
except (TypeError, ValueError, ValidationError):
setattr(class_list, 'data', previous_state)
raise
finally:
Expand Down
24 changes: 21 additions & 3 deletions RAT/utils/custom_errors.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,44 @@
"""Defines routines for custom error handling in RAT."""

from pydantic import ValidationError
import traceback


def formatted_pydantic_error(error: ValidationError) -> str:
def formatted_pydantic_error(error: ValidationError, custom_error_messages: dict[str, str] = None) -> str:
"""Write a custom string format for pydantic validation errors.

Parameters
----------
error : pydantic.ValidationError
A ValidationError produced by a pydantic model
A ValidationError produced by a pydantic model.
custom_error_messages: dict[str, str], optional
A dict of custom error messages for given error types.

Returns
-------
error_str : str
A string giving details of the ValidationError in a custom format.
"""
if custom_error_messages is None:
custom_error_messages = {}
num_errors = error.error_count()
error_str = f'{num_errors} validation error{"s"[:num_errors!=1]} for {error.title}'

for this_error in error.errors():
error_type = this_error['type']
error_msg = custom_error_messages[error_type] if error_type in custom_error_messages else this_error["msg"]

error_str += '\n'
if this_error['loc']:
error_str += ' '.join(this_error['loc']) + '\n'
error_str += ' ' + this_error['msg']
error_str += f' {error_msg}'

return error_str


def formatted_traceback() -> str:
"""Takes the traceback obtained from "traceback.format_exc()" and removes the exception message for pydantic
ValidationErrors.
"""
traceback_string = traceback.format_exc()
return traceback_string.split('pydantic_core._pydantic_core.ValidationError:')[0]
Loading