Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Cache trial when set_trial_state_values is called
  • Loading branch information
c-bata committed May 12, 2023
commit fb6391e5f46d7953959710fbb2d8d0fca27b7683
2 changes: 1 addition & 1 deletion .github/workflows/tests-storage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
# RDB. Since current name "tests-rdbstorage" is required in the Branch protection rules, you
# need to modify the Branch protection rules as well.
tests-rdbstorage:
if: (github.event_name == 'schedule' && github.repository == 'optuna/optuna') || (github.event_name != 'schedule')
if: (github.event_name == 'schedule') || (github.event_name != 'schedule')
runs-on: ubuntu-latest

strategy:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ concurrency:

jobs:
tests:
if: (github.event_name == 'schedule' && github.repository == 'optuna/optuna') || (github.event_name != 'schedule')
if: (github.event_name == 'schedule') || (github.event_name != 'schedule')
runs-on: ubuntu-latest

strategy:
Expand Down
10 changes: 9 additions & 1 deletion optuna/storages/_cached_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,15 @@ def get_best_trial(self, study_id: int) -> FrozenTrial:
def set_trial_state_values(
self, trial_id: int, state: TrialState, values: Optional[Sequence[float]] = None
) -> bool:
return self._backend.set_trial_state_values(trial_id, state=state, values=values)
ret = self._backend.set_trial_state_values(trial_id, state=state, values=values)
if state.is_finished() and trial_id in self._trial_id_to_study_id_and_number:
backend_trial = self._backend.get_trial(trial_id)
study_id, trial_number = self._trial_id_to_study_id_and_number[trial_id]
with self._lock:
study = self._studies[study_id]
study.trials[trial_number] = backend_trial
study.finished_trial_ids.add(trial_id)
return ret

def set_trial_intermediate_value(
self, trial_id: int, step: int, intermediate_value: float
Expand Down
4 changes: 0 additions & 4 deletions tests/storages_tests/test_cached_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,6 @@ def test_delete_study() -> None:
trial_id2 = storage.create_new_trial(study_id2)
storage.set_trial_state_values(trial_id2, state=TrialState.COMPLETE)

# Update _StudyInfo.finished_trial_ids
storage.read_trials_from_remote_storage(study_id1)
storage.read_trials_from_remote_storage(study_id2)

storage.delete_study(study_id1)
assert storage._get_cached_trial(trial_id1) is None
assert storage._get_cached_trial(trial_id2) is not None