Memory Aware Synapses - Rahaf Aljundi


This paper is about an approach that is used for solving task-based sequential learning problems in Neural Networks. In TBSL, datasets arrive as tasks sequentially, and the model trains them one by one as they arrive. The data once trained is lost and cannot be obtained to re-train the model again. If the new data fed to the model is very different from the one it was trained on before, it will completely adapt to serve the new data and forget what it learned before. This issue is termed as catastrophic forgetting.

All the experiments that tested the method discussed in the paper involves classification tasks.

What is the crux of the paper?
This paper is about doing continual learning within a single neural network itself.
The authors have identified that in order for the model to remember the older tasks, the important parameters for a particular task should not be overwritten during training of future tasks. The important parameters are generally referred to as the weights of the neural network that contribute the most to the final output of the network.
However, in the course of this paper, the authors have assumed that important parameters are those which when perturbed slightly bring a huge change in the output of the neural network.
The parameters which are considered to be important for past tasks should be updated very minimally(in case it need to be reused) or not at all during training of future tasks.

How is this achieved?
This is achieved by assigning weights called importance weights to every parameter $\theta_{ij}$ (connects $j^{th}$ and $i^{th}$ neuron). These weights signify the importance of its parameter. Higher the importance weight, higher the importance of the parameter.
Once a new task arrives, its training loss will have two components: The first component corresponds to the usual training loss incurred for this new task. The second component is a regularization term that grants a penalty for updating parameters that have high importance weights.
Once the training of the particular task reaches convergence, the importance weights are updated to accumulate the information of that task.

Delving into the technicalities

Calculating the importance weights: 

The importance weight of a parameter is high if any change in the value of that parameter brings a huge change in the output of the neural n/w(huge change in $f$, the input-output mapping function that we are trying to model as it is a classification task). What mathematical quantity can be used to describe the change in a function when its parameters are getting changed? Yes, that's right, GRADIENT!
This can be mathematically expressed by: $g_{ij}(x_{n})=\nabla_{\theta_{ij}}f(x_{n};\theta)$. However, as $f$ is a multi-dimensional vector, for every data point n in the task, there will have to be $k$ backpropagation steps where k is the no of categories in the classification task. Therefore, instead of using the gradient directly, the $L_{2}$ norm of the gradient is used. Now the output reduces to a scalar quantity and there's no backprop overhead: $g_{i,j}(x_{n})=\nabla_{\theta_{ij}}\parallel f(x_{n};\theta)\parallel_{2}^{2}$.
Finally, the importance weight for each parameter $\theta_{ij}$ is calculated by averaging the sum of gradient magnitudes over all representative data points of the current task.
$\Omega_{ij}=\frac{1}{N}\sum_{n=1}^{N}\parallel g_{ij}(x_{n})\parallel$
Note that for calculating importance weights, the labels weren't used at all - Hence the beauty of this method is that it can be extended to compute importance weights even when labelled data is not available.
Question: Is omega cumulatively summed over as the tasks go by? Then won't omega become very high as the no of tasks increases a lot?

Designing the importance weight penalty for the loss function: 
This section describes how the penalty term is designed to make sure that the important parameters do not get updated during the training procedure.
Given a task t, if $\Omega_{ij}$ is high, the corresponding $\theta_{ij}^{t}$ should not be updated much - instead should be as close to $\theta_{ij}^{t-1}$ as possible.
Suppose if $\Omega_{ij}$ is low, there should be room for $\theta_{ij}^{t}$ to vary from $\theta_{ij}^{t-1}$ (lesser penalty in updating $\theta_{ij}^{t}$ from the case above)
This objective can be attained by minimizing the following product: $\Omega_{ij}(\theta_{ij}^{t}-\theta_{ij}^{t-1})^{2}$.
The difference between $\theta_{ij}^{t}$ and $\theta_{ij}^{t-1}$ is squared so that the loss doesn't become negative and gradient computation is made easy.
So the final loss function can be written as: $L(\theta)=\sum_{n=1}^{N}l_{t}(f(x_{n};\theta),y_{n})+\lambda\sum_{i,j}\Omega_{ij}(\theta_{ij}^{t}-\theta_{ij}^{t-1})^{2}$

Question: Should you care about the gradient property?

Link to Original Paper: https://arxiv.org/abs/1711.09601

Comments