Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions scim2_server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,18 @@
from scim2_models import Attribute
from scim2_models import BaseModel
from scim2_models import CaseExact
from scim2_models import Error
from scim2_models import Extension
from scim2_models import Meta
from scim2_models import Resource
from scim2_models import ResourceType
from scim2_models import Schema
from scim2_models import SearchRequest
from scim2_models import Uniqueness
from scim2_models import UniquenessException
from werkzeug.http import generate_etag

from scim2_server.filter import evaluate_filter
from scim2_server.operators import ResolveSortOperator
from scim2_server.utils import SCIMException
from scim2_server.utils import get_by_alias


Expand Down Expand Up @@ -120,10 +119,9 @@ def query_resources(
Resources. The List must contain a copy of resources.
Mutating elements in the List must not modify the data
stored in the backend.
:raises SCIMException: If the backend only supports querying for
one resource type at a time, setting resource_type_id to
None the backend may raise a
SCIMException(Error.make_too_many_error()).
:raises TooManyException: If the backend only supports querying
for one resource type at a time, setting resource_type_id to
None the backend may raise TooManyException.
"""
raise NotImplementedError

Expand Down Expand Up @@ -353,7 +351,7 @@ def create_resource(
if existing_resource.meta.resource_type == resource_type_id:
existing_value = unique_attribute.get_attribute(existing_resource)
if existing_value == new_value:
raise SCIMException(Error.make_uniqueness_error())
raise UniquenessException()

self.resources.append(resource)
return resource
Expand Down Expand Up @@ -392,7 +390,7 @@ def update_resource(
existing_resource
)
if existing_value == new_value:
raise SCIMException(Error.make_uniqueness_error())
raise UniquenessException()

self.resources[found_res_idx] = updated_resource
return updated_resource
Expand Down
5 changes: 2 additions & 3 deletions scim2_server/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from scim2_filter_parser import ast as scim2ast
from scim2_models import BaseModel
from scim2_models import CaseExact
from scim2_models import Error
from scim2_models import InvalidFilterException

from scim2_server.utils import SCIMException
from scim2_server.utils import get_by_alias
from scim2_server.utils import parse_new_value

Expand Down Expand Up @@ -125,4 +124,4 @@ def check_comparable_value(value):
status code 400) with "scimType" of "invalidFilter"."
"""
if isinstance(value, bytes | bool | NoneType):
raise SCIMException(Error.make_invalid_filter_error())
raise InvalidFilterException()
37 changes: 20 additions & 17 deletions scim2_server/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
from scim2_filter_parser.parser import SCIMParser
from scim2_models import BaseModel
from scim2_models import CaseExact
from scim2_models import Error
from scim2_models import InvalidPathException
from scim2_models import InvalidValueException
from scim2_models import Mutability
from scim2_models import MutabilityException
from scim2_models import NoTargetException
from scim2_models import PatchOperation
from scim2_models import Required
from scim2_models import Resource
from scim2_models import Returned
from scim2_models import SensitiveException

from scim2_server.filter import evaluate_filter
from scim2_server.utils import SCIMException
from scim2_server.utils import get_by_alias
from scim2_server.utils import get_or_create
from scim2_server.utils import handle_extension
Expand Down Expand Up @@ -55,7 +58,7 @@ def parse_attribute_path(attribute_path: str | None) -> dict[str, Any] | None:

match = ATTRIBUTE_PATH_REGEX.match(attribute_path)
if not match:
raise SCIMException(Error.make_invalid_path_error())
raise InvalidPathException()
parse_attribute_path.cache[attribute_path] = match.groupdict()
return match.groupdict()

Expand Down Expand Up @@ -148,7 +151,7 @@ def match_multi_valued_attribute_sub(
attribute_name = get_by_alias(type(model), attribute)
multi_valued_attribute = get_or_create(model, attribute_name, True)
if not isinstance(multi_valued_attribute, list):
raise SCIMException(Error.make_invalid_path_error())
raise InvalidPathException()
token_stream = SCIMLexer().tokenize(condition)
condition = SCIMParser().parse(token_stream)
self.init_return(model, attribute_name, sub_attribute, self.value)
Expand All @@ -160,13 +163,13 @@ def match_multi_valued_attribute(
self, attribute: str, condition: str, model: BaseModel
):
if self.REQUIRES_VALUE and not isinstance(self.value, dict):
raise SCIMException(Error.make_invalid_value_error())
raise InvalidValueException()
attribute_name = get_by_alias(type(model), attribute)
multi_valued_attribute = get_or_create(
model, attribute_name, self.REQUIRES_VALUE
)
if not isinstance(multi_valued_attribute, list):
raise SCIMException(Error.make_invalid_path_error())
raise InvalidPathException()
token_stream = SCIMLexer().tokenize(condition)
condition = SCIMParser().parse(token_stream)
if self.REQUIRES_VALUE:
Expand Down Expand Up @@ -196,7 +199,7 @@ def match_complex_attribute(self, attribute: str, model: BaseModel, sub_path: st
self.match_attribute(sub_path, value)
else:
if not isinstance(complex_attribute, BaseModel):
raise SCIMException(Error.make_invalid_path_error())
raise InvalidPathException()
self.match_attribute(sub_path, complex_attribute)

def match_attribute(self, attribute: str, model: BaseModel):
Expand All @@ -205,9 +208,9 @@ def match_attribute(self, attribute: str, model: BaseModel):

def call_on_root(self, model: Resource):
if not self.OPERATE_ON_ROOT:
raise SCIMException(Error.make_no_target_error())
raise NoTargetException()
if not isinstance(self.value, dict):
raise SCIMException(Error.make_invalid_value_error())
raise InvalidValueException()
for k, v in self.value.items():
ext, scim_name = handle_extension(model, k)
if ext == model:
Expand All @@ -233,7 +236,7 @@ def operation(cls, model: BaseModel, attribute: str, value: Any):
return

if model.get_field_annotation(alias, Mutability) == Mutability.read_only:
raise SCIMException(Error.make_mutability_error())
raise MutabilityException()

if model.get_field_multiplicity(alias):
if getattr(model, alias) is None:
Expand All @@ -247,7 +250,7 @@ def operation(cls, model: BaseModel, attribute: str, value: Any):
model.get_field_annotation(alias, Required) == Required.true
and not new_value
):
raise SCIMException(Error.make_invalid_value_error())
raise InvalidValueException()
setattr(model, alias, new_value)


Expand All @@ -268,10 +271,10 @@ def operation(cls, model: BaseModel, attribute: str, value: Any):
Mutability.read_only,
Mutability.immutable,
):
raise SCIMException(Error.make_mutability_error())
raise MutabilityException()

if model.get_field_annotation(alias, Required) == Required.true:
raise SCIMException(Error.make_invalid_value_error())
raise InvalidValueException()

setattr(model, alias, None)

Expand All @@ -283,21 +286,21 @@ class ReplaceOperator(Operator):
def operation(cls, model: BaseModel, attribute: str, value: Any):
alias = get_by_alias(type(model), attribute)
if model.get_field_multiplicity(alias) and not isinstance(value, list):
raise SCIMException(Error.make_invalid_value_error())
raise InvalidValueException()

existing_value = getattr(model, alias)
new_value = parse_new_value(model, alias, value)
if new_value == existing_value:
return

if model.get_field_annotation(alias, Mutability) == Mutability.read_only:
raise SCIMException(Error.make_mutability_error())
raise MutabilityException()

if (
model.get_field_annotation(alias, Required) == Required.true
and not new_value
):
raise SCIMException(Error.make_invalid_value_error())
raise InvalidValueException()
setattr(model, alias, new_value)


Expand Down Expand Up @@ -372,7 +375,7 @@ def init_return(
model.get_field_annotation(alias, Mutability) == Mutability.write_only
or model.get_field_annotation(alias, Returned) == Returned.never
):
raise SCIMException(Error.make_sensitive_error())
raise SensitiveException()

@classmethod
def operation(
Expand Down
4 changes: 2 additions & 2 deletions scim2_server/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from scim2_models import Resource
from scim2_models import ResourceType
from scim2_models import Schema
from scim2_models import SCIMException
from scim2_models import SearchRequest
from scim2_models import ServiceProviderConfig
from scim2_models import Sort
Expand All @@ -39,7 +40,6 @@

from scim2_server.backend import Backend
from scim2_server.operators import patch_resource
from scim2_server.utils import SCIMException
from scim2_server.utils import merge_resources


Expand Down Expand Up @@ -537,7 +537,7 @@ def wsgi_app(self, request: Request, environ):
return self.make_error(Error(status=e.code, detail=e.description))
except SCIMException as e:
self.log.exception(e)
return self.make_error(e.scim_error)
return self.make_error(e.to_error())
except ValidationError as e:
self.log.exception(e)
return self.make_error(Error(status=400, detail=str(e)))
Expand Down
29 changes: 12 additions & 17 deletions scim2_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,16 @@
from pydantic import EmailStr
from pydantic import ValidationError
from scim2_models import BaseModel
from scim2_models import Error
from scim2_models import Extension
from scim2_models import InvalidValueException
from scim2_models import Mutability
from scim2_models import MutabilityException
from scim2_models import NoTargetException
from scim2_models import Resource
from scim2_models import ResourceType
from scim2_models import Schema


class SCIMException(Exception):
"""A wrapper class, because an "Error" does not inherit from Exception and should not be raised."""

def __init__(self, scim_error: Error):
self.scim_error = scim_error


def load_json_resource(json_name: str) -> list:
"""Load a JSON document from the scim2_server package resources."""
fp = importlib.resources.files("scim2_server") / "resources" / json_name
Expand Down Expand Up @@ -69,7 +64,7 @@ def merge_resources(target: Resource, updates: BaseModel):
if mutability == Mutability.immutable and getattr(
target, set_attribute
) not in (None, new_value):
raise SCIMException(Error.make_mutability_error())
raise MutabilityException()
setattr(target, set_attribute, new_value)


Expand All @@ -82,8 +77,8 @@ def get_by_alias(
:param scim_name: SCIM attribute name
:param allow_none: Allow returning None if attribute is not found
:return: pydantic attribute name
:raises SCIMException: If no attribute is found and allow_none is
False
:raises NoTargetException: If no attribute is found and allow_none
is False
"""
try:
return next(
Expand All @@ -94,7 +89,7 @@ def get_by_alias(
except StopIteration as e:
if allow_none:
return None
raise SCIMException(Error.make_no_target_error()) from e
raise NoTargetException() from e


def get_or_create(
Expand All @@ -107,15 +102,15 @@ def get_or_create(
:param check_mutability: If True, validate that the attribute is
mutable
:return: A complex attribute model
:raises SCIMException: If attribute is not mutable and
:raises MutabilityException: If attribute is not mutable and
check_mutability is True
"""
if check_mutability:
if model.get_field_annotation(attribute_name, Mutability) in (
Mutability.read_only,
Mutability.immutable,
):
raise SCIMException(Error.make_mutability_error())
raise MutabilityException()
ret = getattr(model, attribute_name, None)
if not ret:
if model.get_field_multiplicity(attribute_name):
Expand Down Expand Up @@ -164,8 +159,8 @@ def model_validate_from_dict(field_root_type: type[BaseModel], value: dict) -> A
def parse_new_value(model: BaseModel, attribute_name: str, value: Any) -> Any:
"""Given a model and attribute name, attempt to parse a new value so that the type matches the type expected by the model.

:raises SCIMException: If attribute can not be mapped to the
required type
:raises InvalidValueException: If attribute can not be mapped to
the required type
"""
field_root_type = model.get_field_root_type(attribute_name)
try:
Expand Down Expand Up @@ -195,5 +190,5 @@ def parse_new_value(model: BaseModel, attribute_name: str, value: Any) -> Any:
else:
new_value = field_root_type(value)
except (AttributeError, TypeError, ValueError, ValidationError) as e:
raise SCIMException(Error.make_invalid_value_error()) from e
raise InvalidValueException() from e
return new_value
Loading
Loading