Skip to content

Commit 49ad425

Browse files
authored
Merge pull request #1169 from rhayes777/feature/jax_cpu_batch_size_1
feature/jax_cpu_batch_size_1
2 parents 79e284c + 074b97b commit 49ad425

1 file changed

Lines changed: 13 additions & 3 deletions

File tree

autofit/non_linear/fitness.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def __init__(
4444
use_jax_vmap : bool = False,
4545
batch_size : Optional[int] = None,
4646
iterations_per_quick_update: Optional[int] = None,
47-
xp=np,
4847
):
4948
"""
5049
Interfaces with any non-linear search to fit the model to the data and return a log likelihood via
@@ -109,7 +108,8 @@ def __init__(
109108
self.model = model
110109
self.paths = paths
111110
self.fom_is_log_likelihood = fom_is_log_likelihood
112-
self.resample_figure_of_merit = resample_figure_of_merit or -xp.inf
111+
112+
self.resample_figure_of_merit = resample_figure_of_merit or -self._xp.inf
113113
self.convert_to_chi_squared = convert_to_chi_squared
114114
self.store_history = store_history
115115

@@ -123,10 +123,20 @@ def __init__(
123123
if self.use_jax_vmap:
124124
self._call = self._vmap
125125

126+
if analysis._use_jax:
127+
128+
import jax
129+
130+
if jax.default_backend() == "cpu":
131+
132+
logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.")
133+
134+
batch_size = 1
135+
126136
self.batch_size = batch_size
127137
self.iterations_per_quick_update = iterations_per_quick_update
128138
self.quick_update_max_lh_parameters = None
129-
self.quick_update_max_lh = -xp.inf
139+
self.quick_update_max_lh = -self._xp.inf
130140
self.quick_update_count = 0
131141

132142
if self.paths is not None:

0 commit comments

Comments
 (0)