1+ """
2+ Know more, visit my Python tutorial page: https://morvanzhou.github.io/tutorials/
3+ My Youtube Channel: https://www.youtube.com/user/MorvanZhou
4+
5+ Dependencies:
6+ tensorflow: 1.1.0
7+ matplotlib
8+ numpy
9+ """
10+ import tensorflow as tf
11+ from tensorflow .examples .tutorials .mnist import input_data
12+ import matplotlib .pyplot as plt
13+ from mpl_toolkits .mplot3d import Axes3D
14+ from matplotlib import cm
15+ import numpy as np
16+
17+ tf .set_random_seed (1 )
18+
19+ # Hyper Parameters
20+ BATCH_SIZE = 64
21+ LR = 0.002 # learning rate
22+ N_TEST_IMG = 5
23+
24+ # Mnist digits
25+ mnist = input_data .read_data_sets ('./mnist' , one_hot = False ) # use not one-hotted target data
26+ test_x = mnist .test .images [:200 ]
27+ test_y = mnist .test .labels [:200 ]
28+
29+ # plot one example
30+ print (mnist .train .images .shape ) # (55000, 28 * 28)
31+ print (mnist .train .labels .shape ) # (55000, 10)
32+ plt .imshow (mnist .train .images [0 ].reshape ((28 , 28 )), cmap = 'gray' )
33+ plt .title ('%i' % np .argmax (mnist .train .labels [0 ]))
34+ plt .show ()
35+
36+ # tf placeholder
37+ tf_x = tf .placeholder (tf .float32 , [None , 28 * 28 ]) # value in the range of (0, 1)
38+
39+ # encoder
40+ en0 = tf .layers .dense (tf_x , 128 , tf .nn .tanh )
41+ en1 = tf .layers .dense (en0 , 64 , tf .nn .tanh )
42+ en2 = tf .layers .dense (en1 , 12 , tf .nn .tanh )
43+ encoded = tf .layers .dense (en2 , 3 )
44+
45+ # decoder
46+ de0 = tf .layers .dense (encoded , 12 , tf .nn .tanh )
47+ de1 = tf .layers .dense (de0 , 64 , tf .nn .tanh )
48+ de2 = tf .layers .dense (de1 , 128 , tf .nn .tanh )
49+ decoded = tf .layers .dense (de2 , 28 * 28 , tf .nn .sigmoid )
50+
51+ loss = tf .losses .mean_squared_error (labels = tf_x , predictions = decoded )
52+ train = tf .train .AdamOptimizer (LR ).minimize (loss )
53+
54+ sess = tf .Session ()
55+ sess .run (tf .global_variables_initializer ())
56+
57+ # initialize figure
58+ f , a = plt .subplots (2 , N_TEST_IMG , figsize = (5 , 2 ))
59+ plt .ion () # continuously plot
60+ plt .show ()
61+
62+ # original data (first row) for viewing
63+ view_data = mnist .test .images [:N_TEST_IMG ]
64+ for i in range (N_TEST_IMG ):
65+ a [0 ][i ].imshow (np .reshape (view_data [i ], (28 , 28 )), cmap = 'gray' )
66+ a [0 ][i ].set_xticks (())
67+ a [0 ][i ].set_yticks (())
68+
69+ for step in range (8000 ):
70+ b_x , b_y = mnist .train .next_batch (BATCH_SIZE )
71+ _ , encoded_ , decoded_ , loss_ = sess .run ([train , encoded , decoded , loss ], {tf_x : b_x })
72+
73+ if step % 100 == 0 : # plotting
74+ print ('train loss: %.4f' % loss_ )
75+ # plotting decoded image (second row)
76+ decoded_data = sess .run (decoded , {tf_x : view_data })
77+ for i in range (N_TEST_IMG ):
78+ a [1 ][i ].clear ()
79+ a [1 ][i ].imshow (np .reshape (decoded_data [i ], (28 , 28 )), cmap = 'gray' )
80+ a [1 ][i ].set_xticks (())
81+ a [1 ][i ].set_yticks (())
82+ plt .draw ()
83+ plt .pause (0.01 )
84+
85+ plt .ioff ()
86+
87+ # visualize in 3D plot
88+ view_data = test_x [:200 ]
89+ encoded_data = sess .run (encoded , {tf_x : view_data })
90+ fig = plt .figure (2 )
91+ ax = Axes3D (fig )
92+ X = encoded_data [:, 0 ]
93+ Y = encoded_data [:, 1 ]
94+ Z = encoded_data [:, 2 ]
95+ for x , y , z , s in zip (X , Y , Z , test_y ):
96+ c = cm .rainbow (int (255 * s / 9 ))
97+ ax .text (x , y , z , s , backgroundcolor = c )
98+ ax .set_xlim (X .min (), X .max ())
99+ ax .set_ylim (Y .min (), Y .max ())
100+ ax .set_zlim (Z .min (), Z .max ())
101+ plt .show ()
0 commit comments