PyTorchのTensorの要素の値を取得: item()
PyTorchテンソルtorch.Tensor
の要素をPython組み込み型(int
やfloat
)の値として取得するにはitem()
メソッドを使う。
ここでは以下の内容について説明する。
torch.Tensor
の要素の値を取得:item()
- 複数要素のテンソルの場合
int()
やfloat()
で変換- そのほかの関数・メソッドの場合
なお、PyTorchテンソルtorch.Tensor
をリストやNumPy配列numpy.ndarray
に変換するには、tolist()
, numpy()
メソッドを使う。
- torch.Tensor.tolist() — PyTorch 1.7.1 documentation
- torch.Tensor.numpy() — PyTorch 1.7.1 documentation
本記事のサンプルコードにおけるPyTorchのバージョンは1.7.1
。以下のtorch.Tensor
を例とする。
import torch
print(torch.__version__)
# 1.7.1
t = torch.arange(6).reshape(2, 3)
print(t)
# tensor([[0, 1, 2],
# [3, 4, 5]])
torch.Tensorの要素の値を取得: item()
torch.Tensor
もnumpy.ndarray
と同じようにインデックスで要素を指定できるが、0次元のtorch.Tensor
として扱われる。
print(t[1, 1])
# tensor(4)
print(type(t[1, 1]))
# <class 'torch.Tensor'>
print(t[1, 1].ndim)
# 0
item()
メソッドで、Python組み込み型(この場合は整数int
)として要素の値を取得できる。
print(t[1, 1].item())
# 4
print(type(t[1, 1].item()))
# <class 'int'>
複数要素のテンソルの場合
item()
で変換できるのは要素数が1個のtorch.Tensor
のみ。複数要素の場合はエラーになる。
print(t[:2, 1])
# tensor([1, 4])
# print(t[:2, 1].item())
# ValueError: only one element tensors can be converted to Python scalars
複数要素のPyTorchテンソルtorch.Tensor
をリストやNumPy配列numpy.ndarray
に変換するには、tolist()
, numpy()
を使う。
- torch.Tensor.tolist() — PyTorch 1.7.1 documentation
- torch.Tensor.numpy() — PyTorch 1.7.1 documentation
なお、item()
で変換できるかどうかは次元数ではなく要素数で決まる。多次元テンソルでも要素数が1個であれば変換可能。
print(t[1, [1]])
# tensor([4])
print(t[1, [1]].ndim)
# 1
print(t[1, [1]].item())
# 4
int()やfloat()で変換
要素数1個のテンソルはint()
やfloat()
でも変換できる。
print(int(t[1, 1]))
# 4
print(float(t[1, 1]))
# 4.0
str()
は'tensor(xxx)'
のような文字列に変換するので注意。値を文字列に変換したい場合はitem()
などを合わせて使う。
print(str(t[1, 1]))
# tensor(4)
print(type(str(t[1, 1])))
# <class 'str'>
print(str(t[1, 1].item()))
# 4
print(type(str(t[1, 1].item())))
# <class 'str'>
int()
, float()
はitem()
メソッドと同じく複数要素のテンソルは変換できない。多次元でも要素数が1個であれば変換可能。
# print(int(t[:2, 1]))
# ValueError: only one element tensors can be converted to Python scalars
print(int(t[1, [1]]))
# 4
# print(float(t[:2, 1]))
# ValueError: only one element tensors can be converted to Python scalars
print(float(t[1, [1]]))
# 4.0
str()
は要素数に限らず'tensor(xxx)'
のような文字列に変換する。
print(str(t[:2, 1]))
# tensor([1, 4])
print(type(str(t[:2, 1])))
# <class 'str'>
そのほかの関数・メソッドの場合
torch.Tensor
の最大値や最小値、合計などを算出するtorch.max()
, torch.min()
, torch.sum()
なども、結果として0次元のtorch.Tensor
を返す場合がある。
Python組み込み型(int
やfloat
)の値として取得するにはitem()
を使えばよい。
print(torch.max(t))
# tensor(5)
print(torch.max(t).item())
# 5
print(torch.sum(t))
# tensor(15)
print(torch.sum(t).item())
# 15
平均torch.mean()
や標準偏差torch.std()
なども同様。
なお、常に0次元のtorch.Tensor
を返すわけではなく、次元ごとに算出するように引数を設定した場合などは複数要素のテンソルを返す。
print(torch.sum(t, 0))
# tensor([3, 5, 7])