note.nkmk.me

PyTorchのTensorの要素の値を取得: item()

Posted: 2021-02-14 / Tags: Python, PyTorch

PyTorchテンソルtorch.Tensorの要素をPython組み込み型(intfloat)の値として取得するにはitem()メソッドを使う。

ここでは以下の内容について説明する。

  • torch.Tensorの要素の値を取得: item()
  • 複数要素のテンソルの場合
  • int()float()で変換
  • そのほかの関数・メソッドの場合

なお、PyTorchテンソルtorch.TensorをリストやNumPy配列numpy.ndarrayに変換するには、tolist(), numpy()メソッドを使う。

本記事のサンプルコードにおけるPyTorchのバージョンは1.7.1。以下のtorch.Tensorを例とする。

import torch

print(torch.__version__)
# 1.7.1

t = torch.arange(6).reshape(2, 3)
print(t)
# tensor([[0, 1, 2],
#         [3, 4, 5]])
スポンサーリンク

torch.Tensorの要素の値を取得: item()

torch.Tensornumpy.ndarrayと同じようにインデックスで要素を指定できるが、0次元のtorch.Tensorとして扱われる。

print(t[1, 1])
# tensor(4)

print(type(t[1, 1]))
# <class 'torch.Tensor'>

print(t[1, 1].ndim)
# 0

item()メソッドで、Python組み込み型(この場合は整数int)として要素の値を取得できる。

print(t[1, 1].item())
# 4

print(type(t[1, 1].item()))
# <class 'int'>

複数要素のテンソルの場合

item()で変換できるのは要素数が1個のtorch.Tensorのみ。複数要素の場合はエラーになる。

print(t[:2, 1])
# tensor([1, 4])

# print(t[:2, 1].item())
# ValueError: only one element tensors can be converted to Python scalars

複数要素のPyTorchテンソルtorch.TensorをリストやNumPy配列numpy.ndarrayに変換するには、tolist(), numpy()を使う。

なお、item()で変換できるかどうかは次元数ではなく要素数で決まる。多次元テンソルでも要素数が1個であれば変換可能。

print(t[1, [1]])
# tensor([4])

print(t[1, [1]].ndim)
# 1

print(t[1, [1]].item())
# 4

int()やfloat()で変換

要素数1個のテンソルはint()float()でも変換できる。

print(int(t[1, 1]))
# 4

print(float(t[1, 1]))
# 4.0

str()'tensor(xxx)'のような文字列に変換するので注意。値を文字列に変換したい場合はitem()などを合わせて使う。

print(str(t[1, 1]))
# tensor(4)

print(type(str(t[1, 1])))
# <class 'str'>

print(str(t[1, 1].item()))
# 4

print(type(str(t[1, 1].item())))
# <class 'str'>

int(), float()item()メソッドと同じく複数要素のテンソルは変換できない。多次元でも要素数が1個であれば変換可能。

# print(int(t[:2, 1]))
# ValueError: only one element tensors can be converted to Python scalars

print(int(t[1, [1]]))
# 4

# print(float(t[:2, 1]))
# ValueError: only one element tensors can be converted to Python scalars

print(float(t[1, [1]]))
# 4.0

str()は要素数に限らず'tensor(xxx)'のような文字列に変換する。

print(str(t[:2, 1]))
# tensor([1, 4])

print(type(str(t[:2, 1])))
# <class 'str'>

そのほかの関数・メソッドの場合

torch.Tensorの最大値や最小値、合計などを算出するtorch.max(), torch.min(), torch.sum()なども、結果として0次元のtorch.Tensorを返す場合がある。

Python組み込み型(intfloat)の値として取得するにはitem()を使えばよい。

print(torch.max(t))
# tensor(5)

print(torch.max(t).item())
# 5

print(torch.sum(t))
# tensor(15)

print(torch.sum(t).item())
# 15

平均torch.mean()や標準偏差torch.std()なども同様。

なお、常に0次元のtorch.Tensorを返すわけではなく、次元ごとに算出するように引数を設定した場合などは複数要素のテンソルを返す。

print(torch.sum(t, 0))
# tensor([3, 5, 7])
スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事