-
Notifications
You must be signed in to change notification settings - Fork 279
csv backend updates #645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
csv backend updates #645
Changes from 5 commits
abe647f
6f91502
52b6554
c693c7d
ff0739c
614764d
b549f84
48426a4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,8 @@ | ||
| import fnmatch | ||
| import io | ||
| from datetime import datetime | ||
| from urllib.request import url2pathname | ||
| import hashlib | ||
|
|
||
| import pandas as pd | ||
| import requests | ||
| from PIL import Image | ||
|
|
||
| from helpers.data_backend.base import BaseDataBackend | ||
| from helpers.image_manipulation.load import load_image | ||
|
|
@@ -28,23 +25,21 @@ def url_to_filename(url: str) -> str: | |
| return url.split("/")[-1] | ||
|
|
||
|
|
||
| def shorten_and_clean_filename(filename: str, no_op: bool): | ||
| if no_op: | ||
| return filename | ||
| filename = filename.replace("%20", "-").replace(" ", "-") | ||
| if len(filename) > 250: | ||
| filename = filename[:120] + "---" + filename[126:] | ||
| return filename | ||
| def str_hash(filename: str) -> str: | ||
| return str(hashlib.sha256(str(filename).encode()).hexdigest()) | ||
|
|
||
|
|
||
| def html_to_file_loc(parent_directory: Path, url: str, shorten_filenames: bool) -> str: | ||
| def path_to_hashed_path(path: Path, hash_filenames: bool) -> Path: | ||
| path = Path(path).resolve() | ||
| if hash_filenames: | ||
| return path.parent.joinpath(str_hash(path.stem) + path.suffix) | ||
| return path | ||
|
|
||
|
|
||
| def html_to_file_loc(parent_directory: Path, url: str, hash_filenames: bool) -> str: | ||
| filename = url_to_filename(url) | ||
| cached_loc = str( | ||
| parent_directory.joinpath( | ||
| shorten_and_clean_filename(filename, no_op=shorten_filenames) | ||
| ) | ||
| ) | ||
| return cached_loc | ||
| cached_loc = path_to_hashed_path(parent_directory.joinpath(filename), hash_filenames) | ||
| return str(cached_loc.resolve()) | ||
|
|
||
|
|
||
| class CSVDataBackend(BaseDataBackend): | ||
|
|
@@ -54,29 +49,28 @@ def __init__( | |
| id: str, | ||
| csv_file: Path, | ||
| compress_cache: bool = False, | ||
| image_url_col: str = "url", | ||
| caption_column: str = "caption", | ||
| url_column: str = "url", | ||
| caption_column: str = "caption", | ||
| image_cache_loc: Optional[str] = None, | ||
| shorten_filenames: bool = False, | ||
| hash_filenames: bool = True, | ||
| ): | ||
| self.id = id | ||
| self.type = "csv" | ||
| self.compress_cache = compress_cache | ||
| self.shorten_filenames = shorten_filenames | ||
| self.hash_filenames = hash_filenames | ||
| self.csv_file = csv_file | ||
| self.accelerator = accelerator | ||
| self.image_url_col = image_url_col | ||
| self.df = pd.read_csv(csv_file, index_col=image_url_col) | ||
| self.url_column = url_column | ||
| self.df = pd.read_csv(csv_file, index_col=url_column) | ||
| self.df = self.df.groupby(level=0).last() # deduplicate by index (image loc) | ||
| self.caption_column = caption_column | ||
| self.url_column = url_column | ||
| self.image_cache_loc = ( | ||
| Path(image_cache_loc) if image_cache_loc is not None else None | ||
| ) | ||
|
|
||
| def read(self, location, as_byteIO: bool = False): | ||
| """Read and return the content of the file.""" | ||
| already_hashed = False | ||
| if isinstance(location, Path): | ||
| location = str(location.resolve()) | ||
| if location.startswith("http"): | ||
|
|
@@ -85,11 +79,12 @@ def read(self, location, as_byteIO: bool = False): | |
| cached_loc = html_to_file_loc( | ||
| self.image_cache_loc, | ||
| location, | ||
| shorten_filenames=self.shorten_filenames, | ||
| self.hash_filenames, | ||
| ) | ||
| if os.path.exists(cached_loc): | ||
| # found cache | ||
| location = cached_loc | ||
| already_hashed = True | ||
| else: | ||
| # actually go to website | ||
| data = requests.get(location, stream=True).raw.data | ||
|
|
@@ -99,8 +94,13 @@ def read(self, location, as_byteIO: bool = False): | |
| data = requests.get(location, stream=True).raw.data | ||
| if not location.startswith("http"): | ||
| # read from local file | ||
| with open(location, "rb") as file: | ||
| data = file.read() | ||
| hashed_location = path_to_hashed_path(location, hash_filenames=self.hash_filenames and not already_hashed) | ||
| try: | ||
| with open(hashed_location, "rb") as file: | ||
| data = file.read() | ||
| except FileNotFoundError as e: | ||
| print(f'ask was for file {location} bound to {hashed_location}') | ||
| raise e | ||
| if not as_byteIO: | ||
| return data | ||
| return BytesIO(data) | ||
|
|
@@ -114,9 +114,7 @@ def write(self, filepath: Union[str, Path], data: Any) -> None: | |
| filepath = Path(filepath) | ||
| # Not a huge fan of auto-shortening filenames, as we hash things for that in other cases. | ||
| # However, this is copied in from the original Arcade-AI CSV backend implementation for compatibility. | ||
| filepath = filepath.parent.joinpath( | ||
| shorten_and_clean_filename(filepath.name, no_op=self.shorten_filenames) | ||
| ) | ||
| filepath = path_to_hashed_path(filepath, self.hash_filenames) | ||
| filepath.parent.mkdir(parents=True, exist_ok=True) | ||
| with open(filepath, "wb") as file: | ||
| # Check if data is a Tensor, and if so, save it appropriately | ||
|
|
@@ -137,6 +135,7 @@ def delete(self, filepath): | |
| if filepath in self.df.index: | ||
| self.df.drop(filepath, inplace=True) | ||
| # self.save_state() | ||
| filepath = path_to_hashed_path(filepath, self.hash_filenames) | ||
| if os.path.exists(filepath): | ||
| logger.debug(f"Deleting file: {filepath}") | ||
| os.remove(filepath) | ||
|
|
@@ -146,20 +145,21 @@ def delete(self, filepath): | |
|
|
||
| def exists(self, filepath): | ||
| """Check if the file exists.""" | ||
| if isinstance(filepath, Path): | ||
| filepath = str(filepath.resolve()) | ||
| return filepath in self.df.index or os.path.exists(filepath) | ||
| if isinstance(filepath, str) and "http" in filepath: | ||
| return filepath in self.df.index | ||
| else: | ||
| filepath = path_to_hashed_path(filepath, self.hash_filenames) | ||
| return os.path.exists(filepath) | ||
|
|
||
| def open_file(self, filepath, mode): | ||
| """Open the file in the specified mode.""" | ||
| return open(filepath, mode) | ||
| return open(path_to_hashed_path(filepath, self.hash_filenames), mode) | ||
|
|
||
| def list_files(self, str_pattern: str, instance_data_dir: str = None) -> tuple: | ||
| """ | ||
| List all files matching the pattern. | ||
| Creates Path objects of each file found. | ||
| """ | ||
| # print frame contents | ||
| logger.debug( | ||
| f"CSVDataBackend.list_files: str_pattern={str_pattern}, instance_data_dir={instance_data_dir}" | ||
| ) | ||
|
|
@@ -279,9 +279,9 @@ def torch_save(self, data, location: Union[str, Path, BytesIO]): | |
| """ | ||
| Save a torch tensor to a file. | ||
| """ | ||
|
|
||
| if isinstance(location, str) or isinstance(location, Path): | ||
| if location not in self.df.index: | ||
| self.df.loc[location] = pd.Series() | ||
| location = path_to_hashed_path(location, self.hash_filenames) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we update this method to use clone() before saving, like the local backend does? |
||
| location = self.open_file(location, "wb") | ||
|
|
||
| if self.compress_cache: | ||
|
|
@@ -297,7 +297,7 @@ def write_batch(self, filepaths: list, data_list: list) -> None: | |
| self.write(filepath, data) | ||
|
|
||
| def save_state(self): | ||
| self.df.to_csv(self.csv_file, index_label=self.image_url_col) | ||
| self.df.to_csv(self.csv_file, index_label=self.url_column) | ||
|
|
||
| def get_caption(self, image_path: str) -> str: | ||
| if self.caption_column is None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,6 +1,6 @@ | ||||||
| from helpers.data_backend.local import LocalDataBackend | ||||||
| from helpers.data_backend.aws import S3DataBackend | ||||||
| from helpers.data_backend.csv import CSVDataBackend | ||||||
| from helpers.data_backend.csv_ import CSVDataBackend | ||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i would prefer we call it csv_url_list if we're changing it
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| from helpers.data_backend.base import BaseDataBackend | ||||||
| from helpers.training.default_settings import default, latest_config_version | ||||||
| from helpers.caching.text_embeds import TextEmbeddingCache | ||||||
|
|
@@ -99,6 +99,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: | |||||
| output["config"]["csv_file"] = backend["csv_file"] | ||||||
| if "csv_caption_column" in backend: | ||||||
| output["config"]["csv_caption_column"] = backend["csv_caption_column"] | ||||||
| if "csv_url_column" in backend: | ||||||
| output["config"]["csv_url_column"] = backend["csv_url_column"] | ||||||
| if "crop_aspect" in backend: | ||||||
| choices = ["square", "preserve", "random"] | ||||||
| if backend.get("crop_aspect", None) not in choices: | ||||||
|
|
@@ -147,8 +149,8 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict: | |||||
| ) | ||||||
| if "hash_filenames" in backend: | ||||||
| output["config"]["hash_filenames"] = backend["hash_filenames"] | ||||||
| if "shorten_filenames" in backend and backend.get("type") == "csv": | ||||||
| output["config"]["shorten_filenames"] = backend["shorten_filenames"] | ||||||
| if "hash_filenames" in backend and backend.get("type") == "csv": | ||||||
| output["config"]["hash_filenames"] = backend["hash_filenames"] | ||||||
|
|
||||||
| # check if caption_strategy=parquet with metadata_backend=json | ||||||
| if ( | ||||||
|
|
@@ -593,7 +595,7 @@ def configure_multi_databackend( | |||||
| csv_file=backend["csv_file"], | ||||||
| csv_cache_dir=backend["csv_cache_dir"], | ||||||
| compress_cache=args.compress_disk_cache, | ||||||
| shorten_filenames=backend.get("shorten_filenames", False), | ||||||
| hash_filenames=backend.get("hash_filenames", False), | ||||||
| ) | ||||||
| # init_backend["instance_data_dir"] = backend.get("instance_data_dir", backend.get("instance_data_root", backend.get("csv_cache_dir"))) | ||||||
| init_backend["instance_data_dir"] = None | ||||||
|
|
@@ -1038,8 +1040,10 @@ def get_csv_backend( | |||||
| id: str, | ||||||
| csv_file: str, | ||||||
| csv_cache_dir: str, | ||||||
| url_column: str, | ||||||
| caption_column: str, | ||||||
| compress_cache: bool = False, | ||||||
| shorten_filenames: bool = False, | ||||||
| hash_filenames: bool = False, | ||||||
| ) -> CSVDataBackend: | ||||||
| from pathlib import Path | ||||||
|
|
||||||
|
|
@@ -1048,8 +1052,11 @@ def get_csv_backend( | |||||
| id=id, | ||||||
| csv_file=Path(csv_file), | ||||||
| image_cache_loc=csv_cache_dir, | ||||||
| url_column=url_column, | ||||||
| caption_column=caption_column, | ||||||
| compress_cache=compress_cache, | ||||||
| shorten_filenames=shorten_filenames, | ||||||
| hash_filenames=hash_filenames, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -1058,6 +1065,7 @@ def check_csv_config(backend: dict, args) -> None: | |||||
| "csv_file": "This is the path to the CSV file containing your image URLs.", | ||||||
| "csv_cache_dir": "This is the path to your temporary cache files where images will be stored. This can grow quite large.", | ||||||
| "csv_caption_column": "This is the column in your csv which contains the caption(s) for the samples.", | ||||||
| "csv_url_column": "This is the column in your csv that contains image urls or paths.", | ||||||
| } | ||||||
| for key in required_keys.keys(): | ||||||
| if key not in backend: | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.