Skip to content

Commit

Permalink
misc: cleanup multituple
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 31, 2023
1 parent a5050f4 commit 31ade85
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 28 deletions.
2 changes: 1 addition & 1 deletion devito/core/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _normalize_kwargs(cls, **kwargs):
o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE)

# GPU parallelism
o['par-tile'] = ParTile(oo.pop('par-tile', False), default=(32, 4, 4, 4))
o['par-tile'] = ParTile(oo.pop('par-tile', False), default=(32, 4, 4))
o['par-collapse-ncores'] = 1 # Always collapse (meaningful if `par-tile=False`)
o['par-collapse-work'] = 1 # Always collapse (meaningful if `par-tile=False`)
o['par-chunk-nonaffine'] = oo.pop('par-chunk-nonaffine', cls.PAR_CHUNK_NONAFFINE)
Expand Down
10 changes: 5 additions & 5 deletions devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.mpi.routines import mpi_registry
from devito.parameters import configuration
from devito.operator import Operator
from devito.tools import as_tuple, is_integer, timed_pass
from devito.tools import as_tuple, is_integer, timed_pass, UnboundTuple
from devito.types import NThreads

__all__ = ['CoreOperator', 'CustomOperator',
Expand Down Expand Up @@ -327,12 +327,12 @@ class OptOption(object):
pass


class ParTileArg(tuple):
class ParTileArg(UnboundTuple):

def __new__(cls, items, shm=0, tag=None):
def __new__(cls, items, rule=None, tag=None):
if items is None:
items = tuple()
obj = super().__new__(cls, items)
obj = super().__new__(cls, *items)
obj.rule = rule
obj.tag = tag
return obj
Expand All @@ -355,7 +355,7 @@ def __new__(cls, items, default=None):

x = items[0]
if is_integer(x):
# E.g., (32, 4, 8)
# E.g., 32
items = (ParTileArg(items),)

elif x is None:
Expand Down
52 changes: 30 additions & 22 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __reduce__(self):
args = tuple()
else:
args = self.default_factory,
return type(self), args, None, None, self.items()
return type(self), args, None, None, self()

def copy(self):
return self.__copy__()
Expand Down Expand Up @@ -639,7 +639,7 @@ def __hash__(self):
return self._hash


class UnboundTuple(object):
class UnboundTuple(tuple):
"""
An UnboundedTuple is a tuple that can be
infinitely iterated over.
Expand All @@ -657,30 +657,37 @@ class UnboundTuple(object):
UnboundTuple(3, 4)
"""

def __init__(self, *items):
def __new__(cls, *items, **kwargs):
nitems = []
for i in as_tuple(items):
if isinstance(i, Iterable):
if isinstance(i, UnboundTuple):
nitems.append(i)
elif isinstance(i, Iterable):
nitems.append(UnboundTuple(*i))
elif i is not None:
nitems.append(i)

self.items = tuple(nitems)
self.last = len(self.items)
self.current = 0
obj = super().__new__(cls, tuple(nitems))
obj.last = len(nitems)
obj.current = 0

return obj

@property
def default(self):
return self.items[0]
return self[0]

@property
def prod(self):
return np.prod(self.items)
return np.prod(self)

def iter(self):
self.current = 0

def next(self):
if self.last == 0:
return None
item = self.items[self.current]
item = self[self.current]
if self.current == self.last-1 or self.current == -1:
self.current = -1
else:
Expand All @@ -691,7 +698,7 @@ def __len__(self):
return self.last

def __repr__(self):
sitems = [s.__repr__() for s in self.items]
sitems = [s.__repr__() for s in self]
return "%s(%s)" % (self.__class__.__name__, ", ".join(sitems))

def __getitem__(self, idx):
Expand All @@ -704,9 +711,9 @@ def __getitem__(self, idx):
return UnboundTuple(*[self[i] for i in range(start, stop, step)])
try:
if idx >= self.last-1:
return self.items[self.last-1]
return super().__getitem__(self.last-1)
else:
return self.items[idx]
return super().__getitem__(idx)
except TypeError:
# Slice, ...
return UnboundTuple(self[i] for i in idx)
Expand Down Expand Up @@ -744,27 +751,28 @@ class UnboundedMultiTuple(UnboundTuple):
3
"""

def __init__(self, *items):
super().__init__(*items)
self.current = -1
def __new__(cls, *items, **kwargs):
obj = super().__new__(cls, *items, **kwargs)
obj.current = -1
return obj

@property
def curitem(self):
return self.items[self.tip]
return self[self.current]

@property
def nextitem(self):
return self.items[min(self.tip + 1, max(len(self.items) - 1, 0))]
return self[min(self.current + 1, max(self.last - 1, 0))]

def index(self, item):
return self.items.index(item)
return self.index(item)

def iter(self):
self.current = min(self.current + 1, self.last - 1)
self.items[self.current].current = 0
self[self.current].current = 0
return

def next(self):
if self.items[self.current].current == -1:
if self[self.current].current == -1:
raise StopIteration
return self.items[self.current].next()
return self[self.current].next()
1 change: 1 addition & 0 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def test_custom_rule0(self):

# Check generated code. By having specified "1" as rule, we expect the
# given par-tile to be applied to the kernel with id 1
from IPython import embed; embed()
bns, _ = assert_blocking(op, {'z0_blk0', 'x1_blk0', 'z2_blk0'})
root = bns['x1_blk0']
iters = FindNodes(Iteration).visit(root)
Expand Down

0 comments on commit 31ade85

Please sign in to comment.