From 14d64692c681742162ebfb210663b3fc3d783679 Mon Sep 17 00:00:00 2001 From: "Alden S. Page" Date: Tue, 2 Jul 2024 14:31:10 -0400 Subject: [PATCH 1/3] Add support for query parameters to BigQueryCheckOperator (#40556) Remove unnecessary space --- airflow/providers/google/cloud/operators/bigquery.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index d3f79c9bbde4d..6b4276a87d5e2 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -222,6 +222,12 @@ class BigQueryCheckOperator( :param deferrable: Run operator in the deferrable mode. :param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job. + :param query_params: a list of dictionary containing query parameter types and + values, passed to BigQuery. The structure of dictionary should look like + 'queryParameters' in Google BigQuery Jobs API: + https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs. + For example, [{ 'name': 'corpus', 'parameterType': { 'type': 'STRING' }, + 'parameterValue': { 'value': 'romeoandjuliet' } }]. (templated) """ template_fields: Sequence[str] = ( @@ -229,6 +235,7 @@ class BigQueryCheckOperator( "gcp_conn_id", "impersonation_chain", "labels", + "query_params", ) template_ext: Sequence[str] = (".sql",) ui_color = BigQueryUIColors.CHECK.value @@ -246,6 +253,7 @@ def __init__( encryption_configuration: dict | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, + query_params: list | None = None, **kwargs, ) -> None: super().__init__(sql=sql, **kwargs) @@ -257,6 +265,7 @@ def __init__( self.encryption_configuration = encryption_configuration self.deferrable = deferrable self.poll_interval = poll_interval + self.query_params = query_params def _submit_job( self, @@ -265,6 +274,8 @@ def _submit_job( ) -> BigQueryJob: """Submit a new job and get the job id for polling the status using Trigger.""" configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}} + if self.query_params: + configuration["query"]["queryParameters"] = query_params self.include_encryption_configuration(configuration, "query") From 2e2b3e955697503b67bf310f08741e19af2be02b Mon Sep 17 00:00:00 2001 From: "Alden S. Page" Date: Tue, 2 Jul 2024 15:24:30 -0400 Subject: [PATCH 2/3] Add a unit test for BigQueryCheckOperator query params; fix missing 'self' reference --- .../google/cloud/operators/bigquery.py | 2 +- .../google/cloud/operators/test_bigquery.py | 31 ++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 6b4276a87d5e2..43131c549a3c3 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -275,7 +275,7 @@ def _submit_job( """Submit a new job and get the job id for polling the status using Trigger.""" configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}} if self.query_params: - configuration["query"]["queryParameters"] = query_params + configuration["query"]["queryParameters"] = self.query_params self.include_encryption_configuration(configuration, "query") diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index febcfe48712e7..38115afcf37a3 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -24,7 +24,7 @@ import pandas as pd import pytest -from google.cloud.bigquery import DEFAULT_RETRY +from google.cloud.bigquery import DEFAULT_RETRY, ScalarQueryParameter from google.cloud.exceptions import Conflict from openlineage.client.facet import ErrorMessageRunFacet, ExternalQueryRunFacet, SqlJobFacet from openlineage.client.run import Dataset @@ -2293,6 +2293,35 @@ def test_bigquery_check_operator_async_finish_before_deferred( mock_defer.assert_not_called() mock_validate_records.assert_called_once_with((1, 2, 3)) + @pytest.mark.db_test + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator._validate_records") + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + def test_bigquery_check_operator_query_parameters_passing( + self, mock_hook, mock_validate_records, create_task_instance_of_operator + ): + job_id = "123456" + hash_ = "hash" + real_job_id = f"{job_id}_{hash_}" + query_params = [ScalarQueryParameter("test_param", "INT64", 1)] + + mocked_job = MagicMock(job_id=real_job_id, error_result=False) + mocked_job.result.return_value = iter([(1, 2, 3)]) # mock rows generator + mock_hook.return_value.insert_job.return_value = mocked_job + mock_hook.return_value.insert_job.return_value.running.return_value = False + + ti = create_task_instance_of_operator( + BigQueryCheckOperator, + dag_id="dag_id", + task_id="bq_check_operator_query_params_job", + sql="SELECT * FROM any WHERE test_param = @test_param", + location=TEST_DATASET_LOCATION, + deferrable=True, + query_params=query_params + ) + + ti.task.execute(MagicMock()) + mock_validate_records.assert_called_once_with((1, 2, 3)) + @pytest.mark.db_test @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") def test_bigquery_check_operator_async_finish_with_error_before_deferred( From ee9e2d5ef682ca14ec7097011dd759d6130220bb Mon Sep 17 00:00:00 2001 From: "Alden S. Page" Date: Tue, 2 Jul 2024 17:06:56 -0400 Subject: [PATCH 3/3] Fix lint (#40558) --- tests/providers/google/cloud/operators/test_bigquery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 38115afcf37a3..d49e75c95070c 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -2316,7 +2316,7 @@ def test_bigquery_check_operator_query_parameters_passing( sql="SELECT * FROM any WHERE test_param = @test_param", location=TEST_DATASET_LOCATION, deferrable=True, - query_params=query_params + query_params=query_params, ) ti.task.execute(MagicMock())