Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dvc/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def __call__(self, data):
"feature": FeatureSchema(
{
Optional("machine", default=False): Bool,
Optional("fast_yaml", default=False): Bool,
},
),
"plots": {
Expand Down
2 changes: 2 additions & 0 deletions dvc/dependency/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,14 @@ def read_params(
return {}

assert self.repo
use_fast_yaml = self.repo.config.get("feature", {}).get("fast_yaml", False)
try:
return read_param_file(
self.repo.fs,
self.fs_path,
list(self.params) if self.params else None,
flatten=flatten,
use_fast_yaml=use_fast_yaml,
)
except ParseError as exc:
raise BadParamFileError(f"Unable to read parameters from '{self}'") from exc
Expand Down
47 changes: 45 additions & 2 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from dvc.utils import relpath
from dvc.utils.collections import apply_diff
from dvc.utils.objects import cached_property
from dvc.utils.serialize import dump_yaml, modify_yaml
from dvc.utils.serialize import (
FastYAMLParseError,
dump_yaml,
dump_yaml_fast,
load_yaml_fast,
modify_yaml,
)

if TYPE_CHECKING:
from dvc.repo import Repo
Expand Down Expand Up @@ -85,6 +91,10 @@ def __init__(self, repo, path, verify=True, **kwargs):
self.path = path
self.verify = verify

@cached_property
def _use_fast_yaml(self) -> bool:
return self.repo.config.get("feature", {}).get("fast_yaml", False)

def __repr__(self):
return f"{self.__class__.__name__}: {relpath(self.path, self.repo.root_dir)}"

Expand Down Expand Up @@ -152,6 +162,7 @@ def _load_yaml(self, **kwargs: Any) -> tuple[Any, str]:
self.path,
self.SCHEMA, # type: ignore[arg-type]
self.repo.fs,
use_fast_yaml=self._use_fast_yaml,
**kwargs,
)

Expand Down Expand Up @@ -428,12 +439,44 @@ def dump_stages(self, stages, **kwargs):
if not stages:
return

if self._use_fast_yaml:
self._dump_stages_fast(stages, **kwargs)
else:
self._dump_stages_slow(stages, **kwargs)

def _dump_stages_fast(self, stages, **kwargs):
try:
data = load_yaml_fast(self.path, self.repo.fs)
except (FileNotFoundError, FastYAMLParseError):
data = {}

if not data:
data = {"schema": "2.0"}
logger.info("Generating lock file '%s'", self.relpath)

data.setdefault("stages", {})
is_modified = False
log_updated = False

for stage in stages:
stage_data = serialize.to_lockfile(stage, **kwargs)
if data["stages"].get(stage.name, {}) != stage_data.get(stage.name, {}):
is_modified = True
if not log_updated:
logger.info("Updating lock file '%s'", self.relpath)
log_updated = True
data["stages"].update(stage_data)

if is_modified:
dump_yaml_fast(self.path, data, fs=self.repo.fs)
self.repo.scm_context.track_file(self.relpath)

def _dump_stages_slow(self, stages, **kwargs):
is_modified = False
log_updated = False
with modify_yaml(self.path, fs=self.repo.fs) as data:
if not data:
data.update({"schema": "2.0"})
# order is important, meta should always be at the top
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: restore comments

logger.info("Generating lock file '%s'", self.relpath)

data["stages"] = data.get("stages", {})
Expand Down
19 changes: 13 additions & 6 deletions dvc/parsing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,11 @@ def select(self, key: str, unwrap: bool = False):

@classmethod
def load_from(
cls, fs, path: str, select_keys: Optional[list[str]] = None
cls,
fs,
path: str,
select_keys: Optional[list[str]] = None,
use_fast_yaml: bool = False,
) -> "Context":
from dvc.utils.serialize import load_path

Expand All @@ -358,7 +362,7 @@ def load_from(
if fs.isdir(path):
raise ParamsLoadError(f"'{path}' is a directory")

data = load_path(path, fs)
data = load_path(path, fs, use_fast_yaml=use_fast_yaml)
if not isinstance(data, Mapping):
typ = type(data).__name__
raise ParamsLoadError(
Expand All @@ -383,7 +387,9 @@ def merge_update(self, other: "Context", overwrite=False):
raise ReservedKeyError(matches)
return super().merge_update(other, overwrite=overwrite)

def merge_from(self, fs, item: str, wdir: str, overwrite=False):
def merge_from(
self, fs, item: str, wdir: str, overwrite=False, use_fast_yaml: bool = False
):
path, _, keys_str = item.partition(":")
path = fs.normpath(fs.join(wdir, path))

Expand All @@ -393,7 +399,7 @@ def merge_from(self, fs, item: str, wdir: str, overwrite=False):
return # allow specifying complete filepath multiple times
self.check_loaded(path, item, select_keys)

ctx = Context.load_from(fs, path, select_keys)
ctx = Context.load_from(fs, path, select_keys, use_fast_yaml=use_fast_yaml)

try:
self.merge_update(ctx, overwrite=overwrite)
Expand Down Expand Up @@ -428,11 +434,12 @@ def load_from_vars(
wdir: str,
stage_name: Optional[str] = None,
default: Optional[str] = None,
use_fast_yaml: bool = False,
):
if default:
to_import = fs.join(wdir, default)
if fs.exists(to_import):
self.merge_from(fs, default, wdir)
self.merge_from(fs, default, wdir, use_fast_yaml=use_fast_yaml)
else:
msg = "%s does not exist, it won't be used in parametrization"
logger.trace(msg, to_import)
Expand All @@ -441,7 +448,7 @@ def load_from_vars(
for index, item in enumerate(vars_):
assert isinstance(item, (str, dict))
if isinstance(item, str):
self.merge_from(fs, item, wdir)
self.merge_from(fs, item, wdir, use_fast_yaml=use_fast_yaml)
else:
joiner = "." if stage_name else ""
meta = Meta(source=f"{stage_name}{joiner}vars[{index}]")
Expand Down
21 changes: 16 additions & 5 deletions dvc/repo/metrics/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,26 @@ def _extract_metrics(metrics, path: str):
return ret


def _read_metric(fs: "FileSystem", path: str, **load_kwargs) -> Any:
val = load_path(path, fs, **load_kwargs)
def _read_metric(
fs: "FileSystem", path: str, use_fast_yaml: bool = False, **load_kwargs
) -> Any:
val = load_path(path, fs, use_fast_yaml=use_fast_yaml, **load_kwargs)
val = _extract_metrics(val, path)
return val or {}


def _read_metrics(
fs: "FileSystem", metrics: Iterable[str], **load_kwargs
fs: "FileSystem",
metrics: Iterable[str],
use_fast_yaml: bool = False,
**load_kwargs,
) -> Iterator[tuple[str, Union[Exception, Any]]]:
for metric in metrics:
try:
yield metric, _read_metric(fs, metric, **load_kwargs)
yield (
metric,
_read_metric(fs, metric, use_fast_yaml=use_fast_yaml, **load_kwargs),
)
except Exception as exc: # noqa: BLE001
logger.debug(exc)
yield metric, exc
Expand Down Expand Up @@ -158,9 +166,12 @@ def _gather_metrics(
# the result and convert to appropriate repo-relative os.path.
files = _collect_metrics(repo, targets=targets, stages=stages, outs_only=outs_only)
data = {}
use_fast_yaml = repo.config.get("feature", {}).get("fast_yaml", False)

fs = repo.dvcfs
for fs_path, result in _read_metrics(fs, files, cache=True):
for fs_path, result in _read_metrics(
fs, files, use_fast_yaml=use_fast_yaml, cache=True
):
repo_path = fs_path.lstrip(fs.root_marker)
repo_os_path = os.sep.join(fs.parts(repo_path))
if not isinstance(result, Exception):
Expand Down
16 changes: 13 additions & 3 deletions dvc/stage/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,25 @@ def _load_cache(self, key, value):
from voluptuous import Invalid

from dvc.schema import COMPILED_LOCK_FILE_STAGE_SCHEMA
from dvc.utils.serialize import YAMLFileCorruptedError, load_yaml
from dvc.utils.serialize import (
FastYAMLParseError,
YAMLFileCorruptedError,
load_yaml,
load_yaml_fast,
)

path = self._get_cache_path(key, value)
use_fast_yaml = self.repo.config.get("feature", {}).get("fast_yaml", False)

try:
return COMPILED_LOCK_FILE_STAGE_SCHEMA(load_yaml(path))
if use_fast_yaml:
data = load_yaml_fast(path)
else:
data = load_yaml(path)
return COMPILED_LOCK_FILE_STAGE_SCHEMA(data)
except FileNotFoundError:
return None
except (YAMLFileCorruptedError, Invalid):
except (YAMLFileCorruptedError, FastYAMLParseError, Invalid):
logger.warning("corrupted cache file '%s'.", relpath(path))
os.unlink(path)
return None
Expand Down
7 changes: 6 additions & 1 deletion dvc/utils/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
)


def load_path(fs_path, fs, **kwargs):
def load_path(fs_path, fs, use_fast_yaml=False, **kwargs):
suffix = fs.suffix(fs_path).lower()
if use_fast_yaml and suffix in (".yaml", ".yml"):
try:
return load_yaml_fast(fs_path, fs=fs) # noqa: F405
except FastYAMLParseError: # noqa: F405
return load_yaml(fs_path, fs=fs, **kwargs) # noqa: F405
loader = LOADERS[suffix]
return loader(fs_path, fs=fs, **kwargs)

Expand Down
104 changes: 104 additions & 0 deletions dvc/utils/serialize/_yaml.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import io
from collections import OrderedDict
from contextlib import contextmanager
Expand All @@ -17,6 +18,13 @@ def __init__(self, path):
super().__init__(path, "YAML file structure is corrupted")


class FastYAMLParseError(YAMLError):
def __init__(self, path, text, exc):
self.text = text
self.original_exc = exc
super().__init__(path, str(exc))


def load_yaml(path, fs=None, **kwargs):
return _load_data(path, parser=parse_yaml, fs=fs, **kwargs)

Expand Down Expand Up @@ -79,3 +87,99 @@ def dumps_yaml(d):
def modify_yaml(path, fs=None):
with _modify_data(path, parse_yaml_for_update, _dump, fs=fs) as d:
yield d


@functools.cache
def _get_safe_loader():
import re

import yaml

base_loader = getattr(yaml, "CSafeLoader", yaml.SafeLoader)

class YAML12SafeLoader(base_loader): # type: ignore[valid-type,misc]
pass

YAML12SafeLoader.yaml_implicit_resolvers = {
k: [
(tag, regexp)
for tag, regexp in v
if tag not in ("tag:yaml.org,2002:bool", "tag:yaml.org,2002:int")
]
for k, v in base_loader.yaml_implicit_resolvers.copy().items()
}

bool_pattern = re.compile(r"^(?:true|True|TRUE|false|False|FALSE)$")
YAML12SafeLoader.add_implicit_resolver(
"tag:yaml.org,2002:bool",
bool_pattern,
list("tTfF"),
)

int_pattern = re.compile(r"^[-+]?(?:[0-9]+|0o[0-7]+|0x[0-9a-fA-F]+)$")
YAML12SafeLoader.add_implicit_resolver(
"tag:yaml.org,2002:int",
int_pattern,
list("-+0123456789"),
)

def construct_yaml_int_yaml12(loader, node):
value = loader.construct_scalar(node)
value = value.replace("_", "")
sign = 1
if value.startswith("-"):
sign = -1
value = value[1:]
elif value.startswith("+"):
value = value[1:]
if value.startswith(("0x", "0o")):
return sign * int(value, 0)
return sign * int(value.lstrip("0") or "0")

YAML12SafeLoader.add_constructor("tag:yaml.org,2002:int", construct_yaml_int_yaml12)

return YAML12SafeLoader


def _get_safe_dumper():
import yaml

try:
return yaml.CSafeDumper
except AttributeError:
return yaml.SafeDumper


def parse_yaml_fast(text, path):
import yaml

try:
return yaml.load(text, Loader=_get_safe_loader()) or {} # noqa: S506
except yaml.YAMLError as exc:
raise FastYAMLParseError(path, text, exc) from exc


def load_yaml_fast(path, fs=None):
from ._common import EncodingError

open_fn = fs.open if fs else open
try:
with open_fn(path, encoding="utf-8") as fd:
text = fd.read()
except UnicodeDecodeError as exc:
raise EncodingError(path, "utf-8") from exc
return parse_yaml_fast(text, path)


def dump_yaml_fast(path, data, fs=None):
import yaml

open_fn = fs.open if fs else open
with open_fn(path, "w", encoding="utf-8") as fd:
yaml.dump(
data,
fd,
Dumper=_get_safe_dumper(),
sort_keys=False,
default_flow_style=False,
)
Loading