From b14cfa497e917827b6e10b9021a8292799432fa5 Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Tue, 19 May 2026 13:12:47 +1000 Subject: [PATCH 01/14] initial publish --- .../providers/google/cloud/hooks/bigquery.py | 74 ++++++++++++++++--- 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index daf4636198564..aa9accea853f0 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -30,7 +30,9 @@ from copy import deepcopy from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload +from urllib.parse import urlparse +import google_auth_httplib2 import pendulum from aiohttp import ClientSession as ClientSession from asgiref.sync import sync_to_async @@ -166,8 +168,11 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: connection_form_widgets["labels"] = StringField( lazy_gettext("Labels"), widget=BS3TextFieldWidget(), validators=[ValidJson()] ) - connection_form_widgets["labels"] = StringField( - lazy_gettext("Labels"), widget=BS3TextFieldWidget(), validators=[ValidJson()] + connection_form_widgets["http_proxy"] = StringField( + lazy_gettext("HTTP Proxy"), widget=BS3TextFieldWidget() + ) + connection_form_widgets["https_proxy"] = StringField( + lazy_gettext("HTTPS Proxy"), widget=BS3TextFieldWidget() ) return connection_form_widgets @@ -184,6 +189,8 @@ def __init__( api_resource_configs: dict | None | object = _UNSET, impersonation_scopes: str | Sequence[str] | None = None, labels: dict | None | object = _UNSET, + http_proxy: str | None | object = _UNSET, + https_proxy: str | None | object = _UNSET, **kwargs, ) -> None: super().__init__(**kwargs) @@ -219,6 +226,16 @@ def __init__( else: self.labels = labels or {} # type: ignore[assignment] + if http_proxy is _UNSET: + self.http_proxy: str | None = self._get_field("http_proxy", None) + else: + self.http_proxy = http_proxy # type: ignore[assignment] + + if https_proxy is _UNSET: + self.https_proxy: str | None = self._get_field("https_proxy", None) + else: + self.https_proxy = https_proxy # type: ignore[assignment] + self.impersonation_scopes: str | Sequence[str] | None = impersonation_scopes def get_conn(self) -> BigQueryConnection: @@ -240,6 +257,30 @@ def get_conn(self) -> BigQueryConnection: hook=self, ) + def _authorize(self) -> google_auth_httplib2.AuthorizedHttp: + """Return an authorized HTTP object, optionally configured with a proxy.""" + proxy_url = self.http_proxy or self.https_proxy + if not proxy_url: + return super()._authorize() + + import httplib2 + from googleapiclient.http import set_user_agent + + from airflow import version + + parsed = urlparse(proxy_url) + proxy_info = httplib2.ProxyInfo( + proxy_type=httplib2.socks.PROXY_TYPE_HTTP, + proxy_host=parsed.hostname, + proxy_port=parsed.port or 80, + proxy_user=parsed.username, + proxy_pass=parsed.password, + ) + http = set_user_agent( + httplib2.Http(proxy_info=proxy_info), "airflow/" + version.version + ) + return google_auth_httplib2.AuthorizedHttp(self.get_credentials(), http=http) + def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None) -> Client: """ Get an authenticated BigQuery Client. @@ -247,13 +288,28 @@ def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None :param project_id: Project ID for the project which the client acts on behalf of. :param location: Default location for jobs / datasets / tables. """ - return Client( - client_info=CLIENT_INFO, - project=project_id, - location=location, - credentials=self.get_credentials(), - client_options=self.get_client_options(), - ) + credentials = self.get_credentials() + kwargs: dict[str, Any] = { + "client_info": CLIENT_INFO, + "project": project_id, + "location": location, + "credentials": credentials, + "client_options": self.get_client_options(), + } + if self.http_proxy or self.https_proxy: + import requests + from google.auth.transport.requests import AuthorizedSession, Request + + session = requests.Session() + session.proxies = {} + if self.http_proxy: + session.proxies["http"] = self.http_proxy + if self.https_proxy: + session.proxies["https"] = self.https_proxy + kwargs["_http"] = AuthorizedSession( + credentials, auth_request=Request(session=session) + ) + return Client(**kwargs) def get_uri(self) -> str: """Override from ``DbApiHook`` for ``get_sqlalchemy_engine()``.""" From a8123b5780b9554bbce415c28f9b0653a2148d62 Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Wed, 20 May 2026 09:09:18 +1000 Subject: [PATCH 02/14] proxy support + documentation --- .../google/docs/connections/bigquery.rst | 5 + .../providers/google/cloud/hooks/bigquery.py | 178 ++++++++++++++---- 2 files changed, 147 insertions(+), 36 deletions(-) diff --git a/providers/google/docs/connections/bigquery.rst b/providers/google/docs/connections/bigquery.rst index c596d3e86b1c4..dea823cb5cb72 100644 --- a/providers/google/docs/connections/bigquery.rst +++ b/providers/google/docs/connections/bigquery.rst @@ -60,3 +60,8 @@ API Resource Configs Labels A dictionary of labels to be applied on the BigQuery job. + +http_proxy + Optional HTTP proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTP proxy. Can also be supplied via environmental variable or connection extra. +https_proxy + Optional HTTPS proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTPS proxy. Can also be supplied via environmental variable or connection extra. \ No newline at end of file diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index aa9accea853f0..89e5da2653840 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -47,7 +47,12 @@ SchemaField, UnknownJob, ) -from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference +from google.cloud.bigquery.dataset import ( + AccessEntry, + Dataset, + DatasetListItem, + DatasetReference, +) from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY from google.cloud.bigquery.routine import Routine, RoutineReference from google.cloud.bigquery.table import ( @@ -65,7 +70,10 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector -from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException +from airflow.providers.common.compat.sdk import ( + AirflowException, + AirflowOptionalProviderFeatureException, +) from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.utils.bigquery import bq_cast @@ -163,10 +171,14 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: validators=[validators.AnyOf(["INTERACTIVE", "BATCH"])], ) connection_form_widgets["api_resource_configs"] = StringField( - lazy_gettext("API Resource Configs"), widget=BS3TextFieldWidget(), validators=[ValidJson()] + lazy_gettext("API Resource Configs"), + widget=BS3TextFieldWidget(), + validators=[ValidJson()], ) connection_form_widgets["labels"] = StringField( - lazy_gettext("Labels"), widget=BS3TextFieldWidget(), validators=[ValidJson()] + lazy_gettext("Labels"), + widget=BS3TextFieldWidget(), + validators=[ValidJson()], ) connection_form_widgets["http_proxy"] = StringField( lazy_gettext("HTTP Proxy"), widget=BS3TextFieldWidget() @@ -246,7 +258,7 @@ def get_conn(self) -> BigQueryConnection: "v2", http=http_authorized, cache_discovery=False, - client_options=self.get_client_options(), + client_options=getattr(self, "get_client_options", lambda: None)(), ) return BigQueryConnection( service=service, @@ -276,9 +288,7 @@ def _authorize(self) -> google_auth_httplib2.AuthorizedHttp: proxy_user=parsed.username, proxy_pass=parsed.password, ) - http = set_user_agent( - httplib2.Http(proxy_info=proxy_info), "airflow/" + version.version - ) + http = set_user_agent(httplib2.Http(proxy_info=proxy_info), "airflow/" + version.version) return google_auth_httplib2.AuthorizedHttp(self.get_credentials(), http=http) def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None) -> Client: @@ -294,7 +304,7 @@ def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None "project": project_id, "location": location, "credentials": credentials, - "client_options": self.get_client_options(), + "client_options": getattr(self, "get_client_options", lambda: None)(), } if self.http_proxy or self.https_proxy: import requests @@ -306,9 +316,9 @@ def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None session.proxies["http"] = self.http_proxy if self.https_proxy: session.proxies["https"] = self.https_proxy - kwargs["_http"] = AuthorizedSession( - credentials, auth_request=Request(session=session) - ) + authorized_session = AuthorizedSession(credentials, auth_request=Request(session=session)) + authorized_session.proxies = dict(session.proxies) + kwargs["_http"] = authorized_session return Client(**kwargs) def get_uri(self) -> str: @@ -361,7 +371,11 @@ def _resolve_table_reference( except KeyError: # Something is wrong so we try to build the reference table_resource["tableReference"] = table_resource.get("tableReference", {}) - values = [("projectId", project_id), ("tableId", table_id), ("datasetId", dataset_id)] + values = [ + ("projectId", project_id), + ("tableId", table_id), + ("datasetId", dataset_id), + ] for key, value in values: # Check if value is already present if no use the provided one resolved_value = table_resource["tableReference"].get(key, value) @@ -401,9 +415,18 @@ def _get_pandas_df( if dialect is None: dialect = "legacy" if self.use_legacy_sql else "standard" + if self.http_proxy or self.https_proxy: + return self.get_client().query(sql, timeout=10).to_dataframe(create_bqstorage_client=False) + credentials, project_id = self.get_credentials_and_project_id() - return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) + return read_gbq( + sql, + project_id=project_id, + dialect=dialect, + credentials=credentials, + **kwargs, + ) def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.DataFrame: try: @@ -418,17 +441,35 @@ def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.Dat credentials, project_id = self.get_credentials_and_project_id() - pandas_df = read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) + pandas_df = read_gbq( + sql, + project_id=project_id, + dialect=dialect, + credentials=credentials, + **kwargs, + ) return pl.from_pandas(pandas_df) @overload def get_df( - self, sql, parameters=None, dialect=None, *, df_type: Literal["pandas"] = "pandas", **kwargs + self, + sql, + parameters=None, + dialect=None, + *, + df_type: Literal["pandas"] = "pandas", + **kwargs, ) -> pd.DataFrame: ... @overload def get_df( - self, sql, parameters=None, dialect=None, *, df_type: Literal["polars"], **kwargs + self, + sql, + parameters=None, + dialect=None, + *, + df_type: Literal["polars"], + **kwargs, ) -> pl.DataFrame: ... def get_df( @@ -623,7 +664,9 @@ def create_empty_dataset( ) # dataset_reference has no param but we can fallback to default value self.log.info( - "%s was not specified in `dataset_reference`. Will use default value %s.", param, value + "%s was not specified in `dataset_reference`. Will use default value %s.", + param, + value, ) dataset_reference["datasetReference"][param] = value @@ -729,7 +772,10 @@ def update_table( """ fields = fields or list(table_resource.keys()) table_resource = self._resolve_table_reference( - table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id + table_resource=table_resource, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, ) table = Table.from_api_repr(table_resource) @@ -785,7 +831,13 @@ def insert_all( The default value is false, which indicates the task should not fail even if any insertion errors occur. """ - self.log.info("Inserting %s row(s) into table %s:%s.%s", len(rows), project_id, dataset_id, table_id) + self.log.info( + "Inserting %s row(s) into table %s:%s.%s", + len(rows), + project_id, + dataset_id, + table_id, + ) table_ref = TableReference(dataset_ref=DatasetReference(project_id, dataset_id), table_id=table_id) bq_client = self.get_client(project_id=project_id) @@ -811,7 +863,12 @@ def insert_all( if fail_on_error: raise AirflowException(f"BigQuery job failed. Error was: {error_msg}") else: - self.log.info("All row(s) inserted successfully: %s:%s.%s", project_id, dataset_id, table_id) + self.log.info( + "All row(s) inserted successfully: %s:%s.%s", + project_id, + dataset_id, + table_id, + ) @GoogleBaseHook.fallback_to_default_project_id def update_dataset( @@ -960,7 +1017,11 @@ def run_grant_dataset_view_access( view_access = AccessEntry( role=None, entity_type="view", - entity_id={"projectId": view_project, "datasetId": view_dataset, "tableId": view_table}, + entity_id={ + "projectId": view_project, + "datasetId": view_dataset, + "tableId": view_table, + }, ) dataset = self.get_dataset(project_id=project_id, dataset_id=source_dataset) @@ -977,7 +1038,9 @@ def run_grant_dataset_view_access( ) dataset.access_entries += [view_access] dataset = self.update_dataset( - fields=["access"], dataset_resource=dataset.to_api_repr(), project_id=project_id + fields=["access"], + dataset_resource=dataset.to_api_repr(), + project_id=project_id, ) else: self.log.info( @@ -992,7 +1055,10 @@ def run_grant_dataset_view_access( @GoogleBaseHook.fallback_to_default_project_id def run_table_upsert( - self, dataset_id: str, table_resource: dict[str, Any], project_id: str = PROVIDE_PROJECT_ID + self, + dataset_id: str, + table_resource: dict[str, Any], + project_id: str = PROVIDE_PROJECT_ID, ) -> dict[str, Any]: """ Update a table if it exists, otherwise create a new one. @@ -1008,7 +1074,10 @@ def run_table_upsert( """ table_id = table_resource["tableReference"]["tableId"] table_resource = self._resolve_table_reference( - table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id + table_resource=table_resource, + project_id=project_id, + dataset_id=dataset_id, + table_id=table_id, ) tables_list_resp = self.get_dataset_tables(dataset_id=dataset_id, project_id=project_id) @@ -1016,9 +1085,17 @@ def run_table_upsert( self.log.info("Table %s:%s.%s exists, updating.", project_id, dataset_id, table_id) table = self.update_table(table_resource=table_resource) else: - self.log.info("Table %s:%s.%s does not exist. creating.", project_id, dataset_id, table_id) + self.log.info( + "Table %s:%s.%s does not exist. creating.", + project_id, + dataset_id, + table_id, + ) table = self.create_table( - dataset_id=dataset_id, table_id=table_id, table_resource=table_resource, project_id=project_id + dataset_id=dataset_id, + table_id=table_id, + table_resource=table_resource, + project_id=project_id, ).to_api_repr() return table @@ -1192,7 +1269,8 @@ def update_table_schema( """ def _build_new_schema( - current_schema: list[dict[str, Any]], schema_fields_updates: list[dict[str, Any]] + current_schema: list[dict[str, Any]], + schema_fields_updates: list[dict[str, Any]], ) -> list[dict[str, Any]]: # Turn schema_field_updates into a dict keyed on field names schema_fields_updates_dict = {field["name"]: field for field in deepcopy(schema_fields_updates)} @@ -1359,7 +1437,12 @@ def update_routine( merged, list(_ROUTINE_WRITABLE_PROPERTIES), retry=retry, timeout=timeout ) out_ref = result.reference - self.log.info("Updated routine: %s.%s.%s", out_ref.project, out_ref.dataset_id, out_ref.routine_id) + self.log.info( + "Updated routine: %s.%s.%s", + out_ref.project, + out_ref.dataset_id, + out_ref.routine_id, + ) return result @GoogleBaseHook.fallback_to_default_project_id @@ -1613,7 +1696,11 @@ def insert_job( client = self.get_client(project_id=project_id, location=location) job_data = { "configuration": configuration, - "jobReference": {"jobId": job_id, "projectId": project_id, "location": location}, + "jobReference": { + "jobId": job_id, + "projectId": project_id, + "location": location, + }, } supported_jobs: dict[str, type[CopyJob] | type[QueryJob] | type[LoadJob] | type[ExtractJob]] = { @@ -2168,7 +2255,10 @@ def _prepare_query_configuration( # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions - allowed_schema_update_options = ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"] + allowed_schema_update_options = [ + "ALLOW_FIELD_ADDITION", + "ALLOW_FIELD_RELAXATION", + ] if not set(allowed_schema_update_options).issuperset(set(schema_update_options)): raise ValueError( @@ -2178,7 +2268,8 @@ def _prepare_query_configuration( if destination_dataset_table: destination_project, destination_dataset, destination_table = self.hook.split_tablename( - table_input=destination_dataset_table, default_project_id=self.project_id + table_input=destination_dataset_table, + default_project_id=self.project_id, ) destination_dataset_table = { # type: ignore @@ -2229,7 +2320,10 @@ def _prepare_query_configuration( _validate_value(param_name, configuration["query"][param_name], param_type) if param_name == "schemaUpdateOptions" and param: - self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options) + self.log.info( + "Adding experimental 'schemaUpdateOptions': %s", + schema_update_options, + ) if param_name == "destinationTable": for key in ["projectId", "datasetId", "tableId"]: @@ -2405,7 +2499,10 @@ async def get_job_instance( ) async def _get_job( - self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None + self, + job_id: str | None, + project_id: str = PROVIDE_PROJECT_ID, + location: str | None = None, ) -> BigQueryJob | UnknownJob: """Get BigQuery job by its ID, project ID and location.""" sync_hook = await self.get_sync_hook() @@ -2413,7 +2510,10 @@ async def _get_job( return job async def get_job_status( - self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None + self, + job_id: str | None, + project_id: str = PROVIDE_PROJECT_ID, + location: str | None = None, ) -> dict[str, str]: job = await self._get_job(job_id=job_id, project_id=project_id, location=location) if job.state == "DONE": @@ -2464,7 +2564,13 @@ async def cancel_job(self, job_id: str, project_id: str | None, location: str | """ async with ClientSession() as session: token = await self.get_token(session=session) - job = Job(job_id=job_id, project=project_id, location=location, token=token, session=session) # type: ignore[arg-type] + job = Job( + job_id=job_id, + project=project_id, + location=location, + token=token, + session=session, + ) # type: ignore[arg-type] self.log.info( "Attempting to cancel BigQuery job: %s in project: %s, location: %s", From c8c509cfe8d6b5287f37c2fc32db3b817760e8a9 Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Wed, 20 May 2026 10:56:51 +1000 Subject: [PATCH 03/14] newline at end of file, does this fix? --- providers/google/docs/connections/bigquery.rst | 2 +- .../src/airflow/providers/google/cloud/hooks/bigquery.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/providers/google/docs/connections/bigquery.rst b/providers/google/docs/connections/bigquery.rst index dea823cb5cb72..669557ab449f7 100644 --- a/providers/google/docs/connections/bigquery.rst +++ b/providers/google/docs/connections/bigquery.rst @@ -64,4 +64,4 @@ Labels http_proxy Optional HTTP proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTP proxy. Can also be supplied via environmental variable or connection extra. https_proxy - Optional HTTPS proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTPS proxy. Can also be supplied via environmental variable or connection extra. \ No newline at end of file + Optional HTTPS proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTPS proxy. Can also be supplied via environmental variable or connection extra. diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 89e5da2653840..c95924c704f34 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -2564,13 +2564,7 @@ async def cancel_job(self, job_id: str, project_id: str | None, location: str | """ async with ClientSession() as session: token = await self.get_token(session=session) - job = Job( - job_id=job_id, - project=project_id, - location=location, - token=token, - session=session, - ) # type: ignore[arg-type] + job = Job(job_id=job_id, project=project_id, location=location, token=token, session=session) # type: ignore[arg-type] self.log.info( "Attempting to cancel BigQuery job: %s in project: %s, location: %s", From 444c211d4f5018da13fe9f88fe43ae71f438ebac Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Wed, 20 May 2026 14:44:44 +1000 Subject: [PATCH 04/14] add unit tests --- .../unit/google/cloud/hooks/test_bigquery.py | 254 ++++++++++++++++++ 1 file changed, 254 insertions(+) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index ddacebbdbe18a..1477fb7a02698 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -2313,3 +2313,257 @@ def test_insert_job_hook_lineage(self, mock_client, mock_query_job, mock_send_li ) mock_send_lineage.assert_called_once_with(context=self.hook, job=mock_job_instance) + + +@pytest.mark.db_test +class TestBigQueryHookProxy: + """Tests that HTTP/HTTPS proxy settings are propagated through _authorize, get_client, and _get_pandas_df.""" + + def _make_hook(self, http_proxy=None, https_proxy=None): + class MockedBigQueryHook(BigQueryHook): + def get_credentials_and_project_id(self): + return CREDENTIALS, PROJECT_ID + + def get_credentials(self): + return mock.MagicMock(name="credentials") + + return MockedBigQueryHook(http_proxy=http_proxy, https_proxy=https_proxy) + + # --- __init__ --- + + def test_http_proxy_stored_from_constructor(self): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + assert hook.http_proxy == "http://proxy.example.com:3128" + assert hook.https_proxy is None + + def test_https_proxy_stored_from_constructor(self): + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + assert hook.http_proxy is None + assert hook.https_proxy == "https://proxy.example.com:3129" + + def test_both_proxies_stored_from_constructor(self): + hook = self._make_hook( + http_proxy="http://proxy.example.com:3128", + https_proxy="https://proxy.example.com:3129", + ) + assert hook.http_proxy == "http://proxy.example.com:3128" + assert hook.https_proxy == "https://proxy.example.com:3129" + + def test_no_proxy_defaults_to_none(self): + hook = self._make_hook() + assert hook.http_proxy is None + assert hook.https_proxy is None + + # --- _authorize --- + + @mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook._authorize") + def test_authorize_without_proxy_delegates_to_base(self, mock_base_authorize): + hook = self._make_hook() + result = hook._authorize() + mock_base_authorize.assert_called_once() + assert result == mock_base_authorize.return_value + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_with_http_proxy_creates_proxy_info( + self, mock_proxy_info, mock_http, _mock_set_user_agent, mock_authorized_http + ): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + result = hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=3128, + proxy_user=None, + proxy_pass=None, + ) + mock_http.assert_called_once_with(proxy_info=mock_proxy_info.return_value) + mock_authorized_http.assert_called_once() + assert result == mock_authorized_http.return_value + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_with_https_proxy_creates_proxy_info( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, mock_authorized_http + ): + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + result = hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=3129, + proxy_user=None, + proxy_pass=None, + ) + assert result == mock_authorized_http.return_value + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_proxy_with_username_and_password( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http + ): + hook = self._make_hook(http_proxy="http://user:secret@proxy.example.com:3128") + hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=3128, + proxy_user="user", + proxy_pass="secret", + ) + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_proxy_without_port_defaults_to_80( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http + ): + hook = self._make_hook(http_proxy="http://proxy.example.com") + hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=80, + proxy_user=None, + proxy_pass=None, + ) + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_http_proxy_used_when_both_proxies_set( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http + ): + hook = self._make_hook( + http_proxy="http://http-proxy.example.com:3128", + https_proxy="https://https-proxy.example.com:3129", + ) + hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="http-proxy.example.com", + proxy_port=3128, + proxy_user=None, + proxy_pass=None, + ) + + # --- get_client --- + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_without_proxy_omits_http_kwarg(self, mock_client): + hook = self._make_hook() + hook.get_client(project_id=PROJECT_ID) + assert "_http" not in mock_client.call_args.kwargs + + @mock.patch("google.auth.transport.requests.Request") + @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_with_http_proxy_sets_session_http_proxy( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + assert session_instance.proxies["http"] == "http://proxy.example.com:3128" + assert mock_client.call_args.kwargs.get("_http") == mock_authorized_session_cls.return_value + + @mock.patch("google.auth.transport.requests.Request") + @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_with_https_proxy_sets_session_https_proxy( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + assert session_instance.proxies["https"] == "https://proxy.example.com:3129" + assert mock_client.call_args.kwargs.get("_http") == mock_authorized_session_cls.return_value + + @mock.patch("google.auth.transport.requests.Request") + @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_with_both_proxies_sets_both_in_session( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook( + http_proxy="http://proxy.example.com:3128", + https_proxy="https://proxy.example.com:3129", + ) + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + assert session_instance.proxies["http"] == "http://proxy.example.com:3128" + assert session_instance.proxies["https"] == "https://proxy.example.com:3129" + + @mock.patch("google.auth.transport.requests.Request") + @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_passes_authorized_session_built_with_proxy_session( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + mock_request_cls.assert_called_once_with(session=session_instance) + mock_authorized_session_cls.assert_called_once_with( + hook.get_credentials(), auth_request=mock_request_cls.return_value + ) + + # --- _get_pandas_df --- + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_pandas_df_with_http_proxy_uses_get_client(self, mock_get_client): + import pandas as pd + + mock_get_client.return_value.query.return_value.to_dataframe.return_value = pd.DataFrame({"a": [1]}) + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + result = hook._get_pandas_df("SELECT 1") + + mock_get_client.assert_called_once() + mock_get_client.return_value.query.assert_called_once_with("SELECT 1", timeout=10) + mock_get_client.return_value.query.return_value.to_dataframe.assert_called_once_with( + create_bqstorage_client=False + ) + assert isinstance(result, pd.DataFrame) + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_pandas_df_with_https_proxy_uses_get_client(self, mock_get_client): + import pandas as pd + + mock_get_client.return_value.query.return_value.to_dataframe.return_value = pd.DataFrame({"a": [1]}) + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + hook._get_pandas_df("SELECT 1") + + mock_get_client.assert_called_once() + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.read_gbq") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_pandas_df_without_proxy_uses_read_gbq(self, mock_get_client, mock_read_gbq): + import pandas as pd + + mock_read_gbq.return_value = pd.DataFrame({"a": [1]}) + hook = self._make_hook() + hook._get_pandas_df("SELECT 1") + + mock_get_client.assert_not_called() + mock_read_gbq.assert_called_once() From 6afecdcae901338440bf8bf173a8c9b1fa378d6d Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Wed, 20 May 2026 16:00:06 +1000 Subject: [PATCH 05/14] unit test revisions --- .../unit/google/cloud/hooks/test_bigquery.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index 1477fb7a02698..2d13ffb7ab7b0 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -2363,12 +2363,13 @@ def test_authorize_without_proxy_delegates_to_base(self, mock_base_authorize): mock_base_authorize.assert_called_once() assert result == mock_base_authorize.return_value + @mock.patch("httplib2.socks") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") @mock.patch("googleapiclient.http.set_user_agent") @mock.patch("httplib2.Http") @mock.patch("httplib2.ProxyInfo") def test_authorize_with_http_proxy_creates_proxy_info( - self, mock_proxy_info, mock_http, _mock_set_user_agent, mock_authorized_http + self, mock_proxy_info, mock_http, _mock_set_user_agent, mock_authorized_http, _mock_socks ): hook = self._make_hook(http_proxy="http://proxy.example.com:3128") result = hook._authorize() @@ -2384,12 +2385,13 @@ def test_authorize_with_http_proxy_creates_proxy_info( mock_authorized_http.assert_called_once() assert result == mock_authorized_http.return_value + @mock.patch("httplib2.socks") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") @mock.patch("googleapiclient.http.set_user_agent") @mock.patch("httplib2.Http") @mock.patch("httplib2.ProxyInfo") def test_authorize_with_https_proxy_creates_proxy_info( - self, mock_proxy_info, _mock_http, _mock_set_user_agent, mock_authorized_http + self, mock_proxy_info, _mock_http, _mock_set_user_agent, mock_authorized_http, _mock_socks ): hook = self._make_hook(https_proxy="https://proxy.example.com:3129") result = hook._authorize() @@ -2403,12 +2405,13 @@ def test_authorize_with_https_proxy_creates_proxy_info( ) assert result == mock_authorized_http.return_value + @mock.patch("httplib2.socks") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") @mock.patch("googleapiclient.http.set_user_agent") @mock.patch("httplib2.Http") @mock.patch("httplib2.ProxyInfo") def test_authorize_proxy_with_username_and_password( - self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks ): hook = self._make_hook(http_proxy="http://user:secret@proxy.example.com:3128") hook._authorize() @@ -2421,12 +2424,13 @@ def test_authorize_proxy_with_username_and_password( proxy_pass="secret", ) + @mock.patch("httplib2.socks") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") @mock.patch("googleapiclient.http.set_user_agent") @mock.patch("httplib2.Http") @mock.patch("httplib2.ProxyInfo") def test_authorize_proxy_without_port_defaults_to_80( - self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks ): hook = self._make_hook(http_proxy="http://proxy.example.com") hook._authorize() @@ -2439,12 +2443,13 @@ def test_authorize_proxy_without_port_defaults_to_80( proxy_pass=None, ) + @mock.patch("httplib2.socks") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") @mock.patch("googleapiclient.http.set_user_agent") @mock.patch("httplib2.Http") @mock.patch("httplib2.ProxyInfo") def test_authorize_http_proxy_used_when_both_proxies_set( - self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks ): hook = self._make_hook( http_proxy="http://http-proxy.example.com:3128", @@ -2526,7 +2531,7 @@ def test_get_client_passes_authorized_session_built_with_proxy_session( session_instance = mock_session_cls.return_value mock_request_cls.assert_called_once_with(session=session_instance) mock_authorized_session_cls.assert_called_once_with( - hook.get_credentials(), auth_request=mock_request_cls.return_value + mock.ANY, auth_request=mock_request_cls.return_value ) # --- _get_pandas_df --- From e04af476a44a21d439c6368f3ca50275cf9bfa7a Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Thu, 18 Jun 2026 15:52:57 +1000 Subject: [PATCH 06/14] reformat --- .../providers/google/cloud/hooks/bigquery.py | 146 ++++-------------- 1 file changed, 27 insertions(+), 119 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index c95924c704f34..1ebe17f279c72 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -30,6 +30,7 @@ from copy import deepcopy from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload + from urllib.parse import urlparse import google_auth_httplib2 @@ -47,12 +48,7 @@ SchemaField, UnknownJob, ) -from google.cloud.bigquery.dataset import ( - AccessEntry, - Dataset, - DatasetListItem, - DatasetReference, -) +from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY from google.cloud.bigquery.routine import Routine, RoutineReference from google.cloud.bigquery.table import ( @@ -70,10 +66,7 @@ from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector -from airflow.providers.common.compat.sdk import ( - AirflowException, - AirflowOptionalProviderFeatureException, -) +from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.utils.bigquery import bq_cast @@ -171,14 +164,10 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: validators=[validators.AnyOf(["INTERACTIVE", "BATCH"])], ) connection_form_widgets["api_resource_configs"] = StringField( - lazy_gettext("API Resource Configs"), - widget=BS3TextFieldWidget(), - validators=[ValidJson()], + lazy_gettext("API Resource Configs"), widget=BS3TextFieldWidget(), validators=[ValidJson()] ) connection_form_widgets["labels"] = StringField( - lazy_gettext("Labels"), - widget=BS3TextFieldWidget(), - validators=[ValidJson()], + lazy_gettext("Labels"), widget=BS3TextFieldWidget(), validators=[ValidJson()] ) connection_form_widgets["http_proxy"] = StringField( lazy_gettext("HTTP Proxy"), widget=BS3TextFieldWidget() @@ -371,11 +360,7 @@ def _resolve_table_reference( except KeyError: # Something is wrong so we try to build the reference table_resource["tableReference"] = table_resource.get("tableReference", {}) - values = [ - ("projectId", project_id), - ("tableId", table_id), - ("datasetId", dataset_id), - ] + values = [("projectId", project_id), ("tableId", table_id), ("datasetId", dataset_id)] for key, value in values: # Check if value is already present if no use the provided one resolved_value = table_resource["tableReference"].get(key, value) @@ -441,35 +426,17 @@ def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.Dat credentials, project_id = self.get_credentials_and_project_id() - pandas_df = read_gbq( - sql, - project_id=project_id, - dialect=dialect, - credentials=credentials, - **kwargs, - ) + pandas_df = read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) return pl.from_pandas(pandas_df) @overload def get_df( - self, - sql, - parameters=None, - dialect=None, - *, - df_type: Literal["pandas"] = "pandas", - **kwargs, + self, sql, parameters=None, dialect=None, *, df_type: Literal["pandas"] = "pandas", **kwargs ) -> pd.DataFrame: ... @overload def get_df( - self, - sql, - parameters=None, - dialect=None, - *, - df_type: Literal["polars"], - **kwargs, + self, sql, parameters=None, dialect=None, *, df_type: Literal["polars"], **kwargs ) -> pl.DataFrame: ... def get_df( @@ -664,9 +631,7 @@ def create_empty_dataset( ) # dataset_reference has no param but we can fallback to default value self.log.info( - "%s was not specified in `dataset_reference`. Will use default value %s.", - param, - value, + "%s was not specified in `dataset_reference`. Will use default value %s.", param, value ) dataset_reference["datasetReference"][param] = value @@ -772,10 +737,7 @@ def update_table( """ fields = fields or list(table_resource.keys()) table_resource = self._resolve_table_reference( - table_resource=table_resource, - project_id=project_id, - dataset_id=dataset_id, - table_id=table_id, + table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id ) table = Table.from_api_repr(table_resource) @@ -831,13 +793,7 @@ def insert_all( The default value is false, which indicates the task should not fail even if any insertion errors occur. """ - self.log.info( - "Inserting %s row(s) into table %s:%s.%s", - len(rows), - project_id, - dataset_id, - table_id, - ) + self.log.info("Inserting %s row(s) into table %s:%s.%s", len(rows), project_id, dataset_id, table_id) table_ref = TableReference(dataset_ref=DatasetReference(project_id, dataset_id), table_id=table_id) bq_client = self.get_client(project_id=project_id) @@ -863,12 +819,7 @@ def insert_all( if fail_on_error: raise AirflowException(f"BigQuery job failed. Error was: {error_msg}") else: - self.log.info( - "All row(s) inserted successfully: %s:%s.%s", - project_id, - dataset_id, - table_id, - ) + self.log.info("All row(s) inserted successfully: %s:%s.%s", project_id, dataset_id, table_id) @GoogleBaseHook.fallback_to_default_project_id def update_dataset( @@ -1017,11 +968,7 @@ def run_grant_dataset_view_access( view_access = AccessEntry( role=None, entity_type="view", - entity_id={ - "projectId": view_project, - "datasetId": view_dataset, - "tableId": view_table, - }, + entity_id={"projectId": view_project, "datasetId": view_dataset, "tableId": view_table}, ) dataset = self.get_dataset(project_id=project_id, dataset_id=source_dataset) @@ -1038,9 +985,7 @@ def run_grant_dataset_view_access( ) dataset.access_entries += [view_access] dataset = self.update_dataset( - fields=["access"], - dataset_resource=dataset.to_api_repr(), - project_id=project_id, + fields=["access"], dataset_resource=dataset.to_api_repr(), project_id=project_id ) else: self.log.info( @@ -1055,10 +1000,7 @@ def run_grant_dataset_view_access( @GoogleBaseHook.fallback_to_default_project_id def run_table_upsert( - self, - dataset_id: str, - table_resource: dict[str, Any], - project_id: str = PROVIDE_PROJECT_ID, + self, dataset_id: str, table_resource: dict[str, Any], project_id: str = PROVIDE_PROJECT_ID ) -> dict[str, Any]: """ Update a table if it exists, otherwise create a new one. @@ -1074,10 +1016,7 @@ def run_table_upsert( """ table_id = table_resource["tableReference"]["tableId"] table_resource = self._resolve_table_reference( - table_resource=table_resource, - project_id=project_id, - dataset_id=dataset_id, - table_id=table_id, + table_resource=table_resource, project_id=project_id, dataset_id=dataset_id, table_id=table_id ) tables_list_resp = self.get_dataset_tables(dataset_id=dataset_id, project_id=project_id) @@ -1085,17 +1024,9 @@ def run_table_upsert( self.log.info("Table %s:%s.%s exists, updating.", project_id, dataset_id, table_id) table = self.update_table(table_resource=table_resource) else: - self.log.info( - "Table %s:%s.%s does not exist. creating.", - project_id, - dataset_id, - table_id, - ) + self.log.info("Table %s:%s.%s does not exist. creating.", project_id, dataset_id, table_id) table = self.create_table( - dataset_id=dataset_id, - table_id=table_id, - table_resource=table_resource, - project_id=project_id, + dataset_id=dataset_id, table_id=table_id, table_resource=table_resource, project_id=project_id ).to_api_repr() return table @@ -1269,8 +1200,7 @@ def update_table_schema( """ def _build_new_schema( - current_schema: list[dict[str, Any]], - schema_fields_updates: list[dict[str, Any]], + current_schema: list[dict[str, Any]], schema_fields_updates: list[dict[str, Any]] ) -> list[dict[str, Any]]: # Turn schema_field_updates into a dict keyed on field names schema_fields_updates_dict = {field["name"]: field for field in deepcopy(schema_fields_updates)} @@ -1437,12 +1367,7 @@ def update_routine( merged, list(_ROUTINE_WRITABLE_PROPERTIES), retry=retry, timeout=timeout ) out_ref = result.reference - self.log.info( - "Updated routine: %s.%s.%s", - out_ref.project, - out_ref.dataset_id, - out_ref.routine_id, - ) + self.log.info("Updated routine: %s.%s.%s", out_ref.project, out_ref.dataset_id, out_ref.routine_id) return result @GoogleBaseHook.fallback_to_default_project_id @@ -1696,11 +1621,7 @@ def insert_job( client = self.get_client(project_id=project_id, location=location) job_data = { "configuration": configuration, - "jobReference": { - "jobId": job_id, - "projectId": project_id, - "location": location, - }, + "jobReference": {"jobId": job_id, "projectId": project_id, "location": location}, } supported_jobs: dict[str, type[CopyJob] | type[QueryJob] | type[LoadJob] | type[ExtractJob]] = { @@ -2255,10 +2176,7 @@ def _prepare_query_configuration( # for more details: # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions - allowed_schema_update_options = [ - "ALLOW_FIELD_ADDITION", - "ALLOW_FIELD_RELAXATION", - ] + allowed_schema_update_options = ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"] if not set(allowed_schema_update_options).issuperset(set(schema_update_options)): raise ValueError( @@ -2268,8 +2186,7 @@ def _prepare_query_configuration( if destination_dataset_table: destination_project, destination_dataset, destination_table = self.hook.split_tablename( - table_input=destination_dataset_table, - default_project_id=self.project_id, + table_input=destination_dataset_table, default_project_id=self.project_id ) destination_dataset_table = { # type: ignore @@ -2320,10 +2237,7 @@ def _prepare_query_configuration( _validate_value(param_name, configuration["query"][param_name], param_type) if param_name == "schemaUpdateOptions" and param: - self.log.info( - "Adding experimental 'schemaUpdateOptions': %s", - schema_update_options, - ) + self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options) if param_name == "destinationTable": for key in ["projectId", "datasetId", "tableId"]: @@ -2499,10 +2413,7 @@ async def get_job_instance( ) async def _get_job( - self, - job_id: str | None, - project_id: str = PROVIDE_PROJECT_ID, - location: str | None = None, + self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None ) -> BigQueryJob | UnknownJob: """Get BigQuery job by its ID, project ID and location.""" sync_hook = await self.get_sync_hook() @@ -2510,10 +2421,7 @@ async def _get_job( return job async def get_job_status( - self, - job_id: str | None, - project_id: str = PROVIDE_PROJECT_ID, - location: str | None = None, + self, job_id: str | None, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None ) -> dict[str, str]: job = await self._get_job(job_id=job_id, project_id=project_id, location=location) if job.state == "DONE": From 10fdc9bd0880be5325a97d540fdd905e24852fcd Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Thu, 18 Jun 2026 15:54:24 +1000 Subject: [PATCH 07/14] return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) --- .../src/airflow/providers/google/cloud/hooks/bigquery.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 1ebe17f279c72..9fd9b993533cd 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -405,13 +405,7 @@ def _get_pandas_df( credentials, project_id = self.get_credentials_and_project_id() - return read_gbq( - sql, - project_id=project_id, - dialect=dialect, - credentials=credentials, - **kwargs, - ) + return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) def _get_polars_df(self, sql, parameters=None, dialect=None, **kwargs) -> pl.DataFrame: try: From f2092f17f5545e8480b54b16dccc7d78819549c4 Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Fri, 19 Jun 2026 08:52:50 +1000 Subject: [PATCH 08/14] remove space --- .../google/src/airflow/providers/google/cloud/hooks/bigquery.py | 1 - 1 file changed, 1 deletion(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 9fd9b993533cd..3a7c190775799 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -30,7 +30,6 @@ from copy import deepcopy from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload - from urllib.parse import urlparse import google_auth_httplib2 From d79b98d2489ffc1c4e18aa02f8bc2c83a4e9a6a6 Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Tue, 23 Jun 2026 09:25:46 +1000 Subject: [PATCH 09/14] remove unit tests, consolidate imports, remove hardcoded timeout + add dialect + kwargs, revert unecessary client options change --- .../providers/google/cloud/hooks/bigquery.py | 48 ++----- .../unit/google/cloud/hooks/test_bigquery.py | 129 ------------------ 2 files changed, 12 insertions(+), 165 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 3a7c190775799..fc1cabf7548ea 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -31,32 +31,21 @@ from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload from urllib.parse import urlparse - +import requests +from google.auth.transport.requests import AuthorizedSession, Request import google_auth_httplib2 import pendulum from aiohttp import ClientSession as ClientSession from asgiref.sync import sync_to_async from gcloud.aio.bigquery import Job, Table as Table_async -from google.cloud.bigquery import ( - DEFAULT_RETRY, - Client, - CopyJob, - ExtractJob, - LoadJob, - QueryJob, - SchemaField, - UnknownJob, -) +from google.cloud.bigquery import DEFAULT_RETRY, Client, CopyJob, ExtractJob, LoadJob, QueryJob, QueryJobConfig, SchemaField, UnknownJob +import httplib2 +from googleapiclient.http import set_user_agent +from airflow import version from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY from google.cloud.bigquery.routine import Routine, RoutineReference -from google.cloud.bigquery.table import ( - Row, - RowIterator, - Table, - TableListItem, - TableReference, -) +from google.cloud.bigquery.table import Row, RowIterator, Table, TableListItem, TableReference from google.cloud.exceptions import NotFound from googleapiclient.discovery import build from pandas_gbq import read_gbq @@ -72,13 +61,7 @@ from airflow.providers.google.cloud.utils.credentials_provider import _get_scopes from airflow.providers.google.cloud.utils.lineage import send_hook_lineage_for_bq_job from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import ( - _UNSET, - PROVIDE_PROJECT_ID, - GoogleBaseAsyncHook, - GoogleBaseHook, - get_field, -) +from airflow.providers.google.common.hooks.base_google import _UNSET, PROVIDE_PROJECT_ID, GoogleBaseAsyncHook, GoogleBaseHook, get_field from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.hashlib_wrapper import md5 from airflow.utils.helpers import convert_camel_to_snake @@ -246,7 +229,7 @@ def get_conn(self) -> BigQueryConnection: "v2", http=http_authorized, cache_discovery=False, - client_options=getattr(self, "get_client_options", lambda: None)(), + client_options=self.get_client_options(), ) return BigQueryConnection( service=service, @@ -262,12 +245,6 @@ def _authorize(self) -> google_auth_httplib2.AuthorizedHttp: proxy_url = self.http_proxy or self.https_proxy if not proxy_url: return super()._authorize() - - import httplib2 - from googleapiclient.http import set_user_agent - - from airflow import version - parsed = urlparse(proxy_url) proxy_info = httplib2.ProxyInfo( proxy_type=httplib2.socks.PROXY_TYPE_HTTP, @@ -295,9 +272,6 @@ def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None "client_options": getattr(self, "get_client_options", lambda: None)(), } if self.http_proxy or self.https_proxy: - import requests - from google.auth.transport.requests import AuthorizedSession, Request - session = requests.Session() session.proxies = {} if self.http_proxy: @@ -394,13 +368,15 @@ def _get_pandas_df( sql: str, parameters: Iterable | Mapping[str, Any] | None = None, dialect: str | None = None, + timeout: float | None = None, **kwargs, ) -> pd.DataFrame: if dialect is None: dialect = "legacy" if self.use_legacy_sql else "standard" if self.http_proxy or self.https_proxy: - return self.get_client().query(sql, timeout=10).to_dataframe(create_bqstorage_client=False) + job_config = QueryJobConfig(use_legacy_sql=(dialect == "legacy")) + return (self.get_client().query(sql, job_config=job_config, timeout=timeout, **kwargs).to_dataframe(create_bqstorage_client=False)) credentials, project_id = self.get_credentials_and_project_id() diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index 2d13ffb7ab7b0..b2d1825251fcb 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -2329,31 +2329,6 @@ def get_credentials(self): return MockedBigQueryHook(http_proxy=http_proxy, https_proxy=https_proxy) - # --- __init__ --- - - def test_http_proxy_stored_from_constructor(self): - hook = self._make_hook(http_proxy="http://proxy.example.com:3128") - assert hook.http_proxy == "http://proxy.example.com:3128" - assert hook.https_proxy is None - - def test_https_proxy_stored_from_constructor(self): - hook = self._make_hook(https_proxy="https://proxy.example.com:3129") - assert hook.http_proxy is None - assert hook.https_proxy == "https://proxy.example.com:3129" - - def test_both_proxies_stored_from_constructor(self): - hook = self._make_hook( - http_proxy="http://proxy.example.com:3128", - https_proxy="https://proxy.example.com:3129", - ) - assert hook.http_proxy == "http://proxy.example.com:3128" - assert hook.https_proxy == "https://proxy.example.com:3129" - - def test_no_proxy_defaults_to_none(self): - hook = self._make_hook() - assert hook.http_proxy is None - assert hook.https_proxy is None - # --- _authorize --- @mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook._authorize") @@ -2385,45 +2360,6 @@ def test_authorize_with_http_proxy_creates_proxy_info( mock_authorized_http.assert_called_once() assert result == mock_authorized_http.return_value - @mock.patch("httplib2.socks") - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") - @mock.patch("googleapiclient.http.set_user_agent") - @mock.patch("httplib2.Http") - @mock.patch("httplib2.ProxyInfo") - def test_authorize_with_https_proxy_creates_proxy_info( - self, mock_proxy_info, _mock_http, _mock_set_user_agent, mock_authorized_http, _mock_socks - ): - hook = self._make_hook(https_proxy="https://proxy.example.com:3129") - result = hook._authorize() - - mock_proxy_info.assert_called_once_with( - proxy_type=mock.ANY, - proxy_host="proxy.example.com", - proxy_port=3129, - proxy_user=None, - proxy_pass=None, - ) - assert result == mock_authorized_http.return_value - - @mock.patch("httplib2.socks") - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") - @mock.patch("googleapiclient.http.set_user_agent") - @mock.patch("httplib2.Http") - @mock.patch("httplib2.ProxyInfo") - def test_authorize_proxy_with_username_and_password( - self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks - ): - hook = self._make_hook(http_proxy="http://user:secret@proxy.example.com:3128") - hook._authorize() - - mock_proxy_info.assert_called_once_with( - proxy_type=mock.ANY, - proxy_host="proxy.example.com", - proxy_port=3128, - proxy_user="user", - proxy_pass="secret", - ) - @mock.patch("httplib2.socks") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") @mock.patch("googleapiclient.http.set_user_agent") @@ -2443,28 +2379,6 @@ def test_authorize_proxy_without_port_defaults_to_80( proxy_pass=None, ) - @mock.patch("httplib2.socks") - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") - @mock.patch("googleapiclient.http.set_user_agent") - @mock.patch("httplib2.Http") - @mock.patch("httplib2.ProxyInfo") - def test_authorize_http_proxy_used_when_both_proxies_set( - self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks - ): - hook = self._make_hook( - http_proxy="http://http-proxy.example.com:3128", - https_proxy="https://https-proxy.example.com:3129", - ) - hook._authorize() - - mock_proxy_info.assert_called_once_with( - proxy_type=mock.ANY, - proxy_host="http-proxy.example.com", - proxy_port=3128, - proxy_user=None, - proxy_pass=None, - ) - # --- get_client --- @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") @@ -2501,39 +2415,6 @@ def test_get_client_with_https_proxy_sets_session_https_proxy( assert session_instance.proxies["https"] == "https://proxy.example.com:3129" assert mock_client.call_args.kwargs.get("_http") == mock_authorized_session_cls.return_value - @mock.patch("google.auth.transport.requests.Request") - @mock.patch("google.auth.transport.requests.AuthorizedSession") - @mock.patch("requests.Session") - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") - def test_get_client_with_both_proxies_sets_both_in_session( - self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls - ): - hook = self._make_hook( - http_proxy="http://proxy.example.com:3128", - https_proxy="https://proxy.example.com:3129", - ) - hook.get_client(project_id=PROJECT_ID) - - session_instance = mock_session_cls.return_value - assert session_instance.proxies["http"] == "http://proxy.example.com:3128" - assert session_instance.proxies["https"] == "https://proxy.example.com:3129" - - @mock.patch("google.auth.transport.requests.Request") - @mock.patch("google.auth.transport.requests.AuthorizedSession") - @mock.patch("requests.Session") - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") - def test_get_client_passes_authorized_session_built_with_proxy_session( - self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls - ): - hook = self._make_hook(http_proxy="http://proxy.example.com:3128") - hook.get_client(project_id=PROJECT_ID) - - session_instance = mock_session_cls.return_value - mock_request_cls.assert_called_once_with(session=session_instance) - mock_authorized_session_cls.assert_called_once_with( - mock.ANY, auth_request=mock_request_cls.return_value - ) - # --- _get_pandas_df --- @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") @@ -2551,16 +2432,6 @@ def test_get_pandas_df_with_http_proxy_uses_get_client(self, mock_get_client): ) assert isinstance(result, pd.DataFrame) - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") - def test_get_pandas_df_with_https_proxy_uses_get_client(self, mock_get_client): - import pandas as pd - - mock_get_client.return_value.query.return_value.to_dataframe.return_value = pd.DataFrame({"a": [1]}) - hook = self._make_hook(https_proxy="https://proxy.example.com:3129") - hook._get_pandas_df("SELECT 1") - - mock_get_client.assert_called_once() - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.read_gbq") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") def test_get_pandas_df_without_proxy_uses_read_gbq(self, mock_get_client, mock_read_gbq): From 52905256a8b065103f951c33d49e7ce82fcb1a44 Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Tue, 23 Jun 2026 09:27:37 +1000 Subject: [PATCH 10/14] format imports --- .../providers/google/cloud/hooks/bigquery.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index fc1cabf7548ea..26be53c4cfbda 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -38,14 +38,29 @@ from aiohttp import ClientSession as ClientSession from asgiref.sync import sync_to_async from gcloud.aio.bigquery import Job, Table as Table_async -from google.cloud.bigquery import DEFAULT_RETRY, Client, CopyJob, ExtractJob, LoadJob, QueryJob, QueryJobConfig, SchemaField, UnknownJob +from google.cloud.bigquery import ( + DEFAULT_RETRY, + Client, + CopyJob, + ExtractJob, + LoadJob, + QueryJob, + SchemaField, + UnknownJob, +) import httplib2 from googleapiclient.http import set_user_agent from airflow import version from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY from google.cloud.bigquery.routine import Routine, RoutineReference -from google.cloud.bigquery.table import Row, RowIterator, Table, TableListItem, TableReference +from google.cloud.bigquery.table import ( + Row, + RowIterator, + Table, + TableListItem, + TableReference, +) from google.cloud.exceptions import NotFound from googleapiclient.discovery import build from pandas_gbq import read_gbq @@ -61,7 +76,13 @@ from airflow.providers.google.cloud.utils.credentials_provider import _get_scopes from airflow.providers.google.cloud.utils.lineage import send_hook_lineage_for_bq_job from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import _UNSET, PROVIDE_PROJECT_ID, GoogleBaseAsyncHook, GoogleBaseHook, get_field +from airflow.providers.google.common.hooks.base_google import ( + _UNSET, + PROVIDE_PROJECT_ID, + GoogleBaseAsyncHook, + GoogleBaseHook, + get_field, +) from airflow.providers.google.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils.hashlib_wrapper import md5 from airflow.utils.helpers import convert_camel_to_snake From 8e6fe1f494cd8b4c5ce86aaf75a91b7d267419d0 Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Tue, 23 Jun 2026 09:29:02 +1000 Subject: [PATCH 11/14] fix broken bit --- .../google/src/airflow/providers/google/cloud/hooks/bigquery.py | 1 + 1 file changed, 1 insertion(+) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 26be53c4cfbda..7e1be1e60368b 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -61,6 +61,7 @@ TableListItem, TableReference, ) +from google.cloud.bigquery import QueryJobConfig from google.cloud.exceptions import NotFound from googleapiclient.discovery import build from pandas_gbq import read_gbq From 64e8aa38985dead1b3bafc2434e251064bab6b5f Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Tue, 23 Jun 2026 09:38:07 +1000 Subject: [PATCH 12/14] makes more sense there --- .../google/src/airflow/providers/google/cloud/hooks/bigquery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 7e1be1e60368b..deddc8ceace1e 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -45,6 +45,7 @@ ExtractJob, LoadJob, QueryJob, + QueryJobConfig, SchemaField, UnknownJob, ) @@ -61,7 +62,6 @@ TableListItem, TableReference, ) -from google.cloud.bigquery import QueryJobConfig from google.cloud.exceptions import NotFound from googleapiclient.discovery import build from pandas_gbq import read_gbq From efe76c4cd5d261ecdb57f4c16ee35ad55c089697 Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Tue, 23 Jun 2026 13:26:27 +1000 Subject: [PATCH 13/14] reorder --- .../providers/google/cloud/hooks/bigquery.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index deddc8ceace1e..ec3288b95d8cf 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -31,13 +31,15 @@ from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload from urllib.parse import urlparse -import requests -from google.auth.transport.requests import AuthorizedSession, Request + import google_auth_httplib2 +import httplib2 import pendulum +import requests from aiohttp import ClientSession as ClientSession from asgiref.sync import sync_to_async from gcloud.aio.bigquery import Job, Table as Table_async +from google.auth.transport.requests import AuthorizedSession, Request from google.cloud.bigquery import ( DEFAULT_RETRY, Client, @@ -49,9 +51,6 @@ SchemaField, UnknownJob, ) -import httplib2 -from googleapiclient.http import set_user_agent -from airflow import version from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY from google.cloud.bigquery.routine import Routine, RoutineReference @@ -64,10 +63,12 @@ ) from google.cloud.exceptions import NotFound from googleapiclient.discovery import build +from googleapiclient.http import set_user_agent from pandas_gbq import read_gbq from pandas_gbq.gbq import GbqConnector # noqa: F401 used in ``airflow.contrib.hooks.bigquery`` from sqlalchemy import create_engine +from airflow import version from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException @@ -398,7 +399,11 @@ def _get_pandas_df( if self.http_proxy or self.https_proxy: job_config = QueryJobConfig(use_legacy_sql=(dialect == "legacy")) - return (self.get_client().query(sql, job_config=job_config, timeout=timeout, **kwargs).to_dataframe(create_bqstorage_client=False)) + return ( + self.get_client() + .query(sql, job_config=job_config, timeout=timeout, **kwargs) + .to_dataframe(create_bqstorage_client=False) + ) credentials, project_id = self.get_credentials_and_project_id() From f76840b0f95700106b109852283a9656460526ff Mon Sep 17 00:00:00 2001 From: Jesse Mansfield Date: Wed, 24 Jun 2026 09:20:59 +1000 Subject: [PATCH 14/14] unit tests rework --- .../providers/google/cloud/hooks/bigquery.py | 2 +- .../tests/unit/google/cloud/hooks/test_bigquery.py | 14 +++++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index ec3288b95d8cf..650faf791dcee 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -401,7 +401,7 @@ def _get_pandas_df( job_config = QueryJobConfig(use_legacy_sql=(dialect == "legacy")) return ( self.get_client() - .query(sql, job_config=job_config, timeout=timeout, **kwargs) + .query(sql, job_config=job_config, timeout=timeout if timeout is not None else 60) .to_dataframe(create_bqstorage_client=False) ) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index b2d1825251fcb..7091a8db735e8 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -32,6 +32,7 @@ CopyJob, DatasetReference, QueryJob, + QueryJobConfig, Table, TableReference, ) @@ -2387,8 +2388,8 @@ def test_get_client_without_proxy_omits_http_kwarg(self, mock_client): hook.get_client(project_id=PROJECT_ID) assert "_http" not in mock_client.call_args.kwargs - @mock.patch("google.auth.transport.requests.Request") - @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Request") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.AuthorizedSession") @mock.patch("requests.Session") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") def test_get_client_with_http_proxy_sets_session_http_proxy( @@ -2401,8 +2402,8 @@ def test_get_client_with_http_proxy_sets_session_http_proxy( assert session_instance.proxies["http"] == "http://proxy.example.com:3128" assert mock_client.call_args.kwargs.get("_http") == mock_authorized_session_cls.return_value - @mock.patch("google.auth.transport.requests.Request") - @mock.patch("google.auth.transport.requests.AuthorizedSession") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Request") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.AuthorizedSession") @mock.patch("requests.Session") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") def test_get_client_with_https_proxy_sets_session_https_proxy( @@ -2426,7 +2427,10 @@ def test_get_pandas_df_with_http_proxy_uses_get_client(self, mock_get_client): result = hook._get_pandas_df("SELECT 1") mock_get_client.assert_called_once() - mock_get_client.return_value.query.assert_called_once_with("SELECT 1", timeout=10) + call_args = mock_get_client.return_value.query.call_args + assert call_args.args == ("SELECT 1",) + assert call_args.kwargs["timeout"] == 60 + assert isinstance(call_args.kwargs["job_config"], QueryJobConfig) mock_get_client.return_value.query.return_value.to_dataframe.assert_called_once_with( create_bqstorage_client=False )