note.nkmk.me

scikit-learnでデータを訓練用とテスト用に分割するtrain_test_split

Date: 2019-04-16 / tags: Python, scikit-learn, 機械学習

scikit-learnのtrain_test_split()関数を使うと、NumPy配列ndarryやリストなどを二分割できる。機械学習においてデータを訓練用(学習用)とテスト用に分割してホールドアウト検証を行う際に用いる。

ここでは以下の内容について説明する。

  • train_test_split()の基本的な使い方
  • 割合、個数を指定: 引数test_size, train_size
  • シャッフルするかを指定: 引数shuffle
  • 乱数シードを指定: 引数random_seed
  • 複数のデータを分割
  • 層化抽出: 引数stratify
  • 具体的な例(アイリスデータセット)
スポンサーリンク

train_test_split()の基本的な使い方

train_test_split()にNumPy配列ndarryを渡すと、二分割されたndarrayが要素として格納されたリストが返される。

import numpy as np
from sklearn.model_selection import train_test_split

a = np.arange(10)
print(a)
# [0 1 2 3 4 5 6 7 8 9]

print(train_test_split(a))
# [array([8, 4, 9, 0, 6, 7, 2]), array([1, 5, 3])]

print(type(train_test_split(a)))
# <class 'list'>

print(len(train_test_split(a)))
# 2

以下のように、アンパックでそれぞれ2つの変数に代入することが多い。

a_train, a_test = train_test_split(a)

print(a_train)
# [1 5 7 4 6 9 3]

print(a_test)
# [2 8 0]

例はndarryだが、Python標準のリストやpandas.DataFrame、疎行列scipy.sparseにも対応している。

割合、個数を指定: 引数test_size, train_size

引数test_sizeでテスト用(返されるリストの2つめの要素)の割合を指定できる。デフォルトはtest_size=0.25で25%がテスト用となる。小数点以下は切り上げとなり、上の例では10 * 0.25 = 2.5 -> 3となる。

test_sizeには0.0 ~ 1.0の割合か、個数を指定する。

割合で指定。

a_train, a_test = train_test_split(a, test_size=0.6)

print(a_train)
# [1 0 5 2]

print(a_test)
# [7 3 4 6 9 8]

個数で指定。

a_train, a_test = train_test_split(a, test_size=6)

print(a_train)
# [7 0 5 3]

print(a_test)
# [4 8 9 2 1 6]

デフォルトでは残りが自動的に訓練用(返されるリストの1つめの要素)となるが、引数train_sizeで別途指定することも可能。

a_train, a_test = train_test_split(a, test_size=0.3, train_size=0.4)

print(a_train)
# [2 8 1 5]

print(a_test)
# [7 6 0]

a_train, a_test = train_test_split(a, test_size=3, train_size=4)

print(a_train)
# [3 7 1 6]

print(a_test)
# [8 0 4]

上の例のように訓練用とテスト用を合わせて100%以下でも問題ないが、100%を超える値を指定するとエラーとなる。

# a_train, a_test = train_test_split(a, test_size=0.8, train_size=0.7)
# ValueError: The sum of test_size and train_size = 1.500000, should be smaller than 1.0. Reduce test_size and/or train_size.

# a_train, a_test = train_test_split(a, test_size=8, train_size=7)
# ValueError: The sum of train_size and test_size = 15, should be smaller than the number of samples 10. Reduce test_size and/or train_size.

シャッフルするかを指定: 引数shuffle

これまでの例のように、デフォルトでは要素がシャッフルされて分割される。引数shuffle=Falseとするとシャッフルされずに先頭から順番に分割される。

a_train, a_test = train_test_split(a, shuffle=False)

print(a_train)
# [0 1 2 3 4 5 6]

print(a_test)
# [7 8 9]

乱数シードを指定: 引数random_seed

シャッフルされる場合、デフォルトでは実行するたびにランダムに分割される。引数random_seedを指定して乱数シードを固定すると常に同じように分割される。

a_train, a_test = train_test_split(a, random_state=0)

print(a_train)
# [9 1 6 7 3 0 5]

print(a_test)
# [2 8 4]

機械学習のモデルの性能を比較するような場合、どのように分割されるかによって結果が異なってしまうため、乱数シードを固定して常に同じように分割されるようにする必要がある。

複数のデータを分割

これまでの例では1つのデータ(ndarray)を分割したが、train_test_split()は複数のデータを分割することができる。

以下の2つのデータを例とする。

