Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ init-qlib-env:

dev:
$(PIPRUN) pip install -e .[docs,lint,package,test] -c $(CONSTRAINTS_FILE)
$(PIPRUN) pip install -U kaggle
if [ "$(CI)" != "true" ] && command -v pre-commit > /dev/null 2>&1; then pre-commit install --hook-type pre-push; fi

# Generate constraints for current Python version.
Expand Down
13 changes: 11 additions & 2 deletions rdagent/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def multiprocessing_wrapper(func_calls: list[tuple[Callable, tuple]], n: int) ->
return [result.get() for result in results]


def cache_with_pickle(hash_func: Callable, post_process_func: Callable | None = None) -> Callable:
def cache_with_pickle(hash_func: Callable, post_process_func: Callable | None = None, force: bool = False) -> Callable:
"""
This decorator will cache the return value of the function with pickle.
The cache key is generated by the hash_func. The hash function returns a string or None.
Expand All @@ -162,12 +162,21 @@ def cache_with_pickle(hash_func: Callable, post_process_func: Callable | None =
The post_process_func will be called with the original arguments and the cached result
to give each caller a chance to process the cached result. The post_process_func should
return the final result.

Parameters
----------
hash_func : Callable
The function to generate the hash key for the cache.
post_process_func : Callable | None, optional
The function to process the cached result, by default None.
force : bool, optional
If True, the cache will be used even if RD_AGENT_SETTINGS.cache_with_pickle is False, by default False.
"""

def cache_decorator(func: Callable) -> Callable:
@functools.wraps(func)
def cache_wrapper(*args: Any, **kwargs: Any) -> Any:
if not RD_AGENT_SETTINGS.cache_with_pickle:
if not RD_AGENT_SETTINGS.cache_with_pickle and not force:
return func(*args, **kwargs)

target_folder = Path(RD_AGENT_SETTINGS.pickle_cache_folder_path_str) / f"{func.__module__}.{func.__name__}"
Expand Down
3 changes: 2 additions & 1 deletion rdagent/scenarios/kaggle/kaggle_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from rdagent.app.kaggle.conf import KAGGLE_IMPLEMENT_SETTING
from rdagent.core.conf import ExtendedBaseSettings
from rdagent.core.exception import KaggleError
from rdagent.core.prompts import Prompts
from rdagent.core.utils import cache_with_pickle
from rdagent.log import rdagent_logger as logger
from rdagent.oai.llm_utils import APIBackend
from rdagent.scenarios.data_science.debug.data import create_debug_data
Expand Down Expand Up @@ -186,6 +186,7 @@ def unzip_data(unzip_file_path: str, unzip_target_path: str) -> None:
zip_ref.extractall(unzip_target_path)


@cache_with_pickle(hash_func=lambda x: x, force=True)
def leaderboard_scores(competition: str) -> list[float]:
from kaggle.api.kaggle_api_extended import KaggleApi

Expand Down
Loading