diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index fcbd6954b..595a2b3bb 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.11.0" +__version__ = "3.11.1" diff --git a/src/citrine/informatics/predictors/__init__.py b/src/citrine/informatics/predictors/__init__.py index e22ecf3d0..db1909480 100644 --- a/src/citrine/informatics/predictors/__init__.py +++ b/src/citrine/informatics/predictors/__init__.py @@ -1,6 +1,7 @@ # flake8: noqa from .predictor import * from .node import * +from .attribute_accumulation_predictor import * from .expression_predictor import * from .graph_predictor import * from .ingredient_fractions_predictor import * diff --git a/src/citrine/informatics/predictors/attribute_accumulation_predictor.py b/src/citrine/informatics/predictors/attribute_accumulation_predictor.py new file mode 100644 index 000000000..dd5b40a3f --- /dev/null +++ b/src/citrine/informatics/predictors/attribute_accumulation_predictor.py @@ -0,0 +1,45 @@ +from typing import List + +from citrine._rest.resource import Resource +from citrine._serialization import properties as _properties +from citrine.informatics.descriptors import Descriptor +from citrine.informatics.predictors import PredictorNode + +__all__ = ['AttributeAccumulationPredictor'] + + +class AttributeAccumulationPredictor(Resource["AttributeAccumulationPredictor"], PredictorNode): + """A predictor that computes an output from an expression and set of bounded inputs. + + For a discussion of expression syntax and a list of allowed symbols, + please see the :ref:`documentation`. + + Parameters + ---------- + name: str + name of the configuration + description: str + the description of the predictor + attributes: List[Descriptor] + the attributes that are accumulated from ancestor nodes + + """ + + attributes = _properties.List(_properties.Object(Descriptor), 'attributes') + sequential = _properties.Boolean('sequential') + + typ = _properties.String('type', default='AttributeAccumulation', deserializable=False) + + def __init__(self, + name: str, + *, + description: str, + attributes: List[Descriptor], + sequential: bool): + self.name: str = name + self.description: str = description + self.attributes: List[Descriptor] = attributes + self.sequential: bool = sequential + + def __str__(self): + return ''.format(self.name) diff --git a/src/citrine/informatics/predictors/graph_predictor.py b/src/citrine/informatics/predictors/graph_predictor.py index 2e2cebdca..6f1405abc 100644 --- a/src/citrine/informatics/predictors/graph_predictor.py +++ b/src/citrine/informatics/predictors/graph_predictor.py @@ -118,3 +118,72 @@ def predict(self, predict_request: SinglePredictRequest) -> SinglePrediction: path = self._path() + '/predict' res = self._session.post_resource(path, predict_request.dump(), version=self._api_version) return SinglePrediction.build(res) + + def _convert_to_multistep(self) -> "GraphPredictor": + """Make the GraphPredictor look as if generated with a MULTISTEP_MATERIALS datasource.""" + from citrine.informatics.predictors import ( + AttributeAccumulationPredictor, MolecularStructureFeaturizer, + LabelFractionsPredictor, SimpleMixturePredictor, IngredientFractionsPredictor, + AutoMLPredictor, MeanPropertyPredictor, ChemicalFormulaFeaturizer + ) + + automl_outputs = {} + featurizer_outputs = set() + automl_inputs = {} + + for predictor in self.predictors: + if isinstance(predictor, AttributeAccumulationPredictor): + raise ValueError("Graph already contains Attribute Accumulation nodes") + elif isinstance(predictor, AutoMLPredictor): + for descriptor in predictor.outputs: + automl_outputs[descriptor.key] = descriptor + for descriptor in predictor.inputs: + automl_inputs[descriptor.key] = descriptor + elif isinstance(predictor, MeanPropertyPredictor): + for descriptor in predictor.properties: + featurizer_outputs.add( + f"mean of property {descriptor.key} in {predictor.input_descriptor.key}" + ) + elif isinstance(predictor, IngredientFractionsPredictor): + for ingredient in predictor.ingredients: + featurizer_outputs.add( + f"{ingredient} share in {predictor.input_descriptor.key}" + ) + elif isinstance(predictor, LabelFractionsPredictor): + for label in predictor.labels: + featurizer_outputs.add( + f"{label} share in {predictor.input_descriptor.key}" + ) + elif isinstance(predictor, (SimpleMixturePredictor, ChemicalFormulaFeaturizer, + MolecularStructureFeaturizer)): + pass + else: + # IngredientsToFormulationRelation, ExpressionPredictor, + # IngredientsToFormulationPredictor + raise NotImplementedError(f"Unhandled predictor type: {type(predictor)}") + + output_accumulator = AttributeAccumulationPredictor( + name="Output variable accumulation", + description="Output variables encountered in the material history. " + "Only sequential mixing steps are considered.", + attributes=list(automl_outputs.values()), + sequential=True + ) + input_accumulator = AttributeAccumulationPredictor( + name="Attribute accumulation", + description="Parameters/conditions encountered in the material history. " + "Most recent values are selected first.", + attributes=[automl_inputs[key] for key in automl_inputs + if key not in featurizer_outputs], + sequential=False + ) + + update = GraphPredictor( + name=self.name, + description=self.description, + predictors=self.predictors + [output_accumulator, input_accumulator], + training_data=self.training_data + ) + update.uid = self.uid + + return update diff --git a/src/citrine/informatics/predictors/node.py b/src/citrine/informatics/predictors/node.py index 637b43400..85cf142f9 100644 --- a/src/citrine/informatics/predictors/node.py +++ b/src/citrine/informatics/predictors/node.py @@ -19,6 +19,7 @@ class PredictorNode(PolymorphicSerializable["PredictorNode"], Predictor): @classmethod def get_type(cls, data) -> Type['PredictorNode']: """Return the subtype.""" + from .attribute_accumulation_predictor import AttributeAccumulationPredictor from .expression_predictor import ExpressionPredictor from .molecular_structure_featurizer import MolecularStructureFeaturizer from .ingredients_to_formulation_predictor import IngredientsToFormulationPredictor @@ -30,6 +31,7 @@ def get_type(cls, data) -> Type['PredictorNode']: from .chemical_formula_featurizer import ChemicalFormulaFeaturizer type_dict = { "AnalyticExpression": ExpressionPredictor, + "AttributeAccumulation": AttributeAccumulationPredictor, "MoleculeFeaturizer": MolecularStructureFeaturizer, "IngredientsToSimpleMixture": IngredientsToFormulationPredictor, "MeanProperty": MeanPropertyPredictor, diff --git a/src/citrine/resources/table_config.py b/src/citrine/resources/table_config.py index 3a5726709..f6d632ddc 100644 --- a/src/citrine/resources/table_config.py +++ b/src/citrine/resources/table_config.py @@ -23,7 +23,8 @@ from citrine.gemtables.variables import ( Variable, IngredientIdentifierByProcessTemplateAndName, IngredientQuantityByProcessAndName, IngredientQuantityDimension, IngredientIdentifierInOutput, IngredientQuantityInOutput, - IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput + IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput, + AttributeByTemplateAndObjectTemplate, LocalAttributeAndObject ) from typing import TYPE_CHECKING @@ -429,6 +430,28 @@ def add_all_ingredients_in_output(self, *, new_config.version_uid = copy(self.version_uid) return new_config + def _convert_to_multistep(self) -> "TableConfig": + """Convert the TableConfig to look like something generated by MULTISTEP_MATERIALS.""" + dup: TableConfig = TableConfig.build(self.dump()) + + def _convert_local(old: Variable) -> Variable: + if isinstance(old, AttributeByTemplateAndObjectTemplate): + return LocalAttributeAndObject( + name=old.name, + headers=old.headers, + template=old.attribute_template, + object_template=old.object_template, + attribute_constraints=old.attribute_constraints, + type_selector=old.type_selector, + ) + else: + return old + + dup.variables = [_convert_local(x) for x in dup.variables] + dup.generation_algorithm = TableFromGemdQueryAlgorithm.MULTISTEP_MATERIALS + + return dup + class TableConfigCollection(Collection[TableConfig]): """Represents the collection of all Table Configs associated with a project.""" diff --git a/tests/informatics/test_predictors.py b/tests/informatics/test_predictors.py index 0645fe4a7..aa1bf29f3 100644 --- a/tests/informatics/test_predictors.py +++ b/tests/informatics/test_predictors.py @@ -200,6 +200,28 @@ def ingredient_fractions_predictor() -> IngredientFractionsPredictor: ) +@pytest.fixture +def input_accumulation_predictor(auto_ml) -> AttributeAccumulationPredictor: + """Build an accumulation node for model inputs.""" + return AttributeAccumulationPredictor( + name='Input accumulation predictor', + description='Bubbles attributes up through the graph', + attributes=auto_ml.inputs, + sequential=False + ) + + +@pytest.fixture +def output_accumulation_predictor(auto_ml) -> AttributeAccumulationPredictor: + """Build an accumulation node for model outputs.""" + return AttributeAccumulationPredictor( + name='Output accumulation predictor', + description='Bubbles attributes up through the graph', + attributes=auto_ml.outputs, + sequential=True + ) + + def test_simple_report(graph_predictor): """Ensures we get a report from a simple predictor post_build call""" with pytest.raises(ValueError): @@ -453,6 +475,17 @@ def test_ingredient_fractions_property_initialization(ingredient_fractions_predi assert str(ingredient_fractions_predictor) == expected_str +def test_attribute_accumulation_predictor_initialization(input_accumulation_predictor, output_accumulation_predictor): + """Make sure the correct fields go to the correct places for an attribute accumulation predictor.""" + assert len(input_accumulation_predictor.attributes) == 2 + expected_input = f"" + assert str(input_accumulation_predictor) == expected_input + + assert len(output_accumulation_predictor.attributes) == 1 + expected_output = f"" + assert str(output_accumulation_predictor) == expected_output + + def test_status(graph_predictor, valid_graph_predictor_data): """Ensure we can check the status of predictor validation.""" # A locally built predictor should be "False" for all status checks @@ -485,3 +518,31 @@ def test_single_predict(graph_predictor): prediction_out = graph_predictor.predict(request) assert prediction_out.dump() == prediction_in.dump() assert session.post_resource.call_count == 1 + +def test__convert_to_multistep(molecule_featurizer, auto_ml, mean_property_predictor, ingredient_fractions_predictor, + label_fractions_predictor, expression_predictor, output_accumulation_predictor, + input_accumulation_predictor): + """Verify graph predictor multistep material update.""" + graph_predictor = GraphPredictor( + name='Graph predictor', + description='description', + predictors=[molecule_featurizer, auto_ml, mean_property_predictor, ingredient_fractions_predictor, label_fractions_predictor], + training_data=[data_source, formulation_data_source] + ) + updated = graph_predictor._convert_to_multistep() + assert len(updated.predictors) == len(graph_predictor.predictors) + 2 + generated_accumulation = [p for p in updated.predictors if isinstance(p, AttributeAccumulationPredictor)] + assert generated_accumulation[0].attributes == output_accumulation_predictor.attributes + assert generated_accumulation[1].attributes == input_accumulation_predictor.attributes + + with pytest.raises(ValueError): + updated._convert_to_multistep() + + + with pytest.raises(NotImplementedError): + GraphPredictor( + name='Graph predictor', + description='description', + predictors=[expression_predictor], + training_data=[data_source, formulation_data_source] + )._convert_to_multistep() diff --git a/tests/resources/test_table_config.py b/tests/resources/test_table_config.py index 7be645f05..5930bf590 100644 --- a/tests/resources/test_table_config.py +++ b/tests/resources/test_table_config.py @@ -9,7 +9,8 @@ IngredientQuantityDimension, IngredientQuantityByProcessAndName, \ IngredientIdentifierByProcessTemplateAndName, TerminalMaterialIdentifier, \ IngredientQuantityInOutput, IngredientIdentifierInOutput, \ - IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput + IngredientLabelsSetByProcessAndName, IngredientLabelsSetInOutput, AttributeByTemplateAndObjectTemplate, \ + LocalAttribute, LocalAttributeAndObject from citrine.resources.table_config import TableConfig, TableConfigCollection, TableBuildAlgorithm, \ TableFromGemdQueryAlgorithm from citrine.resources.data_concepts import CITRINE_SCOPE @@ -900,3 +901,20 @@ def test_update_unregistered_fail(collection, session): def test_delete(collection): with pytest.raises(NotImplementedError): collection.delete(empty_defn().config_uid) + + +def test__convert_to_multistep(): + variables = [ + AttributeByTemplate("One", headers=["one"], template=uuid4()), + AttributeByTemplateAndObjectTemplate("Two", headers=["two"], attribute_template=uuid4(), object_template=uuid4()), + LocalAttribute("Three", headers=["three"], template=uuid4()), + LocalAttributeAndObject("Four", headers=["four"], template=uuid4(), object_template=uuid4()), + ] + columns = [MeanColumn(data_source=v.name, target_units="") for v in variables] + config: TableConfig = TableConfig.build(TableConfigDataFactory( + variables=[v.dump() for v in variables], + columns=[c.dump() for c in columns], + )) + updated = config._convert_to_multistep() + assert len(config.variables) == len(config.variables) + assert not any(isinstance(x, AttributeByTemplateAndObjectTemplate) for x in updated.variables)