diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml index 0f66a51e91..773b358bef 100644 --- a/.github/workflows/ci_pipeline.yml +++ b/.github/workflows/ci_pipeline.yml @@ -351,3 +351,68 @@ jobs: with: failed_run_id: '${{ github.run_id }}' secrets: inherit + + track_performance: + name: Track Test Performance + needs: [tpu-tests, gpu-tests, cpu-tests] + if: ${{ always() && !cancelled() }} + runs-on: linux-x86-ct6e-180-4tpu + container: google/cloud-sdk:524.0.0 + permissions: + contents: write + id-token: write + pull-requests: write + steps: + - uses: actions/checkout@v5 + + - name: Mark git repositories as safe + run: git config --global --add safe.directory ${GITHUB_WORKSPACE} + + - name: Download all test results + uses: actions/download-artifact@v4 + with: + path: test-results + pattern: test-results-* + merge-multiple: true + + - name: Parse JUnit XML to Benchmark format + run: | + mkdir -p ./cache + python3 tests/utils/parse_junit_to_benchmark.py test-results benchmark-results.json + echo "Parsed Benchmark Results:" + cat benchmark-results.json + + - name: Fetch Baseline Benchmark Data from GCS + run: | + mkdir -p ./cache + gcloud storage cp gs://maxtext-test-assets/benchmark-data.json ./cache/benchmark-data.json || true + + - name: Track Test Durations (Main) + if: github.ref == 'refs/heads/main' + uses: benchmark-action/github-action-benchmark@v1 + with: + name: MaxText Test Execution Times + tool: 'customSmallerIsBetter' + output-file-path: benchmark-results.json + external-data-json-path: ./cache/benchmark-data.json + github-token: ${{ secrets.GITHUB_TOKEN }} + alert-threshold: '150%' + comment-on-alert: true + fail-on-alert: false + + - name: Verify Test Durations (PR) + if: github.ref != 'refs/heads/main' + uses: benchmark-action/github-action-benchmark@v1 + with: + name: MaxText Test Execution Times + tool: 'customSmallerIsBetter' + output-file-path: benchmark-results.json + external-data-json-path: ./cache/benchmark-data.json + github-token: ${{ secrets.GITHUB_TOKEN }} + alert-threshold: '150%' + comment-on-alert: true + fail-on-alert: true + + - name: Upload Updated Baseline to GCS + run: | + gcloud storage cp ./cache/benchmark-data.json gs://maxtext-test-assets/benchmark-data.json diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 4b82a118e0..a019cf6725 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -195,6 +195,7 @@ jobs: -v \ -m "${FINAL_PYTEST_MARKER}" \ --durations=0 \ + --junitxml=test-results-${INPUTS_DEVICE_TYPE}-${INPUTS_WORKER_GROUP}.xml \ $PYTEST_COV_ARGS \ $SPLIT_ARGS \ ${INPUTS_PYTEST_EXTRA_ARGS} @@ -227,3 +228,10 @@ jobs: # If scheduled, upload to scheduled flag only. If PR, upload to regular flag only. flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }} verbose: true + - name: Upload Test Results XML + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results-${{ inputs.device_type }}-${{ inputs.worker_group }} + path: test-results-*.xml + if-no-files-found: ignore diff --git a/tests/utils/parse_junit_to_benchmark.py b/tests/utils/parse_junit_to_benchmark.py new file mode 100644 index 0000000000..87c6efb6f3 --- /dev/null +++ b/tests/utils/parse_junit_to_benchmark.py @@ -0,0 +1,65 @@ +import xml.etree.ElementTree as ET +import glob +import json +import sys +import os + +def main(): + if len(sys.argv) < 3: + print("Usage: python parse_junit_to_benchmark.py ") + sys.exit(1) + + xml_dir = sys.argv[1] + output_json = sys.argv[2] + + benchmarks = [] + total_times_by_device = {} + + xml_files = glob.glob(os.path.join(xml_dir, "*.xml")) + for xml_file in xml_files: + basename = os.path.basename(xml_file) + # e.g., test-results-tpu-1.xml -> device = tpu + device = "unknown" + parts = basename.replace(".xml", "").split("-") + if len(parts) >= 3: + device = parts[2] + + try: + tree = ET.parse(xml_file) + except Exception as e: + print(f"Error parsing {xml_file}: {e}") + continue + + root = tree.getroot() + + for testsuite in root.iter('testsuite'): + for testcase in testsuite.iter('testcase'): + name = testcase.get('name') + classname = testcase.get('classname') + time_val = float(testcase.get('time', 0.0)) + + # Prefix with device to distinguish test times on different hardware + full_name = f"[{device.upper()}] {classname}::{name}" + + benchmarks.append({ + "name": full_name, + "unit": "sec", + "value": time_val + }) + + total_times_by_device[device] = total_times_by_device.get(device, 0.0) + time_val + + for device, total_time in total_times_by_device.items(): + benchmarks.append({ + "name": f"Total {device.upper()} Test Suite Time", + "unit": "sec", + "value": total_time + }) + + with open(output_json, "w") as f: + json.dump(benchmarks, f, indent=2) + + print(f"Parsed {len(xml_files)} XML files and extracted {len(benchmarks)} duration metrics.") + +if __name__ == "__main__": + main()