TensorFlow, Kerasで重み・バイアスなどの値を取得、可視化

Posted: | Tags: Python, TensorFlow, Keras, 機械学習

TensorFlow, Kerasで構築したモデルやレイヤーの重み(カーネルの重み)やバイアスなどのパラメータの値を取得したり可視化したりする方法について説明する。

  • レイヤーのパラメータ(重み・バイアスなど)を取得
    • get_weights()メソッド
    • weights属性
    • trainable_weights, non_trainable_weights属性
    • kernel, bias属性など
  • モデルのweights属性とget_weights()
  • 畳み込み層(CNN)のフィルタ(カーネル)の重みを可視化

パラメータの値そのものではなくパラメータ数を取得したい場合は以下の記事を参照。

最後に説明する畳み込み層(畳み込みニューラルネットワーク / CNN: Convolutional Neural Network)のフィルタ(カーネル)の重みの可視化は、重みの値をそのまま画像化するという最もシンプルな方法のみ。

なお、weightsget_weights()といった名前にも現れているように、カーネルの重みやバイアスなどのパラメータを総称して重み(Weights)と呼ぶこともある。本記事においても特に厳密に使い分けているわけではないので、適宜読み替えていただきたい。

以下のサンプルコードのTensorFlowのバージョンは2.1.0。TensorFlowに統合されたKerasを使う。

import tensorflow as tf
import numpy as np

print(tf.__version__)
# 2.1.0

以下のモデルを例とする。ただの例なので、なにか意味のあるモデルではない。

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(1, (3, 3), padding='same',
                           name='L0_conv2d', input_shape=(10, 10, 1)),
    tf.keras.layers.Flatten(name='L1_flatten'),
    tf.keras.layers.Dense(10, name='L2_dense', use_bias=False),
    tf.keras.layers.Dense(1, name='L3_dense'),
    tf.keras.layers.BatchNormalization(name='L4_bn')
])

model.summary()
# Model: "sequential"
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #   
# =================================================================
# L0_conv2d (Conv2D)           (None, 10, 10, 1)         10        
# _________________________________________________________________
# L1_flatten (Flatten)         (None, 100)               0         
# _________________________________________________________________
# L2_dense (Dense)             (None, 10)                1000      
# _________________________________________________________________
# L3_dense (Dense)             (None, 1)                 11        
# _________________________________________________________________
# L4_bn (BatchNormalization)   (None, 1)                 4         
# =================================================================
# Total params: 1,025
# Trainable params: 1,023
# Non-trainable params: 2
# _________________________________________________________________

レイヤーのパラメータ(重み・バイアスなど)を取得

get_weights()メソッド

入力10出力1のシンプルな全結合層を例とする。

l3 = model.layers[3]

get_weights()メソッドはリストを返す。このレイヤーの場合、要素数は2

print(type(l3.get_weights()))
# <class 'list'>

print(len(l3.get_weights()))
# 2

要素はnumpy.ndarrayで、このレイヤーの場合、最初の要素がカーネルの重み、2つ目の要素がバイアスの値。重みの値はランダムな初期値となっている。

print(l3.get_weights()[0])
# [[-0.45019907]
#  [ 0.3547594 ]
#  [-0.01801795]
#  [ 0.5543849 ]
#  [-0.13720274]
#  [-0.71705985]
#  [ 0.30951375]
#  [-0.19865453]
#  [ 0.11943179]
#  [ 0.5920785 ]]

print(type(l3.get_weights()[0]))
# <class 'numpy.ndarray'>

print(l3.get_weights()[1])
# [0.]

print(type(l3.get_weights()[1]))
# <class 'numpy.ndarray'>

get_weights()が返すリストの要素数はレイヤーの種類や設定によって異なる。

print(len(model.layers[0].weights))
# 2

print(len(model.layers[1].weights))
# 0

print(len(model.layers[2].weights))
# 1

print(len(model.layers[3].weights))
# 2

print(len(model.layers[4].weights))
# 4

