pandas: Random sampling from DataFrame with sample()

Modified: | Tags: Python, pandas

You can get a random sample from pandas.DataFrame and Series by the sample() method. This is useful for checking data in a large pandas.DataFrame, Series.

This article describes the following contents.

  • Default behavior of sample()
  • Rows or columns: axis
  • The number of rows and columns: n
  • The fraction of rows and columns: frac
  • The seed for the random number generator: random_state
  • With or without replacement: replace
  • Reset index: ignore_index, reset_index()

Use the iris data set included as a sample in seaborn.

import pandas as pd
import seaborn as sns

df = sns.load_dataset("iris")
print(df.shape)
# (150, 5)

The following examples are for pandas.DataFrame, but pandas.Series also has sample(). The usage is the same for both.

Note that you can check large size pandas.DataFrame and Series with head() and tail(), which return the first/last n rows.

Default behavior of sample()

By default, one row is randomly selected.

print(df.sample())
#      sepal_length  sepal_width  petal_length  petal_width    species
# 133           6.3          2.8           5.1          1.5  virginica

Rows or columns: axis

If the axis parameter is set to 1, a column is randomly extracted instead of a row.

print(df.sample(axis=1))
#      petal_width
# 0            0.2
# 1            0.2
# 2            0.2
# 3            0.2
# 4            0.2
# ..           ...
# 145          2.3
# 146          1.9
# 147          2.0
# 148          2.3
# 149          1.8
# 
# [150 rows x 1 columns]

The number of rows and columns: n

The number of rows or columns to be selected can be specified in the n parameter.

print(df.sample(n=3))
#     sepal_length  sepal_width  petal_length  petal_width     species
# 29           4.7          3.2           1.6          0.2      setosa
# 67           5.8          2.7           4.1          1.0  versicolor
# 18           5.7          3.8           1.7          0.3      setosa

The fraction of rows and columns: frac

The fraction of rows and columns to be selected can be specified in the frac parameter. frac=1 means 100%.

print(df.sample(frac=0.04))
#      sepal_length  sepal_width  petal_length  petal_width     species
# 15            5.7          4.4           1.5          0.4      setosa
# 66            5.6          3.0           4.5          1.5  versicolor
# 131           7.9          3.8           6.4          2.0   virginica
# 64            5.6          2.9           3.6          1.3  versicolor
# 81            5.5          2.4           3.7          1.0  versicolor
# 137           6.4          3.1           5.5          1.8   virginica

You cannot specify n and frac at the same time.

# print(df.sample(n=3, frac=0.04))
# ValueError: Please enter a value for `frac` OR `n`, not both

The seed for the random number generator: random_state

The seed for the random number generator can be specified in the random_state parameter. The same rows/columns are returned for the same random_state.

print(df.sample(n=3, random_state=0))
#      sepal_length  sepal_width  petal_length  petal_width     species
# 114           5.8          2.8           5.1          2.4   virginica
# 62            6.0          2.2           4.0          1.0  versicolor
# 33            5.5          4.2           1.4          0.2      setosa

print(df.sample(n=3, random_state=0))
#      sepal_length  sepal_width  petal_length  petal_width     species
# 114           5.8          2.8           5.1          2.4   virginica
# 62            6.0          2.2           4.0          1.0  versicolor
# 33            5.5          4.2           1.4          0.2      setosa

With or without replacement: replace

If the replace parameter is set to True, rows and columns are sampled with replacement. The same row/column may be selected repeatedly.

The default value for replace is False (sampling without replacement).

print(df.head(3))
#    sepal_length  sepal_width  petal_length  petal_width species
# 0           5.1          3.5           1.4          0.2  setosa
# 1           4.9          3.0           1.4          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa

print(df.head(3).sample(n=3, replace=True))
#    sepal_length  sepal_width  petal_length  petal_width species
# 0           5.1          3.5           1.4          0.2  setosa
# 0           5.1          3.5           1.4          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa

If replace=True, you can specify a value greater than the original number of rows/columns in n or a value greater than 1 in frac.

print(df.head(3).sample(n=5, replace=True))
#    sepal_length  sepal_width  petal_length  petal_width species
# 1           4.9          3.0           1.4          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa
# 0           5.1          3.5           1.4          0.2  setosa
# 0           5.1          3.5           1.4          0.2  setosa
# 1           4.9          3.0           1.4          0.2  setosa

print(df.head(3).sample(frac=2, replace=True))
#    sepal_length  sepal_width  petal_length  petal_width species
# 2           4.7          3.2           1.3          0.2  setosa
# 1           4.9          3.0           1.4          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa
# 0           5.1          3.5           1.4          0.2  setosa
# 2           4.7          3.2           1.3          0.2  setosa

Reset index: ignore_index, reset_index()

If you want to reindex the result (0, 1, ... , n-1), set the ignore_index parameter of sample() to True.

print(df.sample(n=3, ignore_index=True))
#    sepal_length  sepal_width  petal_length  petal_width     species
# 0           5.2          2.7           3.9          1.4  versicolor
# 1           6.3          2.5           4.9          1.5  versicolor
# 2           5.7          3.0           4.2          1.2  versicolor

The ignore_index was added in pandas 1.3.0. For earlier versions, you can use the reset_index() method. Set the drop parameter to True to delete the original index.

print(df.sample(n=3).reset_index(drop=True))
#    sepal_length  sepal_width  petal_length  petal_width    species
# 0           4.9          3.1           1.5          0.2     setosa
# 1           7.9          3.8           6.4          2.0  virginica
# 2           6.3          2.8           5.1          1.5  virginica

Related Categories

Related Articles