Skip to content

Commit

Permalink
Merge pull request #627 from mtsokol/sparse-sort-and-take
Browse files Browse the repository at this point in the history
API: Add `sort` and `take` functions for COO format
  • Loading branch information
mtsokol authored Jan 16, 2024
2 parents 82fb0d5 + 5bd29ad commit 16b46cd
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 35 deletions.
4 changes: 4 additions & 0 deletions docs/generated/sparse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ API

save_npz

sort

squeeze

stack
Expand All @@ -152,6 +154,8 @@ API

sum

take

tensordot

tril
Expand Down
4 changes: 4 additions & 0 deletions sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@
nansum,
result_type,
roll,
sort,
take,
tril,
triu,
unique_counts,
Expand Down Expand Up @@ -283,13 +285,15 @@
"sign",
"sin",
"sinh",
"sort",
"sqrt",
"square",
"squeeze",
"stack",
"std",
"subtract",
"sum",
"take",
"tan",
"tanh",
"tensordot",
Expand Down
4 changes: 4 additions & 0 deletions sparse/_coo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
nansum,
result_type,
roll,
sort,
stack,
take,
tril,
triu,
unique_counts,
Expand Down Expand Up @@ -51,7 +53,9 @@
"nansum",
"result_type",
"roll",
"sort",
"stack",
"take",
"tril",
"triu",
"unique_counts",
Expand Down
213 changes: 178 additions & 35 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections.abc import Iterable
from functools import reduce
from typing import NamedTuple, Optional, Tuple
from typing import Any, NamedTuple, Optional, Tuple

import numba

Expand Down Expand Up @@ -1090,14 +1090,8 @@ def expand_dims(x, /, *, axis=0):
(1, 6, 1)
"""
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)
x = _validate_coo_input(x)

if not isinstance(axis, int):
raise IndexError(f"Invalid axis position: {axis}")
Expand All @@ -1109,6 +1103,8 @@ def expand_dims(x, /, *, axis=0):
new_shape.insert(axis, 1)
new_shape = tuple(new_shape)

from .core import COO

return COO(
new_coords,
x.data,
Expand Down Expand Up @@ -1140,14 +1136,8 @@ def flip(x, /, *, axis=None):
relative to ``x``, are reordered.
"""
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)
x = _validate_coo_input(x)

if axis is None:
axis = range(x.ndim)
Expand All @@ -1158,6 +1148,8 @@ def flip(x, /, *, axis=None):
for ax in axis:
new_coords[ax, :] = x.shape[ax] - 1 - x.coords[ax, :]

from .core import COO

