Is your training dataset balanced?
|

How to Check if a Dataset Is Imbalanced

A dataset is imbalanced if the classes within the dataset are not evenly distributed.

For example, if you are building a machine learning model to classify whether an email is spam, and 99% of the emails in your dataset are not spam, then your dataset is imbalanced.

A spam classification is a relatively easygoing example. But ML models trained on the imbalanced dataset can have more severe consequences.

For example, if a machine learning model is trained on a dataset where most patients do not have a specific disease, it may be less accurate at predicting the disease in patients with it. This could lead to misdiagnosis or inadequate treatment.

Let’s take another example. Suppose an ML model is used to predict the likelihood of a defendant reoffending, and the dataset is imbalanced with many non-offenders; the model may be biased towards predicting a low risk of reoffending. This could result in lenient sentences for high-risk individuals and unfairly harsh sentences for low-risk individuals.

Finally, imbalanced datasets can also have broader social impacts. For instance, if a machine learning model is used to predict which job candidates are most likely to succeed, and the dataset is imbalanced with a majority of successful candidates coming from a specific group (e.g., men, a particular racial group, etc.), the model may be biased towards predicting success for candidates from that group. This could perpetuate existing societal biases and contribute to unequal opportunities.

It is essential to be aware of the potential impacts of imbalanced datasets and take steps to address them when building machine learning models. This article focuses on various techniques to identify imbalanced datasets.

Grab your aromatic coffee (or tea) and get ready…!

Use Visual inspection to find imbalanced datasets.

You can visually inspect the distribution of classes in your dataset by plotting a bar chart or histogram. If one type is significantly larger than the other, then your dataset is likely imbalanced.

Use a library such as Matplotlib to create a bar chart or histogram to visualize the class distribution in your dataset. For example:

import matplotlib.pyplot as plt

# Assume that y is a list or array containing the class labels of your dataset
unique_labels, counts = np.unique(y, return_counts=True)

plt.bar(unique_labels, counts)
plt.xlabel('Class label')
plt.ylabel('Number of samples')
plt.title('Class distribution in dataset')
plt.show()
Python

This will create a bar chart showing the number of samples for each class label.

Analyze the class distribution.

You can also check the class distribution by calculating the proportion of samples in each class. For example, if your dataset has 1000 samples and 900 belong to class 0, and 100 belong to class 1, your dataset is imbalanced.

Use the numpy library to calculate the proportion of samples in each class. For example:

import numpy as np

# Assume that y is a list or array containing the class labels of your dataset
unique_labels, counts = np.unique(y, return_counts=True)
proportions = counts / len(y)

print(proportions)
Python

This will print the proportion of samples in each class.

Watch out for biases in your performance matrices.

Another way to identify an imbalanced dataset is by looking at the performance metrics of your machine-learning model. If your model consistently performs well in one class but poorly in the other, your dataset is likely imbalanced.

Widely used libraries such as scikit-learn can calculate performance metrics for your machine-learning model. For example:

from sklearn.metrics import accuracy_score, f1_score

# Assume that y_true is a list or array containing the true class labels and y_pred is a list or array containing the predicted class labels
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

print(f'Accuracy: {accuracy:.2f}')
print(f'F1 score: {f1:.2f}')
Python

This will print the accuracy and F1 score of the model. The dataset is likely imbalanced if the performance is significantly better on one class than the other.

How to use Shannon entropy to measure the imbalance

Shannon entropy is a measure of the impurity or uncertainty of a system. In the context of imbalanced datasets, Shannon entropy can be used as a measure of the imbalance between the classes. A dataset with high Shannon entropy is more balanced, while a dataset with low Shannon entropy is more imbalanced.

To calculate Shannon entropy, you can use the following formula:

H = -∑p(x)log(p(x))

Where H is the Shannon entropy, p(x) is the probability of a sample belonging to class x, and ∑ is the sum over all categories.

Here is an example of how to calculate Shannon entropy in Python:

import numpy as np

# Assume that y is a list or array containing the class labels of your dataset
unique_labels, counts = np.unique(y, return_counts=True)
proportions = counts / len(y)

# Calculate Shannon entropy
shannon_entropy = -np.sum(proportions * np.log(proportions))

print(f'Shannon entropy: {shannon_entropy:.2f}')
Python

This will print the Shannon entropy of the dataset. A higher Shannon entropy value indicates a more balanced dataset, while a lower value indicates a more imbalanced dataset.

Here is a full example of how to use Shannon entropy to measure imbalance using a synthetic dataset:

import numpy as np

# Generate synthetic imbalanced dataset
num_samples = 1000
num_class_0 = 900
num_class_1 = 100

# Class 0 samples
X0 = np.random.normal(0, 1, size=(num_class_0, 2))
y0 = np.zeros(num_class_0)

# Class 1 samples
X1 = np.random.normal(1, 1, size=(num_class_1, 2))
y1 = np.ones(num_class_1)

# Combine class 0 and class 1 samples
X = np.concatenate((X0, X1))
y = np.concatenate((y0, y1))

# Calculate Shannon entropy
unique_labels, counts = np.unique(y, return_counts=True)
proportions = counts / len(y)
shannon_entropy = -np.sum(proportions * np.log(proportions))

print(f'Shannon entropy: {shannon_entropy:.2f}')
Python

In this example, we generated a synthetic dataset with 1000 samples, 900 of which belong to class 0 and 100 belong to class 1. We then calculated the Shannon entropy of the dataset using the formula provided above. In this case, the Shannon entropy is 0.15, which indicates a relatively imbalanced dataset.

You can also use the visual inspection and class distribution techniques mentioned earlier to confirm that the dataset is imbalanced.

For example, use Matplotlib to create a bar chart or histogram of the class distribution or calculate the proportion of samples in each class.

Final thoughts

Imbalanced datasets are a real problem for ML engineers. The real challenge for ML systems in production is that we need to monitor them for any new rare events.

Because as we continuously retrain models in production, new datasets may be skewed towards a single class. And such unbalanced can have serious issues.

Related: Data challenges in Production ML Systems.

We must ensure that the dataset is not biased at all costs. For this, we can use either visual inspection, study the class distribution, or monitor the performance matrices.

A different approach is to use the Shannon entropy value. This figure would tell us if the dataset is potentially biased towards a specific group.


Thanks for the read, friend. It seems you and I have lots of common interests. Say Hi to me on LinkedIn, Twitter, and Medium.

Not a Medium member yet? Please use this link to become a member because I earn a commission for referring at no extra cost for you.

Similar Posts