note.nkmk.me

NumPy配列ndarrayの下三角行列・上三角行列を抽出・生成(tril, triu, tri)

Date: 2019-09-18 / tags: Python, NumPy

対角線より上の成分が0である行列を下三角行列、対角線より下の成分が0である行列を上三角行列という。

NumPy配列ndarrayから下三角行列・上三角行列を抽出(取得)するにはnumpy.tril(), numpy.triu()、下三角行列・上三角行列を新たに生成するにはnumpy.tri()を使う。

ここでは以下の内容について説明する。

  • NumPy配列ndarrayから下三角行列を抽出(取得): numpy.tril()
  • NumPy配列ndarrayから上三角行列を抽出(取得): numpy.triu()
  • 下三角行列・上三角行列を生成: numpy.tri()

対角行列を生成したり対角成分を抽出したりするにはnumpy.diag()を使う。以下の記事を参照。

スポンサーリンク

NumPy配列ndarrayから下三角行列を抽出(取得): numpy.tril()

numpy.tril()の引数にnumpy.ndarrayを指定すると、対角線より上の成分を0としたnumpy.ndarrayが返される。ビューではなく、元のnumpy.ndarrayとメモリを共有しないコピー。

import numpy as np

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

print(np.tril(a))
# [[ 0  0  0  0]
#  [ 4  5  0  0]
#  [ 8  9 10  0]
#  [12 13 14 15]]

第二引数kで境界となる対角線の位置を指定可能。正の値だと上側(右側)、負の値だと下側(左側)に移動する。

print(np.tril(a, k=2))
# [[ 0  1  2  0]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

print(np.tril(a, k=-1))
# [[ 0  0  0  0]
#  [ 4  0  0  0]
#  [ 8  9  0  0]
#  [12 13 14  0]]

正方行列でなくてもOK。考え方は正方行列の場合と同じ。

a = np.arange(12).reshape(3, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]]

print(np.tril(a))
# [[ 0  0  0  0]
#  [ 4  5  0  0]
#  [ 8  9 10  0]]

print(np.tril(a, k=-1))
# [[0 0 0 0]
#  [4 0 0 0]
#  [8 9 0 0]]

3次元以上の配列の場合、最後の2つの軸(次元)に対して処理が適用される。

print(np.tril(np.arange(32).reshape(2, 4, 4)))
# [[[ 0  0  0  0]
#   [ 4  5  0  0]
#   [ 8  9 10  0]
#   [12 13 14 15]]
# 
#  [[16  0  0  0]
#   [20 21  0  0]
#   [24 25 26  0]
#   [28 29 30 31]]]

print(np.tril(np.arange(16).reshape(1, 1, 4, 4)))
# [[[[ 0  0  0  0]
#    [ 4  5  0  0]
#    [ 8  9 10  0]
#    [12 13 14 15]]]]

NumPy配列ndarrayから上三角行列を抽出(取得): numpy.triu()

numpy.triu()の引数にnumpy.ndarrayを指定すると、対角線より下の成分を0としたnumpy.ndarrayが返される。

a = np.arange(16).reshape(4, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]
#  [12 13 14 15]]

print(np.triu(a))
# [[ 0  1  2  3]
#  [ 0  5  6  7]
#  [ 0  0 10 11]
#  [ 0  0  0 15]]

numpy.tril()と同じく第二引数kを指定可能。

print(np.triu(a, k=2))
# [[0 0 2 3]
#  [0 0 0 7]
#  [0 0 0 0]
#  [0 0 0 0]]

print(np.triu(a, k=-1))
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 0  9 10 11]
#  [ 0  0 14 15]]

正方行列でない場合や3次元以上の配列の場合もnumpy.tril()と同様。

a = np.arange(12).reshape(3, 4)
print(a)
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 8  9 10 11]]

print(np.triu(a))
# [[ 0  1  2  3]
#  [ 0  5  6  7]
#  [ 0  0 10 11]]

print(np.triu(a, k=-1))
# [[ 0  1  2  3]
#  [ 4  5  6  7]
#  [ 0  9 10 11]]
print(np.triu(np.arange(16).reshape(1, 1, 4, 4)))
# [[[[ 0  1  2  3]
#    [ 0  5  6  7]
#    [ 0  0 10 11]
#    [ 0  0  0 15]]]]

下三角行列・上三角行列を生成: numpy.tri()

numpy.tri()は対角線より上の成分が0、それ以外の成分が1の下三角行列を生成する。

第一引数Nに整数値を指定するとNN列の正方行列(二次元配列)が生成される。

デフォルトはデータ型dtypefloat

print(np.tri(4))
# [[1. 0. 0. 0.]
#  [1. 1. 0. 0.]
#  [1. 1. 1. 0.]
#  [1. 1. 1. 1.]]

print(type(np.tri(4)))
# <class 'numpy.ndarray'>

print(np.tri(4).dtype)
# float64
source: numpy_tri.py

データ型は引数dtypeで指定できる。

print(np.tri(4, dtype=int))
# [[1 0 0 0]
#  [1 1 0 0]
#  [1 1 1 0]
#  [1 1 1 1]]

print(np.tri(4, dtype=int).dtype)
# int64
source: numpy_tri.py

numpy.tril(), numpy.triu()と同様に引数kを指定可能。

print(np.tri(4, k=1))
# [[1. 1. 0. 0.]
#  [1. 1. 1. 0.]
#  [1. 1. 1. 1.]
#  [1. 1. 1. 1.]]

print(np.tri(4, k=-1))
# [[0. 0. 0. 0.]
#  [1. 0. 0. 0.]
#  [1. 1. 0. 0.]
#  [1. 1. 1. 0.]]
source: numpy_tri.py

第二引数Mを指定するとNM列となる。

print(np.tri(3, 4))
# [[1. 0. 0. 0.]
#  [1. 1. 0. 0.]
#  [1. 1. 1. 0.]]

print(np.tri(3, 4, k=-1))
# [[0. 0. 0. 0.]
#  [1. 0. 0. 0.]
#  [1. 1. 0. 0.]]
source: numpy_tri.py

上三角行列を生成する関数はないので、numpy.flip()による上下左右反転か、numpy.rot90()による回転を利用する。numpy.rot90()の場合は第二引数に2を指定して180度回転する(90度回転を2回)。

print(np.flip(np.tri(4)))
# [[1. 1. 1. 1.]
#  [0. 1. 1. 1.]
#  [0. 0. 1. 1.]
#  [0. 0. 0. 1.]]

print(np.rot90(np.tri(4), 2))
# [[1. 1. 1. 1.]
#  [0. 1. 1. 1.]
#  [0. 0. 1. 1.]
#  [0. 0. 0. 1.]]
source: numpy_tri.py

numpy.tri()で上三角行列を生成してから反転または回転するので引数kの符号に注意。

print(np.flip(np.tri(4, k=1)))
# [[1. 1. 1. 1.]
#  [1. 1. 1. 1.]
#  [0. 1. 1. 1.]
#  [0. 0. 1. 1.]]

print(np.flip(np.tri(4, k=-1)))
# [[0. 1. 1. 1.]
#  [0. 0. 1. 1.]
#  [0. 0. 0. 1.]
#  [0. 0. 0. 0.]]
source: numpy_tri.py
スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事