Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/specdec_bench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,54 @@ python3 run.py \
--runtime_params runtime_args_long_context.yaml
```

## Uploading results to S3

Each `run.py` invocation writes a result directory containing `configuration.json`,
`timing.json`, `acceptance_rate.json`, and (when applicable) `mtbench.json` / `specbench.json`.
`upload_to_s3.py` is a single-file, drop-in tool that uploads one run — or an entire sweep —
to any S3-compatible bucket:

```bash
python upload_to_s3.py /path/to/run_or_sweep_dir s3://your-bucket/some/prefix \
--endpoint https://your-s3-endpoint \
--key-id YOUR_KEY_ID \
--secret YOUR_SECRET
```

`--endpoint`, `--key-id`, and `--secret` default to the `S3_ENDPOINT`, `S3_KEY_ID`, and
`S3_SECRET` environment variables. Omit `--endpoint` (or set `S3_ENDPOINT=""`) to use AWS S3's
default endpoint. Use `--dry-run` to preview the upload plan, and `--skip-existing` to skip
runs already present at the destination instead of failing.

The tool handles two directory layouts and mirrors them into S3:

- **Flat** — `LOCAL_DIR/run_name/{configuration,timing,...}.json`
- **Sweep** — `LOCAL_DIR/sweep_name/run_name/{configuration,timing,...}.json`

`LOCAL_DIR`'s basename is preserved in the destination prefix, so re-uploads from the same
source land in the same place.

### Optional attestation fields

`run.py` reads two environment variables when writing `configuration.json`; they're optional
provenance hints that downstream consumers (dashboards, leaderboards) can use to attest a run:

| Env var | Purpose |
|---|---|
| `JIRA_TICKET` | A tracking ID for the run (your tracker — JIRA key, GitHub issue, etc.) |
| `HUGGINGFACE_MODEL_ID` | The public model id on the Hugging Face Hub, so the model used can be independently fetched |

Set them in the same shell that launches `run.py`:

```bash
export JIRA_TICKET=MYTRACK-1234
export HUGGINGFACE_MODEL_ID=meta-llama/Llama-3.3-70B-Instruct
python3 run.py ...
```

Both fields appear in the run's `configuration.json` as `jira_ticket` and `huggingface_model_id`
(or `null` when unset). They have no effect on the benchmark itself.

## Notes

The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method.
Expand Down
2 changes: 2 additions & 0 deletions examples/specdec_bench/requirements_speed.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
boto3>=1.34.0
botocore>=1.34.0
datasets>=3.1.0
rich>=14.2.0
seaborn>=0.13.2
Expand Down
5 changes: 5 additions & 0 deletions examples/specdec_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from specdec_bench import datasets, metrics, models, runners
from specdec_bench.utils import (
decode_chat,
dump_env,
encode_chat,
get_tokenizer,
postprocess_base,
Expand Down Expand Up @@ -174,6 +175,10 @@ def run_simple(args):
if args.save_dir is not None:
for metric in metrics_list:
metric.update_directory(args.save_dir)
# Stamp configuration.json BEFORE the run loop so the file lands even
# when the run crashes mid-way. Engine init is already done, so the
# live serving_config from the model is available.
dump_env(args, args.save_dir, overrides={"serving_config": model.get_serving_config()})

runner = runners.SimpleRunner(model, metrics=metrics_list)

Expand Down
15 changes: 15 additions & 0 deletions examples/specdec_bench/specdec_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Methodology version. Bump:
# - minor (0.X.0) when adding a new metric or strictly-additive provenance field
# - major (X.0.0) when changing how an existing metric is computed OR its
# on-disk field names (incompatible with prior consumers / visualizers)
# The visualizer aggregates runs by major version to avoid apple-to-orange
# comparisons across methodology changes.
#
# 1.0.0: rename Request_AR / Category_AR / Average_AR → *_AL across the
# SpecBench / AcceptanceRate / MTBench metric writers, AND add
# Joint_Acceptance_Rate to the AcceptanceRate metric. The renamed
# values were always acceptance LENGTH (mean tokens generated per
# inference step), not a rate, and the visualizer reads *_AL.
# Pre-1.0.0 runs in S3 have *_AR and no Joint_AR; they must be
# re-run or post-processed before comparing.
__version__ = "1.0.0"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The __version__ = "1.0.0" bump and the documented Request_AR → Request_AL / Average_AR → Average_AL rename are a breaking on-disk schema change that's not visible in this PR's title ("configuration.json provenance + upload_to_s3"). Any downstream consumer parsing Request_AR from older acceptance_rate.json / specbench_results.json files breaks silently. Please either split this into its own PR titled around the methodology bump, or call it out explicitly at the top of the PR body — and update README/SPECBENCH_PORTING.md if they reference the old field names.

18 changes: 13 additions & 5 deletions examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,23 +55,31 @@ def _process_lengths(self, lengths):
self.out["Conditional_Acceptance_Rate"][k] = running_len / sum_lengths / prev_ratio
prev_ratio = running_len / sum_lengths
running_len -= v
# Joint acceptance rate at step k = product of conditional acceptance
# rates at steps 1..k = probability that ≥k tokens are accepted in
# a row. The visualizer renders this as a separate panel.
self.out["Joint_Acceptance_Rate"] = {}
running_joint = 1.0
for k, cond_ar in self.out["Conditional_Acceptance_Rate"].items():
running_joint *= cond_ar
self.out["Joint_Acceptance_Rate"][k] = running_joint

def process_final(self, text_outputs):
all_ar = []
lengths = {}
self.out["Request_AR"] = {}
self.out["Request_AL"] = {}
self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0]))
for request_id, turns in self.prompt_ar.items():
self.out["Request_AR"][request_id] = {}
self.out["Request_AL"][request_id] = {}
for turn_id, turn in turns.items():
ar = sum(turn) / len(turn)
self.out["Request_AR"][request_id][turn_id] = ar
self.out["Request_AL"][request_id][turn_id] = ar
all_ar.append(ar)
self._get_lengths(turn, lengths)
print(request_id, turn_id, self.out["Request_AR"][request_id][turn_id])
print(request_id, turn_id, self.out["Request_AL"][request_id][turn_id])
average_ar = sum(all_ar) / len(all_ar)
print("Average AR:", average_ar)
self.out["Average_AR"] = average_ar
self.out["Average_AL"] = average_ar
Comment on lines 80 to +82
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Update the average label to match renamed key.

Line 81 still prints "Average AR" while Line 82 writes Average_AL, which makes logs inconsistent with output fields.

Suggested patch
-        print("Average AR:", average_ar)
+        print("Average AL:", average_ar)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py` around lines
80 - 82, The print label is inconsistent with the output key: code computes
average_ar from all_ar and currently prints "Average AR" but stores it under
self.out["Average_AL"]; update the print statement to use the renamed key (e.g.,
print("Average_AL:", average_ar)) so the log message matches
self.out["Average_AL"] and keep references to the variables average_ar, all_ar
and the dict key self.out["Average_AL"] in the same block.

