Skip to content

Commit

Permalink
compiler: cleanup ParTile
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Nov 2, 2023
1 parent fb2170d commit 09d3d74
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 26 deletions.
9 changes: 5 additions & 4 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from devito.mpi.routines import mpi_registry
from devito.parameters import configuration
from devito.operator import Operator
from devito.tools import as_tuple, is_integer, timed_pass, UnboundTuple
from devito.tools import (as_tuple, is_integer, timed_pass,
UnboundTuple, UnboundedMultiTuple)
from devito.types import NThreads

__all__ = ['CoreOperator', 'CustomOperator',
Expand Down Expand Up @@ -338,11 +339,11 @@ def __new__(cls, items, rule=None, tag=None):
return obj


class ParTile(tuple, OptOption):
class ParTile(UnboundedMultiTuple, OptOption):

def __new__(cls, items, default=None):
if not items:
return tuple()
return UnboundedMultiTuple()
elif isinstance(items, bool):
if not default:
raise ValueError("Expected `default` value, got None")
Expand Down Expand Up @@ -394,7 +395,7 @@ def __new__(cls, items, default=None):
else:
raise ValueError("Expected bool or iterable, got %s instead" % type(items))

obj = super().__new__(cls, items)
obj = super().__new__(cls, *items)
obj.default = as_tuple(default)

return obj
9 changes: 3 additions & 6 deletions devito/passes/clusters/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ class BlockSizeGenerator(object):
"""

def __init__(self, par_tile):
self.umt = UnboundedMultiTuple(*par_tile)
self.umt = par_tile
self.tip = -1

# This is for Clusters that need a small par-tile to avoid under-utilizing
Expand Down Expand Up @@ -459,11 +459,11 @@ def next(self, prefix, d, clusters):
return self.umt_small.next()

if x:
item = self.umt.curitem
item = self.umt.curitem()
else:
# We can't `self.umt.iter()` because we might still want to
# fallback to `self.umt_small`
item = self.umt.nextitem
item = self.umt.nextitem()

# Handle user-provided rules
# TODO: This is also rudimentary
Expand All @@ -481,9 +481,6 @@ def next(self, prefix, d, clusters):
# This is like "pattern unmatched" -- fallback to `umt_small`
umt = self.umt_small

if not x:
umt.iter()

return umt.next()


Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/languages/openacc.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _make_partree(self, candidates, nthreads=None):
if self._is_offloadable(root) and \
all(i.is_Affine for i in [root] + collapsable) and \
self.par_tile:
tile = self.par_tile.next()
tile = self.par_tile.nextitem()
assert isinstance(tile, UnboundTuple)

body = self.DeviceIteration(gpu_fit=self.gpu_fit, tile=tile,
Expand Down
7 changes: 3 additions & 4 deletions devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from devito.passes.iet.langbase import (LangBB, LangTransformer, DeviceAwareMixin,
make_sections_from_imask)
from devito.symbolics import INT, ccode
from devito.tools import UnboundTuple, as_tuple, flatten, is_integer, prod
from devito.tools import as_tuple, flatten, is_integer, prod
from devito.types import Symbol

__all__ = ['PragmaSimdTransformer', 'PragmaShmTransformer',
Expand Down Expand Up @@ -622,8 +622,7 @@ def __init__(self, sregistry, options, platform, compiler):
super().__init__(sregistry, options, platform, compiler)

self.gpu_fit = options['gpu-fit']
self.par_tile = UnboundTuple(*options['par-tile'],
default=options['par-tile'].default)
self.par_tile = options['par-tile']
self.par_disabled = options['par-disabled']

def _score_candidate(self, n0, root, collapsable=()):
Expand Down Expand Up @@ -659,7 +658,7 @@ def _make_partree(self, candidates, nthreads=None, index=None):
if self._is_offloadable(root):
body = self.DeviceIteration(gpu_fit=self.gpu_fit,
ncollapsed=len(collapsable)+1,
tile=self.par_tile.next(),
tile=self.par_tile.nextitem(),
**root.args)
partree = ParallelTree([], body, nthreads=nthreads)

Expand Down
32 changes: 21 additions & 11 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,14 +670,9 @@ def __new__(cls, *items, **kwargs):
obj = super().__new__(cls, tuple(nitems))
obj.last = len(nitems)
obj.current = 0
obj._default = kwargs.get('default', nitems[0])

return obj

@property
def default(self):
return self._default

@property
def prod(self):
return np.prod(self)
Expand All @@ -686,7 +681,7 @@ def iter(self):
self.current = 0

def next(self):
if self.last == 0:
if not self:
return None
item = self[self.current]
if self.current == self.last-1 or self.current == -1:
Expand All @@ -703,6 +698,8 @@ def __repr__(self):
return "%s(%s)" % (self.__class__.__name__, ", ".join(sitems))

def __getitem__(self, idx):
if not self:
return None
if isinstance(idx, slice):
start = idx.start or 0
stop = idx.stop or self.last
Expand Down Expand Up @@ -754,26 +751,39 @@ class UnboundedMultiTuple(UnboundTuple):

def __new__(cls, *items, **kwargs):
obj = super().__new__(cls, *items, **kwargs)
obj.current = -1
# MultiTuple are un-initialized
obj.current = None
return obj

@property
def curitem(self):
if self.current is None:
raise StopIteration
if not self:
return None
return self[self.current]

@property
def nextitem(self):
return self[min(self.current + 1, max(self.last - 1, 0))]
if not self:
return None
self.iter()
return self.curitem()

def index(self, item):
return self.index(item)

def iter(self):
self.current = min(self.current + 1, self.last - 1)
if self.current is None:
self.current = 0
else:
self.current = min(self.current + 1, self.last - 1)
self[self.current].current = 0
return

def next(self):
if not self:
return None
if self.current is None:
raise StopIteration
if self[self.current].current == -1:
raise StopIteration
return self[self.current].next()
8 changes: 8 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,14 @@ def test_ctypes_to_cstr(dtype, expected):

def test_unbounded_multi_tuple():
ub = UnboundedMultiTuple([1, 2], [3, 4])
with pytest.raises(StopIteration):
ub.next()

with pytest.raises(StopIteration):
assert ub.curitem()

ub.iter()
assert ub.curitem() == (1, 2)
assert ub.next() == 1
assert ub.next() == 2

Expand All @@ -121,6 +127,8 @@ def test_unbounded_multi_tuple():
ub.iter()
assert ub.next() == 3

assert ub.nextitem() == (3, 4)


def test_unbound_tuple():
# Make sure we don't drop needed None for 2.5d
Expand Down

0 comments on commit 09d3d74

Please sign in to comment.