Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions src/cli/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@

import shtab

from src.lib.utils import client, client_configs, common, credentials, osmo_errors
from src.lib.utils import client, client_configs, credentials, common, osmo_errors

CRED_TYPES = ['REGISTRY', 'DATA', 'GENERIC']


def _save_config(data_cred: credentials.DataCredential):
def _save_config(data_cred: credentials.StaticDataCredential):
"""
Sets default config information
"""
Expand All @@ -46,7 +46,8 @@ def _save_config(data_cred: credentials.DataCredential):
config['auth']['data'][data_cred.endpoint] = {
'access_key_id': data_cred.access_key_id,
'access_key': data_cred.access_key.get_secret_value(),
'region': data_cred.region}
'region': data_cred.region,
}
with open(password_file, 'w', encoding='utf-8') as file:
yaml.dump(config, file)
os.chmod(password_file, stat.S_IREAD | stat.S_IWRITE)
Expand Down Expand Up @@ -97,21 +98,30 @@ def _run_set_command(service_client: client.ServiceClient, args: argparse.Namesp
print(f'File {value} cannot be found.')
sys.exit(1)

if args.type == 'GENERIC':
if args.type == 'DATA':
# Validate that the data credential is valid
try:
credentials.StaticDataCredential(**cred_payload)
except Exception as err: # pylint: disable=broad-except
raise osmo_errors.OSMOUserError(f'Invalid DATA credential: {str(err)}')

elif args.type == 'GENERIC':
cred_payload = {'credential': cred_payload}

result = service_client.request(client.RequestMethod.POST,
f'api/credentials/{args.name}',
payload={args.type.lower() + '_credential': cred_payload})
result = service_client.request(
client.RequestMethod.POST,
f'api/credentials/{args.name}',
payload={args.type.lower() + '_credential': cred_payload},
)

if args.format_type == 'json':
print(json.dumps(result, indent=2))
else:
print(f'Set {args.type} credential {args.name}.')

if args.type == 'DATA':
if 'endpoint' not in cred_payload:
raise osmo_errors.OSMOUserError('Endpoint is required for DATA credentials.')
_save_config(credentials.DataCredential(**cred_payload))
# Save the data credential to the client config
_save_config(credentials.StaticDataCredential(**cred_payload))


def _run_list_command(service_client: client.ServiceClient, args: argparse.Namespace):
Expand All @@ -132,11 +142,11 @@ def _run_list_command(service_client: client.ServiceClient, args: argparse.Names
for cred in result['credentials']:
cred['local'] = 'N/A'
if cred.get('cred_type', '') == 'DATA':
try:
client_configs.get_credentials(cred.get('profile', ''))
cred['local'] = 'Yes'
except osmo_errors.OSMOError:
cred['local'] = 'No'
data_cred = credentials.get_static_data_credential_from_config(
url=cred.get('profile', ''),
)
cred['local'] = 'Yes' if data_cred else 'No'

table.add_row([cred.get(column, '-') for column in columns])
print(f'{table.draw()}\n')

Expand Down Expand Up @@ -206,15 +216,15 @@ def setup_parser(parser: argparse._SubParsersAction):
'payload corresponding to each type of credential:\n'
'\n'
# pylint: disable=line-too-long
'+-----------------+---------------------------+---------------------------------------+\n'
'| Credential Type | Mandatory keys | Optional keys |\n'
'+-----------------+---------------------------+---------------------------------------+\n'
'| REGISTRY | auth | registry, username |\n'
'+-----------------+---------------------------+---------------------------------------+\n'
'| DATA | access_key_id, access_key | endpoint, region (default: us-east-1) |\n'
'+-----------------+---------------------------+---------------------------------------+\n'
'| GENERIC | | |\n'
'+-----------------+---------------------------+---------------------------------------+\n'
'+-----------------+-------------------------------------+-----------------------------+\n'
'| Credential Type | Mandatory keys | Optional keys |\n'
'+-----------------+-------------------------------------+-----------------------------+\n'
'| REGISTRY | auth | registry, username |\n'
'+-----------------+-------------------------------------+-----------------------------+\n'
'| DATA | access_key_id, access_key, endpoint | region (default: us-east-1) |\n'
'+-----------------+-------------------------------------+-----------------------------+\n'
'| GENERIC | | |\n'
'+-----------------+-------------------------------------+-----------------------------+\n'
# pylint: enable=line-too-long
'\n'
),
Expand Down
1 change: 0 additions & 1 deletion src/lib/data/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ osmo_py_library(
"//src/lib/data/storage",
"//src/lib/utils:cache",
"//src/lib/utils:client",
"//src/lib/utils:client_configs",
"//src/lib/utils:logging",
"//src/lib/utils:common",
"//src/lib/utils:osmo_errors",
Expand Down
11 changes: 4 additions & 7 deletions src/lib/data/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@
import diskcache
import pydantic

from .. import storage, constants
from .. import storage
from ..storage import constants
from ..storage.core import progress
from ...utils import client, client_configs, common, osmo_errors, paths
from ...utils import client, common, osmo_errors, paths


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -305,12 +306,8 @@ def _validate_source_path(
if re.fullmatch(constants.STORAGE_BACKEND_REGEX, source_path):
# Remote path logic
path_components = storage.construct_storage_backend(source_path)
user_credentials = client_configs.get_credentials(path_components.profile)
path_components.data_auth(
user_credentials.access_key_id,
user_credentials.access_key.get_secret_value(),
user_credentials.region,
storage.AccessType.READ,
access_type=storage.AccessType.READ,
)

return RemotePath(path_components, has_asterisk, priority)
Expand Down
23 changes: 3 additions & 20 deletions src/lib/data/dataset/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,7 @@ def upload_start(
location_result['path'],
cache_config=self.cache_config,
)
credentials = client_configs.get_credentials(path_components.profile)
path_components.data_auth(
credentials.access_key_id,
credentials.access_key.get_secret_value(),
credentials.region,
storage.AccessType.WRITE,
)
path_components.data_auth(access_type=storage.AccessType.WRITE)

# Parse and validate the input paths
local_paths, backend_paths = common.parse_upload_paths(input_paths)
Expand Down Expand Up @@ -398,24 +392,13 @@ def _update_dataset_start(
location_result['path'],
cache_config=self.cache_config,
)
credentials = client_configs.get_credentials(path_components.profile)

if remove_regex:
# Validate delete access
path_components.data_auth(
credentials.access_key_id,
credentials.access_key.get_secret_value(),
credentials.region,
storage.AccessType.DELETE,
)
path_components.data_auth(access_type=storage.AccessType.DELETE)
if add_paths:
# Validate write access
path_components.data_auth(
credentials.access_key_id,
credentials.access_key.get_secret_value(),
credentials.region,
storage.AccessType.WRITE,
)
path_components.data_auth(access_type=storage.AccessType.WRITE)

# If add_paths is provided, seperate paths and perform basic authentication
# against backend paths
Expand Down
10 changes: 2 additions & 8 deletions src/lib/data/dataset/migrating.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .. import storage
from ..storage import common as storage_common, copying
from ..storage.core import client, executor, progress, provider
from ...utils import cache, client_configs, osmo_errors
from ...utils import cache, osmo_errors


MANIFEST_REGEX_PATTERN = re.compile(r'.*\/manifests\/[0-9]+\.json$')
Expand Down Expand Up @@ -221,15 +221,9 @@ def migrate(
# Resolve the region for the destination storage backend.
# This is necessary for generating a valid regional HTTP URL for uploaded objects
# for certain storage backends (e.g. AWS S3).
destination_creds = client_configs.get_credentials(destination_backend.profile)
destination_region = destination_backend.region(
destination_creds.access_key_id,
destination_creds.access_key.get_secret_value(),
)
destination_region = destination_backend.region()

client_factory = destination_backend.client_factory(
access_key_id=destination_creds.access_key_id,
access_key=destination_creds.access_key.get_secret_value(),
region=destination_region,
)

Expand Down
12 changes: 3 additions & 9 deletions src/lib/data/dataset/updating.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from . import common, uploading
from .. import storage
from ..storage.core import executor
from ...utils import cache, client_configs, osmo_errors
from ...utils import cache, osmo_errors


#################################
Expand Down Expand Up @@ -224,17 +224,11 @@ def update(
# Resolve the region for the destination storage backend.
# This is necessary for generating a valid regional HTTP URL for uploaded objects
# for certain storage backends (e.g. AWS S3).
destination_creds = client_configs.get_credentials(destination.profile)
destination_region = destination.region(
destination_creds.access_key_id,
destination_creds.access_key.get_secret_value(),
)
destination_region = destination.region()

client_factory = destination.client_factory(
access_key_id=destination_creds.access_key_id,
access_key=destination_creds.access_key.get_secret_value(),
region=destination_region,
request_headers=request_headers,
region=destination_region,
)

manifest_cache = diskcache.Index()
Expand Down
19 changes: 5 additions & 14 deletions src/lib/data/dataset/uploading.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .. import storage
from ..storage import common as storage_common, uploading
from ..storage.core import client, executor, progress, provider
from ...utils import cache, client_configs, common as utils_common, osmo_errors
from ...utils import cache, common as utils_common, osmo_errors


##########################
Expand Down Expand Up @@ -143,12 +143,9 @@ def dataset_upload_remote_file_entry_generator(
)

# Create the base to be used for objects' HTTP URLs.
data_creds = storage_client.data_credential
data_cred = storage_client.data_credential
url_base = storage_backend.parse_uri_to_link(
storage_backend.region(
data_creds.access_key_id,
data_creds.access_key.get_secret_value(),
),
region=storage_backend.region(data_cred),
)

# Iterate over the objects in the remote path.
Expand Down Expand Up @@ -381,17 +378,11 @@ def upload(
# Resolve the region for the destination storage backend.
# This is necessary for generating a valid regional HTTP URL for uploaded objects
# for certain storage backends (e.g. AWS S3).
destination_creds = client_configs.get_credentials(destination.profile)
destination_region = destination.region(
destination_creds.access_key_id,
destination_creds.access_key.get_secret_value(),
)
destination_region = destination.region()

client_factory = destination.client_factory(
access_key_id=destination_creds.access_key_id,
access_key=destination_creds.access_key.get_secret_value(),
region=destination_region,
request_headers=request_headers,
region=destination_region,
)

manifest_cache = diskcache.Index()
Expand Down
4 changes: 2 additions & 2 deletions src/lib/data/storage/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ osmo_py_library(
srcs = glob(["*.py"]),
deps = [
"//src/lib/utils:client_configs",
"//src/lib/utils:credentials",
"//src/lib/utils:common",
"//src/lib/utils:logging",
"//src/lib/utils:paths",
"//src/lib/utils:osmo_errors",
"//src/lib/data/constants",
"//src/lib/data/storage/constants",
"//src/lib/data/storage/credentials",
"//src/lib/data/storage/core",
"//src/lib/data/storage/backends",
requirement("pydantic"),
Expand Down
3 changes: 2 additions & 1 deletion src/lib/data/storage/backends/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ osmo_py_library(
"//src/lib/utils:cache",
"//src/lib/utils:common",
"//src/lib/utils:osmo_errors",
"//src/lib/data/constants",
"//src/lib/data/storage/constants",
"//src/lib/data/storage/core",
"//src/lib/data/storage/credentials",
requirement("azure-storage-blob"),
requirement("azure-identity"),
requirement("boto3"),
Expand Down
20 changes: 14 additions & 6 deletions src/lib/data/storage/backends/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from azure.core import exceptions
from azure.storage import blob

from .. import credentials
from ..core import client, provider
from ....utils import common

Expand Down Expand Up @@ -270,13 +271,20 @@ def __next__(self) -> bytes:
return chunk


def create_client(
connection_string: str,
) -> blob.BlobServiceClient:
def create_client(data_cred: credentials.DataCredential) -> blob.BlobServiceClient:
"""
Creates a new Azure Blob Storage client.
"""
return blob.BlobServiceClient.from_connection_string(conn_str=connection_string)
match data_cred:
case credentials.StaticDataCredential():
return blob.BlobServiceClient.from_connection_string(
conn_str=data_cred.access_key.get_secret_value(),
)
case credentials.DefaultDataCredential():
raise NotImplementedError(
'Default data credentials are not supported yet')
case _ as unreachable:
assert_never(unreachable)


class AzureBlobStorageClient(client.StorageClient):
Expand Down Expand Up @@ -822,10 +830,10 @@ class AzureBlobStorageClientFactory(provider.StorageClientFactory):
Factory for the AzureBlobStorageClient.
"""

connection_string: str
data_cred: credentials.DataCredential

@override
def create(self) -> AzureBlobStorageClient:
return AzureBlobStorageClient(
lambda: create_client(self.connection_string),
lambda: create_client(self.data_cred),
)
Loading