How to Handle Imbalanced Datasets in Machine Learning
An imbalanced dataset (a.k.a unbalanced dataset or skewed dataset) refers to a situation in which one class (or label) is more prevalent than the other class(es).
For example, in a binary classification problem with two classes (e.g., “churn” vs. “not churn”), an imbalanced dataset might have many more examples of the “not churn” class compared to the “churn” class. This can be a problem when training machine learning models, as the model may be biased towards the majority class and have poor performance when it comes to predicting the minority class.
Various strategies can address an imbalanced dataset, such as undersampling the majority class, oversampling the minority class, or using weighted loss functions. It’s essential to choose the appropriate approach based on the specific characteristics of your dataset and the goals of your model.
This post discusses various strategies to deal with imbalanced datasets with practical solutions.
A practical example of the impact of Unbalanced datasets.
Consider a case where you are building a machine learning model to predict whether a customer will churn (i.e., stop using a company’s products or services). You have a dataset of customer data, and you want to train a model to predict whether a customer will churn based on their past behavior.
Unfortunately, the dataset is unbalanced. Many more customers did not churn (the majority class) compared to those who did churn (the minority class). For example, there may be 10,000 customers who did not churn and only 1,000 customers who did churn.
If you train a machine learning model on this unbalanced dataset, the model may be biased towards the majority class (i.e., it will predict that most customers will not churn). This is because the model has seen more examples of the majority class during training and may have learned to expect the majority class more often.
As a result, the model may have poor performance when predicting the minority class (i.e., customers who will churn). This can be a problem because accurately predicting churn is likely an important goal for the company. For example, if the model predicts that a customer will not churn when they will, the company may not take any action to try to retain that customer, which can result in lost revenue.
To address this issue, you could try one of the strategies mentioned below to balance the dataset and mitigate the impact of the class imbalance. For example, you could undersample the majority class, oversample the minority class, or use a weighted loss function during training. Yet, it’s essential to be careful when using these techniques, as they can also have drawbacks (e.g., loss of important information and overfitting).
Related: Data challenges in Production ML Systems.
Strategies to avoid bias when you have an imbalanced dataset
We know that an imbalanced dataset will cause the model to be biased toward the majority class. Here are some strategies that you can try to address an unbalanced dataset:
Collect more data: One option is to collect more data for the minority class to balance the dataset. Yet, this is not always practical or workable.
Undersample the majority class: Another option is to undersample the majority class by randomly selecting a smaller subset of the majority class data. This can help balance the dataset, but it may also result in losing important information.
Oversample the minority class: A third option is to oversample the minority class by generating synthetic samples or sampling with replacements from the minority class. This can help balance the dataset but may lead to overfitting if not carefully done.
Use weighted loss functions: Some ML algorithms allow you to specify a weight for each class when training the model. You can assign a higher weight to the minority class to give it more influence on the model.
Use techniques to mitigate the class imbalance: You can try techniques such as using class-specific evaluation metrics (e.g., precision and recall) or using a class-balanced subsample during training to mitigate the impact of the class imbalance.
Using Synthetic dataset: Synthetic dataset is artificially generated data without losing the statistical properties. Although it’s not applicable in every situation, it’s helpful in some applications.
There is no one-size-fits-all solution for dealing with an unbalanced dataset. The best approach will depend on your dataset’s specific characteristics and your model’s goals.
Related: Machine Learning Systems in Real-Life Production vs. Research/Academic Settings.
Synthetic Minority Oversampling Technique (SMOT) to overcome imbalanced dataset issue
Generating synthetic data can help address bias arising from an imbalanced dataset, mainly if you cannot collect more real-world data. However, it’s essential to use caution when using synthetic data, as it can also introduce its own biases if not done carefully.
Here are some potential benefits and drawbacks to consider when using synthetic data to address bias in an imbalanced dataset:
Benefits:
- Synthetic data can help you balance the dataset by generating additional examples of the minority class.
- Synthetic data can be used to preserve privacy or protect sensitive information by generating data similar to real-world data but does not contain sensitive information.
Drawbacks:
- Synthetic data is not real-world data, and it may not accurately reflect the actual distribution of the data. This can lead to a model not representative of the real-world problem.
- Synthetic data can introduce biases if the generation process is not carefully designed. For example, if you sample from a distribution that is not representative of real-world data, the generated data may be biased.
Overall, generating synthetic data can help address bias in an imbalanced dataset, but it’s essential to use it carefully and consider the potential drawbacks. Other approaches, such as undersampling or oversampling, may also be worth considering depending on the specific characteristics of your dataset and the goals of your model.
Related: How to Deploy Machine Learning Models Smartly?
Synthetic Minority Oversampling Technique in Scikit-learn
Here is an example of how you might use synthetic data with scikit-learn, a popular machine-learning library for Python:
First, you will need to generate synthetic data using a suitable method. One option is to use the SMOTE
(Synthetic Minority Oversampling Technique) class from the imblearn
library, a popular method for generating synthetic data for imbalanced datasets.
from imblearn.over_sampling import SMOTE
# Generate synthetic data using SMOTE
X, y = SMOTE().fit_sample(X_train, y_train)
Here, X
and y
are the synthetic data and labels, respectively. X_train
and y_train
are the original training data and tags. The SMOTE
class generates synthetic samples of the minority class by sampling from the feature space and interpolating between existing minority class samples.
Next, you can use the synthetic data to train a machine learning model using scikit-learn. For example, you might use a decision tree classifier:
from sklearn.tree import DecisionTreeClassifier
# Train a decision tree classifier on the synthetic data
clf = DecisionTreeClassifier()
clf.fit(X, y)
Finally, you can evaluate the performance of the model on the test set using scikit-learn’s built-in evaluation functions:
from sklearn.metrics import classification_report
# Make predictions on the test set
y_pred = clf.predict(X_test)
# Print a classification report
print(classification_report(y_test, y_pred))
This will print a classification report showing the precision, recall, and F1 score for each class, as well as the overall accuracy of the model. You can use these metrics to evaluate the model’s performance and identify potential biases.
This is just one example of how synthetic data can be used with scikit-learn. There are many other options and techniques that you can use depending on the specific characteristics of your dataset and the goals of your model.
Related: In the 8 Key MLOps Roles, Where Do You Fit In?
Pick the right evaluation matrix for your skewed dataset.
Using a different evaluation metric can be one way to address the issue of a skewed or unbalanced dataset. In particular, class-specific metrics such as precision and recall can be useful when the classes in the dataset are imbalanced.
Precision and recall are evaluation metrics commonly used in classification tasks and often used in combination (e.g., as the F1 score). Precision measures the proportion of true positive predictions made by the model out of all positive predictions. The recall measures the proportion of true positive predictions made by the model out of all actual positive cases.
For example, suppose you are building a machine learning model to predict whether a customer will churn (i.e., stop using a company’s products or services). Suppose the dataset is imbalanced, with many more customers who did not churn compared to those who did churn. In that case, you may want to use precision and recall as evaluation metrics to understand better the model’s performance on the minority class (customers who did churn).
Using class-specific metrics such as precision and recall can be helpful when the classes in the dataset are imbalanced because they allow you to evaluate the model’s performance on each class separately. This can be useful for identifying potential biases in the model and understanding how well the model can predict the minority class.
Precision and recall are not the only evaluation metrics that can be useful when the classes in the dataset are imbalanced. Other metrics, such as the Area Under the Precision-Recall Curve (AUPRC) and the Matthews correlation coefficient (MCC), can also be helpful in this context. The best metric to use will depend on your dataset’s specific characteristics and your model’s goals.
Here’s an example of how you might use the recall_score
function to evaluate the recall of a machine-learning model on a test set:
from sklearn.metrics import recall_score
# Make predictions on the test set
y_pred = model.predict(X_test)
# Calculate the recall of the model
recall = recall_score(y_test, y_pred)
print(f"Model recall: {recall:.3f}")
In this example, y_pred
is a numpy array of predictions made by the model on the test set and y_test
is a numpy array of true labels for the test set. The recall_score
function calculates the recall of the model based on the true labels and predictions and returns the result as a float.
The recall_score
function has several optional parameters that you can use to customize the recall calculation. For example, you can specify the pos_label
parameter to specify which class should be considered the positive class (the default is 1). You can also specify the average
parameter to specify how the recall should be averaged across classes (e.g., "micro"
, "macro"
, or "weighted"
).
It’s important to note that the recall_score
function is just one of many evaluation metrics available in scikit-learn. You can use many other metrics to evaluate your machine learning model’s performance, depending on your dataset’s specific characteristics and the goals of your model.
Related: How to Deploy Machine Learning Models Using Fast API
Final thoughts
An imbalanced dataset is a terrible source to train a machine learning model. Especially if accuracy is what you’re trying to improve, the model would have a greater accuracy simply by always predicting the majority label.
Depending on the situation, you must use a different evaluation matrix, such as recall. Your domain expertise has a huge role here.
But if you need accuracy, you can try synthetic data or other methods outlined in this article.