PyTorchのTensorの次元数、形状、要素数を取得: dim(), size(), numel()
PyTorchテンソルtorch.Tensor
の次元数、形状、要素数を取得するには、dim()
, size()
, numel()
などを使う。エイリアスもいくつか定義されている。
- torch.Tensor.dim() — PyTorch 1.7.1 documentation
- torch.Tensor.size() — PyTorch 1.7.1 documentation
- torch.numel() — PyTorch 1.7.1 documentation
ここでは以下の内容について説明する。
torch.Tensor
の次元数を取得:dim()
,ndimension()
,ndim
torch.Tensor
の形状を取得:size()
,shape
torch.Tensor
の要素数を取得:numel()
,nelement()
NumPy配列numpy.ndarray
の次元数、形状、要素数の取得については以下の記事を参照。
本記事のサンプルコードにおけるPyTorchのバージョンは1.7.1
。以下のtorch.Tensor
を例とする。
import torch
print(torch.__version__)
# 1.7.1
t = torch.zeros(2, 3)
print(t)
# tensor([[0., 0., 0.],
# [0., 0., 0.]])
source: torch_tensor_dim_size_numel.py
torch.Tensorの次元数を取得: dim(), ndimension(), ndim
torch.Tensor
の次元数はdim()
メソッドで取得できる。返り値は整数int
。
print(t.dim())
# 2
print(type(t.dim()))
# <class 'int'>
source: torch_tensor_dim_size_numel.py
ndimension()
メソッドや、numpy.ndarray
と同じndim
属性も使用できる。
- torch.Tensor.ndimension() — PyTorch 1.7.1 documentation
- torch.Tensor.ndim — PyTorch 1.7.1 documentation
print(t.ndimension())
# 2
print(t.ndim)
# 2
source: torch_tensor_dim_size_numel.py
torch.Tensorの形状を取得: size(), shape
torch.Tensor
の形状はsize()
メソッドで取得できる。返り値はtorch.Size
。
print(t.size())
# torch.Size([2, 3])
print(type(t.size()))
# <class 'torch.Size'>
source: torch_tensor_dim_size_numel.py
torch.Size
はタプルのサブクラス。各要素をインデックスで指定して取得したり、アンパックで個別に取得したりできる。要素は整数int
。
print(issubclass(type(t.size()), tuple))
# True
print(t.size()[1])
# 3
print(type(t.size()[1]))
# <class 'int'>
a, b = t.size()
print(a)
# 2
print(b)
# 3
source: torch_tensor_dim_size_numel.py
バージョン1.7.1
時点では公式ドキュメントに見当たらなかったが、numpy.ndarray
と同様にshape
属性も使用できる。
print(t.shape)
# torch.Size([2, 3])
source: torch_tensor_dim_size_numel.py
torch.Tensorの要素数を取得: numel(), nelement()
torch.Tensor
の全要素数はtorch.numel()
関数で取得できる。返り値は整数int
。
print(torch.numel(t))
# 6
print(type(torch.numel(t)))
# <class 'int'>
source: torch_tensor_dim_size_numel.py
numel()
はtorch.Tensor
のメソッドとしても定義されている。また、そのエイリアスとしてnelement()
メソッドも提供されている。
- torch.Tensor.numel() — PyTorch 1.7.1 documentation
- torch.Tensor.nelement() — PyTorch 1.7.1 documentation
print(t.numel())
# 6
print(t.nelement())
# 6
source: torch_tensor_dim_size_numel.py
numpy.ndarray
はsize
属性で要素数を取得できるが、torch.Tensor
では使えないので注意。