Skip to content

Commit

Permalink
implemented an efficient (progressive) way to have drop_edge in the m…
Browse files Browse the repository at this point in the history
…odel
  • Loading branch information
MartinXPN committed Feb 3, 2021
1 parent 5ce7427 commit 67aa6c7
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 10 deletions.
46 changes: 46 additions & 0 deletions abcde/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from torch_geometric.utils import degree
from torch_geometric.utils.dropout import filter_adj
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_sparse import coalesce


def dropout_adj(edge_index, edge_attr=None, p=0.5, force_undirected=False,
num_nodes=None, training=True):
if p < 0. or p > 1.:
raise ValueError('Dropout probability has to be between 0 and 1, '
'but got {}'.format(p))

if not training:
return edge_index, edge_attr

N = maybe_num_nodes(edge_index, num_nodes)
row, col = edge_index

if force_undirected:
row, col, edge_attr = filter_adj(row, col, edge_attr, row < col)

# Mask for which edges to keep
mask = edge_index.new_full((row.size(0), ), 1 - p, dtype=torch.float)
mask = torch.bernoulli(mask).to(torch.bool)
row_deg, col_deg = degree(row), degree(col)

# initial_keep = torch.sum(mask)
mask |= row_deg[row] < 5
mask |= col_deg[col] < 5
# print(f'Total #edges {edge_index.size()} Initially planned to keep {initial_keep} edges, Eventually kept {torch.sum(mask)}')

# return row[mask], col[mask], None if edge_attr is None else edge_attr[mask]
row, col, edge_attr = filter_adj(row, col, edge_attr, mask)

if force_undirected:
edge_index = torch.stack(
[torch.cat([row, col], dim=0),
torch.cat([col, row], dim=0)], dim=0)
if edge_attr is not None:
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
else:
edge_index = torch.stack([row, col], dim=0)

return edge_index, edge_attr
17 changes: 11 additions & 6 deletions abcde/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from typing import List, Dict, Tuple

import inspect
import numpy as np
import pytorch_lightning as pl
import torch
Expand All @@ -11,6 +11,7 @@
from torch_geometric.data import Batch
from torch_geometric.nn import GCNConv, LayerNorm

from abcde.dropout import dropout_adj
from abcde.loss import PairwiseRankingCrossEntropyLoss
from abcde.metrics import kendall_tau, top_k_ranking_accuracy

Expand Down Expand Up @@ -120,14 +121,17 @@ def forward(self, inputs):


class ABCDE(BetweennessCentralityEstimator):
def __init__(self, nb_gcn_cycles: Tuple[int, ...], conv_sizes: Tuple[int, ...],
def __init__(self, nb_gcn_cycles: Tuple[int, ...], conv_sizes: Tuple[int, ...], drops: Tuple[float, ...],
lr_reduce_patience: int = 1, dropout: float = 0.):
super().__init__(lr_reduce_patience=lr_reduce_patience)
print('gcn cycles:', nb_gcn_cycles)
print('conv sizes:', conv_sizes)
print('drops:', drops)
assert len(nb_gcn_cycles) == len(conv_sizes) == len(drops)
self.save_hyperparameters()
self.nb_gcn_cycles: Tuple[int, ...] = nb_gcn_cycles
self.conv_sizes: Tuple[int, ...] = conv_sizes
self.drops: Tuple[float, ...] = drops
self.dropout: float = dropout

self.node_mlp = nn.Sequential(
Expand Down Expand Up @@ -156,24 +160,25 @@ def __init__(self, nb_gcn_cycles: Tuple[int, ...], conv_sizes: Tuple[int, ...],
) for _ in range(gcn_cycles)
]))

print(f'Largest Linear: {transition_size} x 32')
self.out_mlp = nn.Sequential(
nn.Linear(transition_size, 32),
nn.LeakyReLU(negative_slope=0.3),
LayerNorm(32),
nn.LeakyReLU(negative_slope=0.3),
nn.Dropout(self.dropout),
nn.Linear(32, 1),
)

def forward(self, inputs):
node_features, edge_index = inputs.x, inputs.edge_index
# drop_edge, _ = dropout_adj(edge_index, p=0.3, force_undirected=True, training=self.training)
prev_block_out = self.node_mlp(node_features)

for transition, convolutions in zip(self.transitions, self.conv_blocks):
for transition, convolutions, drop in zip(self.transitions, self.conv_blocks, self.drops):
drop_edge, _ = dropout_adj(edge_index, p=drop, force_undirected=True, training=self.training)
x = transition(prev_block_out)
states = [x]
for conv in convolutions:
x = conv(x, edge_index=edge_index)
x = conv(x, edge_index=drop_edge)
states.append(x)
x = torch.amax(torch.stack(states), dim=0)
prev_block_out = torch.cat([x, prev_block_out], dim=-1)
Expand Down
12 changes: 8 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

# Fix the seed for reproducibility
fix_random_seed(42)
experiment = ExperimentSetup(name='deep', create_latest=True, long_description="""
experiment = ExperimentSetup(name='drop', create_latest=True, long_description="""
Try dropping edges while training
Graphs are only of 'powerlaw' type.
Use unique convolutions.
Use blocks of convolutions followed with max pooling and skip connections
Expand All @@ -23,8 +24,11 @@
TensorBoardLogger(experiment.log_dir, name=experiment.name, default_hp_metric=False),
# AimLogger(experiment=experiment.name),
]
model = ABCDE(nb_gcn_cycles=(4, 4, 6, 6, 8),
conv_sizes=(64, 64, 32, 32, 16), lr_reduce_patience=2, dropout=0.1)
# For previous best changes needed: conv_sizes=(64, 64, 32, 32, 16, 16), drops=(0, 0, 0, 0, 0, 0)
model = ABCDE(nb_gcn_cycles=(4, 4, 6, 6, 8, 8),
conv_sizes=(48, 48, 32, 32, 24, 24),
drops=(0.5, 0.4, 0.3, 0.2, 0.1, 0),
lr_reduce_patience=2, dropout=0.1)
data = GraphDataModule(min_nodes=4000, max_nodes=5000, nb_train_graphs=160, nb_valid_graphs=240,
batch_size=16, graph_type='powerlaw', regenerate_epoch_interval=10,
repeats=8)
Expand All @@ -34,7 +38,7 @@
reload_dataloaders_every_epoch=True,
callbacks=[
EarlyStopping(monitor='val_kendal', patience=5, verbose=True, mode='max'),
ModelCheckpoint(dirpath=experiment.model_save_path, filename='deep-{epoch:02d}-{val_kendal:.2f}', monitor='val_kendal', save_top_k=5, verbose=True, mode='max'),
ModelCheckpoint(dirpath=experiment.model_save_path, filename='drop-{epoch:02d}-{val_kendal:.2f}', monitor='val_kendal', save_top_k=5, verbose=True, mode='max'),
LearningRateMonitor(logging_interval='step'),
])
trainer.fit(model, datamodule=data)

0 comments on commit 67aa6c7

Please sign in to comment.