Skip to content

Commit 5ad6e16

Browse files
committed
add finetuning
1 parent d2bde86 commit 5ad6e16

File tree

5 files changed

+75
-351
lines changed

5 files changed

+75
-351
lines changed
Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
3-
# File: vgg_cifar.py
3+
# File: inception_cifar.py
44
# Author: Qian Ge <geqian1001@gmail.com>
55

66
import os
@@ -12,14 +12,15 @@
1212

1313
sys.path.append('../')
1414
import loader as loader
15-
from src.nets.vgg import VGG_CIFAR10
15+
from src.nets.googlenet import GoogleNet_cifar
1616
from src.helper.trainer import Trainer
1717
from src.helper.evaluator import Evaluator
1818

1919

2020
DATA_PATH = '/home/qge2/workspace/data/dataset/cifar/'
21-
SAVE_PATH = '/home/qge2/workspace/data/out/vgg/cifar/final/'
22-
VGG_PATH = '/home/qge2/workspace/data/pretrain/vgg/vgg19.npy'
21+
# DATA_PATH = '/Users/gq/workspace/Dataset/cifar-10-batches-py/'
22+
SAVE_PATH = '/home/qge2/workspace/data/out/googlenet/cifar/'
23+
PRETRINED_PATH = '/home/qge2/workspace/data/pretrain/inception/googlenet.npy'
2324

2425
def get_args():
2526
parser = argparse.ArgumentParser()
@@ -51,14 +52,14 @@ def train():
5152

5253
pre_trained_path=None
5354
if FLAGS.finetune:
54-
pre_trained_path = VGG_PATH
55-
train_model = VGG_CIFAR10(
55+
pre_trained_path = PRETRINED_PATH
56+
train_model = GoogleNet_cifar(
5657
n_channel=3, n_class=10, pre_trained_path=pre_trained_path,
57-
bn=True, wd=5e-3, trainable=True, sub_vgg_mean=False)
58+
bn=True, wd=0, trainable=True, sub_imagenet_mean=False)
5859
train_model.create_train_model()
5960

60-
valid_model = VGG_CIFAR10(
61-
n_channel=3, n_class=10, bn=True, sub_vgg_mean=False)
61+
valid_model = GoogleNet_cifar(
62+
n_channel=3, n_class=10, bn=True, sub_imagenet_mean=False)
6263
valid_model.create_test_model()
6364

