PyTorchのtorch.flattenとtorch.nn.Flattenの違い
PyTorchのtorch.flatten()はすべての次元を平坦化(一次元化)するが、torch.nn.Flattenのインスタンスは最初の次元(バッチ用の次元)はそのままで以降の次元を平坦化するという違いがある(デフォルトの場合)。
ここでは以下の内容について説明する。
torch.flatten()の使い方torch.nn.Flattenの使い方- 平坦化処理を含む
torch.nn.Moduleのサブクラスを定義する際の注意
本記事におけるPyTorchのバージョンは以下の通り。バージョンが異なると仕様が異なる可能性があるので注意。
import torch
import torch.nn as nn
print(torch.__version__)
# 1.7.1
torch.flatten()の使い方
torch.flatten()は第一引数に指定したtorch.Tensorを平坦化する。
t = torch.zeros(2, 3, 4, 5)
print(t.shape)
# torch.Size([2, 3, 4, 5])
print(torch.flatten(t).shape)
# torch.Size([120])
第二引数start_dim、第三引数end_dimを指定するとその間の次元のみが平坦化される。
print(torch.flatten(t, 1, 2).shape)
# torch.Size([2, 12, 5])
負の値で後ろからの位置も指定できる。-1が一番最後。
デフォルト値はstart_dim=0, end_dim=-1で、上の例のように最初から最後まですべての次元が平坦化される。
torch.nn.Flattenの使い方
torch.nn.Flattenはtorch.nn.Moduleのサブクラス。torch.nn.Sequentialや独自のクラスを定義してモデル(ネットワーク)を構築する際に用いる。
torch.nn.Flattenのインスタンスを生成し実行すると、第一引数に指定したtorch.Tensorが平坦化される。
torch.flatten()と異なり、デフォルトはstart_dim=1, end_dim=-1の範囲が平坦化され、最初の次元はそのまま。
flatten = nn.Flatten()
print(flatten(t).shape)
# torch.Size([2, 60])
start_dim, end_dimはコンストラクタtorch.nn.Flatten()で指定可能。第一引数がstart_dim、第二引数がend_dim。
flatten_all = nn.Flatten(0, -1)
print(flatten_all(t).shape)
# torch.Size([120])
最初の次元を平坦化しないのは、バッチ処理用の次元を保持するため。
例えば、高さh幅wチャンネル数c(RGB画像の場合はc = 3)を入力とする場合、一枚の画像のテンソルの形状は(c, h, w)、n枚の画像をまとめたテンソルの形状は(n, c, h, w)となる。
モデルに対しては(n, c, h, w)のテンソルが入力されるが、平坦化の処理としては(n, c * h * w)に変換(それぞれの画像を平坦化)することが望ましい。torch.nn.Flattenのデフォルト値はそのような処理になっている。
平坦化処理を含むtorch.nn.Moduleのサブクラスを定義する際の注意
平坦化処理を含むtorch.nn.Moduleのサブクラスを定義する場合、上述の通り、最初の次元(バッチ用の次元)は保持することが望ましい。
torch.nn.Flattenを用いる場合はデフォルトのままでOK。
class Net(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
def forward(self, x):
x = self.flatten(x)
return x
net = Net()
print(net(t).shape)
# torch.Size([2, 60])
torch.flatten()をforward()内で用いる場合、デフォルトのままだとすべての次元が平坦化されてしまう。
class NetFunctional(nn.Module):
def forward(self, x):
x = torch.flatten(x)
return x
net_func = NetFunctional()
print(net_func(t).shape)
# torch.Size([120])
第二引数start_dimを1とすればよい。
class NetFunctionalDim(nn.Module):
def forward(self, x):
x = torch.flatten(x, 1)
return x
net_func_dim = NetFunctionalDim()
print(net_func_dim(t).shape)
# torch.Size([2, 60])
一部のみを平坦化するようなモデルを構築したい場合は、torch.nn.Flatten()またはtorch.flatten()のstart_dim, end_dimを所望の値に指定すればよい。