X = np.arange(20).reshape(2, 10).T
print(X)
# [[ 0 10]
#  [ 1 11]
#  [ 2 12]
#  [ 3 13]
#  [ 4 14]
#  [ 5 15]
#  [ 6 16]
#  [ 7 17]
#  [ 8 18]
#  [ 9 19]]

y = np.arange(10)
print(y)
# [0 1 2 3 4 5 6 7 8 9]

train_test_split()の引数に順に指定すると、以下のように分割される。2つのデータが対応して分割されているのが確認できる。

X_train, X_test, y_train, y_test = train_test_split(X, y)

print(X_train)
# [[ 6 16]
#  [ 5 15]
#  [ 4 14]
#  [ 2 12]
#  [ 3 13]
#  [ 8 18]
#  [ 1 11]]

print(X_test)
# [[ 7 17]
#  [ 0 10]
#  [ 9 19]]

print(y_train)
# [6 5 4 2 3 8 1]

print(y_test)
# [7 0 9]

3つ以上でも問題ないが、それぞれのデータのlen()またはshape[0](最初の次元の大きさ)が一致している必要がある。一致していないとエラー。

y_mismatch = np.arange(8)
print(y_mismatch)
# [0 1 2 3 4 5 6 7]

# X_train, X_test, y_train, y_test = train_test_split(X, y_mismatch)
# ValueError: Found input variables with inconsistent numbers of samples: [10, 8]

層化抽出: 引数stratify

例えば教師あり学習では特徴行列(説明変数)と正解ラベル(目的変数)の2つのデータを用いる。

二値分類(2クラス分類)では正解ラベルは以下のように0, 1のいずれかになる。

y = np.array([0] * 5 + [1] * 5)
print(y)
# [0 0 0 0 0 1 1 1 1 1]

教師あり学習のためにデータを分割したい場合、正解ラベルの比率を訓練用とテスト用で一致させたいことがあるが、以下の例ではテスト用に0の要素が含まれていない。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100)

print(y_train)
# [0 1 0 0 0 0 1 1]

print(y_test)
# [1 1]

引数stratifyに均等に分割させたいデータ(多くの場合は正解ラベル)を指定すると、そのデータの値の比率が一致するように分割される。

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=100,
                                                    stratify=y)

print(y_train)
# [1 1 0 0 0 1 1 0]

print(y_test)
# [1 0]

サンプル数が少ないとイメージしにくいので、次の具体例も参照されたい。

具体的な例(アイリスデータセット)

具体的な例として、アイリスデータセットを分割する。

150件のデータがSepal Length(がく片の長さ)、Sepal Width(がく片の幅)、Petal Length(花びらの長さ)、Petal Width(花びらの幅)の4つの特徴量を持っており、Setosa, Versicolor, Virginicaの3品種に分類されている。

load_iris()でデータを取得する。正解ラベルy0, 1, 2の3種類。

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris

data = load_iris()

X = data['data']
y = data['target']

print(X.shape)
# (150, 4)

print(X[:5])
# [[5.1 3.5 1.4 0.2]
#  [4.9 3.  1.4 0.2]
#  [4.7 3.2 1.3 0.2]
#  [4.6 3.1 1.5 0.2]
#  [5.  3.6 1.4 0.2]]

print(y.shape)
# (150,)

print(y)
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
#  0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
#  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
#  2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
#  2 2]

train_test_split()で以下のように分割できる。サイズが大きいので形状shapeのみ示している。

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

print(X_train.shape)
# (112, 4)

print(X_test.shape)
# (38, 4)

print(y_train.shape)
# (112,)

print(y_test.shape)
# (38,)

テスト用の正解ラベルy_testを確認すると、各ラベルの数が不均等になっている。

print(y_test)
# [2 1 0 2 0 2 0 1 1 1 2 1 1 1 1 0 1 1 0 0 2 1 0 0 2 0 0 1 1 0 2 1 0 2 2 1 0
#  1]

print((y_test == 0).sum())
# 13

print((y_test == 1).sum())
# 16

print((y_test == 2).sum())
# 9

引数stratifyを指定すると、各ラベルの数を揃えて分割できる。

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0, stratify=y)

print(y_test)
# [0 0 0 0 1 1 1 0 1 2 2 2 1 2 1 0 0 2 0 1 2 1 1 0 2 0 0 1 2 1 0 1 2 2 0 1 2
#  2]

print((y_test == 0).sum())
# 13

print((y_test == 1).sum())
# 13

print((y_test == 2).sum())
# 12
スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事