From a2105c97b911043b8a43f1e28149765f54e9573a Mon Sep 17 00:00:00 2001 From: Austin Noto-Moniz Date: Fri, 23 Aug 2024 10:06:42 -0400 Subject: [PATCH] Deprecate training_data on subpredictors. For quite a while, the backend has only supported training data supplied on graph predictors. To support the existing interface where it can be provided on many subpredictors, the backend hoists any training data it finds up to the root. Deprecating the associated fields in the SDK is a step towards no longer needing to do that. --- src/citrine/__version__.py | 2 +- .../predictors/auto_ml_predictor.py | 22 +++++++- .../predictors/mean_property_predictor.py | 25 +++++++-- .../predictors/simple_mixture_predictor.py | 20 ++++++- tests/informatics/test_predictors.py | 56 +++++++++++++++++++ tests/serialization/test_predictors.py | 6 +- 6 files changed, 119 insertions(+), 12 deletions(-) diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index c2a4b6453..dcbfb52f6 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.4.8" +__version__ = "3.5.0" diff --git a/src/citrine/informatics/predictors/auto_ml_predictor.py b/src/citrine/informatics/predictors/auto_ml_predictor.py index c636e70e1..68b92483f 100644 --- a/src/citrine/informatics/predictors/auto_ml_predictor.py +++ b/src/citrine/informatics/predictors/auto_ml_predictor.py @@ -1,5 +1,6 @@ from typing import List, Optional, Set +from deprecation import deprecated from gemd.enumeration.base_enumeration import BaseEnumeration from citrine._rest.resource import Resource @@ -52,7 +53,7 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode): estimators: Optional[Set[AutoMLEstimator]] Set of estimators to consider during during AutoML model selection. If None is provided, defaults to AutoMLEstimator.RANDOM_FOREST. - training_data: Optional[List[DataSource]] + training_data: Optional[List[DataSource]] (deprecated) Sources of training data. Each can be either a CSV or an GEM Table. Candidates from multiple data sources will be combined into a flattened list and de-duplicated by uid and identifiers. De-duplication is performed if a uid or identifier is shared between two or @@ -69,7 +70,7 @@ class AutoMLPredictor(Resource["AutoMLPredictor"], PredictorNode): 'estimators', default={AutoMLEstimator.RANDOM_FOREST} ) - training_data = _properties.List( + _training_data = _properties.List( _properties.Object(DataSource), 'training_data', default=[] @@ -90,7 +91,22 @@ def __init__(self, self.inputs: List[Descriptor] = inputs self.estimators: Set[AutoMLEstimator] = estimators or {AutoMLEstimator.RANDOM_FOREST} self.outputs = outputs - self.training_data: List[DataSource] = training_data or [] + # self.training_data: List[DataSource] = training_data or [] + if training_data: + self.training_data: List[DataSource] = training_data + + @property + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'") + def training_data(self): + """[DEPRECATED] Retrieve training data associated with this node.""" + return self._training_data + + @training_data.setter + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'") + def training_data(self, value): + self._training_data = value def __str__(self): return ''.format(self.name) diff --git a/src/citrine/informatics/predictors/mean_property_predictor.py b/src/citrine/informatics/predictors/mean_property_predictor.py index 69c568c73..c71bb60df 100644 --- a/src/citrine/informatics/predictors/mean_property_predictor.py +++ b/src/citrine/informatics/predictors/mean_property_predictor.py @@ -1,10 +1,12 @@ -from typing import List, Optional, Mapping, Union +from typing import List, Mapping, Optional, Union + +from deprecation import deprecated from citrine._rest.resource import Resource from citrine._serialization import properties as _properties from citrine.informatics.data_sources import DataSource from citrine.informatics.descriptors import ( - FormulationDescriptor, RealDescriptor, CategoricalDescriptor + CategoricalDescriptor, FormulationDescriptor, RealDescriptor ) from citrine.informatics.predictors import PredictorNode @@ -79,7 +81,7 @@ class MeanPropertyPredictor(Resource["MeanPropertyPredictor"], PredictorNode): ), 'default_properties' ) - training_data = _properties.List( + _training_data = _properties.List( _properties.Object(DataSource), 'training_data', default=[] ) @@ -104,7 +106,22 @@ def __init__(self, self.impute_properties: bool = impute_properties self.label: Optional[str] = label self.default_properties: Optional[Mapping[str, Union[str, float]]] = default_properties - self.training_data: List[DataSource] = training_data or [] + # self.training_data: List[DataSource] = training_data or [] + if training_data: + self.training_data: List[DataSource] = training_data def __str__(self): return ''.format(self.name) + + @property + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'") + def training_data(self): + """[DEPRECATED] Retrieve training data associated with this node.""" + return self._training_data + + @training_data.setter + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'") + def training_data(self, value): + self._training_data = value diff --git a/src/citrine/informatics/predictors/simple_mixture_predictor.py b/src/citrine/informatics/predictors/simple_mixture_predictor.py index 5f803abe5..da2b1fe4b 100644 --- a/src/citrine/informatics/predictors/simple_mixture_predictor.py +++ b/src/citrine/informatics/predictors/simple_mixture_predictor.py @@ -1,5 +1,7 @@ from typing import List, Optional +from deprecation import deprecated + from citrine._rest.resource import Resource from citrine._serialization import properties from citrine.informatics.data_sources import DataSource @@ -28,7 +30,7 @@ class SimpleMixturePredictor(Resource["SimpleMixturePredictor"], PredictorNode): """ - training_data = properties.List(properties.Object(DataSource), 'training_data', default=[]) + _training_data = properties.List(properties.Object(DataSource), 'training_data', default=[]) typ = properties.String('type', default='SimpleMixture', deserializable=False) @@ -39,7 +41,8 @@ def __init__(self, training_data: Optional[List[DataSource]] = None): self.name: str = name self.description: str = description - self.training_data: List[DataSource] = training_data or [] + if training_data: + self.training_data: List[DataSource] = training_data def __str__(self): return ''.format(self.name) @@ -53,3 +56,16 @@ def input_descriptor(self) -> FormulationDescriptor: def output_descriptor(self) -> FormulationDescriptor: """The output formulation descriptor with key 'Flat Formulation'.""" return FormulationDescriptor.flat() + + @property + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data must be accessed through the top-level GraphPredictor.'") + def training_data(self): + """[DEPRECATED] Retrieve training data associated with this node.""" + return self._training_data + + @training_data.setter + @deprecated(deprecated_in="3.5.0", removed_in="4.0.0", + details="Training data should only be added to the top-level GraphPredictor.'") + def training_data(self, value): + self._training_data = value diff --git a/tests/informatics/test_predictors.py b/tests/informatics/test_predictors.py index b4250ece6..392438163 100644 --- a/tests/informatics/test_predictors.py +++ b/tests/informatics/test_predictors.py @@ -327,6 +327,24 @@ def test_auto_ml_multiple_outputs(auto_ml_multiple_outputs): assert built.dump()['outputs'] == [z.dump(), y.dump()] +def test_auto_ml_deprecated_training_data(auto_ml): + with pytest.deprecated_call(): + pred = AutoMLPredictor( + name='AutoML Predictor', + description='Predicts z from inputs w and x', + inputs=auto_ml.inputs, + outputs=auto_ml.outputs, + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + ) + + new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] + with pytest.deprecated_call(): + pred.training_data = new_training_data + + with pytest.deprecated_call(): + assert pred.training_data == new_training_data + + def test_ing_to_formulation_initialization(ing_to_formulation_predictor): """Make sure the correct fields go to the correct places for an ingredients to formulation predictor.""" assert ing_to_formulation_predictor.name == 'Ingredients to formulation predictor' @@ -361,6 +379,28 @@ def test_mean_property_round_robin(mean_property_predictor): assert len(cat_props) == 1 +def test_mean_property_training_data_deprecated(mean_property_predictor): + with pytest.deprecated_call(): + pred = MeanPropertyPredictor( + name='Mean property predictor', + description='Computes mean ingredient properties', + input_descriptor=mean_property_predictor.input_descriptor, + properties=mean_property_predictor.properties, + p=2.5, + impute_properties=True, + default_properties=mean_property_predictor.default_properties, + label=mean_property_predictor.label, + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + ) + + new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] + with pytest.deprecated_call(): + pred.training_data = new_training_data + + with pytest.deprecated_call(): + assert pred.training_data == new_training_data + + def test_label_fractions_property_initialization(label_fractions_predictor): """Make sure the correct fields go to the correct places for a label fraction predictor.""" assert label_fractions_predictor.name == 'Label fractions predictor' @@ -379,6 +419,22 @@ def test_simple_mixture_predictor_initialization(simple_mixture_predictor): assert str(simple_mixture_predictor) == expected_str +def test_simplex_mixture_training_data_deprecated(): + with pytest.deprecated_call(): + pred = SimpleMixturePredictor( + name='Simple mixture predictor', + description='Computes mean ingredient properties', + training_data=[GemTableDataSource(table_id=uuid.uuid4(), table_version=1)] + ) + + new_training_data = [GemTableDataSource(table_id=uuid.uuid4(), table_version=2)] + with pytest.deprecated_call(): + pred.training_data = new_training_data + + with pytest.deprecated_call(): + assert pred.training_data == new_training_data + + def test_ingredient_fractions_property_initialization(ingredient_fractions_predictor): """Make sure the correct fields go to the correct places for an ingredient fractions predictor.""" assert ingredient_fractions_predictor.name == 'Ingredient fractions predictor' diff --git a/tests/serialization/test_predictors.py b/tests/serialization/test_predictors.py index afdecdc6e..be55fd1c9 100644 --- a/tests/serialization/test_predictors.py +++ b/tests/serialization/test_predictors.py @@ -19,7 +19,8 @@ def test_auto_ml_deserialization(valid_auto_ml_predictor_data): assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="") assert len(predictor.outputs) == 1 assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="") - assert len(predictor.training_data) == 0 + with pytest.deprecated_call(): + assert len(predictor.training_data) == 0 def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data): @@ -31,7 +32,8 @@ def test_polymorphic_auto_ml_deserialization(valid_auto_ml_predictor_data): assert predictor.inputs[0] == RealDescriptor("x", lower_bound=0, upper_bound=100, units="") assert len(predictor.outputs) == 1 assert predictor.outputs[0] == RealDescriptor("z", lower_bound=0, upper_bound=100, units="") - assert len(predictor.training_data) == 0 + with pytest.deprecated_call(): + assert len(predictor.training_data) == 0 def test_legacy_serialization(valid_auto_ml_predictor_data):