Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,25 @@ def delete_file(self, path: str) -> None:
"""
self.conn.remove(path) # type: ignore[arg-type, union-attr]

@staticmethod
def _validate_within_directory(base_dir: str, candidate: str) -> str:
"""
Ensure ``candidate`` resolves to a path inside ``base_dir``.

Directory-entry names are returned by the remote SFTP server and may
contain ``..`` components; joining them into the local destination path
could otherwise write outside it. Containment is verified before any
local write or ``mkdir``.
"""
base_real = os.path.realpath(base_dir)
candidate_real = os.path.realpath(candidate)
if candidate_real != base_real and os.path.commonpath([base_real, candidate_real]) != base_real:
raise ValueError(
f"Refusing to write outside the destination directory: "
f"{candidate!r} resolves outside {base_dir!r}"
)
return candidate

def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefetch: bool = True) -> None:
"""
Transfer the remote directory to a local location.
Expand All @@ -400,10 +419,14 @@ def retrieve_directory(self, remote_full_path: str, local_full_path: str, prefet
Path(local_full_path).mkdir(parents=True)
files, dirs, _ = self.get_tree_map(remote_full_path)
for dir_path in dirs:
new_local_path = os.path.join(local_full_path, os.path.relpath(dir_path, remote_full_path))
new_local_path = self._validate_within_directory(
local_full_path, os.path.join(local_full_path, os.path.relpath(dir_path, remote_full_path))
)
Path(new_local_path).mkdir(parents=True, exist_ok=True)
for file_path in files:
new_local_path = os.path.join(local_full_path, os.path.relpath(file_path, remote_full_path))
new_local_path = self._validate_within_directory(
local_full_path, os.path.join(local_full_path, os.path.relpath(file_path, remote_full_path))
)
self.retrieve_file(file_path, new_local_path, prefetch)

def retrieve_directory_concurrently(
Expand Down Expand Up @@ -438,12 +461,18 @@ def retrieve_file_chunk(
new_local_file_paths, remote_file_paths = [], []
files, dirs, _ = self.get_tree_map(remote_full_path)
for dir_path in dirs:
new_local_path = os.path.join(local_full_path, os.path.relpath(dir_path, remote_full_path))
new_local_path = self._validate_within_directory(
local_full_path,
os.path.join(local_full_path, os.path.relpath(dir_path, remote_full_path)),
)
Path(new_local_path).mkdir(parents=True, exist_ok=True)
for file in files:
remote_file_paths.append(file)
new_local_file_paths.append(
os.path.join(local_full_path, os.path.relpath(file, remote_full_path))
self._validate_within_directory(
local_full_path,
os.path.join(local_full_path, os.path.relpath(file, remote_full_path)),
)
)
remote_file_chunks = [remote_file_paths[i::workers] for i in range(workers)]
local_file_chunks = [new_local_file_paths[i::workers] for i in range(workers)]
Expand Down
23 changes: 23 additions & 0 deletions providers/sftp/tests/unit/sftp/hooks/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,29 @@ def test_store_and_retrieve_directory_concurrently(self):
)
assert retrieved_dir_name in os.listdir(os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS))

def test_validate_within_directory_rejects_escape(self):
base = os.path.join(self.temp_dir, "download")
with pytest.raises(ValueError, match="outside the destination directory"):
SFTPHook._validate_within_directory(base, os.path.join(base, "..", "victim"))
# An in-bounds candidate is returned unchanged.
inside = os.path.join(base, "sub", "file")
assert SFTPHook._validate_within_directory(base, inside) == inside

def test_retrieve_directory_rejects_server_path_traversal(self):
# A remote SFTP server can return a directory-entry name containing ".."
# so the recursive download would escape the local destination directory.
remote = "/srv/export"
local = os.path.join(self.temp_dir, "download_traversal")
escaping_file = "/srv/export/../victim/payload"
with (
patch.object(SFTPHook, "get_tree_map", return_value=([escaping_file], [], [])),
patch.object(SFTPHook, "retrieve_file") as mock_retrieve,
):
with pytest.raises(ValueError, match="outside the destination directory"):
self.hook.retrieve_directory(remote_full_path=remote, local_full_path=local)
mock_retrieve.assert_not_called()
assert not os.path.exists(os.path.join(self.temp_dir, "victim"))

@patch("paramiko.SSHClient")
@patch("paramiko.ProxyCommand")
@patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection")
Expand Down
Loading