forked from liuzhuang13/DenseNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDenseConnectLayer.lua
More file actions
152 lines (130 loc) · 5.5 KB
/
DenseConnectLayer.lua
File metadata and controls
152 lines (130 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
require 'nn'
require 'cudnn'
require 'cunn'
local function ShareGradInput(module, key)
assert(key)
module.__shareGradInputKey = key
return module
end
--------------------------------------------------------------------------------
-- Standard densely connected layer (memory inefficient)
--------------------------------------------------------------------------------
function DenseConnectLayerStandard(nChannels, opt)
local net = nn.Sequential()
net:add(ShareGradInput(cudnn.SpatialBatchNormalization(nChannels), 'first'))
net:add(cudnn.ReLU(true))
if opt.bottleneck then
net:add(cudnn.SpatialConvolution(nChannels, 4 * opt.growthRate, 1, 1, 1, 1, 0, 0))
nChannels = 4 * opt.growthRate
if opt.dropRate > 0 then net:add(nn.Dropout(opt.dropRate)) end
net:add(cudnn.SpatialBatchNormalization(nChannels))
net:add(cudnn.ReLU(true))
end
net:add(cudnn.SpatialConvolution(nChannels, opt.growthRate, 3, 3, 1, 1, 1, 1))
if opt.dropRate > 0 then net:add(nn.Dropout(opt.dropRate)) end
return nn.Sequential()
:add(nn.Concat(2)
:add(nn.Identity())
:add(net))
end
--------------------------------------------------------------------------------
-- Customized densely connected layer (memory efficient)
--------------------------------------------------------------------------------
local DenseConnectLayerCustom, parent = torch.class('nn.DenseConnectLayerCustom', 'nn.Container')
function DenseConnectLayerCustom:__init(nChannels, opt)
parent.__init(self)
self.train = true
self.opt = opt
self.net1 = nn.Sequential()
self.net1:add(ShareGradInput(cudnn.SpatialBatchNormalization(nChannels), 'first'))
self.net1:add(cudnn.ReLU(true))
self.net2 = nn.Sequential()
if opt.bottleneck then
self.net2:add(cudnn.SpatialConvolution(nChannels, 4*opt.growthRate, 1, 1, 1, 1, 0, 0))
nChannels = 4 * opt.growthRate
self.net2:add(cudnn.SpatialBatchNormalization(nChannels))
self.net2:add(cudnn.ReLU(true))
end
self.net2:add(cudnn.SpatialConvolution(nChannels, opt.growthRate, 3, 3, 1, 1, 1, 1))
-- contiguous outputs of previous layers
self.input_c = torch.Tensor():type(opt.tensorType)
-- save a copy of BatchNorm statistics before forwarding it for the second time when optMemory=4
self.saved_bn_running_mean = torch.Tensor():type(opt.tensorType)
self.saved_bn_running_var = torch.Tensor():type(opt.tensorType)
self.gradInput = {}
self.output = {}
self.modules = {self.net1, self.net2}
end
function DenseConnectLayerCustom:updateOutput(input)
if type(input) ~= 'table' then
self.output[1] = input
self.output[2] = self.net2:forward(self.net1:forward(input))
else
for i = 1, #input do
self.output[i] = input[i]
end
torch.cat(self.input_c, input, 2)
self.net1:forward(self.input_c)
self.output[#input+1] = self.net2:forward(self.net1.output)
end
if self.opt.optMemory == 4 then
local running_mean, running_var = self.net1:get(1).running_mean, self.net1:get(1).running_var
self.saved_bn_running_mean:resizeAs(running_mean):copy(running_mean)
self.saved_bn_running_var:resizeAs(running_var):copy(running_var)
end
return self.output
end
function DenseConnectLayerCustom:updateGradInput(input, gradOutput)
if type(input) ~= 'table' then
self.gradInput = gradOutput[1]
if self.opt.optMemory == 4 then self.net1:forward(input) end
self.net2:updateGradInput(self.net1.output, gradOutput[2])
self.gradInput:add(self.net1:updateGradInput(input, self.net2.gradInput))
else
torch.cat(self.input_c, input, 2)
if self.opt.optMemory == 4 then self.net1:forward(self.input_c) end
self.net2:updateGradInput(self.net1.output, gradOutput[#gradOutput])
self.net1:updateGradInput(self.input_c, self.net2.gradInput)
local nC = 1
for i = 1, #input do
self.gradInput[i] = gradOutput[i]
self.gradInput[i]:add(self.net1.gradInput:narrow(2,nC,input[i]:size(2)))
nC = nC + input[i]:size(2)
end
end
if self.opt.optMemory == 4 then
self.net1:get(1).running_mean:copy(self.saved_bn_running_mean)
self.net1:get(1).running_var:copy(self.saved_bn_running_var)
end
return self.gradInput
end
function DenseConnectLayerCustom:accGradParameters(input, gradOutput, scale)
scale = scale or 1
self.net2:accGradParameters(self.net1.output, gradOutput[#gradOutput], scale)
if type(input) ~= 'table' then
self.net1:accGradParameters(input, self.net2.gradInput, scale)
else
self.net1:accGradParameters(self.input_c, self.net2.gradInput, scale)
end
end
function DenseConnectLayerCustom:__tostring__()
local tab = ' '
local line = '\n'
local next = ' |`-> '
local lastNext = ' `-> '
local ext = ' | '
local extlast = ' '
local last = ' ... -> '
local str = 'DenseConnectLayerCustom'
str = str .. ' {' .. line .. tab .. '{input}'
for i=1,#self.modules do
if i == #self.modules then
str = str .. line .. tab .. lastNext .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. extlast)
else
str = str .. line .. tab .. next .. '(' .. i .. '): ' .. tostring(self.modules[i]):gsub(line, line .. tab .. ext)
end
end
str = str .. line .. tab .. last .. '{output}'
str = str .. line .. '}'
return str
end