BTS II: An introduction to learning theory

(David here.) In this blogpost, we'll go off on a tangent and explore what it means to learn from data. In the process, we will get slightly closer (but not quite there) to the context where Rademacher complexity emerges.

This is the second blogpost in the sequence, continuing off BTS I.

\newcommand\lrangle[1]{\left\langle #1 \right\rangle} \newcommand \E {\mathbb E} \newcommand \cH {\mathcal H} \newcommand \Var {\mathrm{Var}} \newcommand \MSE {\mathrm{MSE}}

\newcommand \cF {\mathcal F} \newcommand \cG {\mathcal G} \newcommand \cR {\mathcal R} \newcommand \cS {\mathcal S} \newcommand \cD {\mathcal D}

Commentary: should you learn theory?

I'll take a page from Dylan and sneak in some non-math opinions. The question at hand is

As an Olympiad contestant, should you learn any "advanced theory"?

"Advanced theory" here is generally anything outside the scope of e.g. the IMO syllabus. It could be things like collegiate topics (linear algebra, group theory etc.) or really anything at all.

There are a few possible answers, but let me focus on two of them:

  • No. It (mostly) doesn't help you solve any new Olympiad problems, and distracts you from the core problem solving aspects of Olympiads.
  • Yes. Some problems are more naturally interpreted using some theory, and learning them helps you understand the big picture ideas better. Besides, Olympiads will not follow you into college but the bits of theory you learn might.

You might know that in the USA, MOP leans heavily towards "yes", and they have had great success in recent years at the IMO. But there is some truth to "no" as well, since it is easy to learn lots of theory and not know how exactly to apply it.

My personal experience is that I "got into" many mathematical areas because of that one Olympiad problem that introduced me to it, but at the same time it's very easy to Google a topic and have a textbook make your head spin. You can strive for a 80-20 here: explore what you want to when it "feels right to you", and give up when it's too overwhelming.

This advice also applies to this blogpost, if you are still of contest age. Feel free to follow along as far as you are comfortable with!

Prediction problems

There are lots of situations in life where we would like to use data to predict things that we care about. For instance,

  • maybe we want to look at the partial scores during the IMO and predict the medal cutoffs
  • maybe we want to look at part of a paragraph and predict the word that comes next.
  • maybe we want to look at yesterday's stock price movements and predict tomorrow's stock price movements
  • maybe we want to look at (...) and predict (...)

In general, we have data XX and we would like to predict YY. Sometimes, YY is completely determined by XX (i.e. Y=f(X)Y=f(X) for some XX). Oftentimes, this is not the case, so instead we speak of a joint probability distribution of (X,Y)(X,Y). (You could think of this as the (X,Y)(X,Y) you'd encounter "in the wild".)

Of course, predictions should be a (deterministic) function of data (which we'll denote as h(X)h(X)). How can we measure how good our predictions are? If the prediction is a real number, we could use the so-called mean squared error MSE(h):=E[(Yh(X))2]\MSE(h) := \E[(Y-h(X))^2] where the E\E averages over the data distribution (X,Y)(X,Y). Then, one way to figure out what a good prediction function hh is would be to minimize hh.

(In fact, I could tell you what the best hh is: it's h(x)=E[YX=x]h(x) = \E[Y|X=x]! I'll leave it as an exercise to the reader as to why.)

So prediction problems are easy? Wait, not so fast. For one, every E\E that you see represents an unknowable number - there's no way to figure out what the expected stock price is tomorrow, for example, let alone minimize hh over it.

Unless... cue statistics.

Statistics

The good news is that we do have a way to "figure out" E[Y]\E[Y]. I'll describe how all of this works using an example that everyone is familiar with: election polls.

Let's say that you would like to get a sense of election results in your GRCs, and in particular the percentage of the voters that voted from PAP. However, it's somewhat impractical to actually get everyone's votes (that's what the election count is - and it takes a long while), so to get a sneak peak we can just pick 100 vote slips and random and look at what's written on them.

Let's say out of these 100 votes, 60 of them vote for the PAP. Obviously, this doesn't mean that the actual percentage of voters who voted for the PAP is 60%. A priori, it could really be that we just happened to pick exactly the 60 people who voted for PAP and no one else did.

But polls work, and behind the scenes the central limit theorem is helping us. Let say that the actual percentage of the population that votes for PAP is pp, and that we draw nn samples.

  • A single vote has mean pp and variance σ2=p(1p)\sigma^2 = p(1-p).
  • The total number of PAP votes consists of nn independent (roughly!) votes, so it approximately follows the normal distribution with mean npnp and variance nσ2n\sigma^2, because for independent random variables, expected values and variances are linear.
  • the fraction of PAP votes we estimate follows the normal distribution with mean pp and variance σ2/n\sigma^2/n. (Variance is a squared quantity, so dividing the underlying by nn means that we have to divide variance by n2n^2 )

Thus, the poll mean follows a normal distribution with mean pp and variance σ2/n\sigma^2 / n, or equivalently standard deviation σ/n\sigma/\sqrt{n}. You might know that the normal distribution has some complicated formula attached to it, like p(x)=12πσ2exp((xμ)2σ2)p(x) = \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left(-\frac{(x-\mu)^2}{\sigma^2}\right) but really the takeaway is simply that

For a normal distribution:

  • there's a 68% chance of being within 1 standard deviation of the mean,
  • a 95% chance of being within 2 standard deviations of the mean,
  • and beyond that the odds of being dd standard deviations greater than the mean scales like exp(d2)\exp(-d^2).

So in our case, let's say we're happy with 2 standard deviations (being in 95% of all scenarios). Then, 60% is likely within p±2σnp \pm \frac{2\sigma}{\sqrt{n}}, and we have that σ2=p(1p)1/4\sigma^2 = p(1-p)\le 1/4. Simplifying, we get that there is a 95% chance that 60% is within p±10%p\pm 10\%. It could still very well be that the real pp was p=0.5p=0.5, but our sample would be a 2.5%-outlier scenario.

Reversing the logic, you could believe that there is an nn where it is extremely likely that your sample mean is within 1% or even 0.1% of the actual mean (even though the number of samples we'd need starts to get ridiculous).

If you're queasy about the fact that the sample mean is really only an approximate normal distribution, you'll be glad to know that there are theorems that rigorously establish this intuition, for example Hoeffding's inequality. But we'll assume that this picture is mostly right.

Sampling, accuracy and reliability

For any estimation problems, we are going to end up with a trade-off between:

  • how many samples we need
  • the accuracy of our estimate
  • the reliability of our estimate (which is typically measured by the probability of the estimate being off our accuracy threshold).

In the context of estimating the mean, we can about this tradeoff using a few different formulations. For sampling, we have

We can estimate any true mean we want up to ϵ\epsilon using O(log(1/δ)/ϵ2)O(\log(1/\delta)/\epsilon^2) samples, with the probability of being wrong at most δ\delta.

(Here, the suppressed constant depends on the value we're estimating). Shuffling this around to focus on the accuracy of our estimate, we also have

With nn samples, we can estimate any true mean we want up to O(log(1/δ)/n)O(\sqrt{\log(1/\delta) / n}) error, with the probability of being wrong at most δ\delta.

Oftentimes for convenience we suppress the δ\delta term and just write O~(1/n)\tilde O(\sqrt{1/n}), because we typically don't care about how large δ\delta is after a fairly small number of zeros beyond the decimal point. In general, any squiggle you see means "extra logarithmic terms we don't really care about", and honestly most of the time we will be careless and drop even that too.

Multiple mean estimation

If you have NN unknown means that you would like to estimate using data, then it follows that with the same number of samples, the probability of at least one of them being wrong is at most NδN\delta (by a "union bound"). As you might be able to see, this gets absorbed easily into the log(1/δ)\log(1/\delta).

Let's restate one of the statements from the last section for multiple mean estimation:

We can estimate any NN true means we want up to ϵ\epsilon using O(log(N/δ)/ϵ2)O(\log( N / \delta )/\epsilon^2) samples, with the probability of being wrong at most δ\delta.

One way to read this is that to estimate NN true means you only need O(log(N)/ϵ2)O(\log(N) / \epsilon^2) additional samples (compared to estimating a single mean) - this reads like a good deal to me!

The more general takeaway is perhaps that

Typically, datasets are representative of the underlying population/distribution.

Machine learning

We can do the "same" idea for a prediction problem. It may not be always possible to learn the exact function (or there may not be one - YY might not be fully determined by XX!), but among a family of functions (denoted here as H\mathcal H) we can define a best one: it's the one that minimizes the mean squared loss: h=argminhHE[(Yh(X))]2h^* = \arg\min_{h\in \cH} \E[(Y-h(X))]^2

For example, if H={any function}\cH = \{\text{any function}\} , then our "best guess" is h(X)=E[YX]h(X) = \E[Y|X].

But again, we have no way of knowing what the mean actually is. So one idea is as follows: we take the data, and we compute a sample estimate of the mean squared error, just like what we've done before: MSE^(h):=E^[(Yh(X))]2=1ni=1n(y(i)h(x(i)))2.\hat \MSE(h) := \hat \E[(Y-h(X))]^2 = \frac{1}{n}\sum_{i=1}^n (y^{(i)}-h(x^{(i)}))^2.

(It's customary to use a little hat to indicate sample averages, so we'll use it on the expectation E\mathbb E to denote a sampled estimate of an expected value.)

Now, instead of minimizing MSE(h)\MSE(h), which we don't have access to, we can minimize MSE^(h)\hat\MSE(h) as a proxy! So the function we learn is h^:=argminhHE^[(Yh(X))]2.\hat h := \arg\min_{h\in \cH} \hat\E[(Y-h(X))]^2. (And congratulations - at this point you basically learnt machine learning.)

Footnote. The academic term for this is Empiricial Risk Minimization, where MSE^(h)\hat \MSE(h) is precisely said "Empirical Risk".

Another footnote. Why do we use hh (and H\cH) instead of ff? In the theory, the candidate function was sometimes also called a hypothesis.

The perils of approximate optimization

What could possibly go wrong with this approach? Let me start sowing the seeds of doubt with the following caution from optimization:

Optimizing an approximate function may not get you the approximate optimizer!

And here's a specific example. Consider the function f(x)=x100+axf(x) = x^{100} + ax, where a=0a=0. This function is clearly minimized at x=0x=0. However, if instead you minimized f^(x)=x100+a^x\hat f(x) = x^{100}+ \hat a x, then:

  • for a>0.001a > 0.001, the minimizer is roughly 1
  • for a<0.001a < -0.001, the minimizer is roughly -1

You can see this for yourself on Desmos. (Clicking on the graph will show you where the minimizer is.)

This is perhaps a good reason to start to doubt our initial, simple story that minimizing MSE^\hat \MSE is really the right move. At the very least, we now expect some bells and whistles.

And yet, machine learning works

How can we eliminate the above bad example? Does it even matter?

Let's try a different perspective. Suppose we only had a finite H\cH - so instead of having continuous parameters, we have finitely many choices. The crazy thought is now:

What if MSE^(h)MSE(h)\hat\MSE(h) \approx \MSE(h) for each model hHh\in \cH?

This is not so crazy if you remember that we get a really good deal for multiple mean estimation. I'll skip working through the details here, but we get that suphHMSE^(h)MSE(h)clogHn\sup_{h\in \cH} |\hat \MSE(h) - \MSE(h)| \le c\cdot \sqrt{\frac{\log|\cH|}{n}} holds true with a probability that exponentially decays in cc.

That's all fine and good, but in many applications we care about H\cH isn't finite! For example, in linear regression, we have H={linear functions of x1,...,xd}\cH = \{\text{linear functions of }x_1,...,x_d\} which are parametrized by dd continuous variables (the coefficients). It turns out that if your parametrization is well-behaved enough (and I'm going to deliberately handwave this), you can get a bound that looks like suphHMSE^(h)MSE(h)cdim(H)n\sup_{h\in \cH} |\hat \MSE(h) - \MSE(h)| \le c\cdot \sqrt{\frac{\dim(\cH)}{n}} where dim(H)\dim(\cH) is the number of variables in your parametrization.

What's next?

Very roughly speaking, we've shown that machine learning works. Yet, there's a lot to be desired in this very last bound - there's too much it doesn't take into account. For example:

  • what if there was a superfluous variable (or almost superfluous) - should it really contribute to the "dimension"?
  • what if your distribution of data was weird: maybe XX is always equal to a constant x0x_0, then regardless of the family of functions H\cH, it should only have an "effective" dimesion of 1, since for any model hh we only care about its value at h(x0)h(x_0).

The more complicated story that accounts for both of these things will (finally!) involve Rademacher complexity.

Another loose end that we had was Massart's lemma:

(Massart's lemma) For finite S\cS', one has R(S)C(maxzSz2)logS.\cR(\cS') \le C\cdot \left(\max_{z\in \cS'} \|z\|_2\right) \cdot \sqrt{\log |\cS'|}.

This kind of echoes what we're seeing here with the finite H|\cH| case, so we're already getting another hint that the Rademacher complexity must be related. We'll also see how we can prove this in the next post.

Comments

Popular posts from this blog

SMO Open 2024 ??% Speedrun

Musings about the IMO