diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index c2a4b6453..dcbfb52f6 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.4.8" +__version__ = "3.5.0" diff --git a/src/citrine/informatics/predictors/auto_ml_predictor.py b/src/citrine/informatics/predictors/auto_ml_predictor.py index c636e70e1..68b92483f 100644 --- a/src/citrine/informatics/predictors/auto_ml_predictor.py +++ b/src/citrine/informatics/predictors/auto_ml_predictor.py @@ -1,5 +1,6 @@ from typing import List, Optional, Set +from deprecation import deprecated from gemd.enumeration.base_enumeration import BaseEnumeration from citrine._rest.resource import Resource @@ -52,7 +53,7 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode): estimators: Optional[Set[AutoMLEstimator]] Set of estimators to consider during during AutoML model selection. If None is provided, defaults to AutoMLEstimator.RANDOM_FOREST. - training_data: Optional[List[DataSource]] + training_data: Optional[List[DataSource]] (deprecated) Sources of training data. Each can be either a CSV or an GEM Table. Candidates from multiple data sources will be combined into a flattened list and de-duplicated by uid and identifiers. De-duplication is performed if a uid or identifier is shared between two or @@ -69,7 +70,7 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode): 'estimators', default={AutoMLEstimator.RANDOM_FOREST} ) - training_data = _properties.List( + _training_data = _properties.List( _properties.Object(DataSource), 'training_data', default=[] @@ -90,7 +91,22 @@ def __init__(self, self.inputs: List[Descriptor] = inputs self.estimators: Set[AutoMLEstimator] = estimators or {AutoMLEstimator.RANDOM_FOREST} self.outputs = outputs - self.training_data: List[DataSource] = training_data or [] + # self.training_data: List[DataSource] = training_data or [] + if training_data: + self.training_data: List[DataSource] = training_data + + @property + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'") + def training_data(self): + """[DEPRECATED] Retrieve training data associated with this node.""" + return self._training_data + + @training_data.setter + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'") + def training_data(self, value): + self._training_data = value def __str__(self): return ''.format(self.name) diff --git a/src/citrine/informatics/predictors/mean_property_predictor.py b/src/citrine/informatics/predictors/mean_property_predictor.py index 69c568c73..c71bb60df 100644 --- a/src/citrine/informatics/predictors/mean_property_predictor.py +++ b/src/citrine/informatics/predictors/mean_property_predictor.py @@ -1,10 +1,12 @@ -from typing import List, Optional, Mapping, Union +from typing import List, Mapping, Optional, Union + +from deprecation import deprecated from citrine._rest.resource import Resource from citrine._serialization import properties as _properties from citrine.informatics.data_sources import DataSource from citrine.informatics.descriptors import ( - FormulationDescriptor, RealDescriptor, CategoricalDescriptor + CategoricalDescriptor, FormulationDescriptor, RealDescriptor ) from citrine.informatics.predictors import PredictorNode @@ -79,7 +81,7 @@ class MeanPropertyPredictor(Resource["MeanPropertyPredictor"], PredictorNode): ), 'default_properties' ) - training_data = _properties.List( + _training_data = _properties.List( _properties.Object(DataSource), 'training_data', default=[] ) @@ -104,7 +106,22 @@ def __init__(self, self.impute_properties: bool = impute_properties self.label: Optional[str] = label self.default_properties: Optional[Mapping[str, Union[str, float]]] = default_properties - self.training_data: List[DataSource] = training_data or [] + # self.training_data: List[DataSource] = training_data or [] + if training_data: + self.training_data: List[DataSource] = training_data def __str__(self): return ''.format(self.name) + + @property + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'") + def training_data(self): + """[DEPRECATED] Retrieve training data associated with this node.""" + return self._training_data + + @training_data.setter + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'") + def training_data(self, value): + self._training_data = value diff --git a/src/citrine/informatics/predictors/simple_mixture_predictor.py b/src/citrine/informatics/predictors/simple_mixture_predictor.py index 5f803abe5..da2b1fe4b 100644 --- a/src/citrine/informatics/predictors/simple_mixture_predictor.py +++ b/src/citrine/informatics/predictors/simple_mixture_predictor.py @@ -1,5 +1,7 @@ from typing import List, Optional +from deprecation import deprecated + from citrine._rest.resource import Resource from citrine._serialization import properties from citrine.informatics.data_sources import DataSource @@ -28,7 +30,7 @@ class SimpleMixturePredictor(Resource["SimpleMixturePredictor"], PredictorNode): """ - training_data = properties.List(properties.Object(DataSource), 'training_data', default=[]) + _training_data = properties.List(properties.Object(DataSource), 'training_data', default=[]) typ = properties.String('type', default='SimpleMixture', deserializable=False) @@ -39,7 +41,8 @@ def __init__(self, training_data: Optional[List[DataSource]] = None): self.name: str = name self.description: str = description - self.training_data: List[DataSource] = training_data or [] + if training_data: + self.training_data: List[DataSource] = training_data def __str__(self): return ''.format(self.name) @@ -53,3 +56,16 @@ def input_descriptor(self) -> FormulationDescriptor: def output_descriptor(self) -> FormulationDescriptor: """The output formulation descriptor with key 'Flat Formulation'.""" return FormulationDescriptor.flat() + + @property + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'") + def training_data(self): + """[DEPRECATED] Retrieve training data associated with this node.""" + return self._training_data + + @training_data.setter + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'") + def training_data(self, value): + self._training_data = value diff --git a/tests/informatics/test_predictors.py b/tests/informatics/test_predictors.py index b4250ece6..392438163 100644 --- a/tests/informatics/test_predictors.py +++ b/tests/informatics/test_predictors.py @@ -327,6 +327,24 @@ def test_auto_ml_multiple_outputs(auto_ml_multiple_outputs): assert built.dump()['outputs'] == [z.dump(), y.dump()] +def test_auto_ml_deprecated_training_data(auto_ml): + with pytest.deprecated_call(): + pred = AutoMLPredictor( + name='AutoML Predictor', + description='Predicts z from inputs w and x', + inputs=auto_ml.inputs, + outputs=auto_ml.outputs, + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + ) + + new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] + with pytest.deprecated_call(): + pred.training_data = new_training_data + + with pytest.deprecated_call(): + assert pred.training_data == new_training_data + + def test_ing_to_formulation_initialization(ing_to_formulation_predictor): """Make sure the correct fields go to the correct places for an ingredients to formulation predictor.""" assert ing_to_formulation_predictor.name == 'Ingredients to formulation predictor' @@ -361,6 +379,28 @@ def test_mean_property_round_robin(mean_property_predictor): assert len(cat_props) == 1 +def test_mean_property_training_data_deprecated(mean_property_predictor): + with pytest.deprecated_call(): + pred = MeanPropertyPredictor( + name='Mean property predictor', + description='Computes mean ingredient properties', + input_descriptor=mean_property_predictor.input_descriptor, + properties=mean_property_predictor.properties, + p=2.5, + impute_properties=True, + default_properties=mean_property_predictor.default_properties, + label=mean_property_predictor.label, + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + ) + + new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] + with pytest.deprecated_call(): + pred.training_data = new_training_data + + with pytest.deprecated_call(): + assert pred.training_data == new_training_data + + def test_label_fractions_property_initialization(label_fractions_predictor): """Make sure the correct fields go to the correct places for a label fraction predictor.""" assert label_fractions_predictor.name == 'Label fractions predictor' @@ -379,6 +419,22 @@ def test_simple_mixture_predictor_initialization(simple_mixture_predictor): assert str(simple_mixture_predictor) == expected_str +def test_simplex_mixture_training_data_deprecated(): + with pytest.deprecated_call(): + pred = SimpleMixturePredictor( + name='Simple mixture predictor', + description='Computes mean ingredient properties', + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + ) + + new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] + with pytest.deprecated_call(): + pred.training_data = new_training_data + + with pytest.deprecated_call(): + assert pred.training_data == new_training_data + + def test_ingredient_fractions_property_initialization(ingredient_fractions_predictor): """Make sure the correct fields go to the correct places for an ingredient fractions predictor.""" assert ingredient_fractions_predictor.name == 'Ingredient fractions predictor' diff --git a/tests/serialization/test_predictors.py b/tests/serialization/test_predictors.py index afdecdc6e..be55fd1c9 100644 --- a/tests/serialization/test_predictors.py +++ b/tests/serialization/test_predictors.py @@ -19,7 +19,8 @@ def test_auto_ml_deserialization(valid_auto_ml_predictor_data): assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="") assert len(predictor.outputs) == 1 assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="") - assert len(predictor.training_data) == 0 + with pytest.deprecated_call(): + assert len(predictor.training_data) == 0 def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data): @@ -31,7 +32,8 @@ def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data): assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="") assert len(predictor.outputs) == 1 assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="") - assert len(predictor.training_data) == 0 + with pytest.deprecated_call(): + assert len(predictor.training_data) == 0 def test_legacy_serialization(valid_auto_ml_predictor_data):