MIT researchers have developed a technique to identify and remove specific data points in training datasets that disproportionately contribute to a model's errors on minority subgroups. This approach aims to address biases in machine learning models by focusing on problematic data points, ultimately improving the fairness and accuracy of predictions for underrepresented groups.
Machine learning models often struggle to make accurate predictions for individuals underrepresented in their training datasets. For example, a model trained predominantly on data from male patients may fail to predict effective treatment options for female patients with the same chronic condition when applied in real-world settings.
One solution is to balance the dataset by removing data points until all subgroups are equally represented. However, this approach can require removing large amounts of data, which negatively impacts the model’s overall performance.
To address this challenge, MIT researchers have developed a new technique that selectively removes the specific data points most responsible for a model’s errors on minority subgroups. Unlike traditional balancing methods, this technique eliminates far fewer data points, preserving the model’s overall accuracy while enhancing its ability to perform well for underrepresented groups.
Beyond improving fairness, this method can also uncover hidden sources of bias in unlabeled datasets, which are far more common than labeled ones in many applications. By identifying these biases, the technique provides a pathway to refine datasets and improve model performance without requiring exhaustive labeling efforts.
This method could be combined with other strategies to enhance fairness in high-stakes applications, such as medical diagnostics. In the future, it may help ensure that underrepresented patients aren’t misdiagnosed due to biased AI models, making these systems more equitable and reliable.
Many other algorithms that try to address this issue assume each datapoint matters as much as every other datapoint. In this paper, we are showing that assumption is not true. There are specific points in our dataset that are contributing to this bias, and we can find those data points, remove them, and get better performance.
Kimia Hamidieh, Electrical Engineering and Computer Science (EECS) Graduate Student, Massachusetts Institute of Technology
The research will be presented at the Conference on Neural Information Processing Systems.
Removing Bad Examples
Machine-learning models are often trained on massive datasets sourced from the internet. Due to their size, these datasets can’t be manually curated, leaving room for problematic examples that can degrade model performance. Additionally, researchers have found that certain data points have a disproportionate impact on how well a model performs on specific tasks.
Building on these insights, the MIT researchers developed a method to identify and remove problematic data points that lead to errors in minority subgroups. Their goal is to address the issue of worst-group error, which arises when a model underperforms for underrepresented subgroups in the training data.
Their new technique leverages previous work, particularly a method called TRAK, which identifies the most influential training examples for a specific model output.
By analyzing the model’s incorrect predictions on minority subgroups, the researchers use TRAK to pinpoint the training examples that contributed most to those errors. This targeted approach minimizes the number of data points that need to be removed, preserving overall performance while improving fairness.
By aggregating this information across bad test predictions in the right way, we are able to find the specific parts of the training that are driving worst-group accuracy down overall.
Andre Ilyas, Stein Fellow, Stanford University
After identifying the problematic samples, the researchers remove those specific data points and retrain the model on the remaining dataset.
This targeted approach strikes a balance between fairness and performance. Typically, more data improves a model's overall accuracy, but by removing only the samples that contribute to failures for minority subgroups, the model retains its overall performance while significantly improving its accuracy for underrepresented groups. This method ensures the model becomes more equitable without sacrificing its effectiveness.
A More Accessible Approach
The MIT researchers tested their method across three machine-learning datasets and found it consistently outperformed other techniques. In one case, their approach improved worst-group accuracy while removing about 20,000 fewer training samples compared to a standard data-balancing method. It also achieved higher accuracy than methods requiring modifications to a model’s internal architecture.
A key advantage of this method is its simplicity and flexibility. Since it focuses on adjusting the dataset rather than altering the model itself, it is easier for practitioners to implement and compatible with a wide range of machine-learning models.
Additionally, the technique can be applied even when bias in the dataset is unknown or when subgroup labels are unavailable. By identifying data points that most influence the features the model is learning, researchers can uncover the variables driving its predictions. This insight can help pinpoint hidden biases and improve model fairness without needing detailed subgroup annotations.
“This is a tool anyone can use when they are training a machine-learning model. They can look at those datapoints and see whether they are aligned with the capability they are trying to teach the model,” added Hamidieh.
Detecting unknown subgroup bias using this technique requires some intuition about which groups to investigate. The researchers aim to validate and expand on this capability through future human studies, enabling a deeper understanding of how the method can uncover hidden biases.
Looking ahead, the team plans to refine the performance and reliability of their approach while prioritizing accessibility and ease of use. By simplifying the method for practitioners, they hope to facilitate its deployment in real-world environments, where addressing biases in machine-learning models can have significant impacts.
Ilyas added, “When you have tools that let you critically look at the data and figure out which datapoints are going to lead to bias or other undesirable behavior, it gives you a first step toward building models that are going to be more fair and more reliable.”
This research is partially supported by the US Defense Advanced Research Projects Agency and the National Science Foundation.