Computing Gradients

The field of matrix calculus, which is a part of multivariable calculus, is really not that different from the single-variable calculus. In multivariable calculus, functions are from $\mathbb{R}^n$ to $\mathbb{R}^m$, and you use gradients, Jacobians, etc. to express these notions. Matrix calculus organizes these derivatives when dealing with matrices and vectors. But the principles are similar: you're still applying the chain rule, product rule, etc., just in higher dimensions with some absurd notations.

  • From calculus 101: Given a function $f: (x_1, x_2, \cdots , x_n) \in \mathbb{R}^n \mapsto f(\mathbf{x}) \in \mathbb{R}$, the gradient $\partial f / \partial \mathbf{x} \in \mathbb{R}^n$ eqauls

$$ \frac{\partial f}{\partial \mathbf{x}} = \left(\frac{\partial f}{\partial x_1}, \cdots, \frac{\partial f}{\partial x_n} \right) $$

  • A linear transformation with $\mathbb{R}^n$ input and $\mathbb{R}^m$ output, is in $m \times n$ matrix (in column vector representation).
  • The Jacobian matrix, which is a $m \times n$ matrix of partial derivatives, is a linear mapping from $\mathbf{x} \in \mathbb{R}^n$ to $\mathbf{f}(\mathbf{x}) \in \mathbb{R}^m$. $\partial \mathbf{f} / \partial \mathbf{x} \in \mathbb{M}^{m \times n}$ equals

$$ \frac{\partial \mathbf{f}}{\partial \mathbf{x}} = \begin{bmatrix} \frac{\partial f_1}{\partial x_1} & \cdots & \frac{\partial f_1}{\partial x_n} \ \vdots & & \vdots \ \frac{\partial f_m}{\partial x_1} & \cdots & \frac{\partial f_m}{\partial x_n} \ \end{bmatrix}

\quad \text{so that } \left( \frac{\partial \mathbf{f}}{\partial \mathbf{x}} \right)_{ij} = \frac{\partial f_i}{\partial x_j} $$

We often write the Jacobian $\partial \mathbf{f} / \partial \mathbf{x}$ as $\partial \mathbf{f}(\mathbf{x})$, which maps the tangent space of the domain of $\mathbf{f}$ around the point $\mathbf{x}$ to the tangent space of the codomain of $\mathbf{f}$ around the point $\mathbf{f}(\mathbf{x})$.

Some building blocks

  1. Consider a simple layer of composition with $f: \mathbb{R}^n \to \mathbb{R}^n$: $$ \mathbf{h} = f (\mathbf{z}) \quad \text{i.e.} \quad h_i = f(z_i) $$

$$ \begin{aligned} \left( \frac{\partial \mathbf{h}}{\partial \mathbf{z}} \right)_{ij} = \frac{\partial h_i}{\partial z_j} = \frac{\partial f(z_i)}{\partial z_j} &= \begin{cases} f^\prime(z_i) & \text{if }i = j \ 0 & \text{otherwise } \ \end{cases} \quad \therefore \frac{\partial \mathbf{h}}{\partial \mathbf{z}} &= \text{diag}(f^\prime(\mathbf{z})) \end{aligned} $$

  1. Consider a linear layer: $$ \mathbf{z} = \mathbf{W} \mathbf{x} + \mathbf{b} \quad \text{i.e.} \quad z_i = W_{ij} x_j + b_i $$

$$ \begin{aligned} \left( \frac{\partial \mathbf{z}}{\partial \mathbf{x}} \right)_{ij} = \frac{\partial \mathbf{z}_i}{\partial \mathbf{x}j} = W{ij} \quad \therefore\frac{\partial \mathbf{z}}{\partial \mathbf{x}} = \mathbf{W} \end{aligned} $$

$$ \begin{aligned} \left( \frac{\partial \mathbf{z}}{\partial \mathbf{b}} \right)_{ij} = \frac{\partial \mathbf{z}_i}{\partial \mathbf{b}j} = \delta{ij} \quad \therefore\frac{\partial \mathbf{z}}{\partial \mathbf{x}} = \mathbf{I} \end{aligned} $$

  1. For two vectors $\mathbf{u}$ and $\mathbf{v}$, the dot product is defined as $\mathbf{u}^T \mathbf{v} \in \mathbb{R}$. $$ \begin{aligned} \left( \frac{\partial (\mathbf{\mathbf{u}^T \mathbf{v}})}{\partial \mathbf{u}} \right)_{i} = \frac{\partial(u_j v_j)}{\partial u_i} = \begin{cases} v_i & \text{if }i = j \ 0 & \text{otherwise } \ \end{cases} \quad \therefore \frac{\partial (\mathbf{\mathbf{u}^T \mathbf{v}})}{\partial \mathbf{u}} = \mathbf{v}^T \end{aligned} $$

The reason behind the choice of $\textbf{v}^T$ over $\textbf{v}$ seems somewhat unclear. The mapping from the given input $x \in \mathbb{R}^n$ and a tangent vector $v \in \mathbb{R}^n$ to the output tangent vector in $\mathbb{R}^m$, denoted by $(x, v) \mapsto \partial f(x) v$, with $\partial f: \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m$.

A simple example: MLP and backpropagation

A multi-layer perceptron (MLP) with one hidden layer can be represented mathematically as follows: $$ s = \mathbf{u}^T \mathbf{h} \longleftarrow \mathbf{h} = f(\mathbf{z}) \longleftarrow \mathbf{z} = \mathbf{W} \mathbf{x} + \mathbf{b} \longleftarrow \mathbf{x} \text{ (input)} $$

$$ \begin{aligned} \frac{\partial s}{\partial \mathbf{z}} &= \frac{\partial s}{\partial \mathbf{h}} \frac{\partial \mathbf{h}}{\partial \mathbf{z}} \ &= \mathbf{u}^T \text{diag}(f^\prime(z)) \ &= \mathbf{u}^T \circ f^\prime(z) \ &:= \delta \text{ (local error signal)} \end{aligned} $$

$$ \begin{aligned} \therefore \frac{\partial s}{\partial \mathbf{b}} &= \frac{\partial s}{\partial \mathbf{h}} \frac{\partial \mathbf{h}}{\partial \mathbf{z}} \frac{\partial \mathbf{z}}{\partial \mathbf{b}} \

&= \delta \circ \mathbf{I} \end{aligned} $$ In addition,

$$ \begin{aligned} \left( \frac{\partial \mathbf{z}} {\partial \mathbf{W}} \right)_{ijk} &= \frac{\partial \mathbf{z}i} {\partial \mathbf{W}{jk}} &= \begin{cases} \mathbf{x}_k & \text{ if } i = j \ 0& \text{ otherwise } \ \end{cases} = \mathbf{x} \mathbf{I} \text{ or } \mathbf{Ix}^T \end{aligned} $$

$$ \begin{aligned} \therefore \frac{\partial s}{\partial \mathbf{W}} &= \frac{\partial s}{\partial \mathbf{h}} \frac{\partial \mathbf{h}}{\partial \mathbf{z}} \frac{\partial \mathbf{z}}{\partial \mathbf{W}} \

&= \delta^T \circ \mathbf{x}^T \end{aligned} $$

Computers turn

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))

A general grad is expected to take a function and returns a (differentiated=gradient) function. This means if you have a Python function f that evaluates the mathematical function, then grad(f) is a Python function that evaluates the mathematical function.That means grad(f)(x) represents the resulting value.