diff --git a/providers/google/docs/connections/bigquery.rst b/providers/google/docs/connections/bigquery.rst index c596d3e86b1c4..669557ab449f7 100644 --- a/providers/google/docs/connections/bigquery.rst +++ b/providers/google/docs/connections/bigquery.rst @@ -60,3 +60,8 @@ API Resource Configs Labels A dictionary of labels to be applied on the BigQuery job. + +http_proxy + Optional HTTP proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTP proxy. Can also be supplied via environmental variable or connection extra. +https_proxy + Optional HTTPS proxy to use when connecting to BigQuery. If not provided, the connection will not use an HTTPS proxy. Can also be supplied via environmental variable or connection extra. diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index daf4636198564..650faf791dcee 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -30,11 +30,16 @@ from copy import deepcopy from datetime import datetime, timedelta from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload +from urllib.parse import urlparse +import google_auth_httplib2 +import httplib2 import pendulum +import requests from aiohttp import ClientSession as ClientSession from asgiref.sync import sync_to_async from gcloud.aio.bigquery import Job, Table as Table_async +from google.auth.transport.requests import AuthorizedSession, Request from google.cloud.bigquery import ( DEFAULT_RETRY, Client, @@ -42,6 +47,7 @@ ExtractJob, LoadJob, QueryJob, + QueryJobConfig, SchemaField, UnknownJob, ) @@ -57,10 +63,12 @@ ) from google.cloud.exceptions import NotFound from googleapiclient.discovery import build +from googleapiclient.http import set_user_agent from pandas_gbq import read_gbq from pandas_gbq.gbq import GbqConnector # noqa: F401 used in ``airflow.contrib.hooks.bigquery`` from sqlalchemy import create_engine +from airflow import version from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector from airflow.providers.common.compat.sdk import AirflowException, AirflowOptionalProviderFeatureException @@ -166,8 +174,11 @@ def get_connection_form_widgets(cls) -> dict[str, Any]: connection_form_widgets["labels"] = StringField( lazy_gettext("Labels"), widget=BS3TextFieldWidget(), validators=[ValidJson()] ) - connection_form_widgets["labels"] = StringField( - lazy_gettext("Labels"), widget=BS3TextFieldWidget(), validators=[ValidJson()] + connection_form_widgets["http_proxy"] = StringField( + lazy_gettext("HTTP Proxy"), widget=BS3TextFieldWidget() + ) + connection_form_widgets["https_proxy"] = StringField( + lazy_gettext("HTTPS Proxy"), widget=BS3TextFieldWidget() ) return connection_form_widgets @@ -184,6 +195,8 @@ def __init__( api_resource_configs: dict | None | object = _UNSET, impersonation_scopes: str | Sequence[str] | None = None, labels: dict | None | object = _UNSET, + http_proxy: str | None | object = _UNSET, + https_proxy: str | None | object = _UNSET, **kwargs, ) -> None: super().__init__(**kwargs) @@ -219,6 +232,16 @@ def __init__( else: self.labels = labels or {} # type: ignore[assignment] + if http_proxy is _UNSET: + self.http_proxy: str | None = self._get_field("http_proxy", None) + else: + self.http_proxy = http_proxy # type: ignore[assignment] + + if https_proxy is _UNSET: + self.https_proxy: str | None = self._get_field("https_proxy", None) + else: + self.https_proxy = https_proxy # type: ignore[assignment] + self.impersonation_scopes: str | Sequence[str] | None = impersonation_scopes def get_conn(self) -> BigQueryConnection: @@ -240,6 +263,22 @@ def get_conn(self) -> BigQueryConnection: hook=self, ) + def _authorize(self) -> google_auth_httplib2.AuthorizedHttp: + """Return an authorized HTTP object, optionally configured with a proxy.""" + proxy_url = self.http_proxy or self.https_proxy + if not proxy_url: + return super()._authorize() + parsed = urlparse(proxy_url) + proxy_info = httplib2.ProxyInfo( + proxy_type=httplib2.socks.PROXY_TYPE_HTTP, + proxy_host=parsed.hostname, + proxy_port=parsed.port or 80, + proxy_user=parsed.username, + proxy_pass=parsed.password, + ) + http = set_user_agent(httplib2.Http(proxy_info=proxy_info), "airflow/" + version.version) + return google_auth_httplib2.AuthorizedHttp(self.get_credentials(), http=http) + def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None = None) -> Client: """ Get an authenticated BigQuery Client. @@ -247,13 +286,25 @@ def get_client(self, project_id: str = PROVIDE_PROJECT_ID, location: str | None :param project_id: Project ID for the project which the client acts on behalf of. :param location: Default location for jobs / datasets / tables. """ - return Client( - client_info=CLIENT_INFO, - project=project_id, - location=location, - credentials=self.get_credentials(), - client_options=self.get_client_options(), - ) + credentials = self.get_credentials() + kwargs: dict[str, Any] = { + "client_info": CLIENT_INFO, + "project": project_id, + "location": location, + "credentials": credentials, + "client_options": getattr(self, "get_client_options", lambda: None)(), + } + if self.http_proxy or self.https_proxy: + session = requests.Session() + session.proxies = {} + if self.http_proxy: + session.proxies["http"] = self.http_proxy + if self.https_proxy: + session.proxies["https"] = self.https_proxy + authorized_session = AuthorizedSession(credentials, auth_request=Request(session=session)) + authorized_session.proxies = dict(session.proxies) + kwargs["_http"] = authorized_session + return Client(**kwargs) def get_uri(self) -> str: """Override from ``DbApiHook`` for ``get_sqlalchemy_engine()``.""" @@ -340,11 +391,20 @@ def _get_pandas_df( sql: str, parameters: Iterable | Mapping[str, Any] | None = None, dialect: str | None = None, + timeout: float | None = None, **kwargs, ) -> pd.DataFrame: if dialect is None: dialect = "legacy" if self.use_legacy_sql else "standard" + if self.http_proxy or self.https_proxy: + job_config = QueryJobConfig(use_legacy_sql=(dialect == "legacy")) + return ( + self.get_client() + .query(sql, job_config=job_config, timeout=timeout if timeout is not None else 60) + .to_dataframe(create_bqstorage_client=False) + ) + credentials, project_id = self.get_credentials_and_project_id() return read_gbq(sql, project_id=project_id, dialect=dialect, credentials=credentials, **kwargs) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index ddacebbdbe18a..7091a8db735e8 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -32,6 +32,7 @@ CopyJob, DatasetReference, QueryJob, + QueryJobConfig, Table, TableReference, ) @@ -2313,3 +2314,136 @@ def test_insert_job_hook_lineage(self, mock_client, mock_query_job, mock_send_li ) mock_send_lineage.assert_called_once_with(context=self.hook, job=mock_job_instance) + + +@pytest.mark.db_test +class TestBigQueryHookProxy: + """Tests that HTTP/HTTPS proxy settings are propagated through _authorize, get_client, and _get_pandas_df.""" + + def _make_hook(self, http_proxy=None, https_proxy=None): + class MockedBigQueryHook(BigQueryHook): + def get_credentials_and_project_id(self): + return CREDENTIALS, PROJECT_ID + + def get_credentials(self): + return mock.MagicMock(name="credentials") + + return MockedBigQueryHook(http_proxy=http_proxy, https_proxy=https_proxy) + + # --- _authorize --- + + @mock.patch("airflow.providers.google.common.hooks.base_google.GoogleBaseHook._authorize") + def test_authorize_without_proxy_delegates_to_base(self, mock_base_authorize): + hook = self._make_hook() + result = hook._authorize() + mock_base_authorize.assert_called_once() + assert result == mock_base_authorize.return_value + + @mock.patch("httplib2.socks") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_with_http_proxy_creates_proxy_info( + self, mock_proxy_info, mock_http, _mock_set_user_agent, mock_authorized_http, _mock_socks + ): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + result = hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=3128, + proxy_user=None, + proxy_pass=None, + ) + mock_http.assert_called_once_with(proxy_info=mock_proxy_info.return_value) + mock_authorized_http.assert_called_once() + assert result == mock_authorized_http.return_value + + @mock.patch("httplib2.socks") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.google_auth_httplib2.AuthorizedHttp") + @mock.patch("googleapiclient.http.set_user_agent") + @mock.patch("httplib2.Http") + @mock.patch("httplib2.ProxyInfo") + def test_authorize_proxy_without_port_defaults_to_80( + self, mock_proxy_info, _mock_http, _mock_set_user_agent, _mock_authorized_http, _mock_socks + ): + hook = self._make_hook(http_proxy="http://proxy.example.com") + hook._authorize() + + mock_proxy_info.assert_called_once_with( + proxy_type=mock.ANY, + proxy_host="proxy.example.com", + proxy_port=80, + proxy_user=None, + proxy_pass=None, + ) + + # --- get_client --- + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_without_proxy_omits_http_kwarg(self, mock_client): + hook = self._make_hook() + hook.get_client(project_id=PROJECT_ID) + assert "_http" not in mock_client.call_args.kwargs + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Request") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_with_http_proxy_sets_session_http_proxy( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + assert session_instance.proxies["http"] == "http://proxy.example.com:3128" + assert mock_client.call_args.kwargs.get("_http") == mock_authorized_session_cls.return_value + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Request") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.AuthorizedSession") + @mock.patch("requests.Session") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Client") + def test_get_client_with_https_proxy_sets_session_https_proxy( + self, mock_client, mock_session_cls, mock_authorized_session_cls, mock_request_cls + ): + hook = self._make_hook(https_proxy="https://proxy.example.com:3129") + hook.get_client(project_id=PROJECT_ID) + + session_instance = mock_session_cls.return_value + assert session_instance.proxies["https"] == "https://proxy.example.com:3129" + assert mock_client.call_args.kwargs.get("_http") == mock_authorized_session_cls.return_value + + # --- _get_pandas_df --- + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_pandas_df_with_http_proxy_uses_get_client(self, mock_get_client): + import pandas as pd + + mock_get_client.return_value.query.return_value.to_dataframe.return_value = pd.DataFrame({"a": [1]}) + hook = self._make_hook(http_proxy="http://proxy.example.com:3128") + result = hook._get_pandas_df("SELECT 1") + + mock_get_client.assert_called_once() + call_args = mock_get_client.return_value.query.call_args + assert call_args.args == ("SELECT 1",) + assert call_args.kwargs["timeout"] == 60 + assert isinstance(call_args.kwargs["job_config"], QueryJobConfig) + mock_get_client.return_value.query.return_value.to_dataframe.assert_called_once_with( + create_bqstorage_client=False + ) + assert isinstance(result, pd.DataFrame) + + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.read_gbq") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") + def test_get_pandas_df_without_proxy_uses_read_gbq(self, mock_get_client, mock_read_gbq): + import pandas as pd + + mock_read_gbq.return_value = pd.DataFrame({"a": [1]}) + hook = self._make_hook() + hook._get_pandas_df("SELECT 1") + + mock_get_client.assert_not_called() + mock_read_gbq.assert_called_once()