-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add _logccdf dispatcher for numerically stable log survival function in censored distributions
#7996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add _logccdf dispatcher for numerically stable log survival function in censored distributions
#7996
Changes from all commits
87eba5d
81da946
15e5f64
15806c0
063af42
19b9979
20322a1
93734c2
2df5274
63c9327
628e6d5
36b8672
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,8 @@ | |
| from collections.abc import Sequence | ||
| from functools import singledispatch | ||
|
|
||
| import pytensor.tensor as pt | ||
|
|
||
| from pytensor.graph import Apply, Op, Variable | ||
| from pytensor.graph.utils import MetaType | ||
| from pytensor.tensor import TensorVariable | ||
|
|
@@ -108,6 +110,45 @@ def _logcdf_helper(rv, value, **kwargs): | |
| return logcdf | ||
|
|
||
|
|
||
| @singledispatch | ||
| def _logccdf( | ||
| op: Op, | ||
| value: TensorVariable, | ||
| *inputs: TensorVariable, | ||
| **kwargs, | ||
| ): | ||
| """Create a graph for the log complementary CDF (log survival function) of a ``RandomVariable``. | ||
|
|
||
| This function dispatches on the type of ``op``, which should be a subclass | ||
| of ``RandomVariable``. If you want to implement new logccdf graphs | ||
| for a ``RandomVariable``, register a new function on this dispatcher. | ||
|
|
||
| The log complementary CDF is defined as log(1 - CDF(x)), also known as the | ||
| log survival function. For distributions with a numerically stable implementation, | ||
| this should be used instead of computing log(1 - exp(logcdf)). | ||
| """ | ||
| raise NotImplementedError(f"LogCCDF method not implemented for {op}") | ||
|
|
||
|
|
||
| def _logccdf_helper(rv, value, **kwargs): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this method do the try except fallback to log1mexp? So users/devs don't need to do it all the time
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! The helper now tries → 15e5f64 |
||
| """Helper that calls `_logccdf` dispatcher with fallback to log1mexp(logcdf). | ||
|
|
||
| If a numerically stable `_logccdf` implementation is registered for the | ||
| distribution, it will be used. Otherwise, falls back to computing | ||
| `log(1 - exp(logcdf))` which may be numerically unstable in the tails. | ||
| """ | ||
| try: | ||
| logccdf = _logccdf(rv.owner.op, value, *rv.owner.inputs, name=rv.name, **kwargs) | ||
| except NotImplementedError: | ||
| logcdf = _logcdf_helper(rv, value, **kwargs) | ||
| logccdf = pt.log1mexp(logcdf) | ||
|
|
||
| if rv.name: | ||
| logccdf.name = f"{rv.name}_logccdf" | ||
|
|
||
| return logccdf | ||
|
|
||
|
|
||
| @singledispatch | ||
| def _icdf( | ||
| op: Op, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.