diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py index b06e001a0701c..8213b9af149bc 100644 --- a/airflow/providers/apache/spark/hooks/spark_submit.py +++ b/airflow/providers/apache/spark/hooks/spark_submit.py @@ -82,9 +82,9 @@ class SparkSubmitHook(BaseHook, LoggingMixin): (will overwrite any spark_binary defined in the connection's extra JSON) :param properties_file: Path to a file from which to load extra properties. If not specified, this will look for conf/spark-defaults.conf. - :param queue: The name of the YARN queue to which the application is submitted. + :param yarn_queue: The name of the YARN queue to which the application is submitted. (will overwrite any yarn queue defined in the connection's extra JSON) - :param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as an client. + :param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as an client. (will overwrite any deployment mode defined in the connection's extra JSON) :param use_krb5ccache: if True, configure spark to use ticket cache instead of relying on keytab for Kerberos login @@ -165,7 +165,7 @@ def __init__( verbose: bool = False, spark_binary: str | None = None, properties_file: str | None = None, - queue: str | None = None, + yarn_queue: str | None = None, deploy_mode: str | None = None, *, use_krb5ccache: bool = False, @@ -201,7 +201,7 @@ def __init__( self._kubernetes_driver_pod: str | None = None self.spark_binary = spark_binary self._properties_file = properties_file - self._queue = queue + self._yarn_queue = yarn_queue self._deploy_mode = deploy_mode self._connection = self._resolve_connection() self._is_yarn = "yarn" in self._connection["master"] @@ -231,7 +231,7 @@ def _resolve_connection(self) -> dict[str, Any]: # Build from connection master or default to yarn if not available conn_data = { "master": "yarn", - "queue": None, + "queue": None, # yarn queue "deploy_mode": None, "spark_binary": self.spark_binary or DEFAULT_SPARK_BINARY, "namespace": None, @@ -248,7 +248,7 @@ def _resolve_connection(self) -> dict[str, Any]: # Determine optional yarn queue from the extra field extra = conn.extra_dejson - conn_data["queue"] = self._queue if self._queue else extra.get("queue") + conn_data["queue"] = self._yarn_queue if self._yarn_queue else extra.get("queue") conn_data["deploy_mode"] = self._deploy_mode if self._deploy_mode else extra.get("deploy-mode") if not self.spark_binary: self.spark_binary = extra.get("spark-binary", DEFAULT_SPARK_BINARY) diff --git a/airflow/providers/apache/spark/operators/spark_submit.py b/airflow/providers/apache/spark/operators/spark_submit.py index 62f7918fcf993..281919b2b14c8 100644 --- a/airflow/providers/apache/spark/operators/spark_submit.py +++ b/airflow/providers/apache/spark/operators/spark_submit.py @@ -72,7 +72,7 @@ class SparkSubmitOperator(BaseOperator): (will overwrite any spark_binary defined in the connection's extra JSON) :param properties_file: Path to a file from which to load extra properties. If not specified, this will look for conf/spark-defaults.conf. - :param queue: The name of the YARN queue to which the application is submitted. + :param yarn_queue: The name of the YARN queue to which the application is submitted. (will overwrite any yarn queue defined in the connection's extra JSON) :param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as a client. (will overwrite any deployment mode defined in the connection's extra JSON) @@ -129,7 +129,7 @@ def __init__( verbose: bool = False, spark_binary: str | None = None, properties_file: str | None = None, - queue: str | None = None, + yarn_queue: str | None = None, deploy_mode: str | None = None, use_krb5ccache: bool = False, **kwargs: Any, @@ -161,7 +161,7 @@ def __init__( self._verbose = verbose self._spark_binary = spark_binary self.properties_file = properties_file - self._queue = queue + self._yarn_queue = yarn_queue self._deploy_mode = deploy_mode self._hook: SparkSubmitHook | None = None self._conn_id = conn_id @@ -206,7 +206,7 @@ def _get_hook(self) -> SparkSubmitHook: verbose=self._verbose, spark_binary=self._spark_binary, properties_file=self.properties_file, - queue=self._queue, + yarn_queue=self._yarn_queue, deploy_mode=self._deploy_mode, use_krb5ccache=self._use_krb5ccache, ) diff --git a/tests/providers/apache/spark/operators/test_spark_submit.py b/tests/providers/apache/spark/operators/test_spark_submit.py index 4f8cb7d5486de..4e1cbf89d3735 100644 --- a/tests/providers/apache/spark/operators/test_spark_submit.py +++ b/tests/providers/apache/spark/operators/test_spark_submit.py @@ -67,8 +67,9 @@ class TestSparkSubmitOperator: "args should keep embedded spaces", ], "use_krb5ccache": True, - "queue": "yarn_dev_queue2", + "yarn_queue": "yarn_dev_queue2", "deploy_mode": "client2", + "queue": "airflow_custom_queue", } def setup_method(self): @@ -122,10 +123,11 @@ def test_execute(self): "args should keep embedded spaces", ], "spark_binary": "sparky", - "queue": "yarn_dev_queue2", + "yarn_queue": "yarn_dev_queue2", "deploy_mode": "client2", "use_krb5ccache": True, "properties_file": "conf/spark-custom.conf", + "queue": "airflow_custom_queue", } assert conn_id == operator._conn_id @@ -153,10 +155,11 @@ def test_execute(self): assert expected_dict["driver_memory"] == operator._driver_memory assert expected_dict["application_args"] == operator.application_args assert expected_dict["spark_binary"] == operator._spark_binary - assert expected_dict["queue"] == operator._queue assert expected_dict["deploy_mode"] == operator._deploy_mode assert expected_dict["properties_file"] == operator.properties_file assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache + assert expected_dict["queue"] == operator.queue + assert expected_dict["yarn_queue"] == operator._yarn_queue @pytest.mark.db_test def test_spark_submit_cmd_connection_overrides(self): @@ -168,18 +171,21 @@ def test_spark_submit_cmd_connection_overrides(self): task_id="spark_submit_job", spark_binary="sparky", dag=self.dag, **config ) cmd = " ".join(operator._get_hook()._build_spark_submit_command("test")) - assert "--queue yarn_dev_queue2" in cmd + assert "--queue yarn_dev_queue2" in cmd # yarn queue assert "--deploy-mode client2" in cmd assert "sparky" in cmd + assert operator.queue == "airflow_custom_queue" # airflow queue - # if we don't pass any overrides in arguments - config["queue"] = None + # if we don't pass any overrides in arguments, default values + config["yarn_queue"] = None config["deploy_mode"] = None + config.pop("queue", None) # using default airflow queue operator2 = SparkSubmitOperator(task_id="spark_submit_job2", dag=self.dag, **config) cmd2 = " ".join(operator2._get_hook()._build_spark_submit_command("test")) - assert "--queue root.default" in cmd2 + assert "--queue root.default" in cmd2 # yarn queue assert "--deploy-mode client2" not in cmd2 assert "spark-submit" in cmd2 + assert operator2.queue == "default" # airflow queue @pytest.mark.db_test def test_render_template(self):