Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 39 additions & 39 deletions helpers/data_backend/csv.py → helpers/data_backend/csv_.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import fnmatch
import io
from datetime import datetime
from urllib.request import url2pathname
import hashlib

import pandas as pd
import requests
from PIL import Image

from helpers.data_backend.base import BaseDataBackend
from helpers.image_manipulation.load import load_image
Expand All @@ -28,23 +25,21 @@ def url_to_filename(url: str) -> str:
return url.split("/")[-1]


def shorten_and_clean_filename(filename: str, no_op: bool):
if no_op:
return filename
filename = filename.replace("%20", "-").replace(" ", "-")
if len(filename) > 250:
filename = filename[:120] + "---" + filename[126:]
return filename
def str_hash(filename: str) -> str:
return str(hashlib.sha256(str(filename).encode()).hexdigest())


def html_to_file_loc(parent_directory: Path, url: str, shorten_filenames: bool) -> str:
def path_to_hashed_path(path: Path, hash_filenames: bool) -> Path:
path = Path(path).resolve()
if hash_filenames:
return path.parent.joinpath(str_hash(path.stem) + path.suffix)
return path


def html_to_file_loc(parent_directory: Path, url: str, hash_filenames: bool) -> str:
filename = url_to_filename(url)
cached_loc = str(
parent_directory.joinpath(
shorten_and_clean_filename(filename, no_op=shorten_filenames)
)
)
return cached_loc
cached_loc = path_to_hashed_path(parent_directory.joinpath(filename), hash_filenames)
return str(cached_loc.resolve())


class CSVDataBackend(BaseDataBackend):
Expand All @@ -54,29 +49,28 @@ def __init__(
id: str,
csv_file: Path,
compress_cache: bool = False,
image_url_col: str = "url",
caption_column: str = "caption",
url_column: str = "url",
caption_column: str = "caption",
image_cache_loc: Optional[str] = None,
shorten_filenames: bool = False,
hash_filenames: bool = True,
):
self.id = id
self.type = "csv"
self.compress_cache = compress_cache
self.shorten_filenames = shorten_filenames
self.hash_filenames = hash_filenames
self.csv_file = csv_file
self.accelerator = accelerator
self.image_url_col = image_url_col
self.df = pd.read_csv(csv_file, index_col=image_url_col)
self.url_column = url_column
self.df = pd.read_csv(csv_file, index_col=url_column)
self.df = self.df.groupby(level=0).last() # deduplicate by index (image loc)
self.caption_column = caption_column
self.url_column = url_column
self.image_cache_loc = (
Path(image_cache_loc) if image_cache_loc is not None else None
)

def read(self, location, as_byteIO: bool = False):
"""Read and return the content of the file."""
already_hashed = False
if isinstance(location, Path):
location = str(location.resolve())
if location.startswith("http"):
Expand All @@ -85,11 +79,12 @@ def read(self, location, as_byteIO: bool = False):
cached_loc = html_to_file_loc(
self.image_cache_loc,
location,
shorten_filenames=self.shorten_filenames,
self.hash_filenames,
)
if os.path.exists(cached_loc):
# found cache
location = cached_loc
already_hashed = True
else:
# actually go to website
data = requests.get(location, stream=True).raw.data
Expand All @@ -99,8 +94,13 @@ def read(self, location, as_byteIO: bool = False):
data = requests.get(location, stream=True).raw.data
if not location.startswith("http"):
# read from local file
with open(location, "rb") as file:
data = file.read()
hashed_location = path_to_hashed_path(location, hash_filenames=self.hash_filenames and not already_hashed)
try:
with open(hashed_location, "rb") as file:
data = file.read()
except FileNotFoundError as e:
print(f'ask was for file {location} bound to {hashed_location}')
Comment thread
williamzhuk marked this conversation as resolved.
Outdated
raise e
if not as_byteIO:
return data
return BytesIO(data)
Expand All @@ -114,9 +114,7 @@ def write(self, filepath: Union[str, Path], data: Any) -> None:
filepath = Path(filepath)
# Not a huge fan of auto-shortening filenames, as we hash things for that in other cases.
# However, this is copied in from the original Arcade-AI CSV backend implementation for compatibility.
filepath = filepath.parent.joinpath(
shorten_and_clean_filename(filepath.name, no_op=self.shorten_filenames)
)
filepath = path_to_hashed_path(filepath, self.hash_filenames)
filepath.parent.mkdir(parents=True, exist_ok=True)
with open(filepath, "wb") as file:
# Check if data is a Tensor, and if so, save it appropriately
Expand All @@ -137,6 +135,7 @@ def delete(self, filepath):
if filepath in self.df.index:
self.df.drop(filepath, inplace=True)
# self.save_state()
filepath = path_to_hashed_path(filepath, self.hash_filenames)
if os.path.exists(filepath):
logger.debug(f"Deleting file: {filepath}")
os.remove(filepath)
Expand All @@ -146,20 +145,21 @@ def delete(self, filepath):

