Skip to content

Commit

Permalink
Merge pull request #13 from iongabrielion/amen_mm
Browse files Browse the repository at this point in the history
amen_mm
  • Loading branch information
ion-g-ion committed Aug 30, 2023
2 parents 245e002 + 03a20b5 commit 20c9816
Show file tree
Hide file tree
Showing 5 changed files with 629 additions and 49 deletions.
22 changes: 22 additions & 0 deletions test_amen_mult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torchtt as tntt
import torch as tn

A = tntt.random([(4, 4), (5, 5), (6, 6), (3, 3)], [1, 2, 3, 2, 1])
B = tntt.random([(4, 4), (5, 5), (6, 6), (3, 3)], [1, 3, 2, 2, 1])

Cr = A @ B

C = tntt.amen_mm(A, B, kickrank=8, verbose=False)


print((C-Cr).norm()/Cr.norm())

A = tntt.random([(4, 4), (5, 5), (6, 6), (3, 3)], [1, 2, 3, 2, 1])
B = tntt.random([(4, 3), (5, 2), (6, 5), (3, 6)], [1, 3, 2, 2, 1])

Cr = A @ B

C = tntt.amen_mm(A, B, kickrank=8, verbose=False)


print((C-Cr).norm()/Cr.norm())
102 changes: 71 additions & 31 deletions tests/test_algebra_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import torch as tn
import numpy as np

err_rel = lambda t, ref : tn.linalg.norm(t-ref).numpy() / tn.linalg.norm(ref).numpy() if ref.shape == t.shape else np.inf



def err_rel(t, ref): return tn.linalg.norm(t-ref).numpy() / \
tn.linalg.norm(ref).numpy() if ref.shape == t.shape else np.inf


class TestLinalgAdvanced(unittest.TestCase):

basic_dtype = tn.float64

def test_dmrg_hadamard(self):
Expand All @@ -34,21 +36,20 @@ def test_dmrg_hadamard(self):
rel_error = (y-yf).norm().numpy()/y.norm().numpy()

self.assertLess(rel_error,1e-12,"DMRG elementwise multiplication.")



def test_dmrg_matvec(self):
"""
Test the fast matrix vector product using DMRG iterations.
"""
n = 32
A = tntt.random([(n,n)]*8,[1]+7*[3]+[1], dtype = tn.complex128)
Am = A + A
x = tntt.random([n]*8,[1]+7*[5]+[1], dtype = tn.complex128)
A = tntt.random([(n, n)]*8, [1]+7*[3]+[1], dtype=tn.complex128)
Am = A + A

x = tntt.random([n]*8, [1]+7*[5]+[1], dtype=tn.complex128)
xm = x + x
xm = xm + xm
# conventional method

# conventional method
y = 8 * (A @ x).round(1e-12)

# dmrg matvec
Expand Down Expand Up @@ -80,35 +81,74 @@ def test_amen_division(self):
"""
Test the division between tensors performed with AMEN optimization.
"""
N = [7,8,9,10]
xs = tntt.meshgrid([tn.linspace(0,1,n, dtype = self.basic_dtype) for n in N])
N = [7, 8, 9, 10]
xs = tntt.meshgrid(
[tn.linspace(0, 1, n, dtype=self.basic_dtype) for n in N])
x = xs[0]+xs[1]+xs[2]+xs[3]+xs[1]*xs[2]+(1-xs[3])*xs[2]+1
x = x.round(0)
y = tntt.ones(x.N, dtype = self.basic_dtype)
y = tntt.ones(x.N, dtype=self.basic_dtype)

a = y/x
b = 1/x
c = tn.tensor(1.0)/x

self.assertLess(err_rel(a.full(),y.full()/x.full()),1e-11,"AMEN division problem: TT and TT.")
self.assertLess(err_rel(b.full(),1/x.full()),1e-11,"AMEN division problem: scalar and TT.")
self.assertLess(err_rel(c.full(),1/x.full()),1e-11,"AMEN division problem: scalar and TT part 2.")


self.assertLess(err_rel(a.full(), y.full()/x.full()),
1e-11, "AMEN division problem: TT and TT.")
self.assertLess(err_rel(b.full(), 1/x.full()), 1e-11,
"AMEN division problem: scalar and TT.")
self.assertLess(err_rel(c.full(), 1/x.full()), 1e-11,
"AMEN division problem: scalar and TT part 2.")

