scikit-learnのサンプルデータセットの一覧と使い方
scikit-learnには分類(classification)や回帰(regression)などの機械学習の問題に使えるデータセットが同梱されている。アルゴリズムを試してみたりするのに便利。
画像などのサイズの大きいデータをダウンロードするための関数も用意されている。
公式ドキュメントの表記に従い、scikit-learnに同梱されているデータをトイ・データセット(Toy dataset)、ダウンロードが必要なサイズの大きいデータを実世界データセット(Real world dataset)と呼ぶ。
ここでは以下の内容について説明する。
- トイ・データセット(Toy dataset)の一覧
- 実世界データセット(Real world dataset)の一覧
- インポートの方法
Bunch
型の使い方(例:load_iris()
)- データのダウンロード(例:
fetch_olivetti_faces()
)
人工的なデータを生成するための関数もあるがここでは触れない。
以降の内容はscikit-learnのバージョン0.20.3
時点のもの。
トイ・データセット(Toy dataset)の一覧
公式ドキュメントは以下。
バージョン0.20.3
時点で7つのデータセットがある。詳細はリンク先を参照。
load_boston()
load_iris()
- sklearn.datasets.load_iris — scikit-learn 0.20.3 documentation
- 分類
- アイリス(アヤメ)の種類
load_diabetes()
load_digits()
load_linnerud()
- sklearn.datasets.load_linnerud — scikit-learn 0.20.3 documentation
- 回帰
- 生理学的(physiological)測定結果と運動(exercise)測定結果
load_wine()
load_breast_cancer
実世界データセット(Real world dataset)の一覧
公式ドキュメントは以下。
バージョン0.20.3
時点で9つのデータセット(うち2つは同じデータソース)がある。詳細はリンク先を参照。
fetch_olivetti_faces()
- sklearn.datasets.fetch_olivetti_faces — scikit-learn 0.20.3 documentation
- 分類
- 同一人物の様々な状態の顔画像(40人 x 10枚)
fetch_20newsgroups()
fetch_20newsgroups_vectorized()
- sklearn.datasets.fetch_20newsgroups_vectorized — scikit-learn 0.20.3 documentation
- 分類
fetch_20newsgroups()
の特徴抽出済みバージョン
fetch_lfw_people()
- sklearn.datasets.fetch_lfw_people — scikit-learn 0.20.3 documentation
- 分類
- LFW: The Labeled Faces in the Wild
- 有名人の顔写真
- 人物を分類
fetch_lfw_pairs()
- sklearn.datasets.fetch_lfw_pairs — scikit-learn 0.20.3 documentation
- 分類
- LFW: The Labeled Faces in the Wild
- 有名人の顔写真
- 2枚ずつペアになっている
- ペアが同一人物かどうかを判定
fetch_covtype()
fetch_rcv1()
- sklearn.datasets.fetch_rcv1 — scikit-learn 0.20.3 documentation
- 分類
- RCV1: Reuters Corpus Volume I
- カテゴリ別のニュース(ベクトル化済み)
fetch_kddcup99()
fetch_california_housing()
インポートの方法
上に一覧で示したデータを取得する関数はsklearn.datasets
モジュールにある。
公式ドキュメントのサンプルコードでは各関数をインポートしている。
from sklearn.datasets import load_iris
data = load_iris()
print(type(data))
# <class 'sklearn.utils.Bunch'>
以下のようにsklearn.datasets
モジュールをインポートしてもOK。
import sklearn.datasets
data_iris = sklearn.datasets.load_iris()
print(type(data_iris))
# <class 'sklearn.utils.Bunch'>
data_boston = sklearn.datasets.load_boston()
print(type(data_boston))
# <class 'sklearn.utils.Bunch'>
インポートについての詳細は以下の記事を参照。
Bunch型の使い方(例: load_iris())
sklearn.datasets
モジュールの関数はBunch
型のオブジェクトを返す。以下、load_iris()
を例とする。格納されている情報に違いはあるが、他の関数でも基本的には同様。
import pandas as pd
from sklearn.datasets import load_iris
data = load_iris()
print(type(data))
# <class 'sklearn.utils.Bunch'>
Bunch
型は辞書dict
のサブクラス。
print(issubclass(type(data), dict))
# True
辞書のメソッドが使用可能。例えばkeys()
でキーの一覧を確認できる。
print(data.keys())
# dict_keys(['data', 'target', 'target_names', 'DESCR', 'feature_names', 'filename'])
値を取得する場合は、辞書のように['キー名']
としてもいいし、.キー名
のようにしてもいい。
print(data['feature_names'])
# ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
print(data.feature_names)
# ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
データのパスfilename
や概要DESCR
などの情報が格納されている。関数によってはdata_filename
とtarget_filename
に分かれていたりするものもある。
print(data.filename)
# /usr/local/lib/python3.7/site-packages/sklearn/datasets/data/iris.csv
print(data.DESCR)
# .. _iris_dataset:
#
# Iris plants dataset
# --------------------
#
# **Data Set Characteristics:**
#
# :Number of Instances: 150 (50 in each of three classes)
# :Number of Attributes: 4 numeric, predictive attributes and the class
# :Attribute Information:
# - sepal length in cm
# - sepal width in cm
# - petal length in cm
# - petal width in cm
# - class:
# - Iris-Setosa
# - Iris-Versicolour
# - Iris-Virginica
#
# :Summary Statistics:
#
# ============== ==== ==== ======= ===== ====================
# Min Max Mean SD Class Correlation
# ============== ==== ==== ======= ===== ====================
# sepal length: 4.3 7.9 5.84 0.83 0.7826
# sepal width: 2.0 4.4 3.05 0.43 -0.4194
# petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
# petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
# ============== ==== ==== ======= ===== ====================
#
# :Missing Attribute Values: None
# :Class Distribution: 33.3% for each of 3 classes.
# :Creator: R.A. Fisher
# :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
# :Date: July, 1988
#
# The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
# from Fisher's paper. Note that it's the same as in R, but not as in the UCI
# Machine Learning Repository, which has two wrong data points.
#
# This is perhaps the best known database to be found in the
# pattern recognition literature. Fisher's paper is a classic in the field and
# is referenced frequently to this day. (See Duda & Hart, for example.) The
# data set contains 3 classes of 50 instances each, where each class refers to a
# type of iris plant. One class is linearly separable from the other 2; the
# latter are NOT linearly separable from each other.
#
# .. topic:: References
#
# - Fisher, R.A. "The use of multiple measurements in taxonomic problems"
# Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
# Mathematical Statistics" (John Wiley, NY, 1950).
# - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
# (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
# - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
# Structure and Classification Rule for Recognition in Partially Exposed
# Environments". IEEE Transactions on Pattern Analysis and Machine
# Intelligence, Vol. PAMI-2, No. 1, 67-71.
# - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE Transactions
# on Information Theory, May 1972, 431-433.
# - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al"s AUTOCLASS II
# conceptual clustering system finds 3 classes in the data.
# - Many, many more ...
説明変数がdata
、目的変数がtarget
に格納されている。
X = data.data
y = data.target
print(type(X))
# <class 'numpy.ndarray'>
print(X.shape)
# (150, 4)
print(X[:5])
# [[5.1 3.5 1.4 0.2]
# [4.9 3. 1.4 0.2]
# [4.7 3.2 1.3 0.2]
# [4.6 3.1 1.5 0.2]
# [5. 3.6 1.4 0.2]]
print(type(y))
# <class 'numpy.ndarray'>
print(y.shape)
# (150,)
print(y)
# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
# 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
# 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
# 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
# 2 2]
pandas.DataFrame
にしたい場合は各特徴量の名称であるfeature_names
を引数columns
に指定するとよい。
df_X = pd.DataFrame(data.data, columns=data.feature_names)
print(df_X.head())
# sepal length (cm) sepal width (cm) petal length (cm) petal width (cm)
# 0 5.1 3.5 1.4 0.2
# 1 4.9 3.0 1.4 0.2
# 2 4.7 3.2 1.3 0.2
# 3 4.6 3.1 1.5 0.2
# 4 5.0 3.6 1.4 0.2
pandas.Series
の場合はそのまま。
s_y = pd.Series(data.target)
print(s_y.head())
# 0 0
# 1 0
# 2 0
# 3 0
# 4 0
# dtype: int64
pandas.DataFrame
やpandas.Series
にするとpandasの便利なメソッドが使える。
print(df_X.describe())
# sepal length (cm) sepal width (cm) petal length (cm) \
# count 150.000000 150.000000 150.000000
# mean 5.843333 3.057333 3.758000
# std 0.828066 0.435866 1.765298
# min 4.300000 2.000000 1.000000
# 25% 5.100000 2.800000 1.600000
# 50% 5.800000 3.000000 4.350000
# 75% 6.400000 3.300000 5.100000
# max 7.900000 4.400000 6.900000
#
# petal width (cm)
# count 150.000000
# mean 1.199333
# std 0.762238
# min 0.100000
# 25% 0.300000
# 50% 1.300000
# 75% 1.800000
# max 2.500000
print(s_y.value_counts())
# 2 50
# 1 50
# 0 50
# dtype: int64
バージョン0.18
以降は引数return_X_y=True
とすることでdata
とtarget
を直接取得できる。関数によっては引数return_X_y
が定義されていない場合もあるので注意。
X, y = load_iris(return_X_y=True)
print(type(X))
# <class 'numpy.ndarray'>
print(X.shape)
# (150, 4)
print(type(y))
# <class 'numpy.ndarray'>
print(y.shape)
# (150,)
データのダウンロード(例: fetch_olivetti_faces())
実世界データセットのfetch_xxx()
は初回実行時にデータをダウンロードする。例はfetch_olivetti_faces()
。
from sklearn.datasets import fetch_olivetti_faces
data = fetch_olivetti_faces()
デフォルトではホームディレクトリにscikit_learn_data
というディレクトリが作られ、そこにデータがダウンロードされる。
以下のような表示がされる。進捗状況は特に示されないので注意。
downloading Olivetti faces from https://ndownloader.figshare.com/files/5976027 to /Users/xxx/scikit_learn_data
一度ダウンロードされれば次からはローカルのファイルが読み込まれるようになる。
fetch_xxx()
では引数data_home
にデータをダウンロードするパスを指定することも可能。その場所にデータが無ければダウンロードされ、データがあればそれが読み込まれる。
fetch_xxx()
もBunch
型のオブジェクトを返す。使い方は上述の通り。
print(type(data))
# <class 'sklearn.utils.Bunch'>
print(data.keys())
# dict_keys(['data', 'images', 'target', 'DESCR'])
print(data.data.shape)
# (400, 4096)
print(data.target.shape)
# (400,)
fetch_olivetti_faces()
ではreturn_X_y
が定義されていない。
fetch_olivetti_faces()
にはdata
のほかimages
もある。
print(data.images.shape)
# (400, 64, 64)
fetch_olivetti_faces()
は400枚の64ピクセル x 64ピクセル
の顔写真画像(白黒)のデータセット。images
には各画像がそのままの形で格納されており全体の形状shape
は(400, 64, 64)
となっている。一方、data
には各画像が平坦化(一次元化)されて格納されており全体の形状shape
は(400, 4096)
となっている(64 x 64 = 4096
)。
どのようなデータかは概要DESCR
や上にリンクを示した公式ドキュメントなどに書かれている。