Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,7 @@ def write(
consolidate_metadata: bool = True,
update_sdata_path: bool = True,
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
) -> None:
"""
Write the `SpatialData` object to a Zarr store.
Expand Down Expand Up @@ -1156,11 +1157,17 @@ def write(
unspecified, the element formats will be set to the latest element format compatible with the specified
SpatialData container format. All the formats and relationships between them are defined in
`spatialdata._io.format.py`.
compressor
A lenght-1 dictionary with as key the type of compression to use for images and labels and as value the
compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are
supported. If not specified, the compression will be `lz4` with compression level 5. Bytes are automatically
ordered for more efficient compression.
"""
from spatialdata._io._utils import _resolve_zarr_store
from spatialdata._io._utils import _resolve_zarr_store, _validate_compressor_args
from spatialdata._io.format import _parse_formats

parsed = _parse_formats(sdata_formats)
_validate_compressor_args(compressor)

if isinstance(file_path, str):
file_path = Path(file_path)
Expand All @@ -1181,6 +1188,7 @@ def write(
element_name=element_name,
overwrite=False,
parsed_formats=parsed,
compressor=compressor,
)

if self.path != file_path and update_sdata_path:
Expand All @@ -1197,6 +1205,7 @@ def _write_element(
element_name: str,
overwrite: bool,
parsed_formats: dict[str, SpatialDataFormatType] | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
) -> None:
from spatialdata._io.io_zarr import _get_groups_for_element

