diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index abc5e6ec..221dfc47 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1112,6 +1112,7 @@ def write( consolidate_metadata: bool = True, update_sdata_path: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1156,11 +1157,17 @@ def write( unspecified, the element formats will be set to the latest element format compatible with the specified SpatialData container format. All the formats and relationships between them are defined in `spatialdata._io.format.py`. + compressor + A lenght-1 dictionary with as key the type of compression to use for images and labels and as value the + compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are + supported. If not specified, the compression will be `lz4` with compression level 5. Bytes are automatically + ordered for more efficient compression. """ - from spatialdata._io._utils import _resolve_zarr_store + from spatialdata._io._utils import _resolve_zarr_store, _validate_compressor_args from spatialdata._io.format import _parse_formats parsed = _parse_formats(sdata_formats) + _validate_compressor_args(compressor) if isinstance(file_path, str): file_path = Path(file_path) @@ -1181,6 +1188,7 @@ def write( element_name=element_name, overwrite=False, parsed_formats=parsed, + compressor=compressor, ) if self.path != file_path and update_sdata_path: @@ -1197,6 +1205,7 @@ def _write_element( element_name: str, overwrite: bool, parsed_formats: dict[str, SpatialDataFormatType] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element @@ -1230,6 +1239,7 @@ def _write_element( group=element_group, name=element_name, element_format=parsed_formats["raster"], + compressor=compressor, ) elif element_type == "labels": write_labels( @@ -1237,6 +1247,7 @@ def _write_element( group=root_group, name=element_name, element_format=parsed_formats["raster"], + compressor=compressor, ) elif element_type == "points": write_points( @@ -1265,6 +1276,7 @@ def write_element( element_name: str | list[str], overwrite: bool = False, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. @@ -1280,6 +1292,11 @@ def write_element( sdata_formats It is recommended to leave this parameter equal to `None`. See more details in the documentation of `SpatialData.write()`. + compressor + A lenght-1 dictionary with as key the type of compression to use for images and labels and as value the + compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are + supported. If not specified, the compression will be `lz4` with compression level 5. Bytes are automatically + ordered for more efficient compression. Notes ----- @@ -1293,7 +1310,7 @@ def write_element( if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) - self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats) + self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats, compressor=compressor) return check_valid_name(element_name) @@ -1327,6 +1344,7 @@ def write_element( element_name=element_name, overwrite=overwrite, parsed_formats=parsed_formats, + compressor=compressor, ) # After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting. if self.has_consolidated_metadata(): diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index b58d6744..f82f3370 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -543,3 +543,22 @@ def handle_read_errors( else: # on_bad_files == BadFileHandleMethod.ERROR # Let it raise exceptions yield + + +def _validate_compressor_args(compressor_dict: dict[Literal["lz4", "zstd"], int] | None) -> None: + if compressor_dict: + if not isinstance(compressor_dict, dict): + raise TypeError( + f"Expected a dictionary with as key the type of compression to use for images and labels and " + f"as value the compression level which should be inclusive between 1 and 9. " + f"Got type: {type(compressor_dict)}" + ) + if len(compressor_dict) != 1: + raise ValueError( + "Expected a dictionary with a single key indicating the type of compression, either 'lz4' or " + "'zstd' and an `int` inclusive between 1 and 9 as value representing the compression level." + ) + if (compression := list(compressor_dict.keys())[0]) not in ["lz4", "zstd"]: + raise ValueError(f"Compression must either be `lz4` or `zstd`, got: {compression}.") + if not isinstance(value := list(compressor_dict.values())[0], int) or not (0 <= value <= 9): + raise ValueError(f"The compression level must be an integer inclusive between 0 and 9. Got: {value}") diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index bc8206db..b8eaddfc 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast import dask.array as da import numpy as np @@ -153,6 +153,7 @@ def _write_raster( name: str, raster_format: RasterFormatType, storage_options: JSONDict | list[JSONDict] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, label_metadata: JSONDict | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: @@ -172,6 +173,8 @@ def _write_raster( The format used to write the raster data. storage_options Additional options for writing the raster data, like chunks and compression. + compressor + Compression settings as a len-1 dictionary with a single key-value {compression: compression level} pair label_metadata Label metadata which can only be defined when writing 'labels'. metadata @@ -201,6 +204,7 @@ def _write_raster( raster_data, raster_format, storage_options, + compressor=compressor, **metadata, ) elif isinstance(raster_data, DataTree): @@ -211,6 +215,7 @@ def _write_raster( raster_data, raster_format, storage_options, + compressor=compressor, **metadata, ) else: @@ -225,13 +230,64 @@ def _write_raster( group.attrs[ATTRS_KEY] = attrs +def _apply_compression( + storage_options: JSONDict | list[JSONDict], + compressor: dict[Literal["lz4", "zstd"], int] | None, + zarr_format: Literal[2, 3] = 3, +) -> JSONDict | list[JSONDict]: + """Apply compression settings to storage options. + + Parameters + ---------- + storage_options + Storage options for zarr arrays + compressor + Compression settings as a dictionary with a single key-value pair + zarr_format + The zarr format version (2 or 3) + + Returns + ------- + Updated storage options with compression settings + """ + # For zarr disk format v2, use numcodecs.Blosc + # For zarr disk format v3, use zarr.codecs.Blosc + from numcodecs import Blosc as BloscV2 + from zarr.codecs import Blosc as BloscV3 + + if not compressor: + return storage_options + + ((compression, compression_level),) = compressor.items() + + assert BloscV2.SHUFFLE == 1 + blosc_v2 = BloscV2(cname=compression, clevel=compression_level, shuffle=1) + blosc_v3 = BloscV3(cname=compression, clevel=compression_level, shuffle=1) + + def _update_dict(d: dict[str, Any]) -> None: + if zarr_format == 2: + d["compressor"] = blosc_v2 + elif zarr_format == 3: + d["zarr_array_kwargs"] = {"compressors": [blosc_v3]} + + if isinstance(storage_options, dict): + _update_dict(d=storage_options) + else: + assert isinstance(storage_options, list) + for option in storage_options: + _update_dict(d=option) + + return storage_options + + def _write_raster_dataarray( raster_type: Literal["image", "labels"], group: zarr.Group, element_name: str, raster_data: DataArray, raster_format: RasterFormatType, - storage_options: JSONDict | list[JSONDict] | None = None, + storage_options: JSONDict | list[JSONDict] | None, + compressor: dict[Literal["lz4", "zstd"], int] | None, **metadata: str | JSONDict | list[JSONDict], ) -> None: """Write raster data of type DataArray to disk. @@ -250,6 +306,8 @@ def _write_raster_dataarray( The format used to write the raster data. storage_options Additional options for writing the raster data, like chunks and compression. + compressor + Compression settings as a len-1 dictionary with a single key-value {compression: compression level} pair metadata Additional metadata for the raster element """ @@ -262,11 +320,19 @@ def _write_raster_dataarray( input_axes: tuple[str, ...] = tuple(raster_data.dims) chunks = raster_data.chunks parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) + + # TODO: improve and test the behavior around storage option after merge <------- if storage_options is not None: if "chunks" not in storage_options and isinstance(storage_options, dict): storage_options["chunks"] = chunks else: storage_options = {"chunks": chunks} + + # Apply compression if specified + storage_options = _apply_compression( + storage_options, compressor, zarr_format=cast(Literal[2, 3], raster_format.zarr_format) + ) + # Scaler needs to be None since we are passing the data already downscaled for the multiscale case. # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. @@ -297,7 +363,8 @@ def _write_raster_datatree( element_name: str, raster_data: DataTree, raster_format: RasterFormatType, - storage_options: JSONDict | list[JSONDict] | None = None, + storage_options: JSONDict | list[JSONDict] | None, + compressor: dict[Literal["lz4", "zstd"], int] | None, **metadata: str | JSONDict | list[JSONDict], ) -> None: """Write raster data of type DataTree to disk. @@ -316,6 +383,8 @@ def _write_raster_datatree( The format used to write the raster data. storage_options Additional options for writing the raster data, like chunks and compression. + compressor + Compression settings as a len-1 dictionary with a single key-value {compression: compression level} pair metadata Additional metadata for the raster element """ @@ -331,10 +400,17 @@ def _write_raster_datatree( transformations = _get_transformations_xarray(xdata) if transformations is None: raise ValueError(f"{element_name} does not have any transformations and can therefore not be written.") + chunks = get_pyramid_levels(raster_data, "chunks") + # TODO: improve and test the behavior around storage option after merge <------- + if storage_options is None: + storage_options = [{"chunks": chunk} for chunk in chunks] + else: + storage_options = [{"chunks": chunk} for chunk in chunks] + # Apply compression if specified + storage_options = _apply_compression(storage_options, compressor, zarr_format=raster_format.zarr_format) parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=raster_format) - storage_options = [{"chunks": chunk} for chunk in chunks] ome_zarr_format = get_ome_zarr_format(raster_format) dask_delayed = write_multi_scale_ngff( pyramid=data, @@ -366,6 +442,7 @@ def write_image( name: str, element_format: RasterFormatType = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, **metadata: str | JSONDict | list[JSONDict], ) -> None: _write_raster( @@ -375,6 +452,7 @@ def write_image( name=name, raster_format=element_format, storage_options=storage_options, + compressor=compressor, **metadata, ) @@ -386,6 +464,7 @@ def write_labels( element_format: RasterFormatType = CurrentRasterFormat(), storage_options: JSONDict | list[JSONDict] | None = None, label_metadata: JSONDict | None = None, + compressor: dict[Literal["lz4", "zstd"], int] | None = None, **metadata: JSONDict, ) -> None: _write_raster( @@ -395,6 +474,7 @@ def write_labels( name=name, raster_format=element_format, storage_options=storage_options, + compressor=compressor, label_metadata=label_metadata, **metadata, ) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 11855a22..a32774e0 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -3,7 +3,7 @@ import tempfile from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, Literal import dask.dataframe as dd import numpy as np @@ -163,6 +163,75 @@ def test_roundtrip( sdata2.write(tmpdir2, sdata_formats=sdata_container_format) _are_directories_identical(tmpdir, tmpdir2, exclude_regexp="[1-9][0-9]*.*") + def test_compression_roundtrip( + self, + tmp_path: str, + full_sdata: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + ): + tmpdir = Path(tmp_path) / "tmp.zarr" + with pytest.raises(TypeError, match="Expected a dictionary with as"): + full_sdata.write(tmpdir, compressor="faulty", sdata_formats=sdata_container_format) + with pytest.raises(ValueError, match="Expected a dictionary with a single"): + full_sdata.write(tmpdir, compressor={"zstd": 8, "other_item": 4}, sdata_formats=sdata_container_format) + with pytest.raises(ValueError, match="Compression must either"): + full_sdata.write(tmpdir, compressor={"faulty": 8}, sdata_formats=sdata_container_format) + with pytest.raises(ValueError, match="Compression must either"): + full_sdata.write(tmpdir, compressor={"The compression level": 10}, sdata_formats=sdata_container_format) + + full_sdata.write(tmpdir, compressor={"zstd": 8}, sdata_formats=sdata_container_format) + + # sourcery skip: no-loop-in-tests + for element in ["image2d", "image2d_multiscale", "labels2d", "labels2d_multiscale"]: + element_type = "images" if element.startswith("image") else "labels" + arr = zarr.open_group(tmpdir / element_type, mode="r")[element]["0"] + compressor = arr.compressors[0] + + # TODO: all these tests fail because the compression arguments are not passed to Dask + if sdata_container_format.zarr_format == 2: + assert compressor.cname == "zstd" + assert compressor.clevel == 8 + elif sdata_container_format.zarr_format == 3: + from zarr.codecs.zstd import ZstdCodec + + assert isinstance(compressor, ZstdCodec) + assert compressor.level == 8 + + @pytest.mark.parametrize("compressor", [{"lz4": 3}, {"zstd": 7}]) + @pytest.mark.parametrize("element", [("images", "image2d"), ("labels", "labels2d")]) + def test_write_element_compression( + self, + tmp_path: str, + full_sdata: SpatialData, + compressor: dict[Literal["lz4", "zstd"], int], + element: str, + sdata_container_format: SpatialDataContainerFormatType, + ): + tmpdir = Path(tmp_path) / "compression.zarr" + sdata = SpatialData() + sdata.write(tmpdir, sdata_formats=sdata_container_format) + + sdata["element"] = full_sdata[element[1]] + sdata.write_element("element", compressor=compressor, sdata_formats=sdata_container_format) + + arr = zarr.open_group(tmpdir / element[0], mode="r")["element"]["0"] + compression = arr.compressors[0] + + # TODO: all these tests fail because the compression arguments are not passed to Dask + if sdata_container_format.zarr_format == 2: + assert compression.cname == list(compressor.keys())[0] + assert compression.clevel == list(compressor.values())[0] + elif sdata_container_format.zarr_format == 3: + from zarr.codecs import ZstdCodec + + compressor_name = list(compressor.keys())[0] + if compressor_name == "zstd": + assert isinstance(compression, ZstdCodec) + # TODO: fix + # elif compressor_name == 'lz4': + # assert isinstance(compression, ???) + assert compression.level == list(compressor.values())[0] + def test_incremental_io_list_of_elements( self, shapes: SpatialData,