Skip to content

Commit

Permalink
api: Patch imperfect cross derivs w symbolic coeffs
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Oct 26, 2023
1 parent 4b0e39a commit 36cf0be
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
14 changes: 11 additions & 3 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def _symbolic_functions(self):
def _uses_symbolic_coefficients(self):
return bool(self._symbolic_functions)

@cached_property
def _coeff_symbol(self, *args, **kwargs):
if self._uses_symbolic_coefficients:
return W
else:
raise ValueError("Couldn't find any symbolic coefficients")

def _eval_at(self, func):
if not func.is_Staggered:
# Cartesian grid, do no waste time
Expand Down Expand Up @@ -327,6 +334,10 @@ def highest_priority(DiffOp):
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]


# Abstract symbol representing a symbolic coefficient
W = sympy.Function('W')


class DifferentiableOp(Differentiable):

__sympy_class__ = None
Expand Down Expand Up @@ -766,9 +777,6 @@ def _new_rawargs(self, *args, **kwargs):
kwargs.pop('is_commutative', None)
return self.func(*args, **kwargs)

def _coeff_symbol(self, *args, **kwargs):
return self.base._coeff_symbol(*args, **kwargs)


class diffify(object):

Expand Down
8 changes: 0 additions & 8 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,6 @@ def coefficients(self):
"""Form of the coefficients of the function."""
return self._coefficients

@cached_property
def _coeff_symbol(self):
if self.coefficients == 'symbolic':
return sympy.Function('W')
else:
raise ValueError("Function was not declared with symbolic "
"coefficients.")

@cached_property
def shape(self):
"""
Expand Down
18 changes: 18 additions & 0 deletions tests/test_symbolic_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def test_aggregate_w_custom_coeffs(self):

def test_cross_derivs(self):
grid = Grid(shape=(11, 11, 11))

q = TimeFunction(name='q', grid=grid, space_order=8, time_order=2,
coefficients='symbolic')
q0 = TimeFunction(name='q', grid=grid, space_order=8, time_order=2)
Expand All @@ -389,3 +390,20 @@ def test_cross_derivs(self):

assert(eq0.evaluate.evalf(_PRECISION).__repr__() ==
eq1.evaluate.evalf(_PRECISION).__repr__())

def test_cross_derivs_imperfect(self):
grid = Grid(shape=(11, 11, 11))

p = TimeFunction(name='p', grid=grid, space_order=4, time_order=2,
coefficients='symbolic')
q = TimeFunction(name='q', grid=grid, space_order=4, time_order=2,
coefficients='symbolic')

p0 = TimeFunction(name='p', grid=grid, space_order=4, time_order=2)
q0 = TimeFunction(name='q', grid=grid, space_order=4, time_order=2)

eq0 = Eq(q0.forward, (q0.dx + p0.dx).dy)
eq1 = Eq(q.forward, (q.dx + p.dx).dy)

assert(eq0.evaluate.evalf(_PRECISION).__repr__() ==
eq1.evaluate.evalf(_PRECISION).__repr__())
16 changes: 13 additions & 3 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_backward_dt2(self):
assert_structure(op, ['t,x,y'], 't,x,y')


class TestSymbolicCoefficients(object):
class TestSymbolicCoeffs(object):

def test_fallback_to_default(self):
grid = Grid(shape=(8, 8, 8))
Expand All @@ -40,12 +40,22 @@ def test_fallback_to_default(self):

def test_numeric_coeffs(self):
grid = Grid(shape=(11,), extent=(10.,))

u = Function(name='u', grid=grid, coefficients='symbolic', space_order=2)
v = Function(name='v', grid=grid, coefficients='symbolic', space_order=2)

coeffs = Substitutions(Coefficient(2, u, grid.dimensions[0], np.zeros(3)))

op = Operator(Eq(u, u.dx2, coefficients=coeffs), opt=({'expand': False},))
op.cfunction
opt = ('advanced', {'expand': False})

# Pure derivative
Operator(Eq(u, u.dx2, coefficients=coeffs), opt=opt).cfunction

# Mixed derivative
Operator(Eq(u, u.dx.dx, coefficients=coeffs), opt=opt).cfunction

# Non-perfect mixed derivative
Operator(Eq(u, (u.dx + v.dx).dx, coefficients=coeffs), opt=opt).cfunction


class Test1Pass(object):
Expand Down

0 comments on commit 36cf0be

Please sign in to comment.