@@ -132,14 +132,19 @@ def parameter_string(self) -> str:
132132 """A human-readable string summarizing the prior's lower and upper limits."""
133133 return f"lower_limit = { self .lower_limit } , upper_limit = { self .upper_limit } "
134134
135- def value_for (self , unit : float ) -> float :
135+ def value_for (self , unit , xp = np ) :
136136 """
137137 Returns a physical value from an input unit value according to the limits of the uniform prior.
138138
139139 Parameters
140140 ----------
141141 unit
142142 A unit value between 0 and 1.
143+ xp
144+ Array-module to dispatch on (``numpy`` or ``jax.numpy``). Default ``numpy``.
145+ The NumPy path preserves the historical ``float(round(..., 14))`` snap
146+ (used as a hash key in ``model.priors``); the JAX path uses the
147+ closed-form ``lower + (upper - lower) * unit`` so the trace stays symbolic.
143148
144149 Returns
145150 -------
@@ -153,9 +158,11 @@ def value_for(self, unit: float) -> float:
153158
154159 physical_value = prior.value_for(unit=0.2)
155160 """
156- return float (
157- round (super ().value_for (unit ), 14 )
158- )
161+ if xp is np :
162+ return float (
163+ round (super ().value_for (unit ), 14 )
164+ )
165+ return self .lower_limit + (self .upper_limit - self .lower_limit ) * unit
159166
160167 def log_prior_from_value (self , value , xp = np ):
161168 """
@@ -166,7 +173,10 @@ def log_prior_from_value(self, value, xp=np):
166173
167174 For a UniformPrior this is always zero, provided the value is between the lower and upper limit.
168175 """
169- return 0.0
176+ if xp is np :
177+ return 0.0
178+ in_bounds = (value >= self .lower_limit ) & (value <= self .upper_limit )
179+ return xp .where (in_bounds , xp .zeros_like (value ), - xp .inf )
170180
171181 @property
172182 def limits (self ) -> Tuple [float , float ]:
0 commit comments