diff --git a/src/cli/credential.py b/src/cli/credential.py index 922e2733..50032127 100644 --- a/src/cli/credential.py +++ b/src/cli/credential.py @@ -26,12 +26,12 @@ import shtab -from src.lib.utils import client, client_configs, common, credentials, osmo_errors +from src.lib.utils import client, client_configs, credentials, common, osmo_errors CRED_TYPES = ['REGISTRY', 'DATA', 'GENERIC'] -def _save_config(data_cred: credentials.DataCredential): +def _save_config(data_cred: credentials.StaticDataCredential): """ Sets default config information """ @@ -46,7 +46,8 @@ def _save_config(data_cred: credentials.DataCredential): config['auth']['data'][data_cred.endpoint] = { 'access_key_id': data_cred.access_key_id, 'access_key': data_cred.access_key.get_secret_value(), - 'region': data_cred.region} + 'region': data_cred.region, + } with open(password_file, 'w', encoding='utf-8') as file: yaml.dump(config, file) os.chmod(password_file, stat.S_IREAD | stat.S_IWRITE) @@ -97,21 +98,30 @@ def _run_set_command(service_client: client.ServiceClient, args: argparse.Namesp print(f'File {value} cannot be found.') sys.exit(1) - if args.type == 'GENERIC': + if args.type == 'DATA': + # Validate that the data credential is valid + try: + credentials.StaticDataCredential(**cred_payload) + except Exception as err: # pylint: disable=broad-except + raise osmo_errors.OSMOUserError(f'Invalid DATA credential: {str(err)}') + + elif args.type == 'GENERIC': cred_payload = {'credential': cred_payload} - result = service_client.request(client.RequestMethod.POST, - f'api/credentials/{args.name}', - payload={args.type.lower() + '_credential': cred_payload}) + result = service_client.request( + client.RequestMethod.POST, + f'api/credentials/{args.name}', + payload={args.type.lower() + '_credential': cred_payload}, + ) + if args.format_type == 'json': print(json.dumps(result, indent=2)) else: print(f'Set {args.type} credential {args.name}.') if args.type == 'DATA': - if 'endpoint' not in cred_payload: - raise osmo_errors.OSMOUserError('Endpoint is required for DATA credentials.') - _save_config(credentials.DataCredential(**cred_payload)) + # Save the data credential to the client config + _save_config(credentials.StaticDataCredential(**cred_payload)) def _run_list_command(service_client: client.ServiceClient, args: argparse.Namespace): @@ -132,11 +142,11 @@ def _run_list_command(service_client: client.ServiceClient, args: argparse.Names for cred in result['credentials']: cred['local'] = 'N/A' if cred.get('cred_type', '') == 'DATA': - try: - client_configs.get_credentials(cred.get('profile', '')) - cred['local'] = 'Yes' - except osmo_errors.OSMOError: - cred['local'] = 'No' + data_cred = credentials.get_static_data_credential_from_config( + url=cred.get('profile', ''), + ) + cred['local'] = 'Yes' if data_cred else 'No' + table.add_row([cred.get(column, '-') for column in columns]) print(f'{table.draw()}\n') @@ -206,15 +216,15 @@ def setup_parser(parser: argparse._SubParsersAction): 'payload corresponding to each type of credential:\n' '\n' # pylint: disable=line-too-long - '+-----------------+---------------------------+---------------------------------------+\n' - '| Credential Type | Mandatory keys | Optional keys |\n' - '+-----------------+---------------------------+---------------------------------------+\n' - '| REGISTRY | auth | registry, username |\n' - '+-----------------+---------------------------+---------------------------------------+\n' - '| DATA | access_key_id, access_key | endpoint, region (default: us-east-1) |\n' - '+-----------------+---------------------------+---------------------------------------+\n' - '| GENERIC | | |\n' - '+-----------------+---------------------------+---------------------------------------+\n' + '+-----------------+-------------------------------------+-----------------------------+\n' + '| Credential Type | Mandatory keys | Optional keys |\n' + '+-----------------+-------------------------------------+-----------------------------+\n' + '| REGISTRY | auth | registry, username |\n' + '+-----------------+-------------------------------------+-----------------------------+\n' + '| DATA | access_key_id, access_key, endpoint | region (default: us-east-1) |\n' + '+-----------------+-------------------------------------+-----------------------------+\n' + '| GENERIC | | |\n' + '+-----------------+-------------------------------------+-----------------------------+\n' # pylint: enable=line-too-long '\n' ), diff --git a/src/lib/data/dataset/BUILD b/src/lib/data/dataset/BUILD index 00d14744..d8b9c64a 100644 --- a/src/lib/data/dataset/BUILD +++ b/src/lib/data/dataset/BUILD @@ -26,7 +26,6 @@ osmo_py_library( "//src/lib/data/storage", "//src/lib/utils:cache", "//src/lib/utils:client", - "//src/lib/utils:client_configs", "//src/lib/utils:logging", "//src/lib/utils:common", "//src/lib/utils:osmo_errors", diff --git a/src/lib/data/dataset/common.py b/src/lib/data/dataset/common.py index 510294f1..72ad07d5 100644 --- a/src/lib/data/dataset/common.py +++ b/src/lib/data/dataset/common.py @@ -32,9 +32,10 @@ import diskcache import pydantic -from .. import storage, constants +from .. import storage +from ..storage import constants from ..storage.core import progress -from ...utils import client, client_configs, common, osmo_errors, paths +from ...utils import client, common, osmo_errors, paths logger = logging.getLogger(__name__) @@ -305,12 +306,8 @@ def _validate_source_path( if re.fullmatch(constants.STORAGE_BACKEND_REGEX, source_path): # Remote path logic path_components = storage.construct_storage_backend(source_path) - user_credentials = client_configs.get_credentials(path_components.profile) path_components.data_auth( - user_credentials.access_key_id, - user_credentials.access_key.get_secret_value(), - user_credentials.region, - storage.AccessType.READ, + access_type=storage.AccessType.READ, ) return RemotePath(path_components, has_asterisk, priority) diff --git a/src/lib/data/dataset/manager.py b/src/lib/data/dataset/manager.py index d507b09c..a10a86ff 100644 --- a/src/lib/data/dataset/manager.py +++ b/src/lib/data/dataset/manager.py @@ -259,13 +259,7 @@ def upload_start( location_result['path'], cache_config=self.cache_config, ) - credentials = client_configs.get_credentials(path_components.profile) - path_components.data_auth( - credentials.access_key_id, - credentials.access_key.get_secret_value(), - credentials.region, - storage.AccessType.WRITE, - ) + path_components.data_auth(access_type=storage.AccessType.WRITE) # Parse and validate the input paths local_paths, backend_paths = common.parse_upload_paths(input_paths) @@ -398,24 +392,13 @@ def _update_dataset_start( location_result['path'], cache_config=self.cache_config, ) - credentials = client_configs.get_credentials(path_components.profile) if remove_regex: # Validate delete access - path_components.data_auth( - credentials.access_key_id, - credentials.access_key.get_secret_value(), - credentials.region, - storage.AccessType.DELETE, - ) + path_components.data_auth(access_type=storage.AccessType.DELETE) if add_paths: # Validate write access - path_components.data_auth( - credentials.access_key_id, - credentials.access_key.get_secret_value(), - credentials.region, - storage.AccessType.WRITE, - ) + path_components.data_auth(access_type=storage.AccessType.WRITE) # If add_paths is provided, seperate paths and perform basic authentication # against backend paths diff --git a/src/lib/data/dataset/migrating.py b/src/lib/data/dataset/migrating.py index 26704c92..af4e11d2 100644 --- a/src/lib/data/dataset/migrating.py +++ b/src/lib/data/dataset/migrating.py @@ -30,7 +30,7 @@ from .. import storage from ..storage import common as storage_common, copying from ..storage.core import client, executor, progress, provider -from ...utils import cache, client_configs, osmo_errors +from ...utils import cache, osmo_errors MANIFEST_REGEX_PATTERN = re.compile(r'.*\/manifests\/[0-9]+\.json$') @@ -221,15 +221,9 @@ def migrate( # Resolve the region for the destination storage backend. # This is necessary for generating a valid regional HTTP URL for uploaded objects # for certain storage backends (e.g. AWS S3). - destination_creds = client_configs.get_credentials(destination_backend.profile) - destination_region = destination_backend.region( - destination_creds.access_key_id, - destination_creds.access_key.get_secret_value(), - ) + destination_region = destination_backend.region() client_factory = destination_backend.client_factory( - access_key_id=destination_creds.access_key_id, - access_key=destination_creds.access_key.get_secret_value(), region=destination_region, ) diff --git a/src/lib/data/dataset/updating.py b/src/lib/data/dataset/updating.py index 5318697b..88c19cef 100644 --- a/src/lib/data/dataset/updating.py +++ b/src/lib/data/dataset/updating.py @@ -29,7 +29,7 @@ from . import common, uploading from .. import storage from ..storage.core import executor -from ...utils import cache, client_configs, osmo_errors +from ...utils import cache, osmo_errors ################################# @@ -224,17 +224,11 @@ def update( # Resolve the region for the destination storage backend. # This is necessary for generating a valid regional HTTP URL for uploaded objects # for certain storage backends (e.g. AWS S3). - destination_creds = client_configs.get_credentials(destination.profile) - destination_region = destination.region( - destination_creds.access_key_id, - destination_creds.access_key.get_secret_value(), - ) + destination_region = destination.region() client_factory = destination.client_factory( - access_key_id=destination_creds.access_key_id, - access_key=destination_creds.access_key.get_secret_value(), - region=destination_region, request_headers=request_headers, + region=destination_region, ) manifest_cache = diskcache.Index() diff --git a/src/lib/data/dataset/uploading.py b/src/lib/data/dataset/uploading.py index 24d8f7e6..e318aedf 100644 --- a/src/lib/data/dataset/uploading.py +++ b/src/lib/data/dataset/uploading.py @@ -31,7 +31,7 @@ from .. import storage from ..storage import common as storage_common, uploading from ..storage.core import client, executor, progress, provider -from ...utils import cache, client_configs, common as utils_common, osmo_errors +from ...utils import cache, common as utils_common, osmo_errors ########################## @@ -143,12 +143,9 @@ def dataset_upload_remote_file_entry_generator( ) # Create the base to be used for objects' HTTP URLs. - data_creds = storage_client.data_credential + data_cred = storage_client.data_credential url_base = storage_backend.parse_uri_to_link( - storage_backend.region( - data_creds.access_key_id, - data_creds.access_key.get_secret_value(), - ), + region=storage_backend.region(data_cred), ) # Iterate over the objects in the remote path. @@ -381,17 +378,11 @@ def upload( # Resolve the region for the destination storage backend. # This is necessary for generating a valid regional HTTP URL for uploaded objects # for certain storage backends (e.g. AWS S3). - destination_creds = client_configs.get_credentials(destination.profile) - destination_region = destination.region( - destination_creds.access_key_id, - destination_creds.access_key.get_secret_value(), - ) + destination_region = destination.region() client_factory = destination.client_factory( - access_key_id=destination_creds.access_key_id, - access_key=destination_creds.access_key.get_secret_value(), - region=destination_region, request_headers=request_headers, + region=destination_region, ) manifest_cache = diskcache.Index() diff --git a/src/lib/data/storage/BUILD b/src/lib/data/storage/BUILD index 0592a551..a24d2465 100644 --- a/src/lib/data/storage/BUILD +++ b/src/lib/data/storage/BUILD @@ -33,12 +33,12 @@ osmo_py_library( srcs = glob(["*.py"]), deps = [ "//src/lib/utils:client_configs", - "//src/lib/utils:credentials", "//src/lib/utils:common", "//src/lib/utils:logging", "//src/lib/utils:paths", "//src/lib/utils:osmo_errors", - "//src/lib/data/constants", + "//src/lib/data/storage/constants", + "//src/lib/data/storage/credentials", "//src/lib/data/storage/core", "//src/lib/data/storage/backends", requirement("pydantic"), diff --git a/src/lib/data/storage/backends/BUILD b/src/lib/data/storage/backends/BUILD index 1fb26aab..51fb0f33 100644 --- a/src/lib/data/storage/backends/BUILD +++ b/src/lib/data/storage/backends/BUILD @@ -26,8 +26,9 @@ osmo_py_library( "//src/lib/utils:cache", "//src/lib/utils:common", "//src/lib/utils:osmo_errors", - "//src/lib/data/constants", + "//src/lib/data/storage/constants", "//src/lib/data/storage/core", + "//src/lib/data/storage/credentials", requirement("azure-storage-blob"), requirement("azure-identity"), requirement("boto3"), diff --git a/src/lib/data/storage/backends/azure.py b/src/lib/data/storage/backends/azure.py index 58547a88..a0079a0b 100644 --- a/src/lib/data/storage/backends/azure.py +++ b/src/lib/data/storage/backends/azure.py @@ -29,6 +29,7 @@ from azure.core import exceptions from azure.storage import blob +from .. import credentials from ..core import client, provider from ....utils import common @@ -270,13 +271,20 @@ def __next__(self) -> bytes: return chunk -def create_client( - connection_string: str, -) -> blob.BlobServiceClient: +def create_client(data_cred: credentials.DataCredential) -> blob.BlobServiceClient: """ Creates a new Azure Blob Storage client. """ - return blob.BlobServiceClient.from_connection_string(conn_str=connection_string) + match data_cred: + case credentials.StaticDataCredential(): + return blob.BlobServiceClient.from_connection_string( + conn_str=data_cred.access_key.get_secret_value(), + ) + case credentials.DefaultDataCredential(): + raise NotImplementedError( + 'Default data credentials are not supported yet') + case _ as unreachable: + assert_never(unreachable) class AzureBlobStorageClient(client.StorageClient): @@ -822,10 +830,10 @@ class AzureBlobStorageClientFactory(provider.StorageClientFactory): Factory for the AzureBlobStorageClient. """ - connection_string: str + data_cred: credentials.DataCredential @override def create(self) -> AzureBlobStorageClient: return AzureBlobStorageClient( - lambda: create_client(self.connection_string), + lambda: create_client(self.data_cred), ) diff --git a/src/lib/data/storage/backends/backends.py b/src/lib/data/storage/backends/backends.py index e8b3ffa8..e6abfd3f 100644 --- a/src/lib/data/storage/backends/backends.py +++ b/src/lib/data/storage/backends/backends.py @@ -32,8 +32,8 @@ import pydantic from . import azure, s3, common +from .. import constants, credentials from ..core import client, header -from ... import constants from ....utils import cache, osmo_errors @@ -132,19 +132,20 @@ def _get_extra_headers( @override def client_factory( self, - access_key_id: str, - access_key: str, + data_cred: credentials.DataCredential | None = None, request_headers: List[header.RequestHeaders] | None = None, **kwargs: Any, ) -> s3.S3StorageClientFactory: """ Returns a factory for creating storage clients. """ - region = kwargs.get('region', None) or self.region(access_key_id, access_key) + region = kwargs.get('region', None) or self.region(data_cred) + + if data_cred is None: + data_cred = self.resolved_data_credential return s3.S3StorageClientFactory( # pylint: disable=unexpected-keyword-arg - access_key_id=access_key_id, - access_key=access_key, + data_cred=data_cred, region=region, scheme=self.scheme, endpoint_url=self.auth_endpoint if self.auth_endpoint else None, @@ -242,9 +243,7 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, - region: str | None = None, + data_cred: credentials.DataCredential | None = None, access_type: common.AccessType | None = None, ): # pylint: disable=unused-argument @@ -254,24 +253,34 @@ def data_auth( if _skip_data_auth(): return - if ':' in access_key_id: - namespace = access_key_id.split(':')[1] - else: - namespace = f'AUTH_{access_key_id}' - if namespace != self.namespace: - raise osmo_errors.OSMOCredentialError( - f'Data key validation error: access_key_id {access_key_id} is ' + - f'not valid for the {self.namespace} namespace.') + if data_cred is None: + data_cred = self.resolved_data_credential - if region is None: - region = self.region(access_key_id, access_key) + match data_cred: + case credentials.StaticDataCredential(): + access_key_id = data_cred.access_key_id + + if ':' in access_key_id: + namespace = access_key_id.split(':')[1] + else: + namespace = f'AUTH_{access_key_id}' + + if namespace != self.namespace: + raise osmo_errors.OSMOCredentialError( + f'Data key validation error: access_key_id {access_key_id} is ' + + f'not valid for the {self.namespace} namespace.') + + case credentials.DefaultDataCredential(): + raise NotImplementedError( + 'Default data credentials are not supported for Swift backend') + case _ as unreachable: + assert_never(unreachable) s3_client = s3.create_client( - access_key_id=access_key_id, - access_key=access_key, + data_cred=data_cred, scheme=self.scheme, endpoint_url=self.auth_endpoint, - region=region, + region=self.region(data_cred), ) def _validate_auth(): @@ -285,7 +294,7 @@ def _validate_auth(): except client.OSMODataStorageClientError as err: if err.message == 'AuthorizationHeaderMalformed': raise osmo_errors.OSMOCredentialError( - f'Data key validation error: region {region} is not valid: ' + f'Data key validation error: region {data_cred.region} is not valid: ' f'{err.__cause__}') if err.message == 'SignatureDoesNotMatch': raise osmo_errors.OSMOCredentialError( @@ -297,8 +306,7 @@ def _validate_auth(): @override def region( self, - access_key_id: str, - access_key: str, + data_cred: credentials.DataCredential | None = None, ) -> str: """ Infer the region of the bucket via provided credentials. @@ -308,11 +316,26 @@ def region( if self._region is not None: return self._region + if data_cred is None: + data_cred = self.resolved_data_credential + + match data_cred: + case credentials.StaticDataCredential(): + pass + case credentials.DefaultDataCredential(): + raise NotImplementedError( + 'Default data credentials are not supported for Swift backend') + case _ as unreachable: + assert_never(unreachable) + + if data_cred.region is not None: + return data_cred.region + s3_client = s3.create_client( - access_key_id=access_key_id, - access_key=access_key, + data_cred=data_cred, scheme=self.scheme, endpoint_url=self.auth_endpoint, + # No region, we need to get it from the bucket location ) def _get_region() -> str: @@ -343,6 +366,12 @@ class S3Backend(Boto3Backend): description='Whether the backend supports batch delete.', ) + supports_environment_auth: Literal[True] = pydantic.Field( + default=True, + const=True, + description='Whether the backend supports environment authentication.', + ) + # Cache the region to avoid re-computing it _region: str | None = pydantic.PrivateAttr(default=None) @@ -407,9 +436,7 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, - region: str | None = None, + data_cred: credentials.DataCredential | None = None, access_type: common.AccessType | None = None, ): """ @@ -426,14 +453,22 @@ def data_auth( elif access_type == common.AccessType.DELETE: action.append('s3:DeleteObject') - if region is None: - region = self.region(access_key_id, access_key) - - session = boto3.Session( - aws_access_key_id=access_key_id, - aws_secret_access_key=access_key, - region_name=region, - ) + if data_cred is None: + data_cred = self.resolved_data_credential + + match data_cred: + case credentials.StaticDataCredential(): + session = boto3.Session( + aws_access_key_id=data_cred.access_key_id, + aws_secret_access_key=data_cred.access_key.get_secret_value(), + region_name=self.region(data_cred), + ) + case credentials.DefaultDataCredential(): + session = boto3.Session( + region_name=self.region(data_cred), + ) + case _ as unreachable: + assert_never(unreachable) iam_client: mypy_boto3_iam.client.IAMClient = session.client('iam') sts_client: mypy_boto3_sts.client.STSClient = session.client('sts') @@ -459,24 +494,19 @@ def _validate_auth(): for result in results['EvaluationResults']: if result['EvalDecision'] != 'allowed': raise osmo_errors.OSMOCredentialError( - f'Data key validation error: access_key_id {access_key_id} ' + - f'has no {result["EvalActionName"]} access for s3://{path}') + f'Data key validation error: no {result["EvalActionName"]} ' + f'access for s3://{path}') try: _ = client.execute_api(_validate_auth, s3.S3ErrorHandler()) except client.OSMODataStorageClientError as err: - if err.message in ('SignatureDoesNotMatch', 'InvalidClientTokenId'): - raise osmo_errors.OSMOCredentialError( - f'Data key validation error: access_key_id {access_key_id} is not valid: ' - f'{err.__cause__}') raise osmo_errors.OSMOCredentialError( - f'Data key validation error: {err.__cause__}') + f'Data key validation error: {err.message}: {err.__cause__}') @override def region( self, - access_key_id: str, - access_key: str, + data_cred: credentials.DataCredential | None = None, ) -> str: """ Infer the region of the bucket via provided credentials. @@ -486,9 +516,14 @@ def region( if self._region is not None: return self._region + if data_cred is None: + data_cred = self.resolved_data_credential + + if data_cred.region is not None: + return data_cred.region + s3_client = s3.create_client( - access_key_id=access_key_id, - access_key=access_key, + data_cred=data_cred, scheme=self.scheme, ) @@ -587,9 +622,7 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, - region: str | None = None, + data_cred: credentials.DataCredential | None = None, access_type: common.AccessType | None = None, ): """ @@ -598,16 +631,23 @@ def data_auth( if _skip_data_auth(): return - if region is None: - region = self.region(access_key_id, access_key) - - s3_client = s3.create_client( - access_key_id=access_key_id, - access_key=access_key, - scheme=self.scheme, - endpoint_url=self.auth_endpoint, - region=region, - ) + if data_cred is None: + data_cred = self.resolved_data_credential + + match data_cred: + case credentials.StaticDataCredential(): + s3_client = s3.create_client( + data_cred=data_cred, + scheme=self.scheme, + endpoint_url=self.auth_endpoint, + region=self.region(data_cred), + ) + case credentials.DefaultDataCredential(): + # TODO: Implement Google Cloud Storage DAL for keyless authentication + raise NotImplementedError( + 'Default data credentials are not supported for GS backend yet') + case _ as unreachable: + assert_never(unreachable) # TODO: Have more detailed validation for different access types def _validate_auth(): @@ -619,29 +659,22 @@ def _validate_auth(): try: _ = client.execute_api(_validate_auth, s3.S3ErrorHandler()) except client.OSMODataStorageClientError as err: - if err.message == 'AuthorizationHeaderMalformed': - raise osmo_errors.OSMOCredentialError( - f'Data key validation error: region {region} is not valid: ' - f'{err.__cause__}') - if err.message == 'SignatureDoesNotMatch': - raise osmo_errors.OSMOCredentialError( - f'Data key validation error: access_key_id {access_key_id} is not valid: ' - f'{err.__cause__}') raise osmo_errors.OSMOCredentialError( - f'Data key validation error: {err.__cause__}') + f'Data key validation error: {err.message}: {err.__cause__}') # TODO: Figure out how to correctly find region @override def region( self, - access_key_id: str, - access_key: str, + data_cred: credentials.DataCredential | None = None, ) -> str: - # pylint: disable=unused-argument """ Infer the region of the bucket via provided credentials. """ - return constants.DEFAULT_GS_REGION + if data_cred is None: + data_cred = self.resolved_data_credential + + return data_cred.region or constants.DEFAULT_GS_REGION class TOSBackend(Boto3Backend): @@ -726,29 +759,31 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, - region: str | None = None, + data_cred: credentials.DataCredential | None = None, access_type: common.AccessType | None = None, ): - # pylint: disable=unused-argument """ Validates if the data is valid for the backend """ if _skip_data_auth(): return - if region is None: - # If region is not provided, we need to extract it from the netloc - region = self.region('', '') - - s3_client = s3.create_client( - access_key_id=access_key_id, - access_key=access_key, - scheme=self.scheme, - endpoint_url=self.auth_endpoint, - region=region, - ) + if data_cred is None: + data_cred = self.resolved_data_credential + + match data_cred: + case credentials.StaticDataCredential(): + s3_client = s3.create_client( + data_cred=data_cred, + scheme=self.scheme, + endpoint_url=self.auth_endpoint, + region=self.region(data_cred), + ) + case credentials.DefaultDataCredential(): + raise NotImplementedError( + 'Default data credentials are not supported for TOS backend') + case _ as unreachable: + assert_never(unreachable) def _validate_auth(): if self.container: @@ -759,24 +794,11 @@ def _validate_auth(): try: client.execute_api(_validate_auth, s3.S3ErrorHandler()) except client.OSMODataStorageClientError as err: - if err.message == 'AuthorizationHeaderMalformed': - raise osmo_errors.OSMOCredentialError( - f'Data key validation error: region {region} is not valid: ' - f'{err.__cause__}') - if err.message == 'SignatureDoesNotMatch': - raise osmo_errors.OSMOCredentialError( - f'Data key validation error: access_key_id {access_key_id} is not valid: ' - f'{err.__cause__}') raise osmo_errors.OSMOCredentialError( - f'Data key validation error: {err.__cause__}') + f'Data key validation error: {err.message}: {err.__cause__}') @override - def region( - self, - access_key_id: str, - access_key: str, - ) -> str: - # pylint: disable=unused-argument + def region(self, _: credentials.DataCredential | None = None) -> str: # netloc = tos-s3-. return self.netloc[len('tos-s3-'):].split('.')[0] @@ -802,6 +824,12 @@ class AzureBlobStorageBackend(common.StorageBackend): description='The storage account of the Azure Blob Storage backend.', ) + supports_environment_auth: Literal[True] = pydantic.Field( + default=True, + const=True, + description='Whether the backend supports environment authentication.', + ) + @override @classmethod def create( @@ -866,9 +894,7 @@ def parse_uri_to_link(self, region: str) -> str: @override def data_auth( self, - access_key_id: str, - access_key: str, - region: str | None = None, + data_cred: credentials.DataCredential | None = None, access_type: common.AccessType | None = None, ): # pylint: disable=unused-argument @@ -878,8 +904,11 @@ def data_auth( if _skip_data_auth(): return + if data_cred is None: + data_cred = self.resolved_data_credential + def _validate_auth(): - with azure.create_client(access_key) as service_client: + with azure.create_client(data_cred) as service_client: if self.container: with service_client.get_container_client(self.container) as container_client: container_client.get_container_properties() @@ -900,10 +929,8 @@ def _validate_auth(): @override def region( self, - access_key_id: str, - access_key: str, + _: credentials.DataCredential | None = None, ) -> str: - # pylint: disable=unused-argument # Azure Blob Storage does not encode region in the URLs, we will simply # use the default region to conform to the interface. return self.default_region @@ -916,8 +943,7 @@ def default_region(self) -> str: @override def client_factory( self, - access_key_id: str, - access_key: str, + data_cred: credentials.DataCredential | None = None, request_headers: List[header.RequestHeaders] | None = None, **kwargs: Any, ) -> azure.AzureBlobStorageClientFactory: @@ -925,9 +951,10 @@ def client_factory( """ Returns a factory for creating storage clients. """ - return azure.AzureBlobStorageClientFactory( # pylint: disable=unexpected-keyword-arg - connection_string=access_key, - ) + if data_cred is None: + data_cred = self.resolved_data_credential + + return azure.AzureBlobStorageClientFactory(data_cred=data_cred) def construct_storage_backend( diff --git a/src/lib/data/storage/backends/common.py b/src/lib/data/storage/backends/common.py index d10f2ea2..895cc117 100644 --- a/src/lib/data/storage/backends/common.py +++ b/src/lib/data/storage/backends/common.py @@ -20,6 +20,8 @@ import abc import enum +import functools +import logging import os import pathlib from urllib import parse @@ -28,6 +30,11 @@ import pydantic from ..core import header, provider +from ..credentials import credentials +from ....utils import osmo_errors + + +logger = logging.getLogger(__name__) class StorageBackendType(enum.Enum): @@ -89,7 +96,13 @@ class StoragePath: ) -class StorageBackend(abc.ABC, pydantic.BaseModel, extra=pydantic.Extra.forbid): +class StorageBackend( + abc.ABC, + pydantic.BaseModel, + extra=pydantic.Extra.forbid, + arbitrary_types_allowed=True, + keep_untouched=(functools.cached_property,), # Don't serialize cached properties +): """ Represents information about a storage backend. """ @@ -101,6 +114,7 @@ class StorageBackend(abc.ABC, pydantic.BaseModel, extra=pydantic.Extra.forbid): path: str override_endpoint: str | None = None + supports_environment_auth: bool = False @classmethod @abc.abstractmethod @@ -184,24 +198,36 @@ def parse_uri_to_link(self, region: str) -> str: @abc.abstractmethod def data_auth( self, - access_key_id: str, - access_key: str, - region: str | None = None, + data_cred: credentials.DataCredential | None = None, access_type: AccessType | None = None, ): """ - Validates if the access id and key can perform action + Validates if the access id and key can perform action. + + If no data credential is provided, it will be resolved via resolved_data_credential. + + Args: + data_cred: The data credential to use for the validation. + access_type: The access type to validate. """ pass @abc.abstractmethod def region( self, - access_key_id: str, - access_key: str, + data_cred: credentials.DataCredential | None = None, ) -> str: """ - Infer the region of the bucket via provided credentials. + Infer the region of the bucket from the storage backend. + + Some backends may not require a data credential to infer the region. + If no data credential is provided, it will be resolved via resolved_data_credential. + + Args: + data_cred: The data credential to use for the region inference. + + Returns: + The region of the bucket. """ pass @@ -232,12 +258,46 @@ def to_storage_path( @abc.abstractmethod def client_factory( self, - access_key_id: str, - access_key: str, + data_cred: credentials.DataCredential | None = None, request_headers: List[header.RequestHeaders] | None = None, **kwargs: Any, ) -> provider.StorageClientFactory: """ Returns a factory for creating storage clients. + + If no data credential is provided, it will be resolved via resolved_data_credential. + + Args: + data_cred: The data credential to use for the client factory. + request_headers: The request headers to use for the client factory. + **kwargs: Additional keyword arguments to pass to the client factory. + + Returns: + A factory for creating storage clients. """ pass + + @functools.cached_property + def resolved_data_credential(self) -> credentials.DataCredential: + """ + Resolve the data credential for the storage backend. + + Returns: + The resolved data credential. + + Raises: + OSMOCredentialError: If the data credential is not found. + """ + data_cred = credentials.get_static_data_credential_from_config(self.profile) + + if data_cred is not None: + return data_cred + + if self.supports_environment_auth: + return credentials.DefaultDataCredential( + endpoint=self.profile, + region=None, + ) + + raise osmo_errors.OSMOCredentialError( + f'Data credential not found for {self.profile}') diff --git a/src/lib/data/storage/backends/s3.py b/src/lib/data/storage/backends/s3.py index f8b5e198..9dd81d6c 100644 --- a/src/lib/data/storage/backends/s3.py +++ b/src/lib/data/storage/backends/s3.py @@ -25,7 +25,7 @@ import re import time from typing import Any, Callable, Dict, Generator, List, Tuple, Type -from typing_extensions import override +from typing_extensions import assert_never, override import boto3.exceptions import boto3.s3.transfer @@ -38,6 +38,7 @@ import mypy_boto3_s3.type_defs from . import common +from .. import credentials from ..core import client, provider from ....utils import common as utils_common @@ -507,8 +508,7 @@ def close(self) -> None: def create_client( - access_key_id: str, - access_key: str, + data_cred: credentials.DataCredential, scheme: str, *, endpoint_url: str | None = None, @@ -522,8 +522,7 @@ def create_client( without proper synchronization. Args: - access_key_id: The access key ID. - access_key: The access key. + data_cred: The data credential. scheme: The scheme of the storage. endpoint_url: The endpoint URL. region: The region. @@ -538,14 +537,27 @@ def _get_client() -> mypy_boto3_s3.S3Client: session = boto3.Session() _add_request_headers(session, extra_headers) - return session.client( - 's3', - endpoint_url=endpoint_url, - aws_access_key_id=access_key_id, - aws_secret_access_key=access_key, - region_name=region, - config=config, - ) + match data_cred: + case credentials.StaticDataCredential(): + # Uses direct credentials (e.g. access key and secret key) + return session.client( + 's3', + endpoint_url=endpoint_url, + aws_access_key_id=data_cred.access_key_id, + aws_secret_access_key=data_cred.access_key.get_secret_value(), + region_name=region, + config=config, + ) + case credentials.DefaultDataCredential(): + # Uses ambient credentials (e.g. Environment variables, Workload Identity, etc.) + return session.client( + 's3', + endpoint_url=endpoint_url, + region_name=region, + config=config, + ) + case _ as unreachable: + assert_never(unreachable) return client.execute_api( _get_client, @@ -1137,8 +1149,7 @@ class S3StorageClientFactory(provider.StorageClientFactory): Factory for the S3StorageClient. """ - access_key_id: str - access_key: str + data_cred: credentials.DataCredential region: str scheme: str endpoint_url: str | None = dataclasses.field(default=None) @@ -1149,8 +1160,7 @@ class S3StorageClientFactory(provider.StorageClientFactory): def create(self) -> S3StorageClient: return S3StorageClient( lambda: create_client( - self.access_key_id, - self.access_key, + self.data_cred, scheme=self.scheme, endpoint_url=self.endpoint_url, region=self.region, diff --git a/src/lib/data/storage/backends/tests/BUILD b/src/lib/data/storage/backends/tests/BUILD index 646363f4..05e2ac1f 100644 --- a/src/lib/data/storage/backends/tests/BUILD +++ b/src/lib/data/storage/backends/tests/BUILD @@ -24,5 +24,6 @@ osmo_py_test( deps = [ "//src/lib/data/storage/backends", "//src/lib/data/storage/core", + "//src/lib/data/storage/credentials", ], ) diff --git a/src/lib/data/storage/backends/tests/test_backends.py b/src/lib/data/storage/backends/tests/test_backends.py index 5cffeffd..f0f4e65b 100644 --- a/src/lib/data/storage/backends/tests/test_backends.py +++ b/src/lib/data/storage/backends/tests/test_backends.py @@ -23,7 +23,9 @@ from unittest import mock from src.lib.data.storage.backends import backends, s3 +from src.lib.data.storage.credentials import credentials from src.lib.data.storage.core import header +from src.lib.utils import osmo_errors class TestBackends(unittest.TestCase): @@ -43,14 +45,19 @@ def test_s3_extra_headers(self): header.DownloadRequestHeaders(headers={'x-download-header': 'test-unsupported-header'}), ] - # Act - s3_client_factory = s3_backend.client_factory( + data_cred = credentials.StaticDataCredential( + endpoint='s3://test-bucket/test-key', access_key_id='test-access-key-id', access_key='test-access-key', - request_headers=request_headers, region='us-east-1', ) + # Act + s3_client_factory = s3_backend.client_factory( + data_cred=data_cred, + request_headers=request_headers, + ) + # Assert self.assertEqual( s3_client_factory.extra_headers, @@ -85,11 +92,16 @@ def test_s3_extra_headers_is_registered(self, mock_session_class): 'before-call.s3.UploadPart': {'x-upload-header': 'test-upload-header'}, 'before-call.s3.CompleteMultipartUpload': {'x-upload-header': 'test-upload-header'}, } + data_cred = credentials.StaticDataCredential( + endpoint='s3://test-bucket/test-key', + access_key_id='test-access-key-id', + access_key='test-access-key', + region='us-east-1', + ) # Act s3.create_client( - access_key_id='test-access-key-id', - access_key='test-access-key', + data_cred=data_cred, scheme='s3', extra_headers=extra_headers ) @@ -144,6 +156,72 @@ def test_path_backend_contains_sub_path(self): self.assertTrue(storage_backend_2 in storage_backend_1) self.assertTrue(storage_backend_1 not in storage_backend_2) + @mock.patch( + 'src.lib.data.storage.credentials.credentials.get_static_data_credential_from_config', + return_value=None, + ) + def test_environment_auth_support(self, mock_get_config): + """ + Test that S3/Azure support environment authentication while other backends do not. + When no static credential is found in config: + - S3 and Azure should return DefaultDataCredential + - Other backends (Swift, GS, TOS) should raise OSMOCredentialError + """ + # pylint: disable=unused-argument + test_cases = [ + # Backends that support environment auth + ( + 's3://test-bucket/test-key', + 's3://test-bucket', + True, + ), + ( + 'azure://testaccount/testcontainer/testpath', + 'azure://testaccount', + True, + ), + # Backends that do NOT support environment auth + ( + 'swift://test.example.com/AUTH_namespace/testcontainer/testpath', + 'swift://test.example.com/AUTH_namespace', + False, + ), + ( + 'gs://test-bucket/test-key', + 'gs://test-bucket', + False, + ), + ( + 'tos://tos-cn-beijing.volces.com/test-bucket/test-key', + 'tos://tos-cn-beijing.volces.com/test-bucket', + False, + ), + ] + + for uri, expected_profile, supports_env_auth in test_cases: + with self.subTest(uri=uri, supports_env_auth=supports_env_auth): + # Arrange + backend = backends.construct_storage_backend(uri=uri) + + if supports_env_auth: + # Act + data_cred = backend.resolved_data_credential + + # Assert + self.assertIsInstance( + data_cred, + credentials.DefaultDataCredential, + f'{uri} should support environment auth' + ) + self.assertEqual(data_cred.endpoint, expected_profile) + else: + # Act & Assert + with self.assertRaises(osmo_errors.OSMOCredentialError) as context: + _ = backend.resolved_data_credential + + self.assertIn('Data credential not found', str(context.exception)) + self.assertIn(expected_profile, str(context.exception)) + if __name__ == '__main__': unittest.main() diff --git a/src/lib/data/storage/client.py b/src/lib/data/storage/client.py index 9a07f5f9..b0759de3 100644 --- a/src/lib/data/storage/client.py +++ b/src/lib/data/storage/client.py @@ -30,7 +30,9 @@ from . import ( backends, common, + constants, copying, + credentials, deleting, downloading, streaming, @@ -40,11 +42,8 @@ ) from .backends import common as backends_common from .core import executor, header -from .. import constants from ...utils import ( cache, - client_configs, - credentials, logging as logging_utils, osmo_errors, paths, @@ -134,7 +133,8 @@ def create( .. important:: - If data_credential is not provided, it will be resolved from the file system. + If data_credential is not provided, it will be resolved from the host system + (i.e. file system, environment variables, etc.). :param str | None storage_uri: The storage URI to use for the client. :param backends_common.StorageBackend | None storage_backend: The storage backend to use for @@ -239,26 +239,43 @@ def create( description='Headers to apply to all requests of this client.', ) - @functools.cached_property - def data_credential(self) -> credentials.DataCredential: + @pydantic.root_validator(skip_on_failure=True) + @classmethod + def validate_data_credential_endpoint(cls, values): """ - Resolves the data credential. + Validates that the data credential endpoint matches the storage backend profile. """ - # Validate data credential input if provided. - if self.data_credential_input is not None: - # Validate that the data credential endpoint matches the storage backend profile. + data_credential_input = values.get('data_credential_input') + if data_credential_input is not None: + storage_uri = values.get('storage_uri') + cache_config = values.get('cache_config') + + # Construct backends to validate profiles match data_cred_backend = backends.construct_storage_backend( - uri=self.data_credential_input.endpoint, - cache_config=self.cache_config, + uri=data_credential_input.endpoint, + cache_config=cache_config, ) - if data_cred_backend.profile != self.storage_backend.profile: + storage_backend = backends.construct_storage_backend( + uri=storage_uri, + cache_config=cache_config, + ) + + if data_cred_backend.profile != storage_backend.profile: raise osmo_errors.OSMOCredentialError( 'Credential endpoint must match the storage backend profile') - return ( - self.data_credential_input or - client_configs.get_credentials(self.storage_backend.profile) - ) + return values + + @functools.cached_property + def data_credential(self) -> credentials.DataCredential: + """ + Resolves the data credential. + """ + if self.data_credential_input is not None: + return self.data_credential_input + + # Resolve the data credential from the storage backend + return self.storage_backend.resolved_data_credential @functools.cached_property def storage_backend(self) -> backends_common.StorageBackend: @@ -273,19 +290,6 @@ def storage_backend(self) -> backends_common.StorageBackend: cache_config=self.cache_config, ) - @functools.cached_property - def storage_auth(self) -> common.StorageAuth: - """ - Storage backend authentication parameters. - - :return: The storage authentication parameters - :rtype: common.StorageAuth - """ - return common.StorageAuth( - user=self.data_credential.access_key_id, - key=self.data_credential.access_key.get_secret_value() - ) - def _validate_remote_path( self, remote_path: str | None, @@ -474,9 +478,7 @@ def _upload_with_paths( request_headers.append(header.UploadRequestHeaders(headers=extra_headers)) client_factory = self.storage_backend.client_factory( - access_key_id=self.storage_auth.user, - access_key=self.storage_auth.key, - region=self.data_credential.region, + data_cred=self.data_credential, request_headers=request_headers, ) @@ -580,9 +582,7 @@ def _upload_with_worker_inputs( request_headers.append(header.UploadRequestHeaders(headers=extra_headers)) client_factory = self.storage_backend.client_factory( - access_key_id=self.storage_auth.user, - access_key=self.storage_auth.key, - region=self.data_credential.region, + data_cred=self.data_credential, request_headers=request_headers, ) @@ -714,9 +714,7 @@ def _copy_with_paths( ) client_factory = self.storage_backend.client_factory( - access_key_id=self.storage_auth.user, - access_key=self.storage_auth.key, - region=self.data_credential.region, + data_cred=self.data_credential, request_headers=[ header.ClientHeaders(headers=self.headers), ] if self.headers else None, @@ -864,9 +862,7 @@ def _download_with_paths( ) client_factory = self.storage_backend.client_factory( - access_key_id=self.storage_auth.user, - access_key=self.storage_auth.key, - region=self.data_credential.region, + data_cred=self.data_credential, request_headers=[ header.ClientHeaders(headers=self.headers), ] if self.headers else None, @@ -955,9 +951,7 @@ def _download_with_worker_inputs( Downloads data using a list of DownloadWorkerInput objects. """ client_factory = self.storage_backend.client_factory( - access_key_id=self.storage_auth.user, - access_key=self.storage_auth.key, - region=self.data_credential.region, + data_cred=self.data_credential, request_headers=[ header.ClientHeaders(headers=self.headers), ] if self.headers else None, @@ -1009,9 +1003,7 @@ def list_objects( ) client_factory = self.storage_backend.client_factory( - access_key_id=self.storage_auth.user, - access_key=self.storage_auth.key, - region=self.data_credential.region, + data_cred=self.data_credential, request_headers=[ header.ClientHeaders(headers=self.headers), ] if self.headers else None, @@ -1097,9 +1089,7 @@ def get_object_stream( validated_remote_path = self._validate_remote_path(remote_path) client_factory = self.storage_backend.client_factory( - access_key_id=self.storage_auth.user, - access_key=self.storage_auth.key, - region=self.data_credential.region, + data_cred=self.data_credential, request_headers=[ header.ClientHeaders(headers=self.headers), ] if self.headers else None, @@ -1189,9 +1179,7 @@ def delete_objects( ) client_factory = self.storage_backend.client_factory( - access_key_id=self.storage_auth.user, - access_key=self.storage_auth.key, - region=self.data_credential.region, + data_cred=self.data_credential, request_headers=[ header.ClientHeaders(headers=self.headers), ] if self.headers else None, @@ -1262,7 +1250,8 @@ def create( Either storage_uri or storage_backend must be provided, not both. .. important:: - If data_credential is not provided, it will be resolved from the file system. + If data_credential is not provided, it will be resolved from the host system + (i.e. file system, environment variables, etc.). :param str | None storage_uri: The object URI to use for the client. :param backends_common.StorageBackend | None storage_backend: The object storage backend to diff --git a/src/lib/data/constants/BUILD b/src/lib/data/storage/constants/BUILD similarity index 100% rename from src/lib/data/constants/BUILD rename to src/lib/data/storage/constants/BUILD diff --git a/src/lib/data/constants/__init__.py b/src/lib/data/storage/constants/__init__.py similarity index 100% rename from src/lib/data/constants/__init__.py rename to src/lib/data/storage/constants/__init__.py diff --git a/src/lib/data/constants/constants.py b/src/lib/data/storage/constants/constants.py similarity index 100% rename from src/lib/data/constants/constants.py rename to src/lib/data/storage/constants/constants.py diff --git a/src/lib/data/storage/credentials/BUILD b/src/lib/data/storage/credentials/BUILD new file mode 100644 index 00000000..ef9d2c4b --- /dev/null +++ b/src/lib/data/storage/credentials/BUILD @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +package(default_visibility = ["//visibility:public"]) + +load("//bzl:py.bzl", "osmo_py_binary", "osmo_py_library") +load("@osmo_python_deps//:requirements.bzl", "requirement") + +osmo_py_library( + name = "credentials", + srcs = glob(["*.py"]), + deps = [ + "//src/lib/data/storage/constants", + "//src/lib/utils:client_configs", + "//src/lib/utils:osmo_errors", + requirement("pydantic"), + requirement("pyyaml"), + ], +) diff --git a/src/lib/data/storage/credentials/__init__.py b/src/lib/data/storage/credentials/__init__.py new file mode 100644 index 00000000..40906e3d --- /dev/null +++ b/src/lib/data/storage/credentials/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Credentials for the data module. +""" + +from .credentials import * diff --git a/src/lib/data/storage/credentials/credentials.py b/src/lib/data/storage/credentials/credentials.py new file mode 100644 index 00000000..d90f99d1 --- /dev/null +++ b/src/lib/data/storage/credentials/credentials.py @@ -0,0 +1,134 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +""" +Credentials for the data module. +""" + +import abc +import os +import re +from typing import Union + +import pydantic +import yaml + +from .. import constants +from ....utils import client_configs, osmo_errors + + +class DataCredentialBase(pydantic.BaseModel, abc.ABC, extra=pydantic.Extra.forbid): + """ + Base class for data credentials (i.e. credentials with endpoint and region). + """ + endpoint: str = pydantic.Field( + ..., + description='The endpoint URL for the data service', + ) + region: str | None = pydantic.Field( + default=None, + description='The region for the data service', + ) + + @pydantic.validator('endpoint') + @classmethod + def validate_endpoint(cls, value: str) -> constants.StorageCredentialPattern: + """ + Validates endpoint. Returns the value of parsed job_id if valid. + """ + if not re.fullmatch(constants.STORAGE_CREDENTIAL_REGEX, value): + raise osmo_errors.OSMOUserError(f'Invalid endpoint: {value}') + return value.rstrip('/') + + +class StaticDataCredential(DataCredentialBase, abc.ABC, extra=pydantic.Extra.forbid): + """ + Static data credentials (i.e. credentials with access_key_id and access_key) for a data backend. + """ + access_key_id: str = pydantic.Field( + ..., + description='The authentication key for a data backend', + ) + + access_key: pydantic.SecretStr = pydantic.Field( + ..., + description='The encrypted authentication secret for a data backend', + ) + + def to_decrypted_dict(self) -> dict[str, str]: + output = { + 'access_key_id': self.access_key_id, + 'access_key': self.access_key.get_secret_value(), + 'endpoint': self.endpoint, + } + + if self.region: + output['region'] = self.region + + return output + + +class DefaultDataCredential(DataCredentialBase, extra=pydantic.Extra.forbid): + """ + Data credential that delegates resolution to the underlying SDK. + + Uses the SDK's default credential chain (e.g., Azure's DefaultAzureCredential, + boto3's credential resolution) which may include environment variables, + workload identity, instance metadata, and other provider-specific methods. + + Intentionally left empty as all credential resolution is handled by the SDK. + """ + pass + + +DataCredential = Union[ + StaticDataCredential, + DefaultDataCredential, +] + + +def get_static_data_credential_from_config(url: str) -> StaticDataCredential | None: + """ + Get a matching static data credential from the config file. + + Args: + url: The URL of the data service. + Returns: + The static data credential or None if not found. + """ + config_dir = client_configs.get_client_config_dir(create=False) + config_file = os.path.join(config_dir, 'config.yaml') + + if not os.path.exists(config_file): + return None + + with open(config_file, 'r', encoding='utf-8') as file: + configs = yaml.safe_load(file.read()) + + if 'auth' in configs and 'data' in configs['auth'] and url in configs['auth']['data']: + data_cred_dict = configs['auth']['data'][url] + data_cred = StaticDataCredential( + access_key_id=data_cred_dict['access_key_id'], + access_key=pydantic.SecretStr(data_cred_dict['access_key']), + endpoint=url, + region=data_cred_dict['region'], + ) + + return data_cred + + return None diff --git a/src/lib/data/storage/mux.py b/src/lib/data/storage/mux.py index f5a5f3de..942c040d 100644 --- a/src/lib/data/storage/mux.py +++ b/src/lib/data/storage/mux.py @@ -126,16 +126,12 @@ def bind(self, storage_profile: str) -> provider.StorageClientProvider: # Pre-init factory outside of the lock # May be redundantly computed by multiple threads and that is okay. - data_cred = client_configs.get_credentials(storage_profile) storage_backend = backends.construct_storage_backend( uri=storage_profile, profile=True, cache_config=self._cache_config, ) client_factory = storage_backend.client_factory( - access_key_id=data_cred.access_key_id, - access_key=data_cred.access_key.get_secret_value(), - region=data_cred.region, request_headers=self._client_factory.request_headers, **self._client_factory.kwargs, ) diff --git a/src/lib/utils/BUILD b/src/lib/utils/BUILD index fa673f83..dfae8470 100644 --- a/src/lib/utils/BUILD +++ b/src/lib/utils/BUILD @@ -72,8 +72,6 @@ osmo_py_library( requirement("pyyaml"), ":cache", ":common", - ":credentials", - ":osmo_errors", ], ) @@ -96,7 +94,8 @@ osmo_py_library( deps = [ requirement("pydantic"), ":osmo_errors", - "//src/lib/data/constants", + "//src/lib/data/storage/constants", + "//src/lib/data/storage/credentials", ], ) @@ -158,7 +157,7 @@ osmo_py_library( deps = [ ":common", ":osmo_errors", - "//src/lib/data/constants", + "//src/lib/data/storage/constants", ], ) diff --git a/src/lib/utils/client_configs.py b/src/lib/utils/client_configs.py index cafe87d3..a34231f1 100644 --- a/src/lib/utils/client_configs.py +++ b/src/lib/utils/client_configs.py @@ -16,16 +16,15 @@ SPDX-License-Identifier: Apache-2.0 """ -import functools import os from typing import Optional import yaml -from . import cache, common, credentials, osmo_errors +from . import cache, common -def get_client_config_dir() -> str: +def get_client_config_dir(create: bool = True) -> str: """ Get path of directory where config files should be stored """ override_dir = os.getenv(common.OSMO_CONFIG_OVERRIDE) xdg_config = os.getenv('XDG_CONFIG_HOME') @@ -42,12 +41,14 @@ def get_client_config_dir() -> str: else: config_dir = os.path.expanduser('~/.config/osmo') - os.makedirs(config_dir, exist_ok=True) + if create: + os.makedirs(config_dir, exist_ok=True) + return config_dir def get_cache_config() -> Optional[cache.CacheConfig]: - osmo_directory = get_client_config_dir() + osmo_directory = get_client_config_dir(create=False) password_file = osmo_directory + '/config.yaml' if os.path.isfile(password_file): @@ -58,29 +59,6 @@ def get_cache_config() -> Optional[cache.CacheConfig]: return None -@functools.lru_cache() -def get_credentials(url: str) -> credentials.DataCredential: - osmo_directory = get_client_config_dir() - password_file = osmo_directory + '/config.yaml' - - if os.path.isfile(password_file): - with open(password_file, 'r', encoding='utf-8') as file: - configs = yaml.safe_load(file.read()) - if url in configs['auth']['data']: - data_cred_dict = configs['auth']['data'][url] - data_cred = credentials.DataCredential( - access_key_id=data_cred_dict['access_key_id'], - access_key=data_cred_dict['access_key'], - endpoint=url, - region=data_cred_dict['region'], - ) - return data_cred - raise osmo_errors.OSMOError(f'Credential not set for {url}. Please set credentials using: \n' + - 'osmo credential set my_cred --type DATA ' + - '--payload access_key_id=your_s3_username access_key=your_s3_key' + - ' endpoint=your_endpoint region=endpoint_region') - - def get_client_state_dir() -> str: """ Get path of directory where state info should be stored, like logs """ override_dir = os.getenv(common.OSMO_STATE_OVERRIDE) diff --git a/src/lib/utils/credentials.py b/src/lib/utils/credentials.py index 912eb2c8..176adb97 100644 --- a/src/lib/utils/credentials.py +++ b/src/lib/utils/credentials.py @@ -16,12 +16,12 @@ SPDX-License-Identifier: Apache-2.0 """ -import re - import pydantic -from . import osmo_errors -from ..data import constants +from ..data.storage import credentials + +StaticDataCredential = credentials.StaticDataCredential +get_static_data_credential_from_config = credentials.get_static_data_credential_from_config CREDNAMEREGEX = r'^[a-zA-Z]([a-zA-Z0-9_-]*[a-zA-Z0-9])?$' @@ -35,59 +35,3 @@ class RegistryCredential(pydantic.BaseModel, extra=pydantic.Extra.forbid): pydantic.SecretStr(''), description='The authentication token for the Docker registry', ) - - -class BasicDataCredential(pydantic.BaseModel, extra=pydantic.Extra.forbid): - """ Authentication information for a data service without endpoint and region info. """ - access_key_id: str = pydantic.Field( - description='The authentication key for the data service') - access_key: pydantic.SecretStr = pydantic.Field( - description='The authentication secret for the data service') - - -class DataCredential(BasicDataCredential, extra=pydantic.Extra.forbid): - """ - Authentication information for a data service. - """ - - endpoint: str = pydantic.Field( - ..., - description='The endpoint URL for the data service', - ) - - region: str = pydantic.Field( - constants.DEFAULT_BOTO3_REGION, - description='The region for the data service', - ) - - @pydantic.validator('endpoint') - @classmethod - def validate_endpoint(cls, value: str) -> constants.StorageCredentialPattern: - """ - Validates endpoint. Returns the value of parsed job_id if valid. - """ - if not re.fullmatch(constants.STORAGE_CREDENTIAL_REGEX, value): - raise osmo_errors.OSMOUserError(f'Invalid endpoint: {value}') - return value.rstrip('/') - - -class DecryptedDataCredential(BasicDataCredential, extra=pydantic.Extra.ignore): - """ - Basic data cred with access_key decrypted. - """ - - access_key: str = pydantic.Field( # type: ignore[assignment] - ..., - description='The authentication secret for the data service', - ) - - region: str = pydantic.Field( - constants.DEFAULT_BOTO3_REGION, - description='The region for the data service', - ) - - -def decrypt(base_cred: DataCredential) -> DecryptedDataCredential: - cred_dict = base_cred.dict() - cred_dict['access_key'] = cred_dict['access_key'].get_secret_value() - return DecryptedDataCredential(**cred_dict) diff --git a/src/lib/utils/validation.py b/src/lib/utils/validation.py index 73155bf6..bd688639 100644 --- a/src/lib/utils/validation.py +++ b/src/lib/utils/validation.py @@ -22,7 +22,7 @@ import re from . import common, osmo_errors -from ..data import constants +from ..data.storage import constants def positive_integer(x: int): diff --git a/src/service/core/config/tests/test_config_history.py b/src/service/core/config/tests/test_config_history.py index f10de55a..f11c0d1e 100644 --- a/src/service/core/config/tests/test_config_history.py +++ b/src/service/core/config/tests/test_config_history.py @@ -1232,9 +1232,11 @@ def test_get_config_diff(self): test_bucket = connectors.BucketConfig( check_key=False, dataset_path='swift://test-endpoint/AUTH_test-team/dev/testuser/datasets', - default_credential=credentials.BasicDataCredential( - access_key='test-secret', + default_credential=credentials.StaticDataCredential( + access_key=pydantic.SecretStr('test-secret'), access_key_id='testuser:AUTH_team-osmo', + endpoint='swift://test-endpoint/AUTH_test-team/', + region='us-east-1', ), description='My fancy dataset', mode='read-write', diff --git a/src/service/core/data/data_service.py b/src/service/core/data/data_service.py index 22545ef0..7ab18b1b 100755 --- a/src/service/core/data/data_service.py +++ b/src/service/core/data/data_service.py @@ -44,17 +44,20 @@ def create_uuid() -> str: return base64.urlsafe_b64encode(unique_id.bytes).decode('utf-8')[:-2] +# TODO: add this to client side for dataset validation (to handle workload identity credentials) def validate_user_cred(postgres: connectors.PostgresConnector, user: str, location: str, access_type: storage.AccessType): """ Validates that the user has the specifc access type to a backend location """ backend_info = storage.construct_storage_backend(location) + user_cred = postgres.get_data_cred(user, backend_info.profile) - backend_info.data_auth(user_cred.access_key_id, - user_cred.access_key, - user_cred.region, - access_type) + if not user_cred: + raise osmo_errors.OSMOCredentialError( + f'Could not find {backend_info.profile} credential for user {user}.') + + backend_info.data_auth(user_cred, access_type) def get_dataset(postgres: connectors.PostgresConnector, bucket: str, name: str) -> Any: diff --git a/src/service/core/workflow/objects.py b/src/service/core/workflow/objects.py index 43a5ea49..3e9e3630 100644 --- a/src/service/core/workflow/objects.py +++ b/src/service/core/workflow/objects.py @@ -19,12 +19,13 @@ import collections import datetime import math -from typing import Any, Dict, List, NamedTuple, Optional, Set +from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Set import yaml import pydantic from src.lib.data import storage +from src.lib.data.storage.credentials import credentials as data_credentials from src.lib.utils import credentials, common, osmo_errors, priority as wf_priority import src.lib.utils.logging from src.utils.job import app, common as task_common, jobs, kb_objects, task, workflow @@ -546,7 +547,23 @@ class CredentialRecord(NamedTuple): payload: str -class UserRegistryCredential(credentials.RegistryCredential, extra=pydantic.Extra.forbid): +class CredentialProtocol(Protocol): + """ Protocol for credentials. """ + @staticmethod + def type() -> connectors.CredentialType: + pass + + def to_db_row(self, user: str, postgres: connectors.PostgresConnector) -> CredentialRecord: + pass + + def valid_cred(self, workflow_config: connectors.WorkflowConfig): + pass + + +class UserRegistryCredential( + credentials.RegistryCredential, + extra=pydantic.Extra.forbid, +): """ Authentication information for a Docker registry. """ auth: str = pydantic.Field( description='The authentication token for the Docker registry') # type: ignore @@ -572,33 +589,62 @@ def valid_cred(self, workflow_config: connectors.WorkflowConfig): raise osmo_errors.OSMOCredentialError('Registry authentication failed.') - -class UserDataCredential(credentials.DataCredential, extra=pydantic.Extra.forbid): +class UserDataCredential( + data_credentials.DataCredentialBase, + extra=pydantic.Extra.forbid, +): """ Authentication information for a data service. """ + + access_key_id: str = pydantic.Field( + ..., + description='The authentication key for a data backend', + ) + access_key: str = pydantic.Field( - description='The authentication secret for the data service') # type: ignore + ..., + description='The authentication secret for a data backend', + ) @staticmethod def type() -> connectors.CredentialType: return connectors.CredentialType.DATA def to_db_row(self, user: str, postgres: connectors.PostgresConnector) -> CredentialRecord: - payload = {'access_key_id': self.access_key_id, - 'access_key': self.access_key, - 'region': self.region} + payload = { + 'access_key_id': self.access_key_id, + 'access_key': self.access_key, + } + + if self.region: + payload['region'] = self.region + payload = postgres.encrypt_dict(payload, user) - return CredentialRecord(self.type().value, - self.endpoint, - connectors.PostgresConnector.encode_hstore(payload)) + + return CredentialRecord( + self.type().value, + self.endpoint, + connectors.PostgresConnector.encode_hstore(payload), + ) def valid_cred(self, workflow_config: connectors.WorkflowConfig): storage_info = storage.construct_storage_backend(self.endpoint, True) if storage_info.scheme in workflow_config.credential_config.disable_data_validation: return - storage_info.data_auth(self.access_key_id, self.access_key, self.region) + data_cred = data_credentials.StaticDataCredential( + endpoint=self.endpoint, + access_key_id=self.access_key_id, + access_key=pydantic.SecretStr(self.access_key), + region=self.region, + ) + + storage_info.data_auth(data_cred) -class UserCredential(pydantic.BaseModel, extra=pydantic.Extra.forbid): + +class UserCredential( + pydantic.BaseModel, + extra=pydantic.Extra.forbid, +): """ Generic authentication information. """ credential: Dict[str, str] = pydantic.Field( description='The credential dictionary that contains authentication information' @@ -659,6 +705,17 @@ def validate_credential(cls, values): # pylint: disable=no-self-argument f'Exactly one of the following must be set {cls.__fields__.keys()}') return values + def get_credential(self) -> CredentialProtocol: + if self.registry_credential is not None: + return self.registry_credential + elif self.data_credential is not None: + return self.data_credential + elif self.generic_credential is not None: + return self.generic_credential + else: + raise osmo_errors.OSMOUserError( + f'Exactly one of the following must be set: {self.__fields__.keys()}') + class CredentialGetResponse(pydantic.BaseModel): """ Credential Response. """ diff --git a/src/service/core/workflow/workflow_service.py b/src/service/core/workflow/workflow_service.py index 4b8bcc86..9f10934b 100644 --- a/src/service/core/workflow/workflow_service.py +++ b/src/service/core/workflow/workflow_service.py @@ -909,7 +909,7 @@ def get_user_credential( @router_credentials.post('/api/credentials/{cred_name}') def set_user_credential( cred_name: str, - credential: objects.CredentialOptions, + credential_option: objects.CredentialOptions, user_header: Optional[str] = fastapi.Header(alias=login.OSMO_USER_HEADER, default=None)): """ Post/Update user credentials """ @@ -923,11 +923,11 @@ def set_user_credential( if not rows: context.database.secret_manager.add_new_user(user_name) - credential_set = getattr(credential, credential.__fields_set__.pop()) + credential = credential_option.get_credential() workflow_config = context.database.get_workflow_configs() - credential_set.valid_cred(workflow_config) + credential.valid_cred(workflow_config) try: - cmd_arg = credential_set.to_db_row(user_name, context.database) + cmd_arg = credential.to_db_row(user_name, context.database) context.database.execute_commit_command(objects.UserCredential.commit_cmd(), tuple([user_name, cred_name]) + cmd_arg) logging.info('Saved credential %s on the server.', cred_name) diff --git a/src/utils/connectors/BUILD b/src/utils/connectors/BUILD index 041cb4b3..afd64fea 100644 --- a/src/utils/connectors/BUILD +++ b/src/utils/connectors/BUILD @@ -12,8 +12,8 @@ osmo_py_library( requirement("pydantic"), requirement("redis"), requirement("pyyaml"), - "//src/lib/data/constants", "//src/lib/data/storage", + "//src/lib/data/storage/constants", "//src/lib/utils:cache", "//src/lib/utils:common", "//src/lib/utils:credentials", diff --git a/src/utils/connectors/postgres.py b/src/utils/connectors/postgres.py index ed790898..995531db 100644 --- a/src/utils/connectors/postgres.py +++ b/src/utils/connectors/postgres.py @@ -40,7 +40,8 @@ from starlette.datastructures import Headers from starlette.types import ASGIApp, Receive, Scope, Send -from src.lib.data import constants, storage +from src.lib.data import storage +from src.lib.data.storage import constants from src.lib.utils import (common, credentials, jinja_sandbox, login, osmo_errors, role) from src.utils import auth, notify @@ -1217,7 +1218,7 @@ def func(new_encrypted: str): self.execute_commit_command(cmd, (new_encrypted,)) return func - def get_data_cred(self, user: str, profile: str) -> credentials.DecryptedDataCredential: + def get_data_cred(self, user: str, profile: str) -> credentials.StaticDataCredential | None: """ Fetch data credentials by profile. """ select_data_cmd = PostgresSelectCommand( table='credential', @@ -1225,25 +1226,27 @@ def get_data_cred(self, user: str, profile: str) -> credentials.DecryptedDataCre condition_args=[user, CredentialType.DATA.value, profile]) row = self.execute_fetch_command(*select_data_cmd.get_args()) if row: - return credentials.DecryptedDataCredential(**self.decrypt_credential(row[0])) + return credentials.StaticDataCredential( + endpoint=profile, + **self.decrypt_credential(row[0]), + ) else: # Check default bucket creds for bucket in self.get_dataset_configs().buckets.values(): bucket_info = storage.construct_storage_backend(bucket.dataset_path) if bucket_info.profile == profile: if bucket.default_credential: - return credentials.DecryptedDataCredential( + return credentials.StaticDataCredential( region=bucket.region, access_key_id=bucket.default_credential.access_key_id, - access_key=bucket.default_credential.access_key.get_secret_value(), - endpoint=bucket_info.profile + access_key=bucket.default_credential.access_key, + endpoint=bucket_info.profile, ) break - raise osmo_errors.OSMOCredentialError( - f'Could not find {profile} credential for user {user}.') + return None - def get_all_data_creds(self, user: str) -> Dict[str, credentials.DecryptedDataCredential]: + def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCredential]: """ Fetch all data credentials for user. """ select_data_cmd = PostgresSelectCommand( table='credential', @@ -1251,18 +1254,22 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.DecryptedDataCr condition_args=[user, CredentialType.DATA.value]) rows = self.execute_fetch_command(*select_data_cmd.get_args()) - user_creds = {cred.profile: credentials.DecryptedDataCredential( - **self.decrypt_credential(cred)) - for cred in rows} + user_creds = { + cred.profile: credentials.StaticDataCredential( + endpoint=cred.profile, + **self.decrypt_credential(cred), + ) + for cred in rows + } # Add default bucket creds for bucket in self.get_dataset_configs().buckets.values(): bucket_info = storage.construct_storage_backend(bucket.dataset_path) if bucket_info.profile not in user_creds and bucket.default_credential: - user_creds[bucket_info.profile] = credentials.DecryptedDataCredential( + user_creds[bucket_info.profile] = credentials.StaticDataCredential( region=bucket.region, access_key_id=bucket.default_credential.access_key_id, - access_key=bucket.default_credential.access_key.get_secret_value(), + access_key=bucket.default_credential.access_key, endpoint=bucket_info.profile ) return user_creds @@ -2275,7 +2282,7 @@ def construct_path(endpoint: str, bucket: str, path: str): class LogConfig(ExtraArgBaseModel): """ Config for storing information about data. """ - credential: credentials.DataCredential | None = None + credential: credentials.StaticDataCredential | None = None class WorkflowInfo(ExtraArgBaseModel): @@ -2292,7 +2299,7 @@ def validate_name(self, name: str): class DataConfig(ExtraArgBaseModel): """ Config for storing information about data. """ - credential: credentials.DataCredential | None = None + credential: credentials.StaticDataCredential | None = None base_url: str = '' # Timeout in mins for osmo-ctrl to retry connecting to the OSMO service until exiting the task @@ -2331,7 +2338,7 @@ class BucketConfig(ExtraArgBaseModel): # Default cred to use doesn't have one # Only applies to workflow operations, NOT user cli since we cannot forward the credential # to the user - default_credential: credentials.BasicDataCredential | None = None + default_credential: credentials.StaticDataCredential | None = None def valid_access(self, bucket_name: str, access_type: BucketModeAccess): if not ((access_type == BucketModeAccess.READ and\ diff --git a/src/utils/job/BUILD b/src/utils/job/BUILD index 5728454d..d1e926c1 100644 --- a/src/utils/job/BUILD +++ b/src/utils/job/BUILD @@ -84,8 +84,8 @@ osmo_py_library( requirement("pyyaml"), requirement("requests"), requirement("urllib3"), - "//src/lib/data/constants", "//src/lib/data/storage", + "//src/lib/data/storage/constants", "//src/lib/utils:cache", "//src/lib/utils:common", "//src/lib/utils:credentials", diff --git a/src/utils/job/task.py b/src/utils/job/task.py index 5f0fa9d5..d2836adf 100644 --- a/src/utils/job/task.py +++ b/src/utils/job/task.py @@ -33,7 +33,8 @@ import urllib3 # type: ignore import yaml -from src.lib.data import constants, storage +from src.lib.data import storage +from src.lib.data.storage import constants from src.lib.utils import (cache, common, credentials, jinja_sandbox, osmo_errors, priority as wf_priority) from src.utils.job import common as task_common, kb_objects @@ -88,8 +89,10 @@ def create_login_dict(user: str, -def create_config_dict(data_info: Dict[str, credentials.DecryptedDataCredential], - cache_config: Optional[cache.CacheConfig] = None) -> Dict: +def create_config_dict( + data_info: dict[str, credentials.StaticDataCredential], + cache_config: cache.CacheConfig | None = None, +) -> dict: ''' Creates the config dict where the input should be a dict containing key values like: url: @@ -98,11 +101,16 @@ def create_config_dict(data_info: Dict[str, credentials.DecryptedDataCredential] ''' data = { 'auth': { - 'data': {data_key: data_value.dict() for data_key, data_value in data_info.items()} + 'data': { + data_key: data_value.to_decrypted_dict() + for data_key, data_value in data_info.items() + } } } + if cache_config: data['cache'] = cache_config.dict() + return data @@ -2132,7 +2140,7 @@ def convert_to_pod_spec( service_config: connectors.ServiceConfig | None = None, dataset_config: connectors.DatasetConfig | None = None, pool_info: connectors.Pool | None = None, - data_endpoints: Dict[str, credentials.DecryptedDataCredential] | None = None, + data_endpoints: Dict[str, credentials.StaticDataCredential] | None = None, skip_refresh_token: bool = False, ) -> Tuple[Dict, Dict[str, kb_objects.FileMount]]: """ @@ -2174,8 +2182,6 @@ def convert_to_pod_spec( input_urls: List[str] = [] input_datasets: List[str] = [] - service_creds = credentials.decrypt(workflow_config.workflow_data.credential) - disabled_data = workflow_config.credential_config.disable_data_validation # TODO: Make extra_args a dumped json to be parsed by osmo-ctrl for index, spec_input in enumerate(task_spec.inputs): @@ -2309,7 +2315,9 @@ def convert_to_pod_spec( service_profile = storage.construct_storage_backend( workflow_config.workflow_data.credential.endpoint).profile - service_config_yaml = create_config_dict({service_profile: service_creds}) + service_config_yaml = create_config_dict({ + service_profile: workflow_config.workflow_data.credential, + }) # User CLI login config login_file = File(path='/login', @@ -2597,14 +2605,18 @@ def decode_hstore(tasks: str) -> Set[str]: return {tp[0] for tp in re.findall(f'"({task_common.NAMEREGEX})"=>"NULL"', tasks)} -def fetch_creds(user: str, data_creds: Dict[str, credentials.DecryptedDataCredential], path: str, - disabled_data: Optional[List[str]] = None) -> Dict: +def fetch_creds( + user: str, + data_creds: dict[str, credentials.StaticDataCredential], + path: str, + disabled_data: list[str] | None = None, +) -> credentials.StaticDataCredential | None: backend_info = storage.construct_storage_backend(path) if backend_info.profile not in data_creds: if not disabled_data or backend_info.scheme not in disabled_data: raise osmo_errors.OSMOCredentialError( f'Could not find {backend_info.profile} credential for user {user}.') - return {} + return None - return data_creds[backend_info.profile].dict() + return data_creds[backend_info.profile] diff --git a/src/utils/job/workflow.py b/src/utils/job/workflow.py index 48526d51..1e6a7dd4 100644 --- a/src/utils/job/workflow.py +++ b/src/utils/job/workflow.py @@ -647,15 +647,22 @@ def _fetch_bucket_info(dataset_info: common.DatasetStructure)\ data_cred = task.fetch_creds(user, user_creds, bucket_info.uri) - # Get if user has access to READ or WRITE - if is_input and bucket_info.uri not in seen_uri_input: - bucket_info.data_auth(data_cred['access_key_id'], data_cred['access_key'], - data_cred['region'], storage.AccessType.READ) - seen_uri_input.add(bucket_info.uri) - if not is_input and bucket_info.uri not in seen_uri_output: - bucket_info.data_auth(data_cred['access_key_id'], data_cred['access_key'], - data_cred['region'], storage.AccessType.WRITE) - seen_uri_output.add(bucket_info.uri) + if data_cred is None: + # User does not have any credentials, check if the backend + # supports environment authentication + if not bucket_info.supports_environment_auth: + raise osmo_errors.OSMOCredentialError( + f'Could not validate access to {bucket_info.uri} for user {user}.') + else: + # Check if user credentials have access to READ + if is_input and bucket_info.uri not in seen_uri_input: + bucket_info.data_auth(data_cred, storage.AccessType.READ) + seen_uri_input.add(bucket_info.uri) + + # Check if user credentials have access to WRITE + if not is_input and bucket_info.uri not in seen_uri_output: + bucket_info.data_auth(data_cred, storage.AccessType.WRITE) + seen_uri_output.add(bucket_info.uri) for input_data_spec in group_task.inputs: _validate_input_output(input_data_spec, True)