Skip to content

Commit 23749d8

Browse files
committed
dev(identity): add support for identity
- Add support for identity - Add self multiplication verification for `identity` - Add taylor expansion verification for `exponentiation` - Add assertation test for `exponentiation` and `identity` - Modify the arrow processing logic for `expoenntiation`
1 parent 96224e9 commit 23749d8

File tree

3 files changed

+182
-23
lines changed

3 files changed

+182
-23
lines changed

grassmann_tensor/tensor.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -744,17 +744,18 @@ def _check_pairs_coverage(self, pairs: tuple[tuple[int, ...], tuple[int, ...]])
744744
def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor:
745745
tensor, left_legs, right_legs = self._group_edges(pairs)
746746

747-
arrow_order = (False, True)
748-
edges_to_reverse = tuple(
749-
i for i, arrow in enumerate(arrow_order) if tensor.arrow[i] != arrow
747+
assert tensor.arrow in ((False, True), (True, False)), (
748+
f"Exponentiation requires arrow (False, True) or (True, False), but got {tensor.arrow}"
750749
)
751-
if edges_to_reverse:
752-
tensor = tensor.reverse(edges_to_reverse)
750+
751+
tensor_reverse_flag = tensor.arrow != (False, True)
752+
if tensor_reverse_flag:
753+
tensor = tensor.reverse((0, 1))
753754

754755
left_dim, right_dim = tensor.tensor.shape
755756

756757
assert left_dim == right_dim, (
757-
f"Exponential requires a square operator, but got {left_dim} x {right_dim}."
758+
f"Exponentiation requires a square operator, but got {left_dim} x {right_dim}."
758759
)
759760

760761
(even_left, odd_left) = tensor.edges[0]
@@ -774,8 +775,8 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma
774775

775776
tensor_exp = dataclasses.replace(tensor, _tensor=tensor_exp)
776777

777-
if edges_to_reverse:
778-
tensor_exp = tensor_exp.reverse(tuple(edges_to_reverse))
778+
if tensor_reverse_flag:
779+
tensor_exp = tensor_exp.reverse((0, 1))
779780

780781
order = left_legs + right_legs
781782
edges_after_permute = tuple(self.edges[i] for i in order)
@@ -787,26 +788,36 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma
787788

788789
return tensor_exp
789790

790-
def identity(self, pairs: tuple[int, ...]) -> GrassmannTensor:
791+
def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor:
791792
tensor, left_legs, right_legs = self._group_edges(pairs)
792793

793-
arrow_order = (False, True)
794-
edges_to_reverse = tuple(
795-
i for i, arrow in enumerate(arrow_order) if tensor.arrow[i] != arrow
794+
assert tensor.arrow in ((False, True), (True, False)), (
795+
f"Identity requires arrow (False, True) or (True, False), but got {tensor.arrow}"
796796
)
797797

798+
tensor_reverse_flag = tensor.arrow != (False, True)
799+
if tensor_reverse_flag:
800+
tensor = tensor.reverse((0, 1))
801+
798802
left_dim, right_dim = tensor.tensor.shape
799803

800804
assert left_dim == right_dim, (
801-
f"Exponential requires a square operator, but got {left_dim} x {right_dim}."
805+
f"Identity requires a square operator, but got {left_dim} x {right_dim}."
806+
)
807+
808+
(even_left, odd_left) = tensor.edges[0]
809+
(even_right, odd_right) = tensor.edges[1]
810+
811+
assert even_left == even_right and odd_left == odd_right, (
812+
f"Parity blocks must be square, but got L=({even_left},{odd_left}), R=({even_right},{odd_right})"
802813
)
803814

804815
I = torch.eye(left_dim, dtype=tensor.tensor.dtype, device=tensor.tensor.device) # noqa: E741
805816

806817
tensor_identity = dataclasses.replace(tensor, _tensor=I)
807818

808-
if edges_to_reverse:
809-
tensor_identity = tensor_identity.reverse(tuple(edges_to_reverse))
819+
if tensor_reverse_flag:
820+
tensor_identity = tensor_identity.reverse((0, 1))
810821

811822
order = left_legs + right_legs
812823
edges_after_permute = tuple(self.edges[i] for i in order)

tests/exponential_test.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import torch
22
import pytest
3+
from typing import TypeAlias
34

45
from grassmann_tensor import GrassmannTensor
56

7+
Tensor: TypeAlias = GrassmannTensor
8+
Pairs: TypeAlias = tuple[tuple[int, ...], tuple[int, ...]]
9+
610

711
def test_exponential() -> None:
812
a = GrassmannTensor(
9-
(True, True, True, True),
13+
(False, True, True, False),
1014
((4, 4), (8, 8), (4, 4), (8, 8)),
1115
torch.randn(8, 16, 8, 16, dtype=torch.float64),
1216
)
@@ -28,5 +32,79 @@ def test_exponential_assertation() -> None:
2832
((2, 2), (4, 4), (8, 8), (16, 16)),
2933
torch.randn(4, 8, 16, 32, dtype=torch.float64),
3034
)
31-
with pytest.raises(AssertionError, match="Exponential requires a square operator"):
35+
with pytest.raises(AssertionError, match="Exponentiation requires arrow"):
3236
a.exponential(((0, 2), (1, 3)))
37+
38+
b = GrassmannTensor(
39+
(False, True, False, True),
40+
((2, 2), (4, 4), (8, 8), (16, 16)),
41+
torch.randn(4, 8, 16, 32, dtype=torch.float64),
42+
)
43+
with pytest.raises(AssertionError, match="Exponentiation requires a square operator"):
44+
b.exponential(((0, 2), (1, 3)))
45+
46+
c = GrassmannTensor(
47+
(False, True, False, True),
48+
((1, 3), (3, 1), (3, 1), (3, 1)),
49+
torch.randn(4, 4, 4, 4, dtype=torch.float64),
50+
)
51+
with pytest.raises(AssertionError, match="Parity blocks must be square"):
52+
c.exponential(((0, 2), (1, 3)))
53+
54+
55+
@pytest.mark.parametrize(
56+
"tensor, pairs",
57+
[
58+
(
59+
GrassmannTensor(
60+
(False, True), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64)
61+
),
62+
((0,), (1,)),
63+
),
64+
(
65+
GrassmannTensor(
66+
(True, False), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64)
67+
),
68+
((0,), (1,)),
69+
),
70+
(
71+
GrassmannTensor(
72+
(False, False, True),
73+
((4, 4), (4, 4), (32, 32)),
74+
torch.randn(8, 8, 64, dtype=torch.float64),
75+
),
76+
((0, 1), (2,)),
77+
),
78+
(
79+
GrassmannTensor(
80+
(False, False, True, True),
81+
((4, 4), (8, 8), (4, 4), (8, 8)),
82+
torch.randn(8, 16, 8, 16, dtype=torch.float64),
83+
),
84+
((0, 1), (2, 3)),
85+
),
86+
],
87+
)
88+
def test_exponential_via_taylor_expansion(
89+
tensor: Tensor,
90+
pairs: Pairs,
91+
) -> None:
92+
tensor_exp = tensor.exponential(pairs)
93+
iter_tensor = tensor.identity(pairs)
94+
iter_tensor, _, _ = iter_tensor._group_edges(pairs)
95+
iter_tensor = iter_tensor.update_mask()
96+
tensor_group_edges, left_legs, right_legs = tensor._group_edges(pairs)
97+
tensor_group_edges = tensor_group_edges.update_mask()
98+
99+
tensor_taylor_expansion = iter_tensor
100+
for i in range(1, 50):
101+
iter_tensor @= tensor_group_edges / i
102+
tensor_taylor_expansion += iter_tensor
103+
104+
order = left_legs + right_legs
105+
edges_after_permute = tuple(tensor.edges[i] for i in order)
106+
tensor_taylor_expansion = tensor_taylor_expansion.reshape(edges_after_permute)
107+
inv_order = tensor._get_inv_order(order)
108+
tensor_taylor_expansion = tensor_taylor_expansion.permute(inv_order)
109+
110+
assert torch.allclose(tensor_taylor_expansion.tensor, tensor_exp.tensor)

