diff --git a/HISTORY.md b/HISTORY.md index 44aef2f7..c0d5be50 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -29,7 +29,8 @@ Our backwards-compatibility policy can be found [here](https://github.com/python ([#707](https://github.com/python-attrs/cattrs/issues/707) [#708](https://github.com/python-attrs/cattrs/pull/708)) - The {mod}`tomlkit ` preconf converter now passes date objects directly to _tomlkit_ for unstructuring. ([#707](https://github.com/python-attrs/cattrs/issues/707) [#708](https://github.com/python-attrs/cattrs/pull/708)) - +- Enum handling has been optimized by switching to hook factories, improving performance especially for plain enums. + ([#705](https://github.com/python-attrs/cattrs/pull/705)) ## 25.3.0 (2025-10-07) diff --git a/src/cattrs/converters.py b/src/cattrs/converters.py index 14b2bff6..54d67a48 100644 --- a/src/cattrs/converters.py +++ b/src/cattrs/converters.py @@ -47,6 +47,7 @@ is_mutable_set, is_optional, is_protocol, + is_subclass, is_tuple, is_typeddict, is_union_type, @@ -76,6 +77,7 @@ UnstructuredValue, UnstructureHook, ) +from .enums import enum_structure_factory, enum_unstructure_factory from .errors import ( IterableValidationError, IterableValidationNote, @@ -246,12 +248,12 @@ def __init__( lambda t: self.get_unstructure_hook(get_type_alias_base(t)), True, ), - (is_literal_containing_enums, self.unstructure), (is_mapping, self._unstructure_mapping), (is_sequence, self._unstructure_seq), (is_mutable_set, self._unstructure_seq), (is_frozenset, self._unstructure_seq), - (lambda t: issubclass(t, Enum), self._unstructure_enum), + (is_literal_containing_enums, self.unstructure), + (lambda t: is_subclass(t, Enum), enum_unstructure_factory, "extended"), (has, self._unstructure_attrs), (is_union_type, self._unstructure_union), (lambda t: t in ANIES, self.unstructure), @@ -298,6 +300,7 @@ def __init__( self._union_struct_registry.__getitem__, True, ), + (lambda t: is_subclass(t, Enum), enum_structure_factory, "extended"), (has, self._structure_attrs), ] ) @@ -308,7 +311,6 @@ def __init__( (bytes, self._structure_call), (int, self._structure_call), (float, self._structure_call), - (Enum, self._structure_enum), (Path, self._structure_call), ] ) @@ -630,12 +632,6 @@ def unstructure_attrs_astuple(self, obj: Any) -> tuple[Any, ...]: res.append(dispatch(a.type or v.__class__)(v)) return tuple(res) - def _unstructure_enum(self, obj: Enum) -> Any: - """Convert an enum to its unstructured value.""" - if "_value_" in obj.__class__.__annotations__: - return self._unstructure_func.dispatch(obj.value.__class__)(obj.value) - return obj.value - def _unstructure_seq(self, seq: Sequence[T]) -> Sequence[T]: """Convert a sequence to primitive equivalents.""" # We can reuse the sequence class, so tuples stay tuples. @@ -715,15 +711,6 @@ def _structure_simple_literal(val, type): raise Exception(f"{val} not in literal {type}") return val - def _structure_enum(self, val: Any, cl: type[Enum]) -> Enum: - """Structure ``val`` if possible and return the enum it corresponds to. - - Uses type hints for the "_value_" attribute if they exist to structure - the enum values before returning the result.""" - if "_value_" in cl.__annotations__: - val = self.structure(val, cl.__annotations__["_value_"]) - return cl(val) - @staticmethod def _structure_enum_literal(val, type): vals = {(x.value if isinstance(x, Enum) else x): x for x in type.__args__} diff --git a/src/cattrs/enums.py b/src/cattrs/enums.py new file mode 100644 index 00000000..b1ab5040 --- /dev/null +++ b/src/cattrs/enums.py @@ -0,0 +1,36 @@ +from collections.abc import Callable +from enum import Enum +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .converters import BaseConverter + + +def enum_unstructure_factory( + type: type[Enum], converter: "BaseConverter" +) -> Callable[[Enum], Any]: + """A factory for generating enum unstructure hooks. + + If the enum is a typed enum (has `_value_`), we use the underlying value's hook. + Otherwise, we use the value directly. + """ + if "_value_" in type.__annotations__: + return lambda e: converter.unstructure(e.value) + + return lambda e: e.value + + +def enum_structure_factory( + type: type[Enum], converter: "BaseConverter" +) -> Callable[[Any, type[Enum]], Enum]: + """A factory for generating enum structure hooks. + + If the enum is a typed enum (has `_value_`), we structure the value first. + Otherwise, we use the value directly. + """ + if "_value_" in type.__annotations__: + val_type = type.__annotations__["_value_"] + val_hook = converter.get_structure_hook(val_type) + return lambda v, _: type(val_hook(v, val_type)) + + return lambda v, _: type(v) diff --git a/src/cattrs/preconf/__init__.py b/src/cattrs/preconf/__init__.py index 27ce1f10..6cab7cb3 100644 --- a/src/cattrs/preconf/__init__.py +++ b/src/cattrs/preconf/__init__.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from datetime import datetime from enum import Enum -from typing import Any, Callable, ParamSpec, TypeVar, get_args +from typing import Any, ParamSpec, TypeVar, get_args from .._compat import is_subclass from ..converters import Converter, UnstructureHook diff --git a/src/cattrs/preconf/bson.py b/src/cattrs/preconf/bson.py index 49574893..9d28a5ba 100644 --- a/src/cattrs/preconf/bson.py +++ b/src/cattrs/preconf/bson.py @@ -99,11 +99,11 @@ def gen_structure_mapping(cl: Any) -> StructureHook: # datetime inherits from date, so identity unstructure hook used # here to prevent the date unstructure hook running. - converter.register_unstructure_hook(datetime, lambda v: v) + converter.register_unstructure_hook(datetime, identity) converter.register_structure_hook(datetime, validate_datetime) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) - converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory(is_primitive_enum, lambda t: identity) converter.register_unstructure_hook_factory( is_literal_containing_enums, literals_with_enums_unstructure_factory ) diff --git a/src/cattrs/preconf/cbor2.py b/src/cattrs/preconf/cbor2.py index 6341d898..ad011c86 100644 --- a/src/cattrs/preconf/cbor2.py +++ b/src/cattrs/preconf/cbor2.py @@ -37,7 +37,7 @@ def configure_converter(converter: BaseConverter): ) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) - converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory(is_primitive_enum, lambda t: identity) converter.register_unstructure_hook_factory( is_literal_containing_enums, literals_with_enums_unstructure_factory ) diff --git a/src/cattrs/preconf/json.py b/src/cattrs/preconf/json.py index 199c574d..355e6619 100644 --- a/src/cattrs/preconf/json.py +++ b/src/cattrs/preconf/json.py @@ -52,7 +52,7 @@ def configure_converter(converter: BaseConverter) -> None: converter.register_unstructure_hook_factory( is_literal_containing_enums, literals_with_enums_unstructure_factory ) - converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory(is_primitive_enum, lambda _: identity) configure_union_passthrough(Union[str, bool, int, float, None], converter) diff --git a/src/cattrs/preconf/msgpack.py b/src/cattrs/preconf/msgpack.py index 92876418..b0726da5 100644 --- a/src/cattrs/preconf/msgpack.py +++ b/src/cattrs/preconf/msgpack.py @@ -46,7 +46,7 @@ def configure_converter(converter: BaseConverter) -> None: converter.register_structure_hook( date, lambda v, _: datetime.fromtimestamp(v, timezone.utc).date() ) - converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory(is_primitive_enum, lambda t: identity) converter.register_unstructure_hook_factory( is_literal_containing_enums, literals_with_enums_unstructure_factory ) diff --git a/src/cattrs/preconf/msgspec.py b/src/cattrs/preconf/msgspec.py index 6274a32b..123e2779 100644 --- a/src/cattrs/preconf/msgspec.py +++ b/src/cattrs/preconf/msgspec.py @@ -3,18 +3,27 @@ from __future__ import annotations from base64 import b64decode +from collections.abc import Callable from dataclasses import is_dataclass from datetime import date, datetime from enum import Enum from functools import partial -from typing import Any, Callable, TypeVar, Union, get_type_hints +from typing import Any, TypeVar, Union, get_type_hints from attrs import has as attrs_has from attrs import resolve_types from msgspec import Struct, convert, to_builtins from msgspec.json import Encoder, decode -from .._compat import fields, get_args, get_origin, is_bare, is_mapping, is_sequence +from .._compat import ( + fields, + get_args, + get_origin, + is_bare, + is_mapping, + is_sequence, + is_subclass, +) from ..cols import is_namedtuple from ..converters import BaseConverter, Converter from ..dispatch import UnstructureHook @@ -74,7 +83,9 @@ def configure_converter(converter: Converter) -> None: configure_passthroughs(converter) converter.register_unstructure_hook(Struct, to_builtins) - converter.register_unstructure_hook(Enum, identity) + converter.register_unstructure_hook_factory( + lambda t: is_subclass(t, Enum), lambda t, c: identity + ) converter.register_structure_hook(Struct, convert) converter.register_structure_hook(bytes, lambda v, _: b64decode(v)) diff --git a/src/cattrs/preconf/orjson.py b/src/cattrs/preconf/orjson.py index 0726ef04..88ba79a0 100644 --- a/src/cattrs/preconf/orjson.py +++ b/src/cattrs/preconf/orjson.py @@ -87,8 +87,8 @@ def key_handler(v): ), ] ) - converter.register_unstructure_hook_func( - partial(is_primitive_enum, include_bare_enums=True), identity + converter.register_unstructure_hook_factory( + partial(is_primitive_enum, include_bare_enums=True), lambda t: identity ) converter.register_unstructure_hook_factory( is_literal_containing_enums, literals_with_enums_unstructure_factory diff --git a/src/cattrs/preconf/ujson.py b/src/cattrs/preconf/ujson.py index 8f330615..6cb36652 100644 --- a/src/cattrs/preconf/ujson.py +++ b/src/cattrs/preconf/ujson.py @@ -47,7 +47,7 @@ def configure_converter(converter: BaseConverter): converter.register_structure_hook(datetime, lambda v, _: datetime.fromisoformat(v)) converter.register_unstructure_hook(date, lambda v: v.isoformat()) converter.register_structure_hook(date, lambda v, _: date.fromisoformat(v)) - converter.register_unstructure_hook_func(is_primitive_enum, identity) + converter.register_unstructure_hook_factory(is_primitive_enum, lambda t: identity) converter.register_unstructure_hook_factory( is_literal_containing_enums, literals_with_enums_unstructure_factory ) diff --git a/tests/test_function_dispatch.py b/tests/test_function_dispatch.py index 4641443d..c6e190ee 100644 --- a/tests/test_function_dispatch.py +++ b/tests/test_function_dispatch.py @@ -31,3 +31,16 @@ class Bar(Foo): assert dispatch.dispatch(Bar) == "foo" dispatch.register(lambda cls: issubclass(cls, Bar), "bar") assert dispatch.dispatch(Bar) == "bar" + + +def test_function_dispatch_exception(): + """Function dispatch gracefully handles exceptions in predicates.""" + dispatch = FunctionDispatch(BaseConverter()) + + def raising_predicate(cls): + raise ValueError("This predicate raises an error") + + dispatch.register(lambda cls: issubclass(cls, float), "float") + dispatch.register(raising_predicate, "error") + + assert dispatch.dispatch(float) == "float"