Source code for dprox.proxfn.base

import numpy as np

from dprox.linop import LinOp, Placeholder, CompGraph
from dprox.utils import to_torch_tensor
import torch.nn as nn


def exists(x):
    return x is not None


def prox_scaled(prox, alpha):
    def _prox(v, lam):
        return prox(v, lam * alpha)
    return _prox


def prox_affine(prox, beta):
    def _prox(v, lam):
        return 1. / beta * prox(beta * v, beta * beta * lam)
    return _prox


def prox_translated(prox, b):
    def _prox(v, lam):
        return prox(v - b, lam) + b
    return _prox


[docs]class ProxFn(nn.Module): """ The abstract class for the proximal operator. f(x) = argmin_x f(x) + 1/(2*lam) * ||x-v||_2^2 """ def __init__(self, linop: LinOp, alpha=1, beta=1): super().__init__() self.linop = linop self.alpha = alpha self.beta = beta self.step = 0 self.dag = CompGraph(linop, zero_out_constant=True) @property def offset(self): return -self.linop.offset def unwrap(self, value): if isinstance(value, Placeholder): return value.value return to_torch_tensor(value, batch=True).to(self.linop.device) def eval(self, v): return NotImplementedError def prox(self, v, lam): """ v: [B,C,H,W], lam: [B] """ if len(lam.shape) == 1: lam = lam.view(lam.shape[0], 1, 1, 1) fn = self._prox fn = prox_scaled(fn, self.alpha) fn = prox_affine(fn, self.beta) fn = prox_translated(fn, self.offset) return fn(v, lam) def convex_conjugate_prox(self, v, lam): # use Moreau’s identity return v - self.prox(v / lam, lam) def _prox(self, v, lam): return NotImplementedError # def grad(self, x): # x_ = x.detach().requires_grad_(True) # self.eval(x_).backward() # return x_.grad def __mul__(self, other): if np.isscalar(other) and other > 0: self.alpha = other return self return TypeError("Can only multiply by a positive scalar.") def __rmul__(self, other): """Called for Number * ProxFn. """ return self * other def __add__(self, other): """ProxFn + ProxFn(s). """ if isinstance(other, ProxFn): return [self, other] elif type(other) == list: return [self] + other else: return NotImplemented def __radd__(self, other): """Called for list + ProxFn. """ if type(other) == list: return other + [self] else: return NotImplemented def __str__(self): return f'{self.__class__.__name__}'