@@ -308,7 +308,10 @@ def reshape(self, new_shape: tuple[int | tuple[int, int], ...]) -> GrassmannTens
308308 if (isinstance (new_shape_check , int ) and new_shape_check == 1 ) or (
309309 new_shape_check == (1 , 0 )
310310 ):
311- arrow .append (False )
311+ if cursor_plan < len (self .arrow ):
312+ arrow .append (self .arrow [cursor_plan ])
313+ else :
314+ arrow .append (False )
312315 edges .append ((1 , 0 ))
313316 shape .append (1 )
314317 cursor_plan += 1
@@ -573,22 +576,18 @@ def _group_edges(
573576 self ,
574577 pairs : tuple [int , ...] | tuple [tuple [int , ...], tuple [int , ...]],
575578 ) -> tuple [GrassmannTensor , tuple [int , ...], tuple [int , ...]]:
576- if (isinstance (pairs , tuple ) and len (pairs )) and all (
577- isinstance (x , tuple ) and all (isinstance (i , int ) for i in x ) for x in pairs
578- ):
579- left_legs = typing .cast (tuple [int , ...], pairs [0 ])
580- right_legs = typing .cast (tuple [int , ...], pairs [1 ])
581- else :
582- left_legs = typing .cast (tuple [int , ...], pairs )
583- right_legs = tuple (i for i in range (self .tensor .dim ()) if i not in left_legs )
579+ return self .group_edges (self , pairs )
584580
585- assert self ._check_pairs_coverage ((left_legs , right_legs )), (
586- f"Input pairs must cover all dimension and disjoint, but got { (left_legs , right_legs )} "
587- )
581+ @staticmethod
582+ def group_edges (
583+ tensor : GrassmannTensor ,
584+ pairs : tuple [int , ...] | tuple [tuple [int , ...], tuple [int , ...]],
585+ ) -> tuple [GrassmannTensor , tuple [int , ...], tuple [int , ...]]:
586+ left_legs , right_legs = GrassmannTensor .get_legs_pair (tensor .tensor .dim (), pairs )
588587
589588 order = left_legs + right_legs
590589
591- tensor = self .permute (order )
590+ tensor = tensor .permute (order )
592591
593592 left_dim = math .prod (tensor .tensor .shape [: len (left_legs )])
594593 right_dim = math .prod (tensor .tensor .shape [len (left_legs ) :])
@@ -597,6 +596,37 @@ def _group_edges(
597596
598597 return tensor , left_legs , right_legs
599598
599+ @staticmethod
600+ def get_legs_pair (
601+ dim : int , pairs : tuple [int , ...] | tuple [tuple [int , ...], tuple [int , ...]]
602+ ) -> tuple [tuple [int , ...], tuple [int , ...]]:
603+ def check_pairs_coverage (dim : int , pairs : tuple [tuple [int , ...], tuple [int , ...]]) -> bool :
604+ set0 = set (pairs [0 ])
605+ set1 = set (pairs [1 ])
606+
607+ are_disjoint = set0 .isdisjoint (set1 )
608+
609+ is_complete_union = (set0 | set1 ) == set (range (dim ))
610+
611+ no_duplicates = len (pairs [0 ]) + len (pairs [1 ]) == dim
612+
613+ return are_disjoint and is_complete_union and no_duplicates
614+
615+ if (isinstance (pairs , tuple ) and len (pairs )) and all (
616+ isinstance (x , tuple ) and all (isinstance (i , int ) for i in x ) for x in pairs
617+ ):
618+ left_legs = typing .cast (tuple [int , ...], pairs [0 ])
619+ right_legs = typing .cast (tuple [int , ...], pairs [1 ])
620+ else :
621+ left_legs = typing .cast (tuple [int , ...], pairs )
622+ right_legs = tuple (i for i in range (dim ) if i not in left_legs )
623+
624+ assert check_pairs_coverage (dim , (left_legs , right_legs )), (
625+ f"Input pairs must cover all dimension and disjoint, but got { (left_legs , right_legs )} "
626+ )
627+
628+ return left_legs , right_legs
629+
600630 def svd (
601631 self ,
602632 free_names_u : tuple [int , ...],
@@ -622,7 +652,25 @@ def svd(
622652 if isinstance (cutoff , tuple ):
623653 assert len (cutoff ) == 2 , "The length of cutoff must be 2 if cutoff is a tuple."
624654
625- tensor , left_legs , right_legs = self ._group_edges (free_names_u )
655+ left_legs , right_legs = GrassmannTensor .get_legs_pair (self .tensor .dim (), free_names_u )
656+ order = left_legs + right_legs
657+ tensor = self .permute (order )
658+
659+ arrow_reverse = tuple (i for i , current in enumerate (tensor .arrow ) if current )
660+ if arrow_reverse :
661+ tensor = tensor .reverse (arrow_reverse ).reverse (arrow_reverse ).reverse (arrow_reverse )
662+
663+ left_dim = math .prod (tensor .tensor .shape [: len (left_legs )])
664+ right_dim = math .prod (tensor .tensor .shape [len (left_legs ) :])
665+ tensor = tensor .reshape ((left_dim , right_dim ))
666+
667+ origin_arrow_left = tuple (self .arrow [i ] for i in left_legs )
668+ origin_arrow_right = tuple (self .arrow [i ] for i in right_legs )
669+
670+ arrow_reverse_left = tuple (i for i , current in enumerate (origin_arrow_left ) if current )
671+ arrow_reverse_right = tuple (
672+ i + 1 for i , current in enumerate (origin_arrow_right ) if current
673+ )
626674
627675 (even_left , odd_left ) = tensor .edges [0 ]
628676 (even_right , odd_right ) = tensor .edges [1 ]
@@ -700,7 +748,7 @@ def svd(
700748 (Vh_even_trunc .shape [1 ], Vh_odd_trunc .shape [1 ]),
701749 )
702750
703- U = GrassmannTensor (_arrow = (True , True ), _edges = U_edges , _tensor = U_tensor )
751+ U = GrassmannTensor (_arrow = (False , True ), _edges = U_edges , _tensor = U_tensor )
704752 S = GrassmannTensor (
705753 _arrow = (
706754 False ,
@@ -709,52 +757,141 @@ def svd(
709757 _edges = S_edges ,
710758 _tensor = torch .diag (S_tensor ),
711759 )
712- Vh = GrassmannTensor (_arrow = (False , True ), _edges = Vh_edges , _tensor = Vh_tensor )
713- # Split
714- left_arrow = [self .arrow [i ] for i in left_legs ]
715- left_edges = [self .edges [i ] for i in left_legs ]
760+ Vh = GrassmannTensor (_arrow = (False , False ), _edges = Vh_edges , _tensor = Vh_tensor )
716761
717- right_arrow = [self .arrow [i ] for i in right_legs ]
762+ left_edges = [self .edges [i ] for i in left_legs ]
718763 right_edges = [self .edges [i ] for i in right_legs ]
719764
720765 U = U .reshape ((* left_edges , U_edges [1 ]))
721- U . _arrow = tuple ( left_arrow + [ True ] )
766+ U = U . reverse ( arrow_reverse_left )
722767
723768 Vh = Vh .reshape ((Vh_edges [0 ], * right_edges ))
724- Vh . _arrow = tuple ([ False ] + right_arrow )
769+ Vh = Vh . reverse ( arrow_reverse_right )
725770
726771 return U , S , Vh
727772
728- def _get_inv_order (self , order : tuple [int , ...]) -> tuple [int , ...]:
729- inv = [0 ] * self .tensor .dim ()
773+ @staticmethod
774+ def get_inv_order (dim : int , order : tuple [int , ...]) -> tuple [int , ...]:
775+ inv = [0 ] * dim
730776 for new_position , origin_idx in enumerate (order ):
731777 inv [origin_idx ] = new_position
732778 return tuple (inv )
733779
734- def _check_pairs_coverage (self , pairs : tuple [tuple [int , ...], tuple [int , ...]]) -> bool :
735- set0 = set (pairs [0 ])
736- set1 = set (pairs [1 ])
780+ def _get_inv_order (self , order : tuple [int , ...]) -> tuple [int , ...]:
781+ return self .get_inv_order (self .tensor .dim (), order )
782+
783+ @staticmethod
784+ def contract (
785+ a : GrassmannTensor ,
786+ b : GrassmannTensor ,
787+ a_leg : int | tuple [int , ...],
788+ b_leg : int | tuple [int , ...],
789+ ) -> GrassmannTensor :
790+ contract_lengths = []
791+ for leg in (a_leg , b_leg ):
792+ if isinstance (leg , int ):
793+ contract_lengths .append (1 )
794+ elif isinstance (leg , tuple ):
795+ contract_lengths .append (len (leg ))
796+ assert all (a .arrow [i ] == a .arrow [leg [0 ]] for i in leg ), (
797+ "All the legs that need to be contracted must have the same arrow"
798+ )
799+
800+ contract_length_a , contract_length_b = contract_lengths
801+
802+ a_leg_tuple = (a_leg ,) if isinstance (a_leg , int ) else a_leg
803+ b_leg_tuple = (b_leg ,) if isinstance (b_leg , int ) else b_leg
737804
738- are_disjoint = set0 .isdisjoint (set1 )
805+ a_range_list = tuple (range (a .tensor .dim ()))
806+ b_range_list = tuple (range (b .tensor .dim ()))
739807
740- is_complete_union = (set0 | set1 ) == set (range (self .tensor .dim ()))
808+ a_contract_set = set (a_leg_tuple )
809+ b_contract_set = set (b_leg_tuple )
741810
742- return are_disjoint and is_complete_union
811+ order_a = tuple (i for i in a_range_list if i not in a_contract_set ) + a_leg_tuple
812+ order_b = b_leg_tuple + tuple (i for i in b_range_list if i not in b_contract_set )
813+
814+ tensor_a = a .permute (order_a )
815+ tensor_b = b .permute (order_b )
816+
817+ assert (tensor_a .arrow [- 1 ], tensor_b .arrow [0 ]) in ((False , True ), (True , False )), (
818+ f"Contract requires arrow (False, True) or (True, False), but got { tensor_a .arrow [- 1 ], tensor_b .arrow [0 ]} "
819+ )
820+
821+ arrow_after_permute_a = tensor_a .arrow
822+ arrow_after_permute_b = tensor_b .arrow
823+
824+ edge_after_permute_a = tensor_a .edges
825+ edge_after_permute_b = tensor_b .edges
826+
827+ arrow_expected_a = [i >= a .tensor .dim () - contract_length_a for i in range (a .tensor .dim ())]
828+ arrow_expected_b = [i >= contract_length_b for i in range (b .tensor .dim ())]
829+
830+ arrow_reverse_a = tuple (
831+ i
832+ for i , (cur , exp ) in enumerate (zip (arrow_after_permute_a , arrow_expected_a ))
833+ if cur != exp
834+ )
835+ arrow_reverse_b = tuple (
836+ i
837+ for i , (cur , exp ) in enumerate (zip (arrow_after_permute_b , arrow_expected_b ))
838+ if cur != exp
839+ )
840+
841+ if arrow_reverse_a :
842+ tensor_a = (
843+ tensor_a .reverse (arrow_reverse_a ).reverse (arrow_reverse_a ).reverse (arrow_reverse_a )
844+ )
845+ if arrow_reverse_b :
846+ tensor_b = (
847+ tensor_b .reverse (arrow_reverse_b ).reverse (arrow_reverse_b ).reverse (arrow_reverse_b )
848+ )
849+
850+ tensor_a = tensor_a .reshape (
851+ (
852+ math .prod (tensor_a .tensor .shape [:- contract_length_a ]),
853+ math .prod (tensor_a .tensor .shape [- contract_length_a :]),
854+ )
855+ )
856+ tensor_b = tensor_b .reshape (
857+ (
858+ math .prod (tensor_b .tensor .shape [:contract_length_b ]),
859+ math .prod (tensor_b .tensor .shape [contract_length_b :]),
860+ )
861+ )
862+
863+ c = tensor_a @ tensor_b
864+
865+ c = c .reshape (
866+ (edge_after_permute_a [:- contract_length_a ] + edge_after_permute_b [contract_length_b :])
867+ )
868+
869+ arrow_reverse_c = tuple (
870+ [i for i in arrow_reverse_a if i < a .tensor .dim () - contract_length_a ]
871+ + [
872+ (a .tensor .dim () - contract_length_a ) + (i - contract_length_b )
873+ for i in arrow_reverse_b
874+ if i >= contract_length_b
875+ ]
876+ )
877+ c = c .reverse (arrow_reverse_c )
878+ return c
743879
744880 def exponential (self , pairs : tuple [tuple [int , ...], tuple [int , ...]]) -> GrassmannTensor :
745881 tensor , left_legs , right_legs = self ._group_edges (pairs )
746882
747- arrow_order = (False , True )
748- edges_to_reverse = tuple (
749- i for i , arrow in enumerate (arrow_order ) if tensor .arrow [i ] != arrow
883+ assert tensor .arrow in ((False , True ), (True , False )), (
884+ f"Exponentiation requires arrow (False, True) or (True, False), but got { tensor .arrow } "
750885 )
751- if edges_to_reverse :
752- tensor = tensor .reverse (edges_to_reverse )
886+
887+ tensor_reverse_flag = tensor .arrow != (False , True )
888+ if tensor_reverse_flag :
889+ tensor = tensor .reverse ((0 , 1 ))
753890
754891 left_dim , right_dim = tensor .tensor .shape
755892
756893 assert left_dim == right_dim , (
757- f"Exponential requires a square operator, but got { left_dim } x { right_dim } ."
894+ f"Exponentiation requires a square operator, but got { left_dim } x { right_dim } ."
758895 )
759896
760897 (even_left , odd_left ) = tensor .edges [0 ]
@@ -774,8 +911,8 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma
774911
775912 tensor_exp = dataclasses .replace (tensor , _tensor = tensor_exp )
776913
777- if edges_to_reverse :
778- tensor_exp = tensor_exp .reverse (tuple ( edges_to_reverse ))
914+ if tensor_reverse_flag :
915+ tensor_exp = tensor_exp .reverse (( 0 , 1 ))
779916
780917 order = left_legs + right_legs
781918 edges_after_permute = tuple (self .edges [i ] for i in order )
@@ -787,6 +924,47 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma
787924
788925 return tensor_exp
789926
927+ def identity (self , pairs : tuple [tuple [int , ...], tuple [int , ...]]) -> GrassmannTensor :
928+ tensor , left_legs , right_legs = self ._group_edges (pairs )
929+
930+ assert tensor .arrow in ((False , True ), (True , False )), (
931+ f"Identity requires arrow (False, True) or (True, False), but got { tensor .arrow } "
932+ )
933+
934+ tensor_reverse_flag = tensor .arrow != (False , True )
935+ if tensor_reverse_flag :
936+ tensor = tensor .reverse ((0 , 1 ))
937+
938+ left_dim , right_dim = tensor .tensor .shape
939+
940+ assert left_dim == right_dim , (
941+ f"Identity requires a square operator, but got { left_dim } x { right_dim } ."
942+ )
943+
944+ (even_left , odd_left ) = tensor .edges [0 ]
945+ (even_right , odd_right ) = tensor .edges [1 ]
946+
947+ assert even_left == even_right and odd_left == odd_right , (
948+ f"Parity blocks must be square, but got L=({ even_left } ,{ odd_left } ), R=({ even_right } ,{ odd_right } )"
949+ )
950+
951+ I = torch .eye (left_dim , dtype = tensor .tensor .dtype , device = tensor .tensor .device ) # noqa: E741
952+
953+ tensor_identity = dataclasses .replace (tensor , _tensor = I )
954+
955+ if tensor_reverse_flag :
956+ tensor_identity = tensor_identity .reverse ((0 , 1 ))
957+
958+ order = left_legs + right_legs
959+ edges_after_permute = tuple (self .edges [i ] for i in order )
960+ tensor_identity = tensor_identity .reshape (edges_after_permute )
961+
962+ inv_order = self ._get_inv_order (order )
963+
964+ tensor_identity = tensor_identity .permute (inv_order )
965+
966+ return tensor_identity
967+
790968 def __post_init__ (self ) -> None :
791969 assert len (self ._arrow ) == self ._tensor .dim (), (
792970 f"Arrow length ({ len (self ._arrow )} ) must match tensor dimensions ({ self ._tensor .dim ()} )."
0 commit comments