Source code for dprox.linop.sum

import torch

from .base import LinOp


[docs]class sum(LinOp): """Sums its inputs. """ def __init__(self, input_nodes): super(sum, self).__init__(input_nodes)
[docs] def forward(self, *inputs, **kwargs): """ Just sum all the inputs, all inputs should have the same shape """ output = torch.zeros_like(inputs[0]) for input in inputs: output += input.to(output.device) return output
[docs] def adjoint(self, input, **kwargs): """ The adjoint of sum spread of the input to all its child """ outputs = LinOp.MultOutput() for _ in self.input_nodes: outputs.append(input) if len(outputs) > 1: return outputs return outputs[0]
def is_diag(self, freq=False): """Is the lin op diagonal (in the frequency domain)? """ return all([arg.is_diag(freq) for arg in self.input_nodes]) def is_gram_diag(self, freq=False): """Is the lin op diagonal (in the frequency domain)? """ return all([arg.is_gram_diag(freq) for arg in self.input_nodes]) def get_diag(self, ref, freq=False): """Returns the diagonal representation (A^TA)^(1/2). Parameters ---------- freq : bool Is the diagonal representation in the frequency domain? Returns ------- dict of variable to ndarray The diagonal operator acting on each variable. """ # var_diags = {var: torch.zeros(var.size) for var in self.variables()} # for arg in self.input_nodes: # arg_diags = arg.get_diag(shape, freq) # for var, diag in arg_diags.items(): # var_diags[var] = var_diags[var] + diag # return var_diags.values()[0] return self.input_nodes[0].get_diag(ref, freq) def norm_bound(self, input_mags): """Gives an upper bound on the magnitudes of the outputs given inputs. Parameters ---------- input_mags : list List of magnitudes of inputs. Returns ------- float Magnitude of outputs. """ return torch.sum(input_mags)
[docs]class copy(sum): def __init__(self, arg): super(copy, self).__init__([arg])
[docs] def forward(self, inputs, **kwargs): """The forward operator. Reads from inputs and writes to outputs. """ return super(copy, self).adjoint(inputs, **kwargs)
[docs] def adjoint(self, *inputs, **kwargs): """The adjoint operator. Reads from inputs and writes to outputs. """ return super(copy, self).forward(*inputs, **kwargs)
def norm_bound(self, input_mags): """Gives an upper bound on the magnitudes of the outputs given inputs. Parameters ---------- input_mags : list List of magnitudes of inputs. Returns ------- float Magnitude of outputs. """ return input_mags[0]