get_weights()の要素はただのnumpy.ndarrayなので、パラメータの値そのものの情報しか持っていない。パラメータ名などを確認したい場合は次に示すweights属性を使う。

weight属性

weights属性もget_weights()メソッドと同じくリストを返す。

print(type(l3.weights))
# <class 'list'>

print(len(l3.weights))
# 2

ただし、その要素はnumpy.ndarrayではなくResourceVariableというクラス。

print(l3.weights[0])
# <tf.Variable 'L3_dense/kernel:0' shape=(10, 1) dtype=float32, numpy=
# array([[-0.45019907],
#        [ 0.3547594 ],
#        [-0.01801795],
#        [ 0.5543849 ],
#        [-0.13720274],
#        [-0.71705985],
#        [ 0.30951375],
#        [-0.19865453],
#        [ 0.11943179],
#        [ 0.5920785 ]], dtype=float32)>

print(type(l3.weights[0]))
# <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>

print(l3.weights[1])
# <tf.Variable 'L3_dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>

print(type(l3.weights[1]))
# <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>

ResourceVariabletf.Variableのサブクラスで、name, shape属性などを持つ。

print(issubclass(type(l3.weights[0]), tf.Variable))
# True

print(l3.weights[0].name)
# L3_dense/kernel:0

print(l3.weights[0].shape)
# (10, 1)

print(l3.weights[1].name)
# L3_dense/bias:0

print(l3.weights[1].shape)
# (1,)

numpy()メソッドで値をnumpy.ndarrayとして取得することも可能。get_weights()メソッドで得られるnumpy.ndarrayと等価。

print(l3.weights[0].numpy())
# [[-0.45019907]
#  [ 0.3547594 ]
#  [-0.01801795]
#  [ 0.5543849 ]
#  [-0.13720274]
#  [-0.71705985]
#  [ 0.30951375]
#  [-0.19865453]
#  [ 0.11943179]
#  [ 0.5920785 ]]

print(l3.weights[1].numpy())
# [0.]

print(np.array_equal(l3.weights[0].numpy(), l3.get_weights()[0]))
# True

print(np.array_equal(l3.weights[1].numpy(), l3.get_weights()[1]))
# True

各レイヤーの情報は以下の通り。モデル生成時にuse_bias=Falseとしたレイヤーにはバイアスがないことや、BatchNormalization層の4つのパラメータの名前などが確認できる。

for w in model.layers[0].weights:
    print('{:<25}{}'.format(w.name, w.shape))
# L0_conv2d/kernel:0       (3, 3, 1, 1)
# L0_conv2d/bias:0         (1,)

print(model.layers[1].weights)
# []

for w in model.layers[2].weights:
    print('{:<25}{}'.format(w.name, w.shape))
# L2_dense/kernel:0        (100, 10)

for w in model.layers[3].weights:
    print('{:<25}{}'.format(w.name, w.shape))
# L3_dense/kernel:0        (10, 1)
# L3_dense/bias:0          (1,)

for w in model.layers[4].weights:
    print('{:<25}{}'.format(w.name, w.shape))
# L4_bn/gamma:0            (1,)
# L4_bn/beta:0             (1,)
# L4_bn/moving_mean:0      (1,)
# L4_bn/moving_variance:0  (1,)

なお、variablesという属性もあるが、これはweightsのエイリアスで中身は等価。どちらを使っても同じ。

variables: Returns the list of all layer variables/weights. Alias of self.weights. tf.keras.layers.Layer | TensorFlow Core v2.1.0

print(l3.weights == l3.variables)
# True

trainable_weights, non_trainable_weights属性

trainable_weights, non_trainable_weightsという属性もある。

weightsのうち、訓練対象のパラメータがtrainable_weightsに、訓練対象でないパラメータがnon_trainable_weightsに含まれる。

基本的には、レイヤーのtrainable属性がTrueであればすべてのパラメータがtrainable_weightsに含まれてnon_trainable_weightsは空、trainable属性がFalseであればすべてのパラメータがnon_trainable_weightsに含まれてtrainable_weightsは空となる。

