diff --git a/devito/core/gpu.py b/devito/core/gpu.py index 46f8914f6b1..7aa24e492f4 100644 --- a/devito/core/gpu.py +++ b/devito/core/gpu.py @@ -65,7 +65,7 @@ def _normalize_kwargs(cls, **kwargs): o['cire-schedule'] = oo.pop('cire-schedule', cls.CIRE_SCHEDULE) # GPU parallelism - o['par-tile'] = ParTile(oo.pop('par-tile', False), default=(32, 4, 4, 4)) + o['par-tile'] = ParTile(oo.pop('par-tile', False), default=(32, 4, 4)) o['par-collapse-ncores'] = 1 # Always collapse (meaningful if `par-tile=False`) o['par-collapse-work'] = 1 # Always collapse (meaningful if `par-tile=False`) o['par-chunk-nonaffine'] = oo.pop('par-chunk-nonaffine', cls.PAR_CHUNK_NONAFFINE) diff --git a/devito/core/operator.py b/devito/core/operator.py index 19f61f4627d..3daa9913169 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -6,7 +6,7 @@ from devito.mpi.routines import mpi_registry from devito.parameters import configuration from devito.operator import Operator -from devito.tools import as_tuple, is_integer, timed_pass +from devito.tools import as_tuple, is_integer, timed_pass, UnboundTuple from devito.types import NThreads __all__ = ['CoreOperator', 'CustomOperator', @@ -327,12 +327,12 @@ class OptOption(object): pass -class ParTileArg(tuple): +class ParTileArg(UnboundTuple): - def __new__(cls, items, shm=0, tag=None): + def __new__(cls, items, rule=None, tag=None): if items is None: items = tuple() - obj = super().__new__(cls, items) + obj = super().__new__(cls, *items) obj.rule = rule obj.tag = tag return obj @@ -355,7 +355,7 @@ def __new__(cls, items, default=None): x = items[0] if is_integer(x): - # E.g., (32, 4, 8) + # E.g., 32 items = (ParTileArg(items),) elif x is None: diff --git a/devito/tools/data_structures.py b/devito/tools/data_structures.py index 198bf2fb0b5..d8f92e0cf29 100644 --- a/devito/tools/data_structures.py +++ b/devito/tools/data_structures.py @@ -216,7 +216,7 @@ def __reduce__(self): args = tuple() else: args = self.default_factory, - return type(self), args, None, None, self.items() + return type(self), args, None, None, self() def copy(self): return self.__copy__() @@ -639,7 +639,7 @@ def __hash__(self): return self._hash -class UnboundTuple(object): +class UnboundTuple(tuple): """ An UnboundedTuple is a tuple that can be infinitely iterated over. @@ -657,30 +657,37 @@ class UnboundTuple(object): UnboundTuple(3, 4) """ - def __init__(self, *items): + def __new__(cls, *items, **kwargs): nitems = [] for i in as_tuple(items): - if isinstance(i, Iterable): + if isinstance(i, UnboundTuple): + nitems.append(i) + elif isinstance(i, Iterable): nitems.append(UnboundTuple(*i)) elif i is not None: nitems.append(i) - self.items = tuple(nitems) - self.last = len(self.items) - self.current = 0 + obj = super().__new__(cls, tuple(nitems)) + obj.last = len(nitems) + obj.current = 0 + + return obj @property def default(self): - return self.items[0] + return self[0] @property def prod(self): - return np.prod(self.items) + return np.prod(self) + + def iter(self): + self.current = 0 def next(self): if self.last == 0: return None - item = self.items[self.current] + item = self[self.current] if self.current == self.last-1 or self.current == -1: self.current = -1 else: @@ -691,7 +698,7 @@ def __len__(self): return self.last def __repr__(self): - sitems = [s.__repr__() for s in self.items] + sitems = [s.__repr__() for s in self] return "%s(%s)" % (self.__class__.__name__, ", ".join(sitems)) def __getitem__(self, idx): @@ -704,9 +711,9 @@ def __getitem__(self, idx): return UnboundTuple(*[self[i] for i in range(start, stop, step)]) try: if idx >= self.last-1: - return self.items[self.last-1] + return super().__getitem__(self.last-1) else: - return self.items[idx] + return super().__getitem__(idx) except TypeError: # Slice, ... return UnboundTuple(self[i] for i in idx) @@ -744,27 +751,28 @@ class UnboundedMultiTuple(UnboundTuple): 3 """ - def __init__(self, *items): - super().__init__(*items) - self.current = -1 + def __new__(cls, *items, **kwargs): + obj = super().__new__(cls, *items, **kwargs) + obj.current = -1 + return obj @property def curitem(self): - return self.items[self.tip] + return self[self.current] @property def nextitem(self): - return self.items[min(self.tip + 1, max(len(self.items) - 1, 0))] + return self[min(self.current + 1, max(self.last - 1, 0))] def index(self, item): - return self.items.index(item) + return self.index(item) def iter(self): self.current = min(self.current + 1, self.last - 1) - self.items[self.current].current = 0 + self[self.current].current = 0 return def next(self): - if self.items[self.current].current == -1: + if self[self.current].current == -1: raise StopIteration - return self.items[self.current].next() + return self[self.current].next() diff --git a/tests/test_dle.py b/tests/test_dle.py index bfcb732bb40..7f505ab920b 100644 --- a/tests/test_dle.py +++ b/tests/test_dle.py @@ -363,6 +363,7 @@ def test_custom_rule0(self): # Check generated code. By having specified "1" as rule, we expect the # given par-tile to be applied to the kernel with id 1 + from IPython import embed; embed() bns, _ = assert_blocking(op, {'z0_blk0', 'x1_blk0', 'z2_blk0'}) root = bns['x1_blk0'] iters = FindNodes(Iteration).visit(root)