I am trying to implement a batch normalization layer in tensor-flow. I am having no problem running the train step of this using tf.moments to get the mean and variance.
For test time, I'd like to set up an exponential moving average to track the mean and variance. I am trying to do it like this:
def batch_normalized_linear_layer(state_below, scope_name, n_inputs, n_outputs, stddev, wd, eps=.0001):
with tf.variable_scope(scope_name) as scope:
weight = _variable_with_weight_decay(
"weights", shape=[n_inputs, n_outputs],
stddev=stddev, wd=wd
)
act = tf.matmul(state_below, weight)
# get moments
act_mean, act_variance = tf.nn.moments(act, [0])
# get mean and variance variables
mean = _variable_on_cpu('bn_mean', [n_outputs], tf.constant_initializer(0.0))
variance = _variable_on_cpu('bn_variance', [n_outputs], tf.constant_initializer(1.0))
# assign the moments
assign_mean = mean.assign(act_mean)
assign_variance = variance.assign(act_variance)
act_bn = tf.mul((act - mean), tf.rsqrt(variance + eps), name=scope.name+"_bn")
beta = _variable_on_cpu("beta", [n_outputs], tf.constant_initializer(0.0))
gamma = _variable_on_cpu("gamma", [n_outputs], tf.constant_initializer(1.0))
bn = tf.add(tf.mul(act_bn, gamma), beta)
output = tf.nn.relu(bn, name=scope.name)
_activation_summary(output)
return output, mean, variance
Where _variable_on_cpu is defined as:
def _variable_on_cpu(name, shape, initializer):
"""Helper to create a Variable stored on CPU memory.
Args:
name: name of the variable
shape: list of ints
initializer: initializer for Variable
Returns:
Variable Tensor
"""
with tf.device('/cpu:0'):
var = tf.get_variable(name, shape, initializer=initializer)
return var
I believe that I am setting
assign_mean = mean.assign(act_mean)
assign_variance = variance.assign(act_variance)
Incorrectly, but I am not sure how. When I use tensorboard to track these mean and variance variables, they are just flat that their initialized values.
output = tf.with_dependencies(dependencies=[assign_mean, assign_variance], output_tensor=output)just before return.