Skip to content

Commit

Permalink
Merge pull request #81 from pdebench/fix/linter_warnings
Browse files Browse the repository at this point in the history
Fix ruff linter warnings
  • Loading branch information
leiterrl authored Dec 28, 2024
2 parents 9f3ca2b + 0cb1e07 commit 7da8b79
Show file tree
Hide file tree
Showing 50 changed files with 575 additions and 631 deletions.
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,20 @@ repos:
args: [--prose-wrap=always]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.14"
rev: "v0.8.4"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.8.0"
hooks:
- id: mypy
files: pdebench|tests
args: []
additional_dependencies:
- pytest
# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: "v1.8.0"
# hooks:
# - id: mypy
# files: pdebench|tests
# args: []
# additional_dependencies:
# - pytest

- repo: https://github.com/codespell-project/codespell
rev: "v2.2.6"
Expand All @@ -62,7 +62,7 @@ repos:
exclude_types: [jupyter]

- repo: https://github.com/shellcheck-py/shellcheck-py
rev: "v0.9.0.6"
rev: "v0.10.0.1"
hooks:
- id: shellcheck

Expand Down
57 changes: 57 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import shutil
from pathlib import Path

import nox

DIR = Path(__file__).parent.resolve()

nox.needs_version = ">=2024.3.2"
nox.options.sessions = ["precommit", "pylint", "tests", "build"]
nox.options.default_venv_backend = "uv|mamba|virtualenv"


@nox.session(python=["3.10"])
def precommit(session: nox.Session) -> None:
"""
Run the linter.
"""
session.install("pre-commit")
session.run(
"pre-commit", "run", "--all-files", "--show-diff-on-failure", *session.posargs
)


@nox.session(python=["3.10"])
def pylint(session: nox.Session) -> None:
"""
Run PyLint.
"""
# This needs to be installed into the package environment, and is slower
# than a pre-commit check
session.install(".", "pylint>=3.2")
session.run("pylint", "pdebench", *session.posargs)


@nox.session(python=["3.10"])
def tests(session: nox.Session) -> None:
"""
Run the unit and regular tests.
"""
session.install(".[test]")
session.run("pytest", *session.posargs)


@nox.session(python=["3.10"])
def build(session: nox.Session) -> None:
"""
Build an SDist and wheel.
"""

build_path = DIR.joinpath("build")
if build_path.exists():
shutil.rmtree(build_path)

