Skip to content

Commit

Permalink
Improve input validation (#10)
Browse files Browse the repository at this point in the history
* make public methods keyword-only
* move validation into public method

* check dtype
also emit warning if the adjacency looks like an edge index
* check value range

* add tests for adjacency validation
* fix seed for reproducible tests
  • Loading branch information
mberr authored May 12, 2022
1 parent 3daee4c commit d4a7498
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 9 deletions.
7 changes: 7 additions & 0 deletions src/torch_ppr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
prepare_page_rank_adjacency,
prepare_x0,
resolve_device,
validate_adjacency,
validate_x,
)

Expand All @@ -26,6 +27,7 @@


def page_rank(
*,
adj: Optional[torch.Tensor] = None,
edge_index: Optional[torch.LongTensor] = None,
num_nodes: Optional[int] = None,
Expand Down Expand Up @@ -68,6 +70,8 @@ def page_rank(
"""
# normalize inputs
adj = prepare_page_rank_adjacency(adj=adj, edge_index=edge_index, num_nodes=num_nodes)
validate_adjacency(adj=adj)

x0 = prepare_x0(x0=x0, n=adj.shape[0])

# input normalization
Expand All @@ -89,6 +93,7 @@ def page_rank(


def personalized_page_rank(
*,
adj: Optional[torch.Tensor] = None,
edge_index: Optional[torch.LongTensor] = None,
num_nodes: Optional[int] = None,
Expand Down Expand Up @@ -129,6 +134,8 @@ def personalized_page_rank(
adj = prepare_page_rank_adjacency(adj=adj, edge_index=edge_index, num_nodes=num_nodes).to(
device=device
)
validate_adjacency(adj=adj)

if indices is None:
indices = torch.arange(adj.shape[0], device=device)
else:
Expand Down
22 changes: 20 additions & 2 deletions src/torch_ppr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,29 @@ def validate_adjacency(adj: torch.Tensor, n: Optional[int] = None):
:raises ValueError:
if the adjacency matrix is invalid
"""
# check dtype
if not torch.is_floating_point(adj):
if adj.shape[0] == 2 and adj.shape[1] != 2:
logger.warning(
"The passed adjacency matrix looks like an edge_index; did you pass it for the wrong parameter?"
)
raise ValueError(
f"Invalid adjacency matrix data type: {adj.dtype}, should be a floating dtype."
)

# check shape
if n is None:
n = adj.shape[0]
if adj.shape != (n, n):
raise ValueError(f"Invalid shape: {adj.shape}. expected: {(n, n)}")
raise ValueError(f"Invalid adjacency matrix shape: {adj.shape}. expected: {(n, n)}")

# check value range
adj = adj.coalesce()
values = adj.values()
if (values < 0.0).any() or (values > 1.0).any():
raise ValueError(
f"Invalid values outside of [0, 1]: min={values.min().item()}, max={values.max().item()}"
)

# check column-sum
adj_sum = torch.sparse.sum(adj, dim=0).to_dense()
Expand Down Expand Up @@ -149,7 +167,7 @@ def prepare_page_rank_adjacency(
If ``None``, and ``adj`` is not already provided, it is inferred from ``edge_index``.
:raises ValueError:
if neither is provided
if neither is provided, or the adjacency matrix is invalid
:return: shape: ``(n, n)``
the symmetric, normalized, and sparse adjacency matrix
Expand Down
21 changes: 15 additions & 6 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from torch_ppr import api
from torch_ppr import page_rank, personalized_page_rank, utils


class APITest(unittest.TestCase):
Expand All @@ -23,21 +23,30 @@ def setUp(self) -> None:
],
dim=-1,
)
self.adj = utils.prepare_page_rank_adjacency(edge_index=self.edge_index)

def test_page_rank(self):
def test_page_rank_edge_index(self):
"""Test Page Rank calculation for an adjacency given as edge list."""
page_rank(edge_index=self.edge_index)

def test_page_rank_adj(self):
"""Test Page Rank calculation."""
api.page_rank(edge_index=self.edge_index)
page_rank(adj=self.adj)

def test_personalized_page_rank_edge_index(self):
"""Test Personalized Page Rank calculation for an adjacency given as edge list."""
personalized_page_rank(edge_index=self.edge_index)

def test_personalized_page_rank(self):
def test_personalized_page_rank_adj(self):
"""Test Personalized Page Rank calculation."""
api.personalized_page_rank(edge_index=self.edge_index)
personalized_page_rank(adj=self.adj)

def test_page_rank_manual(self):
"""Test Page Rank calculation on a simple manually created example."""
# A - B - C
# |
# D
edge_index = torch.as_tensor(data=[(0, 1), (1, 2), (1, 3)]).t()
x = api.page_rank(edge_index=edge_index)
x = page_rank(edge_index=edge_index)
# verify that central node has the largest PR value
assert x.argmax() == 1
35 changes: 35 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class UtilsTest(unittest.TestCase):

def setUp(self) -> None:
"""Prepare data."""
# fix seed for reproducible tests
torch.manual_seed(seed=42)
self.edge_index = torch.cat(
[
torch.randint(self.num_nodes, size=(2, self.num_edges - self.num_nodes)),
Expand Down Expand Up @@ -60,6 +62,39 @@ def test_edge_index_to_sparse_matrix(self):
)
assert adj.shape == (self.num_nodes, self.num_nodes)

def test_validate_adjacancy(self):
"""Test adjacency validation."""
adj = utils.prepare_page_rank_adjacency(edge_index=self.edge_index)
# plain validation with shape inference
utils.validate_adjacency(adj=adj)
# plain validation with explicit shape
utils.validate_adjacency(adj=adj, n=self.num_nodes)
# test error raising
for adj in (
# an edge_index instead of adj
self.edge_index,
# wrong shape
torch.sparse_coo_tensor(
indices=torch.empty(2, 0, dtype=torch.long),
values=torch.empty(0),
size=(2, 3),
),
# wrong value range
torch.sparse_coo_tensor(
indices=self.edge_index,
values=torch.full(size=(self.num_edges,), fill_value=2.0),
size=(self.num_nodes, self.num_nodes),
),
# wrong sum
torch.sparse_coo_tensor(
indices=self.edge_index,
values=torch.ones(self.num_edges),
size=(self.num_nodes, self.num_nodes),
),
):
with self.assertRaises(ValueError):
utils.validate_adjacency(adj=adj)

def test_prepare_page_rank_adjacency(self):
"""Test adjacency preparation."""
for (adj, edge_index) in (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ def test_version_type(self):
This is only meant to be an example test.
"""
version = get_version()
version = get_version(with_git_hash=True)
self.assertIsInstance(version, str)

0 comments on commit d4a7498

Please sign in to comment.