pandasのピボットテーブルでカテゴリ毎の統計量などを算出

Posted: | Tags: Python, pandas

pandas.pivot_table()関数を使うと、Excelなどの表計算ソフトのピボットテーブル機能と同様の処理が実現できる。

カテゴリデータ(カテゴリカルデータ、質的データ)のカテゴリごとにグルーピング(グループ分け)して量的データの統計量(平均、合計、最大、最小、標準偏差など)を確認・分析できる。便利。

カテゴリごとの出現回数・頻度を集計する場合はpandas.crosstab()という関数が別途用意されている(pivot_table()でも可能)。

ここでは、

  • pandas.pivot_table()関数の基本的な使い方
  • カテゴリごとの小計・総計を算出: 引数margins
  • 結果の値の算出方法を指定: 引数aggfunc
  • 元データの欠損値NaNを除外するか指定: 引数dropna

について説明する。

例としてタイタニックの生存情報のデータを使用する。Kaggleの問題からダウンロードできる。

import numpy as np
import pandas as pd

df = pd.read_csv('data/src/titanic_train.csv', index_col=0).drop(['Name', 'Ticket', 'SibSp', 'Parch'], axis=1)

print(df.head())
#              Survived  Pclass     Sex   Age     Fare Cabin Embarked
# PassengerId                                                        
# 1                   0       3    male  22.0   7.2500   NaN        S
# 2                   1       1  female  38.0  71.2833   C85        C
# 3                   1       3  female  26.0   7.9250   NaN        S
# 4                   1       1  female  35.0  53.1000  C123        S
# 5                   0       3    male  35.0   8.0500   NaN        S

適当に列を除外している。

pandas.pivot_table()関数の基本的な使い方

pandas.pivot_table()関数の必須の引数は以下の3つ。

  • data(第一引数): 元データのpandas.DataFrameオブジェクトを指定。
  • index: 元データの列名を指定。結果の行見出しとなる。
  • columns: 元データの列名を指定。結果の列見出しとなる。

引数index, columnsに指定していない列の平均値が結果として算出される。このときデータ型が数値でない列は除外される。平均値以外の値を算出する方法は後述。

print(pd.pivot_table(df, index='Pclass', columns='Sex'))
#               Age                   Fare             Survived          
# Sex        female       male      female       male    female      male
# Pclass                                                                 
# 1       34.611765  41.281386  106.125798  67.226127  0.968085  0.368852
# 2       28.722973  30.740707   21.970121  19.741782  0.921053  0.157407
# 3       21.750000  26.507589   16.118810  12.661633  0.500000  0.135447

pandas.pivot_table()関数が返すのはpandas.DataFrame

print(type(pd.pivot_table(df, index='Pclass', columns='Sex')))
# <class 'pandas.core.frame.DataFrame'>

さらに、引数valuesに元データの列名を指定すると、その列に対する結果のみが算出される。

print(pd.pivot_table(df, index='Pclass', columns='Sex', values='Age'))
# Sex        female       male
# Pclass                      
# 1       34.611765  41.281386
# 2       28.722973  30.740707
# 3       21.750000  26.507589

引数index, columns, valuesには元データの列名のリストを指定することも可能。結果はマルチインデックス(階層型インデックス)のpandas.DataFrameとして返される。

print(pd.pivot_table(df, index=['Sex', 'Pclass'], columns='Survived', values=['Age', 'Fare']))
#                      Age                   Fare            
# Survived               0          1           0           1
# Sex    Pclass                                              
# female 1       25.666667  34.939024  110.604167  105.978159
#        2       36.000000  28.080882   18.250000   22.288989
#        3       23.818182  19.329787   19.773093   12.464526
# male   1       44.581967  36.248000   62.894910   74.637320
#        2       33.369048  16.022000   19.488965   21.095100
#        3       27.255814  22.274211   12.204469   15.579696

カテゴリごとの小計・総計を算出: 引数margins

