Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,13 @@ def logcdf(value, mu, sigma):
msg="sigma > 0",
)

def logccdf(value, mu, sigma):
return check_parameters(
normal_lccdf(mu, sigma, value),
sigma > 0,
msg="sigma > 0",
)

def icdf(value, mu, sigma):
res = mu + sigma * -np.sqrt(2.0) * pt.erfcinv(2 * value)
res = check_icdf_value(res, value)
Expand Down
13 changes: 12 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
rv_size_is_none,
shape_from_dims,
)
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableOp, _icdf, _logccdf, _logcdf, _logprob
from pymc.logprob.basic import logp
from pymc.logprob.rewriting import logprob_rewrites_db
from pymc.printing import str_for_dist
Expand Down Expand Up @@ -150,6 +150,17 @@ def logcdf(op, value, *dist_params, **kwargs):
dist_params = [dist_params[i] for i in params_idxs]
return class_logcdf(value, *dist_params)

class_logccdf = clsdict.get("logccdf")
if class_logccdf:

@_logccdf.register(rv_type)
def logccdf(op, value, *dist_params, **kwargs):
if isinstance(op, RandomVariable):
rng, size, *dist_params = dist_params
elif params_idxs:
dist_params = [dist_params[i] for i in params_idxs]
return class_logccdf(value, *dist_params)

class_icdf = clsdict.get("icdf")
if class_icdf:

Expand Down
25 changes: 22 additions & 3 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.logprob.basic import icdf, logcdf, logp
from pymc.logprob.basic import icdf, logccdf, logcdf, logp
from pymc.math import logdiffexp
from pymc.pytensorf import collect_default_updates
from pymc.util import check_dist_not_registered
Expand Down Expand Up @@ -211,6 +211,23 @@ def _create_logcdf_exprs(
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
return lower_logcdf, upper_logcdf

@staticmethod
def _create_lower_logccdf_expr(
base_rv: TensorVariable,
value: TensorVariable,
lower: TensorVariable,
) -> TensorVariable:
"""Create logccdf expression at lower bound for base_rv.

Uses `value` as a template for broadcasting. This is numerically more
stable than computing log(1 - exp(logcdf)) for distributions that have
a registered logccdf method.
"""
# For left truncated discrete RVs, we need to include the whole lower bound.
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
return logccdf(base_rv, lower_value, warn_rvs=False)

def update(self, node: Apply):
"""Return the update mapping for the internal RNGs.

Expand Down Expand Up @@ -401,7 +418,8 @@ def truncated_logprob(op, values, *inputs, **kwargs):
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
# Use numerically stable logccdf instead of log(1 - exp(logcdf))
lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower)
elif is_upper_bounded:
lognorm = upper_logcdf

Expand Down Expand Up @@ -438,7 +456,8 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = pt.log1mexp(lower_logcdf)
# Use numerically stable logccdf instead of log(1 - exp(logcdf))
lognorm = TruncatedRV._create_lower_logccdf_expr(base_rv, value, lower)
elif is_upper_bounded:
lognorm = upper_logcdf

Expand Down
2 changes: 2 additions & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from pymc.logprob.basic import (
conditional_logp,
icdf,
logccdf,
logcdf,
logp,
transformed_conditional_logp,
Expand All @@ -59,6 +60,7 @@

__all__ = (
"icdf",
"logccdf",
"logcdf",
"logp",
)
41 changes: 41 additions & 0 deletions pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from collections.abc import Sequence
from functools import singledispatch

import pytensor.tensor as pt

from pytensor.graph import Apply, Op, Variable
from pytensor.graph.utils import MetaType
from pytensor.tensor import TensorVariable
Expand Down Expand Up @@ -108,6 +110,45 @@ def _logcdf_helper(rv, value, **kwargs):
return logcdf


@singledispatch
def _logccdf(
op: Op,
value: TensorVariable,
*inputs: TensorVariable,
**kwargs,
):
"""Create a graph for the log complementary CDF (log survival function) of a ``RandomVariable``.

This function dispatches on the type of ``op``, which should be a subclass
of ``RandomVariable``. If you want to implement new logccdf graphs
for a ``RandomVariable``, register a new function on this dispatcher.

The log complementary CDF is defined as log(1 - CDF(x)), also known as the
log survival function. For distributions with a numerically stable implementation,
this should be used instead of computing log(1 - exp(logcdf)).
"""
raise NotImplementedError(f"LogCCDF method not implemented for {op}")


def _logccdf_helper(rv, value, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this method do the try except fallback to log1mexp? So users/devs don't need to do it all the time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! The helper now tries _logccdf first and automatically falls back to log1mexp(logcdf) if not implemented. Callers no longer need to handle the exception.

15e5f64

"""Helper that calls `_logccdf` dispatcher with fallback to log1mexp(logcdf).

If a numerically stable `_logccdf` implementation is registered for the
distribution, it will be used. Otherwise, falls back to computing
`log(1 - exp(logcdf))` which may be numerically unstable in the tails.
"""
try:
logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs)
except NotImplementedError:
logcdf = _logcdf_helper(rv, value, **kwargs)
logccdf = pt.log1mexp(logcdf)

if rv.name:
logccdf.name = f"{rv.name}_logccdf"

return logccdf


@singledispatch
def _icdf(
op: Op,
Expand Down
65 changes: 65 additions & 0 deletions pymc/logprob/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from pymc.logprob.abstract import (
MeasurableOp,
_icdf_helper,
_logccdf_helper,
_logcdf_helper,
_logprob,
_logprob_helper,
Expand Down Expand Up @@ -302,6 +303,70 @@ def normal_logcdf(value, mu, sigma):
return expr


def logccdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable:
"""Create a graph for the log complementary CDF (log survival function) of a random variable.

The log complementary CDF is defined as log(1 - CDF(x)), also known as the
log survival function. For distributions with a numerically stable implementation,
this is more accurate than computing log(1 - exp(logcdf)).

Parameters
----------
rv : TensorVariable
value : tensor_like
Should be the same type (shape and dtype) as the rv.
warn_rvs : bool, default True
Warn if RVs were found in the logccdf graph.
This can happen when a variable has other random variables as inputs.
In that case, those random variables should be replaced by their respective values.

Returns
-------
logccdf : TensorVariable

Raises
------
RuntimeError
If the logccdf cannot be derived.

Examples
--------
Create a compiled function that evaluates the logccdf of a variable

.. code-block:: python

import pymc as pm
import pytensor.tensor as pt

mu = pt.scalar("mu")
rv = pm.Normal.dist(mu, 1.0)

value = pt.scalar("value")
rv_logccdf = pm.logccdf(rv, value)

# Use .eval() for debugging
print(rv_logccdf.eval({value: 0.9, mu: 0.0})) # -1.5272506

# Compile a function for repeated evaluations
rv_logccdf_fn = pm.compile_pymc([value, mu], rv_logccdf)
print(rv_logccdf_fn(value=0.9, mu=0.0)) # -1.5272506

"""
value = pt.as_tensor_variable(value, dtype=rv.dtype)
try:
return _logccdf_helper(rv, value, **kwargs)
except NotImplementedError:
# Try to rewrite rv
fgraph = construct_ir_fgraph({rv: value})
[ir_valued_rv] = fgraph.outputs
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
expr = _logccdf_helper(ir_rv, ir_value, **kwargs)
[expr] = cleanup_ir([expr])
if warn_rvs:
_warn_rvs_in_inferred_graph([expr])
return expr


def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=True, **kwargs) -> TensorVariable:
"""Create a graph for the inverse CDF of a random variable.

Expand Down
3 changes: 2 additions & 1 deletion pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pymc.logprob.abstract import (
MeasurableElemwise,
_logccdf_helper,
_logcdf_helper,
_logprob,
_logprob_helper,
Expand Down Expand Up @@ -95,7 +96,7 @@ def comparison_logprob(op, values, base_rv, operand, **kwargs):
base_rv_op = base_rv.owner.op

logcdf = _logcdf_helper(base_rv, operand, **kwargs)
logccdf = pt.log1mexp(logcdf)
logccdf = _logccdf_helper(base_rv, operand, **kwargs)

condn_exp = pt.eq(value, np.array(True))

Expand Down
6 changes: 4 additions & 2 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import MeasurableElemwise, _logcdf, _logprob
from pymc.logprob.abstract import MeasurableElemwise, _logccdf_helper, _logcdf, _logprob
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue, filter_measurable_variables

Expand Down Expand Up @@ -119,7 +119,9 @@ def clip_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
if not (isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))):
is_upper_bounded = True

