note.nkmk.me

pandas.DataFrameをGroupByでグルーピングし統計量を算出

Date: 2018-01-13 / tags: Python, pandas
このエントリーをはてなブックマークに追加

pandas.DataFramegroupby()メソッドでデータをグルーピング(グループ分け)できる。グループごとにデータを集計して、それぞれの平均、最小値、最大値、合計などの統計量を算出することが可能。

irisデータセット

例として、irisデータセットを使用する。

ここでは、seabornにサンプルとして含まれているデータを使う。pandas.DataFrameなので便利。

import pandas as pd
import seaborn as sns

df = sns.load_dataset("iris")
print(df.shape)
print(df.head(5))
# (150, 5)
#    sepal_length  sepal_width  petal_length  petal_width species
# 0           5.1          3.5           1.4          0.2  setosa
# 1           4.9          3.0           1.4          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa
# 3           4.6          3.1           1.5          0.2  setosa
# 4           5.0          3.6           1.4          0.2  setosa

スペースを削減するために列名を省略しておく。

df.columns = ['sl', 'sw', 'pl', 'pw', 'species']
print(df.head(5))
#     sl   sw   pl   pw species
# 0  5.1  3.5  1.4  0.2  setosa
# 1  4.9  3.0  1.4  0.2  setosa
# 2  4.7  3.2  1.3  0.2  setosa
# 3  4.6  3.1  1.5  0.2  setosa
# 4  5.0  3.6  1.4  0.2  setosa

DataFrame.groupby()でグルーピング

DataFramegroupby()メソッドでグルーピング(グループ分け)する。

列名を指定するとその列の値に応じてグルーピングされる。返されるのはGroupByオブジェクトで、それ自体をprintしても何も表示されない。

grouped = df.groupby('species')
print(grouped)
print(type(grouped))
# <pandas.core.groupby.DataFrameGroupBy object at 0x111b209e8>
# <class 'pandas.core.groupby.DataFrameGroupBy'>

GroupByオブジェクトのメソッド一覧は以下の公式ドキュメント参照。

size()メソッドでそれぞれのグループごとのサンプル数が確認できる。

print(grouped.size()) 
# species
# setosa        50
# versicolor    50
# virginica     50
# dtype: int64

グループごとの統計量を算出

平均、最小値、最大値、合計などを算出

GroupByオブジェクトに対しmean(), min(), max(), sum()などのメソッドを適用することで、グループごとの平均、最小値、最大値、合計などの統計量を算出できる。

print(grouped.mean()) 
#                sl     sw     pl     pw
# species                               
# setosa      5.006  3.428  1.462  0.246
# versicolor  5.936  2.770  4.260  1.326
# virginica   6.588  2.974  5.552  2.026

print(grouped.min()) 
#              sl   sw   pl   pw
# species                       
# setosa      4.3  2.3  1.0  0.1
# versicolor  4.9  2.0  3.0  1.0
# virginica   4.9  2.2  4.5  1.4

print(grouped.max()) 
#              sl   sw   pl   pw
# species                       
# setosa      5.8  4.4  1.9  0.6
# versicolor  7.0  3.4  5.1  1.8
# virginica   7.9  3.8  6.9  2.5

print(grouped.sum()) 
#                sl     sw     pl     pw
# species                               
# setosa      250.3  171.4   73.1   12.3
# versicolor  296.8  138.5  213.0   66.3
# virginica   329.4  148.7  277.6  101.3

そのほか、標準偏差std()、分散var()などもある。

複数の統計量を同時に算出

GroupByオブジェクトのagg()メソッドで、複数の統計量を同時に算出することも可能。min(), max()などのメソッド名の文字列をリストで指定する。

print(grouped.agg(['min', 'max'])) 
#              sl        sw        pl        pw     
#             min  max  min  max  min  max  min  max
# species                                           
# setosa      4.3  5.8  2.3  4.4  1.0  1.9  0.1  0.6
# versicolor  4.9  7.0  2.0  3.4  3.0  5.1  1.0  1.8
# virginica   4.9  7.9  2.2  3.8  4.5  6.9  1.4  2.5

また、describe()メソッドを使うと主要な統計量を一括で算出できる。

print(grouped.describe()) 
#               pl                                                 pw         \
#            count   mean       std  min  25%   50%    75%  max count   mean   
# species                                                                      
# setosa      50.0  1.462  0.173664  1.0  1.4  1.50  1.575  1.9  50.0  0.246   
# versicolor  50.0  4.260  0.469911  3.0  4.0  4.35  4.600  5.1  50.0  1.326   
# virginica   50.0  5.552  0.551895  4.5  5.1  5.55  5.875  6.9  50.0  2.026   
#            ...    sl         sw                                                
#            ...   75%  max count   mean       std  min    25%  50%    75%  max  
# species    ...                                                                 
# setosa     ...   5.2  5.8  50.0  3.428  0.379064  2.3  3.200  3.4  3.675  4.4  
# versicolor ...   6.3  7.0  50.0  2.770  0.313798  2.0  2.525  2.8  3.000  3.4  
# virginica  ...   6.9  7.9  50.0  2.974  0.322497  2.2  2.800  3.0  3.175  3.8  
# [3 rows x 32 columns]

グループごとの統計量のグラフをプロット

GroupByオブジェクトに対しmean(), min(), max(), sum()などのメソッドを適用すると返ってくるのはpandas.DataFrameなので、そのままplot()メソッドを使ってグラフを描画して可視化できる。

print(type(grouped.max()))
# <class 'pandas.core.frame.DataFrame'>

ax = grouped.max().plot.bar(rot=0)
fig = ax.get_figure()
fig.savefig('data/dst/iris_pandas_groupby_max.jpg')

pandas groupby plot

スポンサーリンク
シェア
このエントリーをはてなブックマークに追加

関連カテゴリー

関連記事