Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 65 additions & 10 deletions python/sglang/test/kits/mmmu_vlm_kit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import glob
import json
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from types import SimpleNamespace

from sglang.srt.environ import temp_set_env
Expand All @@ -18,6 +20,67 @@
DEFAULT_MEM_FRACTION_STATIC = 0.8


def _is_mmmu_parquet_corruption(error_output: str) -> bool:
"""Check if error is due to MMMU parquet file corruption."""
return (
"ArrowInvalid" in error_output
and "Parquet magic bytes not found" in error_output
and ("MMMU" in error_output or "lmms-lab--MMMU" in error_output)
)


def _cleanup_mmmu_dataset_cache():
"""Clean up corrupted MMMU dataset cache to allow fresh download."""
# Priority 1: Check CI convention path /hf_home first (used in Docker containers)
ci_hf_home = Path("/hf_home/hub/datasets--lmms-lab--MMMU")
if ci_hf_home.exists():
mmmu_cache_path = ci_hf_home
else:
# Priority 2: Use HF_HOME env var or default user cache
hf_home = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
mmmu_cache_path = Path(hf_home) / "hub" / "datasets--lmms-lab--MMMU"

if mmmu_cache_path.exists():
print(f"Detected corrupted MMMU parquet cache. Cleaning up: {mmmu_cache_path}")
try:
shutil.rmtree(mmmu_cache_path)
print(f"Successfully removed corrupted cache: {mmmu_cache_path}")
return True
except OSError as e:
print(f"Warning: Failed to remove cache {mmmu_cache_path}: {e}")
return False
else:
print(f"MMMU cache not found at {mmmu_cache_path}, skipping cleanup")
return False


def _run_lmms_eval_with_retry(cmd: list[str], timeout: int = 3600) -> None:
"""Run lmms_eval command with automatic retry on MMMU parquet corruption."""
try:
subprocess.run(
cmd,
check=True,
timeout=timeout,
capture_output=True,
text=True,
)
except subprocess.CalledProcessError as e:
error_output = e.stderr + e.stdout
if _is_mmmu_parquet_corruption(error_output):
print("Detected MMMU parquet corruption error. Attempting recovery...")
if _cleanup_mmmu_dataset_cache():
print("Retrying lmms_eval with fresh download...")
with temp_set_env(
HF_HUB_OFFLINE="0",
HF_DATASETS_DOWNLOAD_MODE="force_redownload",
):
subprocess.run(cmd, check=True, timeout=timeout)
else:
raise
else:
raise


class MMMUMixin:
"""Mixin for MMMU evaluation.

Expand Down Expand Up @@ -81,11 +144,7 @@ def run_mmmu_eval(
OPENAI_API_KEY=self.api_key,
OPENAI_API_BASE=f"{self.base_url}/v1",
):
subprocess.run(
cmd,
check=True,
timeout=3600,
)
_run_lmms_eval_with_retry(cmd)

def test_mmmu(self: CustomTestCase):
"""Run MMMU evaluation test."""
Expand Down Expand Up @@ -209,11 +268,7 @@ def run_mmmu_eval(
*self.mmmu_args,
]

subprocess.run(
cmd,
check=True,
timeout=3600,
)
_run_lmms_eval_with_retry(cmd)

def _run_vlm_mmmu_test(
self,
Expand Down
Loading