note.nkmk.me

NumPyのeyeまたはidentityでone-hot表現に変換

Date: 2017-08-17 / tags: Python, NumPy
スポンサーリンク

one-hot表現とは

1つだけが1(high)で、それ以外は0(low)のビット列をone-hotと呼ぶ。1-of-K表現とも呼ばれる。

ちなみに、1つだけが0でそれ以外が1であるビット列をone-coldと呼ぶこともある。らしい。

TensorFlowなどの機械学習で分類を行う際には、正解ラベルをone-hotで表現する必要がある。例えば、手書き数字(09の10種類)のデータセットであるMNISTで正解となるラベルが2の場合、one-hotで表すと、[0,0,1,0,0,0,0,0,0,0]となる。

NumPyのeye関数またはidentity関数を使うと簡単にone-hot表現に変換できる。

numpy.eye()

numpy.eye()は、1が斜めに並んで、それ以外は0となる2次元のndarrayを返す関数。

e = np.eye(4)
print(type(e))
print(e)
print(e.dtype)
# <class 'numpy.ndarray'>
# [[ 1.  0.  0.  0.]
#  [ 0.  1.  0.  0.]
#  [ 0.  0.  1.  0.]
#  [ 0.  0.  0.  1.]]
# float64

デフォルトのデータ型はfloat64。引数dtypeでデータ型を指定できる。

e = np.eye(4, M=3, k=1, dtype=np.int8)
print(e)
print(e.dtype)
# [[0 1 0]
#  [0 0 1]
#  [0 0 0]
#  [0 0 0]]
# int8

引数Mで列のサイズ、k1の始まり位置を変えられる。…が、どういう時に使うのかよく分からない。

numpy.identity()

numpy.identity()は名前の通り、単位行列(identity matrix)を返す関数。

i = np.identity(4)
print(i)
print(i.dtype)
# [[ 1.  0.  0.  0.]
#  [ 0.  1.  0.  0.]
#  [ 0.  0.  1.  0.]
#  [ 0.  0.  0.  1.]]
# float64

デフォルトのデータ型はfloat64。引数dtypeでデータ型を指定できる。

i = np.identity(4, dtype=np.uint8)
print(i)
print(i.dtype)
# [[1 0 0 0]
#  [0 1 0 0]
#  [0 0 1 0]
#  [0 0 0 1]]
# uint8

他の引数はない。

なぜ同じような関数が2つもあるのかと思って、ソースを見てみると、numpy.identity()は内部でnumpy.eye()を呼んでいるだけ。

    from numpy import eye
    return eye(n, dtype=dtype)
source: numeric.py

numpy.eye()でも単位行列は得られるけど、numpy.identity()というわかりやすい名前の関数も用意してある、ということだろう。

one-hot表現に変換

単位行列があればone-hot表現に変換するのは簡単。

例えば変換元が10種類の場合は、10×10の単位行列を作ってインデックスに変換元の値をいれてやればいい。

a = [3, 0, 8, 1, 9]
a_one_hot = np.identity(10)[a]
print(a)
print(a_one_hot)
# [3, 0, 8, 1, 9]
# [[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
#  [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
#  [ 0.  0.  0.  0.  0.  0.  0.  0.  1.  0.]
#  [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
#  [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]]

適当な09の値をもつ配列aをone-hot表現に変換している。

irisデータセットのように正解ラベルが3種類の場合もやり方は同じ。

a = [2, 2, 0, 1, 0]
a_one_hot = np.identity(3)[a]
print(a)
print(a_one_hot)
# [2, 2, 0, 1, 0]
# [[ 0.  0.  1.]
#  [ 0.  0.  1.]
#  [ 1.  0.  0.]
#  [ 0.  1.  0.]
#  [ 1.  0.  0.]]

データ型は引数dtypeで適宜指定すればOK。

例ではわかりやすい名前のnumpy.identity()を使っているが、numpy.eye()でも同じ。お好みで。

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事