|
| 1 | +''' |
| 2 | +MNIST using Recurrent Neural Network to predict handwritten digits |
| 3 | +In this tutorial, I am going to demonstrate how to use recurrent neural |
| 4 | +network to predict the famous handwritten digits "MNIST". |
| 5 | +The MNIST dataset consists: |
| 6 | +mnist.train: 55000 training images |
| 7 | +mnist.validation: 5000 validation images |
| 8 | +mnist.test: 10000 test images |
| 9 | +Each image is 28 pixels (rows) by 28 pixels (cols). |
| 10 | +''' |
| 11 | + |
| 12 | +import tensorflow as tf |
| 13 | +import numpy as np |
| 14 | +import matplotlib.pyplot as plt |
| 15 | +import argparse |
| 16 | + |
| 17 | +# Useful function for arguments. |
| 18 | +def str2bool(v): |
| 19 | + return v.lower() in ("yes", "true") |
| 20 | + |
| 21 | +# Parser |
| 22 | +parser = argparse.ArgumentParser(description='Creating Classifier') |
| 23 | + |
| 24 | +###################### |
| 25 | +# Optimization Flags # |
| 26 | +###################### |
| 27 | + |
| 28 | +parser.add_argument('--learning_rate', default=0.001, type=float, help='initial learning rate') |
| 29 | +parser.add_argument('--seed', default=111, type=int, help='seed') |
| 30 | + |
| 31 | +################## |
| 32 | +# Training Flags # |
| 33 | +################## |
| 34 | +parser.add_argument('--batch_size', default=128, type=int, help='Batch size for training') |
| 35 | +parser.add_argument('--num_epoch', default=10, type=int, help='Number of training iterations') |
| 36 | +parser.add_argument('--batch_per_log', default=10, type=int, help='Print the log at what number of batches?') |
| 37 | + |
| 38 | +############### |
| 39 | +# Model Flags # |
| 40 | +############### |
| 41 | +parser.add_argument('--hidden_size', default=128, type=int, help='Number of neurons for RNN hodden layer') |
| 42 | + |
| 43 | +# Add all arguments to parser |
| 44 | +args = parser.parse_args() |
| 45 | + |
| 46 | + |
| 47 | +# Reset the graph set the random numbers to be the same using "seed" |
| 48 | +tf.reset_default_graph() |
| 49 | +tf.set_random_seed(args.seed) |
| 50 | +np.random.seed(args.seed) |
| 51 | + |
| 52 | +# Divide 28x28 images to rows of data to feed to RNN as sequantial information |
| 53 | +step_size = 28 |
| 54 | +input_size = 28 |
| 55 | +output_size = 10 |
| 56 | + |
| 57 | +# Input tensors |
| 58 | +X = tf.placeholder(tf.float32, [None, step_size, input_size]) |
| 59 | +y = tf.placeholder(tf.int32, [None]) |
| 60 | + |
| 61 | +# Rnn |
| 62 | +cell = tf.nn.rnn_cell.BasicRNNCell(num_units=args.hidden_size) |
| 63 | +output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32) |
| 64 | + |
| 65 | +# Forward pass and loss calcualtion |
| 66 | +logits = tf.layers.dense(state, output_size) |
| 67 | +cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits) |
| 68 | +loss = tf.reduce_mean(cross_entropy) |
| 69 | + |
| 70 | +# optimizer |
| 71 | +optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate).minimize(loss) |
| 72 | + |
| 73 | +# Prediction |
| 74 | +prediction = tf.nn.in_top_k(logits, y, 1) |
| 75 | +accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32)) |
| 76 | + |
| 77 | +# input data |
| 78 | +from tensorflow.examples.tutorials.mnist import input_data |
| 79 | +mnist = input_data.read_data_sets("MNIST_data/") |
| 80 | + |
| 81 | +# Process MNIST |
| 82 | +X_test = mnist.test.images # X_test shape: [num_test, 28*28] |
| 83 | +X_test = X_test.reshape([-1, step_size, input_size]) |
| 84 | +y_test = mnist.test.labels |
| 85 | + |
| 86 | +# initialize the variables |
| 87 | +init = tf.global_variables_initializer() |
| 88 | + |
| 89 | +# Empty list for tracking |
| 90 | +loss_train_list = [] |
| 91 | +acc_train_list = [] |
| 92 | + |
| 93 | +# train the model |
| 94 | +with tf.Session() as sess: |
| 95 | + sess.run(init) |
| 96 | + n_batches = mnist.train.num_examples // args.batch_size |
| 97 | + for epoch in range(args.num_epoch): |
| 98 | + for batch in range(n_batches): |
| 99 | + X_train, y_train = mnist.train.next_batch(args.batch_size) |
| 100 | + X_train = X_train.reshape([-1, step_size, input_size]) |
| 101 | + sess.run(optimizer, feed_dict={X: X_train, y: y_train}) |
| 102 | + loss_train, acc_train = sess.run( |
| 103 | + [loss, accuracy], feed_dict={X: X_train, y: y_train}) |
| 104 | + loss_train_list.append(loss_train) |
| 105 | + acc_train_list.append(acc_train) |
| 106 | + print('Epoch: {}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format( |
| 107 | + epoch + 1, loss_train, acc_train)) |
| 108 | + loss_test, acc_test = sess.run( |
| 109 | + [loss, accuracy], feed_dict={X: X_test, y: y_test}) |
| 110 | + print('Test Loss: {:.3f}, Test Acc: {:.3f}'.format(loss_test, acc_test)) |
0 commit comments