diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py index 54d4896e1277f..5aac2f485e445 100644 --- a/airflow/providers/amazon/aws/operators/rds.py +++ b/airflow/providers/amazon/aws/operators/rds.py @@ -585,6 +585,8 @@ class RdsCreateDbInstanceOperator(RdsBaseOperator): :param wait_for_completion: If True, waits for creation of the DB instance to complete. (default: True) """ + template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs") + def __init__( self, *, @@ -637,6 +639,8 @@ class RdsDeleteDbInstanceOperator(RdsBaseOperator): :param wait_for_completion: If True, waits for deletion of the DB instance to complete. (default: True) """ + template_fields = ("db_instance_identifier", "rds_kwargs") + def __init__( self, *, diff --git a/tests/system/providers/amazon/aws/example_rds_event.py b/tests/system/providers/amazon/aws/example_rds_event.py index bbb24037a207f..1f1bc6ec08f9d 100644 --- a/tests/system/providers/amazon/aws/example_rds_event.py +++ b/tests/system/providers/amazon/aws/example_rds_event.py @@ -23,9 +23,10 @@ from airflow import DAG from airflow.decorators import task from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.hooks.rds import RdsHook from airflow.providers.amazon.aws.operators.rds import ( + RdsCreateDbInstanceOperator, RdsCreateEventSubscriptionOperator, + RdsDeleteDbInstanceOperator, RdsDeleteEventSubscriptionOperator, ) from airflow.utils.trigger_rule import TriggerRule @@ -41,35 +42,6 @@ def create_sns_topic(env_id) -> str: return boto3.client('sns').create_topic(Name=f'{env_id}-topic')['TopicArn'] -@task -def create_rds_instance(db_name, instance_name) -> None: - rds_client = RdsHook().get_conn() - rds_client.create_db_instance( - DBName=db_name, - DBInstanceIdentifier=instance_name, - AllocatedStorage=20, - DBInstanceClass='db.t3.micro', - Engine='postgres', - MasterUsername='username', - # NEVER store your production password in plaintext in a DAG like this. - # Use Airflow Secrets or a secret manager for this in production. - MasterUserPassword='rds_password', - ) - - rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=instance_name) - - -@task(trigger_rule=TriggerRule.ALL_DONE) -def delete_db_instance(instance_name) -> None: - rds_client = RdsHook().get_conn() - rds_client.delete_db_instance( - DBInstanceIdentifier=instance_name, - SkipFinalSnapshot=True, - ) - - rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=instance_name) - - @task(trigger_rule=TriggerRule.ALL_DONE) def delete_sns_topic(topic_arn) -> None: boto3.client('sns').delete_topic(TopicArn=topic_arn) @@ -90,6 +62,21 @@ def delete_sns_topic(topic_arn) -> None: sns_topic = create_sns_topic(test_context[ENV_ID_KEY]) + create_db_instance = RdsCreateDbInstanceOperator( + task_id="create_db_instance", + db_instance_identifier=rds_instance_name, + db_instance_class="db.t4g.micro", + engine="postgres", + rds_kwargs={ + "MasterUsername": "rds_username", + # NEVER store your production password in plaintext in a DAG like this. + # Use Airflow Secrets or a secret manager for this in production. + "MasterUserPassword": "rds_password", + "AllocatedStorage": 20, + "DBName": rds_db_name, + }, + ) + # [START howto_operator_rds_create_event_subscription] create_subscription = RdsCreateEventSubscriptionOperator( task_id='create_subscription', @@ -108,16 +95,23 @@ def delete_sns_topic(topic_arn) -> None: ) # [END howto_operator_rds_delete_event_subscription] + delete_db_instance = RdsDeleteDbInstanceOperator( + task_id="delete_db_instance", + db_instance_identifier=rds_instance_name, + rds_kwargs={"SkipFinalSnapshot": True}, + trigger_rule=TriggerRule.ALL_DONE, + ) + chain( # TEST SETUP test_context, sns_topic, - create_rds_instance(rds_db_name, rds_instance_name), + create_db_instance, # TEST BODY create_subscription, delete_subscription, # TEST TEARDOWN - delete_db_instance(rds_instance_name), + delete_db_instance, delete_sns_topic(sns_topic), ) diff --git a/tests/system/providers/amazon/aws/example_rds_export.py b/tests/system/providers/amazon/aws/example_rds_export.py index 356789b76a9b1..f0090470c7e03 100644 --- a/tests/system/providers/amazon/aws/example_rds_export.py +++ b/tests/system/providers/amazon/aws/example_rds_export.py @@ -24,7 +24,9 @@ from airflow.providers.amazon.aws.hooks.rds import RdsHook from airflow.providers.amazon.aws.operators.rds import ( RdsCancelExportTaskOperator, + RdsCreateDbInstanceOperator, RdsCreateDbSnapshotOperator, + RdsDeleteDbInstanceOperator, RdsDeleteDbSnapshotOperator, RdsStartExportTaskOperator, ) @@ -44,41 +46,12 @@ ) -@task -def create_rds_instance(db_name: str, instance_name: str) -> None: - rds_client = RdsHook().conn - rds_client.create_db_instance( - DBName=db_name, - DBInstanceIdentifier=instance_name, - AllocatedStorage=20, - DBInstanceClass='db.t3.micro', - Engine='postgres', - MasterUsername='username', - # NEVER store your production password in plaintext in a DAG like this. - # Use Airflow Secrets or a secret manager for this in production. - MasterUserPassword='rds_password', - ) - - rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=instance_name) - - @task def get_snapshot_arn(snapshot_name: str) -> str: result = RdsHook().conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_name) return result['DBSnapshots'][0]['DBSnapshotArn'] -@task(trigger_rule=TriggerRule.ALL_DONE) -def delete_rds_instance(instance_name) -> None: - rds_client = RdsHook().get_conn() - rds_client.delete_db_instance( - DBInstanceIdentifier=instance_name, - SkipFinalSnapshot=True, - ) - - rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=instance_name) - - with DAG( dag_id=DAG_ID, schedule='@once', @@ -101,6 +74,21 @@ def delete_rds_instance(instance_name) -> None: bucket_name=bucket_name, ) + create_db_instance = RdsCreateDbInstanceOperator( + task_id="create_db_instance", + db_instance_identifier=rds_instance_name, + db_instance_class="db.t4g.micro", + engine="postgres", + rds_kwargs={ + "MasterUsername": "rds_username", + # NEVER store your production password in plaintext in a DAG like this. + # Use Airflow Secrets or a secret manager for this in production. + "MasterUserPassword": "rds_password", + "AllocatedStorage": 20, + "DBName": rds_db_name, + }, + ) + create_snapshot = RdsCreateDbSnapshotOperator( task_id='create_snapshot', db_type='instance', @@ -160,11 +148,18 @@ def delete_rds_instance(instance_name) -> None: force_delete=True, ) + delete_db_instance = RdsDeleteDbInstanceOperator( + task_id="delete_db_instance", + db_instance_identifier=rds_instance_name, + rds_kwargs={"SkipFinalSnapshot": True}, + trigger_rule=TriggerRule.ALL_DONE, + ) + chain( # TEST SETUP test_context, create_bucket, - create_rds_instance(rds_db_name, rds_instance_name), + create_db_instance, create_snapshot, await_snapshot, snapshot_arn, @@ -175,7 +170,7 @@ def delete_rds_instance(instance_name) -> None: # TEST TEARDOWN delete_snapshot, delete_bucket, - delete_rds_instance(rds_instance_name), + delete_db_instance, ) from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/amazon/aws/example_rds_snapshot.py b/tests/system/providers/amazon/aws/example_rds_snapshot.py index eb634a7099136..05b2f6f22bd88 100644 --- a/tests/system/providers/amazon/aws/example_rds_snapshot.py +++ b/tests/system/providers/amazon/aws/example_rds_snapshot.py @@ -19,12 +19,12 @@ from datetime import datetime from airflow import DAG -from airflow.decorators import task from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.hooks.rds import RdsHook from airflow.providers.amazon.aws.operators.rds import ( RdsCopyDbSnapshotOperator, + RdsCreateDbInstanceOperator, RdsCreateDbSnapshotOperator, + RdsDeleteDbInstanceOperator, RdsDeleteDbSnapshotOperator, ) from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor @@ -36,35 +36,6 @@ sys_test_context_task = SystemTestContextBuilder().build() -@task -def create_rds_instance(db_name, instance_name) -> None: - rds_client = RdsHook().get_conn() - rds_client.create_db_instance( - DBName=db_name, - DBInstanceIdentifier=instance_name, - AllocatedStorage=20, - DBInstanceClass='db.t3.micro', - Engine='postgres', - MasterUsername='username', - # NEVER store your production password in plaintext in a DAG like this. - # Use Airflow Secrets or a secret manager for this in production. - MasterUserPassword='rds_password', - ) - - rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=instance_name) - - -@task(trigger_rule=TriggerRule.ALL_DONE) -def delete_rds_instance(instance_name) -> None: - rds_client = RdsHook().get_conn() - rds_client.delete_db_instance( - DBInstanceIdentifier=instance_name, - SkipFinalSnapshot=True, - ) - - rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=instance_name) - - with DAG( dag_id=DAG_ID, schedule='@once', @@ -79,6 +50,21 @@ def delete_rds_instance(instance_name) -> None: rds_snapshot_name = f'{test_context[ENV_ID_KEY]}-snapshot' rds_snapshot_copy_name = f'{rds_snapshot_name}-copy' + create_db_instance = RdsCreateDbInstanceOperator( + task_id="create_db_instance", + db_instance_identifier=rds_instance_name, + db_instance_class="db.t4g.micro", + engine="postgres", + rds_kwargs={ + "MasterUsername": "rds_username", + # NEVER store your production password in plaintext in a DAG like this. + # Use Airflow Secrets or a secret manager for this in production. + "MasterUserPassword": "rds_password", + "AllocatedStorage": 20, + "DBName": rds_db_name, + }, + ) + # [START howto_operator_rds_create_db_snapshot] create_snapshot = RdsCreateDbSnapshotOperator( task_id='create_snapshot', @@ -127,10 +113,17 @@ def delete_rds_instance(instance_name) -> None: db_snapshot_identifier=rds_snapshot_copy_name, ) + delete_db_instance = RdsDeleteDbInstanceOperator( + task_id="delete_db_instance", + db_instance_identifier=rds_instance_name, + rds_kwargs={"SkipFinalSnapshot": True}, + trigger_rule=TriggerRule.ALL_DONE, + ) + chain( # TEST SETUP test_context, - create_rds_instance(rds_db_name, rds_instance_name), + create_db_instance, # TEST BODY create_snapshot, snapshot_sensor, @@ -139,7 +132,7 @@ def delete_rds_instance(instance_name) -> None: # TEST TEARDOWN snapshot_copy_sensor, delete_snapshot_copy, - delete_rds_instance(rds_instance_name), + delete_db_instance, ) from tests.system.utils.watcher import watcher