A Primer on Distributional Simplicity Bias

A circuit of what we know so far about distributional simplicity bias in neural networks.

Author
Affiliation
Published

May 17, 2024

Here are a few of the questions I care about: How are neural networks learning? What does a neural network “know” if you stop it halfway through training, compared to at the end of training? Is there something about how networks learn that is consistent across MLPs, ResNets and Transformers? I think the distributional simplicity bias (DSB) conjecture helps with thinking about some of these questions.

Summary of Refinetti, Ingrosso, and Goldt (2023)

Conjecture 1 (Distributional Simplicity Bias) Any parametric model, trained on a classification task using SGD, initially exploits lower-order input statistics to classify its inputs. As training progresses, the network uses increasingly higher order statistics (Refinetti, Ingrosso, and Goldt 2023).

The discussion on “statistics” in DSB by Refinetti, Ingrosso, and Goldt (2023) is centered around the idea of cumulants (jump to here for the mathematical formulation.)

To demonstrate Conjecture 1, Refinetti, Ingrosso, and Goldt (2023) trained ResNet18 models on a few different approximations of the CIFAR-10 dataset, each of which they periodically tested, during model training, for accuracy on the original CIFAR-10 dataset. The CIFAR-10 approximations they used were:

  • GM: Samples from a mixture of Gaussians, where each Gaussian was fitted to one class of CIFAR-10, capturing only the mean and covariance.
  • WGAN: Samples from a WGAN for each class of CIFAR-10
  • CIFAR-5m: Samples from the CIFAR-5m dataset, which consists of 6 million synthetic CIFAR-10-like images generated by the CIFAR-10 Denoising Diffusion generative model.

The key result of this experiment is that the CIFAR10 test accuracies of the ResNets during training on the different clones collapse: GM/CIFAR10 and CIFAR10/CIFAR10, the base model, achieve the same test accuracy for about 50 steps of SGD; WGAN/CIFAR10 matches the test accuracy of the base model for about 1000 steps, and cifar5m/CIFAR10 matches the base model for about 2000 steps.

To summarize, the greater the similarity between the approximated dataset and the original dataset, the greater the number of steps where the test accuracy of the models are close. Beyond that, the accuracy of the model trained on the approximated datatsaet plateaus (but does not fall!).

The authors conducted this experiment on various architectures (two-layer MLP, LeNet, ViT, DenseNet, ResNet18), and still found the same pattern. Additionally, these results hold even if you use a pretrained ResNet18 model, as opposed to a randomly initialized one.

Although, regarding architecture, they did see a curious deviation: a two-layer MLP learning the Gaussian approx. does better than a two-layer learning the WGAN approx.

This suggests that for a simple fully-connected network, having precise first and second moments is more important than the higher-order cumulants generated by the convolutions of the WGAN.

In their discussion, the authors make an important note: while they have full control over the statistics of the mixture of Gaussians approximation, they do not have this control over the other approximate datasets. Then, it is reasonable to believe that the third order cumulant of the WGAN/CIRAR-5m datasets is not equivalent to that of the original dataset. So, the accuracy should diverge relatively soon to when it does the for mixture of Gaussians approximation. But it does not! The authors thus hypothesize that perhaps it is only a part of the cumulant that is relevant for learning.

It’s an open question, then, as to which information about higher order cumulants is deemed important to the neural network.

Summary of Belrose et al. (2024)

Belrose et al. (2024) study DSB through a complementary set of experiments on CIFAR-10, among other experiments.

We train our models on real datasets, then test them throughout training on synthetic data that probe the model’s reliance on statistics of different orders

Unlike Refinetti, Ingrosso, and Goldt (2023), Belrose et al. (2024) train a model on the original CIFAR-10 and test that model on approximations of the CIFAR-10 dataset. Still, their results support the DSB conjecture. They see that the accuracy on first-order samples peaks earlier in training and has a lower maximum than accuracy on second-order samples, followed by second-order hypercube-constrained samples (see Figure 1).

First-order samples uses ICS (Belrose et al. 2024) to generate images with matching means (but no other moments). Second-order samples are from a Gaussian distribution, with hypercube-constrained samples restricting the range to [0,255]

Figure 1: Figure from Belrose et al. (2024), fooling models with grafting statistics