logccdf = pt.log1mexp(logcdf)
# Use numerically stable logccdf (falls back to log1mexp if not available)
logccdf = _logccdf_helper(base_rv, value, **kwargs)

# For right clipped discrete RVs, we need to add an extra term
# corresponding to the pmf at the upper bound
if base_rv.dtype.startswith("int"):
Expand Down
5 changes: 4 additions & 1 deletion pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
MeasurableOp,
_icdf,
_icdf_helper,
_logccdf_helper,
_logcdf,
_logcdf_helper,
_logprob,
Expand Down Expand Up @@ -248,9 +249,11 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg

logcdf = _logcdf_helper(measurable_input, backward_value)
if is_discrete:
# For discrete distributions, use the logcdf at the previous value
logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1))
else:
logccdf = pt.log1mexp(logcdf)
# Use numerically stable logccdf (falls back to log1mexp if not available)
logccdf = _logccdf_helper(measurable_input, backward_value)

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
Expand Down
67 changes: 67 additions & 0 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,73 @@ def rv_op(cls, size=None, rng=None):
resized_rv = change_dist_size(rv, new_size=5, expand=True)
assert resized_rv.type.shape == (5,)

def test_logccdf_with_extended_signature(self):
"""Test logccdf registration for SymbolicRandomVariable with extended_signature.

What: Tests that a custom Distribution subclass using SymbolicRandomVariable
with an extended_signature can define a logccdf method that gets properly
registered and dispatched.

Why: The DistributionMeta metaclass has two code paths for registering
distribution methods like logp, logcdf, logccdf:
1. For standard RandomVariable ops: unpack (rng, size, *params)
2. For SymbolicRandomVariable with extended_signature: use params_idxs

This test specifically exercises path #2 (the params_idxs branch) to ensure
logccdf works for custom distributions that wrap other distributions with
additional graph structure.

How:
1. Creates a custom Distribution (TestDistWithLogccdf) that:
- Uses a SymbolicRandomVariable with extended_signature
- Wraps a Normal distribution internally
- Defines a logccdf method using normal_lccdf
2. Creates an instance with mu=0, sigma=1
3. Evaluates pm.logccdf at value=0.5
4. Compares against scipy.stats.norm.logsf reference

The extended_signature "[rng],[size],(),()->[rng],()" means:
- Inputs: rng, size, and two scalar params (mu, sigma)
- Outputs: next_rng and scalar draws
"""
from pymc.distributions.dist_math import normal_lccdf
from pymc.distributions.distribution import Distribution

class TestDistWithLogccdf(Distribution):
# Create a SymbolicRandomVariable type with extended_signature
rv_type = type(
"TestRVWithLogccdf",
(SymbolicRandomVariable,),
{"extended_signature": "[rng],[size],(),()->[rng],()"},
)

@classmethod
def dist(cls, mu, sigma, **kwargs):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
return super().dist([mu, sigma], **kwargs)

@classmethod
def rv_op(cls, mu, sigma, size=None, rng=None):
rng = normalize_rng_param(rng)
size = normalize_size_param(size)
# Internally uses Normal, but wrapped in SymbolicRandomVariable
next_rng, draws = Normal.dist(mu, sigma, size=size, rng=rng).owner.outputs
return cls.rv_type(
inputs=[rng, size, mu, sigma],
outputs=[next_rng, draws],
ndim_supp=0,
)(rng, size, mu, sigma)

# This logccdf will be registered via params_idxs path
def logccdf(value, mu, sigma):
return normal_lccdf(mu, sigma, value)

rv = TestDistWithLogccdf.dist(0, 1)
result = pm.logccdf(rv, 0.5).eval()
expected = st.norm(0, 1).logsf(0.5) # ≈ -0.994
npt.assert_allclose(result, expected)


def test_distribution_op_registered():
"""Test that returned Ops are registered as virtual subclasses of the respective PyMC distributions."""
Expand Down
Loading