PyTorchのTensorの次元数、形状、要素数を取得: dim(), size(), numel()
PyTorchテンソルtorch.Tensorの次元数、形状、要素数を取得するには、dim(), size(), numel()などを使う。エイリアスもいくつか定義されている。
- torch.Tensor.dim() — PyTorch 1.7.1 documentation
- torch.Tensor.size() — PyTorch 1.7.1 documentation
- torch.numel() — PyTorch 1.7.1 documentation
ここでは以下の内容について説明する。
torch.Tensorの次元数を取得:dim(),ndimension(),ndimtorch.Tensorの形状を取得:size(),shapetorch.Tensorの要素数を取得:numel(),nelement()
NumPy配列numpy.ndarrayの次元数、形状、要素数の取得については以下の記事を参照。
本記事のサンプルコードにおけるPyTorchのバージョンは1.7.1。以下のtorch.Tensorを例とする。
import torch
print(torch.__version__)
# 1.7.1
t = torch.zeros(2, 3)
print(t)
# tensor([[0., 0., 0.],
# [0., 0., 0.]])
source: torch_tensor_dim_size_numel.py
torch.Tensorの次元数を取得: dim(), ndimension(), ndim
torch.Tensorの次元数はdim()メソッドで取得できる。返り値は整数int。
print(t.dim())
# 2
print(type(t.dim()))
# <class 'int'>
source: torch_tensor_dim_size_numel.py
ndimension()メソッドや、numpy.ndarrayと同じndim属性も使用できる。
- torch.Tensor.ndimension() — PyTorch 1.7.1 documentation
- torch.Tensor.ndim — PyTorch 1.7.1 documentation
print(t.ndimension())
# 2
print(t.ndim)
# 2
source: torch_tensor_dim_size_numel.py
torch.Tensorの形状を取得: size(), shape
torch.Tensorの形状はsize()メソッドで取得できる。返り値はtorch.Size。
print(t.size())
# torch.Size([2, 3])
print(type(t.size()))
# <class 'torch.Size'>
source: torch_tensor_dim_size_numel.py
torch.Sizeはタプルのサブクラス。各要素をインデックスで指定して取得したり、アンパックで個別に取得したりできる。要素は整数int。
print(issubclass(type(t.size()), tuple))
# True
print(t.size()[1])
# 3
print(type(t.size()[1]))
# <class 'int'>
a, b = t.size()
print(a)
# 2
print(b)
# 3
source: torch_tensor_dim_size_numel.py
バージョン1.7.1時点では公式ドキュメントに見当たらなかったが、numpy.ndarrayと同様にshape属性も使用できる。
print(t.shape)
# torch.Size([2, 3])
source: torch_tensor_dim_size_numel.py
torch.Tensorの要素数を取得: numel(), nelement()
torch.Tensorの全要素数はtorch.numel()関数で取得できる。返り値は整数int。
print(torch.numel(t))
# 6
print(type(torch.numel(t)))
# <class 'int'>
source: torch_tensor_dim_size_numel.py
numel()はtorch.Tensorのメソッドとしても定義されている。また、そのエイリアスとしてnelement()メソッドも提供されている。
- torch.Tensor.numel() — PyTorch 1.7.1 documentation
- torch.Tensor.nelement() — PyTorch 1.7.1 documentation
print(t.numel())
# 6
print(t.nelement())
# 6
source: torch_tensor_dim_size_numel.py
numpy.ndarrayはsize属性で要素数を取得できるが、torch.Tensorでは使えないので注意。