tests/identity_test.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,83 @@
1+
import pytest
12
import torch
3+
from typing import TypeAlias
24

35
from grassmann_tensor import GrassmannTensor
46

7+
Tensor: TypeAlias = GrassmannTensor
8+
Pairs: TypeAlias = tuple[tuple[int, ...], tuple[int, ...]]
59

6-
def test_identity() -> None:
10+
11+
def test_identity_assertation() -> None:
712
a = GrassmannTensor(
8-
(False, False, True, True),
9-
((4, 4), (8, 8), (4, 4), (8, 8)),
10-
torch.randn(8, 16, 8, 16, dtype=torch.float64),
13+
(True, True, True, True),
14+
((2, 2), (4, 4), (8, 8), (16, 16)),
15+
torch.randn(4, 8, 16, 32, dtype=torch.float64),
16+
)
17+
with pytest.raises(AssertionError, match="Identity requires arrow"):
18+
a.identity(((0, 2), (1, 3)))
19+
20+
b = GrassmannTensor(
21+
(False, True, False, True),
22+
((2, 2), (4, 4), (8, 8), (16, 16)),
23+
torch.randn(4, 8, 16, 32, dtype=torch.float64),
1124
)
12-
a.identity((0, 1))
13-
a.identity((2, 3))
25+
with pytest.raises(AssertionError, match="Identity requires a square operator"):
26+
b.identity(((0, 2), (1, 3)))
27+
28+
c = GrassmannTensor(
29+
(False, True, False, True),
30+
((1, 3), (3, 1), (3, 1), (3, 1)),
31+
torch.randn(4, 4, 4, 4, dtype=torch.float64),
32+
)
33+
with pytest.raises(AssertionError, match="Parity blocks must be square"):
34+
c.identity(((0, 2), (1, 3)))
35+
36+
37+
@pytest.mark.parametrize(
38+
"tensor, pairs",
39+
[
40+
(
41+
GrassmannTensor(
42+
(False, True), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64)
43+
),
44+
((0,), (1,)),
45+
),
46+
(
47+
GrassmannTensor(
48+
(True, False), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64)
49+
),
50+
((0,), (1,)),
51+
),
52+
(
53+
GrassmannTensor(
54+
(False, False, True),
55+
((4, 4), (4, 4), (32, 32)),
56+
torch.randn(8, 8, 64, dtype=torch.float64),
57+
),
58+
((0, 1), (2,)),
59+
),
60+
(
61+
GrassmannTensor(
62+
(False, False, True, True),
63+
((4, 4), (8, 8), (4, 4), (8, 8)),
64+
torch.randn(8, 16, 8, 16, dtype=torch.float64),
65+
),
66+
((0, 1), (2, 3)),
67+
),
68+
],
69+
)
70+
def test_identity_via_self_multiplication(
71+
tensor: Tensor,
72+
pairs: Pairs,
73+
) -> None:
74+
identity = tensor.identity(pairs)
75+
identity, _, _ = identity._group_edges(pairs)
76+
tensor, _, _ = tensor._group_edges(pairs)
77+
tensor_reverse_flag = tensor.arrow != (False, True)
78+
if tensor_reverse_flag:
79+
identity = identity.reverse((0, 1))
80+
tensor = tensor.reverse((0, 1))
81+
assert torch.allclose((identity @ identity).tensor, identity.tensor)
82+
assert torch.allclose((identity @ tensor).tensor, tensor.tensor)
83+
assert torch.allclose((tensor @ identity).tensor, tensor.tensor)

0 commit comments

Comments
 (0)