note.nkmk.me

pandas.DataFrameをGroupByでグルーピングし統計量を算出

Date: 2018-01-13 / Modified: 2018-06-27 / tags: Python, pandas

pandas.DataFrame, pandas.Seriesgroupby()メソッドでデータをグルーピング(グループ分け)できる。グループごとにデータを集約して、それぞれの平均、最小値、最大値、合計などの統計量を算出したり、任意の関数で処理したりすることが可能。

マルチインデックスを設定することでも同様の処理ができる。以下の記事を参照。

また、pandas.pivot_table(), pandas.crosstab()という関数を用いてカテゴリごとの統計量やサンプル数を算出することもできる。この方法が一番シンプルかも知れない。

ここでは以下の内容について説明する。

  • irisデータセット
  • groupby()でグルーピング
  • 平均、最小値、最大値、合計などを算出
  • 任意の処理を適用して集約: agg()
  • 主要な統計量を一括算出: describe()
  • グラフをプロット
スポンサーリンク

irisデータセット

例としてirisデータセットを使用する。

irisデータセットについては以下の記事を参照。

ここではseabornにサンプルとして含まれているデータを使う。

import pandas as pd
import seaborn as sns
import numpy as np

df = sns.load_dataset("iris")
print(df.shape)
# (150, 5)

print(df.head(5))
#    sepal_length  sepal_width  petal_length  petal_width species
# 0           5.1          3.5           1.4          0.2  setosa
# 1           4.9          3.0           1.4          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa
# 3           4.6          3.1           1.5          0.2  setosa
# 4           5.0          3.6           1.4          0.2  setosa

スペースを削減するために省略した列名に変更しておく。

df.columns = ['sl', 'sw', 'pl', 'pw', 'species']
print(df.head(5))
#     sl   sw   pl   pw species
# 0  5.1  3.5  1.4  0.2  setosa
# 1  4.9  3.0  1.4  0.2  setosa
# 2  4.7  3.2  1.3  0.2  setosa
# 3  4.6  3.1  1.5  0.2  setosa
# 4  5.0  3.6  1.4  0.2  setosa

groupby()でグルーピング

pandas.DataFramegroupby()メソッドでグルーピング(グループ分け)する。

引数に列名を指定するとその列の値ごとにグルーピングされる。

返されるのはGroupByオブジェクトでそれ自体をprint()で出力しても中身は表示されない。

grouped = df.groupby('species')
print(grouped)
# <pandas.core.groupby.groupby.DataFrameGroupBy object at 0x10c69f6a0>

print(type(grouped))
# <class 'pandas.core.groupby.groupby.DataFrameGroupBy'>

pandas.Seriesにも同様にgroupby()メソッドが用意されている。

GroupByオブジェクトのメソッド一覧は以下の公式ドキュメント参照。

size()メソッドでそれぞれのグループごとのサンプル数が確認できる。

print(grouped.size())
# species
# setosa        50
# versicolor    50
# virginica     50
# dtype: int64

平均、最小値、最大値、合計などを算出

GroupByオブジェクトに対しmean(), min(), max(), sum()などのメソッドを適用すると、グループごとの平均、最小値、最大値、合計などの統計量を算出できる。

print(grouped.mean())
#                sl     sw     pl     pw
# species                               
# setosa      5.006  3.428  1.462  0.246
# versicolor  5.936  2.770  4.260  1.326
# virginica   6.588  2.974  5.552  2.026

print(grouped.min())
#              sl   sw   pl   pw
# species                       
# setosa      4.3  2.3  1.0  0.1
# versicolor  4.9  2.0  3.0  1.0
# virginica   4.9  2.2  4.5  1.4

print(grouped.max())
#              sl   sw   pl   pw
# species                       
# setosa      5.8  4.4  1.9  0.6
# versicolor  7.0  3.4  5.1  1.8
# virginica   7.9  3.8  6.9  2.5

print(grouped.sum())
#                sl     sw     pl     pw
# species                               
# setosa      250.3  171.4   73.1   12.3
# versicolor  296.8  138.5  213.0   66.3
# virginica   329.4  148.7  277.6  101.3

そのほか標準偏差std()、分散var()などもある。

いずれのメソッドも新たなpandas.DataFrameを返す。

print(type(grouped.mean()))
# <class 'pandas.core.frame.DataFrame'>

任意の処理を適用して集約: agg()

GroupByオブジェクトのagg()メソッドで任意の処理を適用することができる。

