From c16e2b10ae2f4f45118618f9e59200bf9da3eb25 Mon Sep 17 00:00:00 2001 From: Max Berrendorf Date: Mon, 9 May 2022 15:36:44 +0200 Subject: [PATCH] change device resolution order (#5) --- src/torch_ppr/api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torch_ppr/api.py b/src/torch_ppr/api.py index 401b8b6..08536d1 100644 --- a/src/torch_ppr/api.py +++ b/src/torch_ppr/api.py @@ -114,15 +114,16 @@ def personalized_page_rank( :return: shape: ``(k, n)`` the PPR vectors for each node index """ + # resolve device first + device = resolve_device(device=device) # prepare adjacency and indices only once - adj = prepare_page_rank_adjacency(adj=adj, edge_index=edge_index) + adj = prepare_page_rank_adjacency(adj=adj, edge_index=edge_index).to(device=device) if indices is None: indices = torch.arange(adj.shape[0], device=device) else: indices = torch.as_tensor(indices, dtype=torch.long, device=device) # normalize inputs batch_size = batch_size or len(indices) - device = resolve_device(device=device) return batched_personalized_page_rank( adj=adj, indices=indices, device=device, batch_size=batch_size, **kwargs ).t()