Skip to content

Commit

Permalink
Outperformed the vanilla DrBC paper on Amazon with Kendall-Tau metric
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinXPN committed Jan 7, 2021
1 parent be81b98 commit b82d0c4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 8 additions & 1 deletion abcde/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import numpy as np
import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -63,6 +65,11 @@ def validation_step(self, batch, batch_idx):
pred = self(batch).cpu().detach().numpy().flatten()
label = batch.y.cpu().detach().numpy().flatten()

# Vertices with deg(v) <= 1 (leafs) can't have high betweenness-centrality
degrees = batch.x.cpu().detach().numpy().flatten()
mask = degrees * len(degrees) < 1.1
pred[mask] = pred.min() - np.finfo(np.float32).eps

top_pred = np.argsort(-pred)
top_label = np.argsort(-label)
res = {
Expand All @@ -73,7 +80,7 @@ def validation_step(self, batch, batch_idx):
'val_mse': mean_squared_error(label, pred),
'val_max_error': max_error(label, pred)
}
self.log_dict(res)
self.log_dict(copy.copy(res))
return res

def configure_optimizers(self):
Expand Down
2 changes: 2 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def predict(model_path: Union[str, IO],
graph = Data(x=torch.from_numpy(degrees),
y=torch.from_numpy(label),
edge_index=torch.from_numpy(edge_index))
print('Graph:', graph)

res = model.validation_step(graph, batch_idx=0)
end = time.time()
res['run_time'] = end - start,
Expand Down

0 comments on commit b82d0c4

Please sign in to comment.