Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,15 +529,24 @@ def template(self) -> "DictStrAny":
return self._template

@cached_property
def resolved_iterable(self) -> dict[str, list]:
def resolved_iterable(self) -> dict[str, Union[list, Mapping]]:
return self._resolve_matrix_data()

def _resolve_matrix_data(self) -> dict[str, list]:
def _resolve_matrix_data(self) -> dict[str, Union[list, Mapping]]:
try:
iterable = self.context.resolve(self.matrix_data, unwrap=False)
except (ContextError, ParseError) as exc:
format_and_raise(exc, f"'{self.where}.{self.name}.matrix'", self.relpath)

for key, value in iterable.items():
if not is_map_or_seq(value):
node = value.value if isinstance(value, Node) else value
typ = type(node).__name__
raise ResolveError(
f"failed to resolve '{self.where}.{self.name}.matrix.{key}'"
f" in '{self.relpath}': expected list/dictionary, got {typ}"
)

# Matrix entries will have `key` and `item` added to the context.
# Warn users if these are already in the context from the global vars.
self._warn_if_overwriting([self.pair.key, self.pair.value])
Expand All @@ -564,13 +573,27 @@ def normalized_iterable(self) -> dict[str, "DictStrAny"]:
assert isinstance(iterable, Mapping)

ret: dict[str, DictStrAny] = {}
matrix = {key: enumerate(v) for key, v in iterable.items()}
matrix = {}
for key, value in iterable.items():
if isinstance(value, Mapping):
# For mappings, use (key, value) pairs.
# Key is used for naming, value is used for context.
matrix[key] = [(to_str(k), v) for k, v in value.items()]
else:
# For sequences, use (name, value) pairs.
# Name is index-based for complex items, or value-based for simple items
items = []
for i, v in enumerate(value):
name = f"{key}{i}" if is_map_or_seq(v) else to_str(v)
items.append((name, v))
matrix[key] = items

for combination in product(*matrix.values()):
d: DictStrAny = {}
fragments: list[str] = []
for k, (i, v) in zip(matrix.keys(), combination):
for k, (name, v) in zip(matrix.keys(), combination):
d[k] = v
fragments.append(f"{k}{i}" if is_map_or_seq(v) else to_str(v))
fragments.append(name)

key = "-".join(fragments)
ret[key] = d
Expand Down
2 changes: 1 addition & 1 deletion dvc/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
VARS_SCHEMA = [str, dict]

STAGE_DEFINITION = {
MATRIX_KWD: {str: vol.Any(str, list)},
MATRIX_KWD: {str: vol.Any(str, list, dict)},
vol.Required(StageParams.PARAM_CMD): vol.Any(str, list),
vol.Optional(StageParams.PARAM_WDIR): str,
vol.Optional(StageParams.PARAM_DEPS): [str],
Expand Down
90 changes: 89 additions & 1 deletion tests/func/parsing/test_matrix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from dvc.parsing import DataResolver, MatrixDefinition
from dvc.parsing import DataResolver, MatrixDefinition, ResolveError
from dvc.schema import COMPILED_MULTI_STAGE_SCHEMA

MATRIX_DATA = {
"os": ["win", "linux"],
Expand Down Expand Up @@ -91,3 +92,90 @@ def test_matrix_key_present(tmp_dir, dvc, matrix):
"[email protected]": {"cmd": "echo linux-3.8-dict1-list0"},
"[email protected]": {"cmd": "echo linux-3.8-dict1-list1"},
}


def test_matrix_schema_allows_mapping():
data = {
"stages": {
"build": {
"matrix": {"models": {"goo": {"val": 1}, "baz": {"val": 2}}},
"cmd": "echo ${item.models.val}",
}
}
}
COMPILED_MULTI_STAGE_SCHEMA(data)


MAPPING_MATRIX_DATA = {"goo": {"val": 1}, "baz": {"val": 2}}


@pytest.mark.parametrize(
"matrix",
[
{"models": MAPPING_MATRIX_DATA},
{"models": "${map_param}"},
],
)
def test_matrix_with_mapping(tmp_dir, dvc, matrix):
(tmp_dir / "params.yaml").dump({"map_param": MAPPING_MATRIX_DATA})
resolver = DataResolver(dvc, tmp_dir.fs_path, {})
data = {"matrix": matrix, "cmd": "echo ${item.models.val}"}
definition = MatrixDefinition(resolver, resolver.context, "build", data)

assert definition.resolve_all() == {
"build@goo": {"cmd": "echo 1"},
"build@baz": {"cmd": "echo 2"},
}


@pytest.mark.parametrize(
"matrix",
[
{"models": MAPPING_MATRIX_DATA, "ver": [1, 2]},
{"models": "${map_param}", "ver": "${ver}"},
],
)
def test_matrix_mixed_mapping_and_list(tmp_dir, dvc, matrix):
(tmp_dir / "params.yaml").dump({"map_param": MAPPING_MATRIX_DATA, "ver": [1, 2]})
resolver = DataResolver(dvc, tmp_dir.fs_path, {})
data = {"matrix": matrix, "cmd": "echo ${item.models.val} ${item.ver}"}
definition = MatrixDefinition(resolver, resolver.context, "build", data)

assert definition.resolve_all() == {
"build@goo-1": {"cmd": "echo 1 1"},
"build@goo-2": {"cmd": "echo 1 2"},
"build@baz-1": {"cmd": "echo 2 1"},
"build@baz-2": {"cmd": "echo 2 2"},
}


@pytest.mark.parametrize(
"matrix",
[
{"models": MAPPING_MATRIX_DATA},
{"models": "${map_param}"},
],
)
def test_matrix_mapping_key_present(tmp_dir, dvc, matrix):
(tmp_dir / "params.yaml").dump({"map_param": MAPPING_MATRIX_DATA})
resolver = DataResolver(dvc, tmp_dir.fs_path, {})
data = {"matrix": matrix, "cmd": "echo ${key}"}
definition = MatrixDefinition(resolver, resolver.context, "build", data)

assert definition.resolve_all() == {
"build@goo": {"cmd": "echo goo"},
"build@baz": {"cmd": "echo baz"},
}


@pytest.mark.parametrize("matrix_value", ["${foo}", "${dct.model1}", "foobar"])
def test_matrix_expects_list_or_dict(tmp_dir, dvc, matrix_value):
(tmp_dir / "params.yaml").dump({"foo": "bar", "dct": {"model1": "a-out"}})
resolver = DataResolver(dvc, tmp_dir.fs_path, {})
data = {"matrix": {"dim": matrix_value}, "cmd": "echo ${item.dim}"}
definition = MatrixDefinition(resolver, resolver.context, "build", data)

with pytest.raises(ResolveError) as exc_info:
definition.resolve_all()
assert "expected list/dictionary, got str" in str(exc_info.value)
assert "stages.build.matrix.dim" in str(exc_info.value)