Overcoming catastrophic forgetting in NNs via Elastic Weight Consolidation - James Kirkpatrick

 This paper introduces a new way to sequential learning of tasks without forgetting previously learned tasks. In this regime, 

  •     Data/Tasks are presented to the model sequentially.
  •     The model needs to adapt to each task while remembering the knowledge from the previous tasks.
  •     There should not be a mechanism to store and replay the data from the prior tasks as the overhead would be large when there are too many tasks.
What is the crux of this paper?

The authors try to solve catastrophic forgetting in machine learning models by analogizing it to how human brains avoid it. Knowledge is encoded in the brain by reducing the plasticity of synapses vital to previously learned tasks. From the neuro biological models of synaptic consolidation - the authors devise EWC or elastic weight consolidation. This algorithm slows down learning of certain weights based on how important they were to the previous tasks.

How is it achieved?

Suppose a task A is learned by a deep neural network resulting in the parameter $\theta^*_{A}$. When learning the next task B, EWC protects task A by finding $\theta^*_{B}$ that is close to $\theta^*_{A}$. This constraint is implemented as a quadratic penalty - like a spring anchoring the new parameters to the old values.   

EWC algorithm can utilize Bayesian principles to achieve this. When there is a new task to be learned, the network parameters are tempered by a prior which is the posterior distribution on the parameters given the data from the previous tasks. 

$$\mathrm{log}\: p(\theta \mid \mathcal{D}) = \mathrm{log} \: p(\mathcal{D}_B \mid \theta) + \mathrm{log}\:  p(\theta \mid \mathcal{D}_A) - \mathrm{log}\:  p(\mathcal{D}_B)$$

As the true posterior is intractable, Laplace approximation is used to approximate the distribution as a Gaussian with mean $\theta*_A$ and diagonal precision given by the diagonal of Fisher information matrix. This can enable faster learning on parameters that poorly contribute to the previous tasks and slower learning on parameters crucial to the previous tasks. 

EWC tries to squeeze more functionality into a fixed network - how does this allocate the functionality? Does it allocate complete separate parts of the network to the different tasks or are the weights shared (a more efficient fashion)? The overlap of Fisher information matrices of each task - determines this information. If the overlap is high, then the weights are shared across the tasks.



Comments