Skip to content

Commit 52efc85

Browse files
authored
ui updates (microsoft#778)
1 parent 5421be9 commit 52efc85

File tree

1 file changed

+60
-9
lines changed

1 file changed

+60
-9
lines changed

rdagent/log/ui/ds_summary.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import pickle
23
import re
34
from collections import deque
45
from datetime import datetime, timedelta
@@ -31,6 +32,16 @@ def get_script_time(stdout_p: Path):
3132
return None
3233

3334

35+
def get_final_sota_exp(log_path: Path):
36+
sota_exp_paths = [i for i in log_path.rglob(f"**/SOTA experiment/**/*.pkl")]
37+
if len(sota_exp_paths) == 0:
38+
return None
39+
final_sota_exp_path = max(sota_exp_paths, key=lambda x: int(re.match(r".*Loop_(\d+).*", str(x))[1]))
40+
with final_sota_exp_path.open("rb") as f:
41+
final_sota_exp = pickle.load(f)
42+
return final_sota_exp
43+
44+
3445
# @st.cache_data(persist=True)
3546
def get_summary_df(log_folders: list[str]) -> tuple[dict, pd.DataFrame]:
3647
summarys = {}
@@ -75,6 +86,12 @@ def get_summary_df(log_folders: list[str]) -> tuple[dict, pd.DataFrame]:
7586
v["exp_gen_time"] = str(exp_gen_time).split(".")[0]
7687
v["coding_time"] = str(coding_time).split(".")[0]
7788
v["running_time"] = str(running_time).split(".")[0]
89+
90+
final_sota_exp = get_final_sota_exp(Path(lf) / k)
91+
if final_sota_exp is not None:
92+
v["sota_exp_score_valid"] = final_sota_exp.result.loc["ensemble"].iloc[0]
93+
else:
94+
v["sota_exp_score_valid"] = None
7895
# 调整实验名字
7996
if "amlt" in lf:
8097
summary[f"{lf[lf.rfind('amlt')+5:].split('/')[0]} - {k}"] = v
@@ -104,6 +121,7 @@ def get_summary_df(log_folders: list[str]) -> tuple[dict, pd.DataFrame]:
104121
"Any Medal",
105122
"Best Result",
106123
"SOTA Exp",
124+
"SOTA Exp Score (valid)",
107125
"SOTA Exp Score",
108126
"Baseline Score",
109127
"Ours - Base",
@@ -178,6 +196,7 @@ def compare_score(s1, s2):
178196
base_df.loc[k, "Ours vs Silver"] = compare_score(v["sota_exp_score"], v.get("silver_threshold", None))
179197
base_df.loc[k, "Ours vs Gold"] = compare_score(v["sota_exp_score"], v.get("gold_threshold", None))
180198
base_df.loc[k, "SOTA Exp Score"] = v.get("sota_exp_score", None)
199+
base_df.loc[k, "SOTA Exp Score (valid)"] = v.get("sota_exp_score_valid", None)
181200
base_df.loc[k, "Baseline Score"] = baseline_score
182201
base_df.loc[k, "Bronze Threshold"] = v.get("bronze_threshold", None)
183202
base_df.loc[k, "Silver Threshold"] = v.get("silver_threshold", None)
@@ -199,6 +218,7 @@ def compare_score(s1, s2):
199218
"Ours - Base": float,
200219
"Ours vs Base": float,
201220
"SOTA Exp Score": float,
221+
"SOTA Exp Score (valid)": float,
202222
"Baseline Score": float,
203223
"Bronze Threshold": float,
204224
"Silver Threshold": float,
@@ -330,6 +350,7 @@ def shorten_folder_name(folder: str) -> str:
330350
"Best Result",
331351
"SOTA Exp",
332352
"SOTA Exp Score",
353+
"SOTA Exp Score (valid)",
333354
],
334355
axis=0,
335356
),
@@ -407,24 +428,54 @@ def shorten_folder_name(folder: str) -> str:
407428
for k, v in summary.items():
408429
with st.container(border=True):
409430
st.markdown(f"**:blue[{k}] - :violet[{v['competition']}]**")
410-
fc1, fc2 = st.columns(2)
411-
tscores = {f"loop {k-1}": v for k, v in v["test_scores"].items()}
412-
tdf = pd.Series(tscores, name="score")
413-
f2 = px.line(tdf, markers=True, title="Test scores")
414-
fc2.plotly_chart(f2, key=k)
415431
try:
416-
vscores = {k: v.iloc[:, 0] for k, v in v["valid_scores"].items()}
432+
tscores = {f"loop {k-1}": v for k, v in v["test_scores"].items()}
433+
vscores = {}
434+
for k, vs in v["valid_scores"].items():
435+
if not vs.index.is_unique:
436+
st.warning(
437+
f"Loop {k}'s valid scores index are not unique, only the last one will be kept to show."
438+
)
439+
st.write(vs)
440+
vscores[k] = vs[~vs.index.duplicated(keep="last")].iloc[:, 0]
417441

418442
if len(vscores) > 0:
419443
metric_name = list(vscores.values())[0].name
420444
else:
421445
metric_name = "None"
422446

447+
tdf = pd.Series(tscores, name="score")
423448
vdf = pd.DataFrame(vscores)
449+
if "ensemble" in vdf.index:
450+
ensemble_row = vdf.loc[["ensemble"]]
451+
vdf = pd.concat([ensemble_row, vdf.drop("ensemble")])
424452
vdf.columns = [f"loop {i}" for i in vdf.columns]
425-
f1 = px.line(vdf.T, markers=True, title=f"Valid scores (metric: {metric_name})")
426-
427-
fc1.plotly_chart(f1, key=f"{k}_v")
453+
fig = go.Figure()
454+
# Add test scores trace from tdf
455+
fig.add_trace(
456+
go.Scatter(
457+
x=tdf.index,
458+
y=tdf,
459+
mode="lines+markers",
460+
name="Test scores",
461+
marker=dict(symbol="diamond"),
462+
line=dict(shape="linear", dash="dash"),
463+
)
464+
)
465+
# Add valid score traces from vdf (transposed to have loops on x-axis)
466+
for column in vdf.T.columns:
467+
fig.add_trace(
468+
go.Scatter(
469+
x=vdf.T.index,
470+
y=vdf.T[column],
471+
mode="lines+markers",
472+
name=f"{column}",
473+
visible=("legendonly" if column != "ensemble" else None),
474+
)
475+
)
476+
fig.update_layout(title=f"Test and Valid scores (metric: {metric_name})")
477+
478+
st.plotly_chart(fig)
428479
except Exception as e:
429480
import traceback
430481

0 commit comments

Comments
 (0)