Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9baf9ca
Add Coords and StrongCoords typing aliases and standardize model/arvi…
aman-coder03 Dec 2, 2025
7fc1312
Fix ruff formatting, imports, and dev requirements
aman-coder03 Dec 3, 2025
9e60df8
Fix circular import by importing modelcontext from pymc.model.core
aman-coder03 Dec 3, 2025
46ed9ca
Fix circular import by lazily importing modelcontext in shape_from_dims
aman-coder03 Dec 3, 2025
8f5040d
Fix circular import by using only lazy modelcontext imports
aman-coder03 Dec 3, 2025
02f9053
Fix Model circular import using TYPE_CHECKING and lazy import
aman-coder03 Dec 3, 2025
c6f3f34
Fix lazy modelcontext import flagged by ruff
aman-coder03 Dec 3, 2025
5b4b1f6
Fix missing modelcontext import flagged by ruff
aman-coder03 Dec 3, 2025
ba05761
Fix circular import of Model in printing.py
aman-coder03 Dec 3, 2025
21b04fd
Move coords typing to pymc.typing and fix circular imports
aman-coder03 Dec 4, 2025
b5ca208
Fix ruff UP038 isinstance union style
aman-coder03 Dec 4, 2025
0f9cad7
Fix printing Model NameError and move Coord typing to pymc.typing
aman-coder03 Dec 4, 2025
26d756d
Move coords typing to pymc.typing and fix printing imports
aman-coder03 Dec 4, 2025
447f990
Remove implicit modelcontext fallback from shape_from_dims
aman-coder03 Dec 4, 2025
627e403
Fix shape_from_dims typing and remove circular import
aman-coder03 Dec 4, 2025
784de8b
Fix typing and import order
aman-coder03 Dec 4, 2025
dec6f7c
Removing comments from typing.py
aman-coder03 Dec 6, 2025
087ebe5
Remove deprecated pymc.typing module after moving aliases to pymc.util
aman-coder03 Dec 6, 2025
0016b6b
Fixes
aman-coder03 Dec 6, 2025
c930d5b
Apply suggestions from code review
aman-coder03 Dec 9, 2025
c651cb2
Update util.py
aman-coder03 Dec 9, 2025
fb07885
Refactor shape_from_dims usage to Model method
aman-coder03 Dec 9, 2025
7c22c16
implementing the suggested changes
aman-coder03 Dec 10, 2025
e823b9f
Use shared type aliases from pymc.util in shape_utils
aman-coder03 Dec 16, 2025
4f924a1
Fix pre-commit formatting issues
aman-coder03 Dec 16, 2025
14f77d6
Fix docstring
aman-coder03 Dec 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.
"""PyMC-ArviZ conversion code."""

from __future__ import annotations

import logging
import warnings

from collections.abc import Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
TypeAlias,
cast,
)

Expand All @@ -38,13 +39,16 @@

import pymc

from pymc.model import Model, modelcontext
from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames
from pymc.model import modelcontext
from pymc.util import StrongCoords

if TYPE_CHECKING:
from pymc.backends.base import MultiTrace
from pymc.model import Model

from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames

___all__ = [""]

Expand All @@ -56,6 +60,7 @@

# random variable object ...
Var = Any
DimsDict: TypeAlias = Mapping[str, Sequence[str]]


def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
Expand Down Expand Up @@ -85,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)


def find_observations(model: "Model") -> dict[str, Var]:
def find_observations(model: Model) -> dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
for obs in model.observed_RVs:
Expand All @@ -102,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]:
return observations


def find_constants(model: "Model") -> dict[str, Var]:
def find_constants(model: Model) -> dict[str, Var]:
"""If there are constants available, return them as a dictionary."""
model_vars = model.basic_RVs + model.deterministics + model.potentials
value_vars = set(model.rvs_to_values.values())
Expand All @@ -123,7 +128,9 @@ def find_constants(model: "Model") -> dict[str, Var]:
return constant_data


