PyTorchのTensorのデータ型(dtype)と型変換(キャスト)
PyTorchテンソルtorch.Tensor
はtorch.float32
やtorch.int64
などのデータ型dtype
を持つ。
ここでは以下の内容について説明する。
torch.Tensor
のデータ型dtype
一覧torch.Tensor
のデータ型を取得:dtype
属性- データ型
dtype
を指定してtorch.Tensor
を生成 torch.Tensor
の型変換(キャスト)to()
メソッドfloat()
,double()
メソッドなど
- 演算における暗黙の型変換(キャスト)
型変換(キャスト)ではなく、デバイス(GPU / CPU)を切り替えたい場合は以下の記事を参照。
本記事のサンプルコードにおけるPyTorchのバージョンは以下の通り。バージョンが異なると仕様が異なる場合もあるので注意。
import torch
print(torch.__version__)
# 1.7.1
torch.Tensorのデータ型dtype一覧
torch.Tensor
のデータ型dtype
一覧は以下の通り。
データ型dtype |
|
---|---|
16ビット浮動小数点数(※) | torch.float16 or torch.half |
16ビット浮動小数点数(※) | torch.bfloat16 |
32ビット浮動小数点数 | torch.float32 or torch.float |
64ビット浮動小数点数 | torch.float64 or torch.double |
8ビット符号なし整数 | torch.uint8 |
8ビット符号あり整数 | torch.int8 |
16ビット符号あり整数 | torch.int16 or torch.short |
32ビット符号あり整数 | torch.int32 or torch.int |
64ビット符号あり整数 | torch.int64 or torch.long |
ブーリアン | torch.bool |
64ビット複素数 | torch.complex64 or torch.cfloat |
128ビット複素数 | torch.complex128 or torch.cdouble |
※16ビット浮動小数点数のtorch.float16
およびtorch.half
は符号部(sign)1ビット、仮数部(significand)10ビット、指数部(exponent)5ビットで、torch.bfloat16
は符号部1ビット、仮数部7ビット、指数部8ビットという違いがある。
これらはtorch.dtype
型で、以降で説明するようにtorch.Tensor
を生成する関数やto()
メソッドなどの引数に指定できる。
print(type(torch.float32))
# <class 'torch.dtype'>
torch.float32
とtorch.float
、torch.float64
とtorch.double
、torch.int64
とtorch.long
などは同じもの。どちらを使ってもよい。
print(torch.float32 is torch.float)
# True
print(torch.int64 is torch.long)
# True
ニューラルネットワークの計算は一般的に32ビット浮動小数点数(torch.float32
, torch.float
)で行われることが多い。
torch.Tensorのデータ型を取得: dtype属性
torch.Tensor
のデータ型はdtype
属性で取得できる。
t_float32 = torch.tensor([0.1, 1.5, 2.9])
print(t_float32)
# tensor([0.1000, 1.5000, 2.9000])
print(t_float32.dtype)
# torch.float32
print(type(t_float32.dtype))
# <class 'torch.dtype'>
データ型dtypeを指定してtorch.Tensorを生成
torch.tensor()
あるいはtorch.ones()
, torch.zeros()
などでは、引数dtype
を指定して任意のデータ型のtorch.Tensor
を生成できる。
t_float64 = torch.tensor([0.1, 1.5, 2.9], dtype=torch.float64)
print(t_float64.dtype)
# torch.float64
t_int32 = torch.ones(3, dtype=torch.int32)
print(t_int32.dtype)
# torch.int32
torch.Tensorの型変換(キャスト)
to()メソッド
torch.Tensor
のto()
メソッドで型変換(キャスト)ができる。
to()
の第一引数dtype
にtorch.float64
などのtorch.dtype
を指定する。
t_float64 = t_float32.to(torch.float64)
print(t_float64.dtype)
# torch.float64
浮動小数点数から整数へのキャストは小数点以下切り捨てになる。
print(t_float32)
# tensor([0.1000, 1.5000, 2.9000])
print(t_float32.to(torch.int64))
# tensor([0, 1, 2])
to()
メソッドはto(device='cuda:0')
のようにCPUからGPUへのコピー(あるいはGPUからCPUへのコピー)にも使われる。dtype
とdevice
を同時に指定することもできる。
float(), double()メソッドなど
float()
メソッドでtorch.float
(= torch.float32
)、double()
メソッドでtorch.double
(= torch.float64
)というように、データ型の名前のメソッドでキャストすることもできる。
half()
, float()
, double()
, short()
, int()
, long()
メソッドは定義されているが、float16()
やint32()
メソッドのようなfloatXX()
, intXX()
メソッドは定義されていないので注意。
t_float64 = t_float32.double()
print(t_float64.dtype)
# torch.float64
# t_float32.float64()
# AttributeError: 'Tensor' object has no attribute 'float64'
演算における暗黙の型変換(キャスト)
異なるデータ型のtorch.Tensor
の二項演算(四則演算など)においては、暗黙の型変換(キャスト)が行われる。
float
とint
の場合はfloat
にキャストされ、float
同士・int
同士の場合はビット数の大きい方にキャストされる。
t_float16 = torch.ones(3, dtype=torch.float16)
t_int64 = torch.ones(3, dtype=torch.int64)
print((t_float16 + t_int64).dtype)
# torch.float16
t_float32 = torch.ones(3, dtype=torch.float32)
t_float64 = torch.ones(3, dtype=torch.float64)
print((t_float32 + t_float64).dtype)
# torch.float64
より詳細なルールについては公式ドキュメントを参照。