self._process_lengths(lengths)
self.write()
self._format_write_output(text_outputs)
Expand Down
14 changes: 7 additions & 7 deletions examples/specdec_bench/specdec_bench/metrics/mtbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,29 @@ class MTBench(AcceptanceRate):
def process_final(self, text_outputs):
i = 0
lengths = {}
self.out["Request_AR"] = {}
self.out["Request_AL"] = {}
self.prompt_ar = dict(sorted(self.prompt_ar.items(), key=lambda x: x[0]))
for request_id, turns in self.prompt_ar.items():
turn_1 = turns[0]
turn_2 = turns[1]
q_id = request_id
mtbench_topic = MTBENCH_TOPICS[q_id // 10]
self.out["Request_AR"][request_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2)
self.out["Request_AL"][request_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2)
self._get_lengths(turn_1, lengths)
self._get_lengths(turn_2, lengths)
print(mtbench_topic, sum(turn_1 + turn_2) / len(turn_1 + turn_2))
per_category = [[] for _ in range(len(MTBENCH_TOPICS))]
for q_id, ar in self.out["Request_AR"].items():
for q_id, ar in self.out["Request_AL"].items():
per_category[q_id // 10].append(ar)
self.out["Category_AR"] = {}
self.out["Category_AL"] = {}
for i, category in enumerate(per_category):
if len(category) > 0:
category_ar = sum(category) / len(category)
self.out["Category_AR"][MTBENCH_TOPICS[i]] = category_ar
self.out["Category_AL"][MTBENCH_TOPICS[i]] = category_ar
print(f"{MTBENCH_TOPICS[i]} Average AR: {category_ar}")
average_ar = sum(self.out["Request_AR"].values()) / len(self.out["Request_AR"])
average_ar = sum(self.out["Request_AL"].values()) / len(self.out["Request_AL"])
print("Average AR:", average_ar)
self.out["Average_AR"] = average_ar
self.out["Average_AL"] = average_ar
self._process_lengths(lengths)
self.write()
self._format_write_output(text_outputs)
Expand Down
24 changes: 12 additions & 12 deletions examples/specdec_bench/specdec_bench/metrics/specbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,26 +44,26 @@ def __init__(self, requests):

def process_final(self, text_outputs):
lengths = {}
self.out["Request_AR"] = {}
self.out["Request_AL"] = {}
for request_id, request in enumerate(self.requests):
turns = self.prompt_ar[request_id].values()
assert len(turns) == len(request.turns), (
f"Number of turns {len(turns)} does not match number of turns in request {len(request.turns)}"
)
self.out["Request_AR"][request.question_id] = mean(list(chain(*turns)))
self.out["Request_AL"][request.question_id] = mean(list(chain(*turns)))
for turn in turns:
self._get_lengths(turn, lengths)
print(request.category, self.out["Request_AR"][request.question_id])
print(request.category, self.out["Request_AL"][request.question_id])
per_category = defaultdict(list)
for request in self.requests:
per_category[request.category].append(self.out["Request_AR"][request.question_id])
self.out["Category_AR"] = {}
per_category[request.category].append(self.out["Request_AL"][request.question_id])
self.out["Category_AL"] = {}
for category_name, category_ar in per_category.items():
if len(category_ar) > 0:
category_ar = mean(category_ar)
self.out["Category_AR"][category_name] = category_ar
average_ar = mean(self.out["Request_AR"].values())
self.out["Average_AR"] = average_ar
self.out["Category_AL"][category_name] = category_ar
average_ar = mean(self.out["Request_AL"].values())
self.out["Average_AL"] = average_ar
self._process_lengths(lengths)
self.write()
self._format_write_output(text_outputs)
Expand Down Expand Up @@ -96,12 +96,12 @@ def _pretty_print_results(self):
table.add_column("Average AR", justify="right", style="green")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

After the AR→AL rename the column header still reads "Average AR" while the values come from Category_AL. Also mtbench.py:55 still prints "... Average AR: {category_ar}". Either rename the user-facing strings to match ("Average AL" / "Acceptance Length") or revert the dict-key rename. Mixed terminology in the same output is more confusing than either consistent choice.


# Add category rows
for category_name, category_ar in sorted(self.out["Category_AR"].items()):
for category_name, category_ar in sorted(self.out["Category_AL"].items()):
table.add_row(category_name, f"{category_ar:.4f}")

# Add separator and summary row
table.add_section()
table.add_row("[bold]Overall Average[/bold]", f"[bold]{self.out['Average_AR']:.4f}[/bold]")
table.add_row("[bold]Overall Average[/bold]", f"[bold]{self.out['Average_AL']:.4f}[/bold]")

console.print(table)

Expand All @@ -124,8 +124,8 @@ def _create_visualizations(

df_clean = pd.DataFrame.from_dict(
{
"question_id": list(self.out["Request_AR"].keys()),
"acceptance_rate": list(self.out["Request_AR"].values()),
"question_id": list(self.out["Request_AL"].keys()),
"acceptance_rate": list(self.out["Request_AL"].values()),
"category": [request.category for request in self.requests],
"response_length": [
mean([len(c["content"]) for c in messages if c["role"] == "assistant"])
Expand Down
9 changes: 9 additions & 0 deletions examples/specdec_bench/specdec_bench/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,14 @@ async def run(self, prompt_ids, sampling_params, request_id, turn_id):
"""
raise NotImplementedError

def get_serving_config(self):
"""Return a JSON-serializable dict describing the engine's effective config.

Captured into configuration.json's `serving_config` for reproducibility.
Subclasses override to surface engine-specific defaults (max_model_len,
kv_cache_dtype, etc.) that don't appear in the CLI args. Default: empty.
"""
return {}

def stop(self):
pass
9 changes: 9 additions & 0 deletions examples/specdec_bench/specdec_bench/models/sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
if "mamba_scheduler_strategy" in kwargs:
engine_kwargs["mamba_scheduler_strategy"] = kwargs["mamba_scheduler_strategy"]

self.engine_kwargs = engine_kwargs
self.model = sgl.Engine(**engine_kwargs)

self.sampling_config = sampling_kwargs
Expand Down Expand Up @@ -129,3 +130,11 @@ async def run(self, prompt_ids, max_length, end_id, request_id, turn_id):
output_dict["output_logits"] = None
output_dict["token_times"] = timing
return output_dict

def get_serving_config(self):
"""Dump the engine_kwargs dict supplied to sgl.Engine()."""
try:
# engine_kwargs is plain dict of scalars/None — already JSON-safe.
return dict(self.engine_kwargs)
except Exception:
return {}
20 changes: 20 additions & 0 deletions examples/specdec_bench/specdec_bench/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
async_scheduling=kwargs.get("async_scheduling", True),
enforce_eager=False,
)
self.engine_args = engine_args
self.model = AsyncLLM.from_engine_args(engine_args)
self.sampling_kwargs = sampling_kwargs
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
Expand Down Expand Up @@ -151,6 +152,25 @@ async def generate(self, prompt_ids, request_id, turn_id):
break
return outputs, timing, full_tokens

def get_serving_config(self):
"""Dump the AsyncEngineArgs dataclass plus the runtime vllm_config when available."""
try:
import dataclasses

cfg = dataclasses.asdict(self.engine_args)
except Exception:
cfg = {}
# vllm exposes the resolved engine config on the AsyncLLM instance once
# initialized — capture max_model_len / kv cache / dtype defaults that
# don't appear in AsyncEngineArgs.
try:
vllm_config = getattr(self.model, "vllm_config", None)
if vllm_config is not None and hasattr(vllm_config, "to_dict"):
cfg["vllm_config"] = vllm_config.to_dict()
except Exception:
pass
return cfg

def stop(self):
try:
self.loop.run_until_complete(self.model.shutdown())
Expand Down
Loading
Loading