diff --git a/models/inception_resnet_v1.py b/models/inception_resnet_v1.py index b9472282..83590bad 100644 --- a/models/inception_resnet_v1.py +++ b/models/inception_resnet_v1.py @@ -31,7 +31,7 @@ def forward(self, x): class Block35(nn.Module): - + def __init__(self, scale=1.0): super().__init__() @@ -185,7 +185,7 @@ class InceptionResnetV1(nn.Module): datasets. Pretrained state_dicts are automatically downloaded on model instantiation if requested and cached in the torch cache. Subsequent instantiations use the cache rather than redownloading. - + Keyword Arguments: pretrained {str} -- Optional pretraining dataset. Either 'vggface2' or 'casia-webface'. (default: {None}) @@ -196,7 +196,7 @@ class InceptionResnetV1(nn.Module): initialized. (default: {None}) dropout_prob {float} -- Dropout probability. (default: {0.6}) """ - def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6): + def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_prob=0.6, device=None): super().__init__() # Set simple attributes @@ -213,6 +213,7 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr else: tmp_classes = self.num_classes + # Define layers self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) @@ -258,16 +259,21 @@ def __init__(self, pretrained=None, classify=False, num_classes=None, dropout_pr if pretrained is not None: load_weights(self, pretrained) - + if self.num_classes is not None: self.logits = nn.Linear(512, self.num_classes) + self.device = torch.device('cpu') + if device is not None: + self.device = device + self.to(device) + def forward(self, x): """Calculate embeddings or probabilities given a batch of input image tensors. - + Arguments: x {torch.tensor} -- Batch of image tensors representing faces. - + Returns: torch.tensor -- Batch of embeddings or softmax probabilities. """ @@ -296,11 +302,11 @@ def forward(self, x): def load_weights(mdl, name): """Download pretrained state_dict and load into model. - + Arguments: mdl {torch.nn.Module} -- Pytorch model. name {str} -- Name of dataset that was used to generate pretrained state_dict. - + Raises: ValueError: If 'pretrained' not equal to 'vggface2' or 'casia-webface'. """ @@ -315,7 +321,7 @@ def load_weights(mdl, name): model_dir = os.path.join(get_torch_home(), 'checkpoints') os.makedirs(model_dir, exist_ok=True) - + state_dict = {} for i, path in enumerate([features_path, logits_path]): cached_file = os.path.join(model_dir, '{}_{}.pt'.format(name, path[-10:])) @@ -327,7 +333,7 @@ def load_weights(mdl, name): with open(cached_file, 'wb') as f: f.write(r.content) state_dict.update(torch.load(cached_file)) - + mdl.load_state_dict(state_dict)