PyTorchのTensorのデータ型(dtype)と型変換(キャスト)
PyTorchテンソルtorch.Tensorはtorch.float32やtorch.int64などのデータ型dtypeを持つ。
ここでは以下の内容について説明する。
torch.Tensorのデータ型dtype一覧torch.Tensorのデータ型を取得:dtype属性- データ型
dtypeを指定してtorch.Tensorを生成 torch.Tensorの型変換(キャスト)to()メソッドfloat(),double()メソッドなど
- 演算における暗黙の型変換(キャスト)
型変換(キャスト)ではなく、デバイス(GPU / CPU)を切り替えたい場合は以下の記事を参照。
本記事のサンプルコードにおけるPyTorchのバージョンは以下の通り。バージョンが異なると仕様が異なる場合もあるので注意。
import torch
print(torch.__version__)
# 1.7.1
torch.Tensorのデータ型dtype一覧
torch.Tensorのデータ型dtype一覧は以下の通り。
データ型dtype |
|
|---|---|
| 16ビット浮動小数点数(※) | torch.float16 or torch.half |
| 16ビット浮動小数点数(※) | torch.bfloat16 |
| 32ビット浮動小数点数 | torch.float32 or torch.float |
| 64ビット浮動小数点数 | torch.float64 or torch.double |
| 8ビット符号なし整数 | torch.uint8 |
| 8ビット符号あり整数 | torch.int8 |
| 16ビット符号あり整数 | torch.int16 or torch.short |
| 32ビット符号あり整数 | torch.int32 or torch.int |
| 64ビット符号あり整数 | torch.int64 or torch.long |
| ブーリアン | torch.bool |
| 64ビット複素数 | torch.complex64 or torch.cfloat |
| 128ビット複素数 | torch.complex128 or torch.cdouble |
※16ビット浮動小数点数のtorch.float16およびtorch.halfは符号部(sign)1ビット、仮数部(significand)10ビット、指数部(exponent)5ビットで、torch.bfloat16は符号部1ビット、仮数部7ビット、指数部8ビットという違いがある。
これらはtorch.dtype型で、以降で説明するようにtorch.Tensorを生成する関数やto()メソッドなどの引数に指定できる。
print(type(torch.float32))
# <class 'torch.dtype'>
torch.float32とtorch.float、torch.float64とtorch.double、torch.int64とtorch.longなどは同じもの。どちらを使ってもよい。
print(torch.float32 is torch.float)
# True
print(torch.int64 is torch.long)
# True
ニューラルネットワークの計算は一般的に32ビット浮動小数点数(torch.float32, torch.float)で行われることが多い。
torch.Tensorのデータ型を取得: dtype属性
torch.Tensorのデータ型はdtype属性で取得できる。
t_float32 = torch.tensor([0.1, 1.5, 2.9])
print(t_float32)
# tensor([0.1000, 1.5000, 2.9000])
print(t_float32.dtype)
# torch.float32
print(type(t_float32.dtype))
# <class 'torch.dtype'>
データ型dtypeを指定してtorch.Tensorを生成
torch.tensor()あるいはtorch.ones(), torch.zeros()などでは、引数dtypeを指定して任意のデータ型のtorch.Tensorを生成できる。
t_float64 = torch.tensor([0.1, 1.5, 2.9], dtype=torch.float64)
print(t_float64.dtype)
# torch.float64
t_int32 = torch.ones(3, dtype=torch.int32)
print(t_int32.dtype)
# torch.int32
torch.Tensorの型変換(キャスト)
to()メソッド
torch.Tensorのto()メソッドで型変換(キャスト)ができる。
to()の第一引数dtypeにtorch.float64などのtorch.dtypeを指定する。
t_float64 = t_float32.to(torch.float64)
print(t_float64.dtype)
# torch.float64
浮動小数点数から整数へのキャストは小数点以下切り捨てになる。
print(t_float32)
# tensor([0.1000, 1.5000, 2.9000])
print(t_float32.to(torch.int64))
# tensor([0, 1, 2])
to()メソッドはto(device='cuda:0')のようにCPUからGPUへのコピー(あるいはGPUからCPUへのコピー)にも使われる。dtypeとdeviceを同時に指定することもできる。
float(), double()メソッドなど
float()メソッドでtorch.float(= torch.float32)、double()メソッドでtorch.double(= torch.float64)というように、データ型の名前のメソッドでキャストすることもできる。
half(), float(), double(), short(), int(), long()メソッドは定義されているが、float16()やint32()メソッドのようなfloatXX(), intXX()メソッドは定義されていないので注意。
t_float64 = t_float32.double()
print(t_float64.dtype)
# torch.float64
# t_float32.float64()
# AttributeError: 'Tensor' object has no attribute 'float64'
演算における暗黙の型変換(キャスト)
異なるデータ型のtorch.Tensorの二項演算(四則演算など)においては、暗黙の型変換(キャスト)が行われる。
floatとintの場合はfloatにキャストされ、float同士・int同士の場合はビット数の大きい方にキャストされる。
t_float16 = torch.ones(3, dtype=torch.float16)
t_int64 = torch.ones(3, dtype=torch.int64)
print((t_float16 + t_int64).dtype)
# torch.float16
t_float32 = torch.ones(3, dtype=torch.float32)
t_float64 = torch.ones(3, dtype=torch.float64)
print((t_float32 + t_float64).dtype)
# torch.float64
より詳細なルールについては公式ドキュメントを参照。