note.nkmk.me

PyTorch Hub, torchvision.modelsで学習済みモデルをダウンロード・使用

Posted: 2021-02-20 / Tags: Python, PyTorch, 機械学習

PyTorch, torchvisionでは、学習済みモデル(訓練済みモデル)をダウンロードして使用できる。

VGGやResNetのような有名なモデルはtorchvision.modelsに含まれている。また、PyTorch Hubという仕組みも用意されており、簡単にモデルを公開したりダウンロードしたりできるようになっている。

ここでは以下の内容について説明する。

  • torchvision.modelsで学習済みモデルをダウンロード・使用
    • 利用できるモデル
    • モデルの生成
    • データのダウンロード場所
  • PyTorch Hubで学習済みモデルをダウンロード・使用
    • 公開されているモデルをダウンロード
    • 自分のモデルを公開

学習済みモデルを使って実際に画像分類を行う例は以下の記事を参照。

本記事におけるPyTorchおよびtorchvisionのバージョンは以下の通り。バージョンによって仕様が異なる場合もあるので注意。

import torch
import torchvision
import pprint

print(torch.__version__)
# 1.7.1

print(torchvision.__version__)
# 0.8.2

pprintは出力を見やすくするために使っている。

スポンサーリンク

torchvision.modelsで学習済みモデルをダウンロード・使用

利用できるモデル

torchvision.modelsに画像分類、torchvision.models.segmentationにセマンティックセグメンテーション、torchvision.models.detectionに物体検出、torchvision.models.videoに動画分類のモデルが含まれている。

pprint.pprint([s for s in dir(torchvision.models) if s[0].isupper()], compact=True)
# ['AlexNet', 'DenseNet', 'GoogLeNet', 'GoogLeNetOutputs', 'Inception3',
#  'InceptionOutputs', 'MNASNet', 'MobileNetV2', 'ResNet', 'ShuffleNetV2',
#  'SqueezeNet', 'VGG']

pprint.pprint([s for s in dir(torchvision.models.segmentation) if s[0].isupper()], compact=True)
# ['DeepLabV3', 'FCN']

pprint.pprint([s for s in dir(torchvision.models.detection) if s[0].isupper()], compact=True)
# ['FasterRCNN', 'KeypointRCNN', 'MaskRCNN', 'RetinaNet']

モデルの生成

各モデルは以下のように生成できる。それぞれのモデルごとにクラスが定義されているが、いずれもtorch.nn.Moduleのサブクラス。

vgg16 = torchvision.models.vgg16()

print(type(vgg16))
# <class 'torchvision.models.vgg.VGG'>

print(issubclass(type(vgg16), torch.nn.Module))
# True

print()でモデルの構成が確認できる。ここでは出力は省略する。

print(vgg16)

例えばVGGでもvgg11()vgg16_bn()(Batch normalizationが含まれている)など様々な種類がある。詳細は公式ドキュメントを参照。

デフォルトはランダムな重みで初期化されたモデルが生成されるが、引数pretrainedTrueとすると学習済みモデルが生成できる。

vgg16_pretrained = torchvision.models.vgg16(pretrained=True)

例えばバイアスを見ると、学習済みモデルでは値が設定されていることが確認できる。

print(vgg16.features[0].bias)
# Parameter containing:
# tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
#         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
#         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
#        requires_grad=True)

print(vgg16_pretrained.features[0].bias)
# Parameter containing:
# tensor([ 0.4034,  0.3778,  0.4644, -0.3228,  0.3940, -0.3953,  0.3951, -0.5496,
#          0.2693, -0.7602, -0.3508,  0.2334, -1.3239, -0.1694,  0.3938, -0.1026,
#          0.0460, -0.6995,  0.1549,  0.5628,  0.3011,  0.3425,  0.1073,  0.4651,
#          0.1295,  0.0788, -0.0492, -0.5638,  0.1465, -0.3890, -0.0715,  0.0649,
#          0.2768,  0.3279,  0.5682, -1.2640, -0.8368, -0.9485,  0.1358,  0.2727,
#          0.1841, -0.5325,  0.3507, -0.0827, -1.0248, -0.6912, -0.7711,  0.2612,
#          0.4033, -0.4802, -0.3066,  0.5807, -1.3325,  0.4844, -0.8160,  0.2386,
#          0.2300,  0.4979,  0.5553,  0.5230, -0.2182,  0.0117, -0.5516,  0.2108],
#        requires_grad=True)