print(l3.trainable)
# True

print(l3.trainable_weights == l3.weights)
# True

print(l3.non_trainable_weights)
# []

l3.trainable = False

print(l3.non_trainable_weights == l3.weights)
# True

print(l3.trainable_weights)
# []

ただし、BatchNormalization層はレイヤーのtrainable属性がTrueであってもnon_trainable_weightsに割り当てられるパラメータがある。詳細は公式ドキュメントを参照。

print(model.layers[4].trainable)
# True

for w in model.layers[4].trainable_weights:
    print('{:<25}{}'.format(w.name, w.shape))
# L4_bn/gamma:0            (1,)
# L4_bn/beta:0             (1,)

for w in model.layers[4].non_trainable_weights:
    print('{:<25}{}'.format(w.name, w.shape))
# L4_bn/moving_mean:0      (1,)
# L4_bn/moving_variance:0  (1,)

kernel, bias属性など

カーネルの重みやバイアスを直接取得できるkernel, bias属性もある。weights属性の要素と同一オブジェクト。

print(l3.kernel)
# <tf.Variable 'L3_dense/kernel:0' shape=(10, 1) dtype=float32, numpy=
# array([[-0.45019907],
#        [ 0.3547594 ],
#        [-0.01801795],
#        [ 0.5543849 ],
#        [-0.13720274],
#        [-0.71705985],
#        [ 0.30951375],
#        [-0.19865453],
#        [ 0.11943179],
#        [ 0.5920785 ]], dtype=float32)>

print(l3.kernel is l3.weights[0])
# True

print(l3.bias)
# <tf.Variable 'L3_dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>

print(l3.bias is l3.weights[1])
# True

weights属性はすべてのレイヤーの基底クラスであるtf.keras.layers.Layerに含まれる属性なのでどんな種類のレイヤーにもあるのに対し、kernel, bias属性があるのは所定のレイヤーのクラスのみ。

# print(model.layers[1].kernel)
# AttributeError: 'Flatten' object has no attribute 'kernel'

# print(model.layers[4].kernel)
# AttributeError: 'BatchNormalization' object has no attribute 'kernel'

use_bias=Falseとしたレイヤーの場合のbias属性はNone。属性自体はあるのでエラーにはならない。

print(model.layers[2].bias)
# None

所定のレイヤーのクラスには、カーネルとバイアス以外のパラメータに対応する属性が提供されている。

print(model.layers[4].gamma)
# <tf.Variable 'L4_bn/gamma:0' shape=(1,) dtype=float32, numpy=array([1.], dtype=float32)>

print(model.layers[4].gamma is model.layers[4].weights[0])
# True

モデルのweights属性とget_weights()

Kerasにおいては、モデルtf.keras.Modelはレイヤーtf.keras.layers.Layerのサブクラスであるから、レイヤーの属性やメソッドがモデルでも使える。

print(issubclass(tf.keras.Model, tf.keras.layers.Layer))
# True

ちなみに、この例のモデルのようなSequentialModelのサブクラス。当然ながらLayerのサブクラスでもある。

print(type(model))
# <class 'tensorflow.python.keras.engine.sequential.Sequential'>

print(issubclass(tf.keras.Sequential, tf.keras.Model))
# True

print(issubclass(tf.keras.Sequential, tf.keras.layers.Layer))
# True

モデルのweights属性は、モデル内の全てのレイヤーのweightsの要素を含むリストを返す。

print(type(model.weights))
# <class 'list'>

print(len(model.weights))
# 9

print(type(model.weights[0]))
# <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>

for w in model.weights:
    print('{:<25}{}'.format(w.name, w.shape))
# L0_conv2d/kernel:0       (3, 3, 1, 1)
# L0_conv2d/bias:0         (1,)
# L2_dense/kernel:0        (100, 10)
# L3_dense/kernel:0        (10, 1)
# L3_dense/bias:0          (1,)
# L4_bn/gamma:0            (1,)
# L4_bn/beta:0             (1,)
# L4_bn/moving_mean:0      (1,)
# L4_bn/moving_variance:0  (1,)

