Source code for dprox.linop.base

import abc
import copy

import numpy as np
import torch
import torch.nn as nn

from dprox.utils import to_torch_tensor


def cast_to_const(expr):
    """Converts a non-LinOp to a Constant.
    """
    from .constant import Constant
    return expr if isinstance(expr, LinOp) else Constant(expr)


[docs]class LinOp(nn.Module): """ Abstract class for all linear operator. """ class MultOutput(list): pass instanceCnt = 0 def __init__(self, input_nodes=[]): super(LinOp, self).__init__() self.input_nodes = nn.ModuleList([cast_to_const(node) for node in input_nodes]) # count id self.linop_id = LinOp.instanceCnt LinOp.instanceCnt += 1 # create a dummy parameter to automatically infer device self.dummy = torch.nn.parameter.Parameter(torch.tensor(0), requires_grad=False) # will be set later by proximal algorithm to indicate current iteration step self.step = 0 # ---------------------------------------------------------------------------- # # Computation # # ---------------------------------------------------------------------------- #
[docs] @abc.abstractmethod def forward(self, inputs): """The forward operator. Compute x -> Kx """ return NotImplemented
@abc.abstractmethod def adjoint(self, inputs): """The adjoint operator. Compute x -> K^Tx """ return NotImplemented # ---------------------------------------------------------------------------- # # Diagonal # # ---------------------------------------------------------------------------- #
[docs] def is_gram_diag(self, freq=False): """Is the lin op's Gram matrix K^TK diagonal (in the frequency domain)? """ return self.is_diag(freq)
[docs] def is_diag(self, freq=False): """Is the lin op K diagonal (in the frequency domain)? """ return False
[docs] def get_diag(self, freq=False): """Returns the diagonal representation (K^TK)^(1/2). Parameters ---------- freq : bool Is the diagonal representation in the frequency domain? Returns ------- dict of variable to ndarray The diagonal operator acting on each variable. """ return NotImplemented
# ---------------------------------------------------------------------------- # # Property # # ---------------------------------------------------------------------------- # @property def device(self): return self.dummy.device @property def variables(self): """Return the list of variables used in the LinOp. """ vars_ = [] for arg in self.input_nodes: vars_ += arg.variables unordered = list(set(vars_)) # Make unique, order by uuid. return sorted(unordered, key=lambda x: x.uuid) @property def constants(self): """Returns a list of constants in the LinOp. """ consts = [] for arg in self.input_nodes: consts += arg.constants return consts
[docs] def is_constant(self): """Is the LinOp constant? """ return len(self.variables()) == 0
@property def value(self): inputs = [] for node in self.input_nodes: inputs.append(node.value) output = self.forward(*inputs) return output @property def offset(self): """Get the constant offset. """ old_vals = {} for var in self.variables: old_vals[var] = var.value var.value = torch.zeros_like(var.value) offset = self.value # Restore old variable values. for var in self.variables: var.value = old_vals[var] return offset
[docs] def norm_bound(self, input_mags): """Gives an upper bound on the magnitudes of the outputs given inputs. Parameters ---------- input_mags : list List of magnitudes of inputs. Returns ------- float Magnitude of outputs. """ return NotImplemented
# ---------------------------------------------------------------------------- # # Util # # ---------------------------------------------------------------------------- # @property def T(self) -> 'LinOp': """ The transpose :math:`A^T` of this linear operator :math:`A`. """ op = self.clone() op.forward, op.adjoint = op.adjoint, op.forward return op @property def gram(self) -> 'LinOp': """ The gram :math:`A^TA` of this linear operator :math:`A`$`. """ op = self.clone() forward, adjoint = op.forward, op.adjoint op.forward = lambda inputs: adjoint(forward(inputs)) op.adjoint = lambda inputs: forward(adjoint(inputs)) return op
[docs] def clone(self) -> 'LinOp': """ The deep copy of this linear operator. """ return copy.deepcopy(self)
def unwrap(self, value): from .placeholder import Placeholder if isinstance(value, Placeholder): return value.value return to_torch_tensor(value, batch=True) # ---------------------------------------------------------------------------- # # Python Magic # # ---------------------------------------------------------------------------- # def __add__(self, other): """Lin Op + Lin Op. """ other = cast_to_const(other) from .sum import sum args = [] for elem in [self, other]: if isinstance(elem, sum): args += elem.input_nodes else: args += [elem] return sum(args) def __mul__(self, other): """Lin Op * Number. """ from .scale import scale # Can only divide by scalar constants. if np.isscalar(other): return scale(other, self) else: raise TypeError("Can only multiply by a scalar constant.") def __rmul__(self, other): """Called for Number * Lin Op. """ return self * other def __truediv__(self, other): """Lin Op / Number. """ return self.__div__(other) def __div__(self, other): """Lin Op / Number. """ from .scale import scale # Can only divide by scalar constants. if np.isscalar(other): return scale(1. / other, self) else: raise TypeError("Can only divide by a scalar constant.") def __sub__(self, other): """Called for lin op - other. """ return self + -other def __rsub__(self, other): """Called for other - lin_op. """ return -self + other def __neg__(self): """The negation of the Lin Op. """ return -1 * self def __rmatmul__(self, other): # other @ self from .constaints import matmul from .variable import Variable if not isinstance(self, Variable): print('only support variable') return matmul(self, other) def __str__(self): """Default to string is name of class. """ return self.__class__.__name__ __array_priority__ = 10000