Skip to content

Commit 3fb2d34

Browse files
committed
added atrous_conv2d
1 parent 123db75 commit 3fb2d34

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ language: python
33
python:
44
- "2.7"
55
- "3.5"
6+
cache: pip
67
install:
78

89
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then

tensorgraph/layers/conv.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,50 @@ def _variables(self):
123123
return [self.filter, self.b]
124124

125125

126+
class Atrous_Conv2D(Template):
127+
128+
@Template.init_name_scope
129+
def __init__(self, input_channels, num_filters, rate=1, kernel_size=(3,3),
130+
filter=None, b=None, padding='VALID', stddev=0.1):
131+
'''
132+
PARAM:
133+
padding: "SAME", same as input shape
134+
"VALID", 0 padding, output shape = h - 2*((k+(k-1)*(r-1))/2)
135+
where k: kernel_size, r: rate and effective filter
136+
size is k+(k-1)*(r-1) with (r-1) zeros inserted
137+
between every horizontal and vertical neighbouring
138+
filter values.
139+
'''
140+
self.input_channels = input_channels
141+
self.num_filters = num_filters
142+
self.kernel_size = kernel_size
143+
self.rate = rate
144+
self.padding = padding
145+
146+
self.filter_shape = self.kernel_size + (self.input_channels, self.num_filters)
147+
self.filter = filter
148+
if self.filter is None:
149+
self.filter = tf.Variable(tf.random_normal(self.filter_shape, stddev=stddev),
150+
name=self.__class__.__name__ + '_filter')
151+
152+
self.b = b
153+
if self.b is None:
154+
self.b = tf.Variable(tf.zeros([self.num_filters]), name=self.__class__.__name__ + '_b')
155+
156+
157+
def _train_fprop(self, state_below):
158+
'''state_below: (b, h, w, c)
159+
'''
160+
conv_out = tf.nn.atrous_conv2d(state_below, self.filter, rate=self.rate, padding=self.padding)
161+
return tf.nn.bias_add(conv_out, self.b)
162+
163+
164+
@property
165+
def _variables(self):
166+
return [self.filter, self.b]
167+
168+
169+
126170
class Depthwise_Conv2D(Template):
127171

128172
@Template.init_name_scope

test/layer_conv_test.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
import tensorflow as tf
33
import tensorgraph as tg
4-
from tensorgraph.layers import Depthwise_Conv2D
4+
from tensorgraph.layers import Depthwise_Conv2D, Atrous_Conv2D
55
import numpy as np
66

77
def test_Depthwise_Conv2d():
@@ -19,5 +19,41 @@ def test_Depthwise_Conv2d():
1919
print(out.shape)
2020

2121

22+
def test_Atrous_Conv2d():
23+
24+
seq = tg.Sequential()
25+
seq.add(Atrous_Conv2D(input_channels=5, num_filters=2, kernel_size=(3, 3), rate=3, padding='SAME'))
26+
27+
h, w, c = 100, 300, 5
28+
X_ph = tf.placeholder('float32', [None, h, w, c])
29+
30+
y_sb = seq.train_fprop(X_ph)
31+
with tf.Session() as sess:
32+
init = tf.global_variables_initializer()
33+
sess.run(init)
34+
out = sess.run(y_sb, feed_dict={X_ph:np.random.rand(32, h, w, c)})
35+
print(out.shape)
36+
assert out.shape[1] == h and out.shape[2] == w
37+
38+
39+
seq = tg.Sequential()
40+
r = 2
41+
k = 5
42+
seq.add(Atrous_Conv2D(input_channels=5, num_filters=2, kernel_size=(k, k), rate=r, padding='VALID'))
43+
44+
h, w, c = 100, 300, 5
45+
X_ph = tf.placeholder('float32', [None, h, w, c])
46+
47+
y_sb = seq.train_fprop(X_ph)
48+
with tf.Session() as sess:
49+
init = tf.global_variables_initializer()
50+
sess.run(init)
51+
out = sess.run(y_sb, feed_dict={X_ph:np.random.rand(32, h, w, c)})
52+
print(out.shape)
53+
assert out.shape[1] == h - 2*((k+(k-1)*(r-1))/2), out.shape[2] == w - 2*((w+(w-1)*(r-1))/2)
54+
55+
56+
2257
if __name__ == '__main__':
23-
test_Depthwise_Conv2d()
58+
# test_Depthwise_Conv2d()
59+
test_Atrous_Conv2d()

0 commit comments

Comments
 (0)