From 22c12d66fcdcf4c8b36357c03d085b50395b61b3 Mon Sep 17 00:00:00 2001 From: subham611 Date: Thu, 4 Apr 2024 15:31:18 +0530 Subject: [PATCH 1/7] Update ACL during job reset --- .../providers/databricks/hooks/databricks.py | 11 +++ .../databricks/operators/databricks.py | 4 ++ .../databricks/operators/test_databricks.py | 72 +++++++++++++++++++ 3 files changed, 87 insertions(+) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index c052164214269..78677934528d3 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -51,6 +51,7 @@ REPAIR_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/repair") OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "api/2.1/jobs/runs/get-output") CANCEL_ALL_RUNS_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel-all") +UPDATE_PERMISSION_ENDPOINT = ("PATCH", "/api/2.0/permissions/jobs") INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install") UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall") @@ -655,6 +656,16 @@ def get_repo_by_path(self, path: str) -> str | None: return None + def update_job_permission(self, json: dict[str, Any]) -> dict: + """ + Update databricks job permission + + :param json: acl payload + :return: + """ + return self._do_api_call(UPDATE_PERMISSION_ENDPOINT, json) + + def test_connection(self) -> tuple[bool, str]: """Test the Databricks connectivity from UI.""" hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 247d810a6bf3d..98877f8b5f7d8 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -316,6 +316,10 @@ def execute(self, context: Context) -> int: if job_id is None: return self._hook.create_job(self.json) self._hook.reset_job(str(job_id), self.json) + if "access_control_list" in self.json and self.json["access_control_list"] is not None: + acl_json = {"access_control_list": self.json["access_control_list"]} + self._hook.update_job_permission(normalise_json_content(acl_json)) + return job_id diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index f2a3441f435cf..e175022eb931d 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -504,6 +504,78 @@ def test_exec_reset(self, db_mock_class): db_mock.reset_job.assert_called_once_with(JOB_ID, expected) assert JOB_ID == return_result + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_update_job_permission(self, db_mock_class): + """ + Test job permission update. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + "access_control_list": ACCESS_CONTROL_LIST, + } + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + db_mock = db_mock_class.return_value + db_mock.find_job_id_by_name.return_value = JOB_ID + + op.execute({}) + + expected = utils.normalise_json_content( + { + "access_control_list": ACCESS_CONTROL_LIST + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksCreateJobsOperator", + ) + + db_mock.update_job_permission.assert_called_once_with(expected) + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_update_job_permission_with_empty_acl(self, db_mock_class): + """ + Test job permission update. + """ + json = { + "name": JOB_NAME, + "tags": TAGS, + "tasks": TASKS, + "job_clusters": JOB_CLUSTERS, + "email_notifications": EMAIL_NOTIFICATIONS, + "webhook_notifications": WEBHOOK_NOTIFICATIONS, + "timeout_seconds": TIMEOUT_SECONDS, + "schedule": SCHEDULE, + "max_concurrent_runs": MAX_CONCURRENT_RUNS, + "git_source": GIT_SOURCE, + } + op = DatabricksCreateJobsOperator(task_id=TASK_ID, json=json) + db_mock = db_mock_class.return_value + db_mock.find_job_id_by_name.return_value = JOB_ID + + op.execute({}) + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksCreateJobsOperator", + ) + + db_mock.update_job_permission.assert_not_called() + class TestDatabricksSubmitRunOperator: def test_init_with_notebook_task_named_parameters(self): From 7c238cf4bbd1378e195189a38d80c72fdb553724 Mon Sep 17 00:00:00 2001 From: subham611 Date: Sat, 6 Apr 2024 00:23:43 +0530 Subject: [PATCH 2/7] Fix static checks --- airflow/providers/databricks/hooks/databricks.py | 2 +- tests/providers/databricks/operators/test_databricks.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 78677934528d3..ea4a90285f7d4 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -660,7 +660,7 @@ def update_job_permission(self, json: dict[str, Any]) -> dict: """ Update databricks job permission - :param json: acl payload + :param json: payload :return: """ return self._do_api_call(UPDATE_PERMISSION_ENDPOINT, json) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index e175022eb931d..278f95bf016b4 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -528,11 +528,8 @@ def test_exec_update_job_permission(self, db_mock_class): op.execute({}) - expected = utils.normalise_json_content( - { - "access_control_list": ACCESS_CONTROL_LIST - } - ) + expected = utils.normalise_json_content({"access_control_list": ACCESS_CONTROL_LIST}) + db_mock_class.assert_called_once_with( DEFAULT_CONN_ID, retry_limit=op.databricks_retry_limit, From dd84555ac55e84ce55453b39757ce72960adb67d Mon Sep 17 00:00:00 2001 From: subham611 Date: Mon, 8 Apr 2024 20:19:44 +0530 Subject: [PATCH 3/7] Fix static checks --- airflow/providers/databricks/hooks/databricks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index ea4a90285f7d4..f62e7bb0fa0ec 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -661,11 +661,10 @@ def update_job_permission(self, json: dict[str, Any]) -> dict: Update databricks job permission :param json: payload - :return: + :return: json containing permission specification """ return self._do_api_call(UPDATE_PERMISSION_ENDPOINT, json) - def test_connection(self) -> tuple[bool, str]: """Test the Databricks connectivity from UI.""" hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id) From 0d08818a18a173c4c2a06fbf7193858b6292648e Mon Sep 17 00:00:00 2001 From: subham611 Date: Mon, 8 Apr 2024 20:20:58 +0530 Subject: [PATCH 4/7] Fix static check --- airflow/providers/databricks/hooks/databricks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index f62e7bb0fa0ec..6911f1a64cc06 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -663,6 +663,7 @@ def update_job_permission(self, json: dict[str, Any]) -> dict: :param json: payload :return: json containing permission specification """ + return self._do_api_call(UPDATE_PERMISSION_ENDPOINT, json) def test_connection(self) -> tuple[bool, str]: From ac0a9204ec1bcf4c249e562610022c263b8e8f8e Mon Sep 17 00:00:00 2001 From: subham611 Date: Mon, 8 Apr 2024 20:50:56 +0530 Subject: [PATCH 5/7] Fix static check --- airflow/providers/databricks/hooks/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 6911f1a64cc06..e07b658a8311b 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -658,7 +658,7 @@ def get_repo_by_path(self, path: str) -> str | None: def update_job_permission(self, json: dict[str, Any]) -> dict: """ - Update databricks job permission + Update databricks job permission. :param json: payload :return: json containing permission specification From 421a40f08e2c4c94934b739d16d629a2882c937c Mon Sep 17 00:00:00 2001 From: subham611 Date: Mon, 8 Apr 2024 21:16:32 +0530 Subject: [PATCH 6/7] Fix static checks --- airflow/providers/databricks/hooks/databricks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index e07b658a8311b..8074566e8b976 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -663,7 +663,6 @@ def update_job_permission(self, json: dict[str, Any]) -> dict: :param json: payload :return: json containing permission specification """ - return self._do_api_call(UPDATE_PERMISSION_ENDPOINT, json) def test_connection(self) -> tuple[bool, str]: From 11e54ab93cc2b25bc661008520092cccfbe46e24 Mon Sep 17 00:00:00 2001 From: subham611 Date: Tue, 9 Apr 2024 13:33:12 +0530 Subject: [PATCH 7/7] Refactor code --- airflow/providers/databricks/operators/databricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 98877f8b5f7d8..3d95c61f6abb3 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -316,8 +316,8 @@ def execute(self, context: Context) -> int: if job_id is None: return self._hook.create_job(self.json) self._hook.reset_job(str(job_id), self.json) - if "access_control_list" in self.json and self.json["access_control_list"] is not None: - acl_json = {"access_control_list": self.json["access_control_list"]} + if (access_control_list := self.json.get("access_control_list")) is not None: + acl_json = {"access_control_list": access_control_list} self._hook.update_job_permission(normalise_json_content(acl_json)) return job_id