Machine learning models can fail when they fight to make predictions for individuals who were underrepresented within the datasets they were trained on.
For example, a model that predicts the very best treatment option for somebody with a chronic illness could be trained using a dataset that incorporates mostly male patients. This model could make incorrect predictions for female patients when utilized in a hospital setting.
To improve results, engineers can try and balance the training data set by removing data points until all subgroups are equally represented. Although dataset balancing is promising, it often requires removing large amounts of knowledge, which affects the general performance of the model.
MIT researchers have developed a brand new technique that identifies and removes specific points in a training data set that contribute most to a model's failure on minority subgroups. By removing far fewer data points than other approaches, this method maintains the general accuracy of the model while improving its performance on underrepresented groups.
Additionally, the technique can discover hidden sources of bias in a training dataset that lacks labels. In many applications, unlabeled data is much more common than labeled data.
This method may be combined with other approaches to enhance the fairness of machine learning models utilized in high-risk situations. For example, it could someday help be sure that underrepresented patients usually are not misdiagnosed as a result of a biased AI model.
“Many other algorithms that try to resolve this problem assume that each data point is as essential as every other data point. In this text we show that the belief is just not true. There are certain points in our data set that contribute to this bias, and we are able to find those data points, remove them, and achieve higher performance,” says Kimia Hamidieh, an electrical engineering and computer science (EECS) graduate student at MIT and co-lead creator of a Paper on this method.
She wrote the paper with co-lead authors Saachi Jain PhD '24 and her EECS graduate student Kristian Georgiev; Andrew Ilyas MEng '18, PhD '23, Stein Fellow at Stanford University; and senior authors Marzyeh Ghassemi, associate professor of EECS and member of the Institute of Medical Engineering Sciences and the Laboratory for Information and Decision Systems, and Aleksander Madry, professor of Cadence Design Systems at MIT. The research will likely be presented on the Conference on Neural Information Processing Systems.
Remove bad examples
Machine learning models are sometimes trained using huge data sets collected from many sources on the Internet. These datasets are far too large to be rigorously curated by hand, in order that they may contain poor examples that impact model performance.
Scientists also know that some data points influence a model's performance on certain downstream tasks greater than others.
The MIT researchers combined these two ideas into an approach that identifies and removes these problematic data points. They are attempting to resolve an issue generally known as worst group error, which occurs when a model underperforms on minority subgroups in a training data set.
The researchers' recent technique relies on previous work by which they used a way called ” TRAKwhich identifies an important training examples for a given model output.
For this recent technique, they take false predictions that the model made about minority subgroups and use TRAK to find out which training examples contributed essentially the most to that false prediction.
“By aggregating this details about poor test predictions in the suitable way, we are able to find the particular parts of the training that result in an overall degradation within the accuracy of the worst group,” explains Ilyas.
They then remove those specific samples and retrain the model using the remaining data.
Because more data typically leads to raised overall performance, the general accuracy of the model is maintained by removing only the samples that end in failures within the worst group, while concurrently increasing performance on minority subgroups.
A more accessible approach
On three machine learning datasets, their method outperformed several techniques. In one case, it increased the accuracy of the worst group while removing about 20,000 fewer training samples than a conventional data balancing method. Their technique also achieved greater accuracy than methods that require changes to the inner workings of a model.
Because the MIT method involves modifying an information set as an alternative, it could be easier for a practitioner to make use of and could be applied to many varieties of models.
It will also be used when bias is unknown because subgroups in a training data set are unlabeled. By identifying the information points that contribute essentially the most to a function the model is learning, it may understand the variables it uses to make a prediction.
“This is a tool that anyone can use when training a machine learning model. They can take a look at these data points and see in the event that they match the skill they need to teach the model,” says Hamidieh.
Using the technique to detect unknown subgroup biases requires a way of which groups to search for. Therefore, researchers hope to validate and explore them more comprehensively through future human studies.
They also need to improve the performance and reliability of their technique and be sure that the tactic is accessible and user-friendly for practitioners who might someday use it in real-world settings.
“Having tools that let you look critically at the information and work out which data points are causing bias or other undesirable behavior is a primary step toward creating models which might be fairer and more reliable,” says Iljas.
This work is funded partly by the National Science Foundation and the US Defense Advanced Research Projects Agency.