Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/citrine/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.6.0"
__version__ = "3.7.0"
35 changes: 6 additions & 29 deletions src/citrine/informatics/workflows/design_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,17 @@ class DesignWorkflow(Resource['DesignWorkflow'], Workflow, AIResourceMetadata):
predictor_id = properties.Optional(properties.UUID, 'predictor_id')
predictor_version = properties.Optional(
properties.Union([properties.Integer(), properties.String()]), 'predictor_version')
_branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id')
branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id')
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version')
""":Optional[int]: Version number of the branch that contains this workflow."""

status_description = properties.String('status_description', serializable=False)
""":str: more detailed description of the workflow's status"""
typ = properties.String('type', default='DesignWorkflow', deserializable=False)

_branch_root_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_root_id',
serializable=False, deserializable=False)
""":Optional[UUID]: Root ID of the branch that contains this workflow."""
_branch_version: Optional[int] = properties.Optional(properties.Integer, 'branch_version',
serializable=False, deserializable=False)
""":Optional[int]: Version number of the branch that contains this workflow."""
_branch_id: Optional[UUID] = properties.Optional(properties.UUID, 'branch_id',
serializable=False)

def __init__(self,
name: str,
Expand All @@ -68,25 +67,3 @@ def design_executions(self) -> DesignExecutionCollection:
raise AttributeError('Cannot initialize execution without project reference!')
return DesignExecutionCollection(
project_id=self.project_id, session=self._session, workflow_id=self.uid)

@property
def branch_root_id(self):
"""Retrieve the root ID of the branch this workflow is on."""
return self._branch_root_id

@branch_root_id.setter
def branch_root_id(self, value):
"""Set the root ID of the branch this workflow is on."""
self._branch_root_id = value
self._branch_id = None

@property
def branch_version(self):
"""Retrieve the version of the branch this workflow is on."""
return self._branch_version

@branch_version.setter
def branch_version(self, value):
"""Set the version of the branch this workflow is on."""
self._branch_version = value
self._branch_id = None
47 changes: 10 additions & 37 deletions src/citrine/resources/design_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from citrine._rest.collection import Collection
from citrine._session import Session
from citrine.exceptions import NotFound
from citrine.informatics.workflows import DesignWorkflow
from citrine.resources.response import Response
from functools import partial
Expand All @@ -31,25 +30,6 @@ def __init__(self,
self.branch_root_id = branch_root_id
self.branch_version = branch_version

def _resolve_branch_root_and_version(self, workflow):
from citrine.resources.branch import BranchCollection

workflow_copy = deepcopy(workflow)
bc = BranchCollection(self.project_id, self.session)
branch = bc.get_by_version_id(version_id=workflow_copy._branch_id)
workflow_copy._branch_root_id = branch.root_id
workflow_copy._branch_version = branch.version
return workflow_copy

def _resolve_branch_id(self, root_id, version):
from citrine.resources.branch import BranchCollection

if root_id and version:
bc = BranchCollection(self.project_id, self.session)
branch = bc.get(root_id=root_id, version=version)
return branch.uid
return None

def register(self, model: DesignWorkflow) -> DesignWorkflow:
"""
Upload a new design workflow.
Expand Down Expand Up @@ -77,15 +57,15 @@ def register(self, model: DesignWorkflow) -> DesignWorkflow:
'project.design_workflows.register().')
raise RuntimeError(msg)
else:
# branch_id is in the body of design workflow endpoints, so it must be serialized.
# This means the collection branch_id might not match the workflow branch_id. The
# collection should win out, since the user is explicitly referencing the branch
# represented by this collection.
# To avoid modifying the parameter, and to ensure the only change is the branch_id, we
# deepcopy, modify, then register it.
# branch_root_id and branch_version are in the body of design workflow endpoints, so
# they must be serialized. This means the collection fields might not match the
# workflow fields. The collection should win out, since the user is explicitly
# referencing the branch represented by this collection.
# To avoid modifying the parameter, and to ensure the only changes are the
# branch_root_id and branch_version, we deepcopy, modify, then register it.
model_copy = deepcopy(model)
model_copy._branch_id = self._resolve_branch_id(self.branch_root_id,
self.branch_version)
model_copy.branch_root_id = self.branch_root_id
model_copy.branch_version = self.branch_version
return super().register(model_copy)

def build(self, data: dict) -> DesignWorkflow:
Expand All @@ -104,7 +84,6 @@ def build(self, data: dict) -> DesignWorkflow:

"""
workflow = DesignWorkflow.build(data)
workflow = self._resolve_branch_root_and_version(workflow)
workflow._session = self.session
workflow.project_id = self.project_id
return workflow
Expand Down Expand Up @@ -137,13 +116,6 @@ def update(self, model: DesignWorkflow) -> DesignWorkflow:
raise ValueError('Cannot update a design workflow unless its branch_root_id and '
'branch_version are set.')

try:
model._branch_id = self._resolve_branch_id(model.branch_root_id,
model.branch_version)
except NotFound:
raise ValueError('Cannot update a design workflow unless its branch_root_id and '
'branch_version exists.')

# If executions have already been done, warn about future behavior change
executions = model.design_executions.list()
if next(executions, None) is not None:
Expand Down Expand Up @@ -197,7 +169,8 @@ def _fetch_page(self,
additional_params: Optional[dict] = None,
) -> Tuple[Iterable[dict], str]:
params = additional_params or {}
params["branch"] = self._resolve_branch_id(self.branch_root_id, self.branch_version)
params["branch_root_id"] = self.branch_root_id
params["branch_version"] = self.branch_version
return super()._fetch_page(path=path,
fetch_func=fetch_func,
page=page,
Expand Down
4 changes: 2 additions & 2 deletions src/citrine/seeding/find_or_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def create_or_update(*,
# Locally created design workflows likely won't have a branch ID but
# need one to be updated.
if isinstance(old_resource, DesignWorkflow):
new_resource._branch_root_id = old_resource.branch_root_id
new_resource._branch_version = old_resource.branch_version
new_resource.branch_root_id = old_resource.branch_root_id
new_resource.branch_version = old_resource.branch_version
return collection.update(new_resource)
else:
logger.info("Registering new module: {}".format(resource.name))
Expand Down
23 changes: 23 additions & 0 deletions tests/resources/test_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ def test_branch_get(session, collection, branch_path):
assert session.last_call == FakeCall(method='GET', path=branch_path, params={'page': 1, 'per_page': 1, 'root': root_id, 'version': version})


def test_branch_get_not_found(session, collection, branch_path):
# Given
session.set_response({"response": []})

# When
with pytest.raises(NotFound):
collection.get(root_id=uuid.uuid4(), version=1)


def test_branch_get_by_version_id(session, collection, branch_path):
# Given
branch_data = BranchDataFactory()
version_id = branch_data['id']
session.set_response(branch_data)

# When
branch = collection.get_by_version_id(version_id=version_id)

# Then
assert session.num_calls == 1
assert session.last_call == FakeCall(method='GET', path=f"{branch_path}/{version_id}")


def test_branch_list(session, collection, branch_path):
# Given
branch_count = 5
Expand Down
64 changes: 18 additions & 46 deletions tests/resources/test_design_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def collection(branch_data, collection_without_branch) -> DesignWorkflowCollecti

@pytest.fixture
def workflow(collection, branch_data, design_workflow_dict) -> DesignWorkflow:
design_workflow_dict["branch_id"] = branch_data["id"]
design_workflow_dict["branch_root_id"] = branch_data["metadata"]["root_id"]
design_workflow_dict["branch_version"] = branch_data["metadata"]["version"]

collection.session.set_response(branch_data)
workflow = collection.build(design_workflow_dict)
Expand Down Expand Up @@ -71,12 +72,6 @@ def workflow_path(collection, workflow=None):
path = f'{path}/{workflow.uid}'
return path

def branches_path(collection, branch_id=None):
path = f'/projects/{collection.project_id}/branches'
if branch_id:
path = f'{path}/{branch_id}'
return path

def assert_workflow(actual, expected, *, include_branch=False):
assert actual.name == expected.name
assert actual.description == expected.description
Expand All @@ -86,7 +81,7 @@ def assert_workflow(actual, expected, *, include_branch=False):
assert actual.predictor_version == expected.predictor_version
assert actual.project_id == expected.project_id
if include_branch:
assert actual.branch_id == expected.branch_id
assert actual._branch_id == expected._branch_id
assert actual.branch_root_id == expected.branch_root_id
assert actual.branch_version == expected.branch_version

Expand All @@ -99,29 +94,22 @@ def test_basic_methods(workflow, collection, design_workflow_dict):
@pytest.mark.parametrize("optional_args", all_combination_lengths(OPTIONAL_ARGS))
def test_register(session, branch_data, workflow_minimal, collection, optional_args):
workflow = workflow_minimal
branch_id = branch_data['id']
branch_data_get_resp = {"response": [branch_data]}
branch_data_get_params = {
'page': 1, 'per_page': 1, 'root': str(collection.branch_root_id), 'version': collection.branch_version
}
branch_root_id = branch_data['metadata']['root_id']
branch_version = branch_data['metadata']['version']

# Set a random value for all optional args selected for this run.
for name, factory in optional_args:
setattr(workflow, name, factory())

# Given
post_dict = {**workflow.dump(), "branch_id": str(branch_id)}
session.set_responses(branch_data_get_resp, {**post_dict, 'status_description': 'status'}, branch_data)
post_dict = {**workflow.dump(), "branch_root_id": str(branch_root_id), "branch_version": branch_version}
session.set_responses({**post_dict, 'status_description': 'status'})

# When
new_workflow = collection.register(workflow)

# Then
assert session.calls == [
FakeCall(method='GET', path=branches_path(collection), params=branch_data_get_params),
FakeCall(method='POST', path=workflow_path(collection), json=post_dict),
FakeCall(method='GET', path=branches_path(collection, branch_id)),
]
assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)]

assert new_workflow.branch_root_id == collection.branch_root_id
assert new_workflow.branch_version == collection.branch_version
Expand All @@ -133,23 +121,18 @@ def test_register_conflicting_branches(session, branch_data, workflow, collectio
old_branch_root_id = uuid.uuid4()
workflow.branch_root_id = old_branch_root_id
assert workflow.branch_root_id != collection.branch_root_id

new_branch_root_id = str(branch_data["metadata"]["root_id"])
new_branch_version = branch_data["metadata"]["version"]

branch_data_get_resp = {"response": [branch_data]}
branch_data_get_params = {
'page': 1, 'per_page': 1, 'root': str(collection.branch_root_id), 'version': collection.branch_version
}
post_dict = {**workflow.dump(), "branch_id": str(branch_data["id"])}
session.set_responses(branch_data_get_resp, {**post_dict, 'status_description': 'status'}, branch_data)
post_dict = {**workflow.dump(), "branch_root_id": new_branch_root_id, "branch_version": new_branch_version}
session.set_responses({**post_dict, 'status_description': 'status'})

# When
new_workflow = collection.register(workflow)

# Then
assert session.calls == [
FakeCall(method='GET', path=branches_path(collection), params=branch_data_get_params),
FakeCall(method='POST', path=workflow_path(collection), json=post_dict),
FakeCall(method='GET', path=branches_path(collection, branch_data["id"])),
]
assert session.calls == [FakeCall(method='POST', path=workflow_path(collection), json=post_dict)]

assert workflow.branch_root_id == old_branch_root_id
assert new_workflow.branch_root_id == collection.branch_root_id
Expand Down Expand Up @@ -180,10 +163,10 @@ def test_delete(collection):


def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollection):
branch_data_get_resp = {"response": [branch_data]}
branch_id = uuid.UUID(branch_data['id'])
branch_root_id = uuid.UUID(branch_data['metadata']['root_id'])
branch_version = branch_data['metadata']['version']

collection.session.set_responses(branch_data_get_resp, {"response": []})
collection.session.set_responses({"response": []})

lst = list(collection.list_archived(per_page=10))
assert len(lst) == 0
Expand All @@ -192,7 +175,7 @@ def test_list_archived(branch_data, workflow, collection: DesignWorkflowCollecti
assert collection.session.last_call == FakeCall(
method='GET',
path=expected_path,
params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch': branch_id},
params={'page': 1, 'per_page': 10, 'filter': "archived eq 'true'", 'branch_root_id': branch_root_id, 'branch_version': branch_version},
json=None
)

Expand All @@ -213,17 +196,10 @@ def test_missing_project(design_workflow_dict):

def test_update(session, branch_data, workflow, collection_without_branch):
# Given
branch_data_get_resp = {"response": [branch_data]}
branch_data_get_params = {
'page': 1, 'per_page': 1, 'root': str(workflow.branch_root_id), 'version': workflow.branch_version
}

post_dict = workflow.dump()
session.set_responses(
branch_data_get_resp,
{"per_page": 1, "next": "", "response": []},
{**post_dict, 'status_description': 'status'},
branch_data
)

# When
Expand All @@ -232,20 +208,16 @@ def test_update(session, branch_data, workflow, collection_without_branch):
# Then
executions_path = f'/projects/{collection_without_branch.project_id}/design-workflows/{workflow.uid}/executions'
assert session.calls == [
FakeCall(method='GET', path=branches_path(collection_without_branch), params=branch_data_get_params),
FakeCall(method='GET', path=executions_path, params={'page': 1, 'per_page': 100}),
FakeCall(method='PUT', path=workflow_path(collection_without_branch, workflow), json=post_dict),
FakeCall(method='GET', path=branches_path(collection_without_branch, branch_data["id"])),
]
assert_workflow(new_workflow, workflow)


def test_update_failure_with_existing_execution(session, branch_data, workflow, collection_without_branch, design_execution_dict):
branch_data_get_resp = {"response": [branch_data]}
workflow.branch_root_id = uuid.uuid4()
post_dict = workflow.dump()
session.set_responses(
branch_data_get_resp,
{"per_page": 1, "next": "", "response": [design_execution_dict]},
{**post_dict, 'status_description': 'status'})

Expand Down
8 changes: 2 additions & 6 deletions tests/resources/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,15 @@ def test_build_design_workflow(session, basic_design_workflow_data):

def test_list_workflows(session, basic_design_workflow_data):
#Given
branch_data = BranchDataFactory()
branch_data_get_resp = {"response": [branch_data]}
session.set_response(branch_data)

workflow_collection = DesignWorkflowCollection(project_id=uuid.uuid4(), session=session)
session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20}, branch_data)
session.set_responses({'response': [basic_design_workflow_data], 'page': 1, 'per_page': 20})

# When
workflows = list(workflow_collection.list(per_page=20))

# Then
expected_design_call = FakeCall(method='GET', path='/projects/{}/modules'.format(workflow_collection.project_id),
params={'per_page': 20, 'module_type': 'DESIGN_WORKFLOW'})
assert 2 == session.num_calls
assert 1 == session.num_calls
assert len(workflows) == 1
assert isinstance(workflows[0], DesignWorkflow)
6 changes: 0 additions & 6 deletions tests/seeding/test_find_or_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,11 @@ def test_create_or_update_unique_found_design_workflow(session):
dw2_dict = DesignWorkflowDataFactory(branch_root_id=root_id, branch_version=version)
dw3_dict = DesignWorkflowDataFactory()
session.set_responses(
# Build (setup)
branch_data, # Find the model's branch root ID and version
# List
{"response": [branch_data]}, # Find the collection's branch version ID
{"response": [dw1_dict, dw2_dict, dw3_dict]}, # Return the design workflows
branch_data, branch_data, branch_data, # Lookup the branch root ID and version of each design workflow.
# Update
{"response": [branch_data]}, # Lookup the module's branch version ID
{"response": []}, # Check if there are any executions
dw2_dict, # Return the updated design workflow
branch_data # Lookup the updated design workflow branch root ID and version
)

collection = LocalDesignWorkflowCollection(project_id=uuid4(), session=session, branch_root_id=root_id, branch_version=version)
Expand Down