Skip to content

Commit c3854cf

Browse files
authored
chore: withdraw for policy error (microsoft#907)
* chore: withdraw for policy error Co-authored-by: you-n-g <you-n-g@users.noreply.github.com> * Apply suggestions from code review
1 parent b87de56 commit c3854cf

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

rdagent/core/exception.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,9 @@ class KaggleError(Exception):
5757
"""
5858
Exceptions raised when calling Kaggle API
5959
"""
60+
61+
62+
class PolicyError(Exception):
63+
"""
64+
Exceptions raised due to content management policy
65+
"""

rdagent/oai/backend/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pytz
1515
from pydantic import TypeAdapter
1616

17+
from rdagent.core.exception import PolicyError
1718
from rdagent.core.utils import LLM_CACHE_SEED_GEN, SingletonBaseClass
1819
from rdagent.log import LogColors
1920
from rdagent.log import rdagent_logger as logger
@@ -368,7 +369,7 @@ def _try_create_chat_completion_or_embedding( # type: ignore[no-untyped-def]
368369
violation_count += 1
369370
if violation_count >= LLM_SETTINGS.violation_fail_limit:
370371
logger.warning("Content policy violation detected.")
371-
raise e
372+
raise PolicyError(e)
372373

373374
if (
374375
openai_imported

rdagent/scenarios/data_science/loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from rdagent.components.workflow.conf import BasePropSetting
2222
from rdagent.components.workflow.rd_loop import RDLoop
2323
from rdagent.core.conf import RD_AGENT_SETTINGS
24-
from rdagent.core.exception import CoderError, RunnerError
24+
from rdagent.core.exception import CoderError, PolicyError, RunnerError
2525
from rdagent.core.proposal import ExperimentFeedback, ExpGen
2626
from rdagent.core.scenario import Scenario
2727
from rdagent.core.utils import import_class
@@ -36,6 +36,7 @@
3636
class DataScienceRDLoop(RDLoop):
3737
# NOTE: we move the DataScienceRDLoop here to be easier to be imported
3838
skip_loop_error = (CoderError, RunnerError)
39+
withdraw_loop_error = (PolicyError,)
3940

4041
@staticmethod
4142
def _get_exp_gen(class_uri: str, scen: Scenario):

rdagent/utils/workflow.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class LoopBase:
8888
loop_trace: dict[int, list[LoopTrace]]
8989

9090
skip_loop_error: tuple[type[BaseException], ...] = () # you can define a list of error that will skip current loop
91+
withdraw_loop_error: tuple[
92+
type[BaseException], ...
93+
] = () # you can define a list of error that will withdraw current loop
9194

9295
EXCEPTION_KEY = "_EXCEPTION"
9396

@@ -183,6 +186,11 @@ def run(self, step_n: int | None = None, loop_n: int | None = None, all_duration
183186
self.step_idx = len(self.steps) - 1 # directly jump to the last step.
184187
self.loop_prev_out[self.EXCEPTION_KEY] = e
185188
continue
189+
elif isinstance(e, self.withdraw_loop_error):
190+
logger.warning(f"Withdraw loop {li} due to {e}")
191+
# Back to previous loop
192+
self.step_backward(li - 1)
193+
continue
186194
else:
187195
raise
188196
finally:
@@ -207,6 +215,27 @@ def run(self, step_n: int | None = None, loop_n: int | None = None, all_duration
207215

208216
self.dump(self.session_folder / f"{li}" / f"{si}_{name}") # save a snapshot after the session
209217

218+
def step_backward(self, li: int) -> None:
219+
prev_session_dir = self.session_folder / str(li)
220+
prev_path = min(
221+
(p for p in prev_session_dir.glob("*_*") if p.is_file()),
222+
key=lambda item: int(item.name.split("_", 1)[0]),
223+
default=None,
224+
)
225+
if prev_path:
226+
loaded = type(self).load(
227+
prev_path,
228+
output_path=self.session_folder.parent,
229+
do_truncate=False,
230+
replace_timer=True,
231+
)
232+
logger.info(f"Load previous session from {prev_path}")
233+
# Overwrite current instance state
234+
self.__dict__ = loaded.__dict__
235+
else:
236+
logger.error(f"No previous dump found at {prev_session_dir}, cannot withdraw loop {li}")
237+
raise
238+
210239
def dump(self, path: str | Path) -> None:
211240
if RD_Agent_TIMER_wrapper.timer.started:
212241
RD_Agent_TIMER_wrapper.timer.update_remain_time()

0 commit comments

Comments
 (0)