I am using Stochastic Weight Averaging (SWA) with Batch Normalization layers in Tensorflow 2.2. For Batch Norm I use tf.keras.layers.BatchNormalization. For SWA I use my own code to average the weights (I wrote my code before tfa.optimizers.SWA appeared). I have read in multiple sources that if using batch norm and SWA we must run a forward pass to make certain data (running mean and st dev of activation weights and/or momentum values?) available to the batch norm layers. What I do not understand - despite a lot of reading - is exactly what needs to be done and how. Specifically:
- When must the forward/prediction pass be run? At the end of each mini-batch, end of each epoch, end of all training?
- When the forward pass is run, how are the running mean & stdev values made available to the batch norm layers?
- Is this process performed magically by the
tfa.optimizers.SWAclass?