データのダウンロード場所

学習済みパラメータのデータファイルのダウンロード場所は以下の通り。上から優先順位が高い順。すでにダウンロード済みのファイルが存在する場合はそれが使われる。

The locations are used in the order of
- Calling hub.set_dir(<PATH_TO_HUB_DIR>)
- $TORCH_HOME/hub, if environment variable TORCH_HOME is set.
- $XDG_CACHE_HOME/torch/hub, if environment variable XDG_CACHE_HOME is set.
- ~/.cache/torch/hub
torch.hub - Where are my downloaded models saved? — PyTorch 1.7.0 documentation

hub.set_dir()を実行したり、環境変数TORCH_HOME, XDG_CACHE_HOMEを設定したりしていないと、~/.cache/torch/hubにダウンロードされる(~はホームディレクトリ)。

PyTorch Hubで学習済みモデルをダウンロード・使用

PyTorch Hubという、ユーザーがモデルを簡単に公開したりダウンロードしたりできる仕組みも用意されている。

torch.hubはPyTorch1.0で追加された。

なお、2021年年2月現在まだベータリリースとのことなので、仕様が変更される可能性もある。

*This is a beta release - we will be collecting feedback and improving the PyTorch Hub over the coming months. PyTorch Hub | PyTorch

torch.hubでダウンロードされるリポジトリや学習済みモデルなどのデータのダウンロード場所は上述のtorchvision.modelsと同じ。

The locations are used in the order of
- Calling hub.set_dir(<PATH_TO_HUB_DIR>)
- $TORCH_HOME/hub, if environment variable TORCH_HOME is set.
- $XDG_CACHE_HOME/torch/hub, if environment variable XDG_CACHE_HOME is set.
- ~/.cache/torch/hub
torch.hub - Where are my downloaded models saved? — PyTorch 1.7.0 documentation

公開されているモデルのダウンロード

モデルの生成にはtorch.hub.load()を使う。初回実行時にデータがダウンロードされる。

第一引数にリポジトリの名前(GitHubの場合、'repo_owner/repo_name[:tag_name]')、第二引数にモデルの名前を指定する。そのほか、モデルごとに定義されている引数を適宜指定する。

model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN',
                       pretrained=True, useGPU=False)
source: torch_hub.py

torch.hub.list()でリポジトリで公開されているモデルの一覧を確認したり、torch.hub.help()でモデルのdocstringを確認したりすることもできる。

print(torch.hub.list('facebookresearch/pytorch_GAN_zoo:hub'))
# ['DCGAN', 'PGAN']
source: torch_hub.py
print(torch.hub.help('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN'))
# 
#     DCGAN basic model
#     pretrained (bool): load a pretrained model ? In this case load a model
#     trained on fashionGen cloth
#     
source: torch_hub.py

基本的には元のリポジトリや公開者のWebサイトなどで使い方が説明されているはずなので、詳細はそちらを参照すればよい。ほかのライブラリのインストールが必要な場合もある。

上の例はFacebook Researchが公開しているDCGANのモデル。

自分のモデルを公開

GitHubのリポジトリに設定ファイルhubconf.pyを追加すると、他の人がtorch.hub.load()でそのモデルを使えるようになる。

以下のリポジトリにプルリクエストを送るとPyTorch Hubのページに掲載されるらしい。自分のモデルを宣伝したい場合は検討してみるといいかもしれない。

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事