import {Legend as Legend, Swatches as Swatches} from "@d3/color-legend"
import {pack as pack} from "@esperanc/range-input-variations"
Comparing the neural representations of two CNNs trained on different continual learning algorithms
May 12, 2023
d3 = require("d3@7")
import {Legend, Swatches} from "@d3/color-legend"
import {pack} from "@esperanc/range-input-variations"
import {testMax, testMin} from "@niniack/forgetting-representations"
import {viewof heatmap_controls} from "@niniack/forgetting-representations"
import {heatmaps} from "@niniack/forgetting-representations"
import {viewof cam_controls_one} from "@niniack/forgetting-representations"
import {viewof cam_controls_two} from "@niniack/forgetting-representations"
import {cams} from "@niniack/forgetting-representations"
html`
<style>
svg {
display: block;
margin: auto;
align-items: center;
}
</style>
`
import {Legend as Legend, Swatches as Swatches} from "@d3/color-legend"
import {pack as pack} from "@esperanc/range-input-variations"
testMax is not defined
A key concept in continual learning (CL), catastrophic forgetting has been of interest since the 90’s. Simply put, it occurs when machine learning models forget past knowledge, from previous tasks, as they learn on new tasks. In recent literature, it is often discussed in the context of deep networks, usually to develop training algorithms or architectures to overcome the issue. Comparatively, little work has been done in analyzing how catastrophic forgetting manifests itself, or probing how existing algorithms affect model weights.
Two papers (see Ramasesh, Dyer, and Raghu 2020; Davari et al. 2022) that piqued my interest, approach the task by looking at neural representations of models. They ask the central question of “how does catastrophic forgetting affect the hidden representations of deep networks?” It is an important question because internal representations are often also used as methods for explainability. Further, these internal representations are what eventually generate the output; so, it is likely worth taking a look at how they evolve over the course of CL training regimes.
Inspired by these works, this blog post conducts a small-scale comparison between neural representations of models trained on different CL methods. I rely on a generalized framework to employ three distance metrics to compare two popular CL baselines, Learning without Forgetting (Li and Hoiem 2017) and Memory-Aware Synapses (Aljundi et al. 2018).
Before we get to the experiments, let’s first dive into some of the math!
Williams et al. (2022) lay out a framework that generalizes comparison metrics CCA and the orthogonal Procrustes Distance.
In the following derivations, we assume two sets of neural activations
The orthogonal procrustes distance aims to find an orthogonal matrix
which is equivalent to minimizing the squared Frobenius norm
Now, we move on to another popular metric!
CCA seeks to find projection vectors
Correlations should be scale-invariant: two vectors should still be perfectly correlated if one of them is simply a scaled version of the other. So, we also impose the constraint that
We can do this because dividing the projection
By doing so, we can rewrite the objective above to the following:
To avoid reducing the dimensionality of the neural representations, we can extend this problem to find weight matrices
where
But, that’s really equivalent to:
Although this problem looks very different from the Procrustes distance, the two problems can actually be written in an equivalent form.
To convert CCA into the Procrustes problem, we can employ a change of variables:
such that the problem is now:
We can employ another change of variables to further simplify the form of the problem:
so that it takes the form:
The two problems are of the same form; the only difference is that CCA solves the problem on the whitened versions,
Returning back to the original CCA problem, before any change of variables, Williams et al. (2022) show that by modifying the constraints to introduce a hyperparameter
Here, if
This is the generalized framework I use to compare the neural representations in the experiments! With no particular intuition, I pick
Now, let’s quickly talk about the continual learning algorithms we will be using in the experiments: Learning without Forgetting and Memory-Aware Synapses.
LwF employs knowledge distillation (KD) (hintonDistillingKnowledgeNeural2015set?) as the main tool to prevent forgetting. The algorithm assumes a multi-head task-IL setting, suggesting that a model can be viewed as having a of shared parameters
In each learning session, the goal of the algorithm is to add a new set of task-specific parameters
where
It’s not just a simple cross entropy loss because we are using the modified outputs
MAS is another regularization-based CL technique focused on identifying and preserving network parameters relevant for previous tasks, while learning a new task. It does this by calculating an importance weight for each parameter in the model with respect to the previous tasks. The importance of a layer’s parameters,
But, that can be simplified with the first-order Taylor series approximation:
With the necessary background out of the way, we can (finally) move onto the experimental setup.
I trained three multi-head VGG networks, one using the LwF strategy, the other using the MAS strategy, and one without any CL strategy (“Naive” strategy). The trainset, consisting of 100k images from the 200 classes, was split into 10 tasks, each made up of 10k images from 20 classes. Similarly, the testset, consisting of 10k images from 200 classes, was split into 10 tasks, resulting in 1k images from 20 classes per experience.
The models were trained using stochastic gradient descent (SGD) with 0.9 momentum and a constant
After training, each snapshot was evaluated on all experiences. As a result, many models were evaluated on tasks that they never learned on. To keep the experiment computationally reasonable, each evaluation was limited to 1k images from the 20 classes. During each evaluation, I hooked into the intermediate layers of the model to store the neural activations,
Then, with the layers held constant, I fit the neural activations, from different strategies, across all experiences, such that I was solving the problem:
to obtain the transformation matrices
Then to get a distance score I solve:
Take a look!
Given the not-so-great hyperparameter selection, the model did not do a great job of learning. So, I ran the experiments on 1k samples from both the train and test datasets, expecting that the activations would differ, but it doesn’t seem like it.
The heatmaps are all generated with the final model, trained on all experiences, for each approach. I picked layers 3, 6, 8, 11, and 13 from the VGG to do the comparisons. And, as described earlier, you can see how the distances change based on the distance method we use. As a reminder,
viewof cam_controls_two is not defined
From the heatmaps, we see, and probably expected, that the two algorithms lead to neural activations that are relatively similar (leftmost heatmap).
Generally, the MAS algorithm leads to neural representations that are most different from the naive implementation, because the heatmap is darker across the grid. With the assumption that the naive implementation is the worst outcome, it might imply that MAS is a more effective algorithm at preventing forgetting; after all, its the furthest away from the worst outcome.
Plotting the accuracy for each final model, across all tasks, LwF is clearly the superior model! It seems like MAS has led to internal representations that have diverged in a direction that harms performance.
{
const data = {
naive_acc: [0.0390, 0.0270, 0.0520, 0.0250, 0.0550, 0.0470, 0.0510, 0.0840, 0.0740, 0.4300],
mas_acc: [0.3850, 0.3450, 0.3390, 0.3380, 0.4140, 0.3530, 0.3290, 0.3830, 0.4500, 0.4310],
lwf_acc: [0.5050, 0.5030, 0.5400, 0.4690, 0.5200, 0.4790, 0.4030, 0.4200, 0.4860, 0.4310]
};
const nameMapper = {
["naive_acc"]: labels[0],
["mas_acc"]: labels[1],
["lwf_acc"]: labels[2]
};
// const colors = { naive_acc: 'blue', mas_acc: 'red', lwf_acc: 'green' };
const width = 900, height = 400;
const margin = ({top: 20, right: 20, bottom: 40, left: 50});
const svg = d3.create("svg")
.attr("width", width)
.attr("height", height);
const xScale = d3.scaleLinear()
.domain([0, 9])
.range([margin.left, width - margin.right]);
const yScale = d3.scaleLinear()
.domain([0, 1])
.range([height - margin.bottom, margin.top]);
const line = d3.line()
.defined(d => !isNaN(d))
.x((d, i) => xScale(i))
.y(d => yScale(d));
for (const name in data) {
const datum = data[name];
svg.append("path")
.datum(datum)
.attr("fill", "none")
.attr("stroke", colors(nameMapper[name]))
.attr("stroke-width", 1.5)
.attr("d", line);
}
svg.append("g")
.attr("transform", `translate(0,${height - margin.bottom})`)
.call(d3.axisBottom(xScale).ticks(9).tickFormat(i => i + 1));
svg.append("g")
.attr("transform", `translate(${margin.left},0)`)
.call(d3.axisLeft(yScale));
svg.append("text")
.attr("transform", `translate(${width / 2},${height - 5})`)
.style("text-anchor", "middle")
.style("font-size", "12px")
.text("Experience");
svg.append("text")
.attr("transform", "rotate(-90)")
.attr("y", margin.left - 35)
.attr("x", -(height / 2))
.style("text-anchor", "middle")
.style("font-size", "12px")
.text("Accuracy");
svg.append("text")
.attr("x", width / 2)
.attr("y", margin.top)
.attr("text-anchor", "middle")
// .style("text-decoration", "underline")
.text("Final Accuracy on All Tasks");
return svg.node();
}
Another interesting phenomenon is how the heatmaps change across different selections of
However, for earlier layers of the model, the regularized metrics are noticeably lighter than those produced by CCA. And, for later layers, the regularized metrics are significantly darker! So, the regularized metrics are able to more drastically distinguish similarities between representations across the layers. And according to that, the changes in internal representations across algorithms is more concentrated in later layers, suggesting that forgetting is a result of changes in later layers.
Looking at layers 6, 8, and 11 with the regularized metrics, there seems to be a noticeable shift in internal representations between the CL algorithms and the naive approach at experience 5. This could signify that task 5 is significantly different from previously seen experiences. In that case, an analysis like this might help determine which tasks are “problematic” or difficult to learn. Or, it might signify a capacity of the model.
While looking at CAMs isn’t a rigorous analysis, it’s still fun!
viewof cam_controls_one is not defined
cams is not defined
Scrolling through the experiences and different layers, the CAMs tell a similar story to the heatmaps. The earlier layers produce CAMs that are visually similar, while the later layers have more obvious differences.
Most CAMs produced from the naive approach have far fewer “bright spots”: the CAMs are darker and more sparse. On the other hand, the CAMs produced by MAS are are much brighter and more activated than the LwF CAMs. This seems to be more pronounced in the later layers.
Building thousands of saliency maps and then running distance metrics on is computationally expensive and so was out of the scope of this project. But, that seems like the natural next step!
Do comparisons with regularized metrics between the CAMs of different algorithms tell the same story as the heatmaps we saw?
Is there something we can learn from this to design a better CL algorithm that improves saliency AND prevents forgetting?
If we did this for several CL algorithms, could we somehow cluster them by how similar their internal representations are? Could we do this by layer? What about for different architectures?
2024-09-16 08:03:56 UTC
If you see mistakes or want to suggest changes, please send me a message at nishantaswani@nyu.edu. Suggestions are appreciated!
Generated text and figures are licensed under Creative Commons Attribution CC BY 4.0. The figures that have been reused from other sources don’t fall under this license and can be recognized by a note in their caption: ‘Figure from …’
@online{aswani2023,
author = {Aswani, Nishant},
title = {Exploring the {Anatomy} of {Catastrophic} {Forgetting}},
date = {2023-05-12},
url = {https://nishantaswani.com/articles/anatomy.html},
langid = {en}
}