diff --git a/src/torch_ppr/utils.py b/src/torch_ppr/utils.py index 9aa2eed..cf27ddb 100644 --- a/src/torch_ppr/utils.py +++ b/src/torch_ppr/utils.py @@ -143,7 +143,7 @@ def validate_adjacency(adj: torch.Tensor, n: Optional[int] = None): else: # hotfix until torch.sparse.sum is implemented adj_sum = adj.t() @ torch.ones(adj.shape[0]) - if not torch.allclose(adj_sum, torch.ones_like(adj_sum)): + if not torch.allclose(adj_sum, torch.ones_like(adj_sum), rtol=1.0e-04): raise ValueError(f"Invalid column sum: {adj_sum}. expected 1.0")