PyTorchのTensorの次元数、形状、要素数を取得: dim(), size(), numel()

Posted: | Tags: Python, PyTorch

PyTorchテンソルtorch.Tensorの次元数、形状、要素数を取得するには、dim(), size(), numel()などを使う。エイリアスもいくつか定義されている。

ここでは以下の内容について説明する。

  • torch.Tensorの次元数を取得: dim(), ndimension(), ndim
  • torch.Tensorの形状を取得: size(), shape
  • torch.Tensorの要素数を取得: numel(), nelement()

NumPy配列numpy.ndarrayの次元数、形状、要素数の取得については以下の記事を参照。

本記事のサンプルコードにおけるPyTorchのバージョンは1.7.1。以下のtorch.Tensorを例とする。

import torch

print(torch.__version__)
# 1.7.1

t = torch.zeros(2, 3)
print(t)
# tensor([[0., 0., 0.],
#         [0., 0., 0.]])

torch.Tensorの次元数を取得: dim(), ndimension(), ndim

torch.Tensorの次元数はdim()メソッドで取得できる。返り値は整数int

print(t.dim())
# 2

print(type(t.dim()))
# <class 'int'>

ndimension()メソッドや、numpy.ndarrayと同じndim属性も使用できる。

print(t.ndimension())
# 2

print(t.ndim)
# 2

torch.Tensorの形状を取得: size(), shape

torch.Tensorの形状はsize()メソッドで取得できる。返り値はtorch.Size

print(t.size())
# torch.Size([2, 3])

print(type(t.size()))
# <class 'torch.Size'>

torch.Sizeはタプルのサブクラス。各要素をインデックスで指定して取得したり、アンパックで個別に取得したりできる。要素は整数int

print(issubclass(type(t.size()), tuple))
# True

print(t.size()[1])
# 3

print(type(t.size()[1]))
# <class 'int'>

a, b = t.size()
print(a)
# 2

print(b)
# 3

バージョン1.7.1時点では公式ドキュメントに見当たらなかったが、numpy.ndarrayと同様にshape属性も使用できる。

print(t.shape)
# torch.Size([2, 3])

torch.Tensorの要素数を取得: numel(), nelement()

torch.Tensorの全要素数はtorch.numel()関数で取得できる。返り値は整数int

print(torch.numel(t))
# 6

print(type(torch.numel(t)))
# <class 'int'>

numel()torch.Tensorのメソッドとしても定義されている。また、そのエイリアスとしてnelement()メソッドも提供されている。

print(t.numel())
# 6

print(t.nelement())
# 6

numpy.ndarraysize属性で要素数を取得できるが、torch.Tensorでは使えないので注意。

関連カテゴリー

関連記事