Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions pytensor/link/jax/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,18 @@
"""


@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_AdvancedSubtensor1(op, node, **kwargs):
def advanced_subtensor1(x, ilist):
return x[ilist]

return advanced_subtensor1


@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
idx_list = op.idx_list

def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
Expand All @@ -47,10 +54,24 @@ def subtensor(x, *ilists):
return subtensor


@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, y, ilist):
return x.at[ilist].set(y)

else:

def jax_fn(x, y, ilist):
return x.at[ilist].add(y)

return jax_fn


@jax_funcify.register(IncSubtensor)
def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
idx_list = op.idx_list

if getattr(op, "set_instead_of_inc", False):

Expand All @@ -77,6 +98,8 @@ def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):

@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list

if getattr(op, "set_instead_of_inc", False):

def jax_fn(x, indices, y):
Expand All @@ -87,8 +110,11 @@ def jax_fn(x, indices, y):
def jax_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
return jax_fn(x, indices, y)

return advancedincsubtensor

Expand Down
49 changes: 26 additions & 23 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
Expand All @@ -29,7 +28,7 @@
IncSubtensor,
Subtensor,
)
from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType
from pytensor.tensor.type_other import MakeSlice


def slice_new(self, start, stop, step):
Expand Down Expand Up @@ -239,28 +238,32 @@ def {function_name}({", ".join(input_names)}):
@register_funcify_and_cache_key(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(op, AdvancedSubtensor):
_x, _y, idxs = node.inputs[0], None, node.inputs[1:]
tensor_inputs = node.inputs[1:]
else:
_x, _y, *idxs = node.inputs

basic_idxs = [
idx
for idx in idxs
if (
isinstance(idx.type, NoneTypeT)
or (isinstance(idx.type, SliceType) and not is_full_slice(idx))
)
]
adv_idxs = [
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
for i, idx in enumerate(idxs)
if isinstance(idx.type, TensorType)
]
tensor_inputs = node.inputs[2:]

# Reconstruct indexing information from idx_list and tensor inputs
basic_idxs = []
adv_idxs = []
input_idx = 0

for i, entry in enumerate(op.idx_list):
if isinstance(entry, slice):
# Basic slice index
basic_idxs.append(entry)
elif isinstance(entry, Type):
# Advanced tensor index
if input_idx < len(tensor_inputs):
idx_input = tensor_inputs[input_idx]
adv_idxs.append(
{
"axis": i,
"dtype": idx_input.type.dtype,
"bcast": idx_input.type.broadcastable,
"ndim": idx_input.type.ndim,
}
)
input_idx += 1

# Special implementation for consecutive integer vector indices
if (
Expand Down
25 changes: 18 additions & 7 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType
from pytensor.tensor.type_other import MakeSlice


def check_negative_steps(indices):
Expand Down Expand Up @@ -63,7 +63,10 @@ def makeslice(start, stop, step):
@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
idx_list = op.idx_list

def advsubtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
return x[indices]

Expand Down Expand Up @@ -102,12 +105,14 @@ def inc_subtensor(x, y, *flattened_indices):
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)

if op.set_instead_of_inc:

def adv_set_subtensor(x, y, *indices):
def adv_set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -120,7 +125,8 @@ def adv_set_subtensor(x, y, *indices):

elif ignore_duplicates:

def adv_inc_subtensor_no_duplicates(x, y, *indices):
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
Expand All @@ -132,13 +138,18 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices):
return adv_inc_subtensor_no_duplicates

else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
# Check if we have slice indexing in idx_list
has_slice_indexing = (
any(isinstance(entry, slice) for entry in idx_list) if idx_list else False
)
if has_slice_indexing:
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)

def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
def adv_inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
# Not needed because slices aren't supported in this path
# check_negative_steps(indices)
if not inplace:
x = x.clone()
Expand Down
40 changes: 40 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,46 @@ def do_constant_folding(self, fgraph, node):
return True


@_vectorize_node.register(Alloc)
def vectorize_alloc(op: Alloc, node: Apply, batch_val, *batch_shapes):
# batch_shapes are usually not batched (they are scalars for the shape)
# batch_val is the value being allocated.

# If shapes are batched, we fall back (complex case)
if any(
b_shp.type.ndim > shp.type.ndim
for b_shp, shp in zip(batch_shapes, node.inputs[1:], strict=True)
):
return vectorize_node_fallback(op, node, batch_val, *batch_shapes)

# If value is batched, we need to prepend batch dims to the output shape
val = node.inputs[0]
batch_ndim = batch_val.type.ndim - val.type.ndim

if batch_ndim == 0:
return op.make_node(batch_val, *batch_shapes)

# We need the size of the batch dimensions
# batch_val has shape (B1, B2, ..., val_dims...)
batch_dims = [batch_val.shape[i] for i in range(batch_ndim)]

new_shapes = batch_dims + list(batch_shapes)

# Alloc expects the value to be broadcastable to the shape from right to left.
# We need to insert singleton dimensions between the batch dimensions and the
# value dimensions so that the value broadcasts correctly against the shape.
missing_dims = len(batch_shapes) - val.type.ndim
if missing_dims > 0:
pattern = (
list(range(batch_ndim))
+ ["x"] * missing_dims
+ list(range(batch_ndim, batch_val.type.ndim))
)
batch_val = batch_val.dimshuffle(pattern)

return op.make_node(batch_val, *new_shapes)


alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))

Expand Down
10 changes: 4 additions & 6 deletions pytensor/tensor/conv/abstract_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1886,9 +1886,7 @@ def frac_bilinear_upsampling(input, frac_ratio):
pad = double_pad // 2

# build pyramidal kernel
kern = bilinear_kernel_2D(ratio=ratio)[np.newaxis, np.newaxis, :, :].astype(
config.floatX
)
kern = bilinear_kernel_2D(ratio=ratio)[None, None, :, :].astype(config.floatX)

# add corresponding padding
pad_kern = pt.concatenate(
Expand Down Expand Up @@ -2019,7 +2017,7 @@ def bilinear_upsampling(
# upsampling rows
upsampled_row = conv2d_grad_wrt_inputs(
output_grad=concat_mat,
filters=kern[np.newaxis, np.newaxis, :, np.newaxis],
filters=kern[None, None, :, None],
input_shape=(up_bs, 1, row * ratio, concat_col),
filter_shape=(1, 1, None, 1),
border_mode=(pad, 0),
Expand All @@ -2030,7 +2028,7 @@ def bilinear_upsampling(
# upsampling cols
upsampled_mat = conv2d_grad_wrt_inputs(
output_grad=upsampled_row,
filters=kern[np.newaxis, np.newaxis, np.newaxis, :],
filters=kern[None, None, None, :],
input_shape=(up_bs, 1, row * ratio, col * ratio),
filter_shape=(1, 1, 1, None),
border_mode=(0, pad),
Expand All @@ -2042,7 +2040,7 @@ def bilinear_upsampling(
kern = bilinear_kernel_2D(ratio=ratio, normalize=True)
upsampled_mat = conv2d_grad_wrt_inputs(
output_grad=concat_mat,
filters=kern[np.newaxis, np.newaxis, :, :],
filters=kern[None, None, :, :],
input_shape=(up_bs, 1, row * ratio, col * ratio),
filter_shape=(1, 1, None, None),
border_mode=(pad, pad),
Expand Down
31 changes: 18 additions & 13 deletions pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,20 +237,22 @@ def is_nd_advanced_idx(idx, dtype) -> bool:
return False

# Parse indices
if isinstance(subtensor_op, Subtensor):
if isinstance(subtensor_op, Subtensor | AdvancedSubtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else:
indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any(
is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices
):
return False

# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any(
is_nd_advanced_idx(idx, integer_dtypes)
or isinstance(getattr(idx, "type", None), NoneTypeT)
for idx in indices
):
return False

# Check that indexing does not act on support dims
batch_ndims = rv_op.batch_ndim(rv_node)
Expand All @@ -269,8 +271,11 @@ def is_nd_advanced_idx(idx, dtype) -> bool:
)
for idx in supp_indices:
if not (
isinstance(idx.type, SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
(isinstance(idx, slice) and idx == slice(None))
or (
isinstance(getattr(idx, "type", None), SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
)
):
return False
n_discarded_idxs = len(supp_indices)
Expand Down
Loading