Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions providers/samba/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ dependencies = [
"google" = [
"apache-airflow-providers-google"
]
"kerberos" = [
"krb5",
"smbprotocol[kerberos]>=1.5.0",
]

[dependency-groups]
dev = [
Expand Down
55 changes: 48 additions & 7 deletions providers/samba/src/airflow/providers/samba/hooks/samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import smbclient

from airflow.providers.common.compat.sdk import BaseHook
from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException, BaseHook

if TYPE_CHECKING:
import smbprotocol.connection
Expand All @@ -43,29 +43,59 @@ class SambaHook(BaseHook):
the connection is used in its place.
:param share_type:
An optional share type name. If this is unset then it will assume a posix share type.
:param auth_protocol:
An optional authentication protocol. If this is unset then it defaults to negotiate.
"""

conn_name_attr = "samba_conn_id"
default_conn_name = "samba_default"
conn_type = "samba"
hook_name = "Samba"

VALID_AUTH_PROTOCOLS = {"negotiate", "ntlm", "kerberos"}

def __init__(
self,
samba_conn_id: str = default_conn_name,
share: str | None = None,
share_type: Literal["posix", "windows"] | None = None,
auth_protocol: Literal["negotiate", "ntlm", "kerberos"] | None = None,
) -> None:
super().__init__()
conn = self.get_connection(samba_conn_id)
extra = conn.extra_dejson

legacy_auth = extra.get("auth")
legacy_auth_protocol = (
Comment thread
haseebmalik18 marked this conversation as resolved.
legacy_auth
if isinstance(legacy_auth, str) and legacy_auth in self.VALID_AUTH_PROTOCOLS
else "negotiate"
)
self._auth_protocol: str = auth_protocol or extra.get("auth_protocol", legacy_auth_protocol)
if self._auth_protocol not in self.VALID_AUTH_PROTOCOLS:
raise ValueError(
f"Invalid auth_protocol '{self._auth_protocol}'. "
f"Must be one of {sorted(self.VALID_AUTH_PROTOCOLS)}."
)

if not conn.login:
uses_kerberos = self._auth_protocol == "kerberos"

if uses_kerberos:
try:
import krb5 # noqa: F401
except ImportError:
raise AirflowOptionalProviderFeatureException(
"Kerberos authentication requires the 'krb5' package. "
"Install it with: pip install 'apache-airflow-providers-samba[kerberos]'"
)

if not conn.login and not uses_kerberos:
self.log.info("Login not provided")

if not conn.password:
if not conn.password and not uses_kerberos:
self.log.info("Password not provided")

self._share_type = share_type or conn.extra_dejson.get("share_type", "posix")
self._share_type = share_type or extra.get("share_type", "posix")
if self._share_type not in {"posix", "windows"}:
self._share_type = "posix"
self.log.warning(
Expand All @@ -77,11 +107,12 @@ def __init__(
self._host = conn.host
self._share = share or conn.schema
self._connection_cache = connection_cache
self._conn_kwargs = {
self._conn_kwargs: dict[str, Any] = {
"username": conn.login,
"password": conn.password,
"port": conn.port or 445,
"connection_cache": connection_cache,
"auth_protocol": self._auth_protocol,
}

def __enter__(self):
Expand Down Expand Up @@ -325,12 +356,22 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
def get_connection_form_widgets(cls) -> dict[str, Any]:
"""Return connection widgets to add to connection form."""
from flask_babel import lazy_gettext
from wtforms import StringField
from wtforms import SelectField, StringField

return {
"share_type": StringField(
label=lazy_gettext("Share Type"),
description="The share OS type (`posix` or `windows`). Used to determine the formatting of file and folder paths.",
default="posix",
)
),
"auth_protocol": SelectField(
label=lazy_gettext("Auth Protocol"),
description=(
"Authentication protocol: `negotiate` (auto-select, default), "
"`ntlm`, or `kerberos`. When using `kerberos`, the system's "
"Kerberos ticket cache is used and username/password are optional."
),
choices=["negotiate", "ntlm", "kerberos"],
default="negotiate",
),
}
117 changes: 116 additions & 1 deletion providers/samba/tests/unit/samba/hooks/test_samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import pytest

from airflow.models import Connection
from airflow.providers.common.compat.sdk import AirflowNotFoundException
from airflow.providers.common.compat.sdk import (
AirflowNotFoundException,
AirflowOptionalProviderFeatureException,
)
from airflow.providers.samba.hooks.samba import SambaHook

try:
Expand Down Expand Up @@ -63,6 +66,7 @@ def test_context_manager(self, get_conn_mock, register_session):
"password": CONNECTION.password,
"port": 445,
"connection_cache": {},
"auth_protocol": "negotiate",
}
cache = kwargs.get("connection_cache")
mock_connection = mock.Mock()
Expand Down Expand Up @@ -117,6 +121,7 @@ def test_method(self, get_conn_mock, name):
"username": CONNECTION.login,
"password": CONNECTION.password,
"port": 445,
"auth_protocol": "negotiate",
}
with mock.patch("smbclient." + name) as p:
kwargs = {}
Expand Down Expand Up @@ -191,6 +196,116 @@ def test__join_path(
hook = SambaHook("samba_default", share_type=path_type)
assert hook._join_path(path) == full_path

@mock.patch.dict("sys.modules", {"krb5": mock.MagicMock()})
@mock.patch("smbclient.register_session")
@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_kerberos_auth_via_extra(self, get_conn_mock, register_session):
"""Test that auth_protocol='kerberos' from extra is passed to smbclient."""
connection = Connection(
host="kerb-host.example.com",
schema="share",
extra='{"auth_protocol": "kerberos"}',
)
get_conn_mock.return_value = connection
register_session.return_value = None
with SambaHook("samba_default"):
_, kwargs = tuple(register_session.call_args_list[0])
assert kwargs["auth_protocol"] == "kerberos"
assert kwargs["username"] is None
assert kwargs["password"] is None

@mock.patch.dict("sys.modules", {"krb5": mock.MagicMock()})
@mock.patch("smbclient.register_session")
@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_kerberos_auth_via_legacy_auth_key(self, get_conn_mock, register_session):
"""Test backward compat: extra {"auth": "kerberos"} is recognized."""
connection = Connection(
host="kerb-host.example.com",
schema="share",
extra='{"auth": "kerberos"}',
)
get_conn_mock.return_value = connection
register_session.return_value = None
with SambaHook("samba_default"):
_, kwargs = tuple(register_session.call_args_list[0])
assert kwargs["auth_protocol"] == "kerberos"

@mock.patch.dict("sys.modules", {"krb5": mock.MagicMock()})
@mock.patch("smbclient.register_session")
@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_kerberos_auth_via_constructor(self, get_conn_mock, register_session):
"""Test that constructor auth_protocol overrides extra."""
connection = Connection(
host="kerb-host.example.com",
schema="share",
login="user",
password="pass",
)
get_conn_mock.return_value = connection
register_session.return_value = None
with SambaHook("samba_default", auth_protocol="kerberos"):
_, kwargs = tuple(register_session.call_args_list[0])
assert kwargs["auth_protocol"] == "kerberos"

@mock.patch("smbclient.register_session")
@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_legacy_auth_key_ignored_when_invalid(self, get_conn_mock, register_session):
"""Test that extra {"auth": "basic"} is ignored and defaults to negotiate."""
connection = Connection(
host="host",
schema="share",
login="user",
password="pass",
extra='{"auth": "basic"}',
)
get_conn_mock.return_value = connection
register_session.return_value = None
with SambaHook("samba_default"):
_, kwargs = tuple(register_session.call_args_list[0])
assert kwargs["auth_protocol"] == "negotiate"

@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_invalid_auth_protocol_raises(self, get_conn_mock):
"""Test that an invalid auth_protocol raises ValueError."""
connection = Connection(
host="host",
schema="share",
extra='{"auth_protocol": "invalid"}',
)
get_conn_mock.return_value = connection
with pytest.raises(ValueError, match="Invalid auth_protocol 'invalid'"):
SambaHook("samba_default")

@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_kerberos_without_dependency_raises(self, get_conn_mock):
"""Test that kerberos auth without krb5 installed raises AirflowOptionalProviderFeatureException."""
connection = Connection(
host="kerb-host.example.com",
schema="share",
extra='{"auth_protocol": "kerberos"}',
)
get_conn_mock.return_value = connection
with mock.patch.dict("sys.modules", {"krb5": None}):
with pytest.raises(AirflowOptionalProviderFeatureException, match="krb5"):
SambaHook("samba_default")

@mock.patch("smbclient.register_session")
@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_ntlm_auth(self, get_conn_mock, register_session):
"""Test that auth_protocol='ntlm' is passed correctly."""
connection = Connection(
host="host",
schema="share",
login="user",
password="pass",
extra='{"auth_protocol": "ntlm"}',
)
get_conn_mock.return_value = connection
register_session.return_value = None
with SambaHook("samba_default"):
_, kwargs = tuple(register_session.call_args_list[0])
assert kwargs["auth_protocol"] == "ntlm"

@mock.patch("airflow.providers.samba.hooks.samba.smbclient.open_file", return_value=mock.Mock())
@mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection")
def test_open_file(self, get_conn_mock, open_file_mock):
Expand Down
13 changes: 12 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading