forked from artyom-beilis/zx_spectrum_deep_learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.bas
231 lines (231 loc) · 7.44 KB
/
train.bas
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
5 REM "Params very fast config"
10 LET kernels=4
15 LET ksize=3
20 LET itersize=4
30 LET clsno=2
35 LET batch=64: LET epochs=2
36 LET kernels=12: LET itersize=1: LET epochs=5: LET clsno=10
40 LET blr=0.01 : LET blr2=0.001 : LET iblr=1/blr
50 LET wd=0.0005
60 LET beta=0.90
61 REM "Params End"
62 LET intsz=8-ksize+1
64 LET poolsz=INT(intsz/2)
66 LET fltsz=poolsz*poolsz*kernels
67 LET wdcomp=1-wd
68 DIM m(3,clsno,fltsz) : REM "IP Matrix"
70 DIM b(3,clsno) : REM "IP Offset"
75 DIM k(3,kernels,ksize,ksize) : REM "Kernels"
80 DIM o(3,kernels) : REM "Kernel Offset"
90 DIM c(2,kernels,intsz,intsz) : REM "conv_res"
100 DIM p(2,fltsz) : REM "pool res"
110 DIM r(2,clsno) : REM "probs"
120 DIM s(fltsz) : REM "pool mask"
130 DIM d(8,8) : REM "digit"
140 DIM z(20): LET timep=1
141 DIM z$(20,2)
200 REM "Functions"
210 DEF FN g(s)=s*(RND+RND+RND+RND+RND+RND+RND+RND+RND+RND+RND+RND-6)
220 DEF FN r(x)=(x>0)*x
300 LOAD ""SCREEN$
311 REM "Initalization"
312 PRINT AT 21,0;"INI ";
315 LET sigma=1/(kernels*ksize*ksize)
320 FOR n=1 TO kernels: FOR r=1 TO ksize: FOR c=1 TO ksize: LET k(1,n,r,c)=FN g(sigma): NEXT c: NEXT r: NEXT n
325 LET sigma=2/(fltsz + clsno)
330 FOR r=1 TO clsno: FOR c=1 TO fltsz: LET m(1,r,c)=FN g(sigma): NEXT c: NEXT r
400 REM "train loop"
410 FOR e=1 TO epochs
420 IF e>=2 THEN LET blr=blr2: LET iblr=1/blr
422 GO SUB 9000: LET start=time
425 LET acc=0: LET count=0
427 LET iter=0
430 FOR b=0 TO batch-1
435 LET loss=0
440 FOR d=0 TO clsno-1
445 LET n$="GP": GO SUB 9020
450 GO SUB 8000: REM "Get pixel from d,b"
455 LET mark=24: GO SUB 8050: REM "Mark"
460 GO SUB 3500 : REM "Forward"
470 GO SUB 4000 : REM "Backward"
480 LET acc=acc + accres
490 LET count=count+1
500 LET loss=loss+lossres
510 LET mark=16*(1+accres): GO SUB 8050: REM "Mark Res"
520 NEXT d
525 LET iter=iter+1
530 IF iter=itersize THEN GO SUB 5000: LET iter=0: REM "apply update"
540 NEXT b
542 GO SUB 9000: LET pass=time-start
545 LET acc=INT(acc / count * 1000)/10
550 PRINT AT 20,0;"Epoch=";e;" Acc=";acc;"% time ";INT(pass+0.5);"m "
560 NEXT e
999 REM "Test"
1000 LOAD ""SCREEN$
1422 GO SUB 9000: LET start=time
1425 LET acc=0: LET count=0
1430 FOR b=0 TO batch-1
1440 FOR d=0 TO clsno-1
1450 GO SUB 8000: REM "Get pixel from d,b"
1455 LET mark=24: GO SUB 8050: REM "Mark"
1460 GO SUB 3500 : REM "Forward"
1480 LET acc=acc + accres
1490 LET count=count+1
1510 LET mark=16*(1+accres): GO SUB 8050 : REM "Mark Res"
1520 NEXT d
1540 NEXT b
1542 GO SUB 9000: LET pass=time-start
1545 LET acc=INT(acc / count * 1000)/10
1550 PRINT AT 20,0;"Test Acc=";acc;"% time ";INT(pass+0.5);"m ";
1590 GO TO 9999: REM "EEEENNNNDDD"
3498 REM "Forward Prop"
3499 REM "Conv FWD"
3500 PRINT AT 21,0;"FW ";
3501 LET n$="FC": GO SUB 9020
3505 FOR r=1 TO intsz: FOR c=1 TO intsz
3510 FOR n=1 TO kernels: LET c(1,n,r,c)=o(1,n): NEXT n
3520 FOR i=1 TO ksize: FOR j=1 TO ksize
3530 IF d(r+i-1,c+j-1) THEN FOR n=1 TO kernels: LET c(1,n,r,c)=c(1,n,r,c) + k(1,n,i,j) : NEXT n
3540 NEXT j: NEXT i
3560 NEXT c: NEXT r
3570 LET n$="FP": GO SUB 9020
3599 REM "Max Pool Relu"
3600 LET pos=1
3605 FOR n=1 TO kernels: FOR r=1 TO poolsz*2 STEP 2 : FOR c=1 TO poolsz*2 STEP 2
3610 LET index=0 : LET maxv=c(1,n,r,c):
3620 IF c(1,n,r ,c+1) > maxv THEN LET maxv=c(1,n,r ,c+1) : LET index=1
3630 IF c(1,n,r+1,c ) > maxv THEN LET maxv=c(1,n,r+1,c ) : LET index=2
3640 IF c(1,n,r+1,c+1) > maxv THEN LET maxv=c(1,n,r+1,c+1) : LET index=3
3650 LET s(pos)=index : LET p(1,pos)=FN r(maxv) : LET pos=pos+1
3660 NEXT c: NEXT r: NEXT n
3670 LET n$="FI": GO SUB 9020
3699 REM "IP Forward"
3700 FOR i=1 TO clsno
3710 LET sum=b(1,i)
3720 FOR j=1 TO fltsz
3730 LET sum=sum+p(1,j)*m(1,i,j)
3740 NEXT j
3750 LET r(1,i)=sum
3760 NEXT i
3770 LET n$="FL": GO SUB 9020
3799 REM "Loss"
3800 LET maxind=0 : LET maxv=r(1,1): LET sum=0
3810 FOR i=2 TO clsno:
3820 IF r(1,i) > maxv THEN LET maxv=r(1,i) : LET maxind=i-1
3830 NEXT i
3840 FOR i=1 TO clsno
3850 LET tgt = (i-1) = d
3860 LET sdiff=tgt-r(1,i): LET sum=sum+sdiff*sdiff
3870 NEXT i
3880 LET lossres=0.5*sum: LET accres=(maxind=d)
3890 RETURN
3900 REM "END OF LOSS FW"
3998 REM "BACK PROP"
3999 REM "Loss Backward"
4000 PRINT AT 21,0;"BW ";
4001 LET n$="BL": GO SUB 9020
4005 FOR i=1 TO clsno
4010 LET tgt = (i-1) = d
4020 LET r(2,i) = r(1,i) - tgt
4030 NEXT i
4035 LET n$="BI":GO SUB 9020
4049 REM "IP Backward"
4050 FOR k=1 TO clsno
4060 LET b(2,k) = b(2,k) + r(2,k)
4070 NEXT k
4080 FOR j=1 TO fltsz : LET p(2,j) = 0: NEXT j
4090 FOR i=1 TO clsno : FOR j=1 TO fltsz
4100 LET m(2,i,j) = m(2,i,j) + p(1,j)*r(2,i)
4110 LET p(2,j)=p(2,j) + m(1,i,j) * r(2,i)
4120 NEXT j: NEXT i
4121 LET n$="BP":GO SUB 9020
4200 REM "Pool/ReLU backward"
4210 FOR k=1 TO fltsz
4220 IF p(1,k) <= 0 THEN LET p(2,k) = 0
4230 NEXT k
4240 LET pos=1
4250 FOR k=1 TO kernels:
4260 FOR r=1 TO intsz STEP 2 : FOR c=1 TO intsz STEP 2
4270 LET indx=s(pos)
4280 LET topv=p(2,pos)
4290 LET c(2,k,r ,c ) = (indx=0) * topv
4300 LET c(2,k,r ,c+1) = (indx=1) * topv
4310 LET c(2,k,r+1,c ) = (indx=2) * topv
4320 LET c(2,k,r+1,c+1) = (indx=3) * topv
4330 LET pos=pos+1
4340 NEXT c: NEXT r
4350 NEXT k
4355 LET n$="Bo":GO SUB 9020
4360 REM "Conv Bias BW"
4400 FOR n=1 TO kernels
4410 LET sum=0
4420 FOR i=1 TO intsz: FOR j=1 TO intsz
4430 LET sum=sum+c(2,n,i,j)
4440 NEXT j : NEXT i
4450 LET o(2,n) = o(2,n) + sum
4460 NEXT n
4467 LET n$="BC":GO SUB 9020
4499 REM "Conv BW"
4500 FOR r=1 TO intsz: FOR c=1 TO intsz:
4510 FOR i=1 TO ksize : FOR j=1 TO ksize
4520 IF d(r+i-1,c+j-1) THEN FOR n=1 TO kernels: LET k(2,n,i,j) = k(2,n,i,j) + c(2,n,r,c) : NEXT n
4530 NEXT j: NEXT i
4540 NEXT c: NEXT r
4545 LET n$="EE": GO SUB 9020: GO SUB 9200
4550 RETURN
4599 REM "End of BW"
4998 REM "Apply Update"
4999 REM "IP matrix update"
5001 PRINT AT 21,0;"UPD ";
5002 FOR i=1 TO clsno
5005 FOR j=1 TO fltsz
5010 LET m(3,i,j) = m(3,i,j) * beta + m(2,i,j) * blr
5020 LET m(1,i,j) = wdcomp * m(1,i,j) - m(3,i,j)
5030 LET m(2,i,j) = 0
5040 NEXT j
5049 REM "Matrix Offset"
5050 LET b(3,i) = b(3,i) * beta + b(2,i) * blr
5060 LET b(1,i) = wdcomp * b(1,i) - b(3,i)
5070 LET b(2,i) = 0
5080 NEXT i
5099 REM "Kernel update"
5100 FOR i=1 TO kernels
5105 FOR j=1 TO ksize: FOR k=1 TO ksize
5110 LET k(3,i,j,k) = k(3,i,j,k) * beta + k(2,i,j,k) * blr
5120 LET k(1,i,j,k) = wdcomp * k(1,i,j,k) - k(3,i,j,k)
5130 LET k(2,i,j,k) = 0
5140 NEXT k: NEXT j
5149 REM "Kernel Offset"
5150 LET o(3,i) = o(3,i) * beta + o(2,i) * blr
5160 LET o(1,i) = wdcomp * o(1,i) - o(3,i)
5170 LET o(2,i) = 0
5180 NEXT i
5190 RETURN
5299 REM "End of Apply Update"
7999 REM "Fill l with digit, b=0-63, d 0 to 9
8000 LET row=((b>=32) + d*2) * 8
8010 LET col=((b>=32)*(b-32)+(b<32)*b) * 8
8020 FOR r=0 TO 7: FOR c=0 TO 7
8030 LET d(r+1,c+1)=POINT((col+c),(175-row-r))
8035 REM PRINT AT 4+r,c;d(r+1,c+1);
8040 NEXT c: NEXT r: RETURN
8049 REM "Mark digit/batch"
8050 LET row=((b>=32) + d*2)
8060 LET col=((b>=32)*(b-32)+(b<32)*b)
8070 LET addr=16384+6144+row*32+col
8080 POKE addr,mark
8090 RETURN
8999 REM "Get timer in minutes"
9000 LET time=(PEEK 23672+256*(PEEK 23673+256*PEEK 23674))/3000
9010 RETURN
9020 RETURN : REM "Uncomment to prof"
9021 LET z(timep) = (PEEK 23672+256*(PEEK 23673+256*PEEK 23674))/50
9025 LET z$(timep) = n$
9030 LET timep=timep+1
9040 RETURN
9200 RETURN : REM "Uncomment to prof"
9201 CLS
9210 FOR i=1 TO timep-2
9220 PRINT z$(i);":";(z(i+1) - z(i))
9230 NEXT i
9240 STOP