weights属性やtrainable_weights, non_trainable_weights属性は基底クラスのLayerに含まれているので、モデルでも使えるが、kernelbiasなどの特定の種類のレイヤーにのみ設定されている属性はモデルでは使えない。

# print(model.kernel)
# AttributeError: 'Sequential' object has no attribute 'kernel'

get_weights()メソッドはLayerに含まれるメソッドなのでモデルでも使える。モデル内の全てのレイヤーのget_weights()が返すリストの要素を含むリストを返す。要素はただのnumpy.ndarrayなので、どれがどのレイヤーの何なのかという情報は得られない。

print(type(model.get_weights()))
# <class 'list'>

print(len(model.get_weights()))
# 9

print(type(model.get_weights()[0]))
# <class 'numpy.ndarray'>

for a in model.get_weights():
    print(a.shape)
# (3, 3, 1, 1)
# (1,)
# (100, 10)
# (10, 1)
# (1,)
# (1,)
# (1,)
# (1,)
# (1,)

ネストしたモデルの場合

上述のように、モデルはレイヤーのサブクラスなので、モデルを一つのレイヤーとして扱うこともできる。

以下のようにネストした(入れ子になった)モデルを例とする。

inner_model = tf.keras.Sequential([
    tf.keras.layers.Dense(100, name='L_in_0', input_shape=(1000,)),
    tf.keras.layers.Dense(10, name='L_in_1')
], name='Inner_model')

outer_model = tf.keras.Sequential([
    inner_model,
    tf.keras.layers.Dense(1, name='L_out_1')
])

outer_model.summary()
# Model: "sequential_1"
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #   
# =================================================================
# Inner_model (Sequential)     (None, 10)                101110    
# _________________________________________________________________
# L_out_1 (Dense)              (None, 1)                 11        
# =================================================================
# Total params: 101,121
# Trainable params: 101,121
# Non-trainable params: 0
# _________________________________________________________________

outer_model.layers[0].summary()
# Model: "Inner_model"
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #   
# =================================================================
# L_in_0 (Dense)               (None, 100)               100100    
# _________________________________________________________________
# L_in_1 (Dense)               (None, 10)                1010      
# =================================================================
# Total params: 101,110
# Trainable params: 101,110
# Non-trainable params: 0
# _________________________________________________________________

weights, get_weights()いずれも、内側のモデルのパラメータを再帰的に取得してリスト化する。

print(len(outer_model.weights))
# 6

for w in outer_model.weights:
    print('{:<25}{}'.format(w.name, w.shape))
# L_in_0/kernel:0          (1000, 100)
# L_in_0/bias:0            (100,)
# L_in_1/kernel:0          (100, 10)
# L_in_1/bias:0            (10,)
# L_out_1/kernel:0         (10, 1)
# L_out_1/bias:0           (1,)

print(len(outer_model.get_weights()))
# 6

for a in outer_model.get_weights():
    print(a.shape)
# (1000, 100)
# (100,)
# (100, 10)
# (10,)
# (10, 1)
# (1,)

畳み込み層(CNN)のフィルタ(カーネル)の重みを可視化

取得したパラメータの活用法として、畳み込み層(畳み込みニューラルネットワーク / CNN: Convolutional Neural Network)のフィルタ(カーネル)の可視化の例を示す。

CNNの可視化には様々な手法が提案されている。

ここでは最もシンプルな、最初のレイヤーのフィルターの重みの値をそのまま画像に変換する例を示す。

サンプルコード全体は以下。

Kerasに同梱されたResNet50の学習済みモデルを使う。初回実行時に学習データのダウンロードが行われる。

ResNetを使うのは、最初のレイヤーのフィルタのサイズが7 x 7で比較的大きく結果が分かりやすいため。他のモデルでもやり方は同じ。

import tensorflow as tf
from PIL import Image
import skimage.util

print(tf.__version__)
# 2.1.0

