-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
101 lines (86 loc) · 3.96 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# coding: utf-8
from __future__ import print_function
import tensorflow as tf
from nets import nets_factory
from preprocessing import preprocessing_factory
import utils
import os
slim = tf.contrib.slim
def gram(layer):
shape = tf.shape(layer)
num_images = shape[0]
width = shape[1]
height = shape[2]
num_filters = shape[3]
filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters]))
grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters)
return grams
def get_style_features(FLAGS):
"""
For the "style_image", the preprocessing step is:
1. Resize the shorter side to FLAGS.image_size
2. Apply central crop
"""
style_layers = FLAGS.style_layers.split(',')
with tf.Graph().as_default():
network_fn = nets_factory.get_network_fn(
FLAGS.loss_model,
num_classes=1,
is_training=False)
image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
size = FLAGS.style_size
img_bytes = tf.read_file(FLAGS.style_image)
if FLAGS.style_image.lower().endswith('png'):
image = tf.image.decode_png(img_bytes)
else:
image = tf.image.decode_jpeg(img_bytes)
# image = _aspect_preserving_resize(image, size)
images = tf.stack([image_preprocessing_fn(image, size, size)])
_, endpoints_dict = network_fn(images, spatial_squeeze=False)
features = []
for layer in style_layers:
feature = endpoints_dict[layer]
feature = tf.squeeze(gram(feature), [0]) # remove the batch dimension
features.append(feature)
with tf.Session() as sess:
init_func = utils._get_init_fn(FLAGS)
init_func(sess)
if os.path.exists('generated') is False:
os.makedirs('generated')
save_file = 'generated/' + FLAGS.style + "_" + str(FLAGS.style_size) + '.jpg'
with open(save_file, 'wb') as f:
target_image = image_unprocessing_fn(images[0, :])
value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8))
f.write(sess.run(value))
tf.logging.info('Target style pattern is saved to: %s.' % save_file)
return sess.run(features)
def style_loss(endpoints_dict, style_features_t, style_layers, style_layers_weights):
style_loss = 0
style_loss_summary = []
for style_gram, layer, weight in zip(style_features_t, style_layers, style_layers_weights):
generated_images, _ = tf.split(endpoints_dict[layer], 2, 0)
size = tf.size(generated_images)
layer_style_loss = tf.nn.l2_loss(gram(generated_images) - style_gram) * 2 / tf.to_float(size)
#generated_gram = gram(generated_images)
#size = tf.size(generated_gram)
#layer_style_loss = tf.nn.l2_loss(generated_gram - style_gram) * 2 / tf.to_float(size)
style_loss += weight * layer_style_loss
style_loss_summary.append(weight * layer_style_loss)
return style_loss, style_loss_summary
def content_loss(endpoints_dict, content_layers):
content_loss = 0
for layer in content_layers:
generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0)
size = tf.size(generated_images)
content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper
return content_loss
def total_variation_loss(layer):
shape = tf.shape(layer)
height = shape[1]
width = shape[2]
y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1])
x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1])
loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
return loss