Fix regression with metrics passed to compile. (#22663)
#2255
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: Keras TPU Tests | |
| on: | |
| push: | |
| branches: [master] | |
| pull_request: | |
| types: [unlabeled] | |
| release: | |
| types: [created] | |
| permissions: | |
| contents: read | |
| concurrency: | |
| group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
| cancel-in-progress: true | |
| jobs: | |
| test-in-container: | |
| strategy: | |
| fail-fast: false | |
| matrix: | |
| multi_device: [false, true] | |
| backend: [jax] | |
| name: ${{ format('Run tests on {0}TPU', matrix.multi_device && 'multi-' || '') }} | |
| runs-on: ${{ matrix.multi_device && 'linux-x86-ct5lp-112-4tpu' || 'linux-x86-ct6e-44-1tpu' }} | |
| # Only run on pushes to master, releases or "kokoro:force-run" unlabel | |
| if: | | |
| github.event_name == 'push' || | |
| github.event_name == 'release' || | |
| (github.event.action == 'unlabeled' && github.event.label.name == 'kokoro:force-run') | |
| container: | |
| image: python:3.11-slim | |
| options: --privileged --network host | |
| steps: | |
| - name: Checkout ${{ github.ref }} | |
| uses: actions/checkout@v6 | |
| - name: Install Dependencies | |
| run: pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt | |
| - name: Set Keras Backend | |
| run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV | |
| - name: Verify JAX Installation | |
| run: python3 -c "import jax; print('JAX devices:', jax.devices()); assert jax.default_backend() == 'tpu'" | |
| - name: Run Tests | |
| if: ${{ !matrix.multi_device }} | |
| run: pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml | |
| - name: Run Multi-device Tests | |
| if: ${{ matrix.multi_device }} | |
| run: pytest keras -m multi_device --cov=keras --cov-config=pyproject.toml |