Source code for dprox.linop.scale
import numpy as np
import torch
from .base import LinOp
[docs]class scale(LinOp):
"""Multiplication scale*X with a fixed scalar.
"""
def __init__(self, scalar, arg):
assert np.isscalar(scalar)
self.scalar = scalar
super(scale, self).__init__([arg])
# ---------------------------------------------------------------------------- #
# Computation #
# ---------------------------------------------------------------------------- #
[docs] def forward(self, input, **kwargs):
"""The forward operator.
Reads from inputs and writes to outputs.
"""
return input * self.scalar
[docs] def adjoint(self, input, **kwargs):
"""The adjoint operator.
Reads from inputs and writes to outputs.
"""
return self.forward(input)
# ---------------------------------------------------------------------------- #
# Diagonal #
# ---------------------------------------------------------------------------- #
def is_gram_diag(self, freq=False):
"""Is the lin Gram diagonal (in the frequency domain)?
"""
return self.input_nodes[0].is_gram_diag(freq)
def is_diag(self, freq=False):
"""Is the lin op diagonal (in the frequency domain)?
"""
return self.input_nodes[0].is_diag(freq)
def get_diag(self, ref, freq=False):
"""Returns the diagonal representation (A^TA)^(1/2).
Parameters
----------
freq : bool
Is the diagonal representation in the frequency domain?
Returns
-------
dict of variable to ndarray
The diagonal operator acting on each variable.
"""
var_diags = self.input_nodes[0].get_diag(ref, freq) * self.scalar
return var_diags * torch.conj(var_diags)
# ---------------------------------------------------------------------------- #
# Property #
# ---------------------------------------------------------------------------- #
def norm_bound(self, input_mags):
"""Gives an upper bound on the magnitudes of the outputs given inputs.
Parameters
----------
input_mags : list
List of magnitudes of inputs.
Returns
-------
float
Magnitude of outputs.
"""
return abs(self.scalar) * input_mags[0]