Skip to content

Commit e0889db

Browse files
committed
Make the names of arguments of overriding methods consistent
1 parent 20e5038 commit e0889db

File tree

4 files changed

+18
-21
lines changed

4 files changed

+18
-21
lines changed

pytensor/link/c/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class Generic(CType, Singleton):
7373
def filter(self, data, strict=False, allow_downcast=None):
7474
return data
7575

76-
def is_valid_value(self, a):
76+
def is_valid_value(self, data, strict: bool = True) -> bool:
7777
return True
7878

7979
def c_declare(self, name, sub, check_input=True):

pytensor/sparse/type.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,31 +98,28 @@ def clone(
9898
shape = self.shape
9999
return type(self)(format, dtype, shape=shape, **kwargs)
100100

101-
def filter(self, value, strict=False, allow_downcast=None):
102-
if isinstance(value, Variable):
101+
def filter(self, data, strict: bool = False, allow_downcast=None):
102+
if isinstance(data, Variable):
103103
raise TypeError(
104104
"Expected an array-like object, but found a Variable: "
105105
"maybe you are trying to call a function on a (possibly "
106106
"shared) variable instead of a numeric array?"
107107
)
108108

109-
if (
110-
isinstance(value, self.format_cls[self.format])
111-
and value.dtype == self.dtype
112-
):
113-
return value
109+
if isinstance(data, self.format_cls[self.format]) and data.dtype == self.dtype:
110+
return data
114111

115112
if strict:
116113
raise TypeError(
117-
f"{value} is not sparse, or not the right dtype (is {value.dtype}, "
114+
f"{data} is not sparse, or not the right dtype (is {data.dtype}, "
118115
f"expected {self.dtype})"
119116
)
120117

121118
# The input format could be converted here
122119
if allow_downcast:
123-
sp = self.format_cls[self.format](value, dtype=self.dtype)
120+
sp = self.format_cls[self.format](data, dtype=self.dtype)
124121
else:
125-
data = self.format_cls[self.format](value)
122+
data = self.format_cls[self.format](data)
126123
up_dtype = ps.upcast(self.dtype, data.dtype)
127124
if up_dtype != self.dtype:
128125
raise TypeError(f"Expected {self.dtype} dtype but got {data.dtype}")
@@ -209,8 +206,8 @@ def values_eq(self, a, b):
209206
and abs(a - b).sum() == 0.0
210207
)
211208

212-
def is_valid_value(self, a):
213-
return scipy.sparse.issparse(a) and (a.format == self.format)
209+
def is_valid_value(self, data, strict: bool = True):
210+
return scipy.sparse.issparse(data) and (data.format == self.format)
214211

215212
def get_shape_info(self, obj):
216213
obj = self.filter(obj)

pytensor/tensor/type.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import warnings
33
from collections.abc import Iterable
4-
from typing import TYPE_CHECKING, Literal, Optional
4+
from typing import TYPE_CHECKING, Any, Literal, Optional
55

66
import numpy as np
77
import numpy.typing as npt
@@ -630,8 +630,8 @@ def c_code_cache_version(self):
630630

631631

632632
class DenseTypeMeta(MetaType):
633-
def __instancecheck__(self, o):
634-
if type(o) is TensorType or isinstance(o, DenseTypeMeta):
633+
def __instancecheck__(self, instance: Any) -> bool:
634+
if type(instance) is TensorType or isinstance(instance, DenseTypeMeta):
635635
return True
636636
return False
637637

pytensor/tensor/variable.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from collections.abc import Iterable
55
from numbers import Number
6-
from typing import TypeVar
6+
from typing import Any, TypeVar
77

88
import numpy as np
99

@@ -1117,8 +1117,8 @@ def __deepcopy__(self, memo):
11171117

11181118

11191119
class DenseVariableMeta(MetaType):
1120-
def __instancecheck__(self, o):
1121-
if type(o) is TensorVariable or isinstance(o, DenseVariableMeta):
1120+
def __instancecheck__(self, instance: Any) -> bool:
1121+
if type(instance) is TensorVariable or isinstance(instance, DenseVariableMeta):
11221122
return True
11231123
return False
11241124

@@ -1132,8 +1132,8 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta):
11321132

11331133

11341134
class DenseConstantMeta(MetaType):
1135-
def __instancecheck__(self, o):
1136-
if type(o) is TensorConstant or isinstance(o, DenseConstantMeta):
1135+
def __instancecheck__(self, instance: Any) -> bool:
1136+
if type(instance) is TensorConstant or isinstance(instance, DenseConstantMeta):
11371137
return True
11381138
return False
11391139

0 commit comments

Comments
 (0)