from typing import Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from dprox.utils.misc import batchify, to_ndarray, to_torch_tensor
from dprox.utils.psf2otf import psf2otf
from .base import LinOp
from .placeholder import Placeholder
[docs]class conv(LinOp):
"""Circular convolution of the input with a kernel.
"""
def __init__(self, arg, kernel):
self.kernel = to_ndarray(kernel)
self.cache = {}
super(conv, self).__init__([arg])
def _FB(self, shape):
if shape not in self.cache:
_, C, H, W = shape
FB = psf2otf(self.kernel, [H, W, C])
FB = batchify(torch.from_numpy(FB))
self.cache[shape] = FB
return self.cache[shape]
[docs] def forward(self, input, **kwargs):
FB = self._FB(input.shape).to(input.device)
Fx = torch.fft.fftn(input, dim=[-2, -1])
output = torch.real(torch.fft.ifftn(FB * Fx, dim=[-2, -1])).float()
return output
[docs] def adjoint(self, input, **kwargs):
FB = self._FB(input.shape).to(input.device)
Fx = torch.fft.fftn(input, dim=[-2, -1])
output = torch.real(torch.fft.ifftn(torch.conj(FB) * Fx, dim=[-2, -1])).float()
return output
def is_diag(self, freq=False):
return freq and self.input_nodes[0].is_diag(freq)
def get_diag(self, x, freq=False):
assert freq
FB = self._FB(x.shape)
# var_diags = self.input_nodes[0].get_diag(shape, freq)
self_diag = torch.abs(torch.conj(FB) * FB)
# for var in var_diags.keys():
# var_diags[var] = var_diags[var] * self_diag
return self_diag.to(self.device)
def norm_bound(self, input_mags):
return np.max(np.abs(self.forward_kernel)) * input_mags[0]
def psf2otf2(psf, output_size):
_, _, fh, fw = psf.shape
# pad out to output_size with zeros
if output_size[2] != fh:
pad = (output_size[2] - fh) / 2
if (output_size[2] - fh) % 2 != 0:
pad_top = pad_left = int(np.ceil(pad))
pad_bottom = pad_right = int(np.floor(pad))
else:
pad_top = pad_left = int(pad) + 1
pad_bottom = pad_right = int(pad) - 1
padded = F.pad(input=psf, pad=[pad_left, pad_right, pad_top, pad_bottom], mode="constant")
else:
padded = psf
# circularly shift so center pixel is at 0,0
padded = torch.fft.ifftshift(padded)
otf = torch.fft.fft2(padded)
return otf
[docs]class conv_doe(LinOp):
"""Circular convolution of the input with a kernel.
"""
def __init__(
self,
arg: LinOp,
psf: Union[Placeholder, torch.Tensor, np.array],
circular: bool = False
):
super().__init__([arg])
self._psf = psf
self.circular = circular
if isinstance(psf, Placeholder):
def on_change(val):
self.psf = nn.parameter.Parameter(val)
self._psf.change(on_change)
else:
self.psf = nn.parameter.Parameter(to_torch_tensor(psf, batch=True))
[docs] def forward(self, img, **kwargs):
psf = self.psf.to(img.device)
if not self.circular:
# linearized conv
target_side_length = 2 * img.shape[2]
height_pad = (target_side_length - img.shape[2]) / 2
width_pad = (target_side_length - img.shape[3]) / 2
pad_top, pad_bottom = int(np.ceil(height_pad)), int(np.floor(height_pad))
pad_left, pad_right = int(np.ceil(width_pad)), int(np.floor(width_pad))
img = F.pad(input=img, pad=[pad_left, pad_right, pad_top, pad_bottom], mode="constant")
otf = psf2otf2(psf, img.shape)
Fx = torch.fft.fftn(img, dim=[-2, -1])
output = torch.real(torch.fft.ifftn(otf * Fx, dim=[-2, -1])).float()
if not self.circular:
output = output[:, :, pad_top:-pad_bottom, pad_left:-pad_right]
return output
[docs] def adjoint(self, img, **kwargs):
psf = self.unwrap(self.psf).to(img.device)
if not self.circular:
# linearized conv
target_side_length = 2 * img.shape[2]
height_pad = (target_side_length - img.shape[2]) / 2
width_pad = (target_side_length - img.shape[3]) / 2
pad_top, pad_bottom = int(np.ceil(height_pad)), int(np.floor(height_pad))
pad_left, pad_right = int(np.ceil(width_pad)), int(np.floor(width_pad))
img = F.pad(input=img, pad=[pad_left, pad_right, pad_top, pad_bottom], mode="constant")
otf = psf2otf2(psf, img.shape)
Fx = torch.fft.fftn(img, dim=[-2, -1])
output = torch.real(torch.fft.ifftn(torch.conj(otf) * Fx, dim=[-2, -1])).float()
if not self.circular:
output = output[:, :, pad_top:-pad_bottom, pad_left:-pad_right]
return output
def is_diag(self, freq=False):
return freq and self.input_nodes[0].is_diag(freq)
def get_diag(self, x, freq=False):
assert freq
psf = self.unwrap(self.psf).to(x.device)
otf = psf2otf2(psf, x.shape)
# var_diags = self.input_nodes[0].get_diag(shape, freq)
self_diag = torch.abs(torch.conj(otf) * otf)
# for var in var_diags.keys():
# var_diags[var] = var_diags[var] * self_diag
return self_diag.to(self.device)
def norm_bound(self, input_mags):
return np.max(np.abs(self.forward_kernel)) * input_mags[0]