Skip to content

Commit

Permalink
change device resolution order (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
mberr authored May 9, 2022
1 parent d2a4897 commit c16e2b1
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/torch_ppr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,16 @@ def personalized_page_rank(
:return: shape: ``(k, n)``
the PPR vectors for each node index
"""
# resolve device first
device = resolve_device(device=device)
# prepare adjacency and indices only once
adj = prepare_page_rank_adjacency(adj=adj, edge_index=edge_index)
adj = prepare_page_rank_adjacency(adj=adj, edge_index=edge_index).to(device=device)
if indices is None:
indices = torch.arange(adj.shape[0], device=device)
else:
indices = torch.as_tensor(indices, dtype=torch.long, device=device)
# normalize inputs
batch_size = batch_size or len(indices)
device = resolve_device(device=device)
return batched_personalized_page_rank(
adj=adj, indices=indices, device=device, batch_size=batch_size, **kwargs
).t()

0 comments on commit c16e2b1

Please sign in to comment.