Additionally, the authors run a set of experiments where they “graft the low-order statistics of class B onto class A”. For earlier steps \((2^{4} - 2^{12})\), the models are fooled by the first order grafting and the second order grafting, with the latter leading to greater percentage of fooling (see Figure 2)

Figure 2: Figure from Belrose et al. (2024), fooling models with grafting statistics

Some statistics: moments and cumulants

In this section, we set up moments, cumulants, and their significance.

The discussion of statistics in Refinetti, Ingrosso, and Goldt (2023) is at the level of dataset classes, i.e. the first-order moment of the “horse” class in the CIFAR-10 dataset.

As such, the random variable should be the dataset class, with \(N\) components (or, images). Each component is a size \(CHW\) vector, given that you flatten the \(C,H,W\) dimensions. However, they are not consistent in their analytic derivations, where \(x\) is a \(D\) dimensional sample, and \(x^i\) is a component (e.g. the x-axis value)

Moments

Given a random variable (r.v.) \(x \in \mathbb{R}^D\), the first-order moments would look like:

\[ \mu^i = \mathbb{E}[x^i], \quad \mu^j = \mathbb{E}[x^j], \quad ..., \]

where \(i,j \in D\). A second-order moment would look like:

\[ \mu = \mathbb{E}[x^i x^j]. \]

For moments, there is a uniqueness theorem which claims that the collection of all moments is enough to completely characterize a probability distribution (Roberts, Yaida, and Hanin 2022, Ch.1.2). Of course, access to all the moments is often infeasible.

For Gaussian distributions, we can characterize the distribution with just two moments. The higher-order moments of Gaussian distributions can be reformulated as combinations of the covariance (second-order) matrix, thanks to Wick’s theorem (Roberts, Yaida, and Hanin 2022, Ch.1.2). But, if we care about non-Gaussian distributions, we need a better system.

Cumulants

Cumulants, \(\kappa\), are another handy statistic, closely related to moments. Like moments, cumulants can be derived from a cumulant-generating function (CGF), which is the natural log of a random variable’s MGF.

Careful, the natural log relation does not extend to the statistic itself, i.e. \(\kappa_n\) is not the natural log of \(\mu_n\).

Defining \(\mathcal{A}\) as the algebra of random variables, where \(x_i \in \mathcal{A}\), our cumulants are generated from a multilinear map \(\kappa_n: \mathcal{A}^n \rightarrow \mathbb{R}\). These cumulants can be defined as a function of moments, by the moment-cumulant formula (MCF):

\(x_i\) is either a component of a r.v. or the result of a valid manipulation between components of r.v.s. So, think of \(\mathcal{A}\) as the playground where all r.v.s and their moves hang out.

\[ \mathbb{E}[x^1 \cdot x^2 \cdot ... \cdot x^n] = \sum_{\pi \in P(n)} \kappa_{\pi}(x^1, x^2, ..., x^n), \]

where \(P(n)\) is the set of all partitions of \(n\) components, and \(\pi\) refers to a particular partition. For example, when \(n=2\), one possible partition (on indicies) \(\pi\) is \(\{\{1\}, \{2\}\}\).

A handy fact, which you can verify with the equation above, is that the initial two moments are equal to the initial two cumulants.

Why Bother?

According to (Roberts, Yaida, and Hanin 2022), we care about cumulants over moments because they offer a good indicator of deviation from Gaussian statistics, i.e., small cumulants signify small deviations from Gaussian statistics. Similarly, Speicher (2023) notes that “almost vanishing” cumulants are a proxy for “almost factorization” of moments, suggesting that small cumulants characterize independence.

While this blog won’t provide a satisfying mathematical formulation for the significance of cumulants, I do attempt a visual:

Code
import numpy as np
from scipy.stats import skewnorm, norm, kstat

def calculate_skewnorm_params(shape):
    if shape == 0:
        return 0, 1  # Mean 0, Std Dev 1 for normal distribution
    delta = shape / np.sqrt(1 + shape**2)
    scale = np.sqrt(1/(1-(2*delta**2/np.pi)))
    loc = -1 * scale*delta* np.sqrt(2/np.pi)
    return loc, scale

x = np.linspace(-5, 5, 1000)
norm_pdf = norm.pdf(x)
norm_cumulants = [kstat(norm.rvs(size=100_000), n=i) for i in range(1, 5)]

