Skip to content

Commit

Permalink
Make jax run on CPU so it does not conflict with TF.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcformanek committed Nov 10, 2024
1 parent 27e32ac commit c092663
Show file tree
Hide file tree
Showing 14 changed files with 42 additions and 0 deletions.
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/continuous_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import hydra
import numpy as np
import jax
from omegaconf import DictConfig
import sonnet as snt
import tensorflow as tf
Expand Down Expand Up @@ -183,6 +184,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/discrete_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import sonnet as snt
import tensorflow as tf
import jax
import tensorflow_probability as tfp
import tree
import hydra
Expand Down Expand Up @@ -202,6 +203,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/iddpg_bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import hydra
import numpy as np
import jax
from omegaconf import DictConfig
import tensorflow as tf
from tensorflow import Tensor
Expand Down Expand Up @@ -292,6 +293,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/iddpg_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import hydra
import numpy as np
import jax
from omegaconf import DictConfig
import tensorflow as tf
from tensorflow import Tensor
Expand Down Expand Up @@ -414,6 +415,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/iql_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import copy
import hydra
import jax
from omegaconf import DictConfig
import sonnet as snt
import tensorflow as tf
Expand Down Expand Up @@ -296,6 +297,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/iql_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import copy
import numpy as np
import jax
import tensorflow as tf
import sonnet as snt
import tree
Expand Down Expand Up @@ -260,6 +261,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/maddpg_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import hydra
import numpy as np
import jax
import sonnet as snt
from omegaconf import DictConfig
import tensorflow as tf
Expand Down Expand Up @@ -408,6 +409,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/maicq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import hydra
import numpy as np
import jax
from omegaconf import DictConfig
import sonnet as snt
import tensorflow as tf
Expand Down Expand Up @@ -318,6 +319,8 @@ def _tf_train_step(
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/omar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import hydra
import tree
import numpy as np
import jax
from omegaconf import DictConfig
import tensorflow as tf
import sonnet as snt
Expand Down Expand Up @@ -490,6 +491,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/qmix_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import hydra
import tree
import numpy as np
import jax
from omegaconf import DictConfig
import sonnet as snt
import tensorflow as tf
Expand Down Expand Up @@ -333,6 +334,8 @@ def mixing(
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/offline/qmix_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from omegaconf import DictConfig
import tree
import numpy as np
import jax
import tensorflow as tf
from tensorflow import Tensor
import sonnet as snt
Expand Down Expand Up @@ -331,6 +332,8 @@ def mixing(
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"])

buffer = FlashbaxReplayBuffer(
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/online/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import hydra
import numpy as np
import jax
from omegaconf import DictConfig
import tensorflow as tf
from tensorflow import Tensor
Expand Down Expand Up @@ -293,6 +294,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(
cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]
)
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/online/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import hydra
import numpy as np
import jax
from omegaconf import DictConfig
import tensorflow as tf
from tensorflow import Tensor
Expand Down Expand Up @@ -294,6 +295,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]:
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(
cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]
)
Expand Down
3 changes: 3 additions & 0 deletions og_marl/tf2_systems/online/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import hydra
import numpy as np
import jax
from omegaconf import DictConfig
import tensorflow as tf
from tensorflow import Tensor
Expand Down Expand Up @@ -298,6 +299,8 @@ def mixing(
def run_experiment(cfg: DictConfig) -> None:
print(cfg)

jax.config.update('jax_platform_name', 'cpu')

env = get_environment(
cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]
)
Expand Down

0 comments on commit c092663

Please sign in to comment.