From fd2ea1f719e2432df306d24b64ececb630bead79 Mon Sep 17 00:00:00 2001 From: haseebmalik18 Date: Thu, 2 Apr 2026 15:26:35 -0400 Subject: [PATCH 1/2] Add auth_protocol support to SambaHook for Kerberos authentication #29590 --- providers/samba/pyproject.toml | 4 + .../airflow/providers/samba/hooks/samba.py | 55 ++++++-- .../tests/unit/samba/hooks/test_samba.py | 117 +++++++++++++++++- uv.lock | 15 ++- 4 files changed, 181 insertions(+), 10 deletions(-) diff --git a/providers/samba/pyproject.toml b/providers/samba/pyproject.toml index 268199698f8e1..d5105dd606d57 100644 --- a/providers/samba/pyproject.toml +++ b/providers/samba/pyproject.toml @@ -70,6 +70,10 @@ dependencies = [ "google" = [ "apache-airflow-providers-google" ] +"kerberos" = [ + "krb5", + "smbprotocol[kerberos]>=1.5.0", +] [dependency-groups] dev = [ diff --git a/providers/samba/src/airflow/providers/samba/hooks/samba.py b/providers/samba/src/airflow/providers/samba/hooks/samba.py index c36f0b06dc929..34dd209d3dbc6 100644 --- a/providers/samba/src/airflow/providers/samba/hooks/samba.py +++ b/providers/samba/src/airflow/providers/samba/hooks/samba.py @@ -24,7 +24,7 @@ import smbclient -from airflow.providers.common.compat.sdk import BaseHook +from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException, BaseHook if TYPE_CHECKING: import smbprotocol.connection @@ -43,6 +43,8 @@ class SambaHook(BaseHook): the connection is used in its place. :param share_type: An optional share type name. If this is unset then it will assume a posix share type. + :param auth_protocol: + An optional authentication protocol. If this is unset then it defaults to negotiate. """ conn_name_attr = "samba_conn_id" @@ -50,22 +52,50 @@ class SambaHook(BaseHook): conn_type = "samba" hook_name = "Samba" + VALID_AUTH_PROTOCOLS = {"negotiate", "ntlm", "kerberos"} + def __init__( self, samba_conn_id: str = default_conn_name, share: str | None = None, share_type: Literal["posix", "windows"] | None = None, + auth_protocol: Literal["negotiate", "ntlm", "kerberos"] | None = None, ) -> None: super().__init__() conn = self.get_connection(samba_conn_id) + extra = conn.extra_dejson + + legacy_auth = extra.get("auth") + legacy_auth_protocol = ( + legacy_auth if isinstance(legacy_auth, str) and legacy_auth in self.VALID_AUTH_PROTOCOLS else None + ) + self._auth_protocol: str = ( + auth_protocol or extra.get("auth_protocol") or legacy_auth_protocol or "negotiate" + ) + if self._auth_protocol not in self.VALID_AUTH_PROTOCOLS: + raise ValueError( + f"Invalid auth_protocol '{self._auth_protocol}'. " + f"Must be one of {sorted(self.VALID_AUTH_PROTOCOLS)}." + ) + + uses_kerberos = self._auth_protocol == "kerberos" - if not conn.login: + if uses_kerberos: + try: + import krb5 # noqa: F401 + except ImportError: + raise AirflowOptionalProviderFeatureException( + "Kerberos authentication requires the 'krb5' package. " + "Install it with: pip install 'apache-airflow-providers-samba[kerberos]'" + ) + + if not conn.login and not uses_kerberos: self.log.info("Login not provided") - if not conn.password: + if not conn.password and not uses_kerberos: self.log.info("Password not provided") - self._share_type = share_type or conn.extra_dejson.get("share_type", "posix") + self._share_type = share_type or extra.get("share_type", "posix") if self._share_type not in {"posix", "windows"}: self._share_type = "posix" self.log.warning( @@ -77,11 +107,12 @@ def __init__( self._host = conn.host self._share = share or conn.schema self._connection_cache = connection_cache - self._conn_kwargs = { + self._conn_kwargs: dict[str, Any] = { "username": conn.login, "password": conn.password, "port": conn.port or 445, "connection_cache": connection_cache, + "auth_protocol": self._auth_protocol, } def __enter__(self): @@ -325,12 +356,22 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]: def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to connection form.""" from flask_babel import lazy_gettext - from wtforms import StringField + from wtforms import SelectField, StringField return { "share_type": StringField( label=lazy_gettext("Share Type"), description="The share OS type (`posix` or `windows`). Used to determine the formatting of file and folder paths.", default="posix", - ) + ), + "auth_protocol": SelectField( + label=lazy_gettext("Auth Protocol"), + description=( + "Authentication protocol: `negotiate` (auto-select, default), " + "`ntlm`, or `kerberos`. When using `kerberos`, the system's " + "Kerberos ticket cache is used and username/password are optional." + ), + choices=["negotiate", "ntlm", "kerberos"], + default="negotiate", + ), } diff --git a/providers/samba/tests/unit/samba/hooks/test_samba.py b/providers/samba/tests/unit/samba/hooks/test_samba.py index 44d79c8830353..06a5e3fb147eb 100644 --- a/providers/samba/tests/unit/samba/hooks/test_samba.py +++ b/providers/samba/tests/unit/samba/hooks/test_samba.py @@ -23,7 +23,10 @@ import pytest from airflow.models import Connection -from airflow.providers.common.compat.sdk import AirflowNotFoundException +from airflow.providers.common.compat.sdk import ( + AirflowNotFoundException, + AirflowOptionalProviderFeatureException, +) from airflow.providers.samba.hooks.samba import SambaHook try: @@ -63,6 +66,7 @@ def test_context_manager(self, get_conn_mock, register_session): "password": CONNECTION.password, "port": 445, "connection_cache": {}, + "auth_protocol": "negotiate", } cache = kwargs.get("connection_cache") mock_connection = mock.Mock() @@ -117,6 +121,7 @@ def test_method(self, get_conn_mock, name): "username": CONNECTION.login, "password": CONNECTION.password, "port": 445, + "auth_protocol": "negotiate", } with mock.patch("smbclient." + name) as p: kwargs = {} @@ -191,6 +196,116 @@ def test__join_path( hook = SambaHook("samba_default", share_type=path_type) assert hook._join_path(path) == full_path + @mock.patch.dict("sys.modules", {"krb5": mock.MagicMock()}) + @mock.patch("smbclient.register_session") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") + def test_kerberos_auth_via_extra(self, get_conn_mock, register_session): + """Test that auth_protocol='kerberos' from extra is passed to smbclient.""" + connection = Connection( + host="kerb-host.example.com", + schema="share", + extra='{"auth_protocol": "kerberos"}', + ) + get_conn_mock.return_value = connection + register_session.return_value = None + with SambaHook("samba_default"): + _, kwargs = tuple(register_session.call_args_list[0]) + assert kwargs["auth_protocol"] == "kerberos" + assert kwargs["username"] is None + assert kwargs["password"] is None + + @mock.patch.dict("sys.modules", {"krb5": mock.MagicMock()}) + @mock.patch("smbclient.register_session") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") + def test_kerberos_auth_via_legacy_auth_key(self, get_conn_mock, register_session): + """Test backward compat: extra {"auth": "kerberos"} is recognized.""" + connection = Connection( + host="kerb-host.example.com", + schema="share", + extra='{"auth": "kerberos"}', + ) + get_conn_mock.return_value = connection + register_session.return_value = None + with SambaHook("samba_default"): + _, kwargs = tuple(register_session.call_args_list[0]) + assert kwargs["auth_protocol"] == "kerberos" + + @mock.patch.dict("sys.modules", {"krb5": mock.MagicMock()}) + @mock.patch("smbclient.register_session") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") + def test_kerberos_auth_via_constructor(self, get_conn_mock, register_session): + """Test that constructor auth_protocol overrides extra.""" + connection = Connection( + host="kerb-host.example.com", + schema="share", + login="user", + password="pass", + ) + get_conn_mock.return_value = connection + register_session.return_value = None + with SambaHook("samba_default", auth_protocol="kerberos"): + _, kwargs = tuple(register_session.call_args_list[0]) + assert kwargs["auth_protocol"] == "kerberos" + + @mock.patch("smbclient.register_session") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") + def test_legacy_auth_key_ignored_when_invalid(self, get_conn_mock, register_session): + """Test that extra {"auth": "basic"} is ignored and defaults to negotiate.""" + connection = Connection( + host="host", + schema="share", + login="user", + password="pass", + extra='{"auth": "basic"}', + ) + get_conn_mock.return_value = connection + register_session.return_value = None + with SambaHook("samba_default"): + _, kwargs = tuple(register_session.call_args_list[0]) + assert kwargs["auth_protocol"] == "negotiate" + + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") + def test_invalid_auth_protocol_raises(self, get_conn_mock): + """Test that an invalid auth_protocol raises ValueError.""" + connection = Connection( + host="host", + schema="share", + extra='{"auth_protocol": "invalid"}', + ) + get_conn_mock.return_value = connection + with pytest.raises(ValueError, match="Invalid auth_protocol 'invalid'"): + SambaHook("samba_default") + + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") + def test_kerberos_without_dependency_raises(self, get_conn_mock): + """Test that kerberos auth without krb5 installed raises AirflowOptionalProviderFeatureException.""" + connection = Connection( + host="kerb-host.example.com", + schema="share", + extra='{"auth_protocol": "kerberos"}', + ) + get_conn_mock.return_value = connection + with mock.patch.dict("sys.modules", {"krb5": None}): + with pytest.raises(AirflowOptionalProviderFeatureException, match="krb5"): + SambaHook("samba_default") + + @mock.patch("smbclient.register_session") + @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") + def test_ntlm_auth(self, get_conn_mock, register_session): + """Test that auth_protocol='ntlm' is passed correctly.""" + connection = Connection( + host="host", + schema="share", + login="user", + password="pass", + extra='{"auth_protocol": "ntlm"}', + ) + get_conn_mock.return_value = connection + register_session.return_value = None + with SambaHook("samba_default"): + _, kwargs = tuple(register_session.call_args_list[0]) + assert kwargs["auth_protocol"] == "ntlm" + @mock.patch("airflow.providers.samba.hooks.samba.smbclient.open_file", return_value=mock.Mock()) @mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection") def test_open_file(self, get_conn_mock, open_file_mock): diff --git a/uv.lock b/uv.lock index 3bd2a71763024..f0bf12f53427a 100644 --- a/uv.lock +++ b/uv.lock @@ -6638,7 +6638,7 @@ docs = [ requires-dist = [ { name = "apache-airflow", editable = "." }, { name = "apache-airflow-providers-common-compat", editable = "providers/common/compat" }, - { name = "openai", extras = ["datalib"], specifier = ">=1.66.0" }, + { name = "openai", extras = ["datalib"], specifier = ">=2.37.0" }, ] [package.metadata.requires-dev] @@ -7351,6 +7351,10 @@ dependencies = [ google = [ { name = "apache-airflow-providers-google" }, ] +kerberos = [ + { name = "krb5" }, + { name = "smbprotocol", extra = ["kerberos"] }, +] [package.dev-dependencies] dev = [ @@ -7369,9 +7373,11 @@ requires-dist = [ { name = "apache-airflow", editable = "." }, { name = "apache-airflow-providers-common-compat", editable = "providers/common/compat" }, { name = "apache-airflow-providers-google", marker = "extra == 'google'", editable = "providers/google" }, + { name = "krb5", marker = "extra == 'kerberos'" }, { name = "smbprotocol", specifier = ">=1.5.0" }, + { name = "smbprotocol", extras = ["kerberos"], marker = "extra == 'kerberos'", specifier = ">=1.5.0" }, ] -provides-extras = ["google"] +provides-extras = ["google", "kerberos"] [package.metadata.requires-dev] dev = [ @@ -21938,6 +21944,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/43/8acc58e24f78af82300a8ebdd457face4c6e4b27f7b52885a239724d3e21/smbprotocol-1.16.1-py3-none-any.whl", hash = "sha256:0f9c35b2a0103314da06aecaffa8821f060a9773ad82cc0d7f29130083d598b4", size = 126532, upload-time = "2026-04-02T03:01:27.718Z" }, ] +[package.optional-dependencies] +kerberos = [ + { name = "pyspnego", extra = ["kerberos"] }, +] + [[package]] name = "smmap" version = "5.0.3" From 7d121832e5810a61ec978a15a2bbad7e5906079e Mon Sep 17 00:00:00 2001 From: haseebmalik18 Date: Mon, 29 Jun 2026 20:34:34 -0400 Subject: [PATCH 2/2] Simplify auth_protocol resolution in SambaHook --- .../samba/src/airflow/providers/samba/hooks/samba.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/providers/samba/src/airflow/providers/samba/hooks/samba.py b/providers/samba/src/airflow/providers/samba/hooks/samba.py index 34dd209d3dbc6..1e6bf5a9cded8 100644 --- a/providers/samba/src/airflow/providers/samba/hooks/samba.py +++ b/providers/samba/src/airflow/providers/samba/hooks/samba.py @@ -67,11 +67,11 @@ def __init__( legacy_auth = extra.get("auth") legacy_auth_protocol = ( - legacy_auth if isinstance(legacy_auth, str) and legacy_auth in self.VALID_AUTH_PROTOCOLS else None - ) - self._auth_protocol: str = ( - auth_protocol or extra.get("auth_protocol") or legacy_auth_protocol or "negotiate" + legacy_auth + if isinstance(legacy_auth, str) and legacy_auth in self.VALID_AUTH_PROTOCOLS + else "negotiate" ) + self._auth_protocol: str = auth_protocol or extra.get("auth_protocol", legacy_auth_protocol) if self._auth_protocol not in self.VALID_AUTH_PROTOCOLS: raise ValueError( f"Invalid auth_protocol '{self._auth_protocol}'. "