Skip to content

Commit ad1b0df

Browse files
committed
Merge pull request #77 from USTC-KnowledgeComputingLab/dev/add-support-for-identity
Add support for identity Close #77
2 parents 74c5fab + da11611 commit ad1b0df

File tree

6 files changed

+412
-72
lines changed

6 files changed

+412
-72
lines changed

grassmann_tensor/tensor.py

Lines changed: 216 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
308308
if (isinstance(new_shape_check, int) and new_shape_check == 1) or (
309309
new_shape_check == (1, 0)
310310
):
311-
arrow.append(False)
311+
if cursor_plan < len(self.arrow):
312+
arrow.append(self.arrow[cursor_plan])
313+
else:
314+
arrow.append(False)
312315
edges.append((1, 0))
313316
shape.append(1)
314317
cursor_plan += 1
@@ -573,22 +576,18 @@ def _group_edges(
573576
self,
574577
pairs: tuple[int, ...] | tuple[tuple[int, ...], tuple[int, ...]],
575578
) -> tuple[GrassmannTensor, tuple[int, ...], tuple[int, ...]]:
576-
if (isinstance(pairs, tuple) and len(pairs)) and all(
577-
isinstance(x, tuple) and all(isinstance(i, int) for i in x) for x in pairs
578-
):
579-
left_legs = typing.cast(tuple[int, ...], pairs[0])
580-
right_legs = typing.cast(tuple[int, ...], pairs[1])
581-
else:
582-
left_legs = typing.cast(tuple[int, ...], pairs)
583-
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
579+
return self.group_edges(self, pairs)
584580

585-
assert self._check_pairs_coverage((left_legs, right_legs)), (
586-
f"Input pairs must cover all dimension and disjoint, but got {(left_legs, right_legs)}"
587-
)
581+
@staticmethod
582+
def group_edges(
583+
tensor: GrassmannTensor,
584+
pairs: tuple[int, ...] | tuple[tuple[int, ...], tuple[int, ...]],
585+
) -> tuple[GrassmannTensor, tuple[int, ...], tuple[int, ...]]:
586+
left_legs, right_legs = GrassmannTensor.get_legs_pair(tensor.tensor.dim(), pairs)
588587

589588
order = left_legs + right_legs
590589

591-
tensor = self.permute(order)
590+
tensor = tensor.permute(order)
592591

