From 698d6a8d0df908d0a981a8b01624422ad23f17a8 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 24 Oct 2023 09:05:53 +0000 Subject: [PATCH] compiler: Patch custom coefficients --- devito/passes/iet/definitions.py | 4 +++- devito/symbolics/extended_sympy.py | 14 +++++--------- tests/test_unexpansion.py | 12 +++++++++++- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index e4195c95c6..0568e833ad 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -355,7 +355,9 @@ def place_definitions(self, iet, globs=None, **kwargs): includes = set() if isinstance(iet, EntryFunction) and globs: for i in sorted(globs, key=lambda f: f.name): - includes.add(self._alloc_array_on_global_mem(iet, i, storage)) + v = self._alloc_array_on_global_mem(iet, i, storage) + if v: + includes.add(v) iet, efuncs = self._inject_definitions(iet, storage) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 87a56c13b7..5db4f6b7d8 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -4,7 +4,7 @@ import numpy as np import sympy -from sympy import Expr, Integer, Function, Number, Tuple, sympify +from sympy import Expr, Function, Number, Tuple, sympify from sympy.core.decorators import call_highest_priority from devito.tools import (Pickable, as_tuple, is_integer, float2, float3, float4, # noqa @@ -278,14 +278,10 @@ class ListInitializer(sympy.Expr, Pickable): def __new__(cls, params): args = [] for p in as_tuple(params): - if isinstance(p, str): - args.append(Symbol(p)) - elif is_integer(p): - args.append(Integer(p)) - elif not isinstance(p, Expr): - raise ValueError("`params` must be an iterable of Expr or str") - else: - args.append(p) + try: + args.append(sympify(p)) + except sympy.SympifyError: + raise ValueError("Illegal param `%s`" % p) obj = sympy.Expr.__new__(cls, *args) obj.params = tuple(args) return obj diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index 1e269328c1..8685e6626e 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -1,7 +1,8 @@ import numpy as np from conftest import assert_structure, get_params, get_arrays, check_array -from devito import Buffer, Eq, Function, TimeFunction, Grid, Operator, cos, sin +from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator, + Substitutions, Coefficient, cos, sin) from devito.types import Symbol @@ -37,6 +38,15 @@ def test_fallback_to_default(self): op.arguments(dt=1, time_M=10) op.cfunction + def test_numeric_coeffs(self): + grid = Grid(shape=(11,), extent=(10.,)) + u = Function(name='u', grid=grid, coefficients='symbolic', space_order=2) + + coeffs = Substitutions(Coefficient(2, u, grid.dimensions[0], np.zeros(3))) + + op = Operator(Eq(u, u.dx2, coefficients=coeffs), opt=({'expand': False},)) + op.cfunction + class Test1Pass(object):