Skip to content

Commit b4aa5f5

Browse files
feat!: convert parameter schema from voluptuous to msgspec
1 parent a63271c commit b4aa5f5

3 files changed

Lines changed: 112 additions & 69 deletions

File tree

src/taskgraph/parameters.py

Lines changed: 95 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,17 @@
1010
from io import BytesIO
1111
from pprint import pformat
1212
from subprocess import CalledProcessError
13+
from typing import Optional, Union
1314
from unittest.mock import Mock
1415
from urllib.parse import urlparse
1516
from urllib.request import urlopen
1617

1718
import mozilla_repo_urls
18-
from voluptuous import ALLOW_EXTRA, Any, Optional, Required, Schema
19+
import msgspec
1920

2021
from taskgraph.util import json, yaml
2122
from taskgraph.util.readonlydict import ReadOnlyDict
22-
from taskgraph.util.schema import validate_schema
23+
from taskgraph.util.schema import Schema, validate_schema
2324
from taskgraph.util.taskcluster import find_task_id, get_artifact_url
2425
from taskgraph.util.vcs import get_repository
2526

@@ -30,43 +31,50 @@ class ParameterMismatch(Exception):
3031

3132
#: Schema for base parameters.
3233
#: Please keep this list sorted and in sync with docs/reference/parameters.rst
33-
base_schema = Schema(
34+
base_schema = Schema.from_dict(
3435
{
35-
Required("base_repository"): str,
36-
Optional("base_ref"): str,
37-
Required("base_rev"): str,
38-
Required("build_date"): int,
39-
Required("build_number"): int,
40-
Required("do_not_optimize"): [str],
41-
Required("enable_always_target"): Any(bool, [str]),
42-
Required("existing_tasks"): {str: str},
43-
Required("files_changed"): [str],
44-
Required("filters"): [str],
45-
Required("head_ref"): str,
46-
Required("head_repository"): str,
47-
Required("head_rev"): str,
48-
Required("head_tag"): str,
49-
Required("level"): str,
50-
Required("moz_build_date"): str,
51-
Required("next_version"): Any(str, None),
52-
Required("optimize_strategies"): Any(str, None),
53-
Required("optimize_target_tasks"): bool,
54-
Required("owner"): str,
55-
Required("project"): str,
56-
Required("pushdate"): int,
57-
Required("pushlog_id"): str,
58-
Required("repository_type"): str,
36+
"base_repository": str,
37+
"base_ref": Optional[str],
38+
"base_rev": str,
39+
"build_date": int,
40+
"build_number": int,
41+
"do_not_optimize": list[str],
42+
"enable_always_target": Union[bool, list[str]],
43+
"existing_tasks": dict[str, str],
44+
"files_changed": list[str],
45+
"filters": list[str],
46+
"head_ref": str,
47+
"head_repository": str,
48+
"head_rev": str,
49+
"head_tag": str,
50+
"level": str,
51+
"moz_build_date": str,
52+
"next_version": Optional[str],
53+
"optimize_strategies": Optional[str],
54+
"optimize_target_tasks": bool,
55+
"owner": str,
56+
"project": str,
57+
"pushdate": int,
58+
"pushlog_id": str,
59+
"repository_type": str,
5960
# target-kinds is not included, since it should never be
6061
# used at run-time
61-
Required("target_tasks_method"): str,
62-
Required("tasks_for"): str,
63-
Required("version"): Any(str, None),
64-
Optional("code-review"): {
65-
Required("phabricator-build-target"): str,
66-
},
67-
}
62+
"target_tasks_method": str,
63+
"tasks_for": str,
64+
"version": Optional[str],
65+
"code-review": Schema.from_dict(
66+
{"phabricator-build-target": str},
67+
name="CodeReviewConfig",
68+
optional=True,
69+
),
70+
},
71+
name="BaseParametersSchema",
72+
forbid_unknown_fields=False,
73+
kw_only=True,
6874
)
6975

76+
_parameter_extensions: list = []
77+
7078

