Source code for dprox.linalg.solve.solver_cg

from typing import Callable

import torch
import numpy as np


def bdot(x: torch.Tensor, y: torch.Tensor):
    """
    Perform batch dot product between two input tensors with the same shape.

    Args:
      x (torch.Tensor): The input tensor of the shape [batch_size, ...].
      y (torch.Tensor): Another input tensor with the same shape as x.

    Returns:
      Batch dot product of the shape [batch_size]
    """
    if len(x.shape) != len(y.shape):
        raise ValueError('The input of `bdot` should have the same shape.')
    if len(x.shape) == 1:
        return torch.dot(x, y)
    return torch.sum(x.reshape(x.shape[0], -1) * y.reshape(y.shape[0], -1), dim=-1)


def expand(x: torch.Tensor, ref: torch.Tensor):
    """
    Expands the input tensor to match the number of dimensions of the reference tensor.

    Args:
      x (torch.Tensor): The input tensor of any shape.
      ref (torch.Tensor): The reference tensor of any shape.

    Returns:
      the expanded tensor `x` to match the number of dimensions of the reference tensor `ref`.
    """
    while len(x.shape) < len(ref.shape):
        x = x.unsqueeze(-1)
    return x


def ravel(x: torch.Tensor):
    """
    Flatten the tensor if it has more than one dimension, this function treat the first dimmension as batch dimmension.

    Args:
      x (torch.Tensor): The input tensor of any shape.

    Returns:
      Flatten tensor with batch dimmension reserved. If the input tensor has only one dimmension, return as it is.
    """
    if len(x.shape) == 1:
        return x
    return x.reshape(x.shape[0], -1)


[docs]def cg( A: Callable, b: torch.Tensor, x0: torch.Tensor = None, rtol: float = 1e-6, max_iters: int = 100, verbose: bool = False ): """ Conjugate gradient method for solving a linear system of equations. Args: A (Callable): A is a callable function representing the forward operator A(x) of a matrix free linear operator. b (torch.Tensor): The parameter `b` is a tensor representing the right-hand side of the linear system of equations `Ax = b`. x0 (torch.Tensor): The initial guess for the solution vector. If not provided, it is initialized to a vector of zeros. rtol (float): Relative tolerance for convergence criteria. Default to 1e-6 max_iters (int): The maximum number of iterations. Defaults to 100 verbose (bool): Whether to logging intermediate information. Defaults to False Returns: The solution `x` to the linear system `Ax=b` using the conjugate gradient method. """ # Temp vars x = torch.zeros_like(b) r = torch.zeros_like(b) Ap = torch.zeros_like(b) # Initialize x if x0 is not None: x = x0 # Compute residual r = A(x) r *= - 1.0 r += b cg_tol = rtol * torch.linalg.norm(ravel(b), 2, dim=-1) # Relative tolerence # CG iteration gamma_1 = p = None cg_iter = np.minimum(max_iters, np.prod(b.shape)) for iter in range(cg_iter): # Check for convergence normr = torch.linalg.norm(ravel(r), 2) if torch.all(normr <= cg_tol): if verbose: print("Converged at CG Iter %03d" % iter) break gamma = bdot(r, r) gamma = expand(gamma, x) # direction vector if iter > 0: beta = gamma / gamma_1 p = r + beta * p else: p = r # Compute Ap Ap = A(p) # Cg update q = Ap tmp = bdot(p, q) alpha = gamma / expand(tmp, x) x = x + alpha * p # update approximation vector r = r - alpha * q # compute residual gamma_1 = gamma if verbose: print(f'Not converged, r norm={normr.tolist()}') return x
def cg2( A, b, x0=None, rtol=1e-6, max_iters=500, verbose=False ): # Solves A x = b x = torch.ones_like(b) if x0 is not None: print('use x init') x = x0 r = b - A(x) d = r rnorm = r.ravel() @ r.ravel() for iter in range(max_iters): Ad = A(d) alpha = rnorm / (d.ravel() @ Ad.ravel()) x = x + alpha * d r = r - alpha * Ad rnorm2 = r.ravel() @ r.ravel() beta = rnorm2 / rnorm rnorm = rnorm2 d = r + beta * d if rnorm2 < rtol: if verbose: print(f'converge at iter={iter}, rtol={rtol}') break res = b - A(x) res = res.ravel() @ res.ravel() return x
[docs]def pcg( A: Callable, b: torch.Tensor, x0: torch.Tensor = None, rtol: float = 1e-6, max_iters: int = 100, verbose: bool = False, Minv: Callable = None, ): """ Preconditioned conjugate gradient method for solving a linear system of equations. The same as :func:conjugate_gradient except it could be preconditioned via `Minv`. Args: A (Callable): A is a callable function representing the forward operator A(x) of a matrix free linear operator. b (torch.Tensor): The parameter `b` is a tensor representing the right-hand side of the linear system of equations `Ax = b`. x0 (torch.Tensor): The initial guess for the solution vector. If not provided, it is initialized to a vector of zeros. rtol (float): Relative tolerance for convergence criteria. Default to 1e-6 max_iters (int): The maximum number of iterations. Defaults to 100 verbose (bool): Whether to logging intermediate information. Defaults to False Minv (Callable): A callable function representing the preconditioner. Returns: The solution `x` to the linear system `Ax=b` using the conjugate gradient method. """ ord = float('inf') if Minv is None: def Minv(x): return x if x0 is not None: x = x0 else: x = torch.ones_like(b) r = A(x) - b y = Minv(r) p = - y bnorm = torch.linalg.vector_norm(b.ravel(), ord=ord) for iter in range(max_iters): Ap = A(p) ry = r.ravel() @ y.ravel() alpha = ry / (p.ravel() @ Ap.ravel()) x = x + alpha * p r = r + alpha * Ap y = Minv(r) # y = r beta = (r.ravel() @ y.ravel()) / ry p = - y + beta * p rnorm = torch.linalg.vector_norm(r.ravel(), ord=ord) # if rnorm < rtol * bnorm: if rnorm < rtol: break if verbose: print(f'#IT: {iter + 1}; bnorm: {bnorm:.3e}; rnorm: {rnorm:.3e}; rtol: {rtol:.3e}') return x