PyTorch Hub, torchvision.modelsで学習済みモデルをダウンロード・使用
PyTorch, torchvisionでは、学習済みモデル(訓練済みモデル)をダウンロードして使用できる。
VGGやResNetのような有名なモデルはtorchvision.models
に含まれている。また、PyTorch Hubという仕組みも用意されており、簡単にモデルを公開したりダウンロードしたりできるようになっている。
ここでは以下の内容について説明する。
torchvision.models
で学習済みモデルをダウンロード・使用- 利用できるモデル
- モデルの生成
- データのダウンロード場所
- PyTorch Hubで学習済みモデルをダウンロード・使用
- 公開されているモデルをダウンロード
- 自分のモデルを公開
学習済みモデルを使って実際に画像分類を行う例は以下の記事を参照。
本記事におけるPyTorchおよびtorchvisionのバージョンは以下の通り。バージョンによって仕様が異なる場合もあるので注意。
import torch
import torchvision
import pprint
print(torch.__version__)
# 1.7.1
print(torchvision.__version__)
# 0.8.2
pprintは出力を見やすくするために使っている。
torchvision.modelsで学習済みモデルをダウンロード・使用
利用できるモデル
torchvision.models
に画像分類、torchvision.models.segmentation
にセマンティックセグメンテーション、torchvision.models.detection
に物体検出、torchvision.models.video
に動画分類のモデルが含まれている。
pprint.pprint([s for s in dir(torchvision.models) if s[0].isupper()], compact=True)
# ['AlexNet', 'DenseNet', 'GoogLeNet', 'GoogLeNetOutputs', 'Inception3',
# 'InceptionOutputs', 'MNASNet', 'MobileNetV2', 'ResNet', 'ShuffleNetV2',
# 'SqueezeNet', 'VGG']
pprint.pprint([s for s in dir(torchvision.models.segmentation) if s[0].isupper()], compact=True)
# ['DeepLabV3', 'FCN']
pprint.pprint([s for s in dir(torchvision.models.detection) if s[0].isupper()], compact=True)
# ['FasterRCNN', 'KeypointRCNN', 'MaskRCNN', 'RetinaNet']
モデルの生成
各モデルは以下のように生成できる。それぞれのモデルごとにクラスが定義されているが、いずれもtorch.nn.Module
のサブクラス。
vgg16 = torchvision.models.vgg16()
print(type(vgg16))
# <class 'torchvision.models.vgg.VGG'>
print(issubclass(type(vgg16), torch.nn.Module))
# True
print()
でモデルの構成が確認できる。ここでは出力は省略する。
print(vgg16)
例えばVGGでもvgg11()
やvgg16_bn()
(Batch normalizationが含まれている)など様々な種類がある。詳細は公式ドキュメントを参照。
デフォルトはランダムな重みで初期化されたモデルが生成されるが、引数pretrained
をTrue
とすると学習済みモデルが生成できる。
vgg16_pretrained = torchvision.models.vgg16(pretrained=True)
例えばバイアスを見ると、学習済みモデルでは値が設定されていることが確認できる。
print(vgg16.features[0].bias)
# Parameter containing:
# tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
# 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
# 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
# requires_grad=True)
print(vgg16_pretrained.features[0].bias)
# Parameter containing:
# tensor([ 0.4034, 0.3778, 0.4644, -0.3228, 0.3940, -0.3953, 0.3951, -0.5496,
# 0.2693, -0.7602, -0.3508, 0.2334, -1.3239, -0.1694, 0.3938, -0.1026,
# 0.0460, -0.6995, 0.1549, 0.5628, 0.3011, 0.3425, 0.1073, 0.4651,
# 0.1295, 0.0788, -0.0492, -0.5638, 0.1465, -0.3890, -0.0715, 0.0649,
# 0.2768, 0.3279, 0.5682, -1.2640, -0.8368, -0.9485, 0.1358, 0.2727,
# 0.1841, -0.5325, 0.3507, -0.0827, -1.0248, -0.6912, -0.7711, 0.2612,
# 0.4033, -0.4802, -0.3066, 0.5807, -1.3325, 0.4844, -0.8160, 0.2386,
# 0.2300, 0.4979, 0.5553, 0.5230, -0.2182, 0.0117, -0.5516, 0.2108],
# requires_grad=True)
データのダウンロード場所
学習済みパラメータのデータファイルのダウンロード場所は以下の通り。上から優先順位が高い順。すでにダウンロード済みのファイルが存在する場合はそれが使われる。
The locations are used in the order of
- Callinghub.set_dir(<PATH_TO_HUB_DIR>)
-$TORCH_HOME/hub
, if environment variableTORCH_HOME
is set.
-$XDG_CACHE_HOME/torch/hub
, if environment variableXDG_CACHE_HOME
is set.
-~/.cache/torch/hub
torch.hub - Where are my downloaded models saved? — PyTorch 1.7.0 documentation
hub.set_dir()
を実行したり、環境変数TORCH_HOME
, XDG_CACHE_HOME
を設定したりしていないと、~/.cache/torch/hub
にダウンロードされる(~
はホームディレクトリ)。
PyTorch Hubで学習済みモデルをダウンロード・使用
PyTorch Hubという、ユーザーがモデルを簡単に公開したりダウンロードしたりできる仕組みも用意されている。
torch.hub
はPyTorch1.0
で追加された。
なお、2021年年2月現在まだベータリリースとのことなので、仕様が変更される可能性もある。
*This is a beta release - we will be collecting feedback and improving the PyTorch Hub over the coming months. PyTorch Hub | PyTorch
torch.hub
でダウンロードされるリポジトリや学習済みモデルなどのデータのダウンロード場所は上述のtorchvision.models
と同じ。
The locations are used in the order of
- Callinghub.set_dir(<PATH_TO_HUB_DIR>)
-$TORCH_HOME/hub
, if environment variableTORCH_HOME
is set.
-$XDG_CACHE_HOME/torch/hub
, if environment variableXDG_CACHE_HOME
is set.
-~/.cache/torch/hub
torch.hub - Where are my downloaded models saved? — PyTorch 1.7.0 documentation
公開されているモデルのダウンロード
モデルの生成にはtorch.hub.load()
を使う。初回実行時にデータがダウンロードされる。
第一引数にリポジトリの名前(GitHubの場合、'repo_owner/repo_name[:tag_name]'
)、第二引数にモデルの名前を指定する。そのほか、モデルごとに定義されている引数を適宜指定する。
model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN',
pretrained=True, useGPU=False)
torch.hub.list()
でリポジトリで公開されているモデルの一覧を確認したり、torch.hub.help()
でモデルのdocstringを確認したりすることもできる。
print(torch.hub.list('facebookresearch/pytorch_GAN_zoo:hub'))
# ['DCGAN', 'PGAN']
print(torch.hub.help('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN'))
#
# DCGAN basic model
# pretrained (bool): load a pretrained model ? In this case load a model
# trained on fashionGen cloth
#
基本的には元のリポジトリや公開者のWebサイトなどで使い方が説明されているはずなので、詳細はそちらを参照すればよい。ほかのライブラリのインストールが必要な場合もある。
上の例はFacebook Researchが公開しているDCGANのモデル。
- DCGAN on FashionGen | PyTorch
- facebookresearch/pytorch_GAN_zoo: A mix of GAN implementations including progressive growing
自分のモデルを公開
GitHubのリポジトリに設定ファイルhubconf.py
を追加すると、他の人がtorch.hub.load()
でそのモデルを使えるようになる。
以下のリポジトリにプルリクエストを送るとPyTorch Hubのページに掲載されるらしい。自分のモデルを宣伝したい場合は検討してみるといいかもしれない。