7179
def get_contents(path):
7280
with open(path) as fh:
@@ -83,11 +91,21 @@ def _get_defaults(repo_root=None):
8391
repo_path = repo_root or os.getcwd()
8492
try:
8593
repo = get_repository(repo_path)
86-
except RuntimeError:
87-
# Use fake values if no repo is detected.
88-
repo = Mock(branch="", head_rev="", tool="git")
94+
# Resolve git-backed attributes eagerly so any subprocess failures
95+
# (e.g. Windows "dubious ownership" when safe.directory isn't honored)
96+
# are caught by the except below instead of escaping later.
97+
branch = repo.branch
98+
head_rev = repo.head_rev
99+
tool = repo.tool
100+
files_changed = repo.get_changed_files("AM")
101+
except (RuntimeError, CalledProcessError):
102+
# Use fake values if no repo is detected or git refuses to operate.
103+
repo = Mock()
89104
repo.get_url.return_value = ""
90-
repo.get_changed_files.return_value = []
105+
branch = ""
106+
head_rev = ""
107+
tool = "git"
108+
files_changed = []
91109

92110
try:
93111
repo_url = repo.get_url()
@@ -110,11 +128,11 @@ def _get_defaults(repo_root=None):
110128
"do_not_optimize": [],
111129
"enable_always_target": True,
112130
"existing_tasks": {},
113-
"files_changed": lambda: repo.get_changed_files("AM"),
131+
"files_changed": files_changed,
114132
"filters": ["target_tasks_method"],
115-
"head_ref": repo.branch or repo.head_rev,
133+
"head_ref": branch or head_rev,
116134
"head_repository": repo_url,
117-
"head_rev": repo.head_rev,
135+
"head_rev": head_rev,
118136
"head_tag": "",
119137
"level": "3",
120138
"moz_build_date": datetime.now().strftime("%Y%m%d%H%M%S"),
@@ -125,7 +143,7 @@ def _get_defaults(repo_root=None):
125143
"project": project,
126144
"pushdate": int(time.time()),
127145
"pushlog_id": "0",
128-
"repository_type": repo.tool,
146+
"repository_type": tool,
129147
"target_tasks_method": "default",
130148
"tasks_for": "",
131149
"version": get_version(repo_path),
@@ -143,19 +161,27 @@ def extend_parameters_schema(schema, defaults_fn=None):
143161
graph-configuration.
144162
145163
Args:
146-
schema (Schema): The voluptuous.Schema object used to describe extended
147-
parameters.
164+
schema: A msgspec ``Schema`` subclass describing extended parameters.
148165
defaults_fn (function): A function which takes no arguments and returns a
149166
dict mapping parameter name to default value in the
150167
event strict=False (optional).
151168
"""
152-
global base_schema
153169
global defaults_functions
154-
base_schema = base_schema.extend(schema)
170+
if not (isinstance(schema, type) and issubclass(schema, msgspec.Struct)):
171+
raise TypeError(
172+
"extend_parameters_schema requires a msgspec Schema subclass; "
173+
f"got {type(schema).__name__}"
174+
)
175+
_parameter_extensions.append(schema)
155176
if defaults_fn:
156177
defaults_functions.append(defaults_fn)
157178

158179

180+
def _schema_key_names(schema) -> set:
181+
"""Return the data-level field names declared by a parameters schema."""
182+
return {f.encode_name for f in msgspec.structs.fields(schema)}
183+
184+
159185
class Parameters(ReadOnlyDict):
160186
"""An immutable dictionary with nicer KeyError messages on failure"""
161187

@@ -214,11 +240,30 @@ def _fill_defaults(repo_root=None, **kwargs):
214240
return kwargs
215241

216242
def check(self):
217-
schema = (
218-
base_schema if self.strict else base_schema.extend({}, extra=ALLOW_EXTRA)
219-
)
243+
data = dict(self.copy())
220244
try:
221-
validate_schema(schema, self.copy(), "Invalid parameters:")
245+
# Validate core fields against just the subset of data owned by the
246+
# base schema. Extension keys are validated separately below, and a
247+
# strict-mode check rejects anything unknown to either.
248+
base_keys = _schema_key_names(base_schema)
249+
base_data = {k: v for k, v in data.items() if k in base_keys}
250+
validate_schema(base_schema, base_data, "Invalid parameters:")
251+
252+
# Validate each registered extension against the keys it declares.
253+
allowed = set(base_keys)
254+
for ext in _parameter_extensions:
255+
ext_keys = _schema_key_names(ext)
256+
allowed |= ext_keys
257+
ext_data = {k: data[k] for k in ext_keys if k in data}
258+
validate_schema(ext, ext_data, "Invalid parameters:")
259+
260+
# Strict mode: reject any data key not covered by base or extensions.
261+
if self.strict:
262+
unknown = sorted(set(data) - allowed)
263+
if unknown:
264+
raise Exception(
265+
"Invalid parameters:\nunknown keys: " + ", ".join(unknown)
266+
)
222267
except Exception as e:
223268
raise ParameterMismatch(str(e))
224269

taskcluster/self_taskgraph/custom_parameters.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

55
import os
6+
from typing import Annotated, Optional
67

7-
from voluptuous import All, Any, Range, Required
8+
import msgspec
89

910
from taskgraph.parameters import extend_parameters_schema
11+
from taskgraph.util.schema import Schema
1012

1113

1214
def get_defaults(repo_root):
@@ -15,14 +17,15 @@ def get_defaults(repo_root):
1517
}
1618

1719

18-
extend_parameters_schema(
19-
{
20-
Required("pull_request_number"): Any(All(int, Range(min=1)), None),
21-
},
22-
defaults_fn=get_defaults,
20+
CustomParametersSchema = Schema.from_dict(
21+
{"pull_request_number": Optional[Annotated[int, msgspec.Meta(ge=1)]]},
22+
name="CustomParametersSchema",
2323
)
2424

2525

26+
extend_parameters_schema(CustomParametersSchema, defaults_fn=get_defaults)
27+
28+
2629
def decision_parameters(graph_config, parameters):
2730
if parameters["tasks_for"] == "github-release":
2831
parameters["target_tasks_method"] = "release"

test/test_parameters.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import gzip
88
import os
99
from base64 import b64decode
10+
from typing import Optional
1011
from unittest import TestCase, mock
1112

1213
import mozilla_repo_urls
1314
import pytest
14-
from voluptuous import Optional, Required, Schema
1515

1616
import taskgraph # noqa: F401
1717
from taskgraph import parameters
@@ -21,6 +21,7 @@
2121
extend_parameters_schema,
2222
load_parameters_file,
2323
)
24+
from taskgraph.util.schema import Schema
2425

2526
from .mockedopen import MockedOpen
2627

@@ -274,20 +275,16 @@ def test_parameters_format_spec(spec, expected):
274275

275276

276277
def test_extend_parameters_schema(monkeypatch):
277-
monkeypatch.setattr(
278-
parameters,
279-
"base_schema",
280-
Schema(
281-
{
282-
Required("foo"): str,
283-
}
284-
),
285-
)
278+
FooSchema = Schema.from_dict({"foo": str}, name="FooSchema")
279+
BarSchema = Schema.from_dict({"bar": Optional[bool]}, name="BarSchema")
280+
281+
monkeypatch.setattr(parameters, "base_schema", FooSchema)
286282
monkeypatch.setattr(
287283
parameters,
288284
"defaults_functions",
289285
list(parameters.defaults_functions),
290286
)
287+
monkeypatch.setattr(parameters, "_parameter_extensions", [])
291288

292289
with pytest.raises(ParameterMismatch):
293290
Parameters(strict=False).check()
@@ -296,9 +293,7 @@ def test_extend_parameters_schema(monkeypatch):
296293
Parameters(foo="1", bar=True).check()
297294

298295
extend_parameters_schema(
299-
{
300-
Optional("bar"): bool,
301-
},
296+
BarSchema,
302297
defaults_fn=lambda root: {"foo": "1", "bar": False},
303298
)
304299

0 commit comments

Comments
 (0)