-
Notifications
You must be signed in to change notification settings - Fork 7
/
hyperparameter.py
107 lines (91 loc) · 3.12 KB
/
hyperparameter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from functools import partial
from network import *
import torch.optim as optim
def get_hyperparameters(network_type):
if network_type == 'mlp':
network = MaskedMLP([784, 500, 500, 500, 500, 10]).cuda()
prune_ratios = [.5, .5, .5, .5, .25]
optimizer = partial(optim.Adam, lr=0.0012)
pretrain_iteration = 50000
finetune_iteration = 50000
batch_size = 60
elif network_type == 'lenet':
network = MaskedLeNet().cuda()
prune_ratios = [.2, .2, .3, .3, .15]
optimizer = partial(optim.Adam, lr=0.0012)
pretrain_iteration = 50000
finetune_iteration = 50000
batch_size = 60
elif network_type == 'conv6':
network = MaskedConv6().cuda()
prune_ratios = [.15, .15, .15, .15, .15, .15, .20, .20, .10]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 30000
finetune_iteration = 20000
batch_size = 60
elif network_type == 'vgg11':
network = MaskedVGG11(use_bn=True).cuda()
prune_ratios = [.15] * 8 + [.10]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 35000
finetune_iteration = 25000
batch_size = 60
elif network_type == 'vgg16':
network = MaskedVGG16(use_bn=True).cuda()
prune_ratios = [.15] * 13 + [.10]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 50000
finetune_iteration = 35000
batch_size = 60
elif network_type == 'vgg19':
network = MaskedVGG19(use_bn=True).cuda()
prune_ratios = [.15] * 16 + [.10]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 60000
finetune_iteration = 40000
batch_size = 60
elif network_type == 'resnet18':
network = MaskedResNet18().cuda()
prune_ratios = [0] + [.15] * 16 + [.10]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 35000
finetune_iteration = 25000
batch_size = 60
elif network_type == 'vgg19_64':
network = MaskedVGG19_64(use_bn=True).cuda()
prune_ratios = [.15] * 16 + [.10]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 60000
finetune_iteration = 40000
batch_size = 60
elif network_type == 'resnet50_64':
network = MaskedResNet50_64().cuda()
prune_ratios = [0] + [.15] * 48 + [.10]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 60000
finetune_iteration = 40000
batch_size = 60
elif network_type == 'wrn-16-8_64':
network = MaskedWideResNet_64(16, 8).cuda()
prune_ratios = [0] + [.15] * 12 + [.10]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 60000
finetune_iteration = 40000
batch_size = 60
elif network_type == 'mlp_global':
network = MaskedMLP([784, 500, 500, 500, 500, 10]).cuda()
prune_ratios = [0.9364, 0.9679, 0.9837, 0.9916, 0.9957, 0.9977, 0.9988]
optimizer = partial(optim.Adam, lr=0.0012)
pretrain_iteration = 50000
finetune_iteration = 50000
batch_size = 60
elif network_type == 'conv6_global':
network = MaskedConv6().cuda()
prune_ratios = [0.8938, 0.9114, 0.9261, 0.9382, 0.9483, 0.9568, 0.9638]
optimizer = partial(optim.Adam, lr=0.0003)
pretrain_iteration = 30000
finetune_iteration = 20000
batch_size = 60
else:
raise ValueError('Unknown network')
return network, prune_ratios, optimizer, pretrain_iteration, finetune_iteration, batch_size