import numpy as np
from .base import LinOp
[docs]class vstack(LinOp):
"""Vectorizes and stacks inputs.
"""
def __init__(self, input_nodes):
super(vstack, self).__init__(input_nodes)
# ---------------------------------------------------------------------------- #
# Computation #
# ---------------------------------------------------------------------------- #
[docs] def forward(self, *inputs, **kwargs):
"""The forward operator.
Reads from inputs and writes to outputs.
"""
if len(inputs) > 1:
return LinOp.MultOutput(inputs)
return inputs[0]
[docs] def adjoint(self, *inputs, **kwargs):
"""The adjoint operator.
Reads from inputs and writes to outputs.
"""
if len(inputs) > 1:
return LinOp.MultOutput(inputs)
return inputs[0]
# ---------------------------------------------------------------------------- #
# Diagonal #
# ---------------------------------------------------------------------------- #
def is_gram_diag(self, freq=False):
"""Is the lin op's Gram matrix diagonal (in the frequency domain)?
"""
return all([arg.is_gram_diag(freq) for arg in self.input_nodes])
def get_diag(self, freq=False):
"""Returns the diagonal representation (A^TA)^(1/2).
Parameters
----------
freq : bool
Is the diagonal representation in the frequency domain?
Returns
-------
dict of variable to ndarray
The diagonal operator acting on each variable.
"""
var_diags = {var: np.zeros(var.size) for var in self.variables()}
for arg in self.input_nodes:
arg_diags = arg.get_diag(freq)
for var, diag in arg_diags.items():
var_diags[var] = var_diags[var] + diag * np.conj(diag)
# Get (A^TA)^{1/2}
for var in self.variables():
var_diags[var] = np.sqrt(var_diags[var])
return var_diags
# ---------------------------------------------------------------------------- #
# Property #
# ---------------------------------------------------------------------------- #
def norm_bound(self, input_mags):
"""Gives an upper bound on the magnitudes of the outputs given inputs.
Parameters
----------
input_mags : list
List of magnitudes of inputs.
Returns
-------
float
Magnitude of outputs.
"""
return np.linalg.norm(input_mags, 2)
[docs]class split(vstack):
def __init__(self, output_nodes):
self.output_nodes = output_nodes
self.input_nodes = []
super(split, self).__init__(output_nodes)
[docs] def forward(self, *inputs, **kwargs):
"""The forward operator.
Reads from inputs and writes to outputs.
"""
return super(split, self).adjoint(*inputs, **kwargs)
[docs] def adjoint(self, *inputs, **kwargs):
"""The adjoint operator.
Reads from inputs and writes to outputs.
"""
return super(split, self).forward(*inputs, **kwargs)
def norm_bound(self, input_mags):
"""Gives an upper bound on the magnitudes of the outputs given inputs.
Parameters
----------
input_mags : list
List of magnitudes of inputs.
Returns
-------
float
Magnitude of outputs.
"""
return input_mags[0]