Skip to content

Commit 72df5f3

Browse files
committed
Infer logp and logcdf of abs of discrete variables
1 parent da3db75 commit 72df5f3

File tree

2 files changed

+138
-73
lines changed

2 files changed

+138
-73
lines changed

pymc/logprob/transforms.py

Lines changed: 75 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
import numpy as np
4141
import pytensor.tensor as pt
4242

43-
from pytensor import scan
43+
from pytensor import graph_replace, scan
4444
from pytensor.gradient import jacobian
4545
from pytensor.graph.basic import Apply, Variable
4646
from pytensor.graph.fg import FunctionGraph
@@ -163,6 +163,8 @@ def __str__(self):
163163
class MeasurableTransform(MeasurableElemwise):
164164
"""A placeholder used to specify a log-likelihood for a transformed measurable variable."""
165165

166+
__props__ = ("scalar_op", "inplace_pattern", "is_discrete")
167+
166168
valid_scalar_types = (
167169
Exp,
168170
Log,
@@ -187,16 +189,55 @@ class MeasurableTransform(MeasurableElemwise):
187189
transform_elemwise: Transform
188190
measurable_input_idx: int
189191

190-
def __init__(self, *args, transform: Transform, measurable_input_idx: int, **kwargs):
192+
def __init__(
193+
self, *args, transform: Transform, measurable_input_idx: int, is_discrete: bool, **kwargs
194+
):
191195
self.transform_elemwise = transform
192196
self.measurable_input_idx = measurable_input_idx
197+
self.is_discrete = is_discrete
193198
super().__init__(*args, **kwargs)
194199

195200

201+
def abs_logprob(op, value, x, **kwargs):
202+
"""Compute the log-CDF graph for an absolute value transformation.
203+
204+
For `Y = |X|`, we have `PDF_Y(y) = PDF_Y(-y) + PDF_Y(y)`.
205+
Except for discrete distributions where there's a special case `P(Y=0) = P(X=0)`.
206+
"""
207+
logprob_pos = _logprob_helper(x, value)
208+
logprob_neg = graph_replace(logprob_pos, {value: -value})
209+
if op.is_discrete:
210+
logprob = pt.switch(
211+
pt.eq(value, 0),
212+
logprob_pos,
213+
pt.logaddexp(logprob_pos, logprob_neg),
214+
)
215+
else:
216+
logprob = pt.logaddexp(logprob_pos, logprob_neg)
217+
logprob = pt.where(value < 0, -np.inf, logprob)
218+
return logprob
219+
220+
221+
def abs_logcdf(op, value, x, **kwargs):
222+
"""Compute the log-CDF graph for an absolute value transformation.
223+
224+
For `Y = |X|`, we have `CDF_Y(y) = P(|X| <= y) = P(-y <= X <= y) = CDF_X(y) - CDF_X(-y)`.
225+
"""
226+
logcdf_pos = _logcdf_helper(x, value)
227+
neg_value = -value - 1 if op.is_discrete else -value
228+
logcdf_neg = graph_replace(logcdf_pos, {value: neg_value})
229+
logcdf = logdiffexp(logcdf_pos, logcdf_neg)
230+
logcdf = pt.where(value < 0, -np.inf, logcdf)
231+
return logcdf
232+
233+
196234
@_logprob.register(MeasurableTransform)
197235
def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwargs):
198236
"""Compute the log-probability graph for a `MeasurabeTransform`."""
199237
# TODO: Could other rewrites affect the order of inputs?
238+
if isinstance(op.scalar_op, Abs):
239+
return abs_logprob(op, values[0], *inputs, **kwargs)
240+
200241
(value,) = values
201242
other_inputs = list(inputs)
202243
measurable_input = other_inputs.pop(op.measurable_input_idx)
@@ -207,6 +248,11 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
207248

208249
# Some transformations, like squaring may produce multiple backward values
209250
if isinstance(backward_value, tuple):
251+
if op.is_discrete:
252+
# Discrete variables tend to have the tricky x=0 case, get out if we don't have a custom implementation
253+
raise NotImplementedError(
254+
"Logprob of transformed discrete variables with non-injective transforms not implemented"
255+
)
210256
input_logprob = pt.logaddexp(
211257
*(
212258
_logprob_helper(measurable_input, backward_val, **kwargs)
@@ -225,8 +271,11 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
225271
ndim_supp = value.ndim - input_logprob.ndim
226272
jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0)))
227273

