Skip to content

Commit

Permalink
deeper drop edge network
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinXPN committed Feb 6, 2021
1 parent d6b383d commit e6ff617
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion abcde/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __post_init__(self, save_path):
workers: int = max(os.cpu_count(), 1)

if save_path is not None and save_path.exists():
print('Found existing dataset at', save_path, 'loading it...')
print(f'Found existing dataset at {save_path} => loading it...')
self.graphs = torch.load(save_path)
return

Expand Down
2 changes: 1 addition & 1 deletion abcde/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def dropout_adj(edge_index, edge_attr=None, p=0.5, force_undirected=False,
raise ValueError('Dropout probability has to be between 0 and 1, '
'but got {}'.format(p))

if not training:
if not training or p == 0:
return edge_index, edge_attr

N = maybe_num_nodes(edge_index, num_nodes)
Expand Down
9 changes: 5 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
AimLogger(experiment=experiment.name),
]
# Previous best: nb_gcn_cycles=(4, 4, 6, 6, 8), conv_sizes=(64, 64, 32, 32, 16), drops=(0, 0, 0, 0, 0)
model = ABCDE(nb_gcn_cycles=(4, 4, 6, 6, 8),
conv_sizes=(64, 64, 32, 32, 16),
drops=(0, 0, 0, 0, 0),
model = ABCDE(nb_gcn_cycles=(4, 4, 6, 6, 8, 8),
conv_sizes=(64, 48, 32, 32, 24, 24),
drops=(0.4, 0.3, 0.2, 0.2, 0.1, 0.1),
lr_reduce_patience=2, dropout=0.1)
data = GraphDataModule(min_nodes=4000, max_nodes=5000, nb_train_graphs=160, nb_valid_graphs=240,
batch_size=16, graph_type='powerlaw', repeats=8, regenerate_epoch_interval=10,
Expand All @@ -41,8 +41,9 @@
max_epochs=100, terminate_on_nan=True, enable_pl_optimizer=True,
reload_dataloaders_every_epoch=True,
callbacks=[
EarlyStopping(monitor='val_kendal', patience=7, verbose=True, mode='max'),
EarlyStopping(monitor='val_kendal', patience=5, verbose=True, mode='max'),
ModelCheckpoint(dirpath=experiment.model_save_path, filename='drop-{epoch:02d}-{val_kendal:.2f}', monitor='val_kendal', save_top_k=5, verbose=True, mode='max'),
LearningRateMonitor(logging_interval='epoch'),
])
trainer.fit = telegram_sender(token='1653878275:AAEIr-mLt9-SSAyYPon1n-CgFQpINjUWHDw', chat_id=695404691)(trainer.fit)
trainer.fit(model, datamodule=data)

0 comments on commit e6ff617

Please sign in to comment.