Image Deraining#
Proximal algorithms are typically used and more suitable to solve inverse problems with known forward models. Here, we explore the proximal algorithms for problems where the forward models are unknown and learned alongside the optimization.
Here, we consider image deraining, i.e., removing rain streaks from an image, as an example for this class of problems. This problem is severely ill-posed as the assignment to a rain layer and latent background layer is unknown a-priori.
As such, researchers typically employ end-to-end black-box neural networks that learn to predict latent clean images from paired datasets. Being able to compile differentiable solvers, ∇-Prox facilitates learnable forward operators by making LinOp learnable for joint optimization.
[1]:
# uncomment the following line to install dprox if your are in online google colab notebook
# !pip install dprox
[1]:
import torch
from dprox import *
from dprox.utils import *
Here is an example of the original image and the rainy image taken from `Rain100L <>`__.
[2]:
b = hf.load_image('examples/derain/derain_input.png')
gt = hf.load_image('examples/derain/derain_target.png')
imshow(gt, b, titles=['original image', 'rainy observation'], off_axis=True)
To solve the problem, we can unroll the iterations of the proximal gradient descent algorithm and learn the forward model, i.e., the degradation operator LearnalbeDegOp, and the unrolled image prior unrolled_prior.
[3]:
from dprox.contrib.derain import LearnableDegOp
# custom linop
A = LearnableDegOp().cuda()
def forward_fn(input, step): return A.forward(input, step)
def adjoint_fn(input, step): return A.adjoint(input, step)
raining = LinOpFactory(forward_fn, adjoint_fn)
# build solver
x = Variable()
data_term = sum_squares(raining(x), b)
reg_term = unrolled_prior(x)
obj = data_term + reg_term
solver = compile(obj, method='pgd')
# load parameters
ckpt = hf.load_checkpoint('image_deraining/derain_pdg.pth')
A.load_state_dict(ckpt['linop'])
reg_term.load_state_dict(ckpt['prior'])
rhos = ckpt['rhos']
with torch.no_grad():
out = solver.solve(x0=b, rhos=rhos, max_iter=7)
out = to_ndarray(out, debatch=True) + b
imshow(gt, out, titles=['original image', f'derained result (PSNR: {psnr(out, gt):.3f})'], off_axis=True)