Reverse Mode Differentiation with Vector Calculus

5 minute read

Published:

This is something back in 2016 when I was struggling with my first serious neural model. Previously I mostly use Scala for NLP. But there was no auto differentiation locally in Scala, so we ended up thinking to write our own package. The task, naturally, turned out to be waaaaay too non-trivial, and we soon gave it up days after. Well, anyway, here is the wrap-up.

Heads-up

The purpose of this note is to brief on how to write backpropagation. Manually:)

Why is this useful? Back to the days when auto-grad libraries (e.g. TensorFlow/Pytorch) were not yet full-fledged, I think this is basically how people train their networks using SGD.

If you want something sense-making mathematically, checkout the “Matrix Cookbook”. This note is just a simplied/optimized brief that works in practice, and so much easier to write on paper.

Finally, for backpropagation, all we need is to get gradient using:

The Forward Computation of “Chain Rule”

Say we have a n-layer feedforward network. Around the k-th layer, it’s like this:

where $\boldsymbol{\theta}$’s are the parameters and $\boldsymbol{x}$’s are naturally the input and output. And at the very end of the pipeline, we have a loss function:

where $g_k$ is a layer/function. So the gradients we are to get are partial derivatives on $\boldsymbol{\theta}$ and $\boldsymbol{x}$. For instance,

And the gradient on parameters:

Let’s pretend equation (3) and (4) are all we have, without infinitely expanding.

The Reverse-Mode Differentiation “Chain Rule” in Vector Calculus

The above chain rule is for notation in scalar world only. In vector world, it makes no bloody sense to write as so. The reason for that is the order of multiplication in vector space is not always commutative. That is the scalar multiplication can be generalized in two possible ways:

  • matrix multiplication (non-commutative)
  • element-wise multiplication $\odot$ (commutative)

Therefore the scalar chain rule needs to be generalized accordingly as well. The way is to treat partial derivative as both a function and a value. So the actual gradient of ${\partial J} / {\partial \boldsymbol{\theta}_k}$ becomes:

where we will keep using squared bracket ${\partial}/{\partial}[\cdot]$ to denote the “derivative function”, and ${\partial}/{\partial}$ value.

PS, what does it mean “a function and a value”? It’s all about the order of operation. If we write something like $A[B[C]]$, then the order of operation is enforced as: calculate $C$, then $B$, then $A$. And when considering $B$ as a value, it’s actually a tensor calculated from $B[C]$.

Well, what is happening inside of these “functions”?

What we gonna do here is to fill those gradients in equation (5).

To do that, let’s focus on the k-th layer. Assuming $\boldsymbol{x}$’s are row vectors, and the k-th layer is defined as $f_k(\boldsymbol{\theta}_k, \boldsymbol{x}_k) = \boldsymbol{x}_k \boldsymbol{W}_k \boldsymbol{V}_k$ (i.e. two linear transformations) (Yes, they can be merged into one, but let’s just say so.) where $\boldsymbol{\theta}_k = \{\boldsymbol{W}_k, \boldsymbol{V}_k\}$.

Remember that we haven’t yet defined the function $g$’s, so let’s assume the following gradients have already been calculated:

According to the definition of $f_k$, we can have the following immediately:

where $\boldsymbol{z}$ is an arbitrary tensor (e.g. the right hand side of equations of (6-7)). Now, plugging equations (6-10) back to equation (5) will yield the gradient .

Above is just a toy example. Let’s see something serious.

Gradient of LSTM

Let’s take the Graves LSTM formulation for example. The forward pass is

(refer to https://arxiv.org/pdf/1308.0850.pdf)

Imaging in the backpropagation of this lstm cell is a function with input $d\boldsymbol{h}$ which is the gradient of $\boldsymbol{h}_t$. Based on that, we are going to get gradients of those gates and hidden states: $d\boldsymbol{i}$, $d\boldsymbol{f}$, $d\boldsymbol{o}$, and $d\boldsymbol{c}$. And further get gradients of learnable parameters $d\boldsymbol{U}_i$, $d\boldsymbol{U}_f$, and so on and so on.

To this end, let’s first focus on equation (11) first. Here we want $d\boldsymbol{U}_i$. Assuming $d\boldsymbol{i}$ is already calculated, the gradient of $\boldsymbol{U}_i$ is:

where $(\cdot)$ represents the linear summation in equation (11). Referring back to the above toy example, what we gonna do next is to get the definition of these “derivative functions”:

where $\boldsymbol{z}$ is a generalized notation for arguments. Plugging equations (17-19) back to (16) yields:

A simple way of sanity check is to make sure the dimensionality of $d\boldsymbol{U}_i$ matches that of $\boldsymbol{U}_i$. Remember $\boldsymbol{x}$ is a row vector, so the dimensionality matches here.

Using the same methods above, we can get $d\boldsymbol{U}_i$, $d\boldsymbol{W}_i$, $d\boldsymbol{V}_i$, $d\boldsymbol{U}_f$, $d\boldsymbol{W}_f$, $d\boldsymbol{V}_f$, $d\boldsymbol{U}_o$, $d\boldsymbol{W}_o$, and $d\boldsymbol{V}_o$. A little extra effort is required to get $d\boldsymbol{U}_c$, and $d\boldsymbol{W}_c$:

where $(\cdot)$ denotes the linear form in equation (13).

Thus, . Similarly we can get $d\boldsymbol{W}_c$. Next, we need calculate what exactly {$d\boldsymbol{i}$, $d\boldsymbol{f}$, $d\boldsymbol{o}$, and $d\boldsymbol{c}$} are, given $d\boldsymbol{h}$.

where $(\cdot)$ represents the linear form in equation (14). And

And

where $(\cdot)$ represents the linear form in equation (13). And

Do the calculation for equations (25-27) will yield what we wanted at the very beginning. Don’t forget those bias terms whose gradients are easy to get. Furthermore, beware of that and should be summed over partial derivatives from equation (11-14).

Gradient Checking

Make sure all manually written gradients pass gradient check:)

Conclusion

Overall, deriving gradients manually is a lot of fun. The whole process becomes unsustainable when it comes to modeling stage which involves constant changes in network architecture. And besides that, performance throttle can be a big issue, especially compared with CUDNN’s builtin LSTM/GRU cell. But still it’s good to know what effectively happened in those backward calls in modern learning libraries.

Updates 10-05-2018: Reverse mode differentiation

Updates 05-17-2019: Minor fixes