def exists(self, filepath):
"""Check if the file exists."""
if isinstance(filepath, Path):
filepath = str(filepath.resolve())
return filepath in self.df.index or os.path.exists(filepath)
if isinstance(filepath, str) and "http" in filepath:
return filepath in self.df.index
else:
filepath = path_to_hashed_path(filepath, self.hash_filenames)
return os.path.exists(filepath)

def open_file(self, filepath, mode):
"""Open the file in the specified mode."""
return open(filepath, mode)
return open(path_to_hashed_path(filepath, self.hash_filenames), mode)

def list_files(self, str_pattern: str, instance_data_dir: str = None) -> tuple:
"""
List all files matching the pattern.
Creates Path objects of each file found.
"""
# print frame contents
logger.debug(
f"CSVDataBackend.list_files: str_pattern={str_pattern}, instance_data_dir={instance_data_dir}"
)
Expand Down Expand Up @@ -279,9 +279,9 @@ def torch_save(self, data, location: Union[str, Path, BytesIO]):
"""
Save a torch tensor to a file.
"""

if isinstance(location, str) or isinstance(location, Path):
if location not in self.df.index:
self.df.loc[location] = pd.Series()
location = path_to_hashed_path(location, self.hash_filenames)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update this method to use clone() before saving, like the local backend does?

location = self.open_file(location, "wb")

if self.compress_cache:
Expand All @@ -297,7 +297,7 @@ def write_batch(self, filepaths: list, data_list: list) -> None:
self.write(filepath, data)

def save_state(self):
self.df.to_csv(self.csv_file, index_label=self.image_url_col)
self.df.to_csv(self.csv_file, index_label=self.url_column)

def get_caption(self, image_path: str) -> str:
if self.caption_column is None:
Expand Down
18 changes: 13 additions & 5 deletions helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from helpers.data_backend.local import LocalDataBackend
from helpers.data_backend.aws import S3DataBackend
from helpers.data_backend.csv import CSVDataBackend
from helpers.data_backend.csv_ import CSVDataBackend
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would prefer we call it csv_url_list if we're changing it

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from helpers.data_backend.csv_ import CSVDataBackend
from helpers.data_backend.csv_url import CSVDataBackend

from helpers.data_backend.base import BaseDataBackend
from helpers.training.default_settings import default, latest_config_version
from helpers.caching.text_embeds import TextEmbeddingCache
Expand Down Expand Up @@ -99,6 +99,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
output["config"]["csv_file"] = backend["csv_file"]
if "csv_caption_column" in backend:
output["config"]["csv_caption_column"] = backend["csv_caption_column"]
if "csv_url_column" in backend:
output["config"]["csv_url_column"] = backend["csv_url_column"]
if "crop_aspect" in backend:
choices = ["square", "preserve", "random"]
if backend.get("crop_aspect", None) not in choices:
Expand Down Expand Up @@ -147,8 +149,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
)
if "hash_filenames" in backend:
output["config"]["hash_filenames"] = backend["hash_filenames"]
if "shorten_filenames" in backend and backend.get("type") == "csv":
output["config"]["shorten_filenames"] = backend["shorten_filenames"]
if "hash_filenames" in backend and backend.get("type") == "csv":
output["config"]["hash_filenames"] = backend["hash_filenames"]

# check if caption_strategy=parquet with metadata_backend=json
if (
Expand Down Expand Up @@ -593,7 +595,7 @@ def configure_multi_databackend(
csv_file=backend["csv_file"],
csv_cache_dir=backend["csv_cache_dir"],
compress_cache=args.compress_disk_cache,
shorten_filenames=backend.get("shorten_filenames", False),
hash_filenames=backend.get("hash_filenames", False),
)
# init_backend["instance_data_dir"] = backend.get("instance_data_dir", backend.get("instance_data_root", backend.get("csv_cache_dir")))
init_backend["instance_data_dir"] = None
Expand Down Expand Up @@ -1038,8 +1040,10 @@ def get_csv_backend(
id: str,
csv_file: str,
csv_cache_dir: str,
url_column: str,
caption_column: str,
compress_cache: bool = False,
shorten_filenames: bool = False,
hash_filenames: bool = False,
) -> CSVDataBackend:
from pathlib import Path

Expand All @@ -1048,8 +1052,11 @@ def get_csv_backend(
id=id,
csv_file=Path(csv_file),
image_cache_loc=csv_cache_dir,
url_column=url_column,
caption_column=caption_column,
compress_cache=compress_cache,
shorten_filenames=shorten_filenames,
hash_filenames=hash_filenames,
)


Expand All @@ -1058,6 +1065,7 @@ def check_csv_config(backend: dict, args) -> None:
"csv_file": "This is the path to the CSV file containing your image URLs.",
"csv_cache_dir": "This is the path to your temporary cache files where images will be stored. This can grow quite large.",
"csv_caption_column": "This is the column in your csv which contains the caption(s) for the samples.",
"csv_url_column": "This is the column in your csv that contains image urls or paths.",
}
for key in required_keys.keys():
if key not in backend:
Expand Down