From ede810b150dddde29eec87f15f40b96e0fad6a7f Mon Sep 17 00:00:00 2001 From: Austin Noto-Moniz Date: Fri, 6 Sep 2024 16:44:28 -0400 Subject: [PATCH] Address DW with branch root ID and version. Now that the design workflow APIs support branch root ID and version explicitly, the SDK should use them. This cuts down on excess requests, resulting in much faster executions. --- src/citrine/__version__.py | 2 +- .../informatics/workflows/design_workflow.py | 35 ++-------- src/citrine/resources/design_workflow.py | 47 +++----------- src/citrine/seeding/find_or_create.py | 4 +- tests/resources/test_branch.py | 23 +++++++ tests/resources/test_design_workflows.py | 64 ++++++------------- tests/resources/test_workflow.py | 8 +-- tests/seeding/test_find_or_create.py | 6 -- 8 files changed, 62 insertions(+), 127 deletions(-) diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index 85197cb4a..46f67e7f8 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.6.0" +__version__ = "3.7.0" diff --git a/src/citrine/informatics/workflows/design_workflow.py b/src/citrine/informatics/workflows/design_workflow.py index 6b869e2e0..a1073b5b6 100644 --- a/src/citrine/informatics/workflows/design_workflow.py +++ b/src/citrine/informatics/workflows/design_workflow.py @@ -32,18 +32,17 @@ class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata): predictor_id = properties.Optional(properties.UUID, 'predictor_id') predictor_version = properties.Optional( properties.Union([properties.Integer(), properties.String()]), 'predictor_version') - _branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id') + branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id') + """:Optional[UUID]: Root ID of the branch that contains this workflow.""" + branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version') + """:Optional[int]: Version number of the branch that contains this workflow.""" status_description = properties.String('status_description', serializable=False) """:str: more detailed description of the workflow's status""" typ = properties.String('type', default='DesignWorkflow', deserializable=False) - _branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id', - serializable=False, deserializable=False) - """:Optional[UUID]: Root ID of the branch that contains this workflow.""" - _branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version', - serializable=False, deserializable=False) - """:Optional[int]: Version number of the branch that contains this workflow.""" + _branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id', + serializable=False) def __init__(self, name: str, @@ -68,25 +67,3 @@ def design_executions(self) -> DesignExecutionCollection: raise AttributeError('Cannot initialize execution without project reference!') return DesignExecutionCollection( project_id=self.project_id, session=self._session, workflow_id=self.uid) - - @property - def branch_root_id(self): - """Retrieve the root ID of the branch this workflow is on.""" - return self._branch_root_id - - @branch_root_id.setter - def branch_root_id(self, value): - """Set the root ID of the branch this workflow is on.""" - self._branch_root_id = value - self._branch_id = None - - @property - def branch_version(self): - """Retrieve the version of the branch this workflow is on.""" - return self._branch_version - - @branch_version.setter - def branch_version(self, value): - """Set the version of the branch this workflow is on.""" - self._branch_version = value - self._branch_id = None diff --git a/src/citrine/resources/design_workflow.py b/src/citrine/resources/design_workflow.py index 48952874e..4fffa021a 100644 --- a/src/citrine/resources/design_workflow.py +++ b/src/citrine/resources/design_workflow.py @@ -4,7 +4,6 @@ from citrine._rest.collection import Collection from citrine._session import Session -from citrine.exceptions import NotFound from citrine.informatics.workflows import DesignWorkflow from citrine.resources.response import Response from functools import partial @@ -31,25 +30,6 @@ def __init__(self, self.branch_root_id = branch_root_id self.branch_version = branch_version - def _resolve_branch_root_and_version(self, workflow): - from citrine.resources.branch import BranchCollection - - workflow_copy = deepcopy(workflow) - bc = BranchCollection(self.project_id, self.session) - branch = bc.get_by_version_id(version_id=workflow_copy._branch_id) - workflow_copy._branch_root_id = branch.root_id - workflow_copy._branch_version = branch.version - return workflow_copy - - def _resolve_branch_id(self, root_id, version): - from citrine.resources.branch import BranchCollection - - if root_id and version: - bc = BranchCollection(self.project_id, self.session) - branch = bc.get(root_id=root_id, version=version) - return branch.uid - return None - def register(self, model: DesignWorkflow) -> DesignWorkflow: """ Upload a new design workflow. @@ -77,15 +57,15 @@ def register(self, model: DesignWorkflow) -> DesignWorkflow: 'project.design_workflows.register().') raise RuntimeError(msg) else: - # branch_id is in the body of design workflow endpoints, so it must be serialized. - # This means the collection branch_id might not match the workflow branch_id. The - # collection should win out, since the user is explicitly referencing the branch - # represented by this collection. - # To avoid modifying the parameter, and to ensure the only change is the branch_id, we - # deepcopy, modify, then register it. + # branch_root_id and branch_version are in the body of design workflow endpoints, so + # they must be serialized. This means the collection fields might not match the + # workflow fields. The collection should win out, since the user is explicitly + # referencing the branch represented by this collection. + # To avoid modifying the parameter, and to ensure the only changes are the + # branch_root_id and branch_version, we deepcopy, modify, then register it. model_copy = deepcopy(model) - model_copy._branch_id = self._resolve_branch_id(self.branch_root_id, - self.branch_version) + model_copy.branch_root_id = self.branch_root_id + model_copy.branch_version = self.branch_version return super().register(model_copy) def build(self, data: dict) -> DesignWorkflow: @@ -104,7 +84,6 @@ def build(self, data: dict) -> DesignWorkflow: """ workflow = DesignWorkflow.build(data) - workflow = self._resolve_branch_root_and_version(workflow) workflow._session = self.session workflow.project_id = self.project_id return workflow @@ -137,13 +116,6 @@ def update(self, model: DesignWorkflow) -> DesignWorkflow: raise ValueError('Cannot update a design workflow unless its branch_root_id and ' 'branch_version are set.') - try: - model._branch_id = self._resolve_branch_id(model.branch_root_id, - model.branch_version) - except NotFound: - raise ValueError('Cannot update a design workflow unless its branch_root_id and ' - 'branch_version exists.') - # If executions have already been done, warn about future behavior change executions = model.design_executions.list() if next(executions, None) is not None: @@ -197,7 +169,8 @@ def _fetch_page(self, additional_params: Optional[dict] = None, ) -> Tuple[Iterable[dict], str]: params = additional_params or {} - params["branch"] = self._resolve_branch_id(self.branch_root_id, self.branch_version) + params["branch_root_id"] = self.branch_root_id + params["branch_version"] = self.branch_version return super()._fetch_page(path=path, fetch_func=fetch_func, page=page, diff --git a/src/citrine/seeding/find_or_create.py b/src/citrine/seeding/find_or_create.py index 7e7cd8045..b8ff71a69 100644 --- a/src/citrine/seeding/find_or_create.py +++ b/src/citrine/seeding/find_or_create.py @@ -175,8 +175,8 @@ def create_or_update(*, # Locally created design workflows likely won't have a branch ID but # need one to be updated. if isinstance(old_resource, DesignWorkflow): - new_resource._branch_root_id = old_resource.branch_root_id - new_resource._branch_version = old_resource.branch_version + new_resource.branch_root_id = old_resource.branch_root_id + new_resource.branch_version = old_resource.branch_version return collection.update(new_resource) else: logger.info("Registering new module: {}".format(resource.name)) diff --git a/tests/resources/test_branch.py b/tests/resources/test_branch.py index f46a1ee82..c0feea86b 100644 --- a/tests/resources/test_branch.py +++ b/tests/resources/test_branch.py @@ -107,6 +107,29 @@ def test_branch_get(session, collection, branch_path): assert session.last_call == FakeCall(method='GET', path=branch_path, params={'page': 1, 'per_page': 1, 'root': root_id, 'version': version}) +def test_branch_get_not_found(session, collection, branch_path): + # Given + session.set_response({"response": []}) + + # When + with pytest.raises(NotFound): + collection.get(root_id=uuid.uuid4(), version=1) + + +def test_branch_get_by_version_id(session, collection, branch_path): + # Given + branch_data = BranchDataFactory() + version_id = branch_data['id'] + session.set_response(branch_data) + + # When + branch = collection.get_by_version_id(version_id=version_id) + + # Then + assert session.num_calls == 1 + assert session.last_call == FakeCall(method='GET', path=f"{branch_path}/{version_id}") + + def test_branch_list(session, collection, branch_path): # Given branch_count = 5 diff --git a/tests/resources/test_design_workflows.py b/tests/resources/test_design_workflows.py index b5279d9e0..d37eb0443 100644 --- a/tests/resources/test_design_workflows.py +++ b/tests/resources/test_design_workflows.py @@ -43,7 +43,8 @@ def collection(branch_data, collection_without_branch) -> DesignWorkflowCollecti @pytest.fixture def workflow(collection, branch_data, design_workflow_dict) -> DesignWorkflow: - design_workflow_dict["branch_id"] = branch_data["id"] + design_workflow_dict["branch_root_id"] = branch_data["metadata"]["root_id"] + design_workflow_dict["branch_version"] = branch_data["metadata"]["version"] collection.session.set_response(branch_data) workflow = collection.build(design_workflow_dict) @@ -71,12 +72,6 @@ def workflow_path(collection, workflow=None): path = f'{path}/{workflow.uid}' return path -def branches_path(collection, branch_id=None): - path = f'/projects/{collection.project_id}/branches' - if branch_id: - path = f'{path}/{branch_id}' - return path - def assert_workflow(actual, expected, *, include_branch=False): assert actual.name == expected.name assert actual.description == expected.description @@ -86,7 +81,7 @@ def assert_workflow(actual, expected, *, include_branch=False): assert actual.predictor_version == expected.predictor_version assert actual.project_id == expected.project_id if include_branch: - assert actual.branch_id == expected.branch_id + assert actual._branch_id == expected._branch_id assert actual.branch_root_id == expected.branch_root_id assert actual.branch_version == expected.branch_version @@ -99,29 +94,22 @@ def test_basic_methods(workflow, collection, design_workflow_dict): @pytest.mark.parametrize("optional_args", all_combination_lengths(OPTIONAL_ARGS)) def test_register(session, branch_data, workflow_minimal, collection, optional_args): workflow = workflow_minimal - branch_id = branch_data['id'] - branch_data_get_resp = {"response": [branch_data]} - branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(collection.branch_root_id), 'version': collection.branch_version - } + branch_root_id = branch_data['metadata']['root_id'] + branch_version = branch_data['metadata']['version'] # Set a random value for all optional args selected for this run. for name, factory in optional_args: setattr(workflow, name, factory()) # Given - post_dict = {**workflow.dump(), "branch_id": str(branch_id)} - session.set_responses(branch_data_get_resp, {**post_dict, 'status_description': 'status'}, branch_data) + post_dict = {**workflow.dump(), "branch_root_id": str(branch_root_id), "branch_version": branch_version} + session.set_responses({**post_dict, 'status_description': 'status'}) # When new_workflow = collection.register(workflow) # Then - assert session.calls == [ - FakeCall(method='GET', path=branches_path(collection), params=branch_data_get_params), - FakeCall(method='POST', path=workflow_path(collection), json=post_dict), - FakeCall(method='GET', path=branches_path(collection, branch_id)), - ] + assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)] assert new_workflow.branch_root_id == collection.branch_root_id assert new_workflow.branch_version == collection.branch_version @@ -133,23 +121,18 @@ def test_register_conflicting_branches(session, branch_data, workflow, collectio old_branch_root_id = uuid.uuid4() workflow.branch_root_id = old_branch_root_id assert workflow.branch_root_id != collection.branch_root_id + + new_branch_root_id = str(branch_data["metadata"]["root_id"]) + new_branch_version = branch_data["metadata"]["version"] - branch_data_get_resp = {"response": [branch_data]} - branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(collection.branch_root_id), 'version': collection.branch_version - } - post_dict = {**workflow.dump(), "branch_id": str(branch_data["id"])} - session.set_responses(branch_data_get_resp, {**post_dict, 'status_description': 'status'}, branch_data) + post_dict = {**workflow.dump(), "branch_root_id": new_branch_root_id, "branch_version": new_branch_version} + session.set_responses({**post_dict, 'status_description': 'status'}) # When new_workflow = collection.register(workflow) # Then - assert session.calls == [ - FakeCall(method='GET', path=branches_path(collection), params=branch_data_get_params), - FakeCall(method='POST', path=workflow_path(collection), json=post_dict), - FakeCall(method='GET', path=branches_path(collection, branch_data["id"])), - ] + assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)] assert workflow.branch_root_id == old_branch_root_id assert new_workflow.branch_root_id == collection.branch_root_id @@ -180,10 +163,10 @@ def test_delete(collection): def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollection): - branch_data_get_resp = {"response": [branch_data]} - branch_id = uuid.UUID(branch_data['id']) + branch_root_id = uuid.UUID(branch_data['metadata']['root_id']) + branch_version = branch_data['metadata']['version'] - collection.session.set_responses(branch_data_get_resp, {"response": []}) + collection.session.set_responses({"response": []}) lst = list(collection.list_archived(per_page=10)) assert len(lst) == 0 @@ -192,7 +175,7 @@ def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollecti assert collection.session.last_call == FakeCall( method='GET', path=expected_path, - params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch': branch_id}, + params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch_root_id': branch_root_id, 'branch_version': branch_version}, json=None ) @@ -213,17 +196,10 @@ def test_missing_project(design_workflow_dict): def test_update(session, branch_data, workflow, collection_without_branch): # Given - branch_data_get_resp = {"response": [branch_data]} - branch_data_get_params = { - 'page': 1, 'per_page': 1, 'root': str(workflow.branch_root_id), 'version': workflow.branch_version - } - post_dict = workflow.dump() session.set_responses( - branch_data_get_resp, {"per_page": 1, "next": "", "response": []}, {**post_dict, 'status_description': 'status'}, - branch_data ) # When @@ -232,20 +208,16 @@ def test_update(session, branch_data, workflow, collection_without_branch): # Then executions_path = f'/projects/{collection_without_branch.project_id}/design-workflows/{workflow.uid}/executions' assert session.calls == [ - FakeCall(method='GET', path=branches_path(collection_without_branch), params=branch_data_get_params), FakeCall(method='GET', path=executions_path, params={'page': 1, 'per_page': 100}), FakeCall(method='PUT', path=workflow_path(collection_without_branch, workflow), json=post_dict), - FakeCall(method='GET', path=branches_path(collection_without_branch, branch_data["id"])), ] assert_workflow(new_workflow, workflow) def test_update_failure_with_existing_execution(session, branch_data, workflow, collection_without_branch, design_execution_dict): - branch_data_get_resp = {"response": [branch_data]} workflow.branch_root_id = uuid.uuid4() post_dict = workflow.dump() session.set_responses( - branch_data_get_resp, {"per_page": 1, "next": "", "response": [design_execution_dict]}, {**post_dict, 'status_description': 'status'}) diff --git a/tests/resources/test_workflow.py b/tests/resources/test_workflow.py index 2151ac551..dafffd672 100644 --- a/tests/resources/test_workflow.py +++ b/tests/resources/test_workflow.py @@ -62,12 +62,8 @@ def test_build_design_workflow(session, basic_design_workflow_data): def test_list_workflows(session, basic_design_workflow_data): #Given - branch_data = BranchDataFactory() - branch_data_get_resp = {"response": [branch_data]} - session.set_response(branch_data) - workflow_collection = DesignWorkflowCollection(project_id=uuid.uuid4(), session=session) - session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20}, branch_data) + session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20}) # When workflows = list(workflow_collection.list(per_page=20)) @@ -75,6 +71,6 @@ def test_list_workflows(session, basic_design_workflow_data): # Then expected_design_call = FakeCall(method='GET', path='/projects/{}/modules'.format(workflow_collection.project_id), params={'per_page': 20, 'module_type': 'DESIGN_WORKFLOW'}) - assert 2 == session.num_calls + assert 1 == session.num_calls assert len(workflows) == 1 assert isinstance(workflows[0], DesignWorkflow) diff --git a/tests/seeding/test_find_or_create.py b/tests/seeding/test_find_or_create.py index ac89532de..897f82529 100644 --- a/tests/seeding/test_find_or_create.py +++ b/tests/seeding/test_find_or_create.py @@ -353,17 +353,11 @@ def test_create_or_update_unique_found_design_workflow(session): dw2_dict = DesignWorkflowDataFactory(branch_root_id=root_id, branch_version=version) dw3_dict = DesignWorkflowDataFactory() session.set_responses( - # Build (setup) - branch_data, # Find the model's branch root ID and version # List - {"response": [branch_data]}, # Find the collection's branch version ID {"response": [dw1_dict, dw2_dict, dw3_dict]}, # Return the design workflows - branch_data, branch_data, branch_data, # Lookup the branch root ID and version of each design workflow. # Update - {"response": [branch_data]}, # Lookup the module's branch version ID {"response": []}, # Check if there are any executions dw2_dict, # Return the updated design workflow - branch_data # Lookup the updated design workflow branch root ID and version ) collection = LocalDesignWorkflowCollection(project_id=uuid4(), session=session, branch_root_id=root_id, branch_version=version)