Source code for dprox.linalg.solve.solver_plss

"""
PLSS: A PROJECTED LINEAR SYSTEMS SOLVER 
For general linear system Ax = b with arbitrary matrix shape (supporting m = n; m > n; m < n)
PLSS and PLSSW are good at solving well-conditioned and ill-conditioned systems respectively.
See https://epubs.siam.org/doi/10.1137/22M1509783
"""
from typing import Callable

import torch


[docs]def plss( A: Callable, b: torch.Tensor, x0: torch.Tensor = None, rtol: float = 1e-6, max_iters: int = 100, verbose: bool = False, ): """ A Projective Linear Systems Solver (for well-conditioned system) Args: A (Callable): A is a callable function representing the forward operator A(x) of a matrix free linear operator. b (torch.Tensor): The parameter `b` is a tensor representing the right-hand side of the linear system of equations `Ax = b`. x0 (torch.Tensor): The initial guess for the solution vector. If not provided, it is initialized to a vector of zeros. rtol (float): Relative tolerance for convergence criteria. Default to 1e-6 max_iters (int): The maximum number of iterations. Defaults to 100 verbose (bool): Whether to logging intermediate information. Defaults to False Returns: The solution x to the linear system Ax=b, using the Preconditioned Least Squares Stationary iterative method. """ if x0 is None: x0 = torch.zeros_like(b) x_min = xk = x0 b_norm = torch.linalg.vector_norm(b) rk = A(xk) - b rk_norm_min = rk_norm = torch.linalg.vector_norm(rk) yk = A.adjoint(rk / rk_norm) rhok = rk_norm deltaki = 1 / torch.sum(yk * yk) pk = - (deltaki * rhok) * yk xk = xk + pk tol = rtol * b_norm for k in range(1, max_iters): rk = A(xk) - b rk_norm = torch.linalg.vector_norm(rk) # Store minimum iterate if rk_norm_min >= rk_norm: x_min = xk rk_norm_min = rk_norm if rk_norm <= tol: break yk = A.adjoint(rk / rk_norm) rhok = rk_norm p2 = torch.sum((pk * pk)) nrp = torch.sqrt(p2) py = torch.sum((pk * yk)) yy = torch.sum((yk * yk)) ny = torch.sqrt(yy) denom = (nrp * ny - py) * (nrp * ny + py) beta1 = (rhok * py) / denom beta2 = - (rhok * p2) / denom # Step computation pk = beta1 * pk + beta2 * yk xk = xk + pk rk = A(xk) - b rk_norm = torch.linalg.vector_norm(rk) if rk_norm_min < rk_norm: rk_norm = rk_norm_min xk = x_min if verbose: print(k + 1) return xk # , rk_norm
def plssw( A: Callable, b: torch.Tensor, Wh: torch.Tensor, x0: torch.Tensor = None, rtol: float = 1e-6, max_iters: int = 100, verbose: bool = False, ): """ A Projective Linear Systems Solver Weighted (for ill-conditioned system) Args: A (Callable): A is a callable function representing the forward operator A(x) of a matrix free linear operator. b (torch.Tensor): The parameter `b` is a tensor representing the right-hand side of the linear system of equations `Ax = b`. Wh (torch.Tensor): A weight matrix used to adjust the importance of different components in the solution. It is used to compute the diagonal matrix Whi, which is the element-wise inverse of Wh x0 (torch.Tensor): The initial guess for the solution vector. If not provided, it is initialized to a vector of zeros. rtol (float): Relative tolerance for convergence criteria. Default to 1e-6 max_iters (int): The maximum number of iterations. Defaults to 100 verbose (bool): Whether to logging intermediate information. Defaults to False Returns: The solution x to the linear system Ax=b, using the Preconditioned Least Squares Stationary iterative method. """ Whi = 1 / Wh Whi[torch.isinf(Whi)] = 0 xk = x0 ck = A(xk) - b nck = torch.linalg.vector_norm(ck) yk = A.adjoint(ck / nck) k = 0 # Store minimum solution estimate xkmin = xk nckmin = nck rhok = nck zk = Whi * yk deltaki = 1 / torch.sum((zk * zk)) pk = - (deltaki * rhok) * (Whi * zk) k = k + 1 xk = xk + pk bnorm = torch.linalg.vector_norm(b) tol = rtol * bnorm for k in range(max_iters): Axk = A(xk) ck = Axk - b nck = torch.linalg.vector_norm(ck) # Store minimum iterate if nckmin >= nck: xkmin = xk nckmin = nck if nck <= tol: break yk = A.adjoint(ck / nck) zk = Whi * yk rhok = nck # Modifications for weighting Wp = Wh * pk p2 = torch.sum(Wp * Wp) nrp = torch.sqrt(p2) py = torch.sum((pk * yk)) yy = torch.sum((zk * zk)) ny = torch.sqrt(yy) denom = (nrp * ny - py) * (nrp * ny + py) beta1 = (rhok * py) / denom beta2 = - (rhok * p2) / denom # Step computation pk = beta1 * pk + beta2 * (Whi * zk) # Prepare for next iteration xk = xk + pk k = k + 1 rk = A(xk) - b nck = torch.linalg.vector_norm(rk) if nckmin < nck: nck = nckmin xk = xkmin if verbose: print(k + 1) return xk # , nck