note.nkmk.me

PyTorchのTensorのデータ型(dtype)と型変換(キャスト)

Posted: 2021-03-06 / Tags: Python, PyTorch, 機械学習

PyTorchテンソルtorch.Tensortorch.float32torch.int64などのデータ型dtypeを持つ。

ここでは以下の内容について説明する。

  • torch.Tensorのデータ型dtype一覧
  • torch.Tensorのデータ型を取得: dtype属性
  • データ型dtypeを指定してtorch.Tensorを生成
  • torch.Tensorの型変換(キャスト)
    • to()メソッド
    • float(), double()メソッドなど
  • 演算における暗黙の型変換(キャスト)

型変換(キャスト)ではなく、デバイス(GPU / CPU)を切り替えたい場合は以下の記事を参照。

本記事のサンプルコードにおけるPyTorchのバージョンは以下の通り。バージョンが異なると仕様が異なる場合もあるので注意。

import torch

print(torch.__version__)
# 1.7.1
スポンサーリンク

torch.Tensorのデータ型dtype一覧

torch.Tensorのデータ型dtype一覧は以下の通り。

データ型dtype
16ビット浮動小数点数(※) torch.float16 or torch.half
16ビット浮動小数点数(※) torch.bfloat16
32ビット浮動小数点数 torch.float32 or torch.float
64ビット浮動小数点数 torch.float64 or torch.double
8ビット符号なし整数 torch.uint8
8ビット符号あり整数 torch.int8
16ビット符号あり整数 torch.int16 or torch.short
32ビット符号あり整数 torch.int32 or torch.int
64ビット符号あり整数 torch.int64 or torch.long
ブーリアン torch.bool
64ビット複素数 torch.complex64 or torch.cfloat
128ビット複素数 torch.complex128 or torch.cdouble

※16ビット浮動小数点数のtorch.float16およびtorch.halfは符号部(sign)1ビット、仮数部(significand)10ビット、指数部(exponent)5ビットで、torch.bfloat16は符号部1ビット、仮数部7ビット、指数部8ビットという違いがある。

これらはtorch.dtype型で、以降で説明するようにtorch.Tensorを生成する関数やto()メソッドなどの引数に指定できる。

print(type(torch.float32))
# <class 'torch.dtype'>

torch.float32torch.floattorch.float64torch.doubletorch.int64torch.longなどは同じもの。どちらを使ってもよい。

print(torch.float32 is torch.float)
# True

print(torch.int64 is torch.long)
# True

ニューラルネットワークの計算は一般的に32ビット浮動小数点数(torch.float32, torch.float)で行われることが多い。

torch.Tensorのデータ型を取得: dtype属性

torch.Tensorのデータ型はdtype属性で取得できる。

t_float32 = torch.tensor([0.1, 1.5, 2.9])
print(t_float32)
# tensor([0.1000, 1.5000, 2.9000])

print(t_float32.dtype)
# torch.float32

print(type(t_float32.dtype))
# <class 'torch.dtype'>

データ型dtypeを指定してtorch.Tensorを生成

torch.tensor()あるいはtorch.ones(), torch.zeros()などでは、引数dtypeを指定して任意のデータ型のtorch.Tensorを生成できる。

t_float64 = torch.tensor([0.1, 1.5, 2.9], dtype=torch.float64)
print(t_float64.dtype)
# torch.float64

t_int32 = torch.ones(3, dtype=torch.int32)
print(t_int32.dtype)
# torch.int32

torch.Tensorの型変換(キャスト)

to()メソッド

torch.Tensorto()メソッドで型変換(キャスト)ができる。

to()の第一引数dtypetorch.float64などのtorch.dtypeを指定する。

t_float64 = t_float32.to(torch.float64)
print(t_float64.dtype)
# torch.float64

浮動小数点数から整数へのキャストは小数点以下切り捨てになる。

print(t_float32)
# tensor([0.1000, 1.5000, 2.9000])

print(t_float32.to(torch.int64))
# tensor([0, 1, 2])

to()メソッドはto(device='cuda:0')のようにCPUからGPUへのコピー(あるいはGPUからCPUへのコピー)にも使われる。dtypedeviceを同時に指定することもできる。

float(), double()メソッドなど

float()メソッドでtorch.float(= torch.float32)、double()メソッドでtorch.double(= torch.float64)というように、データ型の名前のメソッドでキャストすることもできる。

half(), float(), double(), short(), int(), long()メソッドは定義されているが、float16()int32()メソッドのようなfloatXX(), intXX()メソッドは定義されていないので注意。

t_float64 = t_float32.double()
print(t_float64.dtype)
# torch.float64

# t_float32.float64()
# AttributeError: 'Tensor' object has no attribute 'float64'

演算における暗黙の型変換(キャスト)

異なるデータ型のtorch.Tensorの二項演算(四則演算など)においては、暗黙の型変換(キャスト)が行われる。

floatintの場合はfloatにキャストされ、float同士・int同士の場合はビット数の大きい方にキャストされる。

t_float16 = torch.ones(3, dtype=torch.float16)
t_int64 = torch.ones(3, dtype=torch.int64)

print((t_float16 + t_int64).dtype)
# torch.float16

t_float32 = torch.ones(3, dtype=torch.float32)
t_float64 = torch.ones(3, dtype=torch.float64)

print((t_float32 + t_float64).dtype)
# torch.float64

より詳細なルールについては公式ドキュメントを参照。

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事