diff --git a/.github/workflows/build_artifacts.yml b/.github/workflows/build_artifacts.yml new file mode 100644 index 000000000000..6f5ba3f629cd --- /dev/null +++ b/.github/workflows/build_artifacts.yml @@ -0,0 +1,137 @@ +# CI - Build JAX Artifacts +# This workflow builds JAX wheels (jax, jaxlib, jax-cuda-plugin, and jax-cuda-pjrt) with a set of +# configuration options (platform, python version, whether to use latest XLA, etc). It can be +# triggered manually via workflow_dispatch or called by other workflows via workflow_call. When a +# workflow call is made, this workflow will build the artifacts and upload it to a GCS bucket so +# that other workflows (e.g. Pytest workflows) can use it. +name: CI - Build JAX Artifacts + +on: + workflow_dispatch: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: choice + required: true + default: "linux-x86-n2-16" + options: + - "linux-x86-n2-16" + - "linux-arm64-c4a-64" + - "windows-x86-n2-16" + artifact: + description: "Which JAX artifact to build?" + type: choice + required: true + default: "jaxlib" + options: + - "jax" + - "jaxlib" + - "jax-cuda-plugin" + - "jax-cuda-pjrt" + python: + description: "Which python version should the artifact be built for?" + type: choice + required: false + default: "3.12" + options: + - "3.10" + - "3.11" + - "3.12" + - "3.13" + clone_main_xla: + description: "Should latest XLA be used?" + type: choice + required: false + default: "0" + options: + - "1" + - "0" + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: false + default: 'no' + options: + - 'yes' + - 'no' + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + artifact: + description: "Which JAX artifact to build?" + type: string + required: true + default: "jaxlib" + python: + description: "Which python version should the artifact be built for?" + type: string + required: false + default: "3.12" + clone_main_xla: + description: "Should latest XLA be used?" + type: string + required: false + default: "0" + upload_artifacts_to_gcs: + description: "Should the artifacts be uploaded to a GCS bucket?" + required: true + default: true + type: boolean + gcs_upload_uri: + description: "GCS location prefix to where the artifacts should be uploaded" + required: true + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + outputs: + gcs_upload_uri: + description: "GCS location prefix to where the artifacts were uploaded" + value: ${{ inputs.gcs_upload_uri }} + +permissions: + contents: read + +jobs: + build-artifacts: + defaults: + run: + # Explicitly set the shell to bash to override Windows's default (cmd) + shell: bash + + runs-on: ${{ inputs.runner }} + + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(inputs.runner, 'windows-x86') && null) }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_CLONE_MAIN_XLA: "${{ inputs.clone_main_xla }}" + + name: Build ${{ inputs.artifact }} (${{ inputs.runner }}, Python ${{ inputs.python }}, clone main XLA=${{ inputs.clone_main_xla }}) + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Enable RBE if building on Linux x86 or Windows x86 + if: contains(inputs.runner, 'linux-x86') || contains(inputs.runner, 'windows-x86') + run: echo "JAXCI_BUILD_ARTIFACT_WITH_RBE=1" >> $GITHUB_ENV + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Build ${{ inputs.artifact }} + run: ./ci/build_artifacts.sh "${{ inputs.artifact }}" + - name: Upload artifacts to a GCS bucket (non-Windows runs) + if: >- + ${{ inputs.upload_artifacts_to_gcs && !contains(inputs.runner, 'windows-x86') }} + run: gsutil -m cp -r "$(pwd)/dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ + # Set shell to cmd to avoid path errors when using gcloud commands on Windows + - name: Upload artifacts to a GCS bucket (Windows runs) + if: >- + ${{ inputs.upload_artifacts_to_gcs && contains(inputs.runner, 'windows-x86') }} + shell: cmd + run: gsutil -m cp -r "dist/*.whl" "${{ inputs.gcs_upload_uri }}"/ \ No newline at end of file diff --git a/.github/workflows/pytest_cpu.yml b/.github/workflows/pytest_cpu.yml new file mode 100644 index 000000000000..4246a00d99db --- /dev/null +++ b/.github/workflows/pytest_cpu.yml @@ -0,0 +1,99 @@ +# CI - Pytest CPU +# +# This workflow runs the CPU tests with Pytest. It can only be triggered by other workflows via +# `workflow_call`. It is used by the "CI - Wheel Tests" workflows to run the Pytest CPU tests. +# +# It consists of the following job: +# run-tests: +# - Downloads the jaxlib wheel from a GCS bucket. +# - Executes the `run_pytest_cpu.sh` script, which performs the following actions: +# - Installs the downloaded jaxlib wheel. +# - Runs the CPU tests with Pytest. +name: CI - Pytest CPU + +on: + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + python: + description: "Which python version should the artifact be built for?" + type: string + required: true + default: "3.12" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + required: true + default: "0" + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + required: true + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: boolean + required: false + default: false + +jobs: + run-tests: + defaults: + run: + # Explicitly set the shell to bash to override Windows's default (cmd) + shell: bash + runs-on: ${{ inputs.runner }} + container: ${{ (contains(inputs.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(inputs.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') || + (contains(inputs.runner, 'windows-x86') && null) }} + + name: "Pytest CPU (${{ inputs.runner }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set env vars for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Adjust name for Windows + if [[ $os =~ "msys_nt" ]]; then + os="windows" + fi + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV + - name: Download jaxlib wheel from GCS (non-Windows runs) + if: ${{ !contains(matrix.runner, 'windows-x86') }} + run: >- + mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ + - name: Download jaxlib wheel from GCS (Windows runs) + if: ${{ contains(matrix.runner, 'windows-x86') }} + shell: cmd + run: >- + mkdir dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" dist/ + - name: Install Python dependencies + run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Run Pytest CPU tests + run: ./ci/run_pytest_cpu.sh diff --git a/.github/workflows/pytest_cuda.yml b/.github/workflows/pytest_cuda.yml new file mode 100644 index 000000000000..4b180dbd9f31 --- /dev/null +++ b/.github/workflows/pytest_cuda.yml @@ -0,0 +1,89 @@ +# CI - Pytest CUDA +# +# This workflow runs the CUDA tests with Pytest. It can only be triggered by other workflows via +# `workflow_call`. It is used by the `CI - Wheel Tests` workflows to run the Pytest CUDA tests. +# +# It consists of the following job: +# run-tests: +# - Downloads the jaxlib and CUDA artifacts from a GCS bucket. +# - Executes the `run_pytest_cuda.sh` script, which performs the following actions: +# - Installs the downloaded jaxlib wheel. +# - Runs the CUDA tests with Pytest. +name: CI - Pytest CUDA + +on: + workflow_call: + inputs: + runner: + description: "Which runner should the workflow run on?" + type: string + required: true + default: "linux-x86-n2-16" + python: + description: "Which python version to test?" + type: string + required: true + default: "3.12" + cuda: + description: "Which CUDA version to test?" + type: string + required: true + default: "12.3" + enable-x64: + description: "Should x64 mode be enabled?" + type: string + required: true + default: "0" + gcs_download_uri: + description: "GCS location prefix from where the artifacts should be downloaded" + required: true + default: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + type: string + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: boolean + required: false + default: false + +jobs: + run-tests: + runs-on: ${{ inputs.runner }} + # TODO: Update to the generic ML ecosystem test containers when they are ready. + container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') || + (contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/nosla-cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest') }} + name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})" + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}" + JAXCI_PYTHON: "python${{ inputs.python }}" + JAXCI_ENABLE_X64: "${{ inputs.enable-x64 }}" + + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + # Halt for testing + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Set env vars for use in artifact download URL + run: | + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # Get the major and minor version of Python. + # E.g if JAXCI_HERMETIC_PYTHON_VERSION=3.10, then python_major_minor=310 + python_major_minor=$(echo "$JAXCI_HERMETIC_PYTHON_VERSION" | tr -d '.') + + echo "OS=${os}" >> $GITHUB_ENV + echo "ARCH=${arch}" >> $GITHUB_ENV + echo "PYTHON_MAJOR_MINOR=${python_major_minor}" >> $GITHUB_ENV + - name: Download the wheel artifacts from GCS + run: >- + mkdir -p $(pwd)/dist && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jaxlib*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*plugin*${PYTHON_MAJOR_MINOR}*${OS}*${ARCH}*.whl" $(pwd)/dist/ && + gsutil -m cp -r "${{ inputs.gcs_download_uri }}/jax*cuda*pjrt*${OS}*${ARCH}*.whl" $(pwd)/dist/ + - name: Install Python dependencies + run: $JAXCI_PYTHON -m pip install -r build/requirements.in + - name: Run Pytest CUDA tests + run: ./ci/run_pytest_cuda.sh \ No newline at end of file diff --git a/.github/workflows/wheel_tests_continuous.yml b/.github/workflows/wheel_tests_continuous.yml new file mode 100644 index 000000000000..8f00159d65aa --- /dev/null +++ b/.github/workflows/wheel_tests_continuous.yml @@ -0,0 +1,106 @@ +# CI - Wheel Tests (Continuous) +# +# This workflow builds JAX artifacts and runs CPU/CUDA tests. +# +# It orchestrates the following: +# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and +# uploads it to a GCS bucket. +# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow to download the jaxlib wheel that was built +# in the previous step and runs CPU tests. +# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and +# uploads them to a GCS bucket. +# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts +# that were built in the previous steps and runs the CUDA tests. +name: CI - Wheel Tests (Continuous) + +on: + # schedule: + # - cron: "0 */2 * * *" # Run once every 2 hours + # TODO: For testing purposes, remove pull_request event before submitting + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + build-jaxlib-artifact: + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy in the CPU tests job + # Enable Windows after we have fixed the runner issue + runner: ["linux-x86-n2-16", "linux-arm64-c4a-64",] # "windows-x86-n2-16"] + artifact: ["jaxlib"] + python: ["3.10"] + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts_to_gcs: true + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + build-cuda-artifacts: + uses: ./.github/workflows/build_artifacts.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the GPU tests job below + runner: ["linux-x86-n2-16"] + artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] + python: ["3.10",] + with: + runner: ${{ matrix.runner }} + artifact: ${{ matrix.artifact }} + python: ${{ matrix.python }} + clone_main_xla: 1 + upload_artifacts_to_gcs: true + gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' + + run-pytest-cpu: + needs: build-jaxlib-artifact + uses: ./.github/workflows/pytest_cpu.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Runner OS and Python values need to match the matrix stategy in the + # build_jaxlib_artifact job above + runner: ["linux-x86-n2-64", "linux-arm64-c4a-64",] + python: ["3.10",] + enable-x64: [1, 0] + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + enable-x64: ${{ matrix.enable-x64 }} + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} + + run-pytest-cuda: + needs: [build-jaxlib-artifact, build-cuda-artifacts] + uses: ./.github/workflows/pytest_cuda.yml + strategy: + fail-fast: false # don't cancel all jobs on failure + matrix: + # Python values need to match the matrix stategy in the artifact build jobs above + runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"] + python: ["3.10",] + cuda: ["12.3", "12.1"] + enable-x64: [1, 0] + exclude: + # Run only a single configuration on H100 to save resources + - runner: "linux-x86-a3-8g-h100-8gpu" + python: "3.10" + cuda: "12.1" + - runner: "linux-x86-a3-8g-h100-8gpu" + python: "3.10" + enable-x64: 0 + with: + runner: ${{ matrix.runner }} + python: ${{ matrix.python }} + cuda: ${{ matrix.cuda }} + enable-x64: ${{ matrix.enable-x64 }} + # GCS upload URI is the same for both artifact build jobs + gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} \ No newline at end of file diff --git a/ci/build_artifacts.sh b/ci/build_artifacts.sh old mode 100644 new mode 100755 diff --git a/ci/run_pytest_cpu.sh b/ci/run_pytest_cpu.sh old mode 100644 new mode 100755 index 2b19ca5ddaa5..0b045bdc7927 --- a/ci/run_pytest_cpu.sh +++ b/ci/run_pytest_cpu.sh @@ -39,6 +39,7 @@ source "ci/utilities/setup_build_environment.sh" export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_64="$JAXCI_ENABLE_X64" # End of test environment variable setup echo "Running CPU tests..." diff --git a/ci/run_pytest_gpu.sh b/ci/run_pytest_cuda.sh old mode 100644 new mode 100755 similarity index 94% rename from ci/run_pytest_gpu.sh rename to ci/run_pytest_cuda.sh index 7bc2492781b2..e6dc3c18dead --- a/ci/run_pytest_gpu.sh +++ b/ci/run_pytest_cuda.sh @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -# Runs Pyest CPU tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt +# Runs Pyest CUDA tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt # wheels to be present inside $JAXCI_OUTPUT_DIR (../dist) # # -e: abort script if one command fails @@ -43,6 +43,7 @@ export PY_COLORS=1 export JAX_SKIP_SLOW_TESTS=true export NCCL_DEBUG=WARN export TF_CPP_MIN_LOG_LEVEL=0 +export JAX_ENABLE_64="$JAXCI_ENABLE_X64" # Set the number of processes to run to be 4x the number of GPUs. export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) @@ -52,7 +53,7 @@ export XLA_PYTHON_CLIENT_ALLOCATOR=platform export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 # End of test environment variable setup -echo "Running GPU tests..." +echo "Running CUDA tests..." "$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \ tests examples \ --deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \ diff --git a/ci/utilities/install_wheels_locally.sh b/ci/utilities/install_wheels_locally.sh index 181256b90804..4af679c9d079 100644 --- a/ci/utilities/install_wheels_locally.sh +++ b/ci/utilities/install_wheels_locally.sh @@ -26,7 +26,13 @@ fi echo "Installing the following wheels:" echo "${WHEELS[@]}" -"$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" + +# On Windows, convert MSYS Linux-like paths to Windows paths. +if [[ $(uname -s) =~ "MSYS_NT" ]]; then + "$JAXCI_PYTHON" -m pip install $(cygpath -w "${WHEELS[@]}") +else + "$JAXCI_PYTHON" -m pip install "${WHEELS[@]}" +fi echo "Installing the JAX package in editable mode at the current commit..." # Install JAX package at the current commit. diff --git a/ci/utilities/run_auditwheel.sh b/ci/utilities/run_auditwheel.sh old mode 100644 new mode 100755