from pathlib import Path
from typing import List, Union
import torch
import torch.nn.functional as F
import torchlight as tl
import torchlight.nn as tlnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from dprox import *
from dprox.contrib.optic import Dataset
from dprox.proxfn import ProxFn
from dprox.utils import *
from . import opt
from .admm import ADMM, ADMM_vxu, LinearizedADMM
from .base import Algorithm
from .hqs import HQS
from .pc import PockChambolle
from .pgd import ProximalGradientDescent
from .specialization import DEQSolver, UnrolledSolver, AutoTuneSolver, build_unrolled_solver
SOLVERS = {
'admm': ADMM,
'admm_vxu': ADMM_vxu,
'ladmm': LinearizedADMM,
'hqs': HQS,
'pc': PockChambolle,
'pgd': ProximalGradientDescent,
}
SPECAILIZATIONS = {
'deq': DEQSolver,
'rl': AutoTuneSolver,
'unroll': build_unrolled_solver,
}
[docs]def compile(
prox_fns: List[ProxFn],
method: str = 'admm',
device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu',
**kwargs
):
"""
Compile the given objective (in terms of a list of proxable functions) into a proximal solver.
>>> solver = compile(data_term+reg_term, method='admm')
Args:
prox_fns (List[ProxFn]): A list or the sum of proxable functions.
method (str): A string that specifies the name of the optimization method to use. Defaults to `admm`.
Valid methods include [`admm`, `admm_vxu`, `ladmm`, `hqs`, `pc`, `pgd`].
device (Union[str, torch.device]): The device (CPU or GPU) on which the solver should run.
It can be either a string ('cpu' or 'cuda') or a `torch.device` object. Defaults to cuda if avaliable.
Returns:
An instance of a solver object that is created using the specified algorithm and proximal functions.
"""
algorithm: Algorithm = SOLVERS[method]
device = torch.device(device) if isinstance(device, str) else device
psi_fns, omega_fns = algorithm.partition(prox_fns)
solver = algorithm.create(psi_fns, omega_fns, **kwargs)
solver = solver.to(device)
return solver
[docs]def specialize(
solver: Algorithm,
method: str = 'deq',
device: Union[str, torch.device] = 'cuda' if torch.cuda.is_available() else 'cpu',
**kwargs
):
"""
Specialize the given solver based on the given method.
>>> deq_solver = specialize(solver, method='deq')
>>> rl_solver = specialize(solver, method='rl')
>>> unroll_solver = specialize(solver, method='unroll')
Args:
solver (Algorithm): the proximal solver that need to be specialized.
method (str): the strategy for the specialization. Choose from [`deq`, `rl`, `unroll`].
device (Union[str, torch.device]): The device (CPU or GPU) on which the solver should run.
It can be either a string ('cpu' or 'cuda') or a `torch.device` object. Defaults to cuda if avaliable
Returns:
The specialized solver.
"""
solver = SPECAILIZATIONS[method](solver, **kwargs)
device = torch.device(device) if isinstance(device, str) else device
solver = solver.to(device)
return solver
def optimize(
prox_fns: List[ProxFn],
merge=False,
absorb=False
):
if absorb:
prox_fns = opt.absorb.absorb_all_linops(prox_fns)
return prox_fns
def visualize():
pass
def train(
solver=None,
**kwargs,
):
if solver is None:
return _train(**kwargs)
if isinstance(solver, AutoTuneSolver):
return solver.train(**kwargs)
else:
raise ValueError(f'Training {solver} is not supported yet.')
def _train(
model,
step_fn,
dataset='BSD500',
savedir='saved',
epochs=10,
bs=2,
lr=1e-4,
resume=None,
):
savedir = Path(savedir)
savedir.mkdir(exist_ok=True, parents=True)
logger = tl.logging.Logger(savedir)
# ----------------- Start Training ------------------------ #
root = hf.download_dataset(dataset, force_download=False)
dataset = Dataset(root)
loader = DataLoader(dataset, batch_size=bs, shuffle=True)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4, weight_decay=1e-3
)
tlnn.utils.adjust_learning_rate(optimizer, lr)
epoch = 0
gstep = 0
best_psnr = 0
imgdir = savedir / 'imgs'
imgdir.mkdir(exist_ok=True, parents=True)
if resume:
ckpt = torch.load(savedir / resume)
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
epoch = ckpt['epoch'] + 1
gstep = ckpt['gstep'] + 1
best_psnr = ckpt['best_psnr']
def save_ckpt(name, psnr=0):
ckpt = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'gstep': gstep,
'psnr': psnr,
'best_psnr': best_psnr,
}
torch.save(ckpt, savedir / name)
save_ckpt('last.pth')
while epoch < epochs:
tracker = tl.trainer.util.MetricTracker()
pbar = tqdm(total=len(loader), dynamic_ncols=True, desc=f'Epcoh[{epoch}]')
for i, batch in enumerate(loader):
gt, inp, pred = step_fn(batch)
loss = F.mse_loss(gt, pred)
loss.backward()
optimizer.step()
optimizer.zero_grad()
psnr = tl.metrics.psnr(pred, gt)
loss = loss.item()
tracker.update('loss', loss)
tracker.update('psnr', psnr)
pbar.set_postfix({'loss': f'{tracker["loss"]:.4f}',
'psnr': f'{tracker["psnr"]:.4f}'})
pbar.update()
gstep += 1
logger.info('Epoch {} Loss={} LR={}'.format(epoch, tracker['loss'], tlnn.utils.get_learning_rate(optimizer)))
save_ckpt('last.pth', tracker['psnr'])
pbar.close()
epoch += 1