Skip to content
Closed
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions pyiceberg/io/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Any,
Callable,
Dict,
Optional,
Union,
)
from urllib.parse import urlparse
Expand Down Expand Up @@ -194,13 +195,18 @@ def _gs(properties: Properties) -> AbstractFileSystem:
def _adls(properties: Properties) -> AbstractFileSystem:
from adlfs import AzureBlobFileSystem

for key, sas_token in {
key.replace(f"{ADLS_SAS_TOKEN}.", ""): value for key, value in properties.items() if key.startswith(ADLS_SAS_TOKEN)
}.items():
if ADLS_ACCOUNT_NAME not in properties:
properties[ADLS_ACCOUNT_NAME] = key.split(".")[0]
if ADLS_SAS_TOKEN not in properties:
properties[ADLS_SAS_TOKEN] = sas_token
# https://learn.microsoft.com/en-us/azure/storage/blobs/data-lake-storage-introduction-abfs-uri#uri-syntax
if netloc := properties.get("netloc"):
account_uri = netloc.split("@")[-1]
else:
account_uri = None

if not properties.get(ADLS_ACCOUNT_NAME) and account_uri:
properties[ADLS_ACCOUNT_NAME] = account_uri.split(".")[0]

# Fixes https://github.com/apache/iceberg-python/issues/1146
if not properties.get(ADLS_SAS_TOKEN) and account_uri:
properties[ADLS_SAS_TOKEN] = properties.get(f"{ADLS_SAS_TOKEN}.{account_uri}")

return AzureBlobFileSystem(
connection_string=properties.get(ADLS_CONNECTION_STRING),
Expand Down Expand Up @@ -340,7 +346,7 @@ class FsspecFileIO(FileIO):
def __init__(self, properties: Properties):
self._scheme_to_fs = {}
self._scheme_to_fs.update(SCHEME_TO_FS)
self.get_fs: Callable[[str], AbstractFileSystem] = lru_cache(self._get_fs)
self.get_fs: Callable[[str, Optional[str]], AbstractFileSystem] = lru_cache(self._get_fs)
super().__init__(properties=properties)

def new_input(self, location: str) -> FsspecInputFile:
Expand All @@ -353,7 +359,7 @@ def new_input(self, location: str) -> FsspecInputFile:
FsspecInputFile: An FsspecInputFile instance for the given location.
"""
uri = urlparse(location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
return FsspecInputFile(location=location, fs=fs)

def new_output(self, location: str) -> FsspecOutputFile:
Expand All @@ -366,7 +372,7 @@ def new_output(self, location: str) -> FsspecOutputFile:
FsspecOutputFile: An FsspecOutputFile instance for the given location.
"""
uri = urlparse(location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
return FsspecOutputFile(location=location, fs=fs)

def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
Expand All @@ -383,14 +389,17 @@ def delete(self, location: Union[str, InputFile, OutputFile]) -> None:
str_location = location

uri = urlparse(str_location)
fs = self.get_fs(uri.scheme)
fs = self.get_fs(uri.scheme, uri.netloc)
fs.rm(str_location)

def _get_fs(self, scheme: str) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme."""
def _get_fs(self, scheme: str, netloc: Optional[str] = None) -> AbstractFileSystem:
"""Get a filesystem for a specific scheme and netloc."""
if scheme not in self._scheme_to_fs:
raise ValueError(f"No registered filesystem for scheme: {scheme}")
return self._scheme_to_fs[scheme](self.properties)
properties = self.properties.copy()
if netloc:
properties["netloc"] = netloc
return self._scheme_to_fs[scheme](properties)

def __getstate__(self) -> Dict[str, Any]:
"""Create a dictionary of the FsSpecFileIO fields used when pickling."""
Expand Down
Loading