274+
# Discrete transformations do not need the jacobian adjustment
275+
logprob = input_logprob if op.is_discrete else input_logprob + jacobian
276+
228277
# The jacobian is used to ensure a value in the supported domain was provided
229-
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
278+
return pt.switch(pt.isnan(jacobian), -np.inf, logprob)
230279

231280

232281
MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid)
@@ -236,6 +285,10 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
236285
@_logcdf.register(MeasurableTransform)
237286
def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs):
238287
"""Compute the log-CDF graph for a `MeasurabeTransform`."""
288+
if isinstance(op.scalar_op, Abs):
289+
# Special case for absolute value transformation
290+
return abs_logcdf(op, value, *inputs, **kwargs)
291+
239292
other_inputs = list(inputs)
240293
measurable_input = other_inputs.pop(op.measurable_input_idx)
241294
backward_value = op.transform_elemwise.backward(value, *other_inputs)
@@ -245,10 +298,8 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
245298
if isinstance(backward_value, tuple):
246299
raise NotImplementedError
247300

248-
is_discrete = measurable_input.type.dtype.startswith("int")
249-
250301
logcdf = _logcdf_helper(measurable_input, backward_value)
251-
if is_discrete:
302+
if op.is_discrete:
252303
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
253304
else:
254305
logccdf = pt.log1mexp(logcdf)
@@ -275,9 +326,6 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
275326
# We don't know if this Op is monotonically increasing/decreasing
276327
raise NotImplementedError
277328

278-
if is_discrete:
279-
return logcdf
280-
281329
# The jacobian is used to ensure a value in the supported domain was provided
282330
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
283331
return pt.switch(pt.isnan(jacobian), -np.inf, logcdf)
@@ -286,13 +334,12 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
286334
@_icdf.register(MeasurableTransform)
287335
def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs):
288336
"""Compute the inverse CDF graph for a `MeasurabeTransform`."""
337+
if op.is_discrete:
338+
raise NotImplementedError("icdf of transformed discrete variables not implemented")
339+
289340
other_inputs = list(inputs)
290341
measurable_input = other_inputs.pop(op.measurable_input_idx)
291342

292-
# Do not apply rewrite to discrete variables
293-
if measurable_input.type.dtype.startswith("int"):
294-
raise NotImplementedError("icdf of transformed discrete variables not implemented")
295-
296343
if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
297344
pass
298345
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
@@ -323,7 +370,7 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
323370
# Fail if transformation is not injective
324371
# A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
325372
if isinstance(op.transform_elemwise.backward(icdf, *other_inputs), tuple):
326-
raise NotImplementedError
373+
raise NotImplementedError("icdf of non-injective transformations not implemented")
327374

328375
return icdf
329376

@@ -481,15 +528,22 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia
481528
[measurable_input] = measurable_inputs
482529
[measurable_output] = node.outputs
483530

484-
# Do not apply rewrite to discrete variables except for their addition and negation
485-
if measurable_input.type.dtype.startswith("int"):
531+
# Do not apply rewrite to discrete variables except if:
532+
# 1. Operation retains a discrete output
533+
# 2. Operation doesn't create holes in the support
534+
# Reason:
535+
# 1. Due to a limitation in our IR we don't know the type of the MeasurableVariable
536+
# We don't want to make other rewrites think they are dealing with continuous variables when they are not
537+
# 2. We don't want to add cumbersome within-domain checks
538+
is_discrete = measurable_input.type.dtype.startswith("int")
539+
if is_discrete:
540+
if not measurable_output.type.dtype.startswith("int"):
541+
return None
486542
if not (
487-
find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add)
543+
isinstance(node.op.scalar_op, Add | Abs)
544+
or find_negated_var(measurable_output) is not None
488545
):
489546
return None
490-
# Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable
491-
if not measurable_output.type.dtype.startswith("int"):
492-
return None
493547

