diff --git a/.github/workflows/ci_pipeline.yml b/.github/workflows/ci_pipeline.yml index c0908c4cdd..2820cb53f4 100644 --- a/.github/workflows/ci_pipeline.yml +++ b/.github/workflows/ci_pipeline.yml @@ -169,6 +169,21 @@ jobs: is_scheduled_run: ${{ github.event_name == 'schedule' }} maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + tpu7x-tests: + name: ${{ matrix.flavor || 'TPU7X' }} tests + needs: [build_and_upload_maxtext_package] + if: needs.analyze_code_changes.outputs.run_tests == 'true' && github.ref == 'refs/heads/main' && (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') + uses: ./.github/workflows/run_tests_coordinator.yml + strategy: + fail-fast: false + matrix: + flavor: [tpu7x-unit, tpu7x-integration, tpu7x-post-training-unit] + with: + flavor: ${{ matrix.flavor }} + base_image: maxtext-unit-test-tpu:py312 + is_scheduled_run: ${{ github.event_name == 'schedule' }} + maxtext_sha: ${{ needs.build_and_upload_maxtext_package.outputs.maxtext_sha }} + gpu-tests: name: ${{ matrix.flavor || 'GPU' }} tests needs: [build_and_upload_maxtext_package] @@ -220,6 +235,7 @@ jobs: maxtext_tpu_pathways_integration_tests: needs: build_and_upload_maxtext_package + if: needs.analyze_code_changes.outputs.run_tests == 'true' uses: ./.github/workflows/run_pathways_tests.yml strategy: fail-fast: false @@ -304,7 +320,7 @@ jobs: notify_failure: name: Notify failed build # creates an issue or modifies last open existing issue for failed build - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] + needs: [tpu-tests, tpu7x-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests] if: ${{ always() }} runs-on: ubuntu-latest permissions: @@ -318,7 +334,7 @@ jobs: investigate_failure: name: Investigate failed build # investigates failure of scheduled run and comments on tracking issue - needs: [tpu-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, notify_failure] + needs: [tpu-tests, tpu7x-tests, gpu-tests, cpu-tests, maxtext_jupyter_notebooks, maxtext_tpu_pathways_unit_tests, maxtext_tpu_pathways_integration_tests, notify_failure] if: ${{ always() && contains(needs.*.result, 'failure') && github.event_name == 'schedule' }} uses: ./.github/workflows/gemini_investigate.yml permissions: diff --git a/.github/workflows/run_tests_coordinator.yml b/.github/workflows/run_tests_coordinator.yml index d76eaec3bd..91e3cfe4ec 100644 --- a/.github/workflows/run_tests_coordinator.yml +++ b/.github/workflows/run_tests_coordinator.yml @@ -25,6 +25,8 @@ on: Test flavor ( tpu-unit, tpu-integration, tpu-post-training-unit, tpu-post-training-integration, + tpu7x-unit, tpu7x-integration, + tpu7x-post-training-unit, tpu7x-post-training-integration, gpu-unit, gpu-integration, cpu-unit, cpu-post-training-unit @@ -77,6 +79,10 @@ jobs: "tpu-integration": "tpu", "tpu-post-training-unit": "tpu", "tpu-post-training-integration": "tpu", + "tpu7x-unit": "tpu", + "tpu7x-integration": "tpu", + "tpu7x-post-training-unit": "tpu", + "tpu7x-post-training-integration": "tpu", "gpu-unit": "cuda12", "gpu-integration": "cuda12", "cpu-unit": "cpu", @@ -89,6 +95,10 @@ jobs: "tpu-integration": "v6e-4", "tpu-post-training-unit": "v6e-4", "tpu-post-training-integration": "v6e-4", + "tpu7x-unit": "tpu7x-8", + "tpu7x-integration": "tpu7x-8", + "tpu7x-post-training-unit": "tpu7x-8", + "tpu7x-post-training-integration": "tpu7x-8", "gpu-unit": "a100-40gb-4", "gpu-integration": "a100-40gb-4", "cpu-unit": "X64", @@ -101,6 +111,10 @@ jobs: "tpu-integration": "linux-x86-ct6e-180-4tpu", "tpu-post-training-unit": "linux-x86-ct6e-180-4tpu", "tpu-post-training-integration": "linux-x86-ct6e-180-4tpu", + "tpu7x-unit": "linux-x86-tpu7x-224-4tpu", + "tpu7x-integration": "linux-x86-tpu7x-224-4tpu", + "tpu7x-post-training-unit": "linux-x86-tpu7x-224-4tpu", + "tpu7x-post-training-integration": "linux-x86-tpu7x-224-4tpu", "gpu-unit": "linux-x86-a2-48-a100-4gpu", "gpu-integration": "linux-x86-a2-48-a100-4gpu", "cpu-unit": "linux-x86-n2-32", @@ -113,6 +127,10 @@ jobs: "tpu-integration": "not cpu_only and not gpu_only and integration_test and not post_training", "tpu-post-training-unit": "not cpu_only and not gpu_only and not integration_test and post_training", "tpu-post-training-integration": "not cpu_only and not gpu_only and integration_test", + "tpu7x-unit": "not cpu_only and not gpu_only and not integration_test and not post_training", + "tpu7x-integration": "not cpu_only and not gpu_only and integration_test and not post_training", + "tpu7x-post-training-unit": "not cpu_only and not gpu_only and not integration_test and post_training", + "tpu7x-post-training-integration": "not cpu_only and not gpu_only and integration_test", "gpu-unit": "not cpu_only and not tpu_only and not integration_test and not post_training", "gpu-integration": "not cpu_only and not tpu_only and integration_test and not post_training", "cpu-unit": "cpu_only and not post_training", @@ -125,6 +143,10 @@ jobs: "tpu-integration": "", "tpu-post-training-unit": "tests/post_training/unit tests/unit", "tpu-post-training-integration": "tests/post_training/integration", + "tpu7x-unit": "", + "tpu7x-integration": "", + "tpu7x-post-training-unit": "tests/post_training/unit tests/unit", + "tpu7x-post-training-integration": "tests/post_training/integration", "gpu-unit": "", "gpu-integration": "", "cpu-unit": "", @@ -137,6 +159,10 @@ jobs: "tpu-integration": "--ignore=tests/post_training", "tpu-post-training-unit": "", "tpu-post-training-integration": "", + "tpu7x-unit": "--ignore=tests/post_training --ignore=tests/inference/kvcache_test.py", + "tpu7x-integration": "--ignore=tests/post_training", + "tpu7x-post-training-unit": "", + "tpu7x-post-training-integration": "", "gpu-unit": "--ignore=tests/post_training", "gpu-integration": "--ignore=tests/post_training", "cpu-unit": "--ignore=tests/post_training", diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index 0fad677333..f85e871a39 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -30,10 +30,8 @@ ) -def _small_model_base_emb_dim(decoupled, device_count): - """Return a tiny embedding dim divisible by local decoupled devices.""" - if not decoupled: - return 28 +def _small_model_base_emb_dim(device_count): + """Return a tiny embedding dim divisible by local devices.""" return ((28 + device_count - 1) // device_count) * device_count @@ -46,7 +44,7 @@ class TrainTests(unittest.TestCase): dataset_path = get_test_dataset_path() _small_model_overrides = [ - f"base_emb_dim={_small_model_base_emb_dim(decoupled, dev_count)}", + f"base_emb_dim={_small_model_base_emb_dim(dev_count)}", "base_num_query_heads=4", "base_num_kv_heads=4", "base_mlp_dim=32",