TensorFlow 2.0におけるBatch Normalizationの動作(training, trainable)
TensorFlow2.0
以降(TF2)におけるBatch Normalization(Batch Norm)層、tf.keras.layers.BatchNormalization
の動作について、引数training
およびtrainable
属性と訓練モード・推論モードの関係を中心に、以下の内容を説明する。
- Batch Normalization(Batch Norm)のアルゴリズム
BatchNormalization
層のTrainable params
/Non-trainable params
BatchNormalization
層の訓練モードと推論モード- 実際の動作の確認
- 訓練モード(
training=True
) - 推論モード(
training=False
) trainable
属性がFalse
のとき
- 訓練モード(
最初にまとめておくと、TensorFlow2.0
以降(TF2)のBatchNormalization
の動作は以下の通り。
- 訓練モード(
training=True
)- ミニバッチの平均と分散で正規化する
- 平均と分散の移動平均
moving_mean
とmoving_variance
を更新する
- 推論モード(
training=False
)moving_mean
とmoving_variance
で正規化するmoving_mean
とmoving_variance
を更新しない
trainable
属性がFalse
のときtraining
の値によらず、常に推論モードで動作する(※TensorFlow2.0
以降)fit()
などのメソッドでも推論モード
trainable
属性の基本的な使い方については以下の記事を参照。
以下のサンプルコードのTensorFlowのバージョンは2.1.0
。TensorFlowに統合されたKerasを使う。
import tensorflow as tf
import numpy as np
print(tf.__version__)
# 2.1.0
Batch Normalization(Batch Norm)のアルゴリズム
Batch Normalization(Batch Norm)のアルゴリズムを簡単に示す。
まず、ミニバッチごとに平均が0
、分散が1
となるようにデータの正規化を行う。数式で表すと以下の通り。$\mu$ は平均、$\sigma^2$ は分散。$\epsilon$ はゼロ除算を防ぐための小さな値。
$$ \hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} $$
さらに、正規化されたデータに対してスケーリングとシフトを行う。その係数が $\gamma$ と $\beta$ で、訓練によって調整されるパラメータ。
$$ y = \gamma \hat{x} + \beta $$
訓練(学習)時はミニバッチの平均・分散を使って正規化するが、推論時は一般的にそれらの移動平均を使って正規化することが多い。
BatchNormalization層のTrainable params / Non-trainable params
TensorFlow2.0
以降(TF2)ではKerasとの統合が強化され、Kerasで提供されているレイヤー(または、Kerasのレイヤーの基底クラスを継承したカスタムレイヤー)の使用が推奨されている。
Batch Normalizationはtf.keras.layers.BatchNormalization
として提供されている。
以下のBatchNormalization
だけのモデルを例とする。
model = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])
model.summary()
# Model: "sequential"
# _________________________________________________________________
# Layer (type) Output Shape Param #
# =================================================================
# bn (BatchNormalization) (None, 1) 4
# =================================================================
# Total params: 4
# Trainable params: 2
# Non-trainable params: 2
# _________________________________________________________________
print(model.layers[0].trainable)
# True
summary()
の出力を見ると分かるように、Batch Normalization層はtrainable
属性がTrue
でもNon-trainable params
(訓練対象ではないパラメータ)を含む。
trainable_weights
およびnon_trainable_weights
属性でそれぞれのパラメータの内容を確認できる。
for w in model.trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn/gamma:0 [1.]
# bn/beta:0 [0.]
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn/moving_mean:0 [0.]
# bn/moving_variance:0 [1.]
訓練対象のパラメータgamma
とbeta
がスケーリングとシフトの係数 $\gamma$ と $\beta$、訓練対象でないパラメータmoving_mean
とmoving_variance
が推論時に使われる、平均・分散の移動平均を表している。
レイヤーまたはモデルのtrainable
属性をFalse
とするとFreezeされ、gamma
とbeta
もNon-trainable params
(訓練対象ではないパラメータ)となる。
model.layers[0].trainable = False
print(model.layers[0].trainable_weights)
# []
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn/gamma:0 [1.]
# bn/beta:0 [0.]
# bn/moving_mean:0 [0.]
# bn/moving_variance:0 [1.]
BatchNormalization層の訓練モードと推論モード
BatchNormalization
層の訓練モードと推論モードの切り替えはtraining
によって制御される。
training
: Python boolean indicating whether the layer should behave in training mode or in inference mode.
training=True
: The layer will normalize its inputs using the mean and variance of the current batch of inputs.training=False
: The layer will normalize its inputs using the mean and variance of its moving statistics, learned during training. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0
training
はtrainable
属性とは別のもので、model(training=False)
のようにモデル自体をcall
する(呼び出す)ときに指定する引数。
training=True
が訓練モード、training=False
が推論モードで、fit()
やpredict()
のようなメソッドでは自動的に適切なモードが選択される。例えば、fit()
では訓練モード、predict()
では推論モードとなる。
上述のように、BatchNormalization
のmoving_mean
, moving_variance
は、推論時に使われる移動平均された平均・分散。
それらの値は訓練時に算出されるが、勾配計算によって更新されるものではないため、レイヤーのtrainable
属性がTrue
でもFalse
でもNon-trainable params
として扱われる。
2) Updates to the weights (moving statistics) are based on the forward pass of a model rather than the result of gradient computations. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0
TensorFlow 2.0
以降は、trainable
をFalse
とすると常に推論モード(training=False
のときの挙動)で動作するようになった。
About setting
layer.trainable = False
on a `BatchNormalization layer:The meaning of setting
layer.trainable = False
is to freeze the layer, i.e. its internal state will not change during training: its trainable weights will not be updated duringfit()
ortrain_on_batch(),
and its state updates will not be run.Usually, this does not necessarily mean that the layer is run in inference mode (which is normally controlled by the
training
argument that can be passed when calling a layer). "Frozen state" and "inference mode" are two separate concepts.However, in the case of the
BatchNormalization
layer, settingtrainable = False
on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).This behavior has been introduced in TensorFlow 2.0, in order to enable
layer.trainable = False
to produce the most commonly expected behavior in the convnet fine-tuning use case. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0
TensorFlow 1.x
では挙動が異なるので注意。trainable
がFalse
でも推論モードにはならない。
This behavior only occurs as of TensorFlow 2.0. In 1.*, setting
layer.trainable = False
would freeze the layer but would not switch it to inference mode. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0
実際の動作の確認
まとめるとBatchNormalization
の動作は以下の通り。
- 訓練モード(
training=True
)- ミニバッチの平均と分散で正規化する
- 平均と分散の移動平均
moving_mean
とmoving_variance
を更新する
- 推論モード(
training=False
)moving_mean
とmoving_variance
で正規化するmoving_mean
とmoving_variance
を更新しない
trainable
属性がFalse
のときtraining
の値によらず、常に推論モードで動作する(※TensorFlow2.0
以降)fit()
などのメソッドでも推論モード
簡単なサンプルコードで実際の動作を確認する。
冒頭に書いたように、サンプルコードのTensorFlowのバージョンは2.1.0
。TensorFlowに統合されたKerasを使う。バージョンが違う場合は異なる結果となる可能性があるので注意。
デフォルトではmoving_mean
とmoving_variance
の初期値はそれぞれ0
と1
。
model = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0 [0.]
# bn_1/moving_variance:0 [1.]
値が100
の要素が1個だけのデータを使う。バッチ分の次元が必要なので形状(1, 1)
の二次元としている。
a = np.array([[100]]).astype('float32')
print(a)
# [[100.]]
print(a.shape)
# (1, 1)
訓練モード(training=True)
このモデルを訓練モード(training=True
)で呼び出すと、出力は0
で、moving_mean
とmoving_variance
の値が更新される。
print(model(a, training=True))
# tf.Tensor([[0.]], shape=(1, 1), dtype=float32)
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0 [1.]
# bn_1/moving_variance:0 [0.99]
出力が0
なのは、訓練モードではミニバッチごとの平均 $\mu$ と分散 $\sigma^2$ が使われるから。この場合は要素が1個だけなので $x - \mu = 0$ となる。
moving_mean
とmoving_variance
の値はミニバッチの平均・分散の移動平均で、モデルを呼び出すたびに以下のように更新される。
moving_mean = moving_mean * momentum + mean * (1 - momentum)
moving_variance = moving_variance * momentum + variance * (1 - momentum)
momentum
はtf.keras.layers.BatchNormalization()
の引数で設定可能で、デフォルト値は0.99
。mean
とvariance
はミニバッチの平均と分散(この例では100
と0
)。
初期値から1回更新した値を計算すると、上のサンプルコードの結果と一致することが分かる。
moving_mean = 0 * 0.99 + 100 * (1 - 0.99) = 1
moving_variance = 1 * 0.99 + 0 * (1 - 0.99) = 0.99
なお、Kerasでは更新前のmoving_mean
, moving_variance
にかける係数をmomentum
と呼ぶが、PyTorchではミニバッチの平均・分散にかける係数(上の式では1 - momentum
)をmomentum
と呼ぶという違いがある。
- Momentum vs. decay in normalization.py for batch normalization · Issue #6839 · keras-team/keras
- torch.nn.BatchNorm1d — PyTorch master documentation
同じデータで呼び出し続けるとそのデータの平均と分散(この例では100
と0
)に近づいていくことが確認できる。
for i in range(1000):
model(a, training=True)
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0 [99.99573]
# bn_1/moving_variance:0 [4.273953e-05]
訓練モードではmoving_mean
とmoving_variance
の値は使われないので、それらの値が更新されても呼び出し時の出力は変わらない。
print(model(a, training=True))
# tf.Tensor([[0.]], shape=(1, 1), dtype=float32)
推論モード(training=False
)での呼び出しや、predict()
では、更新されたmoving_mean
とmoving_variance
の値が使われる。
print(model(a, training=False))
# tf.Tensor([[0.13110352]], shape=(1, 1), dtype=float32)
print(model.predict(a))
# [[0.13110352]]
trainable
属性をFalse
とすると常に推論モードで動作する。training=True
でもmoving_mean
とmoving_variance
の値が使われる。
model.layers[0].trainable = False
print(model(a, training=True))
# tf.Tensor([[0.13110352]], shape=(1, 1), dtype=float32)
推論モード(training=False)
推論モードではmoving_mean
とmoving_variance
の値は更新されないため、何度モデルを呼び出しても同じ値のまま変わらない。
model = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_2/moving_mean:0 [0.]
# bn_2/moving_variance:0 [1.]
for i in range(1000):
model(a, training=False)
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_2/moving_mean:0 [0.]
# bn_2/moving_variance:0 [1.]
推論モードでの呼び出しとpredict()
の出力は以下の通り。ミニバッチの平均・分散ではなくmoving_mean
とmoving_variance
の値(ここでは0
と1
)が使われる。$\epsilon$ の初期値が0.001
なので、数式通りの結果となっていることが分かる。
print(model(a, training=False))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)
print(model.predict(a))
# [[99.95004]]
print((100 - 0) / np.sqrt(1 + 0.001))
# 99.95003746877732
$\epsilon$ の初期値はtf.keras.layers.BatchNormalization()
の引数epsilon
で設定可能。
trainable属性がFalseのとき
trainable
属性をFalse
とすると、training
の値によらず常に推論モードで動作する(※TensorFlow 2.0
以降)。
training=True
で呼び出してもmoving_mean
とmoving_variance
の値は更新されない。
model = tf.keras.Sequential([
tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0 [0.]
# bn_3/moving_variance:0 [1.]
model.layers[0].trainable = False
for i in range(1000):
model(a, training=True)
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0 [1.]
# bn_3/beta:0 [0.]
# bn_3/moving_mean:0 [0.]
# bn_3/moving_variance:0 [1.]
training=True
で呼び出しても正規化にはmoving_mean
とmoving_variance
の値が使われる。
print(model(a, training=True))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)
print(model(a, training=False))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)
print(model.predict(a))
# [[99.95004]]
trainable
属性がFalse
の場合、通常は訓練モードで動作するfit()
などのメソッドにおいても推論モードで動作する。
fit()
を実行してもmoving_mean
とmoving_variance
の値は更新されない。
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(a, a, verbose=0)
# WARNING:tensorflow:The list of trainable weights is empty. Make sure that you are not setting model.trainable to False before compiling the model.
#
# <tensorflow.python.keras.callbacks.History at 0x13c3e78d0>
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0 [1.]
# bn_3/beta:0 [0.]
# bn_3/moving_mean:0 [0.]
# bn_3/moving_variance:0 [1.]
trainable
属性をTrue
とすると、fit()
では訓練モードで動作する。gamma
, beta
はもちろん、moving_mean
とmoving_variance
の値も更新される。
model.layers[0].trainable = True
model.compile(optimizer='adam', loss='mean_squared_error')
model.fit(a, a, verbose=0)
# <tensorflow.python.keras.callbacks.History at 0x13c3d1e50>
for w in model.trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0 [1.]
# bn_3/beta:0 [0.001]
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0 [1.]
# bn_3/moving_variance:0 [0.99]
上述のように、moving_mean
とmoving_variance
は勾配計算によって調整されるものではないためnon_trainable_weights
に含まれるが、値は更新されるので要注意。
訓練モードでの呼び出しの例と同様に、エポック数を増やすとさらに更新されていく。
model.fit(a, a, epochs=1000, verbose=0)
# <tensorflow.python.keras.callbacks.History at 0x13c5625d0>
for w in model.trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0 [1.]
# bn_3/beta:0 [0.9988577]
for w in model.non_trainable_weights:
print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0 [99.99573]
# bn_3/moving_variance:0 [4.273953e-05]