-
Notifications
You must be signed in to change notification settings - Fork 3
/
predict.py
177 lines (136 loc) · 6.93 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import argparse
import numpy as np
import matplotlib.pyplot as plt
import skimage.io
from keras.models import load_model
from constants import verbosity, save_dir, overlap, \
model_name, tests_path, input_width, input_height, scale_fact
from utils import float_im
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('image', type=str,
help='image name (example: "bird.png") that must be inside the "./input/" folder')
parser.add_argument('-m', '--model', type=str, default=model_name,
help='model name (in the "./save/" folder), followed by ".h5"')
parser.add_argument('-s', '--save', type=str, default='your_image.png',
help='the name of the saved image which will appear inside the "output" folder')
args = parser.parse_args()
def predict(args):
"""
Super-resolution on the input image using the model.
:param args:
:return:
'predictions' contains an array of every single cropped sub-image once enhanced (the outputs of the model).
'image' is the original image, untouched.
'crops' is the array of every single cropped sub-image that will be used as input to the model.
"""
model = load_model(save_dir + '/' + args.model)
image = skimage.io.imread(tests_path + args.image)[:, :, :3] # removing possible extra channels (Alpha)
print("Image shape:", image.shape)
predictions = []
images = []
# Padding and cropping the image
overlap_pad = (overlap, overlap) # padding tuple
pad_width = (overlap_pad, overlap_pad, (0, 0)) # assumes color channel as last
padded_image = np.pad(image, pad_width, 'constant') # padding the border
crops = seq_crop(padded_image) # crops into multiple sub-parts the image based on 'input_' constants
# Arranging the divided image into a single-dimension array of sub-images
for i in range(len(crops)): # amount of vertical crops
for j in range(len(crops[0])): # amount of horizontal crops
current_image = crops[i][j]
images.append(current_image)
print("Moving on to predictions. Amount:", len(images))
upscaled_overlap = overlap * 2
for p in range(len(images)):
if p % 3 == 0 and verbosity == 2:
print("--prediction #", p)
# Hack due to some GPUs that can only handle one image at a time
input_img = (np.expand_dims(images[p], 0)) # Add the image to a batch where it's the only member
pred = model.predict(input_img)[0] # returns a list of lists, one for each image in the batch
# Cropping the useless parts of the overlapped predictions (to prevent the repeated erroneous edge-prediction)
pred = pred[upscaled_overlap:pred.shape[0]-upscaled_overlap, upscaled_overlap:pred.shape[1]-upscaled_overlap]
predictions.append(pred)
return predictions, image, crops
def show_pred_output(input, pred):
plt.figure(figsize=(20, 20))
plt.suptitle("Results")
plt.subplot(1, 2, 1)
plt.title("Input : " + str(input.shape[1]) + "x" + str(input.shape[0]))
plt.imshow(input, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)
plt.subplot(1, 2, 2)
plt.title("Output : " + str(pred.shape[1]) + "x" + str(pred.shape[0]))
plt.imshow(pred, cmap=plt.cm.binary).axes.get_xaxis().set_visible(False)
plt.show()
# adapted from https://stackoverflow.com/a/52463034/9768291
def seq_crop(img):
"""
To crop the whole image in a list of sub-images of the same size.
Size comes from "input_" variables in the 'constants' (Evaluation).
Padding with 0 the Bottom and Right image.
:param img: input image
:return: list of sub-images with defined size (as per 'constants')
"""
sub_images = [] # will contain all the cropped sub-parts of the image
j, shifted_height = 0, 0
while shifted_height < (img.shape[0] - input_height):
horizontal = []
shifted_height = j * (input_height - overlap)
i, shifted_width = 0, 0
while shifted_width < (img.shape[1] - input_width):
shifted_width = i * (input_width - overlap)
horizontal.append(crop_precise(img,
shifted_width,
shifted_height,
input_width,
input_height))
i += 1
sub_images.append(horizontal)
j += 1
return sub_images
def crop_precise(img, coord_x, coord_y, width_length, height_length):
"""
To crop a precise portion of an image.
When trying to crop outside of the boundaries, the input to padded with zeros.
:param img: image to crop
:param coord_x: width coordinate (top left point)
:param coord_y: height coordinate (top left point)
:param width_length: width of the cropped portion starting from coord_x (toward right)
:param height_length: height of the cropped portion starting from coord_y (toward bottom)
:return: the cropped part of the image
"""
tmp_img = img[coord_y:coord_y + height_length, coord_x:coord_x + width_length]
return float_im(tmp_img) # From [0,255] to [0.,1.]
# adapted from https://stackoverflow.com/a/52733370/9768291
def reconstruct(predictions, crops):
"""
Used to reconstruct a whole image from an array of mini-predictions.
The image had to be split in sub-images because the GPU's memory
couldn't handle the prediction on a whole image.
:param predictions: an array of upsampled images, from left to right, top to bottom.
:param crops: 2D array of the cropped images
:return: the reconstructed image as a whole
"""
# unflatten predictions
def nest(data, template):
data = iter(data)
return [[next(data) for _ in row] for row in template]
if len(crops) != 0:
predictions = nest(predictions, crops)
# At this point "predictions" is a 3D image of the individual outputs
H = np.cumsum([x[0].shape[0] for x in predictions])
W = np.cumsum([x.shape[1] for x in predictions[0]])
D = predictions[0][0]
recon = np.empty((H[-1], W[-1], D.shape[2]), D.dtype)
for rd, rs in zip(np.split(recon, H[:-1], 0), predictions):
for d, s in zip(np.split(rd, W[:-1], 1), rs):
d[...] = s
# Removing the pad from the reconstruction
tmp_overlap = overlap * (scale_fact - 1) # using "-2" leaves the outer edge-prediction error
return recon[tmp_overlap:recon.shape[0]-tmp_overlap, tmp_overlap:recon.shape[1]-tmp_overlap]
if __name__ == '__main__':
print(" - ", args)
preds, original, crops = predict(args) # returns the predictions along with the original
enhanced = reconstruct(preds, crops) # reconstructs the enhanced image from predictions
# Save and display the result
enhanced = np.clip(enhanced, 0, 1)
plt.imsave('output/' + args.save, enhanced, cmap=plt.cm.gray)
show_pred_output(original, enhanced)