1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import dataclasses 8import hashlib 9import re 10import typing 11from enum import IntEnum 12from typing import Any, Dict, Optional, Union 13 14from torch._export.serde import schema 15from torch._export.serde.union import _Union 16 17 18class SchemaUpdateError(Exception): 19 pass 20 21 22def _check(x, msg): 23 if not x: 24 raise SchemaUpdateError(msg) 25 26 27def _staged_schema(): 28 ret: Dict[str, Any] = {} 29 defs = {} 30 31 def _handle_aggregate(ty): 32 def dump_type(t): 33 if isinstance(t, type): 34 return t.__name__ 35 elif isinstance(t, str): 36 assert t in defs 37 return t 38 elif o := typing.get_origin(t): 39 # Lemme know if there's a better way to do this. 40 if o == list: 41 head = "List" 42 elif o == dict: 43 head = "Dict" 44 elif o == tuple: 45 if typing.get_args(t) == (): 46 return "Tuple[()]" 47 head = "Tuple" 48 elif o == Union: 49 args = typing.get_args(t) 50 assert len(args) == 2 and args[1] == type(None) 51 return f"Optional[{dump_type(args[0])}]" 52 else: 53 raise AssertionError(f"Type {t} is not supported in export schema.") 54 return ( 55 f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]" 56 ) 57 elif t == (): 58 return "()" 59 else: 60 raise AssertionError(f"Type {t} is not supported in export schema.") 61 62 def dump_field(f): 63 t = dump_type(f.type) 64 ret = {"type": t} 65 66 value = dataclasses.MISSING 67 if f.default is not dataclasses.MISSING: 68 value = f.default 69 elif f.default_factory is not dataclasses.MISSING: 70 value = f.default_factory() 71 72 if t.startswith("Optional[") and value is not None: 73 raise AssertionError( 74 f"Optional field {ty.__name__}.{f.name} must have default value to be None." 75 ) 76 77 if value is not dataclasses.MISSING: 78 default = str(value) 79 ret["default"] = default 80 return ret 81 82 return {f.name: dump_field(f) for f in dataclasses.fields(ty)} 83 84 def _handle_int_enum(name, ty): 85 ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}} 86 87 def _handle_struct(name, ty): 88 ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)} 89 90 def _handle_union(name, ty): 91 ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)} 92 93 for name in dir(schema): 94 if name.startswith("_"): 95 continue 96 97 value = getattr(schema, name) 98 99 if hasattr(value, "__module__") and value.__module__ != schema.__name__: 100 continue 101 102 defs[name] = value 103 104 for name, value in defs.items(): 105 if isinstance(value, type): 106 if issubclass(value, IntEnum): 107 _handle_int_enum(name, value) 108 elif dataclasses.is_dataclass(value): 109 if issubclass(value, _Union): 110 _handle_union(name, value) 111 else: 112 _handle_struct(name, value) 113 else: 114 raise AssertionError(f"Unknown schema type {name}: {value}") 115 elif isinstance(value, (int, tuple)): 116 assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION") 117 else: 118 raise AssertionError(f"Unknown variable {name}: {value}") 119 120 ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"]) 121 assert all(x > 0 for x in ret["SCHEMA_VERSION"]) 122 ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"] 123 assert ret["TREESPEC_VERSION"] > 0 124 return ret 125 126 127def _diff_schema(dst, src): 128 additions = {key: src[key] for key in src.keys() - dst.keys()} 129 subtractions = {key: dst[key] for key in dst.keys() - src.keys()} 130 131 common_keys = src.keys() & dst.keys() 132 133 versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"} 134 common_keys -= versions 135 136 for key in common_keys: 137 src_kind = src[key]["kind"] 138 src_fields = src[key]["fields"] 139 dst_kind = dst[key]["kind"] 140 dst_fields = dst[key]["fields"] 141 _check( 142 src_kind == dst_kind, 143 f"Type {key} changed kind from {dst_kind} to {src_kind}", 144 ) 145 assert isinstance(src_fields, dict) and isinstance(dst_fields, dict) 146 added_fields = { 147 key: src_fields[key] for key in src_fields.keys() - dst_fields.keys() 148 } 149 subtracted_fields = { 150 key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys() 151 } 152 common_fields = src_fields.keys() & dst_fields.keys() 153 154 for field in common_fields: 155 src_field = src_fields[field] 156 dst_field = dst_fields[field] 157 if src_kind == "struct": 158 _check( 159 src_field["type"] == dst_field["type"], 160 f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", 161 ) 162 if "default" in src_field and "default" not in dst_field: 163 added_fields[field] = {} 164 added_fields[field]["default"] = src_field["default"] 165 if "default" not in src_field and "default" in dst_field: 166 subtracted_fields[field] = {} 167 subtracted_fields[field]["default"] = dst_field["default"] 168 elif src_kind == "enum": 169 _check( 170 src_field == dst_field, 171 f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}", 172 ) 173 elif src_kind == "union": 174 _check( 175 src_field["type"] == dst_field["type"], 176 f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}", 177 ) 178 else: 179 raise AssertionError(f"Unknown kind {src_kind}: {key}") 180 if len(added_fields) > 0: 181 assert key not in additions 182 additions[key] = {} 183 additions[key]["fields"] = added_fields 184 if len(subtracted_fields) > 0: 185 assert key not in subtractions 186 subtractions[key] = {} 187 subtractions[key]["fields"] = subtracted_fields 188 189 return additions, subtractions 190 191 192def _hash_schema(s): 193 return hashlib.sha256(repr(s).encode("utf-8")).hexdigest() 194 195 196@dataclasses.dataclass 197class _Commit: 198 result: Dict[str, Any] 199 checksum_result: str 200 path: str 201 additions: Dict[str, Any] 202 subtractions: Dict[str, Any] 203 base: Dict[str, Any] 204 checksum_base: Optional[str] 205 206 207def update_schema(): 208 import importlib.resources 209 210 if importlib.resources.is_resource(__package__, "schema.yaml"): 211 content = importlib.resources.read_text(__package__, "schema.yaml") 212 match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content) 213 _check(match is not None, "checksum not found in schema.yaml") 214 assert match is not None 215 checksum_base = match.group(1) 216 from yaml import load, Loader 217 218 dst = load(content, Loader=Loader) 219 assert isinstance(dst, dict) 220 else: 221 checksum_base = None 222 dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None} 223 224 src = _staged_schema() 225 additions, subtractions = _diff_schema(dst, src) 226 return _Commit( 227 result=src, 228 checksum_result=_hash_schema(src), 229 path=__package__.replace(".", "/") + "/schema.yaml", 230 additions=additions, 231 subtractions=subtractions, 232 base=dst, 233 checksum_base=checksum_base, 234 ) 235 236 237def check(commit: _Commit, force_unsafe: bool = False): 238 next_version = None 239 reason = "" 240 # Step 1: Detect major schema updates. 241 if len(commit.additions) > 0: 242 for k, v in commit.additions.items(): 243 if k not in commit.base: 244 continue 245 kind = commit.result[k]["kind"] 246 fields = v["fields"] 247 for f, d in fields.items(): 248 if "default" not in d and kind == "struct": 249 reason += ( 250 f"Field {k}.{f} is added to schema.py without a default value as an incomparible change " 251 + "which requires major version bump.\n" 252 ) 253 next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] 254 255 if len(commit.subtractions) > 0: 256 for k, v in commit.subtractions.items(): 257 if k not in commit.result: 258 continue 259 for f in v["fields"]: 260 reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n" 261 next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1] 262 263 if force_unsafe: 264 reason += "--force-unsafe is used." 265 next_version = commit.result["SCHEMA_VERSION"] 266 else: 267 # Step 2: Detect minor schema updates. 268 if next_version is None and len(commit.additions) > 0: 269 for k, v in commit.additions.items(): 270 for f in v["fields"]: 271 reason += ( 272 f"Field {k}.{f} is added to schema.py as an compatible change " 273 + "which still requires minor version bump.\n" 274 ) 275 next_version = [ 276 commit.base["SCHEMA_VERSION"][0], 277 commit.base["SCHEMA_VERSION"][1] + 1, 278 ] 279 if next_version is None and len(commit.subtractions) > 0: 280 for k, v in commit.subtractions.items(): 281 for f in v["fields"]: 282 reason += ( 283 f"Field {k}.{f} is removed from schema.py as an compatible change " 284 + "which still requires minor version bump.\n" 285 ) 286 next_version = [ 287 commit.base["SCHEMA_VERSION"][0], 288 commit.base["SCHEMA_VERSION"][1] + 1, 289 ] 290 291 return next_version, reason 292