TensorFlow 2.0におけるBatch Normalizationの動作(training, trainable)
TensorFlow2.0以降(TF2)におけるBatch Normalization(Batch Norm)層、tf.keras.layers.BatchNormalizationの動作について、引数trainingおよびtrainable属性と訓練モード・推論モードの関係を中心に、以下の内容を説明する。
- Batch Normalization(Batch Norm)のアルゴリズム
BatchNormalization層のTrainable params/Non-trainable paramsBatchNormalization層の訓練モードと推論モード- 実際の動作の確認
- 訓練モード(
training=True) - 推論モード(
training=False) trainable属性がFalseのとき
- 訓練モード(
最初にまとめておくと、TensorFlow2.0以降(TF2)のBatchNormalizationの動作は以下の通り。
- 訓練モード(
training=True)- ミニバッチの平均と分散で正規化する
- 平均と分散の移動平均
moving_meanとmoving_varianceを更新する
- 推論モード(
training=False)moving_meanとmoving_varianceで正規化するmoving_meanとmoving_varianceを更新しない
trainable属性がFalseのときtrainingの値によらず、常に推論モードで動作する(※TensorFlow2.0以降)fit()などのメソッドでも推論モード
trainable属性の基本的な使い方については以下の記事を参照。
以下のサンプルコードのTensorFlowのバージョンは2.1.0。TensorFlowに統合されたKerasを使う。
import tensorflow as tf
import numpy as np
print(tf.__version__)
# 2.1.0
Batch Normalization(Batch Norm)のアルゴリズム
Batch Normalization(Batch Norm)のアルゴリズムを簡単に示す。
まず、ミニバッチごとに平均が0、分散が1となるようにデータの正規化を行う。数式で表すと以下の通り。$\mu$ は平均、$\sigma^2$ は分散。$\epsilon$ はゼロ除算を防ぐための小さな値。
$$ \hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} $$
さらに、正規化されたデータに対してスケーリングとシフトを行う。その係数が $\gamma$ と $\beta$ で、訓練によって調整されるパラメータ。
$$ y = \gamma \hat{x} + \beta $$
訓練(学習)時はミニバッチの平均・分散を使って正規化するが、推論時は一般的にそれらの移動平均を使って正規化することが多い。
BatchNormalization層のTrainable params / Non-trainable params
TensorFlow2.0以降(TF2)ではKerasとの統合が強化され、Kerasで提供されているレイヤー(または、Kerasのレイヤーの基底クラスを継承したカスタムレイヤー)の使用が推奨されている。
Batch Normalizationはtf.keras.layers.BatchNormalizationとして提供されている。
以下のBatchNormalizationだけのモデルを例とする。
model = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])
model.summary()
# Model: "sequential"
# _________________________________________________________________
# Layer (type) Output Shape Param #
# =================================================================
# bn (BatchNormalization) (None, 1) 4
# =================================================================
# Total params: 4
# Trainable params: 2
# Non-trainable params: 2
# _________________________________________________________________
print(model.layers[0].trainable)
# True
summary()の出力を見ると分かるように、Batch Normalization層はtrainable属性がTrueでもNon-trainable params(訓練対象ではないパラメータ)を含む。
trainable_weightsおよびnon_trainable_weights属性でそれぞれのパラメータの内容を確認できる。
for w in model.trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn/gamma:0 [1.]
# bn/beta:0 [0.]
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn/moving_mean:0 [0.]
# bn/moving_variance:0 [1.]
訓練対象のパラメータgammaとbetaがスケーリングとシフトの係数 $\gamma$ と $\beta$、訓練対象でないパラメータmoving_meanとmoving_varianceが推論時に使われる、平均・分散の移動平均を表している。
レイヤーまたはモデルのtrainable属性をFalseとするとFreezeされ、gammaとbetaもNon-trainable params(訓練対象ではないパラメータ)となる。
model.layers[0].trainable = False
print(model.layers[0].trainable_weights)
# []
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn/gamma:0 [1.]
# bn/beta:0 [0.]
# bn/moving_mean:0 [0.]
# bn/moving_variance:0 [1.]
BatchNormalization層の訓練モードと推論モード
BatchNormalization層の訓練モードと推論モードの切り替えはtrainingによって制御される。
training: Python boolean indicating whether the layer should behave in training mode or in inference mode.
training=True: The layer will normalize its inputs using the mean and variance of the current batch of inputs.training=False: The layer will normalize its inputs using the mean and variance of its moving statistics, learned during training. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0
trainingはtrainable属性とは別のもので、model(training=False)のようにモデル自体をcallする(呼び出す)ときに指定する引数。
training=Trueが訓練モード、training=Falseが推論モードで、fit()やpredict()のようなメソッドでは自動的に適切なモードが選択される。例えば、fit()では訓練モード、predict()では推論モードとなる。
上述のように、BatchNormalizationのmoving_mean, moving_varianceは、推論時に使われる移動平均された平均・分散。
それらの値は訓練時に算出されるが、勾配計算によって更新されるものではないため、レイヤーのtrainable属性がTrueでもFalseでもNon-trainable paramsとして扱われる。
2) Updates to the weights (moving statistics) are based on the forward pass of a model rather than the result of gradient computations. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0
TensorFlow 2.0以降は、trainableをFalseとすると常に推論モード(training=Falseのときの挙動)で動作するようになった。
About setting
layer.trainable = Falseon a `BatchNormalization layer:The meaning of setting
layer.trainable = Falseis to freeze the layer, i.e. its internal state will not change during training: its trainable weights will not be updated duringfit()ortrain_on_batch(),and its state updates will not be run.Usually, this does not necessarily mean that the layer is run in inference mode (which is normally controlled by the
trainingargument that can be passed when calling a layer). "Frozen state" and "inference mode" are two separate concepts.However, in the case of the
BatchNormalizationlayer, settingtrainable = Falseon the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).This behavior has been introduced in TensorFlow 2.0, in order to enable
layer.trainable = Falseto produce the most commonly expected behavior in the convnet fine-tuning use case. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0
TensorFlow 1.xでは挙動が異なるので注意。trainableがFalseでも推論モードにはならない。
This behavior only occurs as of TensorFlow 2.0. In 1.*, setting
layer.trainable = Falsewould freeze the layer but would not switch it to inference mode. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0
実際の動作の確認
まとめるとBatchNormalizationの動作は以下の通り。
- 訓練モード(
training=True)- ミニバッチの平均と分散で正規化する
- 平均と分散の移動平均
moving_meanとmoving_varianceを更新する
- 推論モード(
training=False)moving_meanとmoving_varianceで正規化するmoving_meanとmoving_varianceを更新しない
trainable属性がFalseのときtrainingの値によらず、常に推論モードで動作する(※TensorFlow2.0以降)fit()などのメソッドでも推論モード
簡単なサンプルコードで実際の動作を確認する。
冒頭に書いたように、サンプルコードのTensorFlowのバージョンは2.1.0。TensorFlowに統合されたKerasを使う。バージョンが違う場合は異なる結果となる可能性があるので注意。
デフォルトではmoving_meanとmoving_varianceの初期値はそれぞれ0と1。
model = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0 [0.]
# bn_1/moving_variance:0 [1.]
値が100の要素が1個だけのデータを使う。バッチ分の次元が必要なので形状(1, 1)の二次元としている。
a = np.array([[100]]).astype('float32')
print(a)
# [[100.]]
print(a.shape)
# (1, 1)
訓練モード(training=True)
このモデルを訓練モード(training=True)で呼び出すと、出力は0で、moving_meanとmoving_varianceの値が更新される。
print(model(a, training=True))
# tf.Tensor([[0.]], shape=(1, 1), dtype=float32)
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0 [1.]
# bn_1/moving_variance:0 [0.99]
出力が0なのは、訓練モードではミニバッチごとの平均 $\mu$ と分散 $\sigma^2$ が使われるから。この場合は要素が1個だけなので $x - \mu = 0$ となる。
moving_meanとmoving_varianceの値はミニバッチの平均・分散の移動平均で、モデルを呼び出すたびに以下のように更新される。
moving_mean = moving_mean * momentum + mean * (1 - momentum)
moving_variance = moving_variance * momentum + variance * (1 - momentum)
momentumはtf.keras.layers.BatchNormalization()の引数で設定可能で、デフォルト値は0.99。meanとvarianceはミニバッチの平均と分散(この例では100と0)。
初期値から1回更新した値を計算すると、上のサンプルコードの結果と一致することが分かる。
moving_mean = 0 * 0.99 + 100 * (1 - 0.99) = 1
moving_variance = 1 * 0.99 + 0 * (1 - 0.99) = 0.99
なお、Kerasでは更新前のmoving_mean, moving_varianceにかける係数をmomentumと呼ぶが、PyTorchではミニバッチの平均・分散にかける係数(上の式では1 - momentum)をmomentumと呼ぶという違いがある。
- Momentum vs. decay in normalization.py for batch normalization · Issue #6839 · keras-team/keras
- torch.nn.BatchNorm1d — PyTorch master documentation
同じデータで呼び出し続けるとそのデータの平均と分散(この例では100と0)に近づいていくことが確認できる。
for i in range(1000):
model(a, training=True)
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0 [99.99573]
# bn_1/moving_variance:0 [4.273953e-05]
訓練モードではmoving_meanとmoving_varianceの値は使われないので、それらの値が更新されても呼び出し時の出力は変わらない。
print(model(a, training=True))
# tf.Tensor([[0.]], shape=(1, 1), dtype=float32)
推論モード(training=False)での呼び出しや、predict()では、更新されたmoving_meanとmoving_varianceの値が使われる。
print(model(a, training=False))
# tf.Tensor([[0.13110352]], shape=(1, 1), dtype=float32)
print(model.predict(a))
# [[0.13110352]]
trainable属性をFalseとすると常に推論モードで動作する。training=Trueでもmoving_meanとmoving_varianceの値が使われる。
model.layers[0].trainable = False
print(model(a, training=True))
# tf.Tensor([[0.13110352]], shape=(1, 1), dtype=float32)
推論モード(training=False)
推論モードではmoving_meanとmoving_varianceの値は更新されないため、何度モデルを呼び出しても同じ値のまま変わらない。
model = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_2/moving_mean:0 [0.]
# bn_2/moving_variance:0 [1.]
for i in range(1000):
model(a, training=False)
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_2/moving_mean:0 [0.]
# bn_2/moving_variance:0 [1.]
推論モードでの呼び出しとpredict()の出力は以下の通り。ミニバッチの平均・分散ではなくmoving_meanとmoving_varianceの値(ここでは0と1)が使われる。$\epsilon$ の初期値が0.001なので、数式通りの結果となっていることが分かる。
print(model(a, training=False))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)
print(model.predict(a))
# [[99.95004]]
print((100 - 0) / np.sqrt(1 + 0.001))
# 99.95003746877732
$\epsilon$ の初期値はtf.keras.layers.BatchNormalization()の引数epsilonで設定可能。
trainable属性がFalseのとき
trainable属性をFalseとすると、trainingの値によらず常に推論モードで動作する(※TensorFlow 2.0以降)。
training=Trueで呼び出してもmoving_meanとmoving_varianceの値は更新されない。
model = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0 [0.]
# bn_3/moving_variance:0 [1.]
model.layers[0].trainable = False
for i in range(1000):
model(a, training=True)
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0 [1.]
# bn_3/beta:0 [0.]
# bn_3/moving_mean:0 [0.]
# bn_3/moving_variance:0 [1.]
training=Trueで呼び出しても正規化にはmoving_meanとmoving_varianceの値が使われる。
print(model(a, training=True))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)
print(model(a, training=False))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)
print(model.predict(a))
# [[99.95004]]
trainable属性がFalseの場合、通常は訓練モードで動作するfit()などのメソッドにおいても推論モードで動作する。
fit()を実行してもmoving_meanとmoving_varianceの値は更新されない。
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(a, a, verbose=0)
# WARNING:tensorflow:The list of trainable weights is empty. Make sure that you are not setting model.trainable to False before compiling the model.
#
# <tensorflow.python.keras.callbacks.History at 0x13c3e78d0>
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0 [1.]
# bn_3/beta:0 [0.]
# bn_3/moving_mean:0 [0.]
# bn_3/moving_variance:0 [1.]
trainable属性をTrueとすると、fit()では訓練モードで動作する。gamma, betaはもちろん、moving_meanとmoving_varianceの値も更新される。
model.layers[0].trainable = True
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(a, a, verbose=0)
# <tensorflow.python.keras.callbacks.History at 0x13c3d1e50>
for w in model.trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0 [1.]
# bn_3/beta:0 [0.001]
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0 [1.]
# bn_3/moving_variance:0 [0.99]
上述のように、moving_meanとmoving_varianceは勾配計算によって調整されるものではないためnon_trainable_weightsに含まれるが、値は更新されるので要注意。
訓練モードでの呼び出しの例と同様に、エポック数を増やすとさらに更新されていく。
model.fit(a, a, epochs=1000, verbose=0)
# <tensorflow.python.keras.callbacks.History at 0x13c5625d0>
for w in model.trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0 [1.]
# bn_3/beta:0 [0.9988577]
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0 [99.99573]
# bn_3/moving_variance:0 [4.273953e-05]