model = tf.keras.applications.ResNet50()

最初の畳み込み層のフィルタ(カーネル)の重みの値をnumpy.ndarrayとして取得。レイヤー名はsummary()の出力などで確認できる(ここでは省略)。

w = model.get_layer('conv1_conv').kernel.numpy()
print(type(w))
# <class 'numpy.ndarray'>

print(w.shape)
# (7, 7, 3, 64)

形状(7, 7, 3, 64)は、サイズが7 x 7の3チャンネルのフィルタが64個あるという意味。

画像に変換するために、値を0 - 255の範囲に規格化しunit8に変換する。

print(w.min(), w.max())
# -0.6710244 0.70432377

w_scale = ((w - w.min()) / (w.max() - w.min()) * 255).astype('uint8')
print(w_scale.min(), w_scale.max())
# 0 255

64個のフィルターを並べて表示するために後ほど使うskimage.util.montage()に合わせて、形状が(64, 7, 7, 3)となるように次元(軸)を入れ替える。

w_transpose = w_scale.transpose(3, 0, 1, 2)
print(w_transpose.shape)
# (64, 7, 7, 3)

ResNet50の学習済み重みデータはcaffeで訓練されており、画像のチャンネルの並びがBGR前提。フィルターのチャンネルもBGRの並びになっているのでRGBに並べ替える。

w_transpose_rgb = w_transpose[...,::-1]

ちなみに、ResNetV2はTensorFlowで訓練されており、画像のチャンネルの並びはRGB前提なので、並べ替える必要はない。

各モデルの中でimagenet_utils.pypreprocess_input()を呼んでおり、そのときの引数modeで判別できる。デフォルトはmode='caffe'

skimage.util.montage()で縦横に並べる。引数multichannel=Trueを忘れないように注意。

montage = skimage.util.montage(w_transpose_rgb, multichannel=True)
print(montage.shape)
# (56, 56, 3)

PIllowの画像に変換、リサイズしてから、画像ファイルとして保存する。

pil_img = Image.fromarray(montage).resize(
    (montage.shape[1] * 8, montage.shape[0] * 8)
)

pil_img.save('../data/img/dst/resnet50_conv1_conv.png')

リサイズや画像保存はPillow以外の方法でも構わないが、リサイズ時の補間方法がバイリニアなどだとなめらかに補間されてしまうため、ニアレストネイバーのほうが結果が見やすい。Pillowのリサイズはデフォルトでニアレストネイバー。

結果の画像は以下の通り。エッジやブロブ(塊)を抽出するフィルタが学習されていることが分かる。

ResNet50 first conv visualization

最初のレイヤーのフィルタは上の例のように3チャンネル(RGB)の画像としてそのまま可視化できるが、以降のCNNのフィルタは単純にはいかない。

参考までに中間層のフィルタを見てみる。ここからは、うまくいかないことを示すための例。

w = model.get_layer('conv3_block1_2_conv').kernel.numpy()
print(type(w))
# <class 'numpy.ndarray'>

print(w.shape)
# (3, 3, 128, 128)

形状(3, 3, 128, 128)3 x 3の128チャンネルのフィルタが128個あるという意味。

3 x 3のフィルタを白黒(グレースケール)として並べて画像化する。流れは上の例と同じ。

w = ((w - w.min()) / (w.max() - w.min()) * 255).astype('uint8')
print(w.min(), w.max())
# 0 255

w = w.reshape(3, 3, 128 * 128).transpose(2, 0, 1)
print(w.shape)
# (16384, 3, 3)

montage = skimage.util.montage(w)
print(montage.shape)
# (384, 384)

Image.fromarray(montage).save('../data/img/dst/resnet50_conv3_block1_2_conv.png')

結果の画像は以下の通り。人間が何らかの意味を読み取れるものではない。

ResNet50 middle conv visualization

中間層のCNNに対しては、フィルタの重みを直接見るのではなく、各レイヤーの出力(特に強く活性化する出力)を見るといった手法が提案されている。

関連カテゴリー

関連記事