diff --git a/og_marl/tf2_systems/offline/continuous_bc.py b/og_marl/tf2_systems/offline/continuous_bc.py index fe56348..2aa5f59 100644 --- a/og_marl/tf2_systems/offline/continuous_bc.py +++ b/og_marl/tf2_systems/offline/continuous_bc.py @@ -17,6 +17,7 @@ import hydra import numpy as np +import jax from omegaconf import DictConfig import sonnet as snt import tensorflow as tf @@ -183,6 +184,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/discrete_bc.py b/og_marl/tf2_systems/offline/discrete_bc.py index 36fc86a..c21fd78 100644 --- a/og_marl/tf2_systems/offline/discrete_bc.py +++ b/og_marl/tf2_systems/offline/discrete_bc.py @@ -18,6 +18,7 @@ import numpy as np import sonnet as snt import tensorflow as tf +import jax import tensorflow_probability as tfp import tree import hydra @@ -202,6 +203,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/iddpg_bc.py b/og_marl/tf2_systems/offline/iddpg_bc.py index c354e84..4288e47 100644 --- a/og_marl/tf2_systems/offline/iddpg_bc.py +++ b/og_marl/tf2_systems/offline/iddpg_bc.py @@ -18,6 +18,7 @@ import copy import hydra import numpy as np +import jax from omegaconf import DictConfig import tensorflow as tf from tensorflow import Tensor @@ -292,6 +293,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/iddpg_cql.py b/og_marl/tf2_systems/offline/iddpg_cql.py index 416986a..e5624a3 100644 --- a/og_marl/tf2_systems/offline/iddpg_cql.py +++ b/og_marl/tf2_systems/offline/iddpg_cql.py @@ -18,6 +18,7 @@ import copy import hydra import numpy as np +import jax from omegaconf import DictConfig import tensorflow as tf from tensorflow import Tensor @@ -414,6 +415,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/iql_bcq.py b/og_marl/tf2_systems/offline/iql_bcq.py index 60766e4..46c1ddb 100644 --- a/og_marl/tf2_systems/offline/iql_bcq.py +++ b/og_marl/tf2_systems/offline/iql_bcq.py @@ -17,6 +17,7 @@ import copy import hydra +import jax from omegaconf import DictConfig import sonnet as snt import tensorflow as tf @@ -296,6 +297,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/iql_cql.py b/og_marl/tf2_systems/offline/iql_cql.py index 64faaf2..8b538d3 100644 --- a/og_marl/tf2_systems/offline/iql_cql.py +++ b/og_marl/tf2_systems/offline/iql_cql.py @@ -16,6 +16,7 @@ import copy import numpy as np +import jax import tensorflow as tf import sonnet as snt import tree @@ -260,6 +261,8 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/maddpg_cql.py b/og_marl/tf2_systems/offline/maddpg_cql.py index 307eed6..863b6c8 100644 --- a/og_marl/tf2_systems/offline/maddpg_cql.py +++ b/og_marl/tf2_systems/offline/maddpg_cql.py @@ -18,6 +18,7 @@ import hydra import numpy as np +import jax import sonnet as snt from omegaconf import DictConfig import tensorflow as tf @@ -408,6 +409,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/maicq.py b/og_marl/tf2_systems/offline/maicq.py index a5a574f..6e42948 100644 --- a/og_marl/tf2_systems/offline/maicq.py +++ b/og_marl/tf2_systems/offline/maicq.py @@ -18,6 +18,7 @@ import copy import hydra import numpy as np +import jax from omegaconf import DictConfig import sonnet as snt import tensorflow as tf @@ -318,6 +319,8 @@ def _tf_train_step( def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/omar.py b/og_marl/tf2_systems/offline/omar.py index 4f8568c..16ca65b 100644 --- a/og_marl/tf2_systems/offline/omar.py +++ b/og_marl/tf2_systems/offline/omar.py @@ -19,6 +19,7 @@ import hydra import tree import numpy as np +import jax from omegaconf import DictConfig import tensorflow as tf import sonnet as snt @@ -490,6 +491,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/qmix_bcq.py b/og_marl/tf2_systems/offline/qmix_bcq.py index e2709cc..15c12b6 100644 --- a/og_marl/tf2_systems/offline/qmix_bcq.py +++ b/og_marl/tf2_systems/offline/qmix_bcq.py @@ -19,6 +19,7 @@ import hydra import tree import numpy as np +import jax from omegaconf import DictConfig import sonnet as snt import tensorflow as tf @@ -333,6 +334,8 @@ def mixing( def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/offline/qmix_cql.py b/og_marl/tf2_systems/offline/qmix_cql.py index 104e62f..080c06a 100644 --- a/og_marl/tf2_systems/offline/qmix_cql.py +++ b/og_marl/tf2_systems/offline/qmix_cql.py @@ -19,6 +19,7 @@ from omegaconf import DictConfig import tree import numpy as np +import jax import tensorflow as tf from tensorflow import Tensor import sonnet as snt @@ -331,6 +332,8 @@ def mixing( def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment(cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"]) buffer = FlashbaxReplayBuffer( diff --git a/og_marl/tf2_systems/online/iddpg.py b/og_marl/tf2_systems/online/iddpg.py index e6c5edb..9cd1c60 100644 --- a/og_marl/tf2_systems/online/iddpg.py +++ b/og_marl/tf2_systems/online/iddpg.py @@ -18,6 +18,7 @@ import copy import hydra import numpy as np +import jax from omegaconf import DictConfig import tensorflow as tf from tensorflow import Tensor @@ -293,6 +294,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment( cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"] ) diff --git a/og_marl/tf2_systems/online/maddpg.py b/og_marl/tf2_systems/online/maddpg.py index 88c82bb..adb2037 100644 --- a/og_marl/tf2_systems/online/maddpg.py +++ b/og_marl/tf2_systems/online/maddpg.py @@ -18,6 +18,7 @@ import copy import hydra import numpy as np +import jax from omegaconf import DictConfig import tensorflow as tf from tensorflow import Tensor @@ -294,6 +295,8 @@ def _tf_train_step(self, experience: Dict[str, Any]) -> Dict[str, Numeric]: def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment( cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"] ) diff --git a/og_marl/tf2_systems/online/qmix.py b/og_marl/tf2_systems/online/qmix.py index 6501bbf..d64257d 100644 --- a/og_marl/tf2_systems/online/qmix.py +++ b/og_marl/tf2_systems/online/qmix.py @@ -18,6 +18,7 @@ import copy import hydra import numpy as np +import jax from omegaconf import DictConfig import tensorflow as tf from tensorflow import Tensor @@ -298,6 +299,8 @@ def mixing( def run_experiment(cfg: DictConfig) -> None: print(cfg) + jax.config.update('jax_platform_name', 'cpu') + env = get_environment( cfg["task"]["source"], cfg["task"]["env"], cfg["task"]["scenario"], seed=cfg["seed"] )