AnalyticsDojo

38. Neural Networks

  • This was adopted from the PyTorch Tutorials.

  • http://pytorch.org/tutorials/beginner/pytorch_with_examples.html

38.1. Neural Networks

  • Neural networks are the foundation of deep learning, which has revolutionized the

In the mathematical theory of artificial neural networks, the universal approximation theorem states[1] that a feed-forward network with a single hidden layer containing a finite number of neurons (i.e., a multilayer perceptron), can approximate continuous functions on compact subsets of Rn, under mild assumptions on the activation function.

38.1.1. Generate Fake Data

  • D_in is the number of dimensions of an input varaible.

  • D_out is the number of dimentions of an output variable.

  • Here we are learning some special “fake” data that represents the xor problem.

  • Here, the dv is 1 if either the first or second variable is

# -*- coding: utf-8 -*-
import numpy as np

#This is our independent and dependent variables. 
x = np.array([ [0,0,0],[1,0,0],[0,1,0],[0,0,0] ])
y = np.array([[0,1,1,0]]).T
print("Input data:\n",x,"\n Output data:\n",y)
Input data:
 [[0 0 0]
 [1 0 0]
 [0 1 0]
 [0 0 0]] 
 Output data:
 [[0]
 [1]
 [1]
 [0]]

38.1.2. A Simple Neural Network

  • Here we are going to build a neural network with 2 hidden layers.

np.random.seed(seed=83832)
#D_in is the number of input variables. 
#H is the hidden dimension.
#D_out is the number of dimensions for the output. 
D_in, H, D_out = 3, 2, 1

# Randomly initialize weights og out 2 hidden layer network.
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)
bias = np.random.randn(H, 1)

38.1.3. Learn the Appropriate Weights via Backpropogation

  • Learning rate adjust how quickly the model will adjust parameters.

# -*- coding: utf-8 -*-

learning_rate = .01
for t in range(500):
    # Forward pass: compute predicted y
    h = x.dot(w1)

    #A relu is just the activation.
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)

    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2
