Skip to content

Commit 7bb4e13

Browse files
committed
added an implementation of gradient descent
1 parent c9d10c5 commit 7bb4e13

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

chapter_02/gradient_descent.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import math
2+
import random
3+
4+
# our network:
5+
#
6+
# x0 ---w0---|
7+
# x1 ---w1---v--sigmoid->
8+
# x2 ---w2---|
9+
#
10+
# where x0 = 1 is the bias
11+
12+
# the weights of the first (and only) layer
13+
weights = [ 1., 2., 3.]
14+
# the output node
15+
v = 0.
16+
17+
# sigmoid function
18+
def sigmoid(x):
19+
return 1 / (1 + math.exp(-x))
20+
21+
def forward_pass(inputs):
22+
# v = 1 * w0 + x1 * w1 + x2 * w2
23+
v_ = 1 * weights[0] + inputs[0] * weights[1] + inputs[1] * weights[2]
24+
return sigmoid(v_)
25+
26+
def error_func(x,y):
27+
return (x-y) ** 2
28+
29+
def total_error(train_data):
30+
# to measure the total error, we
31+
# sum up over all training examples
32+
# and take the mean of the squared errors
33+
err = 0.
34+
for train_in, train_res in train_data:
35+
net_out = forward_pass(train_in)
36+
err += error_func(train_res, net_out)
37+
return err / len(train_data)
38+
39+
# training data for AND
40+
# ([x1, x2], expected result)
41+
train_data_and = [
42+
([ 0, 0 ], 0.),
43+
([ 0, 1 ], 0.),
44+
([ 1, 0 ], 0.),
45+
([ 1, 1 ], 1.)
46+
]
47+
48+
# current network outputs and error value for the example (1,0)
49+
out = forward_pass(train_data_and[2][0])
50+
print("network output for (1,0)")
51+
print(out)
52+
print("error value for (1,0)")
53+
print(error_func(out, train_data_and[2][1]))
54+
55+
# current network outputs and error value for the example (1,1)
56+
out = forward_pass(train_data_and[3][0])
57+
print("network output for (1,1)")
58+
print(out)
59+
print("error value for (1,1)")
60+
print(error_func(out, train_data_and[3][1]))
61+
62+
def train_by_gradient_descent(train_data):
63+
train_in, train_res = train_data
64+
65+
# first calculate the network output for the example
66+
net_out = forward_pass(train_in)
67+
68+
# we now need to update the weights w0 (=bias weight), w1 and w2
69+
# i.e. we have to compute the partial derivatives using the chain rule
70+
# a) ∂_err/ ∂_w0 = ∂_err/ ∂_out * ∂_out/ ∂_net * ∂_net/ ∂_w0
71+
# b) ∂_err/ ∂_w1 = ∂_err/ ∂_out * ∂_out/ ∂_net * ∂_net/ ∂_w1
72+
# c) ∂_err/ ∂_w2 = ∂_err/ ∂_out * ∂_out/ ∂_net * ∂_net/ ∂_w2
73+
74+
# ∂_err/ ∂_out = 2 * ((train_res - net_out) ** (2 - 1)) * -1
75+
# technically we can define the error function as 1/2 * error_func
76+
# so that the 2 cancels out when computing the derivative, but let's
77+
# keep it here for fun
78+
err_out = -2*(train_res - net_out)
79+
80+
# ∂_out/ ∂_net = out * (1 - out)
81+
out_net = net_out * (1 - net_out)
82+
83+
# ∂_net/ ∂_w0 = x0 = 1
84+
# ∂_net/ ∂_w1 = x1
85+
# ∂_net/ ∂_w2 = x2
86+
net_w0 = 1
87+
net_w1 = train_in[0]
88+
net_w2 = train_in[1]
89+
90+
# multiply to get the actual derivatives
91+
total_w0 = err_out * out_net * net_w0
92+
total_w1 = err_out * out_net * net_w1
93+
total_w2 = err_out * out_net * net_w2
94+
95+
# update the weights, let's use a learning rate of 0.5
96+
nu = 0.5
97+
weights[0] = weights[0] - nu * total_w0
98+
weights[1] = weights[1] - nu * total_w1
99+
weights[2] = weights[2] - nu * total_w2
100+
101+
# let's train out network a few times on the four training data
102+
for i in range(0,100):
103+
104+
train_example = train_data_and[random.randint(0, 3)]
105+
train_by_gradient_descent(train_example)
106+
print("total error:")
107+
print(total_error(train_data_and))
108+
109+
print("learned weights:")
110+
print(weights)
111+
112+
# current network outputs and error value for the example (1,0)
113+
out = forward_pass(train_data_and[2][0])
114+
print("network output for (1,0)")
115+
print(out)
116+
print("error value for (1,0)")
117+
print(error_func(out, train_data_and[2][1]))
118+
119+
# current network outputs and error value for the example (1,1)
120+
out = forward_pass(train_data_and[3][0])
121+
print("network output for (1,1)")
122+
print(out)
123+
print("error value for (1,1)")
124+
print(error_func(out, train_data_and[3][1]))

0 commit comments

Comments
 (0)