TensorFlow 2.0におけるBatch Normalizationの動作(training, trainable)

Posted: | Tags: Python, TensorFlow, Keras, 機械学習

TensorFlow2.0以降(TF2)におけるBatch Normalization(Batch Norm)層、tf.keras.layers.BatchNormalizationの動作について、引数trainingおよびtrainable属性と訓練モード・推論モードの関係を中心に、以下の内容を説明する。

  • Batch Normalization(Batch Norm)のアルゴリズム
  • BatchNormalization層のTrainable params / Non-trainable params
  • BatchNormalization層の訓練モードと推論モード
  • 実際の動作の確認
    • 訓練モード(training=True
    • 推論モード(training=False
    • trainable属性がFalseのとき

最初にまとめておくと、TensorFlow2.0以降(TF2)のBatchNormalizationの動作は以下の通り。

  • 訓練モード(training=True
    • ミニバッチの平均と分散で正規化する
    • 平均と分散の移動平均moving_meanmoving_varianceを更新する
  • 推論モード(training=False
    • moving_meanmoving_varianceで正規化する
    • moving_meanmoving_varianceを更新しない
  • trainable属性がFalseのとき
    • trainingの値によらず、常に推論モードで動作する(※TensorFlow 2.0以降)
      • fit()などのメソッドでも推論モード

trainable属性の基本的な使い方については以下の記事を参照。

以下のサンプルコードのTensorFlowのバージョンは2.1.0。TensorFlowに統合されたKerasを使う。

import tensorflow as tf
import numpy as np

print(tf.__version__)
# 2.1.0

Batch Normalization(Batch Norm)のアルゴリズム

Batch Normalization(Batch Norm)のアルゴリズムを簡単に示す。

まず、ミニバッチごとに平均が0、分散が1となるようにデータの正規化を行う。数式で表すと以下の通り。$\mu$ は平均、$\sigma^2$ は分散。$\epsilon$ はゼロ除算を防ぐための小さな値。

$$ \hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} $$

さらに、正規化されたデータに対してスケーリングとシフトを行う。その係数が $\gamma$ と $\beta$ で、訓練によって調整されるパラメータ。

$$ y = \gamma \hat{x} + \beta $$

訓練(学習)時はミニバッチの平均・分散を使って正規化するが、推論時は一般的にそれらの移動平均を使って正規化することが多い。

BatchNormalization層のTrainable params / Non-trainable params

TensorFlow2.0以降(TF2)ではKerasとの統合が強化され、Kerasで提供されているレイヤー(または、Kerasのレイヤーの基底クラスを継承したカスタムレイヤー)の使用が推奨されている。

Batch Normalizationはtf.keras.layers.BatchNormalizationとして提供されている。

以下のBatchNormalizationだけのモデルを例とする。

model = tf.keras.Sequential([
    tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])

model.summary()
# Model: "sequential"
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #   
# =================================================================
# bn (BatchNormalization)      (None, 1)                 4         
# =================================================================
# Total params: 4
# Trainable params: 2
# Non-trainable params: 2
# _________________________________________________________________

print(model.layers[0].trainable)
# True

summary()の出力を見ると分かるように、Batch Normalization層はtrainable属性がTrueでもNon-trainable params(訓練対象ではないパラメータ)を含む。

trainable_weightsおよびnon_trainable_weights属性でそれぞれのパラメータの内容を確認できる。

for w in model.trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn/gamma:0                    [1.]
# bn/beta:0                     [0.]

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn/moving_mean:0              [0.]
# bn/moving_variance:0          [1.]

訓練対象のパラメータgammabetaがスケーリングとシフトの係数 $\gamma$ と $\beta$、訓練対象でないパラメータmoving_meanmoving_varianceが推論時に使われる、平均・分散の移動平均を表している。

レイヤーまたはモデルのtrainable属性をFalseとするとFreezeされ、gammabetaNon-trainable params(訓練対象ではないパラメータ)となる。

model.layers[0].trainable = False

print(model.layers[0].trainable_weights)
# []

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn/gamma:0                    [1.]
# bn/beta:0                     [0.]
# bn/moving_mean:0              [0.]
# bn/moving_variance:0          [1.]

BatchNormalization層の訓練モードと推論モード

BatchNormalization層の訓練モードと推論モードの切り替えはtrainingによって制御される。

  • training: Python boolean indicating whether the layer should behave in training mode or in inference mode.
    • training=True: The layer will normalize its inputs using the mean and variance of the current batch of inputs.
    • training=False: The layer will normalize its inputs using the mean and variance of its moving statistics, learned during training. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0

trainingtrainable属性とは別のもので、model(training=False)のようにモデル自体をcallする(呼び出す)ときに指定する引数。

training=Trueが訓練モード、training=Falseが推論モードで、fit()predict()のようなメソッドでは自動的に適切なモードが選択される。例えば、fit()では訓練モード、predict()では推論モードとなる。

上述のように、BatchNormalizationmoving_mean, moving_varianceは、推論時に使われる移動平均された平均・分散。

それらの値は訓練時に算出されるが、勾配計算によって更新されるものではないため、レイヤーのtrainable属性がTrueでもFalseでもNon-trainable paramsとして扱われる。

2) Updates to the weights (moving statistics) are based on the forward pass of a model rather than the result of gradient computations. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0

TensorFlow 2.0以降は、trainableFalseとすると常に推論モード(training=Falseのときの挙動)で動作するようになった。

About setting layer.trainable = False on a `BatchNormalization layer:

The meaning of setting layer.trainable = False is to freeze the layer, i.e. its internal state will not change during training: its trainable weights will not be updated during fit() or train_on_batch(), and its state updates will not be run.

Usually, this does not necessarily mean that the layer is run in inference mode (which is normally controlled by the training argument that can be passed when calling a layer). "Frozen state" and "inference mode" are two separate concepts.

However, in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0

TensorFlow 1.xでは挙動が異なるので注意。trainableFalseでも推論モードにはならない。

This behavior only occurs as of TensorFlow 2.0. In 1.*, setting layer.trainable = False would freeze the layer but would not switch it to inference mode. tf.keras.layers.BatchNormalization | TensorFlow Core v2.1.0

実際の動作の確認

まとめるとBatchNormalizationの動作は以下の通り。

  • 訓練モード(training=True
    • ミニバッチの平均と分散で正規化する
    • 平均と分散の移動平均moving_meanmoving_varianceを更新する
  • 推論モード(training=False
    • moving_meanmoving_varianceで正規化する
    • moving_meanmoving_varianceを更新しない
  • trainable属性がFalseのとき
    • trainingの値によらず、常に推論モードで動作する(※TensorFlow 2.0以降)
      • fit()などのメソッドでも推論モード

簡単なサンプルコードで実際の動作を確認する。

冒頭に書いたように、サンプルコードのTensorFlowのバージョンは2.1.0。TensorFlowに統合されたKerasを使う。バージョンが違う場合は異なる結果となる可能性があるので注意。

デフォルトではmoving_meanmoving_varianceの初期値はそれぞれ01

model = tf.keras.Sequential([
    tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0            [0.]
# bn_1/moving_variance:0        [1.]

値が100の要素が1個だけのデータを使う。バッチ分の次元が必要なので形状(1, 1)の二次元としている。

a = np.array([[100]]).astype('float32')
print(a)
# [[100.]]

print(a.shape)
# (1, 1)

訓練モード(training=True)

このモデルを訓練モード(training=True)で呼び出すと、出力は0で、moving_meanmoving_varianceの値が更新される。

print(model(a, training=True))
# tf.Tensor([[0.]], shape=(1, 1), dtype=float32)

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0            [1.]
# bn_1/moving_variance:0        [0.99]

出力が0なのは、訓練モードではミニバッチごとの平均 $\mu$ と分散 $\sigma^2$ が使われるから。この場合は要素が1個だけなので $x - \mu = 0$ となる。

moving_meanmoving_varianceの値はミニバッチの平均・分散の移動平均で、モデルを呼び出すたびに以下のように更新される。

moving_mean = moving_mean * momentum + mean * (1 - momentum)
moving_variance = moving_variance * momentum + variance * (1 - momentum)

momentumtf.keras.layers.BatchNormalization()の引数で設定可能で、デフォルト値は0.99meanvarianceはミニバッチの平均と分散(この例では1000)。

初期値から1回更新した値を計算すると、上のサンプルコードの結果と一致することが分かる。

moving_mean = 0 * 0.99 + 100 * (1 - 0.99) = 1
moving_variance = 1 * 0.99 + 0 * (1 - 0.99) = 0.99

なお、Kerasでは更新前のmoving_mean, moving_varianceにかける係数をmomentumと呼ぶが、PyTorchではミニバッチの平均・分散にかける係数(上の式では1 - momentum)をmomentumと呼ぶという違いがある。

同じデータで呼び出し続けるとそのデータの平均と分散(この例では1000)に近づいていくことが確認できる。

for i in range(1000):
    model(a, training=True)

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_1/moving_mean:0            [99.99573]
# bn_1/moving_variance:0        [4.273953e-05]

訓練モードではmoving_meanmoving_varianceの値は使われないので、それらの値が更新されても呼び出し時の出力は変わらない。

print(model(a, training=True))
# tf.Tensor([[0.]], shape=(1, 1), dtype=float32)

推論モード(training=False)での呼び出しや、predict()では、更新されたmoving_meanmoving_varianceの値が使われる。

print(model(a, training=False))
# tf.Tensor([[0.13110352]], shape=(1, 1), dtype=float32)

print(model.predict(a))
# [[0.13110352]]

trainable属性をFalseとすると常に推論モードで動作する。training=Trueでもmoving_meanmoving_varianceの値が使われる。

model.layers[0].trainable = False

print(model(a, training=True))
# tf.Tensor([[0.13110352]], shape=(1, 1), dtype=float32)

推論モード(training=False)

推論モードではmoving_meanmoving_varianceの値は更新されないため、何度モデルを呼び出しても同じ値のまま変わらない。

model = tf.keras.Sequential([
    tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_2/moving_mean:0            [0.]
# bn_2/moving_variance:0        [1.]

for i in range(1000):
    model(a, training=False)

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_2/moving_mean:0            [0.]
# bn_2/moving_variance:0        [1.]

推論モードでの呼び出しとpredict()の出力は以下の通り。ミニバッチの平均・分散ではなくmoving_meanmoving_varianceの値(ここでは01)が使われる。$\epsilon$ の初期値が0.001なので、数式通りの結果となっていることが分かる。

print(model(a, training=False))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)

print(model.predict(a))
# [[99.95004]]

print((100 - 0) / np.sqrt(1 + 0.001))
# 99.95003746877732

$\epsilon$ の初期値はtf.keras.layers.BatchNormalization()の引数epsilonで設定可能。

trainable属性がFalseのとき

trainable属性をFalseとすると、trainingの値によらず常に推論モードで動作する(※TensorFlow 2.0以降)。

training=Trueで呼び出してもmoving_meanmoving_varianceの値は更新されない。

model = tf.keras.Sequential([
    tf.keras.layers.BatchNormalization(name='bn', input_shape=(1,))
])

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0            [0.]
# bn_3/moving_variance:0        [1.]

model.layers[0].trainable = False

for i in range(1000):
    model(a, training=True)

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0                  [1.]
# bn_3/beta:0                   [0.]
# bn_3/moving_mean:0            [0.]
# bn_3/moving_variance:0        [1.]

training=Trueで呼び出しても正規化にはmoving_meanmoving_varianceの値が使われる。

print(model(a, training=True))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)

print(model(a, training=False))
# tf.Tensor([[99.95004]], shape=(1, 1), dtype=float32)

print(model.predict(a))
# [[99.95004]]

trainable属性がFalseの場合、通常は訓練モードで動作するfit()などのメソッドにおいても推論モードで動作する。

fit()を実行してもmoving_meanmoving_varianceの値は更新されない。

model.compile(optimizer='adam', loss='mean_squared_error')

model.fit(a, a, verbose=0)
# WARNING:tensorflow:The list of trainable weights is empty. Make sure that you are not setting model.trainable to False before compiling the model.
# 
# <tensorflow.python.keras.callbacks.History at 0x13c3e78d0>

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0                  [1.]
# bn_3/beta:0                   [0.]
# bn_3/moving_mean:0            [0.]
# bn_3/moving_variance:0        [1.]

trainable属性をTrueとすると、fit()では訓練モードで動作する。gamma, betaはもちろん、moving_meanmoving_varianceの値も更新される。

model.layers[0].trainable = True

model.compile(optimizer='adam', loss='mean_squared_error')

model.fit(a, a, verbose=0)
# <tensorflow.python.keras.callbacks.History at 0x13c3d1e50>

for w in model.trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0                  [1.]
# bn_3/beta:0                   [0.001]

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0            [1.]
# bn_3/moving_variance:0        [0.99]

上述のように、moving_meanmoving_varianceは勾配計算によって調整されるものではないためnon_trainable_weightsに含まれるが、値は更新されるので要注意。

訓練モードでの呼び出しの例と同様に、エポック数を増やすとさらに更新されていく。

model.fit(a, a, epochs=1000, verbose=0)
# <tensorflow.python.keras.callbacks.History at 0x13c5625d0>

for w in model.trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/gamma:0                  [1.]
# bn_3/beta:0                   [0.9988577]

for w in model.non_trainable_weights:
    print('{:<30}{}'.format(w.name, w.numpy()))
# bn_3/moving_mean:0            [99.99573]
# bn_3/moving_variance:0        [4.273953e-05]

関連カテゴリー

関連記事