def test_amen_division_preconditioned(self):
"""
Test the elemntwise division using AMEN (use preconditioner for the local subsystem).
"""
N = [7,8,9,10]
xs = tntt.meshgrid([tn.linspace(0,1,n, dtype = self.basic_dtype) for n in N])
N = [7, 8, 9, 10]
xs = tntt.meshgrid(
[tn.linspace(0, 1, n, dtype=self.basic_dtype) for n in N])
x = xs[0]+xs[1]+xs[2]+xs[3]+xs[1]*xs[2]+(1-xs[3])*xs[2]+1
x = x.round(0)
y = tntt.ones(x.N)

a = tntt.elementwise_divide(y,x,preconditioner = 'c')


self.assertLess(err_rel(a.full(),y.full()/x.full()),1e-11,"AMEN division problem (preconditioner): TT and TT.")



a = tntt.elementwise_divide(y, x, preconditioner='c')

self.assertLess(err_rel(a.full(), y.full()/x.full()), 1e-11,
"AMEN division problem (preconditioner): TT and TT.")

def test_amen_mv(self):
"""
Thet the AMEn matvec.
"""

A = tntt.randn([(3, 4), (5, 6), (7, 8), (2, 3)], [1, 2, 2, 3, 1])
x = tntt.randn([4, 6, 8, 3], [1, 4, 3, 3, 1])

Cr = 25 * A @ x

A = A + A + A + A + A
x = x + x + x + x + x

C = tntt.amen_mv(A, x)

self.assertLess((C-Cr).norm()/Cr.norm(), 1e-11, "AMEN matvec.")

def test_amen_mm(self):
"""
Thet the AMEn matmat.
"""

A = tntt.randn([(3, 4), (5, 6), (7, 8), (2, 3)], [1, 2, 2, 3, 1])
B = tntt.randn([(4, 2), (6, 4), (8, 5), (3, 7)], [1, 4, 3, 3, 1])

Cr = 25 * A @ B

A = A + A + A + A + A
B = B + B + B + B + B

C = tntt.amen_mm(A, B)

self.assertLess((C-Cr).norm()/Cr.norm(), 1e-11, "AMEN matmul.")


if __name__ == '__main__':
unittest.main()
unittest.main()
26 changes: 23 additions & 3 deletions torchtt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
"""


from ._tt_base import TT
from ._extras import eye, zeros, kron, ones, random, randn, reshape, meshgrid , dot, elementwise_divide, numel, rank1TT, bilinear_form, diag, permute, load, save, cat, pad, shape_mn_to_tuple, shape_tuple_to_mn
# from .torchtt import TT, eye, zeros, kron, ones, random, randn, reshape, meshgrid , dot, elementwise_divide, numel, rank1TT, bilinear_form, diag, permute, load, save, cat, pad
from ._dmrg import dmrg_hadamard

__all__ = ['TT', 'eye', 'zeros', 'kron', 'ones', 'random', 'randn', 'reshape', 'meshgrid', 'dot', 'elementwise_divide', 'numel', 'rank1TT', 'bilinear_form', 'diag', 'permute', 'load', 'save', 'cat', 'pad', 'shape_mn_to_tuple', 'shape_tuple_to_mn', 'dmrg_hadamard']

from ._amen import amen_mm, amen_mv
from . import solvers
from . import grad
# from .grad import grad, watch, unwatch
Expand All @@ -24,3 +23,24 @@
from . import nn
from . import cpp
# from .errors import *

try:
import torchttcpp
_flag_use_cpp = True
except:
import warnings
warnings.warn(
"\x1B[33m\nC++ implementation not available. Using pure Python.\n\033[0m")
_flag_use_cpp = False


def cpp_enabled():
"""
Is the C++ backend enabled?
Returns:
bool: the flag
"""
return _flag_use_cpp

__all__ = ['TT', 'eye', 'zeros', 'kron', 'ones', 'random', 'randn', 'reshape', 'meshgrid' , 'dot', 'elementwise_divide', 'numel', 'rank1TT', 'bilinear_form', 'diag', 'permute', 'load', 'save', 'cat', 'amen_mm', 'amen_mv', 'cpp_available', 'pad', 'shape_mn_to_tuple', 'shape_tuple_to_mn', 'dmrg_hadamard']
Loading

0 comments on commit 20c9816

Please sign in to comment.