diff --git a/src/sagemaker/experiments/experiment.py b/src/sagemaker/experiments/experiment.py index 8f59ff36b3..824734d294 100644 --- a/src/sagemaker/experiments/experiment.py +++ b/src/sagemaker/experiments/experiment.py @@ -15,6 +15,8 @@ import time +from botocore.exceptions import ClientError + from sagemaker.apiutils import _base_types from sagemaker.experiments.trial import _Trial from sagemaker.experiments.trial_component import _TrialComponent @@ -154,10 +156,7 @@ def _load_or_create( Returns: experiments.experiment._Experiment: A SageMaker `_Experiment` object """ - sagemaker_client = sagemaker_session.sagemaker_client try: - experiment = _Experiment.load(experiment_name, sagemaker_session) - except sagemaker_client.exceptions.ResourceNotFound: experiment = _Experiment.create( experiment_name=experiment_name, display_name=display_name, @@ -165,6 +164,13 @@ def _load_or_create( tags=tags, sagemaker_session=sagemaker_session, ) + except ClientError as ce: + error_code = ce.response["Error"]["Code"] + error_message = ce.response["Error"]["Message"] + if not (error_code == "ValidationException" and "already exists" in error_message): + raise ce + # already exists + experiment = _Experiment.load(experiment_name, sagemaker_session) return experiment def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None): diff --git a/src/sagemaker/experiments/trial.py b/src/sagemaker/experiments/trial.py index 146b24f18b..ce8deb4862 100644 --- a/src/sagemaker/experiments/trial.py +++ b/src/sagemaker/experiments/trial.py @@ -13,6 +13,8 @@ """Contains the Trial class.""" from __future__ import absolute_import +from botocore.exceptions import ClientError + from sagemaker.apiutils import _base_types from sagemaker.experiments import _api_types from sagemaker.experiments.trial_component import _TrialComponent @@ -268,8 +270,20 @@ def _load_or_create( Returns: experiments.trial._Trial: A SageMaker `_Trial` object """ - sagemaker_client = sagemaker_session.sagemaker_client try: + trial = _Trial.create( + experiment_name=experiment_name, + trial_name=trial_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + except ClientError as ce: + error_code = ce.response["Error"]["Code"] + error_message = ce.response["Error"]["Message"] + if not (error_code == "ValidationException" and "already exists" in error_message): + raise ce + # already exists trial = _Trial.load(trial_name, sagemaker_session) if trial.experiment_name != experiment_name: # pylint: disable=no-member raise ValueError( @@ -278,12 +292,4 @@ def _load_or_create( trial.experiment_name # pylint: disable=no-member ) ) - except sagemaker_client.exceptions.ResourceNotFound: - trial = _Trial.create( - experiment_name=experiment_name, - trial_name=trial_name, - display_name=display_name, - tags=tags, - sagemaker_session=sagemaker_session, - ) return trial diff --git a/src/sagemaker/experiments/trial_component.py b/src/sagemaker/experiments/trial_component.py index e5701b2119..a85a3dba8a 100644 --- a/src/sagemaker/experiments/trial_component.py +++ b/src/sagemaker/experiments/trial_component.py @@ -15,6 +15,8 @@ import time +from botocore.exceptions import ClientError + from sagemaker.apiutils import _base_types from sagemaker.experiments import _api_types from sagemaker.experiments._api_types import TrialComponentSearchResult @@ -326,16 +328,20 @@ def _load_or_create( experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. bool: A boolean variable indicating whether the trail component already exists """ - sagemaker_client = sagemaker_session.sagemaker_client is_existed = False try: - run_tc = _TrialComponent.load(trial_component_name, sagemaker_session) - is_existed = True - except sagemaker_client.exceptions.ResourceNotFound: run_tc = _TrialComponent.create( trial_component_name=trial_component_name, display_name=display_name, tags=tags, sagemaker_session=sagemaker_session, ) + except ClientError as ce: + error_code = ce.response["Error"]["Code"] + error_message = ce.response["Error"]["Message"] + if not (error_code == "ValidationException" and "already exists" in error_message): + raise ce + # already exists + run_tc = _TrialComponent.load(trial_component_name, sagemaker_session) + is_existed = True return run_tc, is_existed diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 9f0bd3293c..2f5191bc30 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3223,14 +3223,11 @@ def create_model_package_from_containers( def submit(request): if model_package_group_name is not None: - try: - self.sagemaker_client.describe_model_package_group( - ModelPackageGroupName=request["ModelPackageGroupName"] - ) - except ClientError: - self.sagemaker_client.create_model_package_group( + _create_resource( + lambda: self.sagemaker_client.create_model_package_group( ModelPackageGroupName=request["ModelPackageGroupName"] ) + ) return self.sagemaker_client.create_model_package(**request) return self._intercept_create_request( @@ -3918,33 +3915,22 @@ def endpoint_from_model_data( name = name or name_from_image(image_uri) model_vpc_config = vpc_utils.sanitize(model_vpc_config) - if _deployment_entity_exists( - lambda: self.sagemaker_client.describe_endpoint(EndpointName=name) - ): - raise ValueError( - 'Endpoint with name "{}" already exists; please pick a different name.'.format(name) - ) + primary_container = container_def( + image_uri=image_uri, + model_data_url=model_s3_location, + env=model_environment_vars, + ) - if not _deployment_entity_exists( - lambda: self.sagemaker_client.describe_model(ModelName=name) - ): - primary_container = container_def( - image_uri=image_uri, - model_data_url=model_s3_location, - env=model_environment_vars, - ) - self.create_model( - name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config - ) + self.create_model( + name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config + ) data_capture_config_dict = None if data_capture_config is not None: data_capture_config_dict = data_capture_config._to_request_dict() - if not _deployment_entity_exists( - lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name) - ): - self.create_endpoint_config( + _create_resource( + lambda: self.create_endpoint_config( name=name, model_name=name, initial_instance_count=initial_instance_count, @@ -3952,8 +3938,17 @@ def endpoint_from_model_data( accelerator_type=accelerator_type, data_capture_config_dict=data_capture_config_dict, ) + ) + + # to make change backwards compatible + response = _create_resource( + lambda: self.create_endpoint(endpoint_name=name, config_name=name, wait=wait) + ) + if not response: + raise ValueError( + 'Endpoint with name "{}" already exists; please pick a different name.'.format(name) + ) - self.create_endpoint(endpoint_name=name, config_name=name, wait=wait) return name def endpoint_from_production_variants( @@ -5452,34 +5447,54 @@ def _deployment_entity_exists(describe_fn): return False +def _create_resource(create_fn): + """Call create function and accepts/pass when resource already exists. + + This is a helper function to use an existing resource if found when creating. + + Args: + create_fn: Create resource function. + + Returns: + (bool): True if new resource was created, False if resource already exists. + """ + try: + create_fn() + # create function succeeded, resource does not exist already + return True + except ClientError as ce: + error_code = ce.response["Error"]["Code"] + error_message = ce.response["Error"]["Message"] + already_exists_exceptions = ["ValidationException", "ResourceInUse"] + already_exists_msg_patterns = ["Cannot create already existing", "already exists"] + if not ( + error_code in already_exists_exceptions + and any(p in error_message for p in already_exists_msg_patterns) + ): + raise ce + # no new resource created as resource already exists + return False + + def _train_done(sagemaker_client, job_name, last_desc): """Placeholder docstring""" in_progress_statuses = ["InProgress", "Created"] - for _ in retries( - max_retry_count=10, # 10*30 = 5min - exception_message_prefix="Waiting for schedule to leave 'Pending' status", - seconds_to_sleep=30, - ): - try: - desc = sagemaker_client.describe_training_job(TrainingJobName=job_name) - status = desc["TrainingJobStatus"] + desc = sagemaker_client.describe_training_job(TrainingJobName=job_name) + status = desc["TrainingJobStatus"] - if secondary_training_status_changed(desc, last_desc): - print() - print(secondary_training_status_message(desc, last_desc), end="") - else: - print(".", end="") - sys.stdout.flush() + if secondary_training_status_changed(desc, last_desc): + print() + print(secondary_training_status_message(desc, last_desc), end="") + else: + print(".", end="") + sys.stdout.flush() - if status in in_progress_statuses: - return desc, False + if status in in_progress_statuses: + return desc, False - print() - return desc, True - except botocore.exceptions.ClientError as err: - if err.response["Error"]["Code"] == "AccessDeniedException": - pass + print() + return desc, True def _processing_job_status(sagemaker_client, job_name): @@ -5799,19 +5814,54 @@ def _deploy_done(sagemaker_client, endpoint_name): def _wait_until_training_done(callable_fn, desc, poll=5): """Placeholder docstring""" - job_desc, finished = callable_fn(desc) + elapsed_time = 0 + finished = None + job_desc = desc while not finished: - time.sleep(poll) - job_desc, finished = callable_fn(job_desc) + try: + elapsed_time += poll + time.sleep(poll) + job_desc, finished = callable_fn(job_desc) + except botocore.exceptions.ClientError as err: + # For initial 5 mins we accept/pass AccessDeniedException. + # The reason is to await tag propagation to avoid false AccessDenied claims for an + # access policy based on resource tags, The caveat here is for true AccessDenied + # cases the routine will fail after 5 mins + if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300: + LOGGER.warning( + "Received AccessDeniedException. This could mean the IAM role does not " + "have the resource permissions, in which case please add resource access " + "and retry. For cases where the role has tag based resource policy, " + "continuing to wait for tag propagation.." + ) + continue + raise err return job_desc def _wait_until(callable_fn, poll=5): """Placeholder docstring""" - result = callable_fn() + elapsed_time = 0 + result = None while result is None: - time.sleep(poll) - result = callable_fn() + try: + elapsed_time += poll + time.sleep(poll) + result = callable_fn() + except botocore.exceptions.ClientError as err: + # For initial 5 mins we accept/pass AccessDeniedException. + # The reason is to await tag propagation to avoid false AccessDenied claims for an + # access policy based on resource tags, The caveat here is for true AccessDenied + # cases the routine will fail after 5 mins + if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300: + LOGGER.warning( + "Received AccessDeniedException. This could mean the IAM role does not " + "have the resource permissions, in which case please add resource access " + "and retry. For cases where the role has tag based resource policy, " + "continuing to wait for tag propagation.." + ) + continue + raise err return result diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 2f1870f1fc..7da9ced131 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -604,12 +604,17 @@ def retries( ) -def retry_with_backoff(callable_func, num_attempts=8): +def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None): """Retry with backoff until maximum attempts are reached Args: callable_func (callable): The callable function to retry. - num_attempts (int): The maximum number of attempts to retry. + num_attempts (int): The maximum number of attempts to retry.(Default: 8) + botocore_client_error_code (str): The specific Botocore ClientError exception error code + on which to retry on. + If provided other exceptions will be raised directly w/o retry. + If not provided, retry on any exception. + (Default: None) """ if num_attempts < 1: raise ValueError( @@ -619,7 +624,15 @@ def retry_with_backoff(callable_func, num_attempts=8): try: return callable_func() except Exception as ex: # pylint: disable=broad-except - if i == num_attempts - 1: + if not botocore_client_error_code or ( + botocore_client_error_code + and isinstance(ex, botocore.exceptions.ClientError) + and ex.response["Error"]["Code"] # pylint: disable=no-member + == botocore_client_error_code + ): + if i == num_attempts - 1: + raise ex + else: raise ex logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex)) time.sleep(2**i) diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index aa9fac6216..95d0702ec8 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -26,6 +26,7 @@ from sagemaker import s3 from sagemaker._studio import _append_project_tags from sagemaker.session import Session +from sagemaker.utils import retry_with_backoff from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep from sagemaker.workflow.entities import ( @@ -241,20 +242,16 @@ def upsert( Returns: response dict from service """ - exists = True try: - self.describe() - except ClientError as e: - err = e.response.get("Error", {}) - if err.get("Code", None) == "ResourceNotFound": - exists = False - else: - raise e - - if not exists: response = self.create(role_arn, description, tags, parallelism_config) - else: + except ClientError as ce: + error_code = ce.response["Error"]["Code"] + error_message = ce.response["Error"]["Message"] + if not (error_code == "ValidationException" and "already exists" in error_message): + raise ce + # already exists response = self.update(role_arn, description) + # add new tags to existing resource if tags is not None: old_tags = self.sagemaker_session.sagemaker_client.list_tags( ResourceArn=response["PipelineArn"] @@ -310,7 +307,12 @@ def start( update_args(kwargs, PipelineParameters=parameters) return self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs) update_args(kwargs, PipelineParameters=format_start_parameters(parameters)) - response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs) + + # retry on AccessDeniedException to cover case of tag propagation delay + response = retry_with_backoff( + lambda: self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs), + botocore_client_error_code="AccessDeniedException", + ) return _PipelineExecution( arn=response["PipelineExecutionArn"], sagemaker_session=self.sagemaker_session, diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index a26d8c9101..eb429e5e79 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -28,7 +28,7 @@ from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.sparkml.model import SparkMLModel -from sagemaker.utils import sagemaker_timestamp +from sagemaker.utils import unique_name_from_base SPARKML_DATA_PATH = os.path.join(DATA_DIR, "sparkml_model") XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model") @@ -60,7 +60,7 @@ def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type path=os.path.join(XGBOOST_DATA_PATH, "xgb_model.tar.gz"), key_prefix="integ-test-data/xgboost/model", ) - batch_job_name = "test-inference-pipeline-batch-{}".format(sagemaker_timestamp()) + batch_job_name = unique_name_from_base("test-inference-pipeline-batch") sparkml_model = SparkMLModel( model_data=sparkml_model_data, env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA}, @@ -99,7 +99,7 @@ def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type): sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model") xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model") - endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-inference-pipeline-deploy") sparkml_model_data = sagemaker_session.upload_data( path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), key_prefix="integ-test-data/sparkml/model", @@ -156,7 +156,7 @@ def test_inference_pipeline_model_deploy_and_update_endpoint( ): sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model") xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model") - endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-inference-pipeline-deploy") sparkml_model_data = sagemaker_session.upload_data( path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"), key_prefix="integ-test-data/sparkml/model", diff --git a/tests/integ/test_mxnet.py b/tests/integ/test_mxnet.py index dac9221745..3d33816822 100644 --- a/tests/integ/test_mxnet.py +++ b/tests/integ/test_mxnet.py @@ -24,7 +24,7 @@ from sagemaker.mxnet.model import MXNetModel from sagemaker.mxnet.processing import MXNetProcessor from sagemaker.serverless import ServerlessInferenceConfig -from sagemaker.utils import sagemaker_timestamp +from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES from tests.integ.kms_utils import get_or_create_kms_key from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name @@ -98,7 +98,7 @@ def test_framework_processing_job_with_deps( @pytest.mark.release def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type): - endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-attach-deploy") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session) @@ -165,7 +165,7 @@ def test_deploy_model( mxnet_inference_latest_py_version, cpu_instance_type, ): - endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( @@ -200,7 +200,7 @@ def test_register_model_package( mxnet_inference_latest_py_version, cpu_instance_type, ): - endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( @@ -216,7 +216,7 @@ def test_register_model_package( sagemaker_session=sagemaker_session, framework_version=mxnet_inference_latest_version, ) - model_package_name = "register-model-package-{}".format(sagemaker_timestamp()) + model_package_name = unique_name_from_base("register-model-package") model_pkg = model.register( content_types=["application/json"], response_types=["application/json"], @@ -239,13 +239,13 @@ def test_register_model_package_versioned( mxnet_inference_latest_py_version, cpu_instance_type, ): - endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=mxnet_training_job ) - model_package_group_name = "register-model-package-{}".format(sagemaker_timestamp()) + model_package_group_name = unique_name_from_base("register-model-package") sagemaker_session.sagemaker_client.create_model_package_group( ModelPackageGroupName=model_package_group_name ) @@ -287,7 +287,7 @@ def test_deploy_model_with_tags_and_kms( mxnet_inference_latest_py_version, cpu_instance_type, ): - endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( @@ -347,7 +347,7 @@ def test_deploy_model_and_update_endpoint( cpu_instance_type, alternative_cpu_instance_type, ): - endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( @@ -395,7 +395,7 @@ def test_deploy_model_with_accelerator( mxnet_eia_latest_py_version, cpu_instance_type, ): - endpoint_name = "test-mxnet-deploy-model-ei-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-deploy-model-ei") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( @@ -426,7 +426,7 @@ def test_deploy_model_with_serverless_inference_config( mxnet_inference_latest_version, mxnet_inference_latest_py_version, ): - endpoint_name = "test-mxnet-deploy-model-serverless-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-deploy-model-serverless") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( @@ -465,7 +465,7 @@ def test_async_fit( mxnet_inference_latest_py_version, cpu_instance_type, ): - endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-mxnet-attach-deploy") with timeout(minutes=5): script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py") diff --git a/tests/integ/test_pytorch.py b/tests/integ/test_pytorch.py index 5e3f227e58..3829e1331e 100644 --- a/tests/integ/test_pytorch.py +++ b/tests/integ/test_pytorch.py @@ -20,7 +20,7 @@ from sagemaker.pytorch.model import PyTorchModel from sagemaker.pytorch.processing import PyTorchProcessor from sagemaker.serverless import ServerlessInferenceConfig -from sagemaker.utils import sagemaker_timestamp +from sagemaker.utils import unique_name_from_base from tests.integ import ( test_region, DATA_DIR, @@ -130,7 +130,7 @@ def test_framework_processing_job_with_deps( def test_fit_deploy( pytorch_training_job_with_latest_infernce_version, sagemaker_session, cpu_instance_type ): - endpoint_name = "test-pytorch-sync-fit-attach-deploy{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-pytorch-sync-fit-attach-deploy") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): estimator = PyTorch.attach( pytorch_training_job_with_latest_infernce_version, sagemaker_session=sagemaker_session @@ -180,7 +180,7 @@ def test_deploy_model( pytorch_inference_latest_version, pytorch_inference_latest_py_version, ): - endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-pytorch-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( @@ -210,7 +210,7 @@ def test_deploy_packed_model_with_entry_point_name( pytorch_inference_latest_version, pytorch_inference_latest_py_version, ): - endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-pytorch-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): model_data = sagemaker_session.upload_data(path=PACKED_MODEL) @@ -240,7 +240,7 @@ def test_deploy_model_with_accelerator( pytorch_eia_latest_version, pytorch_eia_latest_py_version, ): - endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-pytorch-deploy-eia") model_data = sagemaker_session.upload_data(path=EIA_MODEL) pytorch = PyTorchModel( model_data, @@ -272,7 +272,7 @@ def test_deploy_model_with_serverless_inference_config( pytorch_inference_latest_version, pytorch_inference_latest_py_version, ): - endpoint_name = "test-pytorch-deploy-model-serverless-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-pytorch-deploy-model-serverless") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( diff --git a/tests/integ/test_sklearn.py b/tests/integ/test_sklearn.py index ad05acdb75..839e601d34 100644 --- a/tests/integ/test_sklearn.py +++ b/tests/integ/test_sklearn.py @@ -20,7 +20,7 @@ from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.sklearn import SKLearn, SKLearnModel, SKLearnProcessor -from sagemaker.utils import sagemaker_timestamp, unique_name_from_base +from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name @@ -119,7 +119,7 @@ def test_training_with_network_isolation( "This test should be fixed. Details in https://github.com/aws/sagemaker-python-sdk/pull/968" ) def test_attach_deploy(sklearn_training_job, sagemaker_session, cpu_instance_type): - endpoint_name = "test-sklearn-attach-deploy-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-sklearn-attach-deploy") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): estimator = SKLearn.attach(sklearn_training_job, sagemaker_session=sagemaker_session) @@ -138,7 +138,7 @@ def test_deploy_model( sklearn_latest_version, sklearn_latest_py_version, ): - endpoint_name = "test-sklearn-deploy-model-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-sklearn-deploy-model") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=sklearn_training_job @@ -162,7 +162,7 @@ def test_deploy_model_with_serverless_inference_config( sklearn_latest_version, sklearn_latest_py_version, ): - endpoint_name = "test-sklearn-deploy-model-serverless-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-sklearn-deploy-model-serverless") with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session): desc = sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=sklearn_training_job @@ -192,7 +192,7 @@ def test_async_fit( sklearn_latest_version, sklearn_latest_py_version, ): - endpoint_name = "test-sklearn-attach-deploy-{}".format(sagemaker_timestamp()) + endpoint_name = unique_name_from_base("test-sklearn-attach-deploy") with timeout(minutes=5): training_job_name = _run_mnist_training_job( diff --git a/tests/unit/common.py b/tests/unit/common.py new file mode 100644 index 0000000000..2fb6abd3be --- /dev/null +++ b/tests/unit/common.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +from botocore.exceptions import ClientError + + +def _raise_unexpected_client_error(**kwargs): + response = { + "Error": {"Code": "ValidationException", "Message": "Name does not satisfy expression."} + } + raise ClientError(error_response=response, operation_name="foo") + + +def _raise_does_not_exist_client_error(**kwargs): + response = {"Error": {"Code": "ValidationException", "Message": "Could not find entity."}} + raise ClientError(error_response=response, operation_name="foo") + + +def _raise_does_already_exists_client_error(**kwargs): + response = {"Error": {"Code": "ValidationException", "Message": "Resource already exists."}} + raise ClientError(error_response=response, operation_name="foo") + + +def _raise_access_denied_client_error(**kwargs): + response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}} + raise ClientError(error_response=response, operation_name="foo") diff --git a/tests/unit/sagemaker/experiments/test_experiment.py b/tests/unit/sagemaker/experiments/test_experiment.py index b0ad55c27f..e6cac54a92 100644 --- a/tests/unit/sagemaker/experiments/test_experiment.py +++ b/tests/unit/sagemaker/experiments/test_experiment.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import botocore.exceptions import pytest import unittest.mock import datetime @@ -78,11 +79,29 @@ def test_delete(sagemaker_session): @patch("sagemaker.experiments.experiment._Experiment.load") -def test_load_or_create_when_exist(mock_load, sagemaker_session): +@patch("sagemaker.experiments.experiment._Experiment.create") +def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session): exp_name = "exp_name" + exists_error = botocore.exceptions.ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Experiment with name (experiment-xyz) already exists.", + } + }, + operation_name="foo", + ) + mock_create.side_effect = exists_error experiment._Experiment._load_or_create( experiment_name=exp_name, sagemaker_session=sagemaker_session ) + mock_create.assert_called_once_with( + experiment_name=exp_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=sagemaker_session, + ) mock_load.assert_called_once_with(exp_name, sagemaker_session) @@ -90,18 +109,10 @@ def test_load_or_create_when_exist(mock_load, sagemaker_session): @patch("sagemaker.experiments.experiment._Experiment.create") def test_load_or_create_when_not_exist(mock_create, mock_load): sagemaker_session = Session() - client = sagemaker_session.sagemaker_client exp_name = "exp_name" - not_found_err = client.exceptions.ResourceNotFound( - error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, - operation_name="foo", - ) - mock_load.side_effect = not_found_err - experiment._Experiment._load_or_create( experiment_name=exp_name, sagemaker_session=sagemaker_session ) - mock_create.assert_called_once_with( experiment_name=exp_name, display_name=None, @@ -109,6 +120,7 @@ def test_load_or_create_when_not_exist(mock_create, mock_load): tags=None, sagemaker_session=sagemaker_session, ) + mock_load.assert_not_called() def test_list_trials_empty(sagemaker_session): diff --git a/tests/unit/sagemaker/experiments/test_trial.py b/tests/unit/sagemaker/experiments/test_trial.py index f6996fefc3..efb29a1161 100644 --- a/tests/unit/sagemaker/experiments/test_trial.py +++ b/tests/unit/sagemaker/experiments/test_trial.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import botocore import pytest import datetime @@ -133,11 +134,21 @@ def test_remove_trial_component(sagemaker_session): @patch("sagemaker.experiments.trial._Trial.load") -def test_load_or_create_when_exist(mock_load): +@patch("sagemaker.experiments.trial._Trial.create") +def test_load_or_create_when_exist(mock_create, mock_load): sagemaker_session = Session() trial_name = "trial_name" exp_name = "exp_name" - + exists_error = botocore.exceptions.ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Experiment with name (experiment-xyz) already exists.", + } + }, + operation_name="foo", + ) + mock_create.side_effect = exists_error # The trial exists and experiment matches mock_load.return_value = _Trial( trial_name=trial_name, @@ -147,6 +158,13 @@ def test_load_or_create_when_exist(mock_load): _Trial._load_or_create( trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session ) + mock_create.assert_called_once_with( + trial_name=trial_name, + experiment_name=exp_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) mock_load.assert_called_once_with(trial_name, sagemaker_session) # The trial exists but experiment does not match @@ -168,14 +186,8 @@ def test_load_or_create_when_exist(mock_load): @patch("sagemaker.experiments.trial._Trial.create") def test_load_or_create_when_not_exist(mock_create, mock_load): sagemaker_session = Session() - client = sagemaker_session.sagemaker_client trial_name = "trial_name" exp_name = "exp_name" - not_found_err = client.exceptions.ResourceNotFound( - error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, - operation_name="foo", - ) - mock_load.side_effect = not_found_err _Trial._load_or_create( trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session @@ -188,6 +200,7 @@ def test_load_or_create_when_not_exist(mock_create, mock_load): tags=None, sagemaker_session=sagemaker_session, ) + mock_load.assert_not_called() def test_list_trials_without_experiment_name(sagemaker_session, datetime_obj): diff --git a/tests/unit/sagemaker/experiments/test_trial_component.py b/tests/unit/sagemaker/experiments/test_trial_component.py index c14663893e..c75a76a556 100644 --- a/tests/unit/sagemaker/experiments/test_trial_component.py +++ b/tests/unit/sagemaker/experiments/test_trial_component.py @@ -17,6 +17,8 @@ from unittest.mock import patch +import botocore + from sagemaker import Session from sagemaker.experiments import _api_types from sagemaker.experiments._api_types import ( @@ -300,11 +302,28 @@ def test_list_trial_components_call_args(sagemaker_session): @patch("sagemaker.experiments.trial_component._TrialComponent.load") -def test_load_or_create_when_exist(mock_load, sagemaker_session): +@patch("sagemaker.experiments.trial_component._TrialComponent.create") +def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session): tc_name = "tc_name" + exists_error = botocore.exceptions.ClientError( + error_response={ + "Error": { + "Code": "ValidationException", + "Message": "Experiment with name (experiment-xyz) already exists.", + } + }, + operation_name="foo", + ) + mock_create.side_effect = exists_error _, is_existed = _TrialComponent._load_or_create( trial_component_name=tc_name, sagemaker_session=sagemaker_session ) + mock_create.assert_called_once_with( + trial_component_name=tc_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) assert is_existed mock_load.assert_called_once_with( tc_name, @@ -316,13 +335,7 @@ def test_load_or_create_when_exist(mock_load, sagemaker_session): @patch("sagemaker.experiments.trial_component._TrialComponent.create") def test_load_or_create_when_not_exist(mock_create, mock_load): sagemaker_session = Session() - client = sagemaker_session.sagemaker_client tc_name = "tc_name" - not_found_err = client.exceptions.ResourceNotFound( - error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, - operation_name="foo", - ) - mock_load.side_effect = not_found_err _, is_existed = _TrialComponent._load_or_create( trial_component_name=tc_name, sagemaker_session=sagemaker_session @@ -335,6 +348,7 @@ def test_load_or_create_when_not_exist(mock_create, mock_load): tags=None, sagemaker_session=sagemaker_session, ) + mock_load.assert_not_called() def test_search(sagemaker_session): diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index f0cb2e5234..813a945cf4 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -33,6 +33,7 @@ from sagemaker.workflow.step_collections import StepCollection from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep from sagemaker.local.local_session import LocalSession +from botocore.exceptions import ClientError @pytest.fixture @@ -173,10 +174,17 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn): ) -def test_pipeline_upsert(sagemaker_session_mock, role_arn): - sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = { - "PipelineArn": "pipeline-arn" - } +def test_pipeline_upsert_resource_already_exists(sagemaker_session_mock, role_arn): + + # case 1: resource already exists + def _raise_does_already_exists_client_error(**kwargs): + response = {"Error": {"Code": "ValidationException", "Message": "Resource already exists."}} + raise ClientError(error_response=response, operation_name="create_pipeline") + + sagemaker_session_mock.sagemaker_client.create_pipeline = Mock( + name="create_pipeline", side_effect=_raise_does_already_exists_client_error + ) + sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = { "PipelineArn": "pipeline-arn" } @@ -197,9 +205,14 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn): ] pipeline.upsert(role_arn=role_arn, tags=tags) - sagemaker_session_mock.sagemaker_client.create_pipeline.assert_not_called() + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with( + PipelineName="MyPipeline", + RoleArn=role_arn, + PipelineDefinition=pipeline.definition(), + Tags=tags, + ) - assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_once_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) assert sagemaker_session_mock.sagemaker_client.list_tags.called_with( @@ -212,6 +225,93 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn): ) +def test_pipeline_upsert_create_unexpected_failure(sagemaker_session_mock, role_arn): + + # case 2: unexpected failure on create + def _raise_unexpected_client_error(**kwargs): + response = { + "Error": {"Code": "ValidationException", "Message": "Name does not satisfy expression."} + } + raise ClientError(error_response=response, operation_name="foo") + + sagemaker_session_mock.sagemaker_client.create_pipeline = Mock( + name="create_pipeline", side_effect=_raise_unexpected_client_error + ) + + sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = { + "PipelineArn": "pipeline-arn" + } + sagemaker_session_mock.sagemaker_client.list_tags.return_value = { + "Tags": [{"Key": "dummy", "Value": "dummy_tag"}] + } + + tags = [ + {"Key": "foo", "Value": "abc"}, + {"Key": "bar", "Value": "xyz"}, + ] + + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[], + sagemaker_session=sagemaker_session_mock, + ) + + with pytest.raises(ClientError): + pipeline.upsert(role_arn=role_arn, tags=tags) + + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with( + PipelineName="MyPipeline", + RoleArn=role_arn, + PipelineDefinition=pipeline.definition(), + Tags=tags, + ) + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_not_called() + sagemaker_session_mock.sagemaker_client.list_tags.assert_not_called() + sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called() + + +def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn): + + # case 3: resource does not exist + sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(name="create_pipeline") + + sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = { + "PipelineArn": "pipeline-arn" + } + sagemaker_session_mock.sagemaker_client.list_tags.return_value = { + "Tags": [{"Key": "dummy", "Value": "dummy_tag"}] + } + + tags = [ + {"Key": "foo", "Value": "abc"}, + {"Key": "bar", "Value": "xyz"}, + ] + + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[], + sagemaker_session=sagemaker_session_mock, + ) + + try: + pipeline.upsert(role_arn=role_arn, tags=tags) + except ClientError: + assert False, "Unexpected ClientError raised" + + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with( + PipelineName="MyPipeline", + RoleArn=role_arn, + PipelineDefinition=pipeline.definition(), + Tags=tags, + ) + + sagemaker_session_mock.sagemaker_client.update_pipeline.assert_not_called() + sagemaker_session_mock.sagemaker_client.list_tags.assert_not_called() + sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called() + + def test_pipeline_delete(sagemaker_session_mock): pipeline = Pipeline( name="MyPipeline", diff --git a/tests/unit/test_endpoint_from_model_data.py b/tests/unit/test_endpoint_from_model_data.py index 64804e2f7d..a8b852836a 100644 --- a/tests/unit/test_endpoint_from_model_data.py +++ b/tests/unit/test_endpoint_from_model_data.py @@ -17,6 +17,11 @@ from mock import MagicMock, Mock from mock import patch +from .common import ( + _raise_unexpected_client_error, + _raise_does_already_exists_client_error, + _raise_does_not_exist_client_error, +) import sagemaker ENDPOINT_NAME = "myendpoint" @@ -39,15 +44,6 @@ def sagemaker_session(): ims = sagemaker.Session( sagemaker_client=MagicMock(name="sagemaker_client"), boto_session=boto_mock ) - ims.sagemaker_client.describe_model = Mock( - name="describe_model", side_effect=_raise_does_not_exist_client_error - ) - ims.sagemaker_client.describe_endpoint_config = Mock( - name="describe_endpoint_config", side_effect=_raise_does_not_exist_client_error - ) - ims.sagemaker_client.describe_endpoint = Mock( - name="describe_endpoint", side_effect=_raise_does_not_exist_client_error - ) ims.create_model = Mock(name="create_model") ims.create_endpoint_config = Mock(name="create_endpoint_config") ims.create_endpoint = Mock(name="create_endpoint") @@ -64,16 +60,6 @@ def test_all_defaults_no_existing_entities(name_from_image_mock, sagemaker_sessi role=DEPLOY_ROLE, wait=False, ) - - sagemaker_session.sagemaker_client.describe_endpoint.assert_called_once_with( - EndpointName=NAME_FROM_IMAGE - ) - sagemaker_session.sagemaker_client.describe_model.assert_called_once_with( - ModelName=NAME_FROM_IMAGE - ) - sagemaker_session.sagemaker_client.describe_endpoint_config.assert_called_once_with( - EndpointConfigName=NAME_FROM_IMAGE - ) sagemaker_session.create_model.assert_called_once_with( name=NAME_FROM_IMAGE, role=DEPLOY_ROLE, container_defs=CONTAINER_DEF, vpc_config=None ) @@ -108,16 +94,6 @@ def test_no_defaults_no_existing_entities(name_from_image_mock, sagemaker_sessio model_vpc_config=VPC_CONFIG, accelerator_type=ACCELERATOR_TYPE, ) - - sagemaker_session.sagemaker_client.describe_endpoint.assert_called_once_with( - EndpointName=ENDPOINT_NAME - ) - sagemaker_session.sagemaker_client.describe_model.assert_called_once_with( - ModelName=ENDPOINT_NAME - ) - sagemaker_session.sagemaker_client.describe_endpoint_config.assert_called_once_with( - EndpointConfigName=ENDPOINT_NAME - ) sagemaker_session.create_model.assert_called_once_with( name=ENDPOINT_NAME, role=DEPLOY_ROLE, @@ -140,27 +116,93 @@ def test_no_defaults_no_existing_entities(name_from_image_mock, sagemaker_sessio @patch("sagemaker.session.name_from_image", return_value=NAME_FROM_IMAGE) def test_model_and_endpoint_config_exist(name_from_image_mock, sagemaker_session): - sagemaker_session.sagemaker_client.describe_model = Mock(name="describe_model") - sagemaker_session.sagemaker_client.describe_endpoint_config = Mock( - name="describe_endpoint_config" + container_def_with_env = CONTAINER_DEF.copy() + + sagemaker_session.create_endpoint_config = Mock( + name="create_endpoint_config", side_effect=_raise_does_already_exists_client_error ) - sagemaker_session.endpoint_from_model_data( - model_s3_location=S3_MODEL_ARTIFACTS, - image_uri=DEPLOY_IMAGE, + try: + sagemaker_session.endpoint_from_model_data( + model_s3_location=S3_MODEL_ARTIFACTS, + image_uri=DEPLOY_IMAGE, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + wait=False, + ) + except ClientError: + assert False, "Unexpected ClientError raised for resource already exists scenario" + + sagemaker_session.create_model.assert_called_once_with( + name=NAME_FROM_IMAGE, + role=None, + container_defs=container_def_with_env, + vpc_config=None, + ) + sagemaker_session.create_endpoint_config.assert_called_once_with( + name=NAME_FROM_IMAGE, + model_name=NAME_FROM_IMAGE, initial_instance_count=INITIAL_INSTANCE_COUNT, instance_type=INSTANCE_TYPE, - wait=False, + accelerator_type=None, + data_capture_config_dict=None, ) - - sagemaker_session.create_model.assert_not_called() - sagemaker_session.create_endpoint_config.assert_not_called() sagemaker_session.create_endpoint.assert_called_once_with( endpoint_name=NAME_FROM_IMAGE, config_name=NAME_FROM_IMAGE, wait=False ) -def test_entity_exists(): +@patch("sagemaker.session.name_from_image", return_value=NAME_FROM_IMAGE) +def test_model_and_endpoint_config_raises_unexpected_error(name_from_image_mock, sagemaker_session): + container_def_with_env = CONTAINER_DEF.copy() + + sagemaker_session.create_endpoint_config = Mock( + name="create_endpoint_config", side_effect=_raise_unexpected_client_error + ) + + with pytest.raises(ClientError): + sagemaker_session.endpoint_from_model_data( + model_s3_location=S3_MODEL_ARTIFACTS, + image_uri=DEPLOY_IMAGE, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + wait=False, + ) + + sagemaker_session.create_model.assert_called_once_with( + name=NAME_FROM_IMAGE, + role=None, + container_defs=container_def_with_env, + vpc_config=None, + ) + sagemaker_session.create_endpoint_config.assert_called_once_with( + name=NAME_FROM_IMAGE, + model_name=NAME_FROM_IMAGE, + initial_instance_count=INITIAL_INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + data_capture_config_dict=None, + ) + sagemaker_session.create_endpoint.assert_not_called() + + +def test_create_resource_entity_exists(): + # _create_resource returns False + assert not sagemaker.session._create_resource(_raise_does_already_exists_client_error) + + +def test_create_resource_unexpected_error(): + # _create_resource returns ClientError + with pytest.raises(ClientError): + sagemaker.session._create_resource(_raise_unexpected_client_error) + + +def test_create_resource_entity_doesnt_exist(): + # _create_resource returns True + assert sagemaker.session._create_resource(lambda: None) + + +def test_deployment_entity_exists(): assert sagemaker.session._deployment_entity_exists(lambda: None) @@ -169,16 +211,5 @@ def test_entity_doesnt_exist(): def test_describe_failure(): - def _raise_unexpected_client_error(): - response = { - "Error": {"Code": "ValidationException", "Message": "Name does not satisfy expression."} - } - raise ClientError(error_response=response, operation_name="foo") - with pytest.raises(ClientError): sagemaker.session._deployment_entity_exists(_raise_unexpected_client_error) - - -def _raise_does_not_exist_client_error(**kwargs): - response = {"Error": {"Code": "ValidationException", "Message": "Could not find entity."}} - raise ClientError(error_response=response, operation_name="foo") diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 20c6cb1222..b9be3fb285 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -22,6 +22,7 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch, call, mock_open +from .common import _raise_unexpected_client_error import sagemaker from sagemaker import TrainingInput, Session, get_execution_role, exceptions from sagemaker.async_inference import AsyncInferenceConfig @@ -29,6 +30,8 @@ _tuning_job_status, _transform_job_status, _train_done, + _wait_until, + _wait_until_training_done, NOTEBOOK_METADATA_FILE, ) from sagemaker.tuner import WarmStartConfig, WarmStartTypes @@ -2342,6 +2345,81 @@ def test_train_done_in_progress(sagemaker_session): assert training_finished is False +@patch("time.sleep", return_value=None) +def test_wait_until_training_done_raises_other_exception(patched_sleep): + response = {"Error": {"Code": "ValidationException", "Message": "Could not access entity."}} + mock_func = Mock( + name="describe_training_job", + side_effect=ClientError(error_response=response, operation_name="foo"), + ) + desc = "dummy" + with pytest.raises(ClientError) as error: + _wait_until_training_done(mock_func, desc) + + mock_func.assert_called_once() + assert "ValidationException" in str(error) + + +@patch("time.sleep", return_value=None) +def test_wait_until_training_done_tag_propagation(patched_sleep): + response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}} + side_effect_iter = [ClientError(error_response=response, operation_name="foo")] * 3 + side_effect_iter.append(("result", "result")) + mock_func = Mock(name="describe_training_job", side_effect=side_effect_iter) + desc = "dummy" + result = _wait_until_training_done(mock_func, desc) + assert result == "result" + assert mock_func.call_count == 4 + + +@patch("time.sleep", return_value=None) +def test_wait_until_training_done_fail_access_denied_after_5_mins(patched_sleep): + response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}} + side_effect_iter = [ClientError(error_response=response, operation_name="foo")] * 70 + mock_func = Mock(name="describe_training_job", side_effect=side_effect_iter) + desc = "dummy" + with pytest.raises(ClientError) as error: + _wait_until_training_done(mock_func, desc) + + # mock_func should be retried 300(elapsed time)/5(default poll delay) = 60 times + assert mock_func.call_count == 61 + assert "AccessDeniedException" in str(error) + + +@patch("time.sleep", return_value=None) +def test_wait_until_raises_other_exception(patched_sleep): + mock_func = Mock(name="describe_training_job", side_effect=_raise_unexpected_client_error) + with pytest.raises(ClientError) as error: + _wait_until(mock_func) + + mock_func.assert_called_once() + assert "ValidationException" in str(error) + + +@patch("time.sleep", return_value=None) +def test_wait_until_tag_propagation(patched_sleep): + response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}} + side_effect_iter = [ClientError(error_response=response, operation_name="foo")] * 3 + side_effect_iter.append("result") + mock_func = Mock(name="describe_training_job", side_effect=side_effect_iter) + result = _wait_until(mock_func) + assert result == "result" + assert mock_func.call_count == 4 + + +@patch("time.sleep", return_value=None) +def test_wait_until_fail_access_denied_after_5_mins(patched_sleep): + response = {"Error": {"Code": "AccessDeniedException", "Message": "Could not access entity."}} + side_effect_iter = [ClientError(error_response=response, operation_name="foo")] * 70 + mock_func = Mock(name="describe_training_job", side_effect=side_effect_iter) + with pytest.raises(ClientError) as error: + _wait_until(mock_func) + + # mock_func should be retried 300(elapsed time)/5(default poll delay) = 60 times + assert mock_func.call_count == 61 + assert "AccessDeniedException" in str(error) + + DEFAULT_EXPECTED_AUTO_ML_JOB_ARGS = { "AutoMLJobName": JOB_NAME, "InputDataConfig": [ @@ -3354,8 +3432,8 @@ def test_wait_for_inference_recommendations_job_completed(sleep, sm_session_infe 4 == sm_session_inference_recommender.sagemaker_client.describe_inference_recommendations_job.call_count ) - assert 2 == sleep.call_count - sleep.assert_has_calls([call(120), call(120)]) + assert 3 == sleep.call_count + sleep.assert_has_calls([call(120), call(120), call(120)]) def test_wait_for_inference_recommendations_job_failed(sagemaker_session): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 683e20f7c8..be15f0f932 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -795,7 +795,8 @@ def test_to_string(): } -def test_start_waiting(capfd): +@patch("time.sleep", return_value=None) +def test_start_waiting(patched_sleep, capfd): waiting_time = 1 sagemaker.utils._start_waiting(waiting_time) out, _ = capfd.readouterr() @@ -803,7 +804,8 @@ def test_start_waiting(capfd): assert "." * sagemaker.utils.WAITING_DOT_NUMBER in out -def test_retry_with_backoff(): +@patch("time.sleep", return_value=None) +def test_retry_with_backoff(patched_sleep): callable_func = Mock() # Invalid input @@ -824,6 +826,25 @@ def test_retry_with_backoff(): callable_func.side_effect = [RuntimeError(run_err_msg), func_return_val] assert retry_with_backoff(callable_func, 2) == func_return_val + # when retry on specific error, fail for other error on 1st try + func_return_val = "Test Return" + response = {"Error": {"Code": "ValidationException", "Message": "Could not find entity."}} + error = botocore.exceptions.ClientError(error_response=response, operation_name="foo") + callable_func.side_effect = [error, func_return_val] + with pytest.raises(botocore.exceptions.ClientError) as run_err: + retry_with_backoff(callable_func, 2, botocore_client_error_code="AccessDeniedException") + assert "ValidationException" in str(run_err) + + # when retry on specific error, One retry passes + func_return_val = "Test Return" + response = {"Error": {"Code": "AccessDeniedException", "Message": "Access denied."}} + error = botocore.exceptions.ClientError(error_response=response, operation_name="foo") + callable_func.side_effect = [error, func_return_val] + assert ( + retry_with_backoff(callable_func, 2, botocore_client_error_code="AccessDeniedException") + == func_return_val + ) + # No retry callable_func.side_effect = None callable_func.return_value = func_return_val