0 10.6579261591
1 9.10203339893
2 7.92822558061
3 7.01603070961
4 6.28979819918
5 5.69984738569
6 5.21233053023
7 4.80346624793
8 4.456102755
9 4.15758768903
10 3.89840273398
11 3.67126267684
12 3.47050562961
13 3.29167096682
14 3.13120131373
15 2.98622833978
16 2.8544162991
17 2.73384607859
18 2.62292812419
19 2.52033626007
20 2.42495682843
21 2.33584920317
22 2.25221484357
23 2.17337282724
24 2.09874034592
25 2.02781703626
26 1.96017229769
27 1.89543495408
28 1.83328476643
29 1.77344541638
30 1.71567866423
31 1.65977944954
32 1.60557175094
33 1.55290505986
34 1.50165135204
35 1.45170246386
36 1.40296779892
37 1.35537230457
38 1.3088546702
39 1.26336570846
40 1.21886688836
41 1.17532899574
42 1.13273090181
43 1.09105842491
44 1.05030327423
45 1.01046206727
46 0.971535415306
47 0.933527072858
48 0.896443149097
49 0.860291379845
50 0.825080459944
51 0.790819436113
52 0.757517160723
53 0.725181806896
54 0.693820445158
55 0.663438681518
56 0.634040356373
57 0.605627303094
58 0.578199164525
59 0.551753265011
60 0.526284534957
61 0.501785484373
62 0.478246221337
63 0.455654510907
64 0.433995869718
65 0.413253691241
66 0.393409396657
67 0.374442606233
68 0.356331326268
69 0.339052146831
70 0.322580445843
71 0.30689059538
72 0.291956166492
73 0.27775012929
74 0.264245045477
75 0.251413251035
76 0.239227027168
77 0.227658758121
78 0.216681074855
79 0.206266984003
80 0.196389981824
81 0.18702415323
82 0.178144256176
83 0.169725791964
84 0.161745062156
85 0.154179212936
86 0.147006267869
87 0.14020515004
88 0.133755694633
89 0.12763865297
90 0.121835689045
91 0.116329369558
92 0.111103148388
93 0.106141346409
94 0.101429127469
95 0.0969524713096
96 0.0926981441057
97 0.0886536672482
98 0.0848072849255
99 0.0811479309802
100 0.0776651954621
101 0.0743492912332
102 0.0711910209273
103 0.0681817445139
104 0.0653133476691
105 0.0625782111165
106 0.0599691810633
107 0.0574795408223
108 0.0551029836872
109 0.0528335870985
110 0.0506657881215
111 0.0485943602377
112 0.0466143914385
113 0.0447212635969
114 0.042910633084
115 0.0411784125909
116 0.0395207541101
117 0.0379340330255
118 0.0364148332613
119 0.0349599334326
120 0.0335662939467
121 0.0322310449976
122 0.0309514754026
123 0.0297250222266
124 0.0285492611449
125 0.0274218974944
126 0.0263407579669
127 0.0253037828991
128 0.0243090191189
129 0.0233546133048
130 0.0224388058244
131 0.0215599250132
132 0.020716381864
133 0.0199066650928
134 0.0191293365551
135 0.0183830269842
136 0.0176664320279
137 0.0169783085599
138 0.0163174712444
139 0.015682789336
140 0.0150731836941
141 0.0144876239984
142 0.0139251261468
143 0.0133847498242
144 0.0128655962277
145 0.0123668059377
146 0.0118875569212
147 0.0114270626611
148 0.0109845703988
149 0.0105593594834
150 0.0101507398196
151 0.00975805040627
152 0.00938065796027
153 0.00901795561851
154 0.00866936171332
155 0.00833431861585
156 0.00801229164266
157 0.00770276802156
158 0.00740525591235
159 0.00711928347915
160 0.00684439801074
161 0.0065801650859
162 0.00632616778098
163 0.00608200591689
164 0.0058472953433
165 0.00562166725752
166 0.00540476755634
167 0.00519625621854
168 0.00499580671658
169 0.0048031054556
170 0.00461785123834
171 0.00443975475448
172 0.00426853809302
173 0.00410393427663
174 0.00394568681655
175 0.00379354928727
176 0.0036472849197
177 0.00350666621201
178 0.00337147455735
179 0.00324149988737
180 0.00311654033101
181 0.00299640188771
182 0.00288089811428
183 0.00276984982498
184 0.00266308480403
185 0.00256043752998
186 0.00246174891152
187 0.00236686603409
188 0.00227564191686
189 0.0021879352796
190 0.002103610319
191 0.00202253649407
192 0.00194458832016
193 0.00186964517126
194 0.00179759109024
195 0.00172831460666
196 0.00166170856184
197 0.00159766994094
198 0.00153609971162
199 0.00147690266916
200 0.00141998728767
201 0.00136526557721
202 0.00131265294649
203 0.00126206807099
204 0.0012134327663
205 0.00116667186639
206 0.00112171310662
207 0.00107848701142
208 0.00103692678629
209 0.000996968213999
210 0.000958549554938
211 0.000921611451238
212 0.000886096834699
213 0.000851950838274
214 0.000819120710998
215 0.000787555736234
216 0.000757207153079
217 0.00072802808083
218 0.000699973446382
219 0.000672999914438
220 0.00064706582042
221 0.00062213110599
222 0.000598157257058
223 0.000575107244193
224 0.000552945465343
225 0.000531637690769
226 0.000511151010104
227 0.000491453781468
228 0.000472515582542
229 0.000454307163537
230 0.000436800401976
231 0.000419968259228
232 0.000403784738718
233 0.000388224845747
234 0.000373264548871
235 0.000358880742763
236 0.000345051212511
237 0.000331754599299
238 0.000318970367402
239 0.000306678772464
240 0.000294860830993
241 0.000283498291035
242 0.000272573603982
243 0.000262069897459
244 0.000251970949269
245 0.000242261162331
246 0.000232925540587
247 0.000223949665846
248 0.00021531967551
249 0.000207022241164
250 0.000199044547992
251 0.000191374274982
252 0.000183999575901
253 0.000176909060996
254 0.000170091779404
255 0.000163537202244
256 0.000157235206347
257 0.000151176058633
258 0.000145350401071
259 0.00013974923623
260 0.000134363913387
261 0.00012918611516
262 0.000124207844669
263 0.000119421413185
264 0.000114819428254
265 0.000110394782283
266 0.000106140641559
267 0.000102050435703
268 9.8117847515e-05
269 9.43368032245e-05
270 9.07014631064e-05
271 8.72062124641e-05
272 8.38456529584e-05
273 8.06145942704e-05
274 7.7508046086e-05
275 7.45212103889e-05
276 7.16494740506e-05
277 6.88884017067e-05
278 6.62337289065e-05
279 6.36813555272e-05
280 6.12273394421e-05
281 5.88678904323e-05
282 5.6599364333e-05
283 5.44182574058e-05
284 5.23212009275e-05
285 5.03049559874e-05
286 4.83664084854e-05
287 4.6502564322e-05
288 4.47105447751e-05
289 4.29875820531e-05
290 4.13310150213e-05
291 3.97382850912e-05
292 3.82069322694e-05
293 3.67345913581e-05
294 3.53189883027e-05
295 3.3957936679e-05
296 3.26493343169e-05
297 3.13911600538e-05
298 3.01814706126e-05
299 2.90183976007e-05
300 2.79001446246e-05
301 2.68249845149e-05
302 2.57912566593e-05
303 2.47973644376e-05
304 2.38417727559e-05
305 2.29230056755e-05
306 2.20396441336e-05
307 2.11903237516e-05
308 2.03737327276e-05
309 1.95886098108e-05
310 1.8833742353e-05
311 1.81079644362e-05
312 1.74101550716e-05
313 1.6739236468e-05
314 1.60941723679e-05
315 1.54739664463e-05
316 1.48776607721e-05
317 1.43043343289e-05
318 1.3753101592e-05
319 1.32231111612e-05
320 1.27135444452e-05
321 1.22236143981e-05
322 1.17525643029e-05
323 1.12996666036e-05
324 1.08642217809e-05
325 1.0445557272e-05
326 1.00430264319e-05
327 9.65600753441e-06
328 9.28390281216e-06
329 8.92613753314e-06
330 8.58215911315e-06
331 8.25143626222e-06
332 7.93345816411e-06
333 7.62773368729e-06
334 7.3337906264e-06
335 7.05117497292e-06
336 6.77945021393e-06
337 6.51819665793e-06
338 6.26701078658e-06
339 6.02550463149e-06
340 5.793305175e-06
341 5.57005377399e-06
342 5.355405606e-06
343 5.14902913661e-06
344 4.95060560741e-06
345 4.75982854362e-06
346 4.57640328076e-06
347 4.40004650954e-06
348 4.23048583824e-06
349 4.06745937205e-06
350 3.9107153085e-06
351 3.76001154859e-06
352 3.61511532285e-06
353 3.47580283178e-06
354 3.34185890026e-06
355 3.21307664512e-06
356 3.08925715566e-06
357 2.9702091864e-06
358 2.85574886169e-06
359 2.74569939173e-06
360 2.6398907995e-06
361 2.53815965818e-06
362 2.44034883882e-06
363 2.34630726756e-06
364 2.25588969234e-06
365 2.16895645855e-06
366 2.08537329331e-06
367 2.00501109806e-06
368 1.92774574924e-06
369 1.85345790647e-06
370 1.7820328283e-06
371 1.71336019498e-06
372 1.64733393802e-06
373 1.58385207641e-06
374 1.52281655908e-06
375 1.46413311345e-06
376 1.40771109987e-06
377 1.35346337155e-06
378 1.30130614001e-06
379 1.25115884566e-06
380 1.20294403334e-06
381 1.15658723272e-06
382 1.11201684326e-06
383 1.06916402363e-06
384 1.02796258538e-06
385 9.88348890689e-07
386 9.50261754097e-07
387 9.13642347993e-07
388 8.7843411175e-07
389 8.44582664369e-07
390 8.12035720482e-07
391 7.80743009592e-07
392 7.50656198435e-07
393 7.21728816319e-07
394 6.93916183354e-07
395 6.67175341437e-07
396 6.41464987903e-07
397 6.16745411731e-07
398 5.92978432208e-07
399 5.70127339959e-07
400 5.48156840241e-07
401 5.27032998437e-07
402 5.06723187636e-07
403 4.87196038242e-07
404 4.68421389522e-07
405 4.50370243018e-07
406 4.33014717762e-07
407 4.16328007208e-07
408 4.00284337833e-07
409 3.84858929323e-07
410 3.70027956301e-07
411 3.5576851153e-07
412 3.42058570526e-07
413 3.28876957544e-07
414 3.1620331287e-07
415 3.04018061376e-07
416 2.92302382281e-07
417 2.81038180085e-07
418 2.70208056617e-07
419 2.59795284167e-07
420 2.49783779642e-07
421 2.40158079733e-07
422 2.30903317024e-07
423 2.22005197033e-07
424 2.13449976131e-07
425 2.05224440316e-07
426 1.97315884802e-07
427 1.89712094396e-07
428 1.82401324632e-07
429 1.75372283628e-07
430 1.68614114651e-07
431 1.6211637934e-07
432 1.55869041592e-07
433 1.49862452054e-07
434 1.44087333223e-07
435 1.38534765113e-07
436 1.33196171483e-07
437 1.28063306584e-07
438 1.23128242427e-07
439 1.18383356539e-07
440 1.13821320184e-07
441 1.09435087051e-07
442 1.05217882363e-07
443 1.01163192419e-07
444 9.72647545298e-08
445 9.35165473464e-08
446 8.99127815589e-08
447 8.64478909551e-08
448 8.31165238233e-08
449 7.99135346858e-08
450 7.68339763517e-08
451 7.38730922756e-08
452 7.10263092112e-08
453 6.82892301472e-08
454 6.56576275159e-08
455 6.31274366642e-08
456 6.06947495746e-08
457 5.83558088295e-08
458 5.61070018079e-08
459 5.39448551048e-08
460 5.18660291671e-08
461 4.9867313135e-08
462 4.79456198829e-08
463 4.60979812509e-08
464 4.43215434606e-08
465 4.26135627074e-08
466 4.09714009217e-08
467 3.93925216956e-08
468 3.7874486364e-08
469 3.64149502389e-08
470 3.50116589871e-08
471 3.3662445149e-08
472 3.23652247904e-08
473 3.11179942838e-08
474 2.99188272138e-08
475 2.87658714016e-08
476 2.76573460442e-08
477 2.65915389638e-08
478 2.55668039634e-08
479 2.4581558284e-08
480 2.363428016e-08
481 2.27235064688e-08
482 2.18478304708e-08
483 2.10058996367e-08
484 2.01964135585e-08
485 1.9418121941e-08
486 1.86698226703e-08
487 1.79503599575e-08
488 1.72586225532e-08
489 1.65935420314e-08
490 1.5954091139e-08
491 1.53392822092e-08
492 1.47481656362e-08
493 1.41798284082e-08
494 1.36333926976e-08
495 1.31080145044e-08
496 1.26028823535e-08
497 1.21172160408e-08
498 1.16502654283e-08
499 1.12013092852e-08

#CFully connected


pred = np.maximum(x.dot(w1),0).dot(w2)

print (pred, "\n", y)
[[ 0.        ]
 [ 0.99992661]
 [ 1.00007337]
 [ 0.        ]] 
 [[0]
 [1]
 [1]
 [0]]

38.1.4. Hidden Layers are Often Viewed as Unknown

  • Just a weighting matrix

#However
w1
array([[-0.20401151,  1.01377406],
       [-0.10186284,  1.01392285],
       [ 1.07856887,  0.01873049]])
w2
array([[ 0.49346731],
       [ 0.98634069]])
# Relu just removes the negative numbers.  
h_relu
array([[ 0.        ,  0.        ,  0.        ],
       [ 0.72108356,  0.        ,  0.        ],
       [ 0.72753913,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ]])