Skip to content

Commit

Permalink
Merge pull request #36 from anant15/device_inception_resent_v1
Browse files Browse the repository at this point in the history
Add: device option to Inception_resnet_v1
  • Loading branch information
timesler authored Nov 10, 2019
2 parents 5a38bf5 + b12db1a commit 95c737f
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions models/inception_resnet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, x):


class Block35(nn.Module):

def __init__(self, scale=1.0):
super().__init__()

Expand Down Expand Up @@ -185,7 +185,7 @@ class InceptionResnetV1(nn.Module):
datasets. Pretrained state_dicts are automatically downloaded on model instantiation if
requested and cached in the torch cache. Subsequent instantiations use the cache rather than
redownloading.
Keyword Arguments:
pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'.
(default: {None})
Expand All @@ -196,7 +196,7 @@ class InceptionResnetV1(nn.Module):
initialized. (default: {None})
dropout_prob {float} -- Dropout probability. (default: {0.6})
"""
def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6):
def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None):
super().__init__()

# Set simple attributes
Expand All @@ -213,6 +213,7 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr
else:
tmp_classes = self.num_classes


# Define layers
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
Expand Down Expand Up @@ -258,16 +259,21 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr

if pretrained is not None:
load_weights(self, pretrained)

if self.num_classes is not None:
self.logits = nn.Linear(512, self.num_classes)

self.device = torch.device('cpu')
if device is not None:
self.device = device
self.to(device)

def forward(self, x):
"""Calculate embeddings or probabilities given a batch of input image tensors.
Arguments:
x {torch.tensor} -- Batch of image tensors representing faces.
Returns:
torch.tensor -- Batch of embeddings or softmax probabilities.
"""
Expand Down Expand Up @@ -296,11 +302,11 @@ def forward(self, x):

def load_weights(mdl, name):
"""Download pretrained state_dict and load into model.
Arguments:
mdl {torch.nn.Module} -- Pytorch model.
name {str} -- Name of dataset that was used to generate pretrained state_dict.
Raises:
ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'.
"""
Expand All @@ -315,7 +321,7 @@ def load_weights(mdl, name):

model_dir = os.path.join(get_torch_home(), 'checkpoints')
os.makedirs(model_dir, exist_ok=True)

state_dict = {}
for i, path in enumerate([features_path, logits_path]):
cached_file = os.path.join(model_dir, '{}_{}.pt'.format(name, path[-10:]))
Expand All @@ -327,7 +333,7 @@ def load_weights(mdl, name):
with open(cached_file, 'wb') as f:
f.write(r.content)
state_dict.update(torch.load(cached_file))

mdl.load_state_dict(state_dict)


Expand Down

0 comments on commit 95c737f

Please sign in to comment.