Exploring the Mathematics of Statistical Learning Theory
It is impossible to deny the power of neural networks – recent progress in deep learning has consolidated them as the approach of choice for many learning tasks. There is, however, still a delicate art to their design, application and training. This post will dig into statistical learning theory, which provides some theoretical backbone to the fundamental tradeoffs present in any learning task, which in turn illuminates some of the nuances present in the current machine learning landscape.
A depiction of the connection between machine learning and mathematics. Generated by DALL·E 3
Introduction
A neural network (NN) with an infinite single dense hidden layer can approximate any function. When I first learned about the universal approximation theorem I was in awe. I knew NNs were powerful in practice, but learning they can do anything in principle felt huge – I thought perhaps we really can capture intelligence this way. Of course there is always a tradeoff, and a common discussion that follow that of universal approximation is the argument that the data needed to train a network grows incredibly fast with its size, not to mention the computing power. In fact, it is argued that the data needed to train a model grows exponentially with its size.
In this post I want to introduce various notions from statistical learning theory which quantify some of the considerations mentioned above. This post will be more mathematical and theoretical in nature, but I have honestly found myself thinking differently about my own uses of machine learning since breaching this challenging topic. I hope others will find it similarly useful and interesting.
The main source for this post is Understanding Machine Learning: From Theory to Algorithms. A lot of this text will follow their notation and choice of topics. However I am also borrowing some of the philosophy from the Geometric Deep Learning project. Finally, a popular set of notes which discuss these notions can be found in the CS229 lectures.
What are we learning, and how?
As many mathematics courses tend to do, we will start with numerous definitions to have a common language to build on. The basic idea of learning is to predict some features
In the context of supervised learning we learn to make predictions by observing samples of paired inputs and outputs, and using these to build a predictive function
The notation so far will serve us well, but we can get a little more abstract still (and use fancy fonts in the process). The space of all possible inputs is
To give an example of when the probability distribution approach is helpful in making the mathematics concrete, consider a loss function
Similarly we can define the loss over the training sample (or any other subset of
By the law of large numbers we can say that
Remark:
One key assumption we can make is that the target function we want to learn does indeed exist
Coming from a physics background the way I like to interpret the lack of realizability is through marginal distributions. Taking the example of house prices, we could imagine that if you had access to hundreds or thousands of other important variables then it would be possible to give an exact price for the house (property energy rating, space for parking, windows facing east or west, etc…). This could be a much larger input space
where
Simple Example
Let’s illustrate an example. We will consider a toy system with a single valued input
This is illustrated in the figure below, with a training set of 16 points scattered.
The approach we will take to learning these samples is to build an auxiliary function
Below is an example of such a
Below we show a second example which much less natural, and more complicated than the first:
Clearly the second function
This particular example was an example of Occam’s razor: the simpler hypothesis was preferable, and extrapolated better to new samples. However if we constructed a new example with two separated green regions, then we would have needed a bit more complexity than the first function offered. The conclusion here is that we want to build learned functions which are as simple as possible, but no simpler.
As a second point, note that this system is realizable, since each
Inductive Bias
We have seen that using arbitrarily complicated functions to model the training data is not a good approach. However, due to the artificially simple example above, the
We can formalize this idea: we aim to select the best performing function
For example, if we had a single input and a single output a valid hypothesis class would be all linear functions
Another valid class would be the set of quadratic polynomials. Or a third set could be the union of those two, although in practice the quadratic choices would almost always have lower loss than the linear functions, making this construction not very helpful. Alternatives to ERM tackle this issue (see Structural Risk Minimization).
Hypothesis classes can also be much more complicated, like a NN governed by millions of parameters instead of two. It could also be something more geometric, like choosing a hyperrectangle in the
Bias-Complexity Tradeoff
The main tradeoff under consideration in statistical learning theory is known as the bias-complexity tradeoff, or bias-variance tradeoff. This tradeoff is strongly connected to under and overfitting, which is an important topic in the modern landscape of deep learning (see also double descent). I see there being multiple ways of heuristically describing this tradeoff, one of them being that a large hypothesis class has an extremely strong ability to memorize spurious correlations in the finite training set. Sometimes it is common to hear that the model starts to memorize noise instead of signal, but it is important to remember this phenomenon exists independent of noisy data.
I personally think the idea of memorizing data rather than learning data is appropriate. Consider the example given above where the more complicated second function memorized all previously seen training data, but learned nothing about the true underlying distribution of the data. Even if the model isn’t capable of memorizing all previous seen data, it could be memorizing spurious correlations.
Remark:
A second tradeoff, less important in the context of learning theory, is that searching for the best performing function in a hypothesis class can be computationally expensive, and infeasible in practice. Sometimes, like in the case of a straight line, we can solve the problem analytically. In other cases, like with NNs, we can use backpropagation to quickly optimize millions or billions of parameters to find a well performing function, but even then there is a level of randomness, and running the same code many times will yield different results. In general we can argue that the problem risks becoming combinatorially expensive. For example, imagine a hypothesis class governed by selecting
Sources of Error
Using the notation introduced above we have that the mean loss of a hypothesis
However, in practice, we do not have access to the underlying distribution, making
Two distinct sources of error are already present here. Firstly our hypothesis class might be too limited to fit the problem perfectly. We quantify this as the approximation error:
In the equation above the cancellation is trivial to see mathematically (it is essentially
With these errors defined we can clarify the bias-complexity tradeoff. Generally speaking, by increasing our number of training samples we can reduce the estimation error. However, if our hypothesis class is chosen poorly, then there will be an irreducible approximation error. Suppose then we want to build a bigger hypothesis class to reduce the approximation error. This leads to each training sample being less informative, and the estimation error will be larger – searching in a larger space requires more data. This is the tradeoff at the heart of statistical learning theory, and much of machine learning practice. It is crucial to build a small hypothesis class which still somehow manages to model the true nature of the system being learned.
In the previous paragraph I used the word “generally”. Can we be confident that decreasing estimation error is really just a case of having more samples? It turns out that it is not always the case, as even our simple example above proves. Loosely speaking, because of the infinite numbers in the range
In the following section we will make these two distinct cases much more quantitative.
Fundamental Theorem of Learning
We can now introduce the precise mathematical notion of learnability. The way it works is by providing probabilistic guarantees on the error of a learned hypothesis.
As a prelude it is worth recapping the way continuity and calculus is studied in a rigorous mathematical sense. While in practice we’re comfortable in simply taking derivatives and running calculations, the mathematics which underpins the operations of differentiation is based on an interplay between small limits
The concept of learnability follows a similar pattern. In this case you must guarantee that there is only a small probability
Using symbols, in a PAC learnable hypothesis class
Remark:
This definition is sometimes known as agnostic PAC learnability, due to the irreducible minimum error. This is because here we are not considering the realizability assumption within the considered hypothesis class, which may come from a poorly chosen class or a lack of realizability in general.
This may sounds very abstract, and far from the day-to-day practice of machine learning, but what it tells us is that learning an arbitrarily good estimate is just a case of training on enough (but finite) samples, if and only if our hypothesis class is PAC learnable. In other words, using a purely mathematical approach we can guarantee a good generalization of certain classes of learned hypotheses! Note however that while we can reduce the estimation error arbitrarily, the approximation error could still be large. By considering a larger hypothesis class we will typically find that
Let’s take the example where we must choose from a finite number of distinct hypotheses. It can be shown that any finite size hypothesis class is PAC learnable with
where
Beyond Finite Hypothesis Classes
Note that while finite hypothesis classes are always PAC learnable, the opposite is not necessarily true. Infinite hypothesis classes (e.g. ones based on real numbers) may be PAC learnable or not. One concept that governs learnability of infinite classes is the VC dimension. The logic behind the VC dimension follows from the No Free Lunch Theorem, which shows that a hypothesis class consisting of all possible functions
The definition of VC dimension considers the restriction of a hypothesis
The fundamental theorem of learning relates various notions of learnability, not all of which are covered here. One of the relationships however is that a hypothesis class is PAC learnable if and only if it has a finite VC dimension! This immediately explains why all finite classes are PAC learnable, and allows us to understand why some infinite classes are PAC learnable and some are not. In the example above the second function
Remark:
The fundamental theorem of learning even bounds the number of samples needed for learnability:
where
Neural Networks and Geometric Deep Learning
So where does this leave us when it comes to our use of NNs? In the introduction I stated that any function can be approximated using even a single hidden dense layer, assuming that can be arbitrarily large. This sounds initially appealing, since an infinite network is guaranteed to have the lowest possible approximation error. The sections above prove that this is not necessarily a good thing however – as the network gets larger it requires more samples to keep the generalization error low, and as it becomes arbitrarily large we lose PAC learnability itself.
Remark:
We can give some quick results specific to a simple dense layer NN, even if they are not necessary for the bottom line conclusions here. Firstly, for a network using the sign activation function which receives inputs
For the remainder of this post I will to discuss NNs in the light of statistical learning theory, however I want to be honest about the possibility that applying this abstract framework to NNs may carry a lot of subtlety and nuance. To name a few examples, there is an understanding that adding a regularization term can vastly reduce a model’s ability to overfit, and therefore provide lower generalization errors. Additionally there is also an emerging notion that stochastic gradient descent (SGD) itself provides an implicit regularization. With regularization in place, it is also possible to enter the interpolation regime following a double descent, where a model’s generalization error goes down despite a larger hypothesis class, completely contrary to the usual thinking in statistical learning.
Sticking for now with the common notion of the bias-complexity tradeoff, the question is how we can keep the approximation power of NNs without having an arbitrarily large hypothesis class. This is where I would like to introduce Geometric Deep Learning to the story. The idea of this mathematical program is to consider the underlying symmetries of the systems we’re trying to learn, and use those to construct much better suited NNs. This approach is built up in a hierarchy of mathematical structures, starting with sets and graphs which have permutation invariance, building up through grids with translational symmetries, arriving at manifolds with invariant continuous deformations. In each case we can construct an equivariant NN, which is forced to encode the underlying symmetry.
By using a symmetry to reduce the size of our hypothesis class, we will not incur the cost of higher approximation error, since we are removing hypotheses we know can’t have been fundamentally correct. Simultaneously we have a smaller hypothesis class and hence a smaller estimation error. Heuristically we can say that we avoid wasting a huge amount of training samples learning about a symmetry that could be encoded.
The easiest example to give is the case of categorizing images. Imagine we are labeling images as containing dogs or cats. Let’s say that as an intermediate step the NN learns to detect noses in the image. This aspect of the architecture should be able to detect noses regardless of them being the middle of the image or the top left corner. With a naive dense layer network this ability would have to be learned multiple times over, for each part of the image. More generally we say that a dog is still a dog after being moved up, down, left or right in the image – dogs exhibit translational invariance! In fact, the GDL approach states that the domain of images itself exhibits translational invariance. The architecture that encodes this symmetry is a convolutional neural network (CNN), which are known to achieve fantastic performance on image recognition tasks.
Another good example is if you are training on sets (graphs follow similar ideas). Let’s say you’re training a model to take in 5 elements of a set that have no particular order to them. If you encode these into a vector and pass it into a naive NN, it could learn something fictitious about the relative ordering of the elements in the vector. This means that the model would mistake
Similarly, in physics informed neural networks (PINNs) we bake in the notion that evolving forward
As a final note, some symmetries are very strict and mathematical, like translational invariance. Others are less clear, or maybe just approximate. Even in these cases I think there is scope to bake in some approximations to the model which greatly aid in learning. For example it is possible that we could encode that a NN should treat nearby and distant pairs of atoms differently when attempting to learn from chemical data.
Conclusion
My take away from statistical learning theory is to try approaching each task in a unique and independent way. While there are now myriad known “off the shelf” network architectures which deal with many types of problems (e.g. CNN for image analysis, LSTM for time series, transformers for language tasks, etc.), these architectures were once also unknown – exploration is key. If your task at hand has some symmetry, exact or approximate, exploiting it carries enormous benefits.
Adding to this point, the known network architectures can be composed in interesting and exciting ways, and can also be combined with more classical computational techniques. Consider for example the recent work by DeepMind, which built a very much custom architecture for the problem at hand – modeling weather – keeping in mind the symmetries needed at each step. In particular, the creation of a tessellated graph network relating nearby points on the Earth is an inspiring approach, and the results speak for themselves.
By creating an appropriate architecture for the problem at hand, you can reduce the hypothesis class, most likely get a lower approximation error, make good use of every sample available for useful training and not learning trivial symmetries and finally improve the generalization error.
Beyond the more pragmatic takeaways from this blog, I also hope to have shared some joy of a more “pure mathematics” approach to machine learning. While many theorems, corollaries and lemmas may be too abstract to directly apply much of the day-to-day work in machine learning, I think they still provide guiding principles on which to build new ideas and get to the right results faster.