-
Notifications
You must be signed in to change notification settings - Fork 407
[OMNIML-4788] specdec_bench: configuration.json provenance + upload_to_s3 #1531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0b78aa7
3fe120e
58633f0
251afa5
02fc28c
dacabe3
8bda357
0b759a2
95e8040
70e8ceb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| boto3>=1.34.0 | ||
| botocore>=1.34.0 | ||
| datasets>=3.1.0 | ||
| rich>=14.2.0 | ||
| seaborn>=0.13.2 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,23 +55,31 @@ def _process_lengths(self, lengths): | |
| self.out["Conditional_Acceptance_Rate"][k] = running_len / sum_lengths / prev_ratio | ||
| prev_ratio = running_len / sum_lengths | ||
| running_len -= v | ||
| # Joint acceptance rate at step k = product of conditional acceptance | ||
| # rates at steps 1..k = probability that ≥k tokens are accepted in | ||
| # a row. The visualizer renders this as a separate panel. | ||
| self.out["Joint_Acceptance_Rate"] = {} | ||
| running_joint = 1.0 | ||
| for k, cond_ar in self.out["Conditional_Acceptance_Rate"].items(): | ||
| running_joint *= cond_ar | ||
| self.out["Joint_Acceptance_Rate"][k] = running_joint | ||
|
|
||
| def process_final(self, text_outputs): | ||
| all_ar = [] | ||
| lengths = {} | ||
| self.out["Request_AR"] = {} | ||
| self.out["Request_AL"] = {} | ||
| self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0])) | ||
| for request_id, turns in self.prompt_ar.items(): | ||
| self.out["Request_AR"][request_id] = {} | ||
| self.out["Request_AL"][request_id] = {} | ||
| for turn_id, turn in turns.items(): | ||
| ar = sum(turn) / len(turn) | ||
| self.out["Request_AR"][request_id][turn_id] = ar | ||
| self.out["Request_AL"][request_id][turn_id] = ar | ||
| all_ar.append(ar) | ||
| self._get_lengths(turn, lengths) | ||
| print(request_id, turn_id, self.out["Request_AR"][request_id][turn_id]) | ||
| print(request_id, turn_id, self.out["Request_AL"][request_id][turn_id]) | ||
| average_ar = sum(all_ar) / len(all_ar) | ||
| print("Average AR:", average_ar) | ||
| self.out["Average_AR"] = average_ar | ||
| self.out["Average_AL"] = average_ar | ||
|
Comment on lines
80
to
+82
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the average label to match renamed key. Line 81 still prints Suggested patch- print("Average AR:", average_ar)
+ print("Average AL:", average_ar)🤖 Prompt for AI Agents |
||
| self._process_lengths(lengths) | ||
| self.write() | ||
| self._format_write_output(text_outputs) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,26 +44,26 @@ def __init__(self, requests): | |
|
|
||
| def process_final(self, text_outputs): | ||
| lengths = {} | ||
| self.out["Request_AR"] = {} | ||
| self.out["Request_AL"] = {} | ||
| for request_id, request in enumerate(self.requests): | ||
| turns = self.prompt_ar[request_id].values() | ||
| assert len(turns) == len(request.turns), ( | ||
| f"Number of turns {len(turns)} does not match number of turns in request {len(request.turns)}" | ||
| ) | ||
| self.out["Request_AR"][request.question_id] = mean(list(chain(*turns))) | ||
| self.out["Request_AL"][request.question_id] = mean(list(chain(*turns))) | ||
| for turn in turns: | ||
| self._get_lengths(turn, lengths) | ||
| print(request.category, self.out["Request_AR"][request.question_id]) | ||
| print(request.category, self.out["Request_AL"][request.question_id]) | ||
| per_category = defaultdict(list) | ||
| for request in self.requests: | ||
| per_category[request.category].append(self.out["Request_AR"][request.question_id]) | ||
| self.out["Category_AR"] = {} | ||
| per_category[request.category].append(self.out["Request_AL"][request.question_id]) | ||
| self.out["Category_AL"] = {} | ||
| for category_name, category_ar in per_category.items(): | ||
| if len(category_ar) > 0: | ||
| category_ar = mean(category_ar) | ||
| self.out["Category_AR"][category_name] = category_ar | ||
| average_ar = mean(self.out["Request_AR"].values()) | ||
| self.out["Average_AR"] = average_ar | ||
| self.out["Category_AL"][category_name] = category_ar | ||
| average_ar = mean(self.out["Request_AL"].values()) | ||
| self.out["Average_AL"] = average_ar | ||
| self._process_lengths(lengths) | ||
| self.write() | ||
| self._format_write_output(text_outputs) | ||
|
|
@@ -96,12 +96,12 @@ def _pretty_print_results(self): | |
| table.add_column("Average AR", justify="right", style="green") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
After the AR→AL rename the column header still reads |
||
|
|
||
| # Add category rows | ||
| for category_name, category_ar in sorted(self.out["Category_AR"].items()): | ||
| for category_name, category_ar in sorted(self.out["Category_AL"].items()): | ||
| table.add_row(category_name, f"{category_ar:.4f}") | ||
|
|
||
| # Add separator and summary row | ||
| table.add_section() | ||
| table.add_row("[bold]Overall Average[/bold]", f"[bold]{self.out['Average_AR']:.4f}[/bold]") | ||
| table.add_row("[bold]Overall Average[/bold]", f"[bold]{self.out['Average_AL']:.4f}[/bold]") | ||
|
|
||
| console.print(table) | ||
|
|
||
|
|
@@ -124,8 +124,8 @@ def _create_visualizations( | |
|
|
||
| df_clean = pd.DataFrame.from_dict( | ||
| { | ||
| "question_id": list(self.out["Request_AR"].keys()), | ||
| "acceptance_rate": list(self.out["Request_AR"].values()), | ||
| "question_id": list(self.out["Request_AL"].keys()), | ||
| "acceptance_rate": list(self.out["Request_AL"].values()), | ||
| "category": [request.category for request in self.requests], | ||
| "response_length": [ | ||
| mean([len(c["content"]) for c in messages if c["role"] == "assistant"]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
__version__ = "1.0.0"bump and the documentedRequest_AR → Request_AL/Average_AR → Average_ALrename are a breaking on-disk schema change that's not visible in this PR's title ("configuration.json provenance + upload_to_s3"). Any downstream consumer parsingRequest_ARfrom olderacceptance_rate.json/specbench_results.jsonfiles breaks silently. Please either split this into its own PR titled around the methodology bump, or call it out explicitly at the top of the PR body — and update README/SPECBENCH_PORTING.md if they reference the old field names.