593592
left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
594593
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])
@@ -597,6 +596,37 @@ def _group_edges(
597596

598597
return tensor, left_legs, right_legs
599598

599+
@staticmethod
600+
def get_legs_pair(
601+
dim: int, pairs: tuple[int, ...] | tuple[tuple[int, ...], tuple[int, ...]]
602+
) -> tuple[tuple[int, ...], tuple[int, ...]]:
603+
def check_pairs_coverage(dim: int, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> bool:
604+
set0 = set(pairs[0])
605+
set1 = set(pairs[1])
606+
607+
are_disjoint = set0.isdisjoint(set1)
608+
609+
is_complete_union = (set0 | set1) == set(range(dim))
610+
611+
no_duplicates = len(pairs[0]) + len(pairs[1]) == dim
612+
613+
return are_disjoint and is_complete_union and no_duplicates
614+
615+
if (isinstance(pairs, tuple) and len(pairs)) and all(
616+
isinstance(x, tuple) and all(isinstance(i, int) for i in x) for x in pairs
617+
):
618+
left_legs = typing.cast(tuple[int, ...], pairs[0])
619+
right_legs = typing.cast(tuple[int, ...], pairs[1])
620+
else:
621+
left_legs = typing.cast(tuple[int, ...], pairs)
622+
right_legs = tuple(i for i in range(dim) if i not in left_legs)
623+
624+
assert check_pairs_coverage(dim, (left_legs, right_legs)), (
625+
f"Input pairs must cover all dimension and disjoint, but got {(left_legs, right_legs)}"
626+
)
627+
628+
return left_legs, right_legs
629+
600630
def svd(
601631
self,
602632
free_names_u: tuple[int, ...],
@@ -622,7 +652,25 @@ def svd(
622652
if isinstance(cutoff, tuple):
623653
assert len(cutoff) == 2, "The length of cutoff must be 2 if cutoff is a tuple."
624654

625-
tensor, left_legs, right_legs = self._group_edges(free_names_u)
655+
left_legs, right_legs = GrassmannTensor.get_legs_pair(self.tensor.dim(), free_names_u)
656+
order = left_legs + right_legs
657+
tensor = self.permute(order)
658+
659+
arrow_reverse = tuple(i for i, current in enumerate(tensor.arrow) if current)
660+
if arrow_reverse:
661+
tensor = tensor.reverse(arrow_reverse).reverse(arrow_reverse).reverse(arrow_reverse)
662+
663+
left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
664+
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])
665+
tensor = tensor.reshape((left_dim, right_dim))
666+
667+
origin_arrow_left = tuple(self.arrow[i] for i in left_legs)
668+
origin_arrow_right = tuple(self.arrow[i] for i in right_legs)
669+
670+
arrow_reverse_left = tuple(i for i, current in enumerate(origin_arrow_left) if current)
671+
arrow_reverse_right = tuple(
672+
i + 1 for i, current in enumerate(origin_arrow_right) if current
673+
)
626674

627675
(even_left, odd_left) = tensor.edges[0]
628676
(even_right, odd_right) = tensor.edges[1]
@@ -700,7 +748,7 @@ def svd(
700748
(Vh_even_trunc.shape[1], Vh_odd_trunc.shape[1]),
701749
)
702750

703-
U = GrassmannTensor(_arrow=(True, True), _edges=U_edges, _tensor=U_tensor)
751+
U = GrassmannTensor(_arrow=(False, True), _edges=U_edges, _tensor=U_tensor)
704752
S = GrassmannTensor(
705753
_arrow=(
706754
False,
@@ -709,52 +757,141 @@ def svd(
709757
_edges=S_edges,
710758
_tensor=torch.diag(S_tensor),
711759
)
712-
Vh = GrassmannTensor(_arrow=(False, True), _edges=Vh_edges, _tensor=Vh_tensor)
713-
# Split
714-
left_arrow = [self.arrow[i] for i in left_legs]
715-
left_edges = [self.edges[i] for i in left_legs]
760+
Vh = GrassmannTensor(_arrow=(False, False), _edges=Vh_edges, _tensor=Vh_tensor)
716761

717-
right_arrow = [self.arrow[i] for i in right_legs]
762+
left_edges = [self.edges[i] for i in left_legs]
718763
right_edges = [self.edges[i] for i in right_legs]
719764

720765
U = U.reshape((*left_edges, U_edges[1]))
721-
U._arrow = tuple(left_arrow + [True])
766+
U = U.reverse(arrow_reverse_left)
722767

723768
Vh = Vh.reshape((Vh_edges[0], *right_edges))
724-
Vh._arrow = tuple([False] + right_arrow)
769+
Vh = Vh.reverse(arrow_reverse_right)
725770

726771
return U, S, Vh
727772

728-
def _get_inv_order(self, order: tuple[int, ...]) -> tuple[int, ...]:
729-
inv = [0] * self.tensor.dim()
773+
@staticmethod
774+
def get_inv_order(dim: int, order: tuple[int, ...]) -> tuple[int, ...]:
775+
inv = [0] * dim
730776
for new_position, origin_idx in enumerate(order):
731777
inv[origin_idx] = new_position
732778
return tuple(inv)
733779

734-
def _check_pairs_coverage(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> bool:
735-
set0 = set(pairs[0])
736-
set1 = set(pairs[1])
780+
def _get_inv_order(self, order: tuple[int, ...]) -> tuple[int, ...]:
781+
return self.get_inv_order(self.tensor.dim(), order)
782+
783+
@staticmethod
784+
def contract(
785+
a: GrassmannTensor,
786+
b: GrassmannTensor,
787+
a_leg: int | tuple[int, ...],
788+
b_leg: int | tuple[int, ...],
789+
) -> GrassmannTensor:
790+
contract_lengths = []
791+
for leg in (a_leg, b_leg):
792+
if isinstance(leg, int):
793+
contract_lengths.append(1)
794+
elif isinstance(leg, tuple):
795+
contract_lengths.append(len(leg))
796+
assert all(a.arrow[i] == a.arrow[leg[0]] for i in leg), (
797+
"All the legs that need to be contracted must have the same arrow"
798+
)
799+
800+
contract_length_a, contract_length_b = contract_lengths
801+
802+
a_leg_tuple = (a_leg,) if isinstance(a_leg, int) else a_leg
803+
b_leg_tuple = (b_leg,) if isinstance(b_leg, int) else b_leg
737804

738-
are_disjoint = set0.isdisjoint(set1)
805+
a_range_list = tuple(range(a.tensor.dim()))
806+
b_range_list = tuple(range(b.tensor.dim()))
739807

740-
is_complete_union = (set0 | set1) == set(range(self.tensor.dim()))
808+
a_contract_set = set(a_leg_tuple)
809+
b_contract_set = set(b_leg_tuple)
741810

742-
return are_disjoint and is_complete_union
811+
order_a = tuple(i for i in a_range_list if i not in a_contract_set) + a_leg_tuple
812+
order_b = b_leg_tuple + tuple(i for i in b_range_list if i not in b_contract_set)
813+
814+
tensor_a = a.permute(order_a)
815+
tensor_b = b.permute(order_b)
816+
817+
assert (tensor_a.arrow[-1], tensor_b.arrow[0]) in ((False, True), (True, False)), (
818+
f"Contract requires arrow (False, True) or (True, False), but got {tensor_a.arrow[-1], tensor_b.arrow[0]}"
819+
)
820+
821+
arrow_after_permute_a = tensor_a.arrow
822+
arrow_after_permute_b = tensor_b.arrow
823+
824+
edge_after_permute_a = tensor_a.edges
825+
edge_after_permute_b = tensor_b.edges
826+
827+
arrow_expected_a = [i >= a.tensor.dim() - contract_length_a for i in range(a.tensor.dim())]
828+
arrow_expected_b = [i >= contract_length_b for i in range(b.tensor.dim())]
829+
830+
arrow_reverse_a = tuple(
831+
i
832+
for i, (cur, exp) in enumerate(zip(arrow_after_permute_a, arrow_expected_a))
833+
if cur != exp
834+
)
835+
arrow_reverse_b = tuple(
836+
i
837+
for i, (cur, exp) in enumerate(zip(arrow_after_permute_b, arrow_expected_b))
838+
if cur != exp
839+
)
840+
841+
if arrow_reverse_a:
842+
tensor_a = (
843+
tensor_a.reverse(arrow_reverse_a).reverse(arrow_reverse_a).reverse(arrow_reverse_a)
844+
)
845+
if arrow_reverse_b:
846+
tensor_b = (
847+
tensor_b.reverse(arrow_reverse_b).reverse(arrow_reverse_b).reverse(arrow_reverse_b)
848+
)
849+
850+
tensor_a = tensor_a.reshape(
851+
(
852+
math.prod(tensor_a.tensor.shape[:-contract_length_a]),
853+
math.prod(tensor_a.tensor.shape[-contract_length_a:]),
854+
)
855+
)
856+
tensor_b = tensor_b.reshape(
857+
(
858+
math.prod(tensor_b.tensor.shape[:contract_length_b]),
859+
math.prod(tensor_b.tensor.shape[contract_length_b:]),
860+
)
861+
)
862+
863+
c = tensor_a @ tensor_b
864+
865+
c = c.reshape(
866+
(edge_after_permute_a[:-contract_length_a] + edge_after_permute_b[contract_length_b:])
867+
)
868+
869+
arrow_reverse_c = tuple(
870+
[i for i in arrow_reverse_a if i < a.tensor.dim() - contract_length_a]
871+
+ [
872+
(a.tensor.dim() - contract_length_a) + (i - contract_length_b)
873+
for i in arrow_reverse_b
874+
if i >= contract_length_b
875+
]
876+
)
877+
c = c.reverse(arrow_reverse_c)
878+
return c
743879

744880
def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor:
745881
tensor, left_legs, right_legs = self._group_edges(pairs)
746882

747-
arrow_order = (False, True)
748-
edges_to_reverse = tuple(
749-
i for i, arrow in enumerate(arrow_order) if tensor.arrow[i] != arrow
883+
assert tensor.arrow in ((False, True), (True, False)), (
884+
f"Exponentiation requires arrow (False, True) or (True, False), but got {tensor.arrow}"
750885
)
751-
if edges_to_reverse:
752-
tensor = tensor.reverse(edges_to_reverse)
886+
887+
tensor_reverse_flag = tensor.arrow != (False, True)
888+
if tensor_reverse_flag:
889+
tensor = tensor.reverse((0, 1))
753890

754891
left_dim, right_dim = tensor.tensor.shape
755892

756893
assert left_dim == right_dim, (
757-
f"Exponential requires a square operator, but got {left_dim} x {right_dim}."
894+
f"Exponentiation requires a square operator, but got {left_dim} x {right_dim}."
758895
)
759896

760897
(even_left, odd_left) = tensor.edges[0]
@@ -774,8 +911,8 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma
774911

775912
tensor_exp = dataclasses.replace(tensor, _tensor=tensor_exp)
776913

777-
if edges_to_reverse:
778-
tensor_exp = tensor_exp.reverse(tuple(edges_to_reverse))
914+
if tensor_reverse_flag:
915+
tensor_exp = tensor_exp.reverse((0, 1))
779916

780917
order = left_legs + right_legs
781918
edges_after_permute = tuple(self.edges[i] for i in order)
@@ -787,6 +924,47 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma
787924

788925
return tensor_exp
789926

927+
def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor:
928+
tensor, left_legs, right_legs = self._group_edges(pairs)
929+
930+
assert tensor.arrow in ((False, True), (True, False)), (
931+
f"Identity requires arrow (False, True) or (True, False), but got {tensor.arrow}"
932+
)
933+
934+
tensor_reverse_flag = tensor.arrow != (False, True)
935+
if tensor_reverse_flag:
936+
tensor = tensor.reverse((0, 1))
937+
938+
left_dim, right_dim = tensor.tensor.shape
939+
940+
assert left_dim == right_dim, (
941+
f"Identity requires a square operator, but got {left_dim} x {right_dim}."
942+
)
943+
944+
(even_left, odd_left) = tensor.edges[0]
945+
(even_right, odd_right) = tensor.edges[1]
946+
947+
assert even_left == even_right and odd_left == odd_right, (
948+
f"Parity blocks must be square, but got L=({even_left},{odd_left}), R=({even_right},{odd_right})"
949+
)
950+
951+
I = torch.eye(left_dim, dtype=tensor.tensor.dtype, device=tensor.tensor.device) # noqa: E741
952+
953+
tensor_identity = dataclasses.replace(tensor, _tensor=I)
954+
955+
if tensor_reverse_flag:
956+
tensor_identity = tensor_identity.reverse((0, 1))
957+
958+
order = left_legs + right_legs
959+
edges_after_permute = tuple(self.edges[i] for i in order)
960+
tensor_identity = tensor_identity.reshape(edges_after_permute)
961+
962+
inv_order = self._get_inv_order(order)
963+
964+
tensor_identity = tensor_identity.permute(inv_order)
965+
966+
return tensor_identity
967+
790968
def __post_init__(self) -> None:
791969
assert len(self._arrow) == self._tensor.dim(), (
792970
f"Arrow length ({len(self._arrow)}) must match tensor dimensions ({self._tensor.dim()})."

tests/contract_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import torch
2+
import pytest
3+
4+
from grassmann_tensor import GrassmannTensor
5+
6+
7+
def test_contract_assertion() -> None:
8+
a = GrassmannTensor((False, True), ((1, 0), (1, 0)), torch.randn(1, 1, dtype=torch.float64))
9+
b = GrassmannTensor(
10+
(False, True, False, True),
11+
((2, 2), (4, 4), (8, 8), (16, 16)),
12+
torch.randn(4, 8, 16, 32, dtype=torch.float64),
13+
)
14+
with pytest.raises(AssertionError, match="Contract requires arrow"):
15+
_ = a.contract(a, b, 0, 0)
16+
with pytest.raises(
17+
AssertionError, match="All the legs that need to be contracted must have the same arrow"
18+
):
19+
_ = a.contract(a, b, 0, (0, 1))

0 commit comments

Comments
 (0)