@@ -571,12 +571,19 @@ def matmul(self, other: GrassmannTensor) -> GrassmannTensor:
571571
572572 def _group_edges (
573573 self ,
574- left_legs : typing . Iterable [ int ],
574+ pairs : tuple [ int , ...] | tuple [ tuple [ int , ...], tuple [ int , ...] ],
575575 ) -> tuple [GrassmannTensor , tuple [int , ...], tuple [int , ...]]:
576- left_legs = tuple (int (i ) for i in left_legs )
577- right_legs = tuple (i for i in range (self .tensor .dim ()) if i not in left_legs )
578- assert set (left_legs ) | set (right_legs ) == set (range (self .tensor .dim ())), (
579- "Left/right must cover all tensor legs."
576+ if (isinstance (pairs , tuple ) and len (pairs )) and all (
577+ isinstance (x , tuple ) and all (isinstance (i , int ) for i in x ) for x in pairs
578+ ):
579+ left_legs = typing .cast (tuple [int , ...], pairs [0 ])
580+ right_legs = typing .cast (tuple [int , ...], pairs [1 ])
581+ else :
582+ left_legs = typing .cast (tuple [int , ...], pairs )
583+ right_legs = tuple (i for i in range (self .tensor .dim ()) if i not in left_legs )
584+
585+ assert self ._check_pairs_coverage ((left_legs , right_legs )), (
586+ f"Input pairs must cover all dimension and disjoint, but got { (left_legs , right_legs )} "
580587 )
581588
582589 order = left_legs + right_legs
@@ -724,7 +731,17 @@ def _get_inv_order(self, order: tuple[int, ...]) -> tuple[int, ...]:
724731 inv [origin_idx ] = new_position
725732 return tuple (inv )
726733
727- def exponential (self , pairs : tuple [int , ...]) -> GrassmannTensor :
734+ def _check_pairs_coverage (self , pairs : tuple [tuple [int , ...], tuple [int , ...]]) -> bool :
735+ set0 = set (pairs [0 ])
736+ set1 = set (pairs [1 ])
737+
738+ are_disjoint = set0 .isdisjoint (set1 )
739+
740+ is_complete_union = (set0 | set1 ) == set (range (self .tensor .dim ()))
741+
742+ return are_disjoint and is_complete_union
743+
744+ def exponential (self , pairs : tuple [tuple [int , ...], tuple [int , ...]]) -> GrassmannTensor :
728745 tensor , left_legs , right_legs = self ._group_edges (pairs )
729746
730747 arrow_order = (False , True )
@@ -743,6 +760,10 @@ def exponential(self, pairs: tuple[int, ...]) -> GrassmannTensor:
743760 (even_left , odd_left ) = tensor .edges [0 ]
744761 (even_right , odd_right ) = tensor .edges [1 ]
745762
763+ assert even_left == even_right and odd_left == odd_right , (
764+ f"Parity blocks must be square, but got L=({ even_left } ,{ odd_left } ), R=({ even_right } ,{ odd_right } )"
765+ )
766+
746767 even_tensor = tensor .tensor [:even_left , :even_right ]
747768 odd_tensor = tensor .tensor [even_left :, even_right :]
748769
0 commit comments