Source code for tt_sketch.sketch_dispatch

"""
Implements methods for dispatching sketching methods for tensors and DRMs.
"""

import enum
from functools import partial
from typing import Callable, Literal, Optional, Tuple

import numpy as np
import numpy.typing as npt
import scipy.linalg

from tt_sketch.drm import TensorTrainDRM
from tt_sketch.drm_base import DRM
from tt_sketch.sketch_container import SketchContainer
from tt_sketch.sketching_methods.abstract_methods import (
    CansketchCP,
    CansketchDense,
    CansketchSparse,
    CansketchTT,
    CanSketchTucker,
)
from tt_sketch.sketching_methods.cp_sketch import sketch_omega_cp, sketch_psi_cp
from tt_sketch.sketching_methods.dense_sketch import (
    sketch_omega_dense,
    sketch_psi_dense,
)
from tt_sketch.sketching_methods.sparse_sketch import (
    sketch_omega_sparse,
    sketch_psi_sparse,
)
from tt_sketch.sketching_methods.tensor_train_sketch import (
    sketch_omega_tt,
    sketch_psi_tt,
)
from tt_sketch.sketching_methods.tucker_sketch import (
    sketch_omega_tucker,
    sketch_psi_tucker,
)
from tt_sketch.tensor import (
    CPTensor,
    DenseTensor,
    SparseTensor,
    Tensor,
    TensorSum,
    TensorTrain,
    TuckerTensor,
)
from tt_sketch.utils import ArrayList, right_mul_pinv

ABSTRACT_TENSOR_SKETCH_DISPATCH = {
    SparseTensor: CansketchSparse,
    TensorTrain: CansketchTT,
    DenseTensor: CansketchDense,
    CPTensor: CansketchCP,
    TuckerTensor: CanSketchTucker,
}

DRM_SKETCH_METHOD_DISPATCH = {
    SparseTensor: "sketch_sparse",
    TensorTrain: "sketch_tt",
    DenseTensor: "sketch_dense",
    CPTensor: "sketch_cp",
    TuckerTensor: "sketch_tucker",
}


OMEGA_METHODS = {
    SparseTensor: sketch_omega_sparse,
    TensorTrain: sketch_omega_tt,
    DenseTensor: sketch_omega_dense,
    CPTensor: sketch_omega_cp,
    TuckerTensor: sketch_omega_tucker,
}

PSI_METHODS = {
    SparseTensor: sketch_psi_sparse,
    TensorTrain: sketch_psi_tt,
    DenseTensor: sketch_psi_dense,
    CPTensor: sketch_psi_cp,
    TuckerTensor: sketch_psi_tucker,
}


def sketch_omega_sum(
    left_sketch_array: ArrayList,
    right_sketch_array: ArrayList,
    *,
    tensor: TensorSum,
    omega_shape: Tuple[int, int],
    **kwargs,
) -> npt.NDArray:
    omega = np.zeros(omega_shape)
    for summand, left_sketch, right_sketch in zip(
        tensor.tensors, left_sketch_array, right_sketch_array
    ):
        omega_method = OMEGA_METHODS[type(summand)]
        omega += omega_method(
            left_sketch,
            right_sketch,
            tensor=summand,
            omega_shape=omega_shape,
            **kwargs,
        )  # type: ignore
    return omega


OMEGA_METHODS[TensorSum] = sketch_omega_sum


def sketch_psi_sum(
    left_sketch_array: ArrayList,
    right_sketch_array: ArrayList,
    *,
    tensor: TensorSum,
    psi_shape: Tuple[int, int],
    **kwargs,
) -> npt.NDArray:
    psi = np.zeros(psi_shape)
    if left_sketch_array is None:
        left_sketch_array = (None,) * tensor.num_summands
    if right_sketch_array is None:
        right_sketch_array = (None,) * tensor.num_summands

    for summand, left_sketch, right_sketch in zip(
        tensor.tensors, left_sketch_array, right_sketch_array
    ):
        psi_method = PSI_METHODS[type(summand)]
        psi += psi_method(
            left_sketch,
            right_sketch,
            tensor=summand,
            psi_shape=psi_shape,
            **kwargs,
        )  # type: ignore
    return psi


PSI_METHODS[TensorSum] = sketch_psi_sum


def sum_sketch(tensor: TensorSum, *, drm: DRM):
    sketch_generators = []
    for summand in tensor.tensors:
        sketch_generators.append(get_sketch_method(summand, drm)(summand))
    for _ in range(len(tensor.shape) - 1):
        yield tuple(next(gen) for gen in sketch_generators)