6465
trainer = Trainer(train_model, valid_model, train_data, init_lr=FLAGS.lr)
@@ -71,24 +72,24 @@ def train():
7172
for epoch_id in range(FLAGS.maxepoch):
7273
trainer.train_epoch(sess, keep_prob=FLAGS.keep_prob, summary_writer=writer)
7374
trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer)
74-
saver.save(sess, '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
75-
saver.save(sess, '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
75+
# saver.save(sess, '{}inception-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
76+
# saver.save(sess, '{}inception-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
7677

7778
def evaluate():
7879
FLAGS = get_args()
7980
train_data, valid_data = loader.load_cifar(
8081
cifar_path=DATA_PATH, batch_size=FLAGS.bsize, substract_mean=True)
8182

82-
valid_model = VGG_CIFAR10(
83-
n_channel=3, n_class=10, bn=True, sub_vgg_mean=False)
83+
valid_model = GoogleNet(
84+
n_channel=3, n_class=10, bn=True, sub_imagenet_mean=False)
8485
valid_model.create_test_model()
8586

8687
evaluator = Evaluator(valid_model)
8788

8889
with tf.Session() as sess:
8990
saver = tf.train.Saver()
9091
sess.run(tf.global_variables_initializer())
91-
saver.restore(sess, '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, FLAGS.load))
92+
saver.restore(sess, '{}inception-cifar-epoch-{}'.format(SAVE_PATH, FLAGS.load))
9293
print('training set:', end='')
9394
evaluator.accuracy(sess, train_data)
9495
print('testing set:', end='')

src/models/inception_module.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,6 @@ def inception_layer(conv_11_size, conv_33_reduce_size, conv_33_size,
6161
convpool = L.conv(filter_size=1, out_dim=pool_size,
6262
name='{}_pool_proj'.format(name))
6363

64-
# conv_11 = conv(inputs, 1, conv_11_size, '{}_1x1'.format(name))
65-
66-
# conv_33_reduce = conv(inputs, 1, conv_33_reduce_size,
67-
# '{}_3x3_reduce'.format(name))
68-
# conv_33 = conv(conv_33_reduce, 3, conv_33_size, '{}_3x3'.format(name))
69-
70-
# conv_55_reduce = conv(inputs, 1, conv_55_reduce_size,
71-
# '{}_5x5_reduce'.format(name))
72-
# conv_55 = conv(conv_55_reduce, 5, conv_55_size, '{}_5x5'.format(name))
73-
74-
# pool = max_pool(inputs, '{}_pool'.format(name), stride=1,
75-
# padding='SAME', filter_size=3)
76-
# convpool = conv(pool, 1, pool_size, '{}_pool_proj'.format(name))
7764
output = tf.concat([conv_11, conv_33, conv_55, convpool], 3,
7865
name='{}_concat'.format(name))
7966
layer_dict['cur_input'] = output
@@ -143,7 +130,7 @@ def inception_layers(layer_dict, inputs=None, pretrained_dict=None,
143130

144131
return layer_dict['cur_input']
145132

146-
def inception_fc(layer_dict, n_class, keep_prob, inputs=None,
133+
def inception_fc(layer_dict, n_class, keep_prob=1., inputs=None,
147134
pretrained_dict=None, is_training=True,
148135
bn=False, init_w=None, trainable=True, wd=0):
149136

@@ -161,36 +148,29 @@ def inception_fc(layer_dict, n_class, keep_prob, inputs=None,
161148

162149
return layer_dict['cur_input']
163150

151+
def auxiliary_classifier(layer_dict, n_class, keep_prob=1., inputs=None,
152+
pretrained_dict=None, is_training=True,
153+
bn=False, init_w=None, trainable=True, wd=0):
154+
155+
if inputs is not None:
156+
layer_dict['cur_input'] = inputs
164157

165-
# with arg_scope([inception_layer],
166-
# trainable=self._trainable,
167-
# data_dict=data_dict):
168-
# # inception3a = inception_layer(
169-
# # pool2_lrn, 64, 96, 128, 16, 32, 32, name='inception_3a')
170-
# # inception3b = inception_layer(
171-
# # inception3a, 128, 128, 192, 32, 96, 64, name='inception_3b')
172-
# # pool3 = max_pool(
173-
# # inception3b, 'pool3', padding='SAME', filter_size=3, stride=2)
174-
175-
# # inception4a = inception_layer(
176-
# # pool3, 192, 96, 208, 16, 48, 64, name='inception_4a')
177-
# # inception4b = inception_layer(
178-
# # inception4a, 160, 112, 224, 24, 64, 64, name='inception_4b')
179-
# # inception4c = inception_layer(
180-
# # inception4b, 128, 128, 256, 24, 64, 64, name='inception_4c')
181-
# # inception4d = inception_layer(
182-
# # inception4c, 112, 144, 288, 32, 64, 64, name='inception_4d')
183-
# # inception4e = inception_layer(
184-
# # inception4d, 256, 160, 320, 32, 128, 128, name='inception_4e')
185-
# # pool4 = max_pool(
186-
# # inception4e, 'pool4', padding='SAME', filter_size=3, stride=2)
187-
188-
# inception5a = inception_layer(
189-
# pool4, 256, 160, 320, 32, 128, 128, name='inception_5a')
190-
# inception5b = inception_layer(
191-
# inception5a, 384, 192, 384, 48, 128, 128, name='inception_5b')
192-
193-
158+
layer_dict['cur_input'] = tf.layers.average_pooling2d(
159+
inputs=layer_dict['cur_input'],
160+
pool_size=5, strides=3,
161+
padding='valid', name='averagepool')
194162

163+
arg_scope = tf.contrib.framework.arg_scope
164+
with arg_scope([L.conv], layer_dict=layer_dict, pretrained_dict=pretrained_dict,
165+
bn=bn, init_w=init_w, trainable=trainable,
166+
is_training=is_training, wd=wd, add_summary=False):
195167

168+
L.conv(1, 128, name='conv', stride=1, nl=tf.nn.relu)
169+
L.conv(4, 1024, name='fc_1', stride=1, padding='VALID')
170+
L.drop_out(layer_dict, is_training, keep_prob=keep_prob)
171+
L.conv(1, 1024, name='fc_2', stride=1, padding='VALID', nl=tf.nn.relu)
172+
L.drop_out(layer_dict, is_training, keep_prob=keep_prob)
173+
L.conv(1, n_class, name='classifier', stride=1, padding='VALID')
174+
layer_dict['cur_input'] = tf.squeeze(layer_dict['cur_input'], [1, 2])
175+
return layer_dict['cur_input']
196176

src/models/vgg_module.py

Lines changed: 0 additions & 154 deletions
This file was deleted.

src/nets/googlenet.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def create_train_model(self):
6060
with tf.variable_scope('fc_layers', reuse=tf.AUTO_REUSE):
6161
self.layers['logits'] = self._fc_layers(self.layers['inception_out'])
6262

63+
with tf.variable_scope('auxiliary_classifier_0'):
64+
self.layers['auxiliary_logits_0'] = self._auxiliary_classifier(
65+
self.layers['inception_4a'])
66+
with tf.variable_scope('auxiliary_classifier_1'):
67+
self.layers['auxiliary_logits_1'] = self._auxiliary_classifier(
68+
self.layers['inception_4d'])
69+
6370
def create_test_model(self):
6471
self.set_is_training(is_training=False)
6572
self._create_test_input()
@@ -101,24 +108,50 @@ def _fc_layers(self, inputs):
101108
is_training=self.is_training, wd=self._wd)
102109
return fc_out
103110

111+
def _auxiliary_classifier(self, inputs):
112+
logits = auxiliary_classifier(
113+
layer_dict=self.layers, n_class=self.n_class, keep_prob=self.keep_prob,
114+
inputs=inputs, pretrained_dict=None, is_training=self.is_training,
115+
bn=self._bn, init_w=INIT_W, trainable=self._trainable, wd=self._wd)
116+
return logits
117+
104118
def _get_loss(self):
105119
with tf.name_scope('loss'):
106120
labels = self.label
107-
logits = self.layers['gap_out']
108-
# logits = tf.squeeze(logits, axis=1)
121+
logits = self.layers['logits']
122+
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
123+
labels=labels,
124+
logits=logits,
125+
name='cross_entropy')
126+
auxilarity_loss = self._get_auxiliary_loss(0) + self._get_auxiliary_loss(1)
127+
return tf.reduce_mean(cross_entropy) + 0.3 * auxilarity_loss
128+
129+
def _get_auxiliary_loss(self, loss_id):
130+
with tf.name_scope('auxilarity_loss_{}'.format(loss_id)):
131+
labels = self.label
132+
logits = self.layers['auxiliary_logits_{}'.format(loss_id)]
109133
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
110134
labels=labels,
111135
logits=logits,
112136
name='cross_entropy')
113-
return tf.reduce_mean(cross_entropy)
137+
return tf.reduce_mean(cross_entropy)
114138

115139
def _get_optimizer(self):
116140
return tf.train.AdamOptimizer(self.lr)
117141

118142
def get_accuracy(self):
119143
with tf.name_scope('accuracy'):
120-
prediction = tf.argmax(self.layers['gap_out'], axis=1)
144+
prediction = tf.argmax(self.layers['logits'], axis=1)
121145
correct_prediction = tf.equal(prediction, self.label)
122146
return tf.reduce_mean(
123147
tf.cast(correct_prediction, tf.float32),
124148
name = 'result')
149+
150+
class GoogleNet_cifar(GoogleNet):
151+
def _fc_layers(self, inputs):
152+
fc_out = module.inception_fc(
153+
layer_dict=self.layers, n_class=self.n_class, keep_prob=self.keep_prob,
154+
inputs=inputs, pretrained_dict=None,
155+
bn=self._bn, init_w=INIT_W, trainable=self._trainable,
156+
is_training=self.is_training, wd=self._wd)
157+
return fc_out

0 commit comments

Comments
 (0)