Expand Down Expand Up @@ -1230,13 +1239,15 @@ def _write_element(
group=element_group,
name=element_name,
element_format=parsed_formats["raster"],
compressor=compressor,
)
elif element_type == "labels":
write_labels(
labels=element,
group=root_group,
name=element_name,
element_format=parsed_formats["raster"],
compressor=compressor,
)
elif element_type == "points":
write_points(
Expand Down Expand Up @@ -1265,6 +1276,7 @@ def write_element(
element_name: str | list[str],
overwrite: bool = False,
sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
) -> None:
"""
Write a single element, or a list of elements, to the Zarr store used for backing.
Expand All @@ -1280,6 +1292,11 @@ def write_element(
sdata_formats
It is recommended to leave this parameter equal to `None`. See more details in the documentation of
`SpatialData.write()`.
compressor
A lenght-1 dictionary with as key the type of compression to use for images and labels and as value the
compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are
supported. If not specified, the compression will be `lz4` with compression level 5. Bytes are automatically
ordered for more efficient compression.

Notes
-----
Expand All @@ -1293,7 +1310,7 @@ def write_element(
if isinstance(element_name, list):
for name in element_name:
assert isinstance(name, str)
self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats)
self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats, compressor=compressor)
return

check_valid_name(element_name)
Expand Down Expand Up @@ -1327,6 +1344,7 @@ def write_element(
element_name=element_name,
overwrite=overwrite,
parsed_formats=parsed_formats,
compressor=compressor,
)
# After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting.
if self.has_consolidated_metadata():
Expand Down
19 changes: 19 additions & 0 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,22 @@ def handle_read_errors(
else: # on_bad_files == BadFileHandleMethod.ERROR
# Let it raise exceptions
yield


def _validate_compressor_args(compressor_dict: dict[Literal["lz4", "zstd"], int] | None) -> None:
if compressor_dict:
if not isinstance(compressor_dict, dict):
raise TypeError(
f"Expected a dictionary with as key the type of compression to use for images and labels and "
f"as value the compression level which should be inclusive between 1 and 9. "
f"Got type: {type(compressor_dict)}"
)
if len(compressor_dict) != 1:
raise ValueError(
"Expected a dictionary with a single key indicating the type of compression, either 'lz4' or "
"'zstd' and an `int` inclusive between 1 and 9 as value representing the compression level."
)
if (compression := list(compressor_dict.keys())[0]) not in ["lz4", "zstd"]:
raise ValueError(f"Compression must either be `lz4` or `zstd`, got: {compression}.")
if not isinstance(value := list(compressor_dict.values())[0], int) or not (0 <= value <= 9):
raise ValueError(f"The compression level must be an integer inclusive between 0 and 9. Got: {value}")
88 changes: 84 additions & 4 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, cast

import dask.array as da
import numpy as np
Expand Down Expand Up @@ -153,6 +153,7 @@ def _write_raster(
name: str,
raster_format: RasterFormatType,
storage_options: JSONDict | list[JSONDict] | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
label_metadata: JSONDict | None = None,
**metadata: str | JSONDict | list[JSONDict],
) -> None:
Expand All @@ -172,6 +173,8 @@ def _write_raster(
The format used to write the raster data.
storage_options
Additional options for writing the raster data, like chunks and compression.
compressor
Compression settings as a len-1 dictionary with a single key-value {compression: compression level} pair
label_metadata
Label metadata which can only be defined when writing 'labels'.
metadata
Expand Down Expand Up @@ -201,6 +204,7 @@ def _write_raster(
raster_data,
raster_format,
storage_options,
compressor=compressor,
**metadata,
)
elif isinstance(raster_data, DataTree):
Expand All @@ -211,6 +215,7 @@ def _write_raster(
raster_data,
raster_format,
storage_options,
compressor=compressor,
**metadata,
)
else:
Expand All @@ -225,13 +230,64 @@ def _write_raster(
group.attrs[ATTRS_KEY] = attrs


def _apply_compression(
storage_options: JSONDict | list[JSONDict],
compressor: dict[Literal["lz4", "zstd"], int] | None,
zarr_format: Literal[2, 3] = 3,
) -> JSONDict | list[JSONDict]:
"""Apply compression settings to storage options.

Parameters
----------
storage_options
Storage options for zarr arrays
compressor
Compression settings as a dictionary with a single key-value pair
zarr_format
The zarr format version (2 or 3)

Returns
-------
Updated storage options with compression settings
"""
# For zarr disk format v2, use numcodecs.Blosc
# For zarr disk format v3, use zarr.codecs.Blosc
from numcodecs import Blosc as BloscV2
from zarr.codecs import Blosc as BloscV3

if not compressor:
return storage_options

((compression, compression_level),) = compressor.items()

assert BloscV2.SHUFFLE == 1
blosc_v2 = BloscV2(cname=compression, clevel=compression_level, shuffle=1)
blosc_v3 = BloscV3(cname=compression, clevel=compression_level, shuffle=1)

def _update_dict(d: dict[str, Any]) -> None:
if zarr_format == 2:
d["compressor"] = blosc_v2
elif zarr_format == 3:
d["zarr_array_kwargs"] = {"compressors": [blosc_v3]}

if isinstance(storage_options, dict):
_update_dict(d=storage_options)
else:
assert isinstance(storage_options, list)
for option in storage_options:
_update_dict(d=option)

return storage_options


def _write_raster_dataarray(
raster_type: Literal["image", "labels"],
group: zarr.Group,
element_name: str,
raster_data: DataArray,
raster_format: RasterFormatType,
storage_options: JSONDict | list[JSONDict] | None = None,
storage_options: JSONDict | list[JSONDict] | None,
compressor: dict[Literal["lz4", "zstd"], int] | None,
**metadata: str | JSONDict | list[JSONDict],
) -> None:
"""Write raster data of type DataArray to disk.
Expand All @@ -250,6 +306,8 @@ def _write_raster_dataarray(
The format used to write the raster data.
storage_options
Additional options for writing the raster data, like chunks and compression.
compressor
Compression settings as a len-1 dictionary with a single key-value {compression: compression level} pair
metadata
Additional metadata for the raster element
"""
Expand All @@ -262,11 +320,19 @@ def _write_raster_dataarray(
input_axes: tuple[str, ...] = tuple(raster_data.dims)
chunks = raster_data.chunks
parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)

# TODO: improve and test the behavior around storage option after merge <-------
if storage_options is not None:
if "chunks" not in storage_options and isinstance(storage_options, dict):
storage_options["chunks"] = chunks
else:
storage_options = {"chunks": chunks}

# Apply compression if specified
storage_options = _apply_compression(
storage_options, compressor, zarr_format=cast(Literal[2, 3], raster_format.zarr_format)
)

# Scaler needs to be None since we are passing the data already downscaled for the multiscale case.
# We need this because the argument of write_image_ngff is called image while the argument of
# write_labels_ngff is called label.
Expand Down Expand Up @@ -297,7 +363,8 @@ def _write_raster_datatree(
element_name: str,
raster_data: DataTree,
raster_format: RasterFormatType,
storage_options: JSONDict | list[JSONDict] | None = None,
storage_options: JSONDict | list[JSONDict] | None,
compressor: dict[Literal["lz4", "zstd"], int] | None,
**metadata: str | JSONDict | list[JSONDict],
) -> None:
"""Write raster data of type DataTree to disk.
Expand All @@ -316,6 +383,8 @@ def _write_raster_datatree(
The format used to write the raster data.
storage_options
Additional options for writing the raster data, like chunks and compression.
compressor
Compression settings as a len-1 dictionary with a single key-value {compression: compression level} pair
metadata
Additional metadata for the raster element
"""
Expand All @@ -331,10 +400,17 @@ def _write_raster_datatree(
transformations = _get_transformations_xarray(xdata)
if transformations is None:
raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.")

chunks = get_pyramid_levels(raster_data, "chunks")
# TODO: improve and test the behavior around storage option after merge <-------
if storage_options is None:
storage_options = [{"chunks": chunk} for chunk in chunks]
else:
storage_options = [{"chunks": chunk} for chunk in chunks]
# Apply compression if specified
storage_options = _apply_compression(storage_options, compressor, zarr_format=raster_format.zarr_format)

parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format)
storage_options = [{"chunks": chunk} for chunk in chunks]
ome_zarr_format = get_ome_zarr_format(raster_format)
dask_delayed = write_multi_scale_ngff(
pyramid=data,
Expand Down Expand Up @@ -366,6 +442,7 @@ def write_image(
name: str,
element_format: RasterFormatType = CurrentRasterFormat(),
storage_options: JSONDict | list[JSONDict] | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
**metadata: str | JSONDict | list[JSONDict],
) -> None:
_write_raster(
Expand All @@ -375,6 +452,7 @@ def write_image(
name=name,
raster_format=element_format,
storage_options=storage_options,
compressor=compressor,
**metadata,
)

Expand All @@ -386,6 +464,7 @@ def write_labels(
element_format: RasterFormatType = CurrentRasterFormat(),
storage_options: JSONDict | list[JSONDict] | None = None,
label_metadata: JSONDict | None = None,
compressor: dict[Literal["lz4", "zstd"], int] | None = None,
**metadata: JSONDict,
) -> None:
_write_raster(
Expand All @@ -395,6 +474,7 @@ def write_labels(
name=name,
raster_format=element_format,
storage_options=storage_options,
compressor=compressor,
label_metadata=label_metadata,
**metadata,
)
Loading
Loading