Skip to content

Commit f4ee12c

Browse files
MorvanZhouMorvan Zhou
authored andcommitted
update
1 parent acbb46f commit f4ee12c

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

tutorial-contents/401_CNN.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
tf.set_random_seed(1)
1616
np.random.seed(1)
1717

18-
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
1918
BATCH_SIZE = 50
2019
LR = 0.001 # learning rate
2120

@@ -24,7 +23,7 @@
2423
test_y = mnist.test.labels[:2000]
2524

2625
tf_x = tf.placeholder(tf.float32, [None, 28*28])/255. # normalize to range (0, 1)
27-
image = tf.reshape(tf_x, [-1, 28, 28, 1]) # (batch, height, width, channel)
26+
image = tf.reshape(tf_x, [-1, 28, 28, 1]) # (batch, height, width, channel)
2827
tf_y = tf.placeholder(tf.int32, [None, 10]) # input y
2928

3029
# CNN
@@ -35,28 +34,27 @@
3534
strides=1,
3635
padding='same',
3736
activation=tf.nn.relu
38-
) # -> (28, 28, 16)
37+
) # -> (28, 28, 16)
3938
pool1 = tf.layers.max_pooling2d(
4039
conv1,
4140
pool_size=2,
4241
strides=2,
43-
) # -> (14, 14, 16)
42+
) # -> (14, 14, 16)
4443
conv2 = tf.layers.conv2d(pool1, 32, 5, 1, 'same', activation=tf.nn.relu) # -> (14, 14, 32)
4544
pool2 = tf.layers.max_pooling2d(conv2, 2, 2) # -> (7, 7, 32)
46-
flat = tf.reshape(pool2, [-1, 7*7*32]) # -> (7*7*32, )
47-
output = tf.layers.dense(flat, 10) # output layer
45+
flat = tf.reshape(pool2, [-1, 7*7*32]) # -> (7*7*32, )
46+
output = tf.layers.dense(flat, 10) # output layer
4847

4948
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output) # compute cost
5049
train_op = tf.train.AdamOptimizer(LR).minimize(loss)
5150

5251
accuracy = tf.metrics.accuracy( # return (acc, update_op), and create 2 local variables
5352
labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]
5453

55-
sess = tf.Session() # control training and others
54+
sess = tf.Session()
5655
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy_op
5756
sess.run(init_op) # initialize var in graph
5857

59-
6058
for step in range(600):
6159
b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
6260
_, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})

0 commit comments

Comments
 (0)