File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -44,7 +44,6 @@ def __init__(
4444 use_jax_vmap : bool = False ,
4545 batch_size : Optional [int ] = None ,
4646 iterations_per_quick_update : Optional [int ] = None ,
47- xp = np ,
4847 ):
4948 """
5049 Interfaces with any non-linear search to fit the model to the data and return a log likelihood via
@@ -109,7 +108,8 @@ def __init__(
109108 self .model = model
110109 self .paths = paths
111110 self .fom_is_log_likelihood = fom_is_log_likelihood
112- self .resample_figure_of_merit = resample_figure_of_merit or - xp .inf
111+
112+ self .resample_figure_of_merit = resample_figure_of_merit or - self ._xp .inf
113113 self .convert_to_chi_squared = convert_to_chi_squared
114114 self .store_history = store_history
115115
@@ -123,10 +123,20 @@ def __init__(
123123 if self .use_jax_vmap :
124124 self ._call = self ._vmap
125125
126+ if analysis ._use_jax :
127+
128+ import jax
129+
130+ if jax .default_backend () == "cpu" :
131+
132+ logger .info ("JAX using CPU backend, batch size set to 1 which will improve performance." )
133+
134+ batch_size = 1
135+
126136 self .batch_size = batch_size
127137 self .iterations_per_quick_update = iterations_per_quick_update
128138 self .quick_update_max_lh_parameters = None
129- self .quick_update_max_lh = - xp .inf
139+ self .quick_update_max_lh = - self . _xp .inf
130140 self .quick_update_count = 0
131141
132142 if self .paths is not None :
You can’t perform that action at this time.
0 commit comments