pandasのMultiindexで階層ごとの統計量・サンプル数を算出
Multiindex(マルチインデックス、階層型インデックス)を設定したpandas.DataFrame
, pandas.Series
では、各階層の項目ごとの統計量(平均、最大、最小、合計、標準偏差など)およびサンプル数を算出できる。
マルチインデックスを設定せずgroupby
メソッドを使っても同様のことが可能。groupby
を使った方法も適宜紹介するが、詳細は以下の記事を参照。
また、pandas.pivot_table()
, pandas.crosstab()
という関数を用いてカテゴリごとの統計量やサンプル数を算出することもできる。この方法が一番シンプルかも知れない。
マルチインデックスの設定方法やマルチインデックスにおける行・列の選択方法などについては以下の記事を参照。
例としてタイタニックの生存者のデータを使用する。Kaggleの問題からダウンロードできる。
比較のため、インデックスが一列だけのオブジェクトとマルチインデックスを設定したオブジェクトを用意する。
import pandas as pd
df = pd.read_csv('data/src/titanic_train.csv')
df.drop(labels=['Name', 'Ticket', 'Cabin'], axis=1, inplace=True)
df_single = df.set_index('PassengerId')
print(df_single.head())
# Survived Pclass Sex Age SibSp Parch Fare Embarked
# PassengerId
# 1 0 3 male 22.0 1 0 7.2500 S
# 2 1 1 female 38.0 1 0 71.2833 C
# 3 1 3 female 26.0 0 0 7.9250 S
# 4 1 1 female 35.0 1 0 53.1000 S
# 5 0 3 male 35.0 0 0 8.0500 S
df_multi = df.set_index(['Sex', 'Pclass', 'Embarked', 'PassengerId']).sort_index()
print(df_multi.head())
# Survived Age SibSp Parch Fare
# Sex Pclass Embarked PassengerId
# female 1 C 2 1 38.0 1 0 71.2833
# 32 1 NaN 1 0 146.5208
# 53 1 49.0 1 0 76.7292
# 178 0 50.0 0 0 28.7125
# 195 1 44.0 0 0 27.7208
print(df_multi.tail())
# Survived Age SibSp Parch Fare
# Sex Pclass Embarked PassengerId
# male 3 S 877 0 20.0 0 0 9.8458
# 878 0 19.0 0 0 7.8958
# 879 0 NaN 0 0 7.8958
# 882 0 33.0 0 0 7.8958
# 885 0 25.0 0 0 7.0500
drop()
で不要な列を削除し、head()
, tail()
で先頭、末尾のデータを表示している。
統計量(平均、最大、最小、合計、標準偏差など)を算出
pandas.DataFrame
, pandas.Series
には統計量を算出する以下のようなメソッドがある。
- 平均:
mean()
- 最大:
max()
- 最小:
min()
- 合計:
sum()
- 標準偏差:
std()
pandas.DataFrame
でこれらのメソッドを呼ぶと、各列に対する統計量がpandas.Series
で返る。
インデックスに設定した列に対する統計量は算出されない。
print(df_multi.mean())
# Survived 0.383838
# Age 29.699118
# SibSp 0.523008
# Parch 0.381594
# Fare 32.204208
# dtype: float64
print(df_single.mean())
# Survived 0.383838
# Pclass 2.308642
# Age 29.699118
# SibSp 0.523008
# Parch 0.381594
# Fare 32.204208
# dtype: float64
print(df_multi.max())
# Survived 1.0000
# Age 80.0000
# SibSp 8.0000
# Parch 6.0000
# Fare 512.3292
# dtype: float64
print(df_single.max())
# Survived 1
# Pclass 3
# Sex male
# Age 80
# SibSp 8
# Parch 6
# Fare 512.329
# dtype: object
以下、平均値を返すmean()
メソッドで説明するが、ほかのメソッドでも同様。
マルチインデックスの場合、引数level
にマルチインデックスの列名を文字列で指定するかレベル(階層)を整数値で指定するとそのレベル(階層)の項目ごとの統計量が算出される。レベル値は外側(大項目)が0
。
print(df_multi.mean(level='Sex'))
# Survived Age SibSp Parch Fare
# Sex
# female 0.742038 27.915709 0.694268 0.649682 44.479818
# male 0.188908 30.726645 0.429809 0.235702 25.523893
print(df_multi.mean(level=0))
# Survived Age SibSp Parch Fare
# Sex
# female 0.742038 27.915709 0.694268 0.649682 44.479818
# male 0.188908 30.726645 0.429809 0.235702 25.523893
print(df_multi.mean(level=1))
# Survived Age SibSp Parch Fare
# Pclass
# 1 0.629630 38.233441 0.416667 0.356481 84.154687
# 2 0.472826 29.877630 0.402174 0.380435 20.662183
# 3 0.242363 25.140620 0.615071 0.393075 13.675550
print(df_multi.mean(level=2))
# Survived Age SibSp Parch Fare
# Embarked
# C 0.553571 30.814769 0.386905 0.363095 59.954144
# Q 0.389610 28.089286 0.428571 0.168831 13.276030
# S 0.336957 29.445397 0.571429 0.413043 27.079812
引数level
にリストを指定すると複数の階層に対する統計量が算出できる。列名でもレベル値でもどちらでもOK。
print(df_multi.mean(level=['Sex', 'Pclass']))
# Survived Age SibSp Parch Fare
# Sex Pclass
# female 1 0.968085 34.611765 0.553191 0.457447 106.125798
# 2 0.921053 28.722973 0.486842 0.605263 21.970121
# 3 0.500000 21.750000 0.895833 0.798611 16.118810
# male 1 0.368852 41.281386 0.311475 0.278689 67.226127
# 2 0.157407 30.740707 0.342593 0.222222 19.741782
# 3 0.135447 26.507589 0.498559 0.224784 12.661633
print(df_multi.mean(level=[0, 1, 2]))
# Survived Age SibSp Parch Fare
# Sex Pclass Embarked
# female 1 C 0.976744 36.052632 0.511628 0.302326 115.640309
# Q 1.000000 33.000000 1.000000 0.000000 90.000000
# S 0.958333 32.704545 0.604167 0.625000 99.026910
# 2 C 1.000000 19.142857 0.714286 0.571429 25.268457
# Q 1.000000 30.000000 0.000000 0.000000 12.350000
# S 0.910448 29.719697 0.477612 0.626866 21.912687
# 3 C 0.652174 14.062500 0.565217 0.826087 14.694926
# Q 0.727273 22.850000 0.212121 0.242424 10.307833
# S 0.375000 23.223684 1.238636 1.000000 18.670077
# male 1 C 0.404762 40.111111 0.238095 0.333333 93.536707
# Q 0.000000 44.000000 2.000000 0.000000 90.000000
# S 0.354430 41.897188 0.329114 0.253165 52.949947
# 2 C 0.200000 25.937500 0.500000 0.500000 25.421250
# Q 0.000000 57.000000 0.000000 0.000000 12.350000
# S 0.154639 30.875889 0.329897 0.195876 19.232474
# 3 C 0.232558 25.016800 0.232558 0.139535 9.352237
# Q 0.076923 28.142857 0.589744 0.128205 11.924251
# S 0.128302 26.574766 0.528302 0.252830 13.307149
マルチインデックスでない場合、groupby()
メソッドを使ってグルーピングできる。groupby()
の引数by
に列名または列名のリストを指定し、さらにmean()
などの統計量を算出するメソッドを呼ぶ。
print(df_single.groupby(by='Sex').mean())
# Survived Pclass Age SibSp Parch Fare
# Sex
# female 0.742038 2.159236 27.915709 0.694268 0.649682 44.479818
# male 0.188908 2.389948 30.726645 0.429809 0.235702 25.523893
print(df_single.groupby(by=['Sex', 'Pclass', 'Embarked']).mean())
# Survived Age SibSp Parch Fare
# Sex Pclass Embarked
# female 1 C 0.976744 36.052632 0.511628 0.302326 115.640309
# Q 1.000000 33.000000 1.000000 0.000000 90.000000
# S 0.958333 32.704545 0.604167 0.625000 99.026910
# 2 C 1.000000 19.142857 0.714286 0.571429 25.268457
# Q 1.000000 30.000000 0.000000 0.000000 12.350000
# S 0.910448 29.719697 0.477612 0.626866 21.912687
# 3 C 0.652174 14.062500 0.565217 0.826087 14.694926
# Q 0.727273 22.850000 0.212121 0.242424 10.307833
# S 0.375000 23.223684 1.238636 1.000000 18.670077
# male 1 C 0.404762 40.111111 0.238095 0.333333 93.536707
# Q 0.000000 44.000000 2.000000 0.000000 90.000000
# S 0.354430 41.897188 0.329114 0.253165 52.949947
# 2 C 0.200000 25.937500 0.500000 0.500000 25.421250
# Q 0.000000 57.000000 0.000000 0.000000 12.350000
# S 0.154639 30.875889 0.329897 0.195876 19.232474
# 3 C 0.232558 25.016800 0.232558 0.139535 9.352237
# Q 0.076923 28.142857 0.589744 0.128205 11.924251
# S 0.128302 26.574766 0.528302 0.252830 13.307149
サンプル数をカウント
各項目ごとにいくつのデータが含まれているか、サンプル数をカウントするにはgroupby()
メソッドおよびsize()
メソッドを使う。
マルチインデックスの場合、groupby()
の引数level
にマルチインデックスの列名、レベル(階層)、または、そのリストを指定する。
print(df_multi.groupby(level='Sex').size())
# Sex
# female 314
# male 577
# dtype: int64
print(df_multi.groupby(level=2).size())
# Embarked
# C 168
# Q 77
# S 644
# dtype: int64
print(df_multi.groupby(level=[0, 1, 2]).size())
# Sex Pclass Embarked
# female 1 C 43
# Q 1
# S 48
# 2 C 7
# Q 2
# S 67
# 3 C 23
# Q 33
# S 88
# male 1 C 42
# Q 1
# S 79
# 2 C 10
# Q 1
# S 97
# 3 C 43
# Q 39
# S 265
# dtype: int64
マルチインデックスでない場合、groupby()
の引数by
に列名または列名のリストを指定する。
print(df_single.groupby(by=['Sex', 'Pclass', 'Embarked']).size())
# Sex Pclass Embarked
# female 1 C 43
# Q 1
# S 48
# 2 C 7
# Q 2
# S 67
# 3 C 23
# Q 33
# S 88
# male 1 C 42
# Q 1
# S 79
# 2 C 10
# Q 1
# S 97
# 3 C 43
# Q 39
# S 265
# dtype: int64