Source code for dprox.linop.variable

import uuid

import torch

from .base import LinOp


[docs]class Variable(LinOp): """A variable. """ def __init__(self, shape=None, value=None, name=None): super(Variable, self).__init__([]) self.uuid = uuid.uuid1() self._value = value self.shape = shape self.varname = name self.initval = None # ---------------------------------------------------------------------------- # # Computation # # ---------------------------------------------------------------------------- #
[docs] def forward(self, inputs, **kwargs): """The forward operator. Reads from inputs and writes to outputs. """ return inputs
[docs] def adjoint(self, inputs, **kwargs): """The adjoint operator. Reads from inputs and writes to outputs. """ return inputs
# ---------------------------------------------------------------------------- # # Diagonal # # ---------------------------------------------------------------------------- # def is_diag(self, freq=False): """Is the lin op diagonal (in the frequency domain)? """ return True def get_diag(self, ref, freq=False): """Returns the diagonal representation (A^TA)^(1/2). Parameters ---------- freq : bool Is the diagonal representation in the frequency domain? Returns ------- dict of variable to ndarray The diagonal operator acting on each variable. """ return torch.ones(ref.shape) # ---------------------------------------------------------------------------- # # Property # # ---------------------------------------------------------------------------- # @property def variables(self): return [self] @property def value(self): return self._value.to(self.device) @value.setter def value(self, val): """Assign a value to the variable. """ self._value = val def norm_bound(self, input_mags): """Gives an upper bound on the magnitudes of the outputs given inputs. Parameters ---------- input_mags : list List of magnitudes of inputs. Returns ------- float Magnitude of outputs. """ return 1.0 # ---------------------------------------------------------------------------- # # Python Magic # # ---------------------------------------------------------------------------- # def __repr__(self): return f'Variable(id={self.uuid}, shape={self.shape}, value={"None" if self._value is None else "somevalue"})'