From 80aaa1a25880d77a80a008995e5b0651d6158e47 Mon Sep 17 00:00:00 2001 From: Nitin Srinivasan Date: Tue, 3 Dec 2024 15:52:55 -0800 Subject: [PATCH] Add Github action workflow for building JAX artifacts PiperOrigin-RevId: 702497163 --- .github/workflows/build_artifacts.yml | 79 +++++++++++++++++++++++++++ ci/build_artifacts.sh | 0 ci/utilities/run_auditwheel.sh | 0 3 files changed, 79 insertions(+) create mode 100644 .github/workflows/build_artifacts.yml mode change 100644 => 100755 ci/build_artifacts.sh mode change 100644 => 100755 ci/utilities/run_auditwheel.sh diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml new file mode 100644 index 000000000000..226d428a4e6d --- /dev/null +++ b/.github/workflows/build_artifacts.yml @@ -0,0 +1,79 @@ +name: CI - Build JAX Artifacts + +on: + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + build-artifacts: + if: github.event.repository.fork == false + + defaults: + run: + # Explicitly set the shell to bash to override the default Windows environment, i.e, cmd. + shell: bash + + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + runner: ["windows-x86-n2-16", "linux-x86-n2-16", "linux-arm64-t2a-16"] + artifact: ["jaxlib", "jax-cuda-pjrt", "jax-cuda-plugin"] + python: ["3.10", "3.11", "3.12", "3.13"] + exclude: + # Don't build jax-cuda-pjrt and jax-cuda-plugin on windows-x86-n2-16 + - runner: "windows-x86-n2-64" + artifact: "jax-cuda-pjrt", + - runner: "windows-x86-n2-64" + artifact: "jax-cuda-plugin" + # Don't build jax-cuda-pjrt for each python version + - artifact: "jax-cuda-pjrt" + python: 3.10 + - artifact: "jax-cuda-pjrt" + python: 3.11 + - artifact: "jax-cuda-pjrt" + python: 3.12 + + runs-on: ${{ matrix.runner }} + + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(matrix.runner, 'windows-x86') && null) }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ matrix.python }}" + + name: Build ${{ matrix.artifact }} on ${{ matrix.runner }} with Python ${{ matrix.python }} + + steps: + - uses: actions/checkout@v3 + - name: Enable RBE on platforms where its supported + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Enable RBE if building on Linux x86 or Windows x86 + if [[ ($os == "linux" || $os =~ "msys_nt" ) && $arch == "x86_64" ]]; then + echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV + fi + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Build ${{ matrix.artifact }} + run: ./ci/build_artifacts.sh "${{ matrix.artifact }}" \ No newline at end of file diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh old mode 100644 new mode 100755 diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh old mode 100644 new mode 100755