How do neural networks learn? A mathematical formula explains how they detect relevant patterns

The insights, published in the journal Science, can also be used to make other types of machine learning architectures more effective.
No items found.
Ioana Patringenaru and Daniel Kane
March 13, 2024

Neural networks have been powering breakthroughs in artificial intelligence, including the large language models that are now being used in a wide range of applications, from finance to human resources to healthcare. But these networks remain a black box whose inner workings engineers and scientists struggle to understand. Now, a team led by data and computer scientists at the University of California San Diego has given neural networks the equivalent of an X-ray to uncover how they actually learn. 

The researchers found that a formula used in statistical analysis provides a streamlined mathematical description of how neural networks, such as GPT-2, a precursor to ChatGPT, learn relevant patterns in data, known as features. This formula also explains how neural networks use these relevant patterns to make predictions. 

“We are trying to understand neural networks from first principles,” said Daniel Beaglehole, a PhD student in the UC San Diego Department of Computer Science and Engineering and co-first author of the study. “With our formula, one can simply interpret which features the network is using to make predictions.”  

Adit Radhakrishnan, a postdoctoral fellow at Harvard who worked on the paper as an MIT EECS PhD student funded by the Schmidt Center and co-first author of the study, added: “We showed that neural networks, unlike other machine learning models, automatically implement this formula to identify features most relevant for prediction.”

The team presented their findings in the March 7 issue of the journal Science

Why does it matter how neural networks make predictions? AI-powered tools are now pervasive in everyday life. Banks use them to approve loans. Hospitals use them to analyze medical data, such as X-rays and MRIs. Companies use them to screen job applicants. But it’s currently difficult to understand the mechanism neural networks use to make decisions and the biases in the training data that might impact this. 

“If you don’t understand how neural networks learn, it’s very hard to establish whether neural networks produce reliable, accurate, and appropriate responses,” said Mikhail Belkin, the paper’s corresponding author and a professor at the UC San Diego Halicioglu Data Science Institute. “This is particularly significant given the rapid recent growth of machine learning and neural net technology.”

Former Eric and Wendy Schmidt Center PhD fellow Adit Radhakrishnan's research focuses on advancing the theoretical foundations of machine learning and developing new methods for tackling biomedical problems.

Understanding how neural networks make predictions is especially important in biological applications. In the realm of drug discovery, for example, researchers would not only want a model that accurately predicts drugs that are effective in treating cancer — they also want to discover biological mechanisms that make such drugs effective, explained Radhakrishnan. “By applying our findings to models trained to predict the effect of drugs on cancer cells, we can discover features of cancer cells that make them susceptible to a given drug and then develop new drugs to specifically target those mechanisms,” he said.

The study is part of a larger effort in Belkin’s research group to develop a mathematical theory that explains how neural networks work. “Technology has outpaced theory by a huge amount,” he said. “We need to catch up.” 

The team also showed that the statistical formula they used to understand how neural networks learn, known as Average Gradient Outer Product (AGOP), could be applied to improve performance and efficiency in other types of machine learning architectures that do not include neural networks.

“If we understand the underlying mechanisms that drive neural networks, we should be able to build machine learning models that are simpler, more efficient, and more interpretable,” Belkin said. “We hope this will help democratize AI.”

The machine learning systems that Belkin envisions would need less computational power, and therefore less power from the grid, to function. These systems also would be less complex and so easier to understand. 

Illustrating the new findings with an example

(Artificial) neural networks are computational tools to learn relationships between data characteristics (i.e. identifying specific objects or faces in an image). One example of a task is determining whether in a new image a person is wearing glasses or not. Machine learning approaches this problem by providing the neural network many example (training) images labeled as images of “a person wearing glasses” or ”a person not wearing glasses.” The neural network learns the relationship between images and their labels, and extracts data patterns, or features, that it needs to focus on to make a determination. One of the reasons AI systems are considered a black box is because it is often difficult to describe mathematically what criteria the systems are actually using to make their predictions, including potential biases. The new work provides a simple mathematical explanation for how the systems are learning these features.

Features are relevant patterns in the data. In the example above, there are a wide range of features that the neural networks learns, and then uses, to determine if in fact a person in a photograph is wearing glasses or not. One feature it would need to pay attention to for this task is the upper part of the face. Other features could be the eye or the nose area where glasses often rest. The network selectively pays attention to the features that it learns are relevant and then discards the other parts of the image, such as the lower part of the face, the hair and so on.  

Feature learning is the ability to recognize relevant patterns in data and then use those patterns to make predictions. In the glasses example, the network learns to pay attention to the upper part of the face. In the new Science paper, the researchers identified a statistical formula that describes how the neural networks are learning features. 

Alternative neural network architectures: The researchers went on to show that inserting this formula into computing systems that do not rely on neural networks allowed these systems to learn faster and more efficiently.  

“How do I ignore what’s not necessary? Humans are good at this,” said Belkin. “Machines are doing the same thing. Large Language Models, for example, are implementing this ‘selective paying attention’ and we haven’t known how they do it. In our Science paper, we present a mechanism explaining at least some of how the neural nets are ‘selectively paying attention.’” 

Study funders included the National Science Foundation and the Simons Foundation for the Collaboration on the Theoretical Foundations of Deep Learning. Belkin is part of NSF-funded and UC San Diego-led The Institute for Learning-enabled Optimization at Scale, or TILOS. 

Paper title: Mechanism for feature learning in neural networks and backpropagation-free machine learning models

Adit Radhakrishnan, Harvard School of Engineering and Applied Sciences and Broad Institute of MIT and Harvard

Daniel Beaglehole and Mikhail Belkin, University of California San Diego

Parthe Pandit: IIT Bombay–Pandit did the work for this paper as a postdoctoral researcher at the UC San Diego Halicioglu Data Science Institute

This story was adapted from a piece UC San Diego Today.

Get Involved