-
Notifications
You must be signed in to change notification settings - Fork 24
/
softmax.go
300 lines (264 loc) · 7.26 KB
/
softmax.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
// Copyright (c) 2021, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package decoder
import (
"bufio"
"compress/gzip"
"encoding/json"
"fmt"
"io"
"math"
"os"
"path/filepath"
"sort"
"cogentcore.org/core/math32"
"cogentcore.org/lab/base/mpi"
"cogentcore.org/lab/tensor"
"github.com/emer/emergent/v2/emer"
)
// SoftMax is a softmax decoder, which is the best choice for a 1-hot classification
// using the widely used SoftMax function: https://en.wikipedia.org/wiki/Softmax_function
type SoftMax struct {
// learning rate
Lrate float32 `default:"0.1"`
// layers to decode
Layers []emer.Layer
// number of different categories to decode
NCats int
// unit values
Units []SoftMaxUnit
// sorted list of indexes into Units, in descending order from strongest to weakest -- i.e., Sortedhas the most likely categorization, and its activity is Units].Act
Sorted []int
// number of inputs -- total sizes of layer inputs
NInputs int
// input values, copied from layers
Inputs []float32
// current target index of correct category
Target int
// for holding layer values
ValuesTsrs map[string]*tensor.Float32 `display:"-"`
// synaptic weights: outer loop is units, inner loop is inputs
Weights tensor.Float32
// mpi communicator -- MPI users must set this to their comm -- do direct assignment
Comm *mpi.Comm `display:"-"`
// delta weight changes: only for MPI mode -- outer loop is units, inner loop is inputs
MPIDWts tensor.Float32
}
// SoftMaxUnit has variables for softmax decoder unit
type SoftMaxUnit struct {
// final activation = e^Ge / sum e^Ge
Act float32
// net input = sum x * w
Net float32
// exp(Net)
Exp float32
}
// InitLayer initializes detector with number of categories and layers
func (sm *SoftMax) InitLayer(ncats int, layers []emer.Layer) {
sm.Layers = layers
nin := 0
for _, ly := range sm.Layers {
nin += ly.AsEmer().Shape.Len()
}
sm.Init(ncats, nin)
}
// Init initializes detector with number of categories and number of inputs
func (sm *SoftMax) Init(ncats, ninputs int) {
sm.NInputs = ninputs
sm.Lrate = 0.1 // seems pretty good
sm.NCats = ncats
sm.Units = make([]SoftMaxUnit, ncats)
sm.Sorted = make([]int, ncats)
sm.Inputs = make([]float32, sm.NInputs)
sm.Weights.SetShapeSizes(sm.NCats, sm.NInputs)
for i := range sm.Weights.Values {
sm.Weights.Values[i] = .1
}
}
// Decode decodes the given variable name from layers (forward pass)
// See Sorted list of indexes for the decoding output -- i.e., Sorted[0]
// is the most likely -- that is returned here as a convenience.
// di is a data parallel index di, for networks capable
// of processing input patterns in parallel.
func (sm *SoftMax) Decode(varNm string, di int) int {
sm.Input(varNm, di)
sm.Forward()
sm.Sort()
return sm.Sorted[0]
}
// Train trains the decoder with given target correct answer (0..NCats-1)
func (sm *SoftMax) Train(targ int) {
sm.Target = targ
sm.Back()
}
// TrainMPI trains the decoder with given target correct answer (0..NCats-1)
// MPI version uses mpi to synchronize weight changes across parallel nodes.
func (sm *SoftMax) TrainMPI(targ int) {
sm.Target = targ
sm.BackMPI()
}
// ValuesTsr gets value tensor of given name, creating if not yet made
func (sm *SoftMax) ValuesTsr(name string) *tensor.Float32 {
if sm.ValuesTsrs == nil {
sm.ValuesTsrs = make(map[string]*tensor.Float32)
}
tsr, ok := sm.ValuesTsrs[name]
if !ok {
tsr = &tensor.Float32{}
sm.ValuesTsrs[name] = tsr
}
return tsr
}
// Input grabs the input from given variable in layers
// di is a data parallel index di, for networks capable
// of processing input patterns in parallel.
func (sm *SoftMax) Input(varNm string, di int) {
off := 0
for _, ly := range sm.Layers {
lb := ly.AsEmer()
tsr := sm.ValuesTsr(lb.Name)
lb.UnitValuesTensor(tsr, varNm, di)
for j, v := range tsr.Values {
sm.Inputs[off+j] = v
}
off += lb.Shape.Len()
}
}
// Forward compute the forward pass from input
func (sm *SoftMax) Forward() {
max := float32(-math.MaxFloat32)
for ui := range sm.Units {
u := &sm.Units[ui]
net := float32(0)
off := ui * sm.NInputs
for j, in := range sm.Inputs {
net += sm.Weights.Values[off+j] * in
}
u.Net = net
if net > max {
max = net
}
}
sum := float32(0)
for ui := range sm.Units {
u := &sm.Units[ui]
u.Net -= max
u.Exp = math32.FastExp(u.Net)
sum += u.Exp
}
for ui := range sm.Units {
u := &sm.Units[ui]
u.Act = u.Exp / sum
}
}
// Sort updates Sorted indexes of the current Unit category activations sorted
// from highest to lowest. i.e., the 0-index value has the strongest
// decoded output category, 1 the next-strongest, etc.
func (sm *SoftMax) Sort() {
for i := range sm.Sorted {
sm.Sorted[i] = i
}
sort.Slice(sm.Sorted, func(i, j int) bool {
return sm.Units[sm.Sorted[i]].Act > sm.Units[sm.Sorted[j]].Act
})
}
// Back compute the backward error propagation pass
func (sm *SoftMax) Back() {
lr := sm.Lrate
for ui := range sm.Units {
u := &sm.Units[ui]
var del float32
if ui == sm.Target {
del = lr * (1 - u.Act)
} else {
del = -lr * u.Act
}
off := ui * sm.NInputs
for j, in := range sm.Inputs {
sm.Weights.Values[off+j] += del * in
}
}
}
// BackMPI compute the backward error propagation pass
// MPI version shares weight changes across nodes
func (sm *SoftMax) BackMPI() {
if sm.MPIDWts.Len() != sm.Weights.Len() {
tensor.SetShapeFrom(&sm.MPIDWts, &sm.Weights)
}
lr := sm.Lrate
for ui := range sm.Units {
u := &sm.Units[ui]
var del float32
if ui == sm.Target {
del = lr * (1 - u.Act)
} else {
del = -lr * u.Act
}
off := ui * sm.NInputs
for j, in := range sm.Inputs {
sm.MPIDWts.Values[off+j] = del * in
}
}
sm.Comm.AllReduceF32(mpi.OpSum, sm.MPIDWts.Values, nil)
for i, dw := range sm.MPIDWts.Values {
sm.Weights.Values[i] += dw
}
}
type softMaxForSerialization struct {
Weights []float32 `json:"weights"`
}
// Save saves the decoder weights to given file paths.
// If path ends in .gz, it will be gzipped.
func (sm *SoftMax) Save(path string) error {
file, err := os.Create(path)
if err != nil {
return err
}
defer file.Close()
ext := filepath.Ext(path)
var writer io.Writer
if ext == ".gz" {
gw := gzip.NewWriter(file)
defer gw.Close()
writer = gw
} else {
bw := bufio.NewWriter(file)
defer bw.Flush()
writer = bw
}
encoder := json.NewEncoder(writer)
return encoder.Encode(softMaxForSerialization{Weights: sm.Weights.Values})
}
// Load loads the decoder weights from given file paths.
// If the shape of the decoder does not match the shape of the saved weights,
// an error will be returned.
func (sm *SoftMax) Load(path string) error {
ext := filepath.Ext(path)
var reader io.Reader
file, err := os.Open(path)
if err != nil {
return err
}
defer file.Close()
if ext == ".gz" {
gr, err := gzip.NewReader(file)
if err != nil {
return err
}
defer gr.Close()
reader = gr
} else {
reader = bufio.NewReader(file)
}
decoder := json.NewDecoder(reader)
var s softMaxForSerialization
if err := decoder.Decode(&s); err != nil {
return err
}
if len(sm.Weights.Values) != len(s.Weights) {
return fmt.Errorf("loaded weights length %d does not match expected length %d", len(s.Weights), len(sm.Weights.Values))
}
sm.Weights.Values = s.Weights
return nil
}