引数marginsTrueとすると、各カテゴリごとの結果(小計)および全体の結果(総計)が算出できる。

print(pd.pivot_table(df, index='Sex', columns='Pclass', values='Age', margins=True))
# Pclass          1          2          3        All
# Sex                                               
# female  34.611765  28.722973  21.750000  27.915709
# male    41.281386  30.740707  26.507589  30.726645
# All     38.233441  29.877630  25.140620  29.699118

小計・総計の行ラベル・列ラベルは引数margins_nameで指定できる。デフォルトは'All'

print(pd.pivot_table(df, index='Sex', columns='Pclass', values='Age',
                     margins=True, margins_name='Total'))
# Pclass          1          2          3      Total
# Sex                                               
# female  34.611765  28.722973  21.750000  27.915709
# male    41.281386  30.740707  26.507589  30.726645
# Total   38.233441  29.877630  25.140620  29.699118

結果の値の算出方法を指定: 引数aggfunc

デフォルトでは平均値が算出されるが、引数aggfuncに関数を指定するとほかの方法で値を算出できる。

デフォルト(引数aggfuncを省略した場合)はnumpy.mean()が指定される。

関数を引数として指定する際は()を書かないので注意。

print(pd.pivot_table(df, index='Sex', columns='Pclass', values='Age',
                     margins=True, aggfunc=np.min))
# Pclass     1     2     3   All
# Sex                           
# female  2.00  2.00  0.75  0.75
# male    0.92  0.67  0.42  0.42
# All     0.92  0.67  0.42  0.42

引数aggfuncには関数のリストを指定することもできる。

print(pd.pivot_table(df, index='Sex', columns='Pclass', values='Age',
                     margins=True, aggfunc=[np.min, np.max]))
#         amin                    amax                  
# Pclass     1     2     3   All     1     2     3   All
# Sex                                                   
# female  2.00  2.00  0.75  0.75  63.0  57.0  63.0  63.0
# male    0.92  0.67  0.42  0.42  80.0  70.0  74.0  80.0
# All     0.92  0.67  0.42  0.42  80.0  70.0  74.0  80.0

指定する関数は一次元配列に対してスカラー値を返す関数であればよい。NumPyの関数に限らず、例えば一次元配列の要素数を返すPythonの組み込み関数len()などでもOK。

print(pd.pivot_table(df, index='Sex', columns='Pclass', values='Age',
                     margins=True, aggfunc=len))
# Pclass      1      2      3    All
# Sex                               
# female   94.0   76.0  144.0  261.0
# male    122.0  108.0  347.0  453.0
# All     186.0  173.0  355.0  714.0

なお、len()を使うとカテゴリごとの出現回数が算出できるが、上述のように、crosstab()という関数も別途用意されている。crosstab()だと結果を1に規格化(正規化)したりできる。

また、次に示す欠損値NaNの取り扱いに注意。

元データの欠損値NaNを除外するか指定: 引数dropna

上述のlen()を使った例の合計値が元データのデータ数よりも少なくなっている。これは欠損値NaNの扱いに起因するもの。

print(len(df))
# 891

print(df.isnull().sum())
# Survived      0
# Pclass        0
# Sex           0
# Age         177
# Fare          0
# Cabin       687
# Embarked      2
# dtype: int64

欠損値NaNの数をカウントする方法については以下の記事を参照。

引数dropnaFalseとするとすべての要素数がカウントされる。

print(pd.pivot_table(df, index='Sex', columns='Pclass', values='Age',
                     margins=True, aggfunc=len, dropna=False))
# Pclass      1      2      3    All
# Sex                               
# female   94.0   76.0  144.0  314.0
# male    122.0  108.0  347.0  577.0
# All     216.0  184.0  491.0  891.0

pivot_table()NaNの処理をまかせると思わぬ結果になることがあるので、元データの時点でNaNを除外するか別の値に置換するなどのケアをしておいたほうが無難かもしれない。

関連カテゴリー

関連記事