skewnorm_data = {}
skewness_values = [-0.5, -1, -1.5, 0, 0.5, 1, 1.5]
for skewness in skewness_values:
    loc, scale = calculate_skewnorm_params(skewness)
    skewnorm_samples = skewnorm.rvs(a=skewness, loc=loc, scale=scale, size=100_000) if skewness != 0 else norm.rvs(size=100_000)
    skewnorm_pdf = skewnorm.pdf(x, a=skewness, loc=loc, scale=scale) if skewness != 0 else norm_pdf
    skewnorm_cumulants = [abs(kstat(skewnorm_samples, n=i)) for i in range(1, 5)]
    skewnorm_data[skewness] = {"x": x.tolist(), "pdf": skewnorm_pdf.tolist(), "cumulants": skewnorm_cumulants}

# Make norm_pdf accessible through py
norm_pdf = norm_pdf.tolist() # Convert to list
x = x.tolist()

Ultimately, this is kind of a contrived (and bad example). It allows you to manipulate the skewness (third moment), showing you that the third and fourth cumulant (related to the third moment) change. Is this useful? Meh. Did I spend a lot of time on it? Yes.

The “nearly” Gaussian distributions that Roberts, Yaida, and Hanin (2022) talk about should have “all odd moments and all [odd cumulants] vanish”, which is clearly not the case for our skewed distributions.

We plot the probability density function (PDF) and compute cumulants by sampling our random variable 100,000 times. Skewing the distribution away from the zero-mean normal, we see that our first and second cumulants don’t really budge. Yet, the higher-order cumulants scale with the skewness, providing an empirical confirmation that they provide a measure of how “un-Gaussian” our distribution is.

Next, we walkthrough some of the analytic calcuations to justify the distributional simplicity bias.

Setting up the gradient flow

To avoid littering in-text citations, I’m disclaiming here that the entire walkthrough below is a reformulation of work done by Refinetti, Ingrosso, and Goldt (2023). I’m just filling in the gaps and rethinking the presentation.

Deriving gradient flow

Returning to our random variable \(x\), let’s create a dataset along with its labels \(y\), which are always either \(\pm 1\).

Assume a single neuron with weight vector \(w\) and an activation function \(\sigma\). Then, the output of the linear transformation is \(\lambda = \dfrac{w_i x^i}{\sqrt{D}}\) and the output of our network is \(\hat{y} = \sigma(\lambda)\). Our loss function is \(\ell = (\hat{y} - y)^2\)

G x0 x 0 x1 x 1 a01 λ x0:e->a01:w x2 ... x1:e->a01:w x3 x D x2:e->a01:w x3:e->a01:w a11 σ(λ) a01:e->a11:w l0 layer 1 (input layer) l1 layer 2 l2 layer 3 (output layer)

A single neuron network