return COO(
new_coords,
x.data,
Expand Down Expand Up @@ -1203,14 +1195,8 @@ def unique_counts(x, /):
>>> sparse.unique_counts(x)
UniqueCountsResult(values=array([-3, 0, 1, 2]), counts=array([1, 1, 2, 2]))
"""
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)
x = _validate_coo_input(x)

x = x.flatten()
values, counts = np.unique(x.data, return_counts=True)
Expand Down Expand Up @@ -1250,6 +1236,116 @@ def unique_values(x, /):
>>> sparse.unique_values(x)
array([-3, 0, 1, 2])
"""

x = _validate_coo_input(x)

x = x.flatten()
values = np.unique(x.data)
if x.nnz < x.size:
values = np.sort(np.concatenate([[x.fill_value], values]))
return values


def sort(x, /, *, axis=-1, descending=False):
"""
Returns a sorted copy of an input array ``x``.
Parameters
----------
x : SparseArray
Input array. Should have a real-valued data type.
axis : int
Axis along which to sort. If set to ``-1``, the function must sort along
the last axis. Default: ``-1``.
descending : bool
Sort order. If ``True``, the array must be sorted in descending order (by value).
If ``False``, the array must be sorted in ascending order (by value).
Default: ``False``.
Returns
-------
out : COO
A sorted array.
Raises
------
ValueError
If the input array isn't and can't be converted to COO format.
Examples
--------
>>> import sparse
>>> x = sparse.COO.from_numpy([1, 0, 2, 0, 2, -3])
>>> sparse.sort(x).todense()
array([-3, 0, 0, 1, 2, 2])
>>> sparse.sort(x, descending=True).todense()
array([ 2, 2, 1, 0, 0, -3])
"""

from .._common import moveaxis
from .core import COO

x = _validate_coo_input(x)

original_ndim = x.ndim
if x.ndim == 1:
x = x[None, :]
axis = -1

x = moveaxis(x, source=axis, destination=-1)
x_shape = x.shape
x = x.reshape((-1, x_shape[-1]))

new_coords, new_data = _sort_coo(x.coords, x.data, x.fill_value, sort_axis_len=x_shape[-1], descending=descending)

x = COO(new_coords, new_data, x.shape, has_duplicates=False, sorted=True, fill_value=x.fill_value)

x = x.reshape(x_shape[:-1] + (x_shape[-1],))
x = moveaxis(x, source=-1, destination=axis)

return x if original_ndim == x.ndim else x.squeeze()


def take(x, indices, /, *, axis=None):
"""
Returns elements of an array along an axis.
Parameters
----------
x : SparseArray
Input array.
indices : ndarray
Array indices. The array must be one-dimensional and have an integer data type.
axis : int
Axis over which to select values. If ``axis`` is negative, the function must
determine the axis along which to select values by counting from the last dimension.
For ``None``, the flattened input array is used. Default: ``None``.
Returns
-------
out : COO
A COO array with requested indices.
Raises
------
ValueError
If the input array isn't and can't be converted to COO format.
"""

x = _validate_coo_input(x)

if axis is None:
x = x.flatten()
return x[indices]

axis = normalize_axis(axis, x.ndim)
full_index = (slice(None),) * axis + (indices, ...)
return x[full_index]


def _validate_coo_input(x: Any):
from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
Expand All @@ -1259,11 +1355,65 @@ def unique_values(x, /):
elif not isinstance(x, COO):
x = x.asformat(COO)

x = x.flatten()
values = np.unique(x.data)
if x.nnz < x.size:
values = np.sort(np.concatenate([[x.fill_value], values]))
return values
return x


@numba.jit(nopython=True, nogil=True)
def _sort_coo(
coords: np.ndarray,
data: np.ndarray,
fill_value: float,
sort_axis_len: int,
descending: bool,
) -> Tuple[np.ndarray, np.ndarray]:
assert coords.shape[0] == 2
group_coords = coords[0, :]
sort_coords = coords[1, :]

data = data.copy()
result_indices = np.empty_like(sort_coords)

# We iterate through all groups and sort each one of them.
# first and last index of a group is tracked.
prev_group = -1
group_first_idx = -1
group_last_idx = -1
# We add `-1` sentinel to know when the last group ends
for idx, group in enumerate(np.append(group_coords, -1)):
if group == prev_group:
continue

if prev_group != -1:
group_last_idx = idx

group_slice = slice(group_first_idx, group_last_idx)
group_size = group_last_idx - group_first_idx

# SORT VALUES
if group_size > 1:
# np.sort in numba doesn't support `np.sort`'s arguments so `stable`
# keyword can't be supported.
# https://numba.pydata.org/numba-doc/latest/reference/numpysupported.html#other-methods
data[group_slice] = np.sort(data[group_slice])
if descending:
data[group_slice] = data[group_slice][::-1]

# SORT INDICES
fill_value_count = sort_axis_len - group_size
indices = np.arange(group_size)
# find a place where fill_value would be
for pos in range(group_size):
if (not descending and fill_value < data[group_slice][pos]) or (
descending and fill_value > data[group_slice][pos]
):
indices[pos:] += fill_value_count
break
result_indices[group_first_idx:group_last_idx] = indices

prev_group = group
group_first_idx = idx

return np.vstack((group_coords, result_indices)), data


@numba.jit(nopython=True, nogil=True)
Expand Down Expand Up @@ -1323,14 +1473,7 @@ def _arg_minmax_common(
assert mode in ("max", "min")
max_mode_flag = mode == "max"

from .core import COO

if isinstance(x, scipy.sparse.spmatrix):
x = COO.from_scipy_sparse(x)
elif not isinstance(x, SparseArray):
raise ValueError(f"Input must be an instance of SparseArray, but it's {type(x)}.")
elif not isinstance(x, COO):
x = x.asformat(COO)
x = _validate_coo_input(x)

if not isinstance(axis, (int, type(None))):
raise ValueError(f"`axis` must be `int` or `None`, but it's: {type(axis)}.")
Expand Down
Loading

0 comments on commit 16b46cd

Please sign in to comment.