@@ -378,7 +378,7 @@ def regularization_matrix_reduced(self) -> Optional[np.ndarray]:
378378 return self .regularization_matrix
379379
380380 # ids of values which are on edge so zero-d and not solved for.
381- ids_to_keep = jnp . array ( self .mapper_index_list , dtype = int )
381+ ids_to_keep = self .mapper_index_list
382382
383383 # Zero rows and columns in the matrix we want to ignore
384384 return self .regularization_matrix [ids_to_keep ][:, ids_to_keep ]
@@ -417,7 +417,7 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
417417 return self .curvature_reg_matrix
418418
419419 # ids of values which are on edge so zero-d and not solved for.
420- ids_to_keep = jnp . array ( self .mapper_index_list , dtype = int )
420+ ids_to_keep = self .mapper_index_list
421421
422422 # Zero rows and columns in the matrix we want to ignore
423423 return self .regularization_matrix [ids_to_keep ][:, ids_to_keep ]
@@ -525,7 +525,7 @@ def reconstruction_reduced(self) -> np.ndarray:
525525 return self .reconstruction
526526
527527 # ids of values which are on edge so zero-d and not solved for.
528- ids_to_keep = jnp . array ( self .mapper_index_list , dtype = int )
528+ ids_to_keep = self .mapper_index_list
529529
530530 # Zero rows and columns in the matrix we want to ignore
531531 return self .reconstruction [ids_to_keep ]
@@ -671,9 +671,9 @@ def regularization_term(self) -> float:
671671 if not self .has (cls = AbstractRegularization ):
672672 return 0.0
673673
674- return np .matmul (
674+ return jnp .matmul (
675675 self .reconstruction_reduced .T ,
676- np .matmul (self .regularization_matrix_reduced , self .reconstruction_reduced ),
676+ jnp .matmul (self .regularization_matrix_reduced , self .reconstruction_reduced ),
677677 )
678678
679679 @cached_property
0 commit comments