note.nkmk.me

scikit-learnのサンプルデータセットの一覧と使い方

Date: 2019-04-16 / tags: Python, scikit-learn, 機械学習

scikit-learnには分類(classification)や回帰(regression)などの機械学習の問題に使えるデータセットが同梱されている。アルゴリズムを試してみたりするのに便利。

画像などのサイズの大きいデータをダウンロードするための関数も用意されている。

公式ドキュメントの表記に従い、scikit-learnに同梱されているデータをトイ・データセット(Toy dataset)、ダウンロードが必要なサイズの大きいデータを実世界データセット(Real world dataset)と呼ぶ。

ここでは以下の内容について説明する。

  • トイ・データセット(Toy dataset)の一覧
  • 実世界データセット(Real world dataset)の一覧
  • インポートの方法
  • Bunch型の使い方(例: load_iris()
  • データのダウンロード(例: fetch_olivetti_faces()

人工的なデータを生成するための関数もあるがここでは触れない。

以降の内容はscikit-learnのバージョン0.20.3時点のもの。

スポンサーリンク

トイ・データセット(Toy dataset)の一覧

公式ドキュメントは以下。

バージョン0.20.3時点で7つのデータセットがある。詳細はリンク先を参照。

実世界データセット(Real world dataset)の一覧

公式ドキュメントは以下。

バージョン0.20.3時点で9つのデータセット(うち2つは同じデータソース)がある。詳細はリンク先を参照。

インポートの方法

上に一覧で示したデータを取得する関数はsklearn.datasetsモジュールにある。

公式ドキュメントのサンプルコードでは各関数をインポートしている。

from sklearn.datasets import load_iris

data = load_iris()

print(type(data))
# <class 'sklearn.utils.Bunch'>

以下のようにsklearn.datasetsモジュールをインポートしてもOK。

import sklearn.datasets

data_iris = sklearn.datasets.load_iris()

print(type(data_iris))
# <class 'sklearn.utils.Bunch'>

data_boston = sklearn.datasets.load_boston()

print(type(data_boston))
# <class 'sklearn.utils.Bunch'>

インポートについての詳細は以下の記事を参照。

Bunch型の使い方(例: load_iris())

sklearn.datasetsモジュールの関数はBunch型のオブジェクトを返す。以下、load_iris()を例とする。格納されている情報に違いはあるが、他の関数でも基本的には同様。

import pandas as pd
from sklearn.datasets import load_iris

data = load_iris()

print(type(data))
# <class 'sklearn.utils.Bunch'>

Bunch型は辞書dictのサブクラス。

print(issubclass(type(data), dict))
# True

辞書のメソッドが使用可能。例えばkeys()でキーの一覧を確認できる。

print(data.keys())
# dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])

値を取得する場合は、辞書のように['キー名']としてもいいし、.キー名のようにしてもいい。

print(data['feature_names'])
# ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

print(data.feature_names)
# ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']

データのパスfilenameや概要DESCRなどの情報が格納されている。関数によってはdata_filenametarget_filenameに分かれていたりするものもある。

print(data.filename)
# /usr/local/lib/python3.7/site-packages/sklearn/datasets/data/iris.csv

print(data.DESCR)
# .. _iris_dataset:
# 
# Iris plants dataset
# --------------------
# 
# **Data Set Characteristics:**
# 
#     :Number of Instances: 150 (50 in each of three classes)
#     :Number of Attributes: 4 numeric, predictive attributes and the class
#     :Attribute Information:
#         - sepal length in cm
#         - sepal width in cm
#         - petal length in cm
#         - petal width in cm
#         - class:
#                 - Iris-Setosa
#                 - Iris-Versicolour
#                 - Iris-Virginica
#                 
#     :Summary Statistics:
# 
#     ============== ==== ==== ======= ===== ====================
#                     Min  Max   Mean    SD   Class Correlation
#     ============== ==== ==== ======= ===== ====================
#     sepal length:   4.3  7.9   5.84   0.83    0.7826
#     sepal width:    2.0  4.4   3.05   0.43   -0.4194
#     petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
#     petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
#     ============== ==== ==== ======= ===== ====================
# 
#     :Missing Attribute Values: None
#     :Class Distribution: 33.3% for each of 3 classes.
#     :Creator: R.A. Fisher
#     :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
#     :Date: July, 1988
# 
# The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
# from Fisher's paper. Note that it's the same as in R, but not as in the UCI
# Machine Learning Repository, which has two wrong data points.
# 
# This is perhaps the best known database to be found in the
# pattern recognition literature.  Fisher's paper is a classic in the field and
# is referenced frequently to this day.  (See Duda & Hart, for example.)  The
# data set contains 3 classes of 50 instances each, where each class refers to a
# type of iris plant.  One class is linearly separable from the other 2; the
# latter are NOT linearly separable from each other.
# 
# .. topic:: References
# 
#    - Fisher, R.A. "The use of multiple measurements in taxonomic problems"
#      Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
#      Mathematical Statistics" (John Wiley, NY, 1950).
#    - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
#      (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
#    - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
#      Structure and Classification Rule for Recognition in Partially Exposed
#      Environments".  IEEE Transactions on Pattern Analysis and Machine
#      Intelligence, Vol. PAMI-2, No. 1, 67-71.
#    - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
#      on Information Theory, May 1972, 431-433.
#    - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
#      conceptual clustering system finds 3 classes in the data.
#    - Many, many more ...

説明変数がdata、目的変数がtargetに格納されている。

X = data.data
y = data.target

print(type(X))
# <class 'numpy.ndarray'>

print(X.shape)
# (150, 4)

print(X[:5])
# [[5.1 3.5 1.4 0.2]
#  [4.9 3.  1.4 0.2]
#  [4.7 3.2 1.3 0.2]
#  [4.6 3.1 1.5 0.2]
#  [5.  3.6 1.4 0.2]]

print(type(y))
# <class 'numpy.ndarray'>

print(y.shape)
# (150,)

print(y)
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
#  0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
#  1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
#  2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
#  2 2]

pandas.DataFrameにしたい場合は各特徴量の名称であるfeature_namesを引数columnsに指定するとよい。

df_X = pd.DataFrame(data.data, columns=data.feature_names)

print(df_X.head())
#    sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)
# 0                5.1               3.5                1.4               0.2
# 1                4.9               3.0                1.4               0.2
# 2                4.7               3.2                1.3               0.2
# 3                4.6               3.1                1.5               0.2
# 4                5.0               3.6                1.4               0.2

pandas.Seriesの場合はそのまま。

s_y = pd.Series(data.target)

print(s_y.head())
# 0    0
# 1    0
# 2    0
# 3    0
# 4    0
# dtype: int64

pandas.DataFramepandas.Seriesにするとpandasの便利なメソッドが使える。

print(df_X.describe())
#        sepal length (cm)  sepal width (cm)  petal length (cm)  \
# count         150.000000        150.000000         150.000000   
# mean            5.843333          3.057333           3.758000   
# std             0.828066          0.435866           1.765298   
# min             4.300000          2.000000           1.000000   
# 25%             5.100000          2.800000           1.600000   
# 50%             5.800000          3.000000           4.350000   
# 75%             6.400000          3.300000           5.100000   
# max             7.900000          4.400000           6.900000   
# 
#        petal width (cm)  
# count        150.000000  
# mean           1.199333  
# std            0.762238  
# min            0.100000  
# 25%            0.300000  
# 50%            1.300000  
# 75%            1.800000  
# max            2.500000  

print(s_y.value_counts())
# 2    50
# 1    50
# 0    50
# dtype: int64

バージョン0.18以降は引数return_X_y=Trueとすることでdatatargetを直接取得できる。関数によっては引数return_X_yが定義されていない場合もあるので注意。

X, y = load_iris(return_X_y=True)

print(type(X))
# <class 'numpy.ndarray'>

print(X.shape)
# (150, 4)

print(type(y))
# <class 'numpy.ndarray'>

print(y.shape)
# (150,)

データのダウンロード(例: fetch_olivetti_faces())

実世界データセットのfetch_xxx()は初回実行時にデータをダウンロードする。例はfetch_olivetti_faces()

from sklearn.datasets import fetch_olivetti_faces

data = fetch_olivetti_faces()

デフォルトではホームディレクトリにscikit_learn_dataというディレクトリが作られ、そこにデータがダウンロードされる。

以下のような表示がされる。進捗状況は特に示されないので注意。

downloading Olivetti faces from https://ndownloader.figshare.com/files/5976027 to /Users/xxx/scikit_learn_data

一度ダウンロードされれば次からはローカルのファイルが読み込まれるようになる。

fetch_xxx()では引数data_homeにデータをダウンロードするパスを指定することも可能。その場所にデータが無ければダウンロードされ、データがあればそれが読み込まれる。

fetch_xxx()Bunch型のオブジェクトを返す。使い方は上述の通り。

print(type(data))
# <class 'sklearn.utils.Bunch'>

print(data.keys())
# dict_keys(['data', 'images', 'target', 'DESCR'])

print(data.data.shape)
# (400, 4096)

print(data.target.shape)
# (400,)

fetch_olivetti_faces()ではreturn_X_yが定義されていない。

fetch_olivetti_faces()にはdataのほかimagesもある。

print(data.images.shape)
# (400, 64, 64)

fetch_olivetti_faces()は400枚の64ピクセル x 64ピクセルの顔写真画像(白黒)のデータセット。imagesには各画像がそのままの形で格納されており全体の形状shape(400, 64, 64)となっている。一方、dataには各画像が平坦化(一次元化)されて格納されており全体の形状shape(400, 4096)となっている(64 x 64 = 4096)。

どのようなデータかは概要DESCRや上にリンクを示した公式ドキュメントなどに書かれている。

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事