494548
# Check that other inputs are not potentially measurable, in which case this rewrite
495549
# would be invalid
@@ -545,6 +599,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia
545599
scalar_op=scalar_op,
546600
transform=transform,
547601
measurable_input_idx=measurable_input_idx,
602+
is_discrete=is_discrete,
548603
)
549604
transform_out = transform_op.make_node(*transform_inputs).default_output()
550605
return [transform_out]

tests/logprob/test_transforms.py

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from pytensor.graph.basic import equal_computations
4545

4646
from pymc.distributions.continuous import Cauchy, ChiSquared
47-
from pymc.distributions.discrete import Bernoulli
47+
from pymc.distributions.discrete import Bernoulli, DiscreteUniform
4848
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
4949
from pymc.logprob.transforms import (
5050
ArccoshTransform,
@@ -285,6 +285,27 @@ def test_loc_transform_rv(self, rv_size, loc_type, addition):
285285
sp.stats.norm(loc_test_val, 1).ppf(q_test_val),
286286
)
287287

288+
def test_shifted_discrete_rv_transform(self):
289+
p = 0.7
290+
rv = Bernoulli.dist(p=p) + 5
291+
vv = rv.type()
292+
293+
rv_logp_fn = pytensor.function([vv], logp(rv, vv))
294+
assert rv_logp_fn(4) == -np.inf
295+
np.testing.assert_allclose(rv_logp_fn(5), np.log(1 - p))
296+
np.testing.assert_allclose(rv_logp_fn(6), np.log(p))
297+
assert rv_logp_fn(7) == -np.inf
298+
299+
rv_logcdf_fn = pytensor.function([vv], logcdf(rv, vv))
300+
assert rv_logcdf_fn(4) == -np.inf
301+
np.testing.assert_allclose(rv_logcdf_fn(5), np.log(1 - p))
302+
np.testing.assert_allclose(rv_logcdf_fn(6), 0)
303+
assert rv_logcdf_fn(7) == 0
304+
305+
# icdf not supported yet
306+
with pytest.raises(NotImplementedError):
307+
icdf(rv, 0)
308+
288309
@pytest.mark.parametrize(
289310
"rv_size, scale_type, product",
290311
[
@@ -337,6 +358,23 @@ def test_negated_rv_transform(self):
337358
np.testing.assert_allclose(x_logcdf_fn(-1.5), sp.stats.halfnorm.logsf(1.5))
338359
np.testing.assert_allclose(x_icdf_fn(0.3), -sp.stats.halfnorm.ppf(1 - 0.3))
339360

361+
def test_negated_discrete_rv_transform(self):
362+
p = 0.7
363+
rv = -Bernoulli.dist(p=p, shape=(4,))
364+
vv = rv.type()
365+
366+
# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
367+
logp_fn = pytensor.function([vv], logp(rv, vv))
368+
np.testing.assert_allclose(
369+
logp_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), np.log(1 - p), -np.inf]
370+
)
371+
372+
logcdf_fn = pytensor.function([vv], logcdf(rv, vv))
373+
np.testing.assert_allclose(logcdf_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), 0, 0])
374+
375+
with pytest.raises(NotImplementedError):
376+
icdf(rv, [-2, -1, 0, 1])
377+
340378
def test_subtracted_rv_transform(self):
341379
# Choose base RV that is asymmetric around zero
342380
x_rv = 5.0 - pt.random.normal(1.0)
@@ -501,21 +539,33 @@ def test_negative_value_frac_power_transform_logp(self, power):
501539
assert np.isneginf(x_logp_fn(-2.5))
502540

503541

504-
@pytest.mark.parametrize("test_val", (2.5, -2.5))
505-
def test_absolute_rv_transform(test_val):
506-
x_rv = pt.abs(pt.random.normal())
507-
y_rv = pt.random.halfnormal()
542+
@pytest.mark.parametrize("continuous", (True, False))
543+
def test_absolute_rv_transform(continuous):
544+
if continuous:
545+
x_rv = pt.abs(pt.random.normal(size=(5,)))
546+
ref_rv = pt.random.halfnormal(size=(5,))
547+
else:
548+
x_rv = pt.abs(DiscreteUniform.dist(-4, 4, size=(5,)))
549+
# |x_rv| = DiscreteUniform(0,4) with P(X=0) halved relative to other values
550+
# We can use a Categorical to representh this
551+
ref_rv = pt.random.categorical(
552+
p=np.array([1, 2, 2, 2, 2]) / 9,
553+
size=(5,),
554+
)
508555

