PyTorchのtorch.flattenとtorch.nn.Flattenの違い
PyTorchのtorch.flatten()
はすべての次元を平坦化(一次元化)するが、torch.nn.Flatten
のインスタンスは最初の次元(バッチ用の次元)はそのままで以降の次元を平坦化するという違いがある(デフォルトの場合)。
ここでは以下の内容について説明する。
torch.flatten()
の使い方torch.nn.Flatten
の使い方- 平坦化処理を含む
torch.nn.Module
のサブクラスを定義する際の注意
本記事におけるPyTorchのバージョンは以下の通り。バージョンが異なると仕様が異なる可能性があるので注意。
import torch
import torch.nn as nn
print(torch.__version__)
# 1.7.1
torch.flatten()の使い方
torch.flatten()
は第一引数に指定したtorch.Tensor
を平坦化する。
t = torch.zeros(2, 3, 4, 5)
print(t.shape)
# torch.Size([2, 3, 4, 5])
print(torch.flatten(t).shape)
# torch.Size([120])
第二引数start_dim
、第三引数end_dim
を指定するとその間の次元のみが平坦化される。
print(torch.flatten(t, 1, 2).shape)
# torch.Size([2, 12, 5])
負の値で後ろからの位置も指定できる。-1
が一番最後。
デフォルト値はstart_dim=0
, end_dim=-1
で、上の例のように最初から最後まですべての次元が平坦化される。
torch.nn.Flattenの使い方
torch.nn.Flatten
はtorch.nn.Module
のサブクラス。torch.nn.Sequential
や独自のクラスを定義してモデル(ネットワーク)を構築する際に用いる。
torch.nn.Flatten
のインスタンスを生成し実行すると、第一引数に指定したtorch.Tensor
が平坦化される。
torch.flatten()
と異なり、デフォルトはstart_dim=1
, end_dim=-1
の範囲が平坦化され、最初の次元はそのまま。
flatten = nn.Flatten()
print(flatten(t).shape)
# torch.Size([2, 60])
start_dim
, end_dim
はコンストラクタtorch.nn.Flatten()
で指定可能。第一引数がstart_dim
、第二引数がend_dim
。
flatten_all = nn.Flatten(0, -1)
print(flatten_all(t).shape)
# torch.Size([120])
最初の次元を平坦化しないのは、バッチ処理用の次元を保持するため。
例えば、高さh
幅w
チャンネル数c
(RGB画像の場合はc = 3
)を入力とする場合、一枚の画像のテンソルの形状は(c, h, w)
、n枚の画像をまとめたテンソルの形状は(n, c, h, w)
となる。
モデルに対しては(n, c, h, w)
のテンソルが入力されるが、平坦化の処理としては(n, c * h * w)
に変換(それぞれの画像を平坦化)することが望ましい。torch.nn.Flatten
のデフォルト値はそのような処理になっている。
平坦化処理を含むtorch.nn.Moduleのサブクラスを定義する際の注意
平坦化処理を含むtorch.nn.Module
のサブクラスを定義する場合、上述の通り、最初の次元(バッチ用の次元)は保持することが望ましい。
torch.nn.Flatten
を用いる場合はデフォルトのままでOK。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
def forward(self, x):
x = self.flatten(x)
return x
net = Net()
print(net(t).shape)
# torch.Size([2, 60])
torch.flatten()
をforward()
内で用いる場合、デフォルトのままだとすべての次元が平坦化されてしまう。
class NetFunctional(nn.Module):
def forward(self, x):
x = torch.flatten(x)
return x
net_func = NetFunctional()
print(net_func(t).shape)
# torch.Size([120])
第二引数start_dim
を1
とすればよい。
class NetFunctionalDim(nn.Module):
def forward(self, x):
x = torch.flatten(x, 1)
return x
net_func_dim = NetFunctionalDim()
print(net_func_dim(t).shape)
# torch.Size([2, 60])
一部のみを平坦化するようなモデルを構築したい場合は、torch.nn.Flatten()
またはtorch.flatten()
のstart_dim
, end_dim
を所望の値に指定すればよい。