1+ """
2+ Know more, visit my Python tutorial page: https://morvanzhou.github.io/tutorials/
3+ My Youtube Channel: https://www.youtube.com/user/MorvanZhou
4+
5+ Dependencies:
6+ tensorflow: 1.1.0
7+ matplotlib
8+ numpy
9+ """
10+ import tensorflow as tf
11+ import numpy as np
12+ import matplotlib .pyplot as plt
13+
14+
15+ # Hyper Parameters
16+ TIME_STEP = 10 # rnn time step
17+ INPUT_SIZE = 1 # rnn input size
18+ CELL_SIZE = 32 # rnn cell size
19+ LR = 0.02 # learning rate
20+
21+ # show data
22+ steps = np .linspace (0 , np .pi * 2 , 100 , dtype = np .float32 )
23+ x_np = np .sin (steps ) # float32 for converting torch FloatTensor
24+ y_np = np .cos (steps )
25+ plt .plot (steps , y_np , 'r-' , label = 'target (cos)' )
26+ plt .plot (steps , x_np , 'b-' , label = 'input (sin)' )
27+ plt .legend (loc = 'best' )
28+ plt .show ()
29+
30+ # tensorflow placeholders
31+ tf_x = tf .placeholder (tf .float32 , [None , TIME_STEP , INPUT_SIZE ]) # shape(batch, 5, 1)
32+ tf_y = tf .placeholder (tf .float32 , [None , TIME_STEP , INPUT_SIZE ]) # input y
33+
34+ # RNN
35+ rnn_cell = tf .contrib .rnn .BasicRNNCell (num_units = CELL_SIZE )
36+ init_s = rnn_cell .zero_state (batch_size = 1 , dtype = tf .float32 ) # very first hidden state
37+ outputs , final_s = tf .nn .dynamic_rnn (
38+ rnn_cell , # cell you have chosen
39+ tf_x , # input
40+ initial_state = init_s , # the initial hidden state
41+ time_major = False , # False: (batch, time step, input); True: (time step, batch, input)
42+ )
43+ outs2D = tf .reshape (outputs , [- 1 , CELL_SIZE ]) # reshape 3D output to 2D for fully connected layer
44+ net_outs2D = tf .layers .dense (outs2D , INPUT_SIZE )
45+ outs = tf .reshape (net_outs2D , [- 1 , TIME_STEP , INPUT_SIZE ]) # reshape back to 3D
46+
47+ loss = tf .losses .mean_squared_error (labels = tf_y , predictions = outs ) # compute cost
48+ train_op = tf .train .AdamOptimizer (LR ).minimize (loss )
49+
50+ sess = tf .Session ()
51+ init_op = tf .group (tf .global_variables_initializer ())
52+ sess .run (init_op ) # initialize var in graph
53+
54+ plt .figure (1 , figsize = (12 , 5 ))
55+ plt .ion () # continuously plot
56+ plt .show ()
57+
58+ for step in range (60 ):
59+ start , end = step * np .pi , (step + 1 )* np .pi # time steps
60+ # use sin predicts cos
61+ steps = np .linspace (start , end , TIME_STEP )
62+ x = np .sin (steps )[np .newaxis , :, np .newaxis ] # shape (batch, time_step, input_size)
63+ y = np .cos (steps )[np .newaxis , :, np .newaxis ]
64+ if 'final_s_' not in globals (): # first state, no any hidden state
65+ feed_dict = {tf_x : x , tf_y : y }
66+ else : # has hidden state, so pass it to rnn
67+ feed_dict = {tf_x : x , tf_y : y , init_s : final_s_ }
68+ _ , pred_ , final_s_ = sess .run ([train_op , outs , final_s ], feed_dict ) # train
69+
70+ # plotting
71+ plt .plot (steps , y .flatten (), 'r-' )
72+ plt .plot (steps , pred_ .flatten (), 'b-' )
73+ plt .ylim ((- 1.2 , 1.2 ))
74+ plt .draw ()
75+ plt .pause (0.05 )
76+
77+ plt .ioff ()
78+ plt .show ()
0 commit comments