BTS II: An introduction to learning theory
(David here.) In this blogpost, we'll go off on a tangent and explore what it means to learn from data. In the process, we will get slightly closer (but not quite there) to the context where Rademacher complexity emerges.
This is the second blogpost in the sequence, continuing off BTS I.
$$\newcommand\lrangle[1]{\left\langle #1 \right\rangle}$$ $$\newcommand \E {\mathbb E}$$ $$\newcommand \cH {\mathcal H}$$ $$\newcommand \Var {\mathrm{Var}}$$ $$\newcommand \MSE {\mathrm{MSE}}$$
$\newcommand \cF {\mathcal F}$ $\newcommand \cG {\mathcal G}$ $\newcommand \cR {\mathcal R}$ $\newcommand \cS {\mathcal S}$ $\newcommand \cD {\mathcal D}$
Commentary: should you learn theory?
I'll take a page from Dylan and sneak in some non-math opinions. The question at hand is
As an Olympiad contestant, should you learn any "advanced theory"?
"Advanced theory" here is generally anything outside the scope of e.g. the IMO syllabus. It could be things like collegiate topics (linear algebra, group theory etc.) or really anything at all.
There are a few possible answers, but let me focus on two of them:
- No. It (mostly) doesn't help you solve any new Olympiad problems, and distracts you from the core problem solving aspects of Olympiads.
- Yes. Some problems are more naturally interpreted using some theory, and learning them helps you understand the big picture ideas better. Besides, Olympiads will not follow you into college but the bits of theory you learn might.
You might know that in the USA, MOP leans heavily towards "yes", and they have had great success in recent years at the IMO. But there is some truth to "no" as well, since it is easy to learn lots of theory and not know how exactly to apply it.
My personal experience is that I "got into" many mathematical areas because of that one Olympiad problem that introduced me to it, but at the same time it's very easy to Google a topic and have a textbook make your head spin. You can strive for a 80-20 here: explore what you want to when it "feels right to you", and give up when it's too overwhelming.
This advice also applies to this blogpost, if you are still of contest age. Feel free to follow along as far as you are comfortable with!
Prediction problems
There are lots of situations in life where we would like to use data to predict things that we care about. For instance,
- maybe we want to look at the partial scores during the IMO and predict the medal cutoffs
- maybe we want to look at part of a paragraph and predict the word that comes next.
- maybe we want to look at yesterday's stock price movements and predict tomorrow's stock price movements
- maybe we want to look at (...) and predict (...)
In general, we have data $X$ and we would like to predict $Y$. Sometimes, $Y$ is completely determined by $X$ (i.e. $Y=f(X)$ for some $X$). Oftentimes, this is not the case, so instead we speak of a joint probability distribution of $(X,Y)$. (You could think of this as the $(X,Y)$ you'd encounter "in the wild".)
Of course, predictions should be a (deterministic) function of data (which we'll denote as $h(X)$). How can we measure how good our predictions are? If the prediction is a real number, we could use the so-called mean squared error $$\MSE(h) := \E[(Y-h(X))^2]$$ where the $\E$ averages over the data distribution $(X,Y)$. Then, one way to figure out what a good prediction function $h$ is would be to minimize $h$.
(In fact, I could tell you what the best $h$ is: it's $h(x) = \E[Y|X=x]$! I'll leave it as an exercise to the reader as to why.)
So prediction problems are easy? Wait, not so fast. For one, every $\E$ that you see represents an unknowable number - there's no way to figure out what the expected stock price is tomorrow, for example, let alone minimize $h$ over it.
Unless... cue statistics.
Statistics
The good news is that we do have a way to "figure out" $\E[Y]$. I'll describe how all of this works using an example that everyone is familiar with: election polls.
Let's say that you would like to get a sense of election results in your GRCs, and in particular the percentage of the voters that voted from PAP. However, it's somewhat impractical to actually get everyone's votes (that's what the election count is - and it takes a long while), so to get a sneak peak we can just pick 100 vote slips and random and look at what's written on them.
Let's say out of these 100 votes, 60 of them vote for the PAP. Obviously, this doesn't mean that the actual percentage of voters who voted for the PAP is 60%. A priori, it could really be that we just happened to pick exactly the 60 people who voted for PAP and no one else did.
But polls work, and behind the scenes the central limit theorem is helping us. Let say that the actual percentage of the population that votes for PAP is $p$, and that we draw $n$ samples.
- A single vote has mean $p$ and variance $\sigma^2 = p(1-p)$.
- The total number of PAP votes consists of $n$ independent (roughly!) votes, so it approximately follows the normal distribution with mean $np$ and variance $n\sigma^2$, because for independent random variables, expected values and variances are linear.
- the fraction of PAP votes we estimate follows the normal distribution with mean $p$ and variance $\sigma^2/n$. (Variance is a squared quantity, so dividing the underlying by $n$ means that we have to divide variance by $n^2$ )
Thus, the poll mean follows a normal distribution with mean $p$ and variance $\sigma^2 / n$, or equivalently standard deviation $\sigma/\sqrt{n}$. You might know that the normal distribution has some complicated formula attached to it, like $$p(x) = \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left(-\frac{(x-\mu)^2}{\sigma^2}\right)$$ but really the takeaway is simply that
For a normal distribution:
- there's a 68% chance of being within 1 standard deviation of the mean,
- a 95% chance of being within 2 standard deviations of the mean,
- and beyond that the odds of being $d$ standard deviations greater than the mean scales like $\exp(-d^2)$.
So in our case, let's say we're happy with 2 standard deviations (being in 95% of all scenarios). Then, 60% is likely within $p \pm \frac{2\sigma}{\sqrt{n}}$, and we have that $\sigma^2 = p(1-p)\le 1/4$. Simplifying, we get that there is a 95% chance that 60% is within $p\pm 10\%$. It could still very well be that the real $p$ was $p=0.5$, but our sample would be a 2.5%-outlier scenario.
Reversing the logic, you could believe that there is an $n$ where it is extremely likely that your sample mean is within 1% or even 0.1% of the actual mean (even though the number of samples we'd need starts to get ridiculous).
If you're queasy about the fact that the sample mean is really only an approximate normal distribution, you'll be glad to know that there are theorems that rigorously establish this intuition, for example Hoeffding's inequality. But we'll assume that this picture is mostly right.
Sampling, accuracy and reliability
For any estimation problems, we are going to end up with a trade-off between:
- how many samples we need
- the accuracy of our estimate
- the reliability of our estimate (which is typically measured by the probability of the estimate being off our accuracy threshold).
In the context of estimating the mean, we can about this tradeoff using a few different formulations. For sampling, we have
We can estimate any true mean we want up to $\epsilon$ using $O(\log(1/\delta)/\epsilon^2)$ samples, with the probability of being wrong at most $\delta$.
(Here, the suppressed constant depends on the value we're estimating). Shuffling this around to focus on the accuracy of our estimate, we also have
With $n$ samples, we can estimate any true mean we want up to $O(\sqrt{\log(1/\delta) / n})$ error, with the probability of being wrong at most $\delta$.
Oftentimes for convenience we suppress the $\delta$ term and just write $\tilde O(\sqrt{1/n})$, because we typically don't care about how large $\delta$ is after a fairly small number of zeros beyond the decimal point. In general, any squiggle you see means "extra logarithmic terms we don't really care about", and honestly most of the time we will be careless and drop even that too.
Multiple mean estimation
If you have $N$ unknown means that you would like to estimate using data, then it follows that with the same number of samples, the probability of at least one of them being wrong is at most $N\delta$ (by a "union bound"). As you might be able to see, this gets absorbed easily into the $\log(1/\delta)$.
Let's restate one of the statements from the last section for multiple mean estimation:
We can estimate any $N$ true means we want up to $\epsilon$ using $O(\log( N / \delta )/\epsilon^2)$ samples, with the probability of being wrong at most $\delta$.
One way to read this is that to estimate $N$ true means you only need $O(\log(N) / \epsilon^2)$ additional samples (compared to estimating a single mean) - this reads like a good deal to me!
The more general takeaway is perhaps that
Typically, datasets are representative of the underlying population/distribution.
Machine learning
We can do the "same" idea for a prediction problem. It may not be always possible to learn the exact function (or there may not be one - $Y$ might not be fully determined by $X$!), but among a family of functions (denoted here as $\mathcal H$) we can define a best one: it's the one that minimizes the mean squared loss: $$h^* = \arg\min_{h\in \cH} \E[(Y-h(X))]^2$$
For example, if $\cH = \{\text{any function}\}$ , then our "best guess" is $h(X) = \E[Y|X]$.
But again, we have no way of knowing what the mean actually is. So one idea is as follows: we take the data, and we compute a sample estimate of the mean squared error, just like what we've done before: $$\hat \MSE(h) := \hat \E[(Y-h(X))]^2 = \frac{1}{n}\sum_{i=1}^n (y^{(i)}-h(x^{(i)}))^2.$$
(It's customary to use a little hat to indicate sample averages, so we'll use it on the expectation $\mathbb E$ to denote a sampled estimate of an expected value.)
Now, instead of minimizing $\MSE(h)$, which we don't have access to, we can minimize $\hat\MSE(h)$ as a proxy! So the function we learn is $$\hat h := \arg\min_{h\in \cH} \hat\E[(Y-h(X))]^2.$$ (And congratulations - at this point you basically learnt machine learning.)
Footnote. The academic term for this is Empiricial Risk Minimization, where $\hat \MSE(h)$ is precisely said "Empirical Risk".
Another footnote. Why do we use $h$ (and $\cH$) instead of $f$? In the theory, the candidate function was sometimes also called a hypothesis.
The perils of approximate optimization
What could possibly go wrong with this approach? Let me start sowing the seeds of doubt with the following caution from optimization:
Optimizing an approximate function may not get you the approximate optimizer!
And here's a specific example. Consider the function $f(x) = x^{100} + ax$, where $a=0$. This function is clearly minimized at $x=0$. However, if instead you minimized $\hat f(x) = x^{100}+ \hat a x$, then:
- for $a > 0.001$, the minimizer is roughly 1
- for $a < -0.001$, the minimizer is roughly -1
You can see this for yourself on Desmos. (Clicking on the graph will show you where the minimizer is.)
This is perhaps a good reason to start to doubt our initial, simple story that minimizing $\hat \MSE$ is really the right move. At the very least, we now expect some bells and whistles.
And yet, machine learning works
How can we eliminate the above bad example? Does it even matter?
Let's try a different perspective. Suppose we only had a finite $\cH$ - so instead of having continuous parameters, we have finitely many choices. The crazy thought is now:
What if $\hat\MSE(h) \approx \MSE(h)$ for each model $h\in \cH$?
This is not so crazy if you remember that we get a really good deal for multiple mean estimation. I'll skip working through the details here, but we get that $$\sup_{h\in \cH} |\hat \MSE(h) - \MSE(h)| \le c\cdot \sqrt{\frac{\log|\cH|}{n}}$$ holds true with a probability that exponentially decays in $c$.
That's all fine and good, but in many applications we care about $\cH$ isn't finite! For example, in linear regression, we have $$\cH = \{\text{linear functions of }x_1,...,x_d\}$$ which are parametrized by $d$ continuous variables (the coefficients). It turns out that if your parametrization is well-behaved enough (and I'm going to deliberately handwave this), you can get a bound that looks like $$\sup_{h\in \cH} |\hat \MSE(h) - \MSE(h)| \le c\cdot \sqrt{\frac{\dim(\cH)}{n}}$$ where $\dim(\cH)$ is the number of variables in your parametrization.
What's next?
Very roughly speaking, we've shown that machine learning works. Yet, there's a lot to be desired in this very last bound - there's too much it doesn't take into account. For example:
- what if there was a superfluous variable (or almost superfluous) - should it really contribute to the "dimension"?
- what if your distribution of data was weird: maybe $X$ is always equal to a constant $x_0$, then regardless of the family of functions $\cH$, it should only have an "effective" dimesion of 1, since for any model $h$ we only care about its value at $h(x_0)$.
The more complicated story that accounts for both of these things will (finally!) involve Rademacher complexity.
Another loose end that we had was Massart's lemma:
(Massart's lemma) For finite $\cS'$, one has $$\cR(\cS') \le C\cdot \left(\max_{z\in \cS'} \|z\|_2\right) \cdot \sqrt{\log |\cS'|}.$$
This kind of echoes what we're seeing here with the finite $|\cH|$ case, so we're already getting another hint that the Rademacher complexity must be related. We'll also see how we can prove this in the next post.
Comments
Post a Comment