We obtain the gradient flow for our neuron by starting with the loss function \[ \begin{align*} \ell &= \left( \hat{y} - y \right)^2. \\[1ex] \eta \frac{\partial \ell}{\partial w_i} &= 2 \left( \sigma(\lambda) - y \right) \sigma'(\lambda) \, \frac{x^i}{\sqrt{D}} \\[1ex] &= \frac{2}{\sqrt{D}} \left( \sigma(\lambda) - y \right) \sigma'(\lambda) \, x^i \\[1ex] &= \mathbb{E}_{j \sim \mathcal{D}} \left[ \left( \sigma(\lambda_j) - y_j \right) \sigma'(\lambda_j) x^i_j \right], \end{align*} \]

where on the second line, we take the first derivative of the objective. We then collect the coefficients and introduce an expectation over all of our dataset \(\mathcal{D}\). Following the original derivation, we drop the scaling coefficients in the final line.

Taylor expansions

We can approximate our gradient flow with a Taylor expansion of our output

More precisely, we do a Mclaurin expansion.

\[ \begin{align*} \sigma(\lambda) &= \sigma(0) + \sigma^{(1)}(0) \lambda + \dfrac{\sigma^{(2)}(0) \lambda^2}{2!} + ..., \end{align*} \]

where \(\sigma^{(k)}\) refers to the \(k\)-th derivative of \(\sigma\). We also have \(\sigma'(\lambda)\) in our gradient flow equation, which we obtain by taking the first derivative, w.r.t. \(\lambda\), of the expansion above

\[ \begin{align*} \sigma'(\lambda) &= 0 + \sigma^{(1)}(0) + \dfrac{2 \sigma^{(2)}(0) \lambda}{2!} + ..., \end{align*} \]

To clarify, \(\sigma'\) and \(\sigma^{(1)}\) are indeed identical.

To simplify the notation, we introduce a term \(\beta_k = \dfrac{\sigma^{(k)}(0)}{k!}\), allowing us to rewrite our expansions as \[ \sigma(\lambda) = \sum_{k=0}^\infty \beta_k \lambda^k, \quad \sigma'(\lambda) = \sum_{k=0}^\infty \tilde{\beta}_{k+1} \lambda^k, \]

where \(\tilde{\beta}_k = k\beta_k\).

Plug and chug

Recalling our gradient flow equation \[ \begin{align*} \frac{\partial \ell}{\partial w_i} &= \mathbb{E} \left[ \left( \sigma(\lambda) - y \right) \sigma'(\lambda) x^i \right] \\[1ex] &= \mathbb{E} \left[ \sigma(\lambda) \sigma'(\lambda) x^i - y \sigma'(\lambda) x^i \right], \end{align*} \]

we focus on the first term within the gradient flow, thanks to the linearity of expectation. We also drop the expectation notation for clarity. We plug in the expansions for \(\sigma(\lambda)\) and \(\sigma'(\lambda)\), giving us

\[ \begin{align*} \mathbb{E} \left[ \sigma(\lambda) \sigma'(\lambda) x^i \right] &= \mathbb{E} \left[ \left( \sum_{k=0}^\infty {\beta}_{k} \lambda^k \right) \left( \sum_{k=0}^\infty \tilde{\beta}_{k+1} \lambda^k \right) x^i \right] \\[1ex] &= \mathbb{E} \left[ \sum_{k=0}^\infty \sum_{n=0}^k {\beta}_{n} \lambda^n \tilde{\beta}_{k-n+1} \lambda^{k-n} x^i \right] \\[1ex] &= \mathbb{E} \left[ \sum_{k=0}^\infty \sum_{n=0}^k {\beta}_{n} \tilde{\beta}_{k-n+1} \lambda^{k} x^i \right] \\[1ex] &= \sum_{k=0}^\infty \mathbb{E} \left[ \sum_{n=0}^k {\beta}_{n} \tilde{\beta}_{k-n+1} \lambda^{k} x^i \right], \\[1ex] \end{align*} \]

where in the second line, we take advantage of the Cauchy product formulation for reindexing. Next, we collect the \(\lambda\) terms to simplify. In the final line, we use the linearity of expectation to move the expectation inside the sum. Repeating this for the second term gives us

Cauchy’s name is on the Eiffel tower. Pretty sick.

\[ \begin{align*} \mathbb{E} \left[ -y \sigma'(\lambda) x^i \right] &= \mathbb{E} \left[ -y \left( \sum_{k=0}^\infty \tilde{\beta}_{k+1} \lambda^k \right) x^i \right] \\[1ex] &= - \sum_{k=0}^\infty \mathbb{E} \left[ \tilde{\beta}_{k+1} y \lambda^k x^i \right], \end{align*} \]

resulting in the formulation \[ \begin{align*} \frac{\partial \ell}{\partial w_i} &= \sum_{k=0}^\infty \mathbb{E} \left[ \sum_{n=0}^k {\beta}_{n} \tilde{\beta}_{k-n+1} \lambda^{k} x^i - \tilde{\beta}_{k+1} y \lambda^k x^i \right] \\[1ex] &= \sum_{k=0}^\infty \mathbb{E} \left[ \left( \sum_{n=0}^k {\beta}_{n} \tilde{\beta}_{k-n+1} - \tilde{\beta}_{k+1} y \right) \lambda^k x^i \right]. \end{align*} \]

To simplify this, we use \(\gamma_k = \sum_{n=0}^k {\beta}_{n} \tilde{\beta}_{k-n+1}\), which, for a couple of \(k\)s looks like

\[ \begin{align*} \gamma_0 &= \beta_0 \tilde{\beta}_1 &= \beta_0 \beta_1, \\[1ex] \gamma_1 &= \beta_0 \tilde{\beta}_2 + \beta_1 \tilde{\beta}_1 &= 2 \beta_0 {\beta}_2 + \beta_1^2. \\[1ex] \end{align*} \]

Thus, we have

You probably need this: \[ \begin{align*} \lambda^k &= \left( \dfrac{w_i x^i}{\sqrt{D}} \right)^k.\\ \tilde{\beta}_k &= k\beta_k. \\ \gamma_k &= \sum_{n=0}^k {\beta}_{n} \tilde{\beta}_{k-n+1}. \end{align*} \]

\[ \frac{\partial \ell}{\partial w_i}= \sum_{k=0}^\infty \mathbb{E} \left[ \left( \gamma_k - \tilde{\beta}_{k+1} y \right) \lambda^k x^i \right]. \]

This formulation is sort of cool. The authors note that

[the] gradient flow updates of the weight have thus two contributions: the first, proportional to \(\gamma\), depends only on the inputs, while the second, proportional to \(\tilde{\beta}\), depends on the product of inputs and their label.

Roughly speaking, the term \(\lambda^k x^i\) looks like we are transforming the original input with the power of an already transformed input. This is then distributed between two coefficients, which grow at different rates, one of which is scaled by the class label. Regardless of the input’s label, it still contributes to the weights(?)

Analytical walkthrough

Here, it is a good idea to note that the authors are working with a toy classification dataset, where the labels are binary, \(y = \pm 1\). The first two components are two cartesian points, equally probably to be drawn from either class of “rectangles”, which are linearly separable by a line parallel to the y-axis on a coordinate plane.

Zeroth order

Referring back to our reformulated gradient, if we set \(k=0\), our gradient is

\[ \frac{\partial \ell}{\partial w_i}= \mathbb{E} \left[ x^i \gamma_0 - x^i {\beta}_{1} y \right]. \]

The authors assume a normalized dataset, such that \(\mathbb{E}[x^i] = 0\). As a result, we are left with the term

\[ \frac{\partial \ell}{\partial w_i}= - {\beta}_{1} \mathbb{E} \left[ x^i y \right]. \]

This simplifies to

\[\boxed{ \frac{\partial \ell}{\partial w_i}= - {\beta}_{1} \left( \kappa_+^i - \kappa_-^i \right) } \]

because we can separate our samples into their respective classes and then take the expectation. This gives us a “mean-based” classifier, which simply draws a boundary perpendicular to the difference vector between two class means.

Recall that the first moment is identical to the first cumulant, so we can use \(\kappa^i\) and \(\mu^i\) interchangeably. This does not extend to higher orders.

First order

Unrolling our summation upto \(k=1\), and using the result from above, we obtain

\[ \frac{\partial \ell}{\partial w_i}= - {\beta}_{1} \left( \kappa_+^i - \kappa_-^i \right) + \mathbb{E} \left[ \left( \gamma_1 - \tilde{\beta}_{2} y \right) \lambda x^i \right]. \]

Here, the authors substitute \(\lambda = w_j x^j\), introducing correlations with an additional component

Actually, they don’t explain or justify this move. I don’t know why this is the case.

\[ \begin{align*} \frac{\partial \ell}{\partial w_i} &= - {\beta}_{1} \left( \kappa_+^i - \kappa_-^i \right) + \mathbb{E} \left[ w_j \left( \gamma_1 x^i x^j - \tilde{\beta}_{2} y x^i x^j \right) \right] \\[1ex] &= - {\beta}_{1} \left( \kappa_+^i - \kappa_-^i \right) + w_j \left( \gamma_1 \kappa^{ij} - \tilde{\beta}_{2} \left(\mu^{ij}_+ - \mu^{ij}_- \right) \right) \\[1ex] &\boxed{= - {\beta}_{1} m^i + w_j \gamma_1 \mu^{ij}} \end{align*} \]

where the final term vanishes in the last line because the authors claim the covariance matrix between the two classes in the toy dataset are the same, i.e. \(\mu_+^{ij} = \mu_{-}^{ij}\). We also introduce \(m^i = \kappa_+^i - \kappa_-^i\), which represents the class mean difference.

Here, the authors pause to ask why \(\mu^{ij}\) is meaningful as a value to the linear classifier, since the discriminatory information provided by the statistic \(\mu^{ij}\) is not immediately obvious. But, along the lines of Fisher’s linear discriminant analysis (LDA), if you expand the covariance matrix, you obtain \[ \mu^{ij} = \frac{1}{2} ( \mu_+^{ij} + \mu_-^{ij} ). \]

Recalling the definition of covariance \[ \begin{align*} cov(x^i, x^j) &= \mu^{ij} - \mu^i \mu^j \\[1ex] \mu^{ij} &= cov(x^i, x^j) + \mu^i \mu^j \end{align*} \]

Oh, and it’s handy that the second cumulant \(\kappa^{ij} = cov(x^i, x^j)\)

Expanding the first term in the total expectation

\[ \begin{align*} \mu_+^{ij} &= \kappa_+^{ij} - \mu_+^i\mu_+^j \\[1ex] &= \kappa_+^{ij} - \left(\mu^i + \frac{1}{2} m^i \right) \left(\mu^j + \frac{1}{2} m^j\right) \\[1ex] &= \kappa_+^{ij} - \mu^i \mu^j - \frac{1}{2} (\mu^i m^j + \mu^j m^i) - \frac{1}{4} m^i m^j. \end{align*} \]

We make use of the properties \[ \begin{aligned} \mu_+^i &= \mu^i + \frac{1}{2} m^i \\ \mu_-^i &= \mu^i - \frac{1}{2} m^i. \end{aligned} \] These come from a bit of algebraic manipulation of the law of total expectation.

Doing a similar expansion for the second term \(\mu_-^{ij}\), we have

\[ \mu_-^{ij} = \kappa_-^{ij} - \mu^i \mu^j + \frac{1}{2} (\mu^i m^j + \mu^j m^i) - \frac{1}{4} m^i m^j. \]

Adding the two expressions together, we get:

\[ \begin{aligned} \mu^{ij} &= \frac{1}{2} (\mu_+^{ij} + \mu_-^{ij}) \\[1ex] &= \frac{1}{2} \left[ (\kappa_+^{ij} + \kappa_-^{ij}) - 2 \mu^i \mu^j - \frac{1}{2} m^i m^j \right] \\[1ex] &= \frac{1}{2}(\kappa_+^{ij} + \kappa_-^{ij}) - \mu^i \mu^j - \frac{1}{4} m^i m^j \\[1ex] % &= \kappa^{ij} - \mu^i \mu^j - \frac{1}{4} m^i m^j \\[1ex] % &= \mu^{ij} - 2\mu^i \mu^j - \frac{1}{4} m^i m^j \end{aligned} \]

Using the language from LDA, the first two terms are a “within-class” covariance, whereas the last term provides the “between-class” covariance.

Second order

Like before, for \(k=2\), we unroll our sum by another value and reuse solutions from previous orders, resulting in

\[ \begin{align*} \frac{\partial \ell}{\partial w_i} &= - {\beta}_{1} m^i + w_j \gamma_1 \mu^{ij} + \mathbb{E} \left[ w_j w_k \left( \gamma_2 x^i x^j x^k - \tilde{\beta}_{3} y x^i x^j x^k \right) \right] \\[1ex] &= - {\beta}_{1} m^i + w_j \gamma_1 \mu^{ij} + w_j w_k \left( \gamma_2 \mu^{ijk} - \tilde{\beta}_{3} \left(\mu^{ijk}_+ - \mu^{ijk}_- \right) \right) \\[1ex] &= - {\beta}_{1} m^i + w_j \gamma_1 \mu^{ij} - w_j w_k \tilde{\beta}_{3} \left(\mu^{ijk}_+ - \mu^{ijk}_- \right), \end{align*} \]

where we remove a term in the final line because we know that \(\mu^{ijk} = 0\). Focusing on the difference between the third order moments, we can rewrite that as cumulants using the moment-to-cumulant formula.

\[ \begin{align*} \mu^{ijk}_+ - \mu^{ijk}_- &= \left( \kappa_+^{ijk} + \kappa_+^{i}\kappa_+^{jk}[3] + 2\kappa_+^{i}\kappa_+^{j}\kappa_+^{k} \right) + \left( \kappa_-^{ijk} + \kappa_-^{i}\kappa_-^{jk}[3] + 2\kappa_-^{i}\kappa_-^{j}\kappa_-^{k} \right). \end{align*} \]

In the equation above, we introduce a bracket notation to condense all the permutations of that type into a single term. From earlier, we also know that the covariance matrices between the two classes are equivalent, so we can condense those bracketed terms further.

\[ \begin{align*} \mu^{ijk}_+ - \mu^{ijk}_- &= \left( \kappa_+^{ijk} - \kappa_-^{ijk} \right) + \left( \kappa_+^{i} - \kappa_-^{i} \right) \kappa_\pm^{jk}[3] + 2\kappa_+^{i}\kappa_+^{j}\kappa_+^{k} - 2\kappa_-^{i}\kappa_-^{j}\kappa_-^{k} \end{align*} \]

In fact, we have encountered \(\kappa_+^i - \kappa_-^i = m^i\) when dealing with the zeroth order classifier. Plugging this back into our unrolled sum

\[ \boxed{ \begin{aligned} \frac{\partial \ell}{\partial w_i} &= - {\beta}_{1} m^i + w_j \gamma_1 \mu^{ij} \\ &\quad - w_j w_k \tilde{\beta}_{3} \left( \left( \kappa_+^{ijk} - \kappa_-^{ijk} \right) + m^i \kappa_\pm^{jk}[3] + 2\kappa_+^{i}\kappa_+^{j}\kappa_+^{k} - 2\kappa_-^{i}\kappa_-^{j}\kappa_-^{k} \right) \end{aligned} } \]

we get the gradient up to the second order.

Third order

⚠️ Proceed with caution! ⚠️

Applying the same initial approach for \(k=3\), we direct our focus only to the new term

\[ \begin{aligned} \mathbb{E} \left[ w_j w_k w_l \left( \gamma_3 x^i x^j x^k x^l - \tilde{\beta}_{4} y x^i x^j x^k x^l \right) \right] &= w_j w_k w_l \gamma_3 \mu^{ijkl} , \end{aligned} \]

where the \(\tilde{\beta}_4\) term vanishes because \(\kappa_+^{ijkl} - \kappa_-^{ijkl} = 0\)

At this point, the authors make the claim that, similar to the first order case, \(\mu^{ijkl}\) can be split into contributions from a within class fourth order moment \(\mu^{ijkl}_w\) and contributions from other lower orders

\[ \mu^{ijkl} \approx \mu^{ijkl}_w + \tau, \]

and that \(\mu^{ijkl}_w\) should be split into two components

\[ \mu^{ijkl}_w = \kappa^{ijkl}_w + \mu_{w,G}^{ijkl}, \]

where the latter does not play a role in the current classifier, which is operating on the “rectangular” dataset.

Last updated

2024-05-17 07:01:55 UTC

Corrections

If you see mistakes or want to suggest changes, please create an issue on the source repository. Suggestions are appreciated!

Reuse

Generated text and figures are licensed under Creative Commons Attribution CC BY 4.0. The raw article and it’s contents are available at on Github, unless otherwise noted. The figures that have been reused from other sources don’t fall under this license and can be recognized by a note in their caption: ‘Figure from …’

References

Belrose, Nora, Quintin Pope, Lucia Quirke, Alex Mallen, and Xiaoli Fern. 2024. “Neural Networks Learn Statistics of Increasing Complexity.” arXiv Preprint arXiv:2402.04362.
Refinetti, Maria, Alessandro Ingrosso, and Sebastian Goldt. 2023. “Neural Networks Trained with SGD Learn Distributions of Increasing Complexity.” In International Conference on Machine Learning, 28843–63. PMLR.
Roberts, Daniel A., Sho Yaida, and Boris Hanin. 2022. The Principles of Deep Learning Theory. Cambridge University Press. https://arxiv.org/abs/2106.10165.
Speicher, Roland. 2023. “High-Dimensional Analysis: Random Matrices and Machine Learning.” Saarland University; Lecture Notes. https://rolandspeicher.files.wordpress.com/2023/08/hda_rmml.pdf.

Citation

BibTeX citation:
@online{aswani2024,
  author = {Aswani, Nishant},
  title = {A {Primer} on {Distributional} {Simplicity} {Bias}},
  date = {2024-05-17},
  url = {https://nishantaswani.com/articles/dsb.html},
  langid = {en}
}
For attribution, please cite this work as:
Aswani, Nishant. 2024. “A Primer on Distributional Simplicity Bias.” May 17, 2024. https://nishantaswani.com/articles/dsb.html.