Source code for tt_sketch.sketch

"""Interface for the streaming and orthogonal sketching algorithms"""
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, Type

import numpy as np
import numpy.typing as npt

from tt_sketch.drm import (
    ALL_DRM,
    DenseGaussianDRM,
    SparseGaussianDRM,
    TensorTrainDRM,
)
from tt_sketch.drm_base import DRM, CanIncreaseRank, CanSlice
from tt_sketch.sketch_container import SketchContainer
from tt_sketch.sketch_dispatch import SketchMethod, general_sketch
from tt_sketch.sketching_methods.abstract_methods import (
    CansketchCP,
    CansketchDense,
    CansketchSparse,
    CansketchTT,
)
from tt_sketch.tensor import Tensor, TensorTrain
from tt_sketch.utils import (
    ArrayList,
    TTRank,
    left_mul_pinv,
    process_tt_rank,
    right_mul_pinv,
)

DEFAULT_DRM = {
    CansketchDense: DenseGaussianDRM,
    CansketchSparse: SparseGaussianDRM,
    CansketchTT: TensorTrainDRM,
    CansketchCP: TensorTrainDRM,
}

BlockedSketch = Dict[Tuple[int, int], SketchContainer]


[docs]def hmt_sketch( tensor: Tensor, rank: TTRank, seed: Optional[int] = None, drm_type: Optional[Type[DRM]] = None, drm: Optional[DRM] = None, return_drm: bool = False, ) -> TensorTrain: """ Perform an orthogonal sketch of a tensor """ d = len(tensor.shape) if seed is None: seed = np.mod(hash(np.random.uniform()), 2**32) if drm is None: if drm_type is None: drm_type = TensorTrainDRM rank = process_tt_rank(rank, tensor.shape, trim=True) drm = drm_type(rank, transpose=True, shape=tensor.shape, seed=seed) else: if tuple(drm.rank[::-1]) != rank: raise ValueError( f"Right rank {rank} does not match the rank of the DRM " f"{drm.rank}." ) sketch = general_sketch(tensor, None, drm, method=SketchMethod.hmt) sketched = TensorTrain(sketch.Psi_cores) if return_drm: # this really is mostly for testing purposes return sketched, drm, right_drm # type: ignore else: return sketched
[docs]def orthogonal_sketch( tensor: Tensor, left_rank: TTRank, right_rank: TTRank, seed: Optional[int] = None, left_drm_type: Optional[Type[DRM]] = None, right_drm_type: Optional[Type[DRM]] = None, left_drm: Optional[DRM] = None, right_drm: Optional[DRM] = None, return_drm: bool = False, ) -> TensorTrain: """ Perform an orthogonal sketch of a tensor """ d = len(tensor.shape) right_rank_bigger = bool(np.all(np.array(left_rank) < np.array(right_rank))) if not right_rank_bigger: raise ValueError( f"The right rank needs to be larger than the left rank. " f"Left rank: {left_rank}, " f"right rank: {right_rank}" ) if seed is None: seed = np.mod(hash(np.random.uniform()), 2**32) if left_drm is None: if left_drm_type is None: if right_drm_type is not None: left_drm_type = right_drm_type else: left_drm_type = TensorTrainDRM left_rank = process_tt_rank(left_rank, tensor.shape, trim=True) left_drm = left_drm_type( left_rank, transpose=False, shape=tensor.shape, seed=seed ) else: if left_drm.rank != left_rank: raise ValueError( f"Left rank {left_rank} does not match the rank of the DRM " f"{left_drm.rank}." ) if right_drm is None: if right_drm_type is None: if left_drm_type is not None: right_drm_type = left_drm_type else: right_drm_type = TensorTrainDRM right_rank = process_tt_rank(right_rank, tensor.shape, trim=False) right_seed = np.mod(seed + hash(str(d)), 2**32) right_drm = right_drm_type( right_rank, transpose=True, shape=tensor.shape, seed=right_seed ) else: if tuple(right_drm.rank[::-1]) != right_rank: raise ValueError( f"Right rank {right_rank} does not match the rank of the DRM " f"{right_drm.rank}." ) sketch = general_sketch( tensor, left_drm, right_drm, method=SketchMethod.orthogonal ) sketched = TensorTrain(sketch.Psi_cores) if return_drm: # this really is mostly for testing purposes return sketched, left_drm, right_drm # type: ignore else: return sketched
[docs]def stream_sketch( tensor: Tensor, left_rank: TTRank, right_rank: TTRank, seed: Optional[int] = None, left_drm_type: Optional[Type[DRM]] = None, right_drm_type: Optional[Type[DRM]] = None, left_drm: Optional[DRM] = None, right_drm: Optional[DRM] = None, return_drm: bool = False, ) -> SketchedTensorTrain: """ Perform a streaming sketch of a tensor """ d = len(tensor.shape) left_rank_bigger = bool(np.all(np.array(left_rank) > np.array(right_rank))) right_rank_bigger = bool(np.all(np.array(left_rank) < np.array(right_rank))) if not left_rank_bigger and not right_rank_bigger: raise ValueError( f"Left ranks or right ranks must be conistently larger or smaller " f"than the other. Left rank: {left_rank}, " f"right rank: {right_rank}" ) if seed is None: seed = np.mod(hash(np.random.uniform()), 2**32) if left_drm is None: if left_drm_type is None: if right_drm_type is not None: left_drm_type = right_drm_type else: left_drm_type = TensorTrainDRM left_rank = process_tt_rank( left_rank, tensor.shape, trim=right_rank_bigger ) left_drm = left_drm_type( left_rank, transpose=False, shape=tensor.shape, seed=seed ) else: if left_drm.rank != left_rank: raise ValueError( f"Left rank {left_rank} does not match the rank of the DRM " f"{left_drm.rank}." ) if right_drm is None: if right_drm_type is None: if left_drm_type is not None: right_drm_type = left_drm_type else: right_drm_type = TensorTrainDRM right_rank = process_tt_rank( right_rank, tensor.shape, trim=left_rank_bigger ) right_seed = np.mod(seed + hash(str(d)), 2**32) right_drm = right_drm_type( right_rank, transpose=True, shape=tensor.shape, seed=right_seed ) else: if tuple(right_drm.rank[::-1]) != right_rank: raise ValueError( f"Right rank {right_rank} does not match the rank of the DRM " f"{right_drm.rank}." ) sketch = general_sketch( tensor, left_drm, right_drm, method=SketchMethod.streaming ) sketched = SketchedTensorTrain(sketch, left_drm, right_drm) if return_drm: # this really is mostly for testing purposes return sketched, left_drm, right_drm # type: ignore else: return sketched
[docs]@dataclass class SketchedTensorTrain(Tensor): """ Container for storing the output of the streaming sketch Stores the result of the sketch as well as the DRMs used for the sketching. Can be cheaply converted to a tensor train, or the sketch can be efficiently updated using the ``__add__`` method. """ sketch_: SketchContainer left_drm: DRM right_drm: DRM @property def left_rank(self) -> Tuple[int, ...]: return self.left_drm.rank @property def right_rank(self) -> Tuple[int, ...]: return self.right_drm.rank[::-1] @property def Psi_cores(self) -> ArrayList: return self.sketch_.Psi_cores @property def size(self) -> int: total_Psi_size = sum(Psi.size for Psi in self.Psi_cores) total_Omega_size = sum(Omega.size for Omega in self.Omega_mats) return total_Psi_size + total_Omega_size @property def Omega_mats(self) -> ArrayList: return self.sketch_.Omega_mats def __post_init__(self): self.shape = self.sketch_.shape
[docs] def C_cores(self, direction="auto") -> ArrayList: return assemble_sketched_tt(self.sketch_, direction=direction)
@property def T(self) -> SketchedTensorTrain: new_sketch = self.sketch_.T return self.__class__(new_sketch, self.right_drm.T, self.left_drm.T)
[docs] def to_tt(self) -> TensorTrain: return TensorTrain(self.C_cores())
[docs] def to_numpy(self) -> npt.NDArray[np.float64]: return self.to_tt().to_numpy()
def __repr__(self) -> str: return ( f"<Sketched tensor train of shape {self.shape} with left-rank " f"{self.left_rank} and right-rank {self.right_rank} " f"at {hex(id(self))}>" ) def __add__(self, other: Tensor) -> SketchedTensorTrain: other_sketch = stream_sketch( other, self.left_rank, self.right_rank, left_drm=self.left_drm, right_drm=self.right_drm, ) new_sketch = self.sketch_ + other_sketch.sketch_ return self.__class__(new_sketch, self.left_drm, self.right_drm)
[docs] def increase_rank( self, tensor: Tensor, new_left_rank: TTRank, new_right_rank: TTRank, ) -> SketchedTensorTrain: """Increase the rank of the approximation by performing a new sketch. Requires DRM with support for the ``CanIncreaseRank`` protocol, which currently is only supported by ``SparseGaussianDRM``. """ new_left_rank = process_tt_rank(new_left_rank, tensor.shape, trim=False) new_right_rank = process_tt_rank( new_right_rank, tensor.shape, trim=False ) for drm in (self.left_drm, self.right_drm): if not isinstance(drm, CanSlice): drm_name = drm.__class__.__name__ raise ValueError( f"Increasing rank is not supported for DRM {drm_name}" ) n_dims = len(tensor.shape) left_rank_slices = [ (0,) * (n_dims - 1), self.left_drm.rank, new_left_rank, ] right_rank_slices = [ (0,) * (n_dims - 1), self.right_drm.rank[::-1], new_right_rank, ] left_drm = self.left_drm.increase_rank(new_left_rank) # type: ignore right_drm = self.right_drm.increase_rank(new_right_rank) # type: ignore sketch_dict = _blocked_stream_sketch_components( tensor, left_drm, right_drm, left_rank_slices, right_rank_slices, excluded_entries=[(0, 0)], ) sketch_dict[(0, 0)] = self.sketch_ sketch = _assemble_blocked_stream_sketches( left_rank_slices, right_rank_slices, tensor.shape, sketch_dict ) return self.__class__(sketch, left_drm, right_drm)
def __mul__(self, other: float) -> SketchedTensorTrain: return self.__class__( self.sketch_ * other, self.left_drm, self.right_drm )
[docs] def dot(self, other: Tensor, reverse=False) -> float: return self.to_tt().dot(other, reverse)
def _blocked_stream_sketch_components( tensor: Tensor, left_rm: CanSlice, right_drm: CanSlice, left_rank_slices: List[Tuple[int, ...]], right_rank_slices: List[Tuple[int, ...]], excluded_entries: Optional[Sequence[Tuple[int, int]]] = None, ) -> BlockedSketch: if excluded_entries is None: excluded_entries = [] block_left_sketches = [ left_rm.slice(rank1, rank2) for rank1, rank2 in zip(left_rank_slices[:-1], left_rank_slices[1:]) ] block_right_sketches = [ right_drm.slice(rank1, rank2) for rank1, rank2 in zip(right_rank_slices[:-1], right_rank_slices[1:]) ] # Compute all the sketches sketch_dict = {} for i, left_sketch_slice in enumerate(block_left_sketches): for j, right_sketch_slice in enumerate(block_right_sketches): if (i, j) in excluded_entries: continue sketch_block = general_sketch( tensor, left_sketch_slice, right_sketch_slice, method=SketchMethod.streaming, ) sketch_dict[(i, j)] = sketch_block return sketch_dict
[docs]def assemble_sketched_tt( sketch: SketchContainer, direction="auto", ) -> ArrayList: """Reconstructs a TT from a sketch, using Psi and Omega matrices.""" tt_cores = [] if direction == "auto": left_rank_bigger = np.all( np.array(sketch.left_rank) > np.array(sketch.right_rank) ) direction = "left" if left_rank_bigger else "right" if direction == "right": for Psi, Omega in zip(sketch.Psi_cores[:-1], sketch.Omega_mats): Psi_shape = Psi.shape Psi_mat = Psi.reshape(Psi_shape[0] * Psi_shape[1], Psi_shape[2]) try: Psi_Omega_pinv = right_mul_pinv(Psi_mat, Omega) except ValueError: print(Psi.shape, Omega.shape) raise core = Psi_Omega_pinv.reshape( Psi_shape[0], Psi_shape[1], Omega.shape[0] ) tt_cores.append(core) tt_cores.append(sketch.Psi_cores[-1]) elif direction == "left": tt_cores.append(sketch.Psi_cores[0]) for Psi, Omega in zip(sketch.Psi_cores[1:], sketch.Omega_mats): Psi_shape = Psi.shape Psi_mat = Psi.reshape(Psi_shape[0], Psi_shape[1] * Psi_shape[2]) try: Omega_pinv_Psi = left_mul_pinv(Omega, Psi_mat) except ValueError: print(Psi.shape, Omega.shape) raise core = Omega_pinv_Psi.reshape( Omega.shape[1], Psi_shape[1], Psi_shape[2] ) tt_cores.append(core) else: raise ValueError(f"Unknown direction {direction}") return tt_cores
def _assemble_blocked_stream_sketches( left_rank_slices: List[Tuple[int, ...]], right_rank_slices: List[Tuple[int, ...]], shape: Tuple[int, ...], sketch_dict: BlockedSketch, ) -> SketchContainer: left_rank = tuple(left_rank_slices[-1]) right_rank = tuple(right_rank_slices[-1]) sketch = SketchContainer.zero(shape, left_rank, right_rank) for (i, j), sketch_block in sketch_dict.items(): left_rank1 = (0,) + left_rank_slices[i] left_rank2 = (1,) + left_rank_slices[i + 1] right_rank1 = right_rank_slices[j] + (0,) right_rank2 = right_rank_slices[j + 1] + (1,) for mu, Psi in enumerate(sketch_block.Psi_cores): sketch.Psi_cores[mu][ left_rank1[mu] : left_rank2[mu], :, right_rank1[mu] : right_rank2[mu], ] = Psi for mu, Omega in enumerate(sketch_block.Omega_mats): sketch.Omega_mats[mu][ left_rank1[mu + 1] : left_rank2[mu + 1], right_rank1[mu] : right_rank2[mu], ] = Omega return sketch
[docs]def get_drm_capabilities(): """List what all the DRMs are capable of""" all_capabilities = {} for drm in ALL_DRM: drm_capabilities = {} for capability in ( CanSlice, CanIncreaseRank, CansketchSparse, CansketchDense, CansketchTT, ): drm_capabilities[capability.__name__] = issubclass(drm, capability) all_capabilities[drm.__name__] = drm_capabilities return all_capabilities
[docs]def blocked_stream_sketch( tensor: Tensor, left_drm: CanSlice, right_drm: CanSlice, left_rank_slices: List[Tuple[int, ...]], right_rank_slices: List[Tuple[int, ...]], ) -> SketchContainer: """Do a blocked sketch. It's use is mainly theoretical, since this this would only be faster in a distributed setting (which isn't properly supported). """ for drm in (left_drm, right_drm): if not isinstance(drm, CanSlice): drm_name = drm.__class__.__name__ raise ValueError(f"Blocked sketch not supported for DRM {drm_name}") sketch_dict = _blocked_stream_sketch_components( tensor, left_drm, right_drm, left_rank_slices, right_rank_slices, ) sketch = _assemble_blocked_stream_sketches( left_rank_slices, right_rank_slices, tensor.shape, sketch_dict, ) return sketch