Source code for tt_sketch.drm.tensor_train_drm

from typing import Optional, Tuple, Union

import numpy as np
from tt_sketch.drm_base import CanSlice, handle_transpose
from tt_sketch.sketching_methods.abstract_methods import (
    CansketchCP,
    CansketchDense,
    CansketchSparse,
    CansketchTT,
    CanSketchTucker,
)
from tt_sketch.tensor import (
    CPTensor,
    DenseTensor,
    SparseTensor,
    TensorTrain,
    TuckerTensor,
)
from tt_sketch.utils import ArrayGenerator, ArrayList


# TODO: Store DRM as a tensor.TensorTrain
[docs]class TensorTrainDRM( CansketchSparse, CansketchTT, CansketchCP, CanSlice, CansketchDense, CanSketchTucker, ): """ Tensor train DRM. Sketches with partial contractions of a fixed TT. """ cores: ArrayList def __init__( self, rank: Union[Tuple[int, ...], int], shape: Tuple[int, ...], transpose: bool, seed: Optional[int] = None, **kwargs, ) -> None: super().__init__(rank, shape, transpose, seed=seed, **kwargs) if transpose: tt_shape = self.shape[::-1] tt_rank = self.true_rank else: tt_shape = self.shape tt_rank = self.true_rank if "cores" not in kwargs: tt = TensorTrain.random( tt_shape, tt_rank, self.seed, norm_goal="norm-preserve" ) self.cores = tt.cores[:-1] else: self.cores = kwargs["cores"] @handle_transpose def sketch_sparse(self, tensor: SparseTensor) -> ArrayGenerator: for mu, core in enumerate(self.cores): core_slice = core[:, tensor.indices[mu], :] if mu == 0: lr_contract = core_slice.reshape(core_slice.shape[1:]) else: lr_contract = np.einsum("ijk,ji->jk", core_slice, lr_contract) sketch_mat = lr_contract[:, self.rank_min[mu] : self.rank_max[mu]] yield sketch_mat.T @handle_transpose def sketch_tt(self, tensor: TensorTrain) -> ArrayGenerator: n_dims = len(self.shape) for mu in range(n_dims - 1): tensor_core = tensor.cores[mu] drm_core = self.cores[mu] if mu == 0: lr_contract = np.einsum("ijk,ijl->kl", tensor_core, drm_core) else: lr_contract = np.einsum( "ij,ikl,jkm->lm", lr_contract, tensor_core, drm_core, optimize="optimal", ) yield lr_contract[:, self.rank_min[mu] : self.rank_max[mu]] @handle_transpose def sketch_cp(self, tensor: CPTensor) -> ArrayGenerator: n_dims = len(self.shape) for mu in range(n_dims - 1): tensor_core = tensor.cores[mu] drm_core = self.cores[mu] if mu == 0: lr_contract = np.einsum("ij,lik->jk", tensor_core, drm_core) else: lr_contract = np.einsum( "ij,ki,jkl->il", lr_contract, tensor_core, drm_core, optimize="optimal", ) yield lr_contract[:, self.rank_min[mu] : self.rank_max[mu]] @handle_transpose def sketch_dense(self, tensor: DenseTensor) -> ArrayGenerator: n_dims = len(self.shape) partial_contraction = self.cores[0].reshape(-1, self.cores[0].shape[-1]) yield partial_contraction.T for mu in range(1, n_dims - 1): core = self.cores[mu] partial_contraction = np.einsum( "ij,jkl->ikl", partial_contraction, core ) partial_contraction = partial_contraction.reshape( -1, partial_contraction.shape[-1] ) yield partial_contraction.T @handle_transpose def sketch_tucker(self, tensor: TuckerTensor) -> ArrayGenerator: n_dims = len(self.shape) partial_contraction = np.einsum( "ijk,jl->ilk", self.cores[0], tensor.factors[0].T ) partial_contraction = partial_contraction.reshape( tensor.rank[0], self.rank[0] ) yield partial_contraction for mu in range(1, n_dims - 1): core_reduced = np.einsum( "jkl,km->jml", self.cores[mu], tensor.factors[mu].T ) partial_contraction = np.einsum( "ij,jml->iml", partial_contraction, core_reduced ) partial_contraction = partial_contraction.reshape( -1, partial_contraction.shape[-1] ) yield partial_contraction