note.nkmk.me

PyTorchのtorch.flattenとtorch.nn.Flattenの違い

Posted: 2021-03-20 / Tags: Python, PyTorch, 機械学習

PyTorchのtorch.flatten()はすべての次元を平坦化(一次元化)するが、torch.nn.Flattenのインスタンスは最初の次元(バッチ用の次元)はそのままで以降の次元を平坦化するという違いがある(デフォルトの場合)。

ここでは以下の内容について説明する。

  • torch.flatten()の使い方
  • torch.nn.Flattenの使い方
  • 平坦化処理を含むtorch.nn.Moduleのサブクラスを定義する際の注意

本記事におけるPyTorchのバージョンは以下の通り。バージョンが異なると仕様が異なる可能性があるので注意。

import torch
import torch.nn as nn

print(torch.__version__)
# 1.7.1
スポンサーリンク

torch.flatten()の使い方

torch.flatten()は第一引数に指定したtorch.Tensorを平坦化する。

t = torch.zeros(2, 3, 4, 5)
print(t.shape)
# torch.Size([2, 3, 4, 5])

print(torch.flatten(t).shape)
# torch.Size([120])

第二引数start_dim、第三引数end_dimを指定するとその間の次元のみが平坦化される。

print(torch.flatten(t, 1, 2).shape)
# torch.Size([2, 12, 5])

負の値で後ろからの位置も指定できる。-1が一番最後。

デフォルト値はstart_dim=0, end_dim=-1で、上の例のように最初から最後まですべての次元が平坦化される。

torch.nn.Flattenの使い方

torch.nn.Flattentorch.nn.Moduleのサブクラス。torch.nn.Sequentialや独自のクラスを定義してモデル(ネットワーク)を構築する際に用いる。

torch.nn.Flattenのインスタンスを生成し実行すると、第一引数に指定したtorch.Tensorが平坦化される。

torch.flatten()と異なり、デフォルトはstart_dim=1, end_dim=-1の範囲が平坦化され、最初の次元はそのまま。

flatten = nn.Flatten()
print(flatten(t).shape)
# torch.Size([2, 60])

start_dim, end_dimはコンストラクタtorch.nn.Flatten()で指定可能。第一引数がstart_dim、第二引数がend_dim

flatten_all = nn.Flatten(0, -1)
print(flatten_all(t).shape)
# torch.Size([120])

最初の次元を平坦化しないのは、バッチ処理用の次元を保持するため。

例えば、高さhwチャンネル数c(RGB画像の場合はc = 3)を入力とする場合、一枚の画像のテンソルの形状は(c, h, w)、n枚の画像をまとめたテンソルの形状は(n, c, h, w)となる。

モデルに対しては(n, c, h, w)のテンソルが入力されるが、平坦化の処理としては(n, c * h * w)に変換(それぞれの画像を平坦化)することが望ましい。torch.nn.Flattenのデフォルト値はそのような処理になっている。

平坦化処理を含むtorch.nn.Moduleのサブクラスを定義する際の注意

平坦化処理を含むtorch.nn.Moduleのサブクラスを定義する場合、上述の通り、最初の次元(バッチ用の次元)は保持することが望ましい。

torch.nn.Flattenを用いる場合はデフォルトのままでOK。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.flatten(x)
        return x

net = Net()
print(net(t).shape)
# torch.Size([2, 60])

torch.flatten()forward()内で用いる場合、デフォルトのままだとすべての次元が平坦化されてしまう。

class NetFunctional(nn.Module):
    def forward(self, x):
        x = torch.flatten(x)
        return x

net_func = NetFunctional()
print(net_func(t).shape)
# torch.Size([120])

第二引数start_dim1とすればよい。

class NetFunctionalDim(nn.Module):
    def forward(self, x):
        x = torch.flatten(x, 1)
        return x

net_func_dim = NetFunctionalDim()
print(net_func_dim(t).shape)
# torch.Size([2, 60])

一部のみを平坦化するようなモデルを構築したい場合は、torch.nn.Flatten()またはtorch.flatten()start_dim, end_dimを所望の値に指定すればよい。

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事