509-
x_vv = x_rv.clone()
510-
y_vv = y_rv.clone()
511-
x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv))
512-
with pytest.raises(NotImplementedError):
513-
logcdf(x_rv, x_vv)
556+
x_vv = x_rv.type()
557+
ref_vv = ref_rv.type()
558+
# Not working with logs because it's easier to debug for discrete case
559+
x_pdf_fn = pytensor.function([x_vv], pt.exp(logp(x_rv, x_vv)))
560+
x_cdf_fn = pytensor.function([x_vv], pt.exp(logcdf(x_rv, x_vv)))
514561
with pytest.raises(NotImplementedError):
515562
icdf(x_rv, x_vv)
516563

517-
y_logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv))
518-
np.testing.assert_allclose(x_logp_fn(test_val), y_logp_fn(test_val))
564+
ref_pdf_fn = pytensor.function([ref_vv], pt.exp(logp(ref_rv, ref_vv)))
565+
ref_cdf_fn = pytensor.function([ref_vv], pt.exp(logcdf(ref_rv, ref_vv)))
566+
test_val = np.array([-2.5, -2.0, 0, 2.0, 2.5], dtype=x_vv.dtype)
567+
np.testing.assert_allclose(x_pdf_fn(test_val), ref_pdf_fn(test_val))
568+
np.testing.assert_allclose(x_cdf_fn(test_val), ref_cdf_fn(test_val))
519569

520570

521571
@pytest.mark.parametrize(
@@ -690,51 +740,11 @@ def test_not_implemented_discrete_rv_transform():
690740
with pytest.raises(RuntimeError, match="could not be derived"):
691741
conditional_logp({y_rv: y_rv.clone()})
692742

693-
y_rv = 5 * pt.random.poisson(1)
743+
y_rv = 5.5 * pt.random.poisson(1)
694744
with pytest.raises(RuntimeError, match="could not be derived"):
695745
conditional_logp({y_rv: y_rv.clone()})
696746

697747

698-
def test_negated_discrete_rv_transform():
699-
p = 0.7
700-
rv = -Bernoulli.dist(p=p, shape=(4,))
701-
vv = rv.type()
702-
703-
# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
704-
logp_fn = pytensor.function([vv], logp(rv, vv))
705-
np.testing.assert_allclose(
706-
logp_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), np.log(1 - p), -np.inf]
707-
)
708-
709-
logcdf_fn = pytensor.function([vv], logcdf(rv, vv))
710-
np.testing.assert_allclose(logcdf_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), 0, 0])
711-
712-
with pytest.raises(NotImplementedError):
713-
icdf(rv, [-2, -1, 0, 1])
714-
715-
716-
def test_shifted_discrete_rv_transform():
717-
p = 0.7
718-
rv = Bernoulli.dist(p=p) + 5
719-
vv = rv.type()
720-
721-
rv_logp_fn = pytensor.function([vv], logp(rv, vv))
722-
assert rv_logp_fn(4) == -np.inf
723-
np.testing.assert_allclose(rv_logp_fn(5), np.log(1 - p))
724-
np.testing.assert_allclose(rv_logp_fn(6), np.log(p))
725-
assert rv_logp_fn(7) == -np.inf
726-
727-
rv_logcdf_fn = pytensor.function([vv], logcdf(rv, vv))
728-
assert rv_logcdf_fn(4) == -np.inf
729-
np.testing.assert_allclose(rv_logcdf_fn(5), np.log(1 - p))
730-
np.testing.assert_allclose(rv_logcdf_fn(6), 0)
731-
assert rv_logcdf_fn(7) == 0
732-
733-
# icdf not supported yet
734-
with pytest.raises(NotImplementedError):
735-
icdf(rv, 0)
736-
737-
738748
@pytest.mark.xfail(reason="Check not implemented yet")
739749
def test_invalid_broadcasted_transform_rv_fails():
740750
loc = pt.vector("loc")

0 commit comments

Comments
 (0)