引数に適用したい関数を指定する。関数などの呼び出し可能オブジェクト(callable)または関数名の文字列で指定可能。

print(grouped.agg(min))
#              sl   sw   pl   pw
# species                       
# setosa      4.3  2.3  1.0  0.1
# versicolor  4.9  2.0  3.0  1.0
# virginica   4.9  2.2  4.5  1.4

print(grouped.agg('max'))
#              sl   sw   pl   pw
# species                       
# setosa      5.8  4.4  1.9  0.6
# versicolor  7.0  3.4  5.1  1.8
# virginica   7.9  3.8  6.9  2.5

なお、組み込み関数にないmean()などはmeanと指定するとエラーになる。NumPyの関数np.meanか文字列'mean'として指定する。

# print(grouped.agg(mean))
# NameError: name 'mean' is not defined

print(grouped.agg(np.mean))
#                sl     sw     pl     pw
# species                               
# setosa      5.006  3.428  1.462  0.246
# versicolor  5.936  2.770  4.260  1.326
# virginica   6.588  2.974  5.552  2.026

print(grouped.agg('mean'))
#                sl     sw     pl     pw
# species                               
# setosa      5.006  3.428  1.462  0.246
# versicolor  5.936  2.770  4.260  1.326
# virginica   6.588  2.974  5.552  2.026

NumPyの関数np.meanpd.np.meanとして指定することも可能。

リストで指定すると複数の処理を同時に適用できる。この場合は結果のpandas.DataFramecolumnsがマルチインデックスになる。

print(grouped.agg([min, 'max']))
#              sl        sw        pl        pw     
#             min  max  min  max  min  max  min  max
# species                                           
# setosa      4.3  5.8  2.3  4.4  1.0  1.9  0.1  0.6
# versicolor  4.9  7.0  2.0  3.4  3.0  5.1  1.0  1.8
# virginica   4.9  7.9  2.2  3.8  4.5  6.9  1.4  2.5

列名をキーとした辞書(dict型オブジェクト)で列ごとに異なる処理を適用することも可能。

print(grouped.agg({'sl': min, 'sw': max, 'pl': np.mean, 'pw': 'mean'}))
#              sl   sw     pl     pw
# species                           
# setosa      4.3  4.4  1.462  0.246
# versicolor  4.9  3.4  4.260  1.326
# virginica   4.9  3.8  5.552  2.026

無名関数(ラムダ式)でもOK。

print(grouped.agg(lambda x: max(x) - min(x)))
#              sl   sw   pl   pw
# species                       
# setosa      1.5  2.1  0.9  0.5
# versicolor  2.1  1.4  2.1  0.8
# virginica   3.0  1.6  2.4  1.1

ラムダ式に対しては各グループの値がpandas.Seriesとして渡される。

print(grouped.agg(lambda x: type(x))['sl'])
# species
# setosa        <class 'pandas.core.series.Series'>
# versicolor    <class 'pandas.core.series.Series'>
# virginica     <class 'pandas.core.series.Series'>
# Name: sl, dtype: object

pandas.Seriesを受け取って一つのオブジェクトを返すラムダ式でないとエラーになるので注意。

# print(grouped.agg(lambda x: x + 1))
# Exception: Must produce aggregated value

文字列の要素に対して処理した例は以下の記事の最後を参照。

複数の統計量を一括算出: describe()

describe()メソッドを使うとグループごとの主要な統計量を一括で算出できる。

以下の例ではsl列に対する結果のみ出力している。

print(grouped.describe()['sl']) 
#             count   mean       std  min    25%  50%  75%  max
# species                                                      
# setosa       50.0  5.006  0.352490  4.3  4.800  5.0  5.2  5.8
# versicolor   50.0  5.936  0.516171  4.9  5.600  5.9  6.3  7.0
# virginica    50.0  6.588  0.635880  4.9  6.225  6.5  6.9  7.9

各項目の意味などは以下の記事を参照。

グループごとの統計量のグラフをプロット

上述のようにGroupByオブジェクトに対しmean(), min(), max(), sum()などのメソッドを適用すると返ってくるのはpandas.DataFrameなので、そのままplot()メソッドを使ってグラフを描画して可視化できる。

print(type(grouped.max()))
# <class 'pandas.core.frame.DataFrame'>

ax = grouped.max().plot.bar(rot=0)
fig = ax.get_figure()
fig.savefig('data/dst/iris_pandas_groupby_max.jpg')

pandas groupby plot

plot()についての詳細は以下の記事を参照。

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事