xref: /aosp_15_r20/external/executorch/exir/serde/schema_check.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import dataclasses
8import hashlib
9import re
10import typing
11from enum import IntEnum
12from typing import Any, Dict, Optional, Union
13
14from torch._export.serde import schema
15from torch._export.serde.union import _Union
16
17
18class SchemaUpdateError(Exception):
19    pass
20
21
22def _check(x, msg):
23    if not x:
24        raise SchemaUpdateError(msg)
25
26
27def _staged_schema():
28    ret: Dict[str, Any] = {}
29    defs = {}
30
31    def _handle_aggregate(ty):
32        def dump_type(t):
33            if isinstance(t, type):
34                return t.__name__
35            elif isinstance(t, str):
36                assert t in defs
37                return t
38            elif o := typing.get_origin(t):
39                # Lemme know if there's a better way to do this.
40                if o == list:
41                    head = "List"
42                elif o == dict:
43                    head = "Dict"
44                elif o == tuple:
45                    if typing.get_args(t) == ():
46                        return "Tuple[()]"
47                    head = "Tuple"
48                elif o == Union:
49                    args = typing.get_args(t)
50                    assert len(args) == 2 and args[1] == type(None)
51                    return f"Optional[{dump_type(args[0])}]"
52                else:
53                    raise AssertionError(f"Type {t} is not supported in export schema.")
54                return (
55                    f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]"
56                )
57            elif t == ():
58                return "()"
59            else:
60                raise AssertionError(f"Type {t} is not supported in export schema.")
61
62        def dump_field(f):
63            t = dump_type(f.type)
64            ret = {"type": t}
65
66            value = dataclasses.MISSING
67            if f.default is not dataclasses.MISSING:
68                value = f.default
69            elif f.default_factory is not dataclasses.MISSING:
70                value = f.default_factory()
71
72            if t.startswith("Optional[") and value is not None:
73                raise AssertionError(
74                    f"Optional field {ty.__name__}.{f.name} must have default value to be None."
75                )
76
77            if value is not dataclasses.MISSING:
78                default = str(value)
79                ret["default"] = default
80            return ret
81
82        return {f.name: dump_field(f) for f in dataclasses.fields(ty)}
83
84    def _handle_int_enum(name, ty):
85        ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}
86
87    def _handle_struct(name, ty):
88        ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)}
89
90    def _handle_union(name, ty):
91        ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)}
92
93    for name in dir(schema):
94        if name.startswith("_"):
95            continue
96
97        value = getattr(schema, name)
98
99        if hasattr(value, "__module__") and value.__module__ != schema.__name__:
100            continue
101
102        defs[name] = value
103
104    for name, value in defs.items():
105        if isinstance(value, type):
106            if issubclass(value, IntEnum):
107                _handle_int_enum(name, value)
108            elif dataclasses.is_dataclass(value):
109                if issubclass(value, _Union):
110                    _handle_union(name, value)
111                else:
112                    _handle_struct(name, value)
113            else:
114                raise AssertionError(f"Unknown schema type {name}: {value}")
115        elif isinstance(value, (int, tuple)):
116            assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION")
117        else:
118            raise AssertionError(f"Unknown variable {name}: {value}")
119
120    ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
121    assert all(x > 0 for x in ret["SCHEMA_VERSION"])
122    ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
123    assert ret["TREESPEC_VERSION"] > 0
124    return ret
125
126
127def _diff_schema(dst, src):
128    additions = {key: src[key] for key in src.keys() - dst.keys()}
129    subtractions = {key: dst[key] for key in dst.keys() - src.keys()}
130
131    common_keys = src.keys() & dst.keys()
132
133    versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"}
134    common_keys -= versions
135
136    for key in common_keys:
137        src_kind = src[key]["kind"]
138        src_fields = src[key]["fields"]
139        dst_kind = dst[key]["kind"]
140        dst_fields = dst[key]["fields"]
141        _check(
142            src_kind == dst_kind,
143            f"Type {key} changed kind from {dst_kind} to {src_kind}",
144        )
145        assert isinstance(src_fields, dict) and isinstance(dst_fields, dict)
146        added_fields = {
147            key: src_fields[key] for key in src_fields.keys() - dst_fields.keys()
148        }
149        subtracted_fields = {
150            key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys()
151        }
152        common_fields = src_fields.keys() & dst_fields.keys()
153
154        for field in common_fields:
155            src_field = src_fields[field]
156            dst_field = dst_fields[field]
157            if src_kind == "struct":
158                _check(
159                    src_field["type"] == dst_field["type"],
160                    f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
161                )
162                if "default" in src_field and "default" not in dst_field:
163                    added_fields[field] = {}
164                    added_fields[field]["default"] = src_field["default"]
165                if "default" not in src_field and "default" in dst_field:
166                    subtracted_fields[field] = {}
167                    subtracted_fields[field]["default"] = dst_field["default"]
168            elif src_kind == "enum":
169                _check(
170                    src_field == dst_field,
171                    f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}",
172                )
173            elif src_kind == "union":
174                _check(
175                    src_field["type"] == dst_field["type"],
176                    f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
177                )
178            else:
179                raise AssertionError(f"Unknown kind {src_kind}: {key}")
180        if len(added_fields) > 0:
181            assert key not in additions
182            additions[key] = {}
183            additions[key]["fields"] = added_fields
184        if len(subtracted_fields) > 0:
185            assert key not in subtractions
186            subtractions[key] = {}
187            subtractions[key]["fields"] = subtracted_fields
188
189    return additions, subtractions
190
191
192def _hash_schema(s):
193    return hashlib.sha256(repr(s).encode("utf-8")).hexdigest()
194
195
196@dataclasses.dataclass
197class _Commit:
198    result: Dict[str, Any]
199    checksum_result: str
200    path: str
201    additions: Dict[str, Any]
202    subtractions: Dict[str, Any]
203    base: Dict[str, Any]
204    checksum_base: Optional[str]
205
206
207def update_schema():
208    import importlib.resources
209
210    if importlib.resources.is_resource(__package__, "schema.yaml"):
211        content = importlib.resources.read_text(__package__, "schema.yaml")
212        match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
213        _check(match is not None, "checksum not found in schema.yaml")
214        assert match is not None
215        checksum_base = match.group(1)
216        from yaml import load, Loader
217
218        dst = load(content, Loader=Loader)
219        assert isinstance(dst, dict)
220    else:
221        checksum_base = None
222        dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}
223
224    src = _staged_schema()
225    additions, subtractions = _diff_schema(dst, src)
226    return _Commit(
227        result=src,
228        checksum_result=_hash_schema(src),
229        path=__package__.replace(".", "/") + "/schema.yaml",
230        additions=additions,
231        subtractions=subtractions,
232        base=dst,
233        checksum_base=checksum_base,
234    )
235
236
237def check(commit: _Commit, force_unsafe: bool = False):
238    next_version = None
239    reason = ""
240    # Step 1: Detect major schema updates.
241    if len(commit.additions) > 0:
242        for k, v in commit.additions.items():
243            if k not in commit.base:
244                continue
245            kind = commit.result[k]["kind"]
246            fields = v["fields"]
247            for f, d in fields.items():
248                if "default" not in d and kind == "struct":
249                    reason += (
250                        f"Field {k}.{f} is added to schema.py without a default value as an incomparible change "
251                        + "which requires major version bump.\n"
252                    )
253                    next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]
254
255    if len(commit.subtractions) > 0:
256        for k, v in commit.subtractions.items():
257            if k not in commit.result:
258                continue
259            for f in v["fields"]:
260                reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n"
261            next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]
262
263    if force_unsafe:
264        reason += "--force-unsafe is used."
265        next_version = commit.result["SCHEMA_VERSION"]
266    else:
267        # Step 2: Detect minor schema updates.
268        if next_version is None and len(commit.additions) > 0:
269            for k, v in commit.additions.items():
270                for f in v["fields"]:
271                    reason += (
272                        f"Field {k}.{f} is added to schema.py as an compatible change "
273                        + "which still requires minor version bump.\n"
274                    )
275            next_version = [
276                commit.base["SCHEMA_VERSION"][0],
277                commit.base["SCHEMA_VERSION"][1] + 1,
278            ]
279        if next_version is None and len(commit.subtractions) > 0:
280            for k, v in commit.subtractions.items():
281                for f in v["fields"]:
282                    reason += (
283                        f"Field {k}.{f} is removed from schema.py as an compatible change "
284                        + "which still requires minor version bump.\n"
285                    )
286            next_version = [
287                commit.base["SCHEMA_VERSION"][0],
288                commit.base["SCHEMA_VERSION"][1] + 1,
289            ]
290
291    return next_version, reason
292