Matrix-free Differentiable Linear Solver#
In this tutorial, we provide a step by step derivation of the matrix-free differentiable linear solver mentioned in \(\nabla\)-Prox.
Recall that our goal is to find the gradient of the output of a linear solver \(\bar{x}\)
with respect to the parameters in the solved linear system, such as \(\frac{\partial \bar{x}}{\partial K}\) and \(\frac{\partial \bar{x}}{\partial b}\).
[1]:
# uncomment the following line to install dprox if your are in online google colab notebook
# !pip install dprox
[2]:
import torch
from torch.autograd.functional import jacobian
Naive Approach with Auto-Diff#
Let us first derive the gradient with auto-differentiation.
[3]:
torch.manual_seed(0)
theta = torch.randn((32,32), requires_grad=True) # define parameter of the linOp K
K = theta * 2
x = torch.randn((32))
b = K @ x
b = b.clone().detach().requires_grad_(True)
xhat = torch.linalg.solve(K, b)
loss = xhat.mean()
loss.backward()
print(theta.grad.shape)
print(b.grad.shape)
torch.Size([32, 32])
torch.Size([32])
Implicit Differentiation#
Auto-diff can be used to efficiently differentiate fast direct linear solvers but is often intractable for iterative linear solvers.
In \(\nabla\)-Prox, we provide an optimized routine to compute the analytic derivatives of linear (iterative) solver outputs with respect to the parameters of linear operators \(\theta\) and \(b\).
Derivation of \(\frac{\partial \bar{x}}{\partial b}\)#
Specifically, we differentiate both sides of \(K\bar{x} =b\) to obtain the derivatives \(\frac{\partial \bar{x}}{\partial b}\) and \(\frac{\partial \bar{x}}{\partial \theta}\) as
from which the gradient \(\frac{\partial \bar{x}}{\partial b} = K^{-1}\) can be easily derived. Typically, we are more interested in the gradient of \(b\) with respect to a scalar loss function \(\mathcal{L}\), which can be obtained with the chain rule of differential calculus.
Since all the linear operators in our system are matrix-free, we cannot directly evaluate the above formula for gradient computing. Instead, we transform it into
where the right-hand-side is the Jacobian of \(\mathcal{L}\) with respect to \(x\) that can be efficiently evaluated with auto-diff systems. The calculation of gradient \(\frac{\partial \mathcal{L}}{\partial b}\) has thus been converted into solving a linear system, requiring significantly less memory.
The above derivation assumes the gradient layout of
dx/db = [dx1/db1, dx1/db2, …, dxn/dbn; dx2/db1, …; dx3/db1, …]
Note that the gradient layout of torch.autograd.functional.jacobian is the same as above.
See also: https://en.wikipedia.org/wiki/Matrix_calculus#Layout_conventions
[4]:
torch.manual_seed(0)
theta = torch.randn((5,5), requires_grad=True)
K = theta * 2
x = torch.randn(5)
b = K @ x
b = b.clone().detach().requires_grad_(True)
xhat = torch.linalg.solve(K, b)
xhat.retain_grad() # retain non-leaf gradient for analytical compute
loss = xhat.mean()
loss.backward()
# analytical gradient using implicit differentiation
db = torch.inverse(K.T) @ xhat.grad
db2 = torch.linalg.solve(K.T, xhat.grad)
# analytical gradient versus auto-grad
print(b.grad)
print((b.grad - db).abs().max())
print(torch.allclose(b.grad, db, rtol=1e-6))
print(torch.allclose(b.grad, db2, rtol=1e-6))
tensor([ 0.1406, -0.0938, 0.2671, -0.1739, -0.1323])
tensor(1.4901e-08, grad_fn=<MaxBackward1>)
True
True
Derivation of \(\frac{\partial \bar{x}}{\partial \theta}\)#
Similarly, the gradient \(\frac{\partial \mathcal{L}}{\partial \theta}\) with respect to the parameters \(\theta\) of the linear operator \(K\) can be derived as
Again, \(\frac{\partial K}{\partial \theta}\) cannot be evaluated directly as we consider matrix-free linear operators. To circumvent this obstacle, we use the fact that
to transform it into
where \(\frac{\partial b}{\partial \theta}\) can be computed by backpropagating the forward computation \(K\bar{x}=b\). As such, the calculation of gradients \(\frac{\partial \mathcal{L}}{\partial b}\) and \(\frac{\partial \mathcal{L}}{\partial \theta}\) is converted into solving linear systems during backpropagation without requiring storing intermediate states, thereby significantly reducing memory consumption and saving computation time.
Note that we assume \(\theta\) to have the same shape as \(K\), so that the shape of \(\frac{\partial K}{\partial \theta}\bar{x}=\frac{\partial b}{\partial \theta}\) holds. The gradient with the real \(\bar{\theta}\) can be automatically tracked by auto-diff if we know the function that transforms \(\theta\) into \(\bar{\theta}\).
Reference Implementation with the Explicit Matrix
Suppose \(K \in \mathrm{R}^{R\times C}\), $:nbsphinx-math:theta `:nbsphinx-math:in :nbsphinx-math:mathrm{R}`^{R2:nbsphinx-math:times `C2} $, :math:`x in mathrm{R}^{N}, \(b \in \mathrm{R}^{N}\).
Since \(K\) is a square matrix, \(R\), \(C\), \(R2\), \(C2\), \(N\) are of the same value. We simply use different symbols to better illustrate the gradient layout.
Note that this might be confused for the shape computation. However, keep in mind, that we are interested in the gradient layout. As all \(R\), \(C\), \(R2\), \(C2\), \(N\) are of the same value, they are valid for matrix multiplications.
[5]:
# define a linOp depending on the parameter theta
Kmat = lambda theta: theta * 2
dK_dtheta = jacobian(Kmat, theta) # [R x C] x [R2 x C2]
# In PyTorch, dK_dtheta @ xhat is recognized as batched matrix multiplication
# It would be [R x C] x [R2 x C2] @ [N x 1]
# so dK_dtheta @ xhat actually returns [R x C] x R2
# Method 1
# R x C @ [R x C] x R2 = R x C x R2
dxhat_dtheta = - K.inverse() @ dK_dtheta @ xhat
# Method 2
# In theory, dxhat_dtheta should be N x [R2 x C2],
# but torch.linalg.solve returns [R2 x C2] x N,
# Note: K = [R x C], -(dK_dtheta @ xhat) = [R2 x C2](batch size) x N x 1
dxhat_dtheta = torch.linalg.solve(K, -(dK_dtheta @ xhat).unsqueeze(-1)).squeeze(-1)
# Therefore, we do not need to transpose dxhat_dtheta here.
dloss_dtheta = dxhat_dtheta @ xhat.grad
print(torch.mean(torch.abs(dloss_dtheta - theta.grad)))
print(torch.allclose(dloss_dtheta, theta.grad, rtol=1e-6))
tensor(2.3991e-08, grad_fn=<MeanBackward0>)
True
[6]:
torch.manual_seed(0)
theta = torch.randn((5,5), requires_grad=True)
x = torch.randn(5)
def f(theta):
K = theta * 2
b = K @ x
b = b.clone().detach().requires_grad_(True)
xhat = torch.linalg.solve(K, b)
return xhat
# Directly evaluate df_dtheta using auto-grad (note that this naive approach scales very poorly)
jab = jacobian(f, theta)
xhat = xhat.clone().detach().requires_grad_()
# xhat.retain_grad()
loss = xhat.mean()
loss.backward()
dtheta = jab.permute(1,2,0) @ xhat.grad
print(torch.mean(torch.abs(dtheta - dloss_dtheta)))
print(jab.permute(1,2,0)[0])
print(dxhat_dtheta[0])
tensor(3.3677e-08, grad_fn=<MeanBackward0>)
tensor([[-1.0499e-01, 1.0606e-02, 3.8049e-02, 9.7144e-05, -9.9203e-01],
[ 3.4649e-02, -3.5001e-03, -1.2557e-02, -3.2060e-05, 3.2739e-01],
[-1.1986e-01, 1.2108e-02, 4.3439e-02, 1.1091e-04, -1.1326e+00],
[ 4.9314e-02, -4.9815e-03, -1.7872e-02, -4.5629e-05, 4.6596e-01],
[-2.3773e-01, 2.4014e-02, 8.6152e-02, 2.1996e-04, -2.2462e+00]])
tensor([[-1.0499e-01, 1.0606e-02, 3.8049e-02, 9.7131e-05, -9.9203e-01],
[ 3.4649e-02, -3.5002e-03, -1.2557e-02, -3.2051e-05, 3.2739e-01],
[-1.1986e-01, 1.2108e-02, 4.3439e-02, 1.1087e-04, -1.1326e+00],
[ 4.9314e-02, -4.9815e-03, -1.7871e-02, -4.5622e-05, 4.6596e-01],
[-2.3773e-01, 2.4014e-02, 8.6152e-02, 2.1991e-04, -2.2462e+00]],
grad_fn=<SelectBackward0>)
Matrix-Free Reference Implementation
Suppose \(K \in \mathrm{R}^{R\times C}\), and \(\theta \in \mathrm{R}^{R2 \times C2}\) , and \(x \in \mathrm{R}^{N}\) , \(b \in \mathrm{R}^{N}\).
Since \(K\) is a square matrix, \(R\), \(C\), \(R2\), \(C2\), \(N\) are of the same value. We simply use different symbols to better illustrate the gradient layout.
[7]:
torch.manual_seed(0)
theta = torch.randn((5,5), requires_grad=True)
K = theta * 2
x = torch.randn(5)
b = K @ x
b = b.clone().detach().requires_grad_(True)
xhat = torch.linalg.solve(K, b)
xhat.retain_grad()
loss = xhat.mean()
loss.backward()
def linop(theta):
return (theta*2) @ xhat
# R2 x C2 here works like batch size
db_dtheta = jacobian(linop, theta).permute(1,2,0).unsqueeze(-1) # [R2 x C2] x N x 1
# in theory, dxhat_dtheta should be N x [R2 x C2],
# but torch.linalg.solve return [R2 x C2] x N,
# Note: K = [R x C], db_dtheta = [R2 x C2](batch size) x N x 1
dxhat_dtheta = torch.linalg.solve(K, -db_dtheta).squeeze(-1)
# therefore, we don't need to transpose dxhat_dtheta here.
dLoss_dtheta = dxhat_dtheta @ xhat.grad
print(dLoss_dtheta.shape)
print(torch.mean(torch.abs(dLoss_dtheta - theta.grad) / torch.max(torch.abs(theta.grad))))
print(torch.allclose(dLoss_dtheta, theta.grad, rtol=1e-6))
torch.Size([5, 5])
tensor(2.6605e-08, grad_fn=<MeanBackward0>)
True