"""Implements the TT-GMRES algorithm for solving linear systems in the
TT-format, as described in Dolgov, arXiv:1206.5512, but with rounding step
optionally replaced by sketching."""
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
import logging
from math import ceil
from time import perf_counter
from typing import Dict, List, Literal, Optional, Tuple, Union, Any
import numpy as np
import numpy.typing as npt
import scipy.linalg
from tt_sketch.sketch import orthogonal_sketch, stream_sketch
from tt_sketch.tensor import Tensor, TensorSum, TensorTrain
from tt_sketch.utils import (
ArrayList,
TTRank,
dematricize,
matricize,
process_tt_rank,
)
[docs]class TTLinearMap(ABC):
"""Abstract class for linear maps in the TT-format."""
in_shape: Tuple[int, ...]
out_shape: Tuple[int, ...]
@abstractmethod
def __call__(self, other: TensorTrain) -> TensorTrain:
...
[docs]class MPO(Tensor, TTLinearMap):
"""MPO with order 4 tensor cores of shape
``(rank[mu-1],in_shape[mu],out_shape[mu],rank[mu])``
Used as linear map in the tensor train format"""
in_shape: Tuple[int, ...]
out_shape: Tuple[int, ...]
rank: Tuple[int, ...]
shape: Tuple[int, ...]
cores: ArrayList
def __init__(self, cores: ArrayList) -> None:
self.cores = cores
self.in_shape = tuple(C.shape[1] for C in cores)
self.out_shape = tuple(C.shape[2] for C in cores)
self.rank = tuple(C.shape[0] for C in cores[1:])
self.shape = tuple(
s1 * s2 for s1, s2 in zip(self.in_shape, self.out_shape)
)
@property
def size(self) -> int:
return sum(C.size for C in self.cores)
@property
def T(self) -> MPO:
"""Transposition here is that of a linear map, this is different from
other tensors."""
new_cores = [C.transpose((0, 2, 1, 3)) for C in self.cores]
return self.__class__(new_cores)
[docs] def to_tt(self) -> TensorTrain:
new_cores = [
C.reshape(C.shape[0], C.shape[1] * C.shape[2], C.shape[3])
for C in self.cores
]
return TensorTrain(new_cores)
[docs] def to_numpy(mpo) -> npt.NDArray:
"""Contract to dense array of shape
``(in_shape[0], out_shape[0], ..., in_shape[d-1], outs_shape[d-1])``"""
res = mpo.cores[0]
res = res.reshape(res.shape[1:])
for C in mpo.cores[1:]:
res = np.einsum("...i,ijkl->...jkl", res, C)
res = res.reshape(res.shape[:-1])
return res
def __call__(self, other: TensorTrain) -> TensorTrain:
new_cores = []
for M, C in zip(self.cores, other.cores):
MC = np.einsum("ijkl,ajb->iaklb", M, C)
MC = MC.reshape(
MC.shape[0] * MC.shape[1],
MC.shape[2],
MC.shape[3] * MC.shape[4],
)
new_cores.append(MC)
return TensorTrain(new_cores)
[docs] @classmethod
def random(
cls,
rank: TTRank,
in_shape: Tuple[int, ...],
out_shape: Tuple[int, ...],
) -> MPO:
prod_shape = tuple(s1 * s2 for s1, s2 in zip(in_shape, out_shape))
rank = process_tt_rank(rank, prod_shape, trim=True)
cores = []
for r1, s1, s2, r2 in zip(
(1,) + rank, in_shape, out_shape, rank + (1,)
):
C = np.random.normal(size=(r1, s1, s2, r2))
C += C.transpose(0, 2, 1, 3).reshape(C.shape) # symmetrize
C = C * np.sqrt(s1 * s2) / np.linalg.norm(C)
cores.append(C)
return cls(cores)
[docs] @classmethod
def eye(cls, shape) -> MPO:
cores = []
for s in shape:
C = np.eye(s, s)
C = C.reshape(1, C.shape[0], C.shape[1], 1)
cores.append(C)
return cls(cores)
def __mul__(self, other: float) -> MPO:
new_cores = self.cores
new_cores[0] = new_cores[0] * other
return self.__class__(new_cores)
[docs]class TTPrecond(TTLinearMap):
"""TTLinearMap that acts by multiplying by the inverse of a matrix on a
specified mode.
The inverse is computed from the precomputed QR factorization of the matrix.
"""
def __init__(self, A, shape, mode=0):
self.A = A
self.Q, self.R = np.linalg.qr(A)
self.mode = mode
self.in_shape = shape
self.out_shape = shape
[docs] def backward_call(self, other: TensorTrain) -> TensorTrain:
new_cores = deepcopy(other.cores)
C = new_cores[self.mode]
C_mat = matricize(C, mode=1, mat_shape=True)
sol = scipy.linalg.solve_triangular(self.R, (self.Q.T) @ C_mat)
new_cores[self.mode] = dematricize(sol, mode=1, shape=C.shape)
return TensorTrain(new_cores)
[docs] def forward_call(self, other: TensorTrain) -> TensorTrain:
new_cores = deepcopy(other.cores)
C = new_cores[self.mode]
C_mat = matricize(C, mode=1, mat_shape=True)
new_cores[self.mode] = dematricize(
self.A @ C_mat, mode=1, shape=C.shape
)
return TensorTrain(new_cores)
__call__ = backward_call
[docs]class TTLinearMapSum:
"""Linear map that eats a TT and returns a sum of TTs.
This is essentially just a container for a list of ``TTLinearMap`` objects.
"""
in_shape: Tuple[int, ...]
out_shape: Tuple[int, ...]
linear_maps: List[TTLinearMap]
def __init__(self, linear_maps: List[TTLinearMap]) -> None:
self.linear_maps = linear_maps
if len(linear_maps) == 0:
raise ValueError("linear_maps cannot be empty")
self.in_shape = linear_maps[0].in_shape
self.out_shape = linear_maps[0].out_shape
for linear_map in linear_maps[1:]:
if linear_map.in_shape != self.in_shape:
raise ValueError("in_shape mismatch")
if linear_map.out_shape != self.out_shape:
raise ValueError("out_shape mismatch")
def __call__(
self, input_tensor: Union[TensorTrain, TensorSum[TensorTrain]]
) -> TensorSum[TensorTrain]:
if isinstance(input_tensor, TensorTrain):
tensor_list = [input_tensor]
else:
tensor_list = input_tensor.tensors
output_list = []
for linear_map in self.linear_maps:
for tensor in tensor_list:
output_list.append(linear_map(tensor))
return TensorSum(output_list)
# def tt_weighted_sum_sketched(
# x0: TensorTrain,
# coeffs: npt.NDArray,
# tt_list: List[TensorTrain],
# tolerance: float,
# max_rank: Tuple[int, ...],
# round: bool = False,
# ):
# """Sketched weighted sum of tensor trains."""
# x_sum = TensorSum([x0])
# for coeff, tt in zip(coeffs, tt_list):
# x_sum += coeff * tt
# x = round_tt_sum(x_sum, tolerance, max_rank, False)
# return x
# def tt_weighted_sum_exact(
# x0: TensorTrain,
# coeffs: npt.NDArray,
# tt_list: List[TensorTrain],
# tolerance: float,
# max_rank: Tuple[int, ...],
# ):
# """Weighted sum of tensor trains rounded to ``max_rank``"""
# x = x0
# for coeff, tt in zip(coeffs, tt_list):
# x = x.add(coeff * tt)
# x = x.round(tolerance, max_rank)
# return x
# def tt_weighted_sum(
# x0: TensorTrain,
# coeffs: npt.NDArray,
# tt_list: List[TensorTrain],
# tolerance: float,
# max_rank: Tuple[int, ...],
# exact: bool = False,
# oversample: int = 5,
# ):
# x_sum = TensorSum([x0])
# for coeff, tt in zip(coeffs, tt_list):
# x_sum += coeff * tt
# x = round_tt_sum(x_sum, max_rank, tolerance, exact, oversample)
# return x
ROUNDING_MODE = Literal["exact", "pairwise", "sketch", "orth_sketch", None]
[docs]def round_tt_sum(
tt_sum: TensorSum[TensorTrain],
max_rank: TTRank,
eps: Optional[float] = None,
method: ROUNDING_MODE = "sketch",
oversample_factor: float = 2,
) -> TensorTrain:
"""Round a sum of tensor trains to a given rank.
method can be one of:
- "exact": Add all TTs to one big TT and round it using TT-SVD
- "pairwise": Add each TT to the next one and round them separately
- "sketch": Use streaming sketch for rounding
- "orth_sketch": Use orthogonal sketch for rounding.
- ``None``: Do not round (for debugging purposes mostly).
"""
if method == "exact":
summands = tt_sum.tensors
tt = summands[0]
for summand in summands[1:]:
tt = tt.add(summand)
return tt.round(eps, max_rank)
elif method == "pairwise":
tt = tt_sum.tensors[0]
for t in tt_sum.tensors[1:]:
tt = tt.add(t).round(eps=eps, max_rank=max_rank)
elif method == "sketch":
left_rank = process_tt_rank(max_rank, tt_sum.shape, trim=True)
right_rank = tuple(ceil(r * oversample_factor) for r in left_rank)
tt = stream_sketch(
tt_sum, left_rank=left_rank, right_rank=right_rank
).to_tt()
elif method == "orth_sketch":
left_rank = process_tt_rank(max_rank, tt_sum.shape, trim=True)
right_rank = tuple(ceil(r * oversample_factor) for r in left_rank)
tt = orthogonal_sketch(
tt_sum, left_rank=left_rank, right_rank=right_rank
)
elif method is None:
return tt_sum # type: ignore
else:
raise ValueError(f"Unknown rounding method: {method}")
return tt
[docs]def tt_sum_gmres(
A: TTLinearMapSum,
b: TensorTrain,
max_rank: TTRank,
precond: Optional[TTPrecond] = None,
final_round_rank: Optional[TTRank] = None,
x0: Optional[TensorTrain] = None,
tolerance: float = 1e-6,
maxiter: int = 100,
symmetric: bool = False,
rounding_method: ROUNDING_MODE = "pairwise",
rounding_method_final: Optional[ROUNDING_MODE] = None,
save_basis: bool = False,
verbose: bool = False,
) -> Tuple[TensorTrain, Dict[str, List]]:
"""
GMRES solver for TTLinearMapSum.
The TTLinearMapSum takes as input a TT and returns a sum of TTs. This means
additional rounding steps are required in comparison to a version of
TT-GMRES where the output of the linear map is a TT.
"""
if final_round_rank is None:
final_round_rank = max_rank
if rounding_method_final is None:
rounding_method_final = rounding_method
if A.out_shape != b.shape:
raise ValueError("Output shape of linear map doesn't match RHS")
if x0 is not None and x0.shape != A.in_shape:
raise ValueError("Input shape of liner map doesn't match initial value")
if A.out_shape != A.in_shape:
raise ValueError("TT-GMRES only works for automorphisms")
max_rank = process_tt_rank(max_rank, A.in_shape, trim=True)
if x0 is None:
# TODO: check whether init with zero or random init is better
# x0 = TensorTrain.random(shape=A.in_shape, rank=max_rank)
x0 = TensorTrain.zero(shape=A.in_shape, rank=1)
def apply_A_pr(x: TensorTrain) -> TensorSum[TensorTrain]:
res = A(x)
if precond is not None:
res = TensorSum([precond(r) for r in res.tensors])
return res
b_pr = precond(b) if precond is not None else b
b_norm = b.norm()
initial_time = perf_counter()
residual = b_pr - apply_A_pr(x0)
residual_rounded = round_tt_sum(
residual, max_rank=max_rank, method=rounding_method # type: ignore
)
residual_norm = residual_rounded.norm()
beta = residual_norm
nu_list: List[TensorTrain] = [residual_rounded / beta]
H_matrix = np.zeros((maxiter + 1, maxiter))
history: Dict[str, Any] = defaultdict(list)
history["w_norm"].append(nu_list[-1].norm())
history["rank"].append(residual_rounded.rank)
history["residual_norm"].append(residual_norm / b_norm)
history["step_time"].append(perf_counter() - initial_time)
for j in range(maxiter):
current_time = perf_counter()
delta = tolerance / (residual_norm / beta)
if verbose:
logging.info(
f"Iteration {j + 1}/{maxiter}, residual norm: {residual_norm / b_norm:.4e}"
)
w_sum = apply_A_pr(nu_list[-1])
w_rounded = round_tt_sum(
w_sum, eps=delta, max_rank=max_rank, method=rounding_method
)
min_j = max(0, j - 2) if symmetric else 0
for i in range(min_j, j + 1):
H_matrix[i, j] = w_rounded.dot(nu_list[i])
# Do Gram-Schmidt orthogonalization
w_sum = (
w_rounded
- TensorSum(nu_list[min_j : j + 1]) * H_matrix[min_j : j + 1, j]
)
w_rounded = round_tt_sum(
w_sum, eps=delta, max_rank=max_rank, method=rounding_method
)
H_matrix[j + 1, j] = w_rounded.norm()
nu_list.append(w_rounded / H_matrix[j + 1, j])
history["step_time"].append(perf_counter() - current_time)
# Compute residual norm
H_red = H_matrix[: j + 2, : j + 1]
e1 = np.zeros(j + 2)
e1[0] = beta
y, (residual_norm,), _, _ = np.linalg.lstsq(H_red, e1, rcond=None)
history["step_time_with_res_norm"].append(perf_counter() - current_time)
history["residual_norm"].append(np.sqrt(residual_norm) / b_norm)
history["rank"].append(w_rounded.rank)
history["w_norm"].append(H_matrix[j + 1, j])
history["delta"].append(delta)
if residual_norm / b_norm < tolerance:
break
# Compute final result and round
y = y[: j + 1]
nu_list = nu_list[: j + 1]
current_time = perf_counter()
result = x0 + TensorSum(nu_list) * y
result_rounded = round_tt_sum(
result,
eps=None,
max_rank=final_round_rank,
method=rounding_method_final,
)
history["final_round_time"] = perf_counter() - current_time
history["total_time"] = perf_counter() - initial_time
if save_basis:
history["H_matrix"] = H_matrix
history["nu_list"] = nu_list
history["y"] = y
return result_rounded, history