# %%
import concurrent.futures
import multiprocessing
from functools import reduce
from operator import mul
from typing import Generator, List, Optional, Sequence, Tuple, Union
import numpy as np
import numpy.typing as npt
import scipy.linalg
# from numpy.random import Generator, SeedSequence, default_rng
from numpy.typing import ArrayLike
ArrayList = List[npt.NDArray[np.float64]]
ArrayGenerator = Generator[npt.NDArray[np.float64], None, None]
TTRank = Union[int, Tuple[int, ...]]
[docs]def hilbert_tensor(n_dims: int, size: int) -> npt.NDArray:
"""Create a Hilbert tensor of specified size and dimensionality."""
grid = np.meshgrid(*([np.arange(size)] * n_dims))
hilbert = 1 / (np.sum(np.array(grid), axis=0) + 1)
return hilbert
[docs]def sqrt_tensor(shape: Tuple[int, ...], a=-0.2, b=2) -> npt.NDArray:
"""Create a tensor of specified shape with square root of a sum a grid.
Values of grid entries vary between a and b."""
def sqrt_sum(X):
return np.sqrt(np.abs(np.sum(X, axis=0)))
vals = [np.linspace(a, b, s) for s in shape]
grid = np.stack(np.meshgrid(*vals))
X = sqrt_sum(grid)
X /= np.linalg.norm(X)
return X
[docs]def power_decay_tensor(
shape: Tuple[int], pow: float = 2.0, seed=None
) -> ArrayLike:
"""Create tensor of specified shape such that singular values of each
unfolding decay with a power law."""
# if seed is not None:
# np.random.seed(np.mod(seed, 2**32 - 1))
seq = SeedSequence(seed)
A_seed = seq.generate_state(1)[0]
A = random_normal(shape=shape, seed=A_seed)
for mode in range(len(A.shape)):
A_mat = matricize(A, mode)
U, S, V = np.linalg.svd(A_mat, full_matrices=False)
S /= S[0]
S *= 1 / np.arange(1, len(S) + 1) ** pow
A_mat = U @ np.diag(S) @ V
A = dematricize(A_mat, mode, A.shape)
return A
[docs]def matricize(
A: npt.NDArray, mode: Union[int, Sequence[int]], mat_shape: bool = False
):
"""Matricize tensor ``A`` with respect to ``mode``.
If mode is an int, return matrix. If mode is a tuple, return tensor of order
``len(mode)+1``, unless ``mat_shape=True``"""
if isinstance(mode, int):
mode = (mode,)
else: # Try casting to tuple
mode = tuple(mode)
perm = mode + tuple(i for i in range(len(A.shape)) if i not in mode)
A = np.transpose(A, perm)
right_shape = (np.prod(A.shape[len(mode) :], dtype=int),)
if mat_shape:
left_shape = (np.prod(A.shape[: len(mode)], dtype=int),)
else:
left_shape = A.shape[: len(mode)] # type: ignore
A = A.reshape(left_shape + right_shape)
return A
[docs]def dematricize(A, mode, shape):
"""Undo matricization of ``A`` with respect to ``mode``. Needs ``shape`` of
original tensor."""
current_shape = [A.shape[0]] + [s for i, s in enumerate(shape) if i != mode]
current_shape = tuple(current_shape)
A = A.reshape(current_shape)
perm = list(range(1, len(shape)))
perm = perm[:mode] + [0] + perm[mode:]
A = np.transpose(A, perm)
return A
[docs]def right_mul_pinv(A, B, cond=None):
"""Compute numerically stable product ``A@np.linalg.pinv(B)``"""
lstsq = scipy.linalg.lstsq(B.T, A.T, cond=cond)
return lstsq[0].T
[docs]def left_mul_pinv(A, B, cond=None):
"""Compute numerically stable product ``np.linalg.pinv(A)@B``"""
lstsq = scipy.linalg.lstsq(A, B, cond=cond)
return lstsq[0]
[docs]def projector(X: npt.NDArray, Y: Optional[npt.NDArray] = None) -> npt.NDArray:
r"""Compute oblique projector :math:`\mathcal P_{X,Y}`"""
if Y is None:
Y = X
P = X @ np.linalg.pinv(Y.T @ X) @ Y.T
return P
[docs]def trim_ranks(
dims: Tuple[int, ...], ranks: Tuple[int, ...]
) -> Tuple[int, ...]:
"""Return TT-rank to which TT can be exactly reduced
A tt-rank can never be more than the product of the dimensions on the left
or right of the rank. Furthermore, any internal edge in the TT cannot have
rank higher than the product of any two connected supercores. Ranks are
iteratively reduced for each edge to satisfy these two requirements until
the requirements are all satisfied.
"""
ranks_trimmed = list(ranks)
for i, r in enumerate(ranks_trimmed):
dim_left = reduce(mul, dims[: i + 1], 1)
dim_right = reduce(mul, dims[i + 1 :], 1)
ranks_trimmed[i] = min(r, dim_left, dim_right)
changed = True
ranks_trimmed = [1] + ranks_trimmed + [1]
for _ in range(100):
changed = False
for i, d in enumerate(dims):
if ranks_trimmed[i + 1] > ranks_trimmed[i] * d:
changed = True
ranks_trimmed[i + 1] = ranks_trimmed[i] * d
if ranks_trimmed[i] > d * ranks_trimmed[i + 1]:
changed = True
ranks_trimmed[i] = d * ranks_trimmed[i + 1]
if not changed:
break
return tuple(ranks_trimmed[1:-1])
[docs]def process_tt_rank(
rank: TTRank, shape: Tuple[int, ...], trim: bool
) -> Tuple[int, ...]:
"""
Process TT rank, and check validity. Makes sure rank is a tuple.
If ``trim=True``, ranks are trimmed to the smallest possible lossless value.
"""
# check if rank is iterable, if not use constant rank
try:
rank_tuple = tuple(rank) # type: ignore
except TypeError:
rank_tuple = (rank,) * (len(shape) - 1) # type: ignore
if len(rank_tuple) != len(shape) - 1:
raise ValueError(
f"TT-rank {rank_tuple} doesn't have right number of elements"
)
if trim:
rank_tuple = trim_ranks(shape, rank_tuple)
return rank_tuple
[docs]class MultithreadedRNG:
"""
Multithreaded standard normal random number generator.
Copy pasta from numpy docs
"""
def __init__(self, shape, seed=None, threads=None):
if threads is None:
threads = multiprocessing.cpu_count()
self.threads = threads
seq = np.random.SeedSequence(seed)
self._random_generators = [
np.random.default_rng(s) for s in seq.spawn(threads)
]
self.shape = shape
n = np.prod(shape)
self.executor = concurrent.futures.ThreadPoolExecutor(threads)
self.values = np.empty(n)
self.step = np.ceil(n / threads).astype(np.int_)
self.fill()
[docs] def fill(self):
def _fill(random_state, out, first, last):
random_state.standard_normal(out=out[first:last])
futures = {}
for i in range(self.threads):
args = (
_fill,
self._random_generators[i],
self.values,
i * self.step,
(i + 1) * self.step,
)
futures[self.executor.submit(*args)] = i
concurrent.futures.wait(futures)
self.values = self.values.reshape(self.shape)
# def __del__(self):
# self.executor.shutdown(False)
[docs]def random_normal(shape, seed=None):
"""
Generate multi-threaded random numbers
"""
return MultithreadedRNG(shape, seed).values