Skip to content

Commit 74c5fab

Browse files
authored
Merge pull request #78 from USTC-KnowledgeComputingLab/fix/exponential-processing-logic
fix(exponential): fix exponential processing logic Close #78
2 parents 1049077 + ff013d1 commit 74c5fab

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

grassmann_tensor/tensor.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -571,12 +571,19 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
571571

572572
def _group_edges(
573573
self,
574-
left_legs: typing.Iterable[int],
574+
pairs: tuple[int, ...] | tuple[tuple[int, ...], tuple[int, ...]],
575575
) -> tuple[GrassmannTensor, tuple[int, ...], tuple[int, ...]]:
576-
left_legs = tuple(int(i) for i in left_legs)
577-
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
578-
assert set(left_legs) | set(right_legs) == set(range(self.tensor.dim())), (
579-
"Left/right must cover all tensor legs."
576+
if (isinstance(pairs, tuple) and len(pairs)) and all(
577+
isinstance(x, tuple) and all(isinstance(i, int) for i in x) for x in pairs
578+
):
579+
left_legs = typing.cast(tuple[int, ...], pairs[0])
580+
right_legs = typing.cast(tuple[int, ...], pairs[1])
581+
else:
582+
left_legs = typing.cast(tuple[int, ...], pairs)
583+
right_legs = tuple(i for i in range(self.tensor.dim()) if i not in left_legs)
584+
585+
assert self._check_pairs_coverage((left_legs, right_legs)), (
586+
f"Input pairs must cover all dimension and disjoint, but got {(left_legs, right_legs)}"
580587
)
581588

582589
order = left_legs + right_legs
@@ -724,7 +731,17 @@ def _get_inv_order(self, order: tuple[int, ...]) -> tuple[int, ...]:
724731
inv[origin_idx] = new_position
725732
return tuple(inv)
726733

727-
def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor:
734+
def _check_pairs_coverage(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> bool:
735+
set0 = set(pairs[0])
736+
set1 = set(pairs[1])
737+
738+
are_disjoint = set0.isdisjoint(set1)
739+
740+
is_complete_union = (set0 | set1) == set(range(self.tensor.dim()))
741+
742+
return are_disjoint and is_complete_union
743+
744+
def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor:
728745
tensor, left_legs, right_legs = self._group_edges(pairs)
729746

730747
arrow_order = (False, True)
@@ -743,6 +760,10 @@ def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor:
743760
(even_left, odd_left) = tensor.edges[0]
744761
(even_right, odd_right) = tensor.edges[1]
745762

763+
assert even_left == even_right and odd_left == odd_right, (
764+
f"Parity blocks must be square, but got L=({even_left},{odd_left}), R=({even_right},{odd_right})"
765+
)
766+
746767
even_tensor = tensor.tensor[:even_left, :even_right]
747768
odd_tensor = tensor.tensor[even_left:, even_right:]
748769

tests/exponential_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@ def test_exponential() -> None:
1010
((4, 4), (8, 8), (4, 4), (8, 8)),
1111
torch.randn(8, 16, 8, 16, dtype=torch.float64),
1212
)
13-
a.exponential((0, 3))
13+
b = a.exponential(((0, 3), (1, 2)))
14+
c = a.exponential(((0, 3), (2, 1)))
15+
assert not torch.allclose(b.tensor, c.tensor)
1416

1517

1618
def test_exponential_with_empty_parity_block() -> None:
1719
a = GrassmannTensor((False, True), ((1, 0), (1, 0)), torch.randn(1, 1))
18-
a.exponential((0,))
20+
a.exponential(((0,), (1,)))
1921
b = GrassmannTensor((False, True), ((0, 1), (0, 1)), torch.randn(1, 1))
20-
b.exponential((0,))
22+
b.exponential(((0,), (1,)))
2123

2224

2325
def test_exponential_assertation() -> None:
@@ -27,4 +29,4 @@ def test_exponential_assertation() -> None:
2729
torch.randn(4, 8, 16, 32, dtype=torch.float64),
2830
)
2931
with pytest.raises(AssertionError, match="Exponential requires a square operator"):
30-
a.exponential((0, 2))
32+
a.exponential(((0, 2), (1, 3)))

0 commit comments

Comments
 (0)