def get_sketch_method(tensor: Tensor, drm: DRM) -> Callable:
    if type(tensor) in DRM_SKETCH_METHOD_DISPATCH:
        drm_sketch_method = DRM_SKETCH_METHOD_DISPATCH[type(tensor)]
        return getattr(drm, drm_sketch_method)
    elif isinstance(tensor, TensorSum):
        return partial(sum_sketch, drm=drm)
    else:
        raise ValueError(f"DRM of type {type(drm)} can't sketch {type(tensor)}")


[docs]def orth_step( Psi: npt.NDArray[np.float64], Omega: Optional[npt.NDArray[np.float64]] ) -> npt.NDArray[np.float64]: """ Perform the orthogonalization step in the orthogonal sketching algorithm. """ Psi_shape = Psi.shape final_right_rank = Psi_shape[2] if Omega is None else Omega.shape[0] Psi_mat = Psi.reshape((Psi_shape[0] * Psi_shape[1], Psi_shape[2])) if Omega is not None: Psi_mat = right_mul_pinv(Psi_mat, Omega) # Psi_mat, _ = np.linalg.qr(Psi_mat) Psi_mat, _ = scipy.linalg.qr(Psi_mat, mode="economic") Psi = Psi_mat.reshape(Psi_shape[0], Psi_shape[1], final_right_rank) return Psi
[docs]class OrthogTTDRM: """Represents the orthogonalized TT used as left-sketch for psi""" def __init__(self, rank, tensor): self.rank = rank self.drm = TensorTrainDRM(rank, tensor.shape, transpose=False, cores=[]) self.generator = None self.tensor = tensor self.sketch_method = get_sketch_method(tensor, self.drm) def add_core(self, core): self.drm.cores.append(core) if self.generator is None: self.generator = self.sketch_method(self.tensor) def __next__(self): return next(self.generator)
[docs]class SketchMethod(enum.Enum): streaming = "streaming" orthogonal = "orthogonal" hmt = "hmt"
[docs]def general_sketch( tensor: Tensor, left_drm: Optional[DRM], right_drm: DRM, method: SketchMethod, ) -> SketchContainer: """General algorithm for sketching a tensor. Does the heavy lifting for both the streaming and orthogonal sketching algorithms.""" n_dims = len(tensor.shape) if method != SketchMethod.hmt: if left_drm is None: raise ValueError(f"left_drm must be provided for method '{method}'") left_contractions = list(get_sketch_method(tensor, left_drm)(tensor)) right_contractions = list(get_sketch_method(tensor, right_drm)(tensor)) if left_drm is None: # This is just required for shape information in the case of HMT left_drm = right_drm.T Psi_cores: ArrayList = [] # Compute Omega matrices Omega_mats: ArrayList = [] if method != SketchMethod.hmt: omega_method = OMEGA_METHODS[type(tensor)] for mu in range(n_dims - 1): omega_shape = (left_drm.rank[mu], right_drm.rank[::-1][mu]) Omega_mats.append( omega_method( left_contractions[mu], right_contractions[mu], tensor=tensor, mu=mu, omega_shape=omega_shape, ) # type: ignore ) if method in (SketchMethod.hmt, SketchMethod.orthogonal): left_psi_drm = OrthogTTDRM(left_drm.rank, tensor) # Compute Psi cores psi_method = PSI_METHODS[type(tensor)] for mu in range(n_dims): if mu > 0: if method in (SketchMethod.hmt, SketchMethod.orthogonal): left_psi_drm.add_core(Psi_cores[-1]) left_sketch = next(left_psi_drm) else: left_sketch = left_contractions[mu - 1] r1 = left_drm.rank[mu - 1] else: left_sketch = None r1 = 1 if mu < n_dims - 1: right_sketch = right_contractions[mu] r2 = right_drm.rank[::-1][mu] else: right_sketch = None r2 = 1 psi_shape = (r1, tensor.shape[mu], r2) Psi = psi_method( left_sketch, right_sketch, tensor=tensor, mu=mu, psi_shape=psi_shape ) # type: ignore if mu < n_dims - 1: if method == SketchMethod.orthogonal: Psi = orth_step(Psi, Omega_mats[mu]) elif method == SketchMethod.hmt: Psi = orth_step(Psi, None) Psi_cores.append(Psi) return SketchContainer(Psi_cores, Omega_mats)