Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Pathways proxy #690

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Note that these utilities do not handle resource management.
"""
import atexit
import importlib
import io
import logging
import math
Expand Down Expand Up @@ -443,6 +444,7 @@ class Config(GKEJob.Config):
enable_tpu_ici_resiliency: Optional[bool] = None
location_hint: Optional[str] = None
enable_tpu_smart_repair: bool = False
import_pathways: Optional[list[str]] = []
jesus-orozco marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def define_flags(cls, fv: flags.FlagValues):
Expand All @@ -457,6 +459,9 @@ def define_flags(cls, fv: flags.FlagValues):
"not all TPU types support this flag.",
**common_kwargs,
)
flags.DEFINE_list(
"import_pathways", [], "Modules to enable pathways proxy.", **common_kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here and below.

)

@classmethod
def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
Expand All @@ -480,6 +485,15 @@ def __init__(self, cfg: Config):
raise NotImplementedError(f"Missing system characteristics for {self._tpu_type}")
super().__init__(cfg)
self._output_volume_mount = dict(name="shared-output", mountPath="/output")
if len(cfg.import_pathways) > 0:
self._import_pathways(cfg.import_pathways)

def _import_pathways(self, import_pathways: list[str]):
try:
for module in import_pathways:
importlib.import_module(module)
except ModuleNotFoundError:
logging.error("An error occurred while importing pathways dependencies.")

def _maybe_add_volume_mount(self, volume_mounts: list[dict], *, spec: Optional[VolumeMount]):
if spec:
Expand Down
4 changes: 3 additions & 1 deletion axlearn/common/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def setup():
logging.info("Devices: %s", devices)
local_devices = jax.local_devices()
logging.info("Local Devices: %s", local_devices)
if not devices or not all(device.platform == FLAGS.jax_backend for device in devices):
if FLAGS.jax_backend != "proxy" and (
not devices or not all(device.platform == FLAGS.jax_backend for device in devices)
):
raise RuntimeError(f"Expected backend {FLAGS.jax_backend}. Got {devices}.")
if FLAGS.data_dir:
# TODO(ruoming): Get rid of --data_dir and use only env var DATA_DIR.
Expand Down
9 changes: 6 additions & 3 deletions axlearn/common/utils_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def setup(
if initialization_timeout is not None:
init_kwargs["initialization_timeout"] = initialization_timeout

if jax_backend == "tpu":
# TPU resources orchestrated by Pathways use 'proxy' as the JAX backend
if jax_backend in ("tpu", "proxy"):
jesus-orozco marked this conversation as resolved.
Show resolved Hide resolved
if not (
distributed_coordinator is None and num_processes is None and process_id is None
):
Expand Down Expand Up @@ -92,5 +93,7 @@ def setup(
# local_device_ids arg allows us to maintain expected behavior
init_kwargs["local_device_ids"] = list(range(8))

jax.distributed.initialize(**init_kwargs)
_jax_distributed_initialized = True
# When using Pathways proxy for TPU backend, jax distributed init is not needed
if jax_backend != "proxy":
jax.distributed.initialize(**init_kwargs)
_jax_distributed_initialized = True
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ audio = [
"levenshtein==0.25.1",
]

jesus-orozco marked this conversation as resolved.
Show resolved Hide resolved
# Pathways utilities.
pathways = [
"pathwaysutils==0.0.7", # for JAX+Pathways single-controller accelerator coordinator
]

[tool.flit.module]
# This defines the import name. https://flit.pypa.io/en/stable/pyproject_toml.html#module-section
name = "axlearn"
Expand Down