def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]:
def coords_and_dims_for_inferencedata(
model: Model,
) -> tuple[StrongCoords, DimsDict]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
Expand Down Expand Up @@ -265,7 +272,7 @@ def __init__(

self.observations = find_observations(self.model)

def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
def split_trace(self) -> tuple[None | MultiTrace, None | MultiTrace]:
"""Split MultiTrace object into posterior and warmup.

Returns
Expand Down Expand Up @@ -491,7 +498,7 @@ def to_inference_data(self):


def to_inference_data(
trace: Optional["MultiTrace"] = None,
trace: MultiTrace | None = None,
*,
prior: Mapping[str, Any] | None = None,
posterior_predictive: Mapping[str, Any] | None = None,
Expand All @@ -500,7 +507,7 @@ def to_inference_data(
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
model: Optional["Model"] = None,
model: Model | None = None,
save_warmup: bool | None = None,
include_transformed: bool = False,
) -> InferenceData:
Expand Down Expand Up @@ -568,8 +575,8 @@ def to_inference_data(
### perhaps we should have an inplace argument?
def predictions_to_inference_data(
predictions,
posterior_trace: Optional["MultiTrace"] = None,
model: Optional["Model"] = None,
posterior_trace: MultiTrace | None = None,
model: Model | None = None,
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
Expand Down
3 changes: 1 addition & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
convert_size,
find_size,
rv_size_is_none,
shape_from_dims,
)
from pymc.logprob.abstract import MeasurableOp, _icdf, _logcdf, _logprob
from pymc.logprob.basic import logp
Expand Down Expand Up @@ -522,7 +521,7 @@ def __new__(
# finally, observed, to determine the shape of the variable.
if kwargs.get("size") is None and kwargs.get("shape") is None:
if dims is not None:
kwargs["shape"] = shape_from_dims(dims, model)
kwargs["shape"] = model.symbolic_shape_from_dims(dims)
elif observed is not None:
kwargs["shape"] = tuple(observed.shape)

Expand Down
61 changes: 17 additions & 44 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

"""Common shape operations to broadcast samples from probability distributions for stochastic nodes in PyMC."""

from __future__ import annotations

import warnings

from collections.abc import Sequence
from functools import singledispatch
from types import EllipsisType
from typing import Any, TypeAlias, cast
from typing import Any, cast

import numpy as np

Expand All @@ -33,18 +35,25 @@
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorVariable

from pymc.model import modelcontext
from pymc.pytensorf import convert_observed_data
from pymc.exceptions import ShapeError
from pymc.pytensorf import PotentialShapeType, convert_observed_data
from pymc.util import (
Dims,
DimsWithEllipsis,
Shape,
Size,
StrongDims,
StrongDimsWithEllipsis,
StrongShape,
StrongSize,
)

__all__ = [
"change_dist_size",
"rv_size_is_none",
"to_tuple",
]

from pymc.exceptions import ShapeError
from pymc.pytensorf import PotentialShapeType


def to_tuple(shape):
"""Convert ints, arrays, and Nones to tuples.
Expand Down Expand Up @@ -85,19 +94,6 @@ def _check_shape_type(shape):
return tuple(out)


# User-provided can be lazily specified as scalars
Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable]
Dims: TypeAlias = str | Sequence[str | None]
DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType]
Size: TypeAlias = int | TensorVariable | Sequence[int | Variable]

# After conversion to vectors
StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...]
StrongDims: TypeAlias = Sequence[str]
StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]


def convert_dims(dims: Dims | None) -> StrongDims | None:
"""Process a user-provided dims variable into None or a valid dims tuple."""
if dims is None:
Expand Down Expand Up @@ -164,31 +160,6 @@ def convert_size(size: Size) -> StrongSize | None:
)


def shape_from_dims(dims: StrongDims, model) -> StrongShape:
"""Determine shape from a `dims` tuple.

Parameters
----------
dims : array-like
A vector of dimension names or None.
model : pm.Model
The current model on stack.

Returns
-------
dims : tuple of (str or None)
Names or None for all RV dimensions.
"""
# Dims must be known already
unknowndim_dims = set(dims) - set(model.dim_lengths)
if unknowndim_dims:
raise KeyError(
f"Dimensions {unknowndim_dims} are unknown to the model and cannot be used to specify a `shape`."
)

return tuple(model.dim_lengths[dname] for dname in dims)


