-
Notifications
You must be signed in to change notification settings - Fork 15
/
train_variants.py
134 lines (95 loc) · 3.98 KB
/
train_variants.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from datetime import datetime
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import gc
import train_loop
def shuffle(images, labels):
"""Return shuffled copies of the arrays, keeping the indexes of
both arrays in corresponding places
"""
cp_images = np.copy(images)
cp_labels = np.copy(labels)
rng_state = np.random.get_state()
np.random.shuffle(cp_images)
np.random.set_state(rng_state)
np.random.shuffle(cp_labels)
return cp_images, cp_labels
def split_train_and_test(images, labels, ratio=0.8):
"""Splits the array into two randomly chosen arrays of training and testing data.
ratio indicates which percentage will be part of the training set."""
images, labels = shuffle(images, labels)
split = int(images.shape[0] * ratio)
training_images = images[:split]
training_labels = labels[:split]
test_images = images[split:]
test_labels = labels[split:]
return [training_images, training_labels], [test_images, test_labels]
def create_sets(num, images, labels):
"""Splits the array into num equally sized sets."""
images, labels = shuffle(images, labels)
set_size = images.shape[0] // num
remaining = images.shape[0] - set_size * num
image_sets = []
label_sets = []
offset = 0
for i in range(num):
extra = 1 if i < remaining else 0
image_sets.append(images[i*set_size + offset:i*set_size + set_size + offset + extra])
label_sets.append(labels[i*set_size + offset:i*set_size + set_size + offset + extra])
offset += extra
return image_sets, label_sets
def get_rotations(num, image_sets, label_sets):
"""Create rotations of the training and test sets for cross validation training
This means if image_sets = [A, B, C] the output will be [[A, B], [B, C], [A, C]]
for the training set."""
training_sets = []
test_sets = []
for i in range(num):
test_sets.append((
image_sets[i],
label_sets[i]
))
training_sets.append((
np.concatenate([s for j, s in enumerate(image_sets) if j != i]),
np.concatenate([s for j, s in enumerate(label_sets) if j != i])
))
return training_sets, test_sets
def train_single(inFile, size=512):
"""Train network a single time using the given files as input.
inFile => path without extension (more than one file will be read)
"""
print('Training...')
# Load data
images = np.load(inFile + '.npy', mmap_mode='r')
labels = np.load(inFile + '_labels.npy', mmap_mode='r')
# Create training and test sets
training, test = split_train_and_test(images, labels)
train_loop.train_net(training, test, size=size)
def train_cross_validation(inFile, sets=3, size=512):
"""Train network multiple times in a cross validation fashon, in order to
cover all the dataset in the tests and avoid the effect of outliers.
inFile => path without extension (more than one file will be read)
sets => number of cross validation sets (training will be repeated this many times
and the size of the test set will be dataset_size / sets)
"""
print('Starting {}-fold cross validation study...'.format(sets))
# Load data
images = np.load(inFile + '.npy', mmap_mode='r')
labels = np.load(inFile + '_labels.npy', mmap_mode='r')
# Create training and test sets for the cross validation study
image_sets, label_sets = create_sets(sets, images, labels)
training_sets, test_sets = get_rotations(sets, image_sets, label_sets)
# import matplotlib.pyplot as plt
# plt.imshow(training_sets[0][0][0,:,:,0]); plt.show();
for i in range(sets):
print('Set {}'.format(i+1))
train_loop.train_net(
training_sets[i],
test_sets[i],
size=size,
run_name='Set {} ({})'.format(i+1, datetime.now().strftime(r'%Y-%m-%d_%H:%M')),
)
tf.reset_default_graph()
gc.collect()