Fisher Information Matrix 🐠Mar 2019
I learn Fisher Information Matrix before/during the goddamn military service.
Email me if you find this article helps or sucks (probably).
Interests in Fisher Information Matrix
Fisher Information Matrix is highly related to Hessian Matrix. Hessian matrix is a square matrix describing the second-order partial derivatives. As we learned in high school, second order information gives us one-step further information on the current curvature. This property allows efficient optimization.
Let's say we have a model parametrized by and we would like to optimize the likelihood w.r.t. . Usually, we aim at maximizing the log-likelihood instead. For the sake of convenience of the following sections, we define a score function as follows:
The score function is the graident of the log-likelihood.
The interpretation of the score function is that it measures the changes of the log-likelihood w.r.t. each parameter in
. Namely, suppose ,
the score function describes how the log-likelihood changes when or
Now, one natural question arises, what if we are interested in the relation between and . To this end, we need the covariance of the score function:
Before diving deeper, let's first calculate this problematic term: the mean of the scrore function . Don't worry. It will be surprisingly easy:
Now, becomes much easier:
If the exact form of the log-likelihood w.r.t. is known, we might be able to directly calculate . However, most of the time the likelihood is intractable. A workaround is that we can calculate empirical fisher if we have empirical samples drawn from .
Note that this formulation is very common. In generative model, we optimize the likelihood of the data (can be either continuous/discrete data) w.r.t. . In supervised learning, it's more straightforward. We optimize w.r.t. . Similarly, we can consider a reinforcement learning program as an inference problem that we optimize w.r.t such that the total return/cost is maximized/minimized (It might not be intuitive, but it's possible with some smart formulation).
The takeaway so far
The Fisher Information Matrix describes the covariance of the gradient of the log-likelihood function. Note that we call it "information" because the Fisher information measures how much the parameters tell us about the data.
🔨 Case study: Elastic weight consolidation
Figure 1. Illustration of the learning process of task B after that of task A.
tl;dr: EWC is an algorithm to avoid catastrophic forgetting in neural networks. It slows down learning on certain weights based on how important they are to previously seen tasks.
Let's say we have two tasks A and B. In continual learning, we first learn task A and then task B. When learning the second task B, the neural networks incline to forget the previously learned knowledge of the previous task (A). The learning process can be written as (corresponds to "no penalty" in figure 1):
To avoid forgetting the learned knowledge in task A, one simple trick is that we can minimize the distances between and . Thus, the learning process becomes (corresponds to "l2" in figure 1):
where is the scalar sets how important the old task is compared to the new one.
It turns out that the l2 constraint is so strong that it could hamper the learning process of task B. Here, we have one more observation: In neural networks, we often over-parametrize the models. There might be some parameters that are less useful and others are more valuable. In the l2 constraint case, each parameter is treated equally. Here, we want to use the diagonal components in Fisher Information Matrix to identify which parameters are more important to task A and apply higher weights to them. (corresponds to "EWC" in figure 1)
where is the diagonal of the Fisher Information Matrix and labels each parameter
tells us if the parameter is important to the previous task A. To compute , we sample the data from task A once and calculate the empirical Fisher Information Matrix as described before. If you also find it interesting, check the PyTorch implementation here @moskomule/ewc.pytorch. (tbh, I didn't run this code.)
The relation between Fisher Information Matrix and KL-divergence
This part is sort of mathness. Hang in there! 🧟
KL-divergence is widely used to measure the difference between two distributions. Here, we will prove that Fisher Information Matrix defines the local curvature in distribution space for which KL-divergence is the metric.
Relation between Fisher and Hessian
(skip to the last line of this subsection if you are not interested in it.) We begin to briefly prove the relation between Fisher and Hessian. Hessian Matrix is a square matrix of second-order partial derivatives of a scalar-valued function, which describe the local curvature. The Hessian matrix of the log-likelihood can be written as:
Now, let's take the expectation w.r.t. the current model :
Thus, we have: The negative of Fisher is the expectation of Hessian of log-likelihood.
Relation between Fisher and KL-divergence
With the conclusion above, we can move on to this interesting property: Fisher Information Matrix defines the local curvature in distribution space for which KL-divergence is the metric. Note that there are two components here: (1) local curvature (Hessian). (2) for which KL-divergence is the metric (KL between two distributions).
We start with checking the KL of two distributions and (assume that ) is very close to since we are interested in local curvature.
and take the first derivative w.r.t. :
the first line comes from the fact that KL can be decomposed as entropy and cross entropy. Then, we take the second derivative w.r.t. :
Thus, we have Hessian w.r.t. evaluate at (the expectation over samples from . In practice, we usually only have the samples from the current parameters , instead of future ones .):
Really interesting, right?