session.install("build")
session.run("python", "-m", "build")
8 changes: 7 additions & 1 deletion pdebench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
"""

from __future__ import annotations

import logging

_logger = logging.getLogger(__name__)
_logger.propagate = False

__version__ = "0.0.1"
__author__ = "Makoto Takamoto, Timothy Praditia, Raphael Leiteritz, Dan MacKinlay, Francesco Alesiani, Dirk Pflüger, Mathias Niepert"
__credits__ = "NEC labs Europe, University of Stuttgart, CSIRO" "s Data61"
__credits__ = "NEC labs Europe, University of Stuttgart, CSIRO's Data61"
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
"""
<NAME OF THE PROGRAM THIS FILE BELONGS TO>
Expand Down Expand Up @@ -144,8 +143,10 @@
THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
"""

from __future__ import annotations

import logging
import time
from math import ceil

Expand All @@ -157,11 +158,13 @@
# Hydra
from omegaconf import DictConfig

logger = logging.getLogger(__name__)


# Init arguments with Hydra
@hydra.main(config_path="config")
def main(cfg: DictConfig) -> None:
print(f"advection velocity: {cfg.args.beta}")
logger.info("advection velocity: %f", cfg.args.beta)

# cell edge coordinate
xe = jnp.linspace(cfg.args.xL, cfg.args.xR, cfg.args.nx + 1)
Expand All @@ -181,14 +184,14 @@ def evolve(u):
uu = uu.at[0].set(u)

while t < cfg.args.fin_time:
print(f"save data at t = {t:.3f}")
logger.info("save data at t = %f", t)
u = set_function(xc, t, cfg.args.beta)
uu = uu.at[i_save].set(u)
t += cfg.args.dt_save
i_save += 1

tm_fin = time.time()
print(f"total elapsed time is {tm_fin - tm_ini} sec")
logger.info("total elapsed time is %f sec", tm_fin - tm_ini)
uu = uu.at[-1].set(u)
return uu, t

Expand All @@ -199,9 +202,9 @@ def set_function(x, t, beta):
u = set_function(xc, t=0, beta=cfg.args.beta)
u = device_put(u) # putting variables in GPU (not necessary??)
uu, t = evolve(u)
print(f"final time is: {t:.3f}")
logger.info("final time is: %f", t)

print("data saving...")
logger.info("data saving...")
cwd = hydra.utils.get_original_cwd() + "/"
jnp.save(cwd + cfg.args.save + "/Advection_beta" + str(cfg.args.beta), uu)
jnp.save(cwd + cfg.args.save + "/x_coordinate", xe)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
"""
<NAME OF THE PROGRAM THIS FILE BELONGS TO>
Expand Down Expand Up @@ -144,22 +143,29 @@
THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
"""

from __future__ import annotations

import random
import sys

# Hydra
from math import ceil, exp, log
from pathlib import Path

import hydra
import jax
import jax.numpy as jnp
from jax import device_put, lax

# Hydra
from omegaconf import DictConfig

sys.path.append("..")
import logging

from utils import Courant, bc, init_multi, limiting

logger = logging.getLogger(__name__)


def _pass(carry):
return carry
Expand Down Expand Up @@ -192,7 +198,7 @@ def main(cfg: DictConfig) -> None:
else:
beta = cfg.multi.beta

print("beta: ", beta)
logger.info("beta: %f", beta)

@jax.jit
def evolve(u):
Expand All @@ -204,7 +210,8 @@ def evolve(u):
uu = jnp.zeros([it_tot, u.shape[0]])
uu = uu.at[0].set(u)

cond_fun = lambda x: x[0] < fin_time
def cond_fun(x):
return x[0] < fin_time

def _body_fun(carry):
def _show(_carry):
Expand All @@ -226,9 +233,7 @@ def _show(_carry):

carry = t, tsave, steps, i_save, dt, u, uu
t, tsave, steps, i_save, dt, u, uu = lax.while_loop(cond_fun, _body_fun, carry)
uu = uu.at[-1].set(u)

return uu
return uu.at[-1].set(u)

@jax.jit
def simulation_fn(i, carry):
Expand Down Expand Up @@ -265,12 +270,11 @@ def flux(u):
fL = uL * beta
fR = uR * beta
# upwind advection scheme
f_upwd = 0.5 * (
return 0.5 * (
fR[1 : cfg.multi.nx + 2]
+ fL[2 : cfg.multi.nx + 3]
- jnp.abs(beta) * (uL[2 : cfg.multi.nx + 3] - uR[1 : cfg.multi.nx + 2])
)
return f_upwd

u = init_multi(xc, numbers=cfg.multi.numbers, k_tot=4, init_key=cfg.multi.init_key)
u = device_put(u) # putting variables in GPU (not necessary??)
Expand All @@ -285,7 +289,7 @@ def flux(u):
# reshape before saving
uu = uu.reshape((-1, *uu.shape[2:]))

print("data saving...")
logger.info("data saving...")
cwd = hydra.utils.get_original_cwd() + "/"
Path(cwd + cfg.multi.save).mkdir(parents=True, exist_ok=True)
jnp.save(cwd + cfg.multi.save + "1D_Advection_Sols_beta" + str(beta)[:5], uu)
Expand Down
14 changes: 9 additions & 5 deletions pdebench/data_gen/data_gen_NLE/BurgersEq/burgers_Hydra.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
"""
<NAME OF THE PROGRAM THIS FILE BELONGS TO>
Expand Down Expand Up @@ -144,8 +143,10 @@
THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
"""

from __future__ import annotations

import logging
import sys
import time
from math import ceil
Expand All @@ -161,6 +162,8 @@
sys.path.append("..")
from utils import Courant, Courant_diff, bc, init, limiting

logger = logging.getLogger(__name__)


def _pass(carry):
return carry
Expand Down Expand Up @@ -201,7 +204,8 @@ def evolve(u):

tm_ini = time.time()

cond_fun = lambda x: x[0] < fin_time
def cond_fun(x):
return x[0] < fin_time

def _body_fun(carry):
def _save(_carry):
Expand All @@ -227,7 +231,7 @@ def _save(_carry):
uu = uu.at[-1].set(u)

tm_fin = time.time()
print(f"total elapsed time is {tm_fin - tm_ini} sec")
logger.info("total elapsed time is %f sec", tm_fin - tm_ini)
return uu, t

@jax.jit
Expand Down Expand Up @@ -285,9 +289,9 @@ def flux(u):
u = init(xc=xc, mode=cfg.args.init_mode, u0=cfg.args.u0, du=cfg.args.du)
u = device_put(u) # putting variables in GPU (not necessary??)
uu, t = evolve(u)
print(f"final time is: {t:.3f}")
logger.info("final time is: %.3f", t)

print("data saving...")
logger.info("data saving...")
cwd = hydra.utils.get_original_cwd() + "/"
if cfg.args.init_mode == "sinsin":
jnp.save(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#!/usr/bin/env python
"""
<NAME OF THE PROGRAM THIS FILE BELONGS TO>
Expand Down Expand Up @@ -144,8 +143,10 @@
THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
"""

from __future__ import annotations

import logging
import random
import sys
from math import ceil, exp, log
Expand All @@ -162,6 +163,8 @@
sys.path.append("..")
from utils import Courant, Courant_diff, bc, init_multi, limiting

logger = logging.getLogger(__name__)


def _pass(carry):
return carry
Expand Down Expand Up @@ -191,7 +194,7 @@ def main(cfg: DictConfig) -> None:
) # uniform number between 0.01 to 100
else:
epsilon = cfg.multi.epsilon
print("epsilon: ", epsilon)
logger.info("epsilon: %f", epsilon)
# t-coordinate
it_tot = ceil((fin_time - ini_time) / dt_save) + 1
tc = jnp.arange(it_tot + 1) * dt_save
Expand All @@ -206,7 +209,8 @@ def evolve(u):
uu = jnp.zeros([it_tot, u.shape[0]])
uu = uu.at[0].set(u)

cond_fun = lambda x: x[0] < fin_time
def cond_fun(x):
return x[0] < fin_time

def _body_fun(carry):
def _show(_carry):
Expand All @@ -228,9 +232,7 @@ def _show(_carry):

carry = t, tsave, steps, i_save, dt, u, uu
t, tsave, steps, i_save, dt, u, uu = lax.while_loop(cond_fun, _body_fun, carry)
uu = uu.at[-1].set(u)

return uu
return uu.at[-1].set(u)

@jax.jit
def simulation_fn(i, carry):
Expand Down Expand Up @@ -301,7 +303,7 @@ def flux(u):
# reshape before saving
uu = uu.reshape((-1, *uu.shape[2:]))

print("data saving...")
logger.info("data saving...")
cwd = hydra.utils.get_original_cwd() + "/"
Path(cwd + cfg.multi.save).mkdir(parents=True, exist_ok=True)
jnp.save(cwd + cfg.multi.save + "1D_Burgers_Sols_Nu" + str(epsilon)[:5], uu)
Expand Down
Loading

0 comments on commit 7da8b79

Please sign in to comment.