Skip to content

Commit 93157b8

Browse files
Jammy2211Jammy2211
authored andcommitted
full JAX success
1 parent 1b5b64f commit 93157b8

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

autoarray/inversion/inversion/abstract.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def regularization_matrix_reduced(self) -> Optional[np.ndarray]:
378378
return self.regularization_matrix
379379

380380
# ids of values which are on edge so zero-d and not solved for.
381-
ids_to_keep = jnp.array(self.mapper_index_list, dtype=int)
381+
ids_to_keep = self.mapper_index_list
382382

383383
# Zero rows and columns in the matrix we want to ignore
384384
return self.regularization_matrix[ids_to_keep][:, ids_to_keep]
@@ -417,7 +417,7 @@ def curvature_reg_matrix_reduced(self) -> Optional[np.ndarray]:
417417
return self.curvature_reg_matrix
418418

419419
# ids of values which are on edge so zero-d and not solved for.
420-
ids_to_keep = jnp.array(self.mapper_index_list, dtype=int)
420+
ids_to_keep = self.mapper_index_list
421421

422422
# Zero rows and columns in the matrix we want to ignore
423423
return self.regularization_matrix[ids_to_keep][:, ids_to_keep]
@@ -525,7 +525,7 @@ def reconstruction_reduced(self) -> np.ndarray:
525525
return self.reconstruction
526526

527527
# ids of values which are on edge so zero-d and not solved for.
528-
ids_to_keep = jnp.array(self.mapper_index_list, dtype=int)
528+
ids_to_keep = self.mapper_index_list
529529

530530
# Zero rows and columns in the matrix we want to ignore
531531
return self.reconstruction[ids_to_keep]
@@ -671,9 +671,9 @@ def regularization_term(self) -> float:
671671
if not self.has(cls=AbstractRegularization):
672672
return 0.0
673673

674-
return np.matmul(
674+
return jnp.matmul(
675675
self.reconstruction_reduced.T,
676-
np.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced),
676+
jnp.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced),
677677
)
678678

679679
@cached_property

0 commit comments

Comments
 (0)