forked from markozeman/SuperpositionDevelopment
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plots.py
199 lines (165 loc) · 6.68 KB
/
plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import math
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
def show_grayscale_image(img):
"""
Show grayscale image (1 channel).
:param img: image to plot
:return: None
"""
plt.imshow(img, cmap='gray', vmin=0, vmax=255)
plt.show()
def show_image(img):
"""
Show coloured image (3 channels).
:param img: image to plot
:return: None
"""
plt.imshow(img)
plt.show()
def plot_general(line_1, line_2, legend_lst, title, x_label, y_label, vertical_lines_x, vl_min, vl_max, text_strings=None):
"""
Plot two lines on the same plot with additional general information.
:param line_1: y values of the first line
:param line_2: y values of the second line
:param legend_lst: list of two values -> [first line label, second line label]
:param title: plot title (string)
:param x_label: label of axis x (string)
:param y_label: label of axis y (string)
:param vertical_lines_x: x values of where to draw vertical lines
:param vl_min: vertical lines minimum y value
:param vl_max: vertical lines maximum y value
:param text_strings: optional list of text strings to add to the bottom of vertical lines
:return: None
"""
font = {'size': 18}
plt.rc('font', **font)
plt.plot(line_1, linewidth=3)
plt.plot(line_2, linewidth=3)
plt.vlines(vertical_lines_x, vl_min, vl_max, colors='k', alpha=0.5, linestyles='dotted', linewidth=3)
plt.legend(legend_lst)
plt.title(title)
plt.xlabel(x_label)
plt.ylabel(y_label)
if text_strings is not None:
for i in range(len(text_strings)):
plt.text(vertical_lines_x[i] + 0.25, vl_min, text_strings[i], colors='k', alpha=0.5)
plt.show()
def plot_many_lines(lines, legend, title, x_label, y_label):
"""
Plot many lines (of the same length) on the x-axis.
:param lines: list of lists of values
:param legend: label for each line (len(lines) = len(legend))
param title: plot title (string)
:param x_label: label of axis x (string)
:param y_label: label of axis y (string)
:return: None
"""
for l in lines:
plt.plot(l)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(title)
plt.legend(legend)
plt.show()
def plot_weights_histogram(x, bins):
"""
Plot weights values on histogram.
:param x: data/values to plot
:param bins: number of bins on histogram
:return: None
"""
plt.hist(x, bins=bins)
plt.title('Values of trained weights in the network')
plt.xlabel('Weight value')
plt.ylabel('Occurrences')
plt.show()
def weights_heatmaps(W_matrices, labels, task_index):
"""
Plot heat maps of weights from layers in the network.
:param W_matrices: list of 2D numpy arrays which represent weights between layers
:param labels: list of strings to put in the plot title
:param task_index: integer index of the current task
:return: None
"""
# norm_matrix = (W_matrix - np.min(W_matrix)) / np.ptp(W_matrix) # normalise matrix between [0, 1]
plt.figure()
if len(W_matrices) <= 3:
plot_layout = (1, len(W_matrices))
else:
plot_layout = (2, math.ceil(len(W_matrices) / 2))
for layer_index, weights_matrix in enumerate(W_matrices):
plt.subplot(*plot_layout, layer_index + 1)
sns.heatmap(weights_matrix, cmap='coolwarm', linewidth=0) if layer_index < 2 or layer_index > 3 \
else sns.heatmap(weights_matrix, cmap='Blues', linewidth=0)
plt.title("Task %d || %s" % (task_index, labels[layer_index]))
plt.tight_layout()
plt.show()
# plt.savefig('../../../Plots/Reproducible results/splitMNIST/20units_50epochs/W heatmaps/plot_%s.png' % str(task_index), bbox_inches='tight', dpi=300)
def plot_confusion_matrix(conf_mat):
"""
Plot already calculated confusion matrix.
:param conf_mat: 2D confusion matrix
:return: None
"""
# compute accuracy
all_cases = conf_mat.sum()
correct_cases = conf_mat.diagonal().sum()
acc = (correct_cases / all_cases) * 100
size = conf_mat.shape[0]
df_cm = pd.DataFrame(conf_mat, range(size), range(size))
sns.heatmap(df_cm, annot=True, cmap='Blues', linewidth=0, fmt='d')
plt.title('Test accuracy: %d / %d = %.2f %%' % (correct_cases, all_cases, acc))
plt.xlabel('predicted')
plt.ylabel('true')
plt.show()
def plot_multiple_results(data, legend_lst, colors, x_label, y_label, vertical_lines_x, vl_min, vl_max, show_CI=True, text_strings=None):
"""
Plot more lines from the saved results on the same plot with additional information.
:param data: list of 2D matrices, each matrix has more samples of the same experiment (number of experiments x length of experiment)
:param legend_lst: list of label values (length of data)
:param colors: list of colors used for lines (length of data)
:param x_label: label of axis x (string)
:param y_label: label of axis y (string)
:param vertical_lines_x: x values of where to draw vertical lines
:param vl_min: vertical lines minimum y value
:param vl_max: vertical lines maximum y value
:param show_CI: show confidence interval range (boolean)
:param text_strings: optional list of text strings to add to the bottom of vertical lines
:return: None
"""
# font = {'size': 20}
# plt.rc('font', **font)
# plot lines with confidence intervals
for i, data in enumerate(data):
matrix = np.array(data)
mean = np.mean(matrix, axis=0)
std = np.std(matrix, axis=0)
# take only every n-th element of the array
n = 1
mean = mean[0::n]
std = std[0::n]
# plot the shaded range of the confidence intervals (mean +/- 2*std)
if show_CI:
up_limit = mean + (2 * std)
up_limit[up_limit > 100] = 100 # cut accuracies above 100
down_limit = mean - (2 * std)
plt.fill_between(range(0, mean.shape[0] * n, n), up_limit, down_limit, color=colors[i], alpha=0.25)
# plot the mean on top (every other line is dashed)
if i % 2 == 0:
plt.plot(range(0, mean.shape[0] * n, n), mean, colors[i], linewidth=3)
else:
plt.plot(range(0, mean.shape[0] * n, n), mean, colors[i], linewidth=3, linestyle='--')
if legend_lst:
plt.legend(legend_lst)
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.vlines(vertical_lines_x, vl_min, vl_max, colors='k', linestyles='dashed', linewidth=2, alpha=0.5)
if text_strings is not None:
for i in range(len(text_strings)):
plt.text(vertical_lines_x[i] + 0.5, vl_min, text_strings[i], color='k', alpha=0.5)
plt.show()