Skip to content

Commit 3a9fc8e

Browse files
fix: unzip kaggle data (microsoft#464)
* unzip kaggle data * read local_data_path from .env file * fix build docs error * recover azure-identity packages * optimize code logic * add error when downloading data from kaggle --------- Co-authored-by: Xu Yang <peteryang@vip.qq.com>
1 parent 83b3f78 commit 3a9fc8e

File tree

5 files changed

+32
-6
lines changed

5 files changed

+32
-6
lines changed

constraints/3.10.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
azure-identity==1.17.1
12
dill==0.3.9
23
pillow==10.4.0
34
psutil==6.1.0

constraints/3.11.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
azure-identity==1.17.1
12
dill==0.3.9
23
pillow==10.4.0
34
psutil==6.1.0

rdagent/app/kaggle/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Config:
4444
competition: str = ""
4545
"""Kaggle competition name, e.g., 'sf-crime'"""
4646

47-
local_data_path: str = "/data/userdata/share/kaggle"
47+
local_data_path: str = ""
4848
"""Folder storing Kaggle competition data"""
4949

5050
if_action_choosing_based_on_UCB: bool = False

rdagent/core/exception.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,9 @@ class ModelEmptyError(Exception):
4242
"""
4343
Exceptions raised when no model is generated correctly
4444
"""
45+
46+
47+
class KaggleError(Exception):
48+
"""
49+
Exceptions raised when calling Kaggle API
50+
"""

rdagent/scenarios/kaggle/kaggle_crawler.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from selenium.webdriver.common.by import By
1616

1717
from rdagent.app.kaggle.conf import KAGGLE_IMPLEMENT_SETTING
18+
from rdagent.core.exception import KaggleError
1819
from rdagent.core.prompts import Prompts
1920
from rdagent.log import rdagent_logger as logger
2021
from rdagent.oai.llm_utils import APIBackend
@@ -99,11 +100,28 @@ def kaggle_description_css_selectors() -> tuple[str, str]:
99100
def download_data(competition: str, local_path: str = KAGGLE_IMPLEMENT_SETTING.local_data_path) -> None:
100101
data_path = f"{local_path}/{competition}"
101102
if not Path(data_path).exists():
102-
subprocess.run(["kaggle", "competitions", "download", "-c", competition, "-p", data_path])
103-
104-
# unzip data
105-
with zipfile.ZipFile(f"{data_path}/{competition}.zip", "r") as zip_ref:
106-
zip_ref.extractall(data_path)
103+
try:
104+
subprocess.run(
105+
["kaggle", "competitions", "download", "-c", competition, "-p", data_path],
106+
check=True,
107+
stderr=subprocess.PIPE,
108+
stdout=subprocess.PIPE,
109+
)
110+
except subprocess.CalledProcessError as e:
111+
logger.error(f"Download failed: {e}, stderr: {e.stderr}, stdout: {e.stdout}")
112+
raise KaggleError(f"Download failed: {e}, stderr: {e.stderr}, stdout: {e.stdout}")
113+
114+
# unzip data
115+
unzip_path = f"{local_path}/{competition}"
116+
if not Path(unzip_path).exists():
117+
unzip_data(unzip_file_path=f"{data_path}/{competition}.zip", unzip_target_path=unzip_path)
118+
for sub_zip_file in Path(unzip_path).rglob("*.zip"):
119+
unzip_data(sub_zip_file, unzip_target_path=unzip_path)
120+
121+
122+
def unzip_data(unzip_file_path: str, unzip_target_path: str) -> None:
123+
with zipfile.ZipFile(unzip_file_path, "r") as zip_ref:
124+
zip_ref.extractall(unzip_target_path)
107125

108126

109127
def leaderboard_scores(competition: str) -> list[float]:

0 commit comments

Comments
 (0)