def find_size(
shape: StrongShape | None,
size: StrongSize | None,
Expand Down Expand Up @@ -403,6 +374,8 @@ def get_support_shape(
assert isinstance(dims, tuple)
if len(dims) < ndim_supp:
raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}")
from pymc.model.core import modelcontext

model = modelcontext(None)
inferred_support_shape = [
model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0)
Expand Down
13 changes: 8 additions & 5 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
)
from pymc.util import (
UNSET,
Coords,
CoordValue,
StrongCoords,
WithMemoization,
_UnsetType,
get_transformed_name,
Expand Down Expand Up @@ -453,7 +456,7 @@ def _validate_name(name):
def __init__(
self,
name="",
coords=None,
coords: Coords | None = None,
check_bounds=True,
*,
model: _UnsetType | None | Model = UNSET,
Expand Down Expand Up @@ -488,7 +491,7 @@ def __init__(
self.deterministics = treelist()
self.potentials = treelist()
self.data_vars = treelist()
self._coords = {}
self._coords: StrongCoords = {}
self._dim_lengths = {}
self.add_coords(coords)

Expand Down Expand Up @@ -907,7 +910,7 @@ def unobserved_RVs(self):
return self.free_RVs + self.deterministics

@property
def coords(self) -> dict[str, tuple | None]:
def coords(self) -> StrongCoords:
"""Coordinate values for model dimensions."""
return self._coords

Expand All @@ -919,7 +922,7 @@ def dim_lengths(self) -> dict[str, TensorVariable]:
"""
return self._dim_lengths

def shape_from_dims(self, dims):
def symbolic_shape_from_dims(self, dims):
shape = []
if len(set(dims)) != len(dims):
raise ValueError("Can not contain the same dimension name twice.")
Expand All @@ -937,7 +940,7 @@ def shape_from_dims(self, dims):
def add_coord(
self,
name: str,
values: Sequence | np.ndarray | None = None,
values: CoordValue = None,
*,
length: int | Variable | None = None,
):
Expand Down
9 changes: 8 additions & 1 deletion pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.


from __future__ import annotations

import re

from functools import partial
from typing import TYPE_CHECKING

from pytensor.compile import SharedVariable
from pytensor.graph.basic import Constant
Expand All @@ -26,7 +29,9 @@
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.type_other import NoneTypeT

from pymc.model import Model
if TYPE_CHECKING:
from pymc.model import Model


__all__ = [
"str_for_dist",
Expand Down Expand Up @@ -302,6 +307,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
# register our custom pretty printer in ipython shells
import IPython

from pymc.model.core import Model

IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty)
IPython.lib.pretty.for_type(Model, _default_repr_pretty)
except (ModuleNotFoundError, AttributeError):
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DataClassState:
def equal_dataclass_values(v1, v2):
if v1.__class__ != v2.__class__:
return False
if isinstance(v1, (list, tuple)): # noqa: UP038
if isinstance(v1, list | tuple):
return len(v1) == len(v2) and all(
equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True)
)
Expand Down
44 changes: 41 additions & 3 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,59 @@
import re

from collections import namedtuple
from collections.abc import Sequence
from collections.abc import Hashable, Mapping, Sequence
from copy import deepcopy
from typing import cast
from types import EllipsisType
from typing import TypeAlias, cast

import arviz
import cloudpickle
import numpy as np
import xarray

from cachetools import LRUCache, cachedmethod
from pytensor import Variable
from pytensor.compile import SharedVariable
from pytensor.graph.basic import Variable
from pytensor.tensor.variable import TensorVariable

from pymc.exceptions import BlockModelAccessError

CoordValue: TypeAlias = Sequence[Hashable] | np.ndarray | None
"""User-provided values for a single coordinate dimension."""

Coords: TypeAlias = Mapping[str, CoordValue]
"""Mapping from dimension name to its coordinate values."""

StrongCoordValue: TypeAlias = tuple[Hashable, ...] | None
"""Normalized coordinate values stored internally."""

StrongCoords: TypeAlias = Mapping[str, StrongCoordValue]
"""Mapping from dimension name to normalized coordinate values."""

StrongDims: TypeAlias = tuple[str, ...]
"""Tuple of dimension names after validation."""

StrongShape: TypeAlias = tuple[int, ...]
"""Fully-resolved numeric shape used internally."""

Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable]
"""User-provided shape specification before normalization."""

Dims: TypeAlias = str | Sequence[str | None]
"""User-provided dimension names before normalization."""

DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType]
"""User-provided dimension names that may include ellipsis."""

Size: TypeAlias = int | TensorVariable | Sequence[int | Variable]
"""User-provided size specification before normalization."""

StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
"""Normalized dimension names that may include ellipsis."""

StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]
"""Normalized symbolic size used internally."""


class _UnsetType:
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""
Expand Down