Linear System Solver#

Unified Interface#

dprox.linalg.linear_solve(A: Module, b: Tensor, config: LinearSolveConfig = LinearSolveConfig(rtol=1e-06, max_iters=100, verbose=False, solver_type='cg', solver_kwargs={}))[source]#

Solves a linear system of equations with analytic gradient.

Parameters:
  • A (torch.nn.Module) – A is a torch.nn.Module object, it should be callable as A(x) for forward operator of the linear operator.

  • b (torch.Tensor) – b is a tensor representing the right-hand side of the linear system of equations Ax = b.

  • config (LinearSolveConfig) – config is an instance of the LinearSolveConfig class, which contains various configuration options for the linear solver. These options include the maximum number of iterations, the tolerance level for convergence, and the method used to solve the linear system.

Returns:

The solution of Ax = b.

class dprox.linalg.LinearSolveConfig(rtol: float = 1e-06, max_iters: int = 100, verbose: bool = False, solver_type: str = 'cg', solver_kwargs: dict = <factory>)[source]#

Defines default configuration parameters for solving linear equations.

Parameters:
  • rtol (float) – The relative tolerance level for convergence, default to 1e-6.

  • max_iters (int) – The maximum number of iterations allowed for convergence.

  • verbose (bool) – whether to print progress updates during the solving process.

  • solver_type (str) – The type of solver to use (e.g. conjugate gradient).

  • solver_kwargs (dict) – additional keyword arguments to pass to the solver function

Linear Solvers#

dprox.linalg.solve.solver_cg.cg(A: Callable, b: Tensor, x0: Tensor | None = None, rtol: float = 1e-06, max_iters: int = 100, verbose: bool = False)[source]#

Conjugate gradient method for solving a linear system of equations.

Parameters:
  • A (Callable) – A is a callable function representing the forward operator A(x) of a matrix free linear operator.

  • b (torch.Tensor) – The parameter b is a tensor representing the right-hand side of the linear system of equations Ax = b.

  • x0 (torch.Tensor) – The initial guess for the solution vector. If not provided, it is initialized to a vector of zeros.

  • rtol (float) – Relative tolerance for convergence criteria. Default to 1e-6

  • max_iters (int) – The maximum number of iterations. Defaults to 100

  • verbose (bool) – Whether to logging intermediate information. Defaults to False

Returns:

The solution x to the linear system Ax=b using the conjugate gradient method.

dprox.linalg.solve.solver_cg.pcg(A: Callable, b: Tensor, x0: Tensor | None = None, rtol: float = 1e-06, max_iters: int = 100, verbose: bool = False, Minv: Callable | None = None)[source]#

Preconditioned conjugate gradient method for solving a linear system of equations. The same as :func:conjugate_gradient except it could be preconditioned via Minv.

Parameters:
  • A (Callable) – A is a callable function representing the forward operator A(x) of a matrix free linear operator.

  • b (torch.Tensor) – The parameter b is a tensor representing the right-hand side of the linear system of equations Ax = b.

  • x0 (torch.Tensor) – The initial guess for the solution vector. If not provided, it is initialized to a vector of zeros.

  • rtol (float) – Relative tolerance for convergence criteria. Default to 1e-6

  • max_iters (int) – The maximum number of iterations. Defaults to 100

  • verbose (bool) – Whether to logging intermediate information. Defaults to False

  • Minv (Callable) – A callable function representing the preconditioner.

Returns:

The solution x to the linear system Ax=b using the conjugate gradient method.

dprox.linalg.solve.solver_minres.minres(A: Callable, b: Tensor, x0: Tensor | None = None, rtol: float = 1e-06, max_iters: int = 100, verbose: bool = False, Minv: Callable | None = None, eps: float = 1e-25, shifts: Tensor | None = None, value: float | None = None)[source]#

Perform MINRES to find solutions to \((K + \alpha \sigma I) x = b\). Will find solutions for multiple shifts \(\sigma\) at the same time.

Parameters:
  • A (Callable) – A is a callable function representing the forward operator A(x) of a matrix free linear operator.

  • b (torch.Tensor) – The parameter b is a tensor representing the right-hand side of the linear system of equations Ax = b.

  • x0 (torch.Tensor) – The initial guess for the solution vector. If not provided, it is initialized to a vector of zeros.

  • rtol (float) – Relative tolerance for convergence criteria. Default to 1e-6

  • max_iters (int) – The maximum number of iterations. Defaults to 100

  • verbose (bool) – Whether to logging intermediate information. Defaults to False

  • Minv (Callable) – A callable function representing the preconditioner.

  • shifts (torch.Tensor) – The shift \(\sigma\) values. If set to None, then \(\sigma=0\).

  • vlaue (float) – The multiplicative constant \(\alpha\). If set to None, then \(\alpha=0\).

Returns:

The solves \(x\). The shape will correspond to the size of b and shifts.

Note

MINRES solver does not support Unrolling mode

dprox.linalg.solve.solver_plss.plss(A: Callable, b: Tensor, x0: Tensor | None = None, rtol: float = 1e-06, max_iters: int = 100, verbose: bool = False)[source]#

A Projective Linear Systems Solver (for well-conditioned system)

Parameters:
  • A (Callable) – A is a callable function representing the forward operator A(x) of a matrix free linear operator.

  • b (torch.Tensor) – The parameter b is a tensor representing the right-hand side of the linear system of equations Ax = b.

  • x0 (torch.Tensor) – The initial guess for the solution vector. If not provided, it is initialized to a vector of zeros.

  • rtol (float) – Relative tolerance for convergence criteria. Default to 1e-6

  • max_iters (int) – The maximum number of iterations. Defaults to 100

  • verbose (bool) – Whether to logging intermediate information. Defaults to False

Returns:

The solution x to the linear system Ax=b, using the Preconditioned Least Squares Stationary iterative method.