1from __future__ import annotations 2 3import contextlib 4import functools 5import hashlib 6import os 7import re 8import sys 9import textwrap 10from dataclasses import fields, is_dataclass 11from enum import auto, Enum 12from pathlib import Path 13from typing import ( 14 Any, 15 Callable, 16 Generic, 17 Iterable, 18 Iterator, 19 Literal, 20 NoReturn, 21 Sequence, 22 TYPE_CHECKING, 23 TypeVar, 24) 25from typing_extensions import Self 26 27from torchgen.code_template import CodeTemplate 28 29 30if TYPE_CHECKING: 31 from argparse import Namespace 32 33 34REPO_ROOT = Path(__file__).absolute().parent.parent 35 36 37# Many of these functions share logic for defining both the definition 38# and declaration (for example, the function signature is the same), so 39# we organize them into one function that takes a Target to say which 40# code we want. 41# 42# This is an OPEN enum (we may add more cases to it in the future), so be sure 43# to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY] 44# what targets are valid for your use. 45class Target(Enum): 46 # top level namespace (not including at) 47 DEFINITION = auto() 48 DECLARATION = auto() 49 # TORCH_LIBRARY(...) { ... } 50 REGISTRATION = auto() 51 # namespace { ... } 52 ANONYMOUS_DEFINITION = auto() 53 # namespace cpu { ... } 54 NAMESPACED_DEFINITION = auto() 55 NAMESPACED_DECLARATION = auto() 56 57 58# Matches "foo" in "foo, bar" but not "foobar". Used to search for the 59# occurrence of a parameter in the derivative formula 60IDENT_REGEX = r"(^|\W){}($|\W)" 61 62 63# TODO: Use a real parser here; this will get bamboozled 64def split_name_params(schema: str) -> tuple[str, list[str]]: 65 m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema) 66 if m is None: 67 raise RuntimeError(f"Unsupported function schema: {schema}") 68 name, _, params = m.groups() 69 return name, params.split(", ") 70 71 72T = TypeVar("T") 73S = TypeVar("S") 74 75# These two functions purposely return generators in analogy to map() 76# so that you don't mix up when you need to list() them 77 78 79# Map over function that may return None; omit Nones from output sequence 80def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]: 81 for x in xs: 82 r = func(x) 83 if r is not None: 84 yield r 85 86 87# Map over function that returns sequences and cat them all together 88def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]: 89 for x in xs: 90 yield from func(x) 91 92 93# Conveniently add error context to exceptions raised. Lets us 94# easily say that an error occurred while processing a specific 95# context. 96@contextlib.contextmanager 97def context(msg_fn: Callable[[], str]) -> Iterator[None]: 98 try: 99 yield 100 except Exception as e: 101 # TODO: this does the wrong thing with KeyError 102 msg = msg_fn() 103 msg = textwrap.indent(msg, " ") 104 msg = f"{e.args[0]}\n{msg}" if e.args else msg 105 e.args = (msg,) + e.args[1:] 106 raise 107 108 109# A little trick from https://github.com/python/mypy/issues/6366 110# for getting mypy to do exhaustiveness checking 111# TODO: put this somewhere else, maybe 112def assert_never(x: NoReturn) -> NoReturn: 113 raise AssertionError(f"Unhandled type: {type(x).__name__}") 114 115 116@functools.lru_cache(maxsize=None) 117def _read_template(template_fn: str) -> CodeTemplate: 118 return CodeTemplate.from_file(template_fn) 119 120 121# String hash that's stable across different executions, unlike builtin hash 122def string_stable_hash(s: str) -> int: 123 sha1 = hashlib.sha1(s.encode("latin1")).digest() 124 return int.from_bytes(sha1, byteorder="little") 125 126 127# A small abstraction for writing out generated files and keeping track 128# of what files have been written (so you can write out a list of output 129# files) 130class FileManager: 131 install_dir: str 132 template_dir: str 133 dry_run: bool 134 filenames: set[str] 135 136 def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: 137 self.install_dir = install_dir 138 self.template_dir = template_dir 139 self.filenames = set() 140 self.dry_run = dry_run 141 142 def _write_if_changed(self, filename: str, contents: str) -> None: 143 old_contents: str | None 144 try: 145 with open(filename) as f: 146 old_contents = f.read() 147 except OSError: 148 old_contents = None 149 if contents != old_contents: 150 # Create output directory if it doesn't exist 151 os.makedirs(os.path.dirname(filename), exist_ok=True) 152 with open(filename, "w") as f: 153 f.write(contents) 154 155 # Read from template file and replace pattern with callable (type could be dict or str). 156 def substitute_with_template( 157 self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]] 158 ) -> str: 159 template_path = os.path.join(self.template_dir, template_fn) 160 env = env_callable() 161 if isinstance(env, dict): 162 if "generated_comment" not in env: 163 generator_default = REPO_ROOT / "torchgen" / "gen.py" 164 try: 165 generator = Path( 166 sys.modules["__main__"].__file__ or generator_default 167 ).absolute() 168 except (KeyError, AttributeError): 169 generator = generator_default.absolute() 170 171 try: 172 generator_path = generator.relative_to(REPO_ROOT).as_posix() 173 except ValueError: 174 generator_path = generator.name 175 176 env = { 177 **env, # copy the original dict instead of mutating it 178 "generated_comment": ( 179 "@" + f"generated by {generator_path} from {template_fn}" 180 ), 181 } 182 template = _read_template(template_path) 183 return template.substitute(env) 184 elif isinstance(env, str): 185 return env 186 else: 187 assert_never(env) 188 189 def write_with_template( 190 self, 191 filename: str, 192 template_fn: str, 193 env_callable: Callable[[], str | dict[str, Any]], 194 ) -> None: 195 filename = f"{self.install_dir}/{filename}" 196 assert filename not in self.filenames, "duplicate file write {filename}" 197 self.filenames.add(filename) 198 if not self.dry_run: 199 substitute_out = self.substitute_with_template( 200 template_fn=template_fn, 201 env_callable=env_callable, 202 ) 203 self._write_if_changed(filename=filename, contents=substitute_out) 204 205 def write( 206 self, 207 filename: str, 208 env_callable: Callable[[], str | dict[str, Any]], 209 ) -> None: 210 self.write_with_template(filename, filename, env_callable) 211 212 def write_sharded( 213 self, 214 filename: str, 215 items: Iterable[T], 216 *, 217 key_fn: Callable[[T], str], 218 env_callable: Callable[[T], dict[str, list[str]]], 219 num_shards: int, 220 base_env: dict[str, Any] | None = None, 221 sharded_keys: set[str], 222 ) -> None: 223 everything: dict[str, Any] = {"shard_id": "Everything"} 224 shards: list[dict[str, Any]] = [ 225 {"shard_id": f"_{i}"} for i in range(num_shards) 226 ] 227 all_shards = [everything] + shards 228 229 if base_env is not None: 230 for shard in all_shards: 231 shard.update(base_env) 232 233 for key in sharded_keys: 234 for shard in all_shards: 235 if key in shard: 236 assert isinstance( 237 shard[key], list 238 ), "sharded keys in base_env must be a list" 239 shard[key] = shard[key].copy() 240 else: 241 shard[key] = [] 242 243 def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None: 244 for k, v in from_.items(): 245 assert k in sharded_keys, f"undeclared sharded key {k}" 246 into[k] += v 247 248 if self.dry_run: 249 # Dry runs don't write any templates, so incomplete environments are fine 250 items = () 251 252 for item in items: 253 key = key_fn(item) 254 sid = string_stable_hash(key) % num_shards 255 env = env_callable(item) 256 257 merge_env(shards[sid], env) 258 merge_env(everything, env) 259 260 dot_pos = filename.rfind(".") 261 if dot_pos == -1: 262 dot_pos = len(filename) 263 base_filename = filename[:dot_pos] 264 extension = filename[dot_pos:] 265 266 for shard in all_shards: 267 shard_id = shard["shard_id"] 268 self.write_with_template( 269 f"{base_filename}{shard_id}{extension}", filename, lambda: shard 270 ) 271 272 # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled 273 self.filenames.discard( 274 f"{self.install_dir}/{base_filename}Everything{extension}" 275 ) 276 277 def write_outputs(self, variable_name: str, filename: str) -> None: 278 """Write a file containing the list of all outputs which are 279 generated by this script.""" 280 content = "set({}\n {})".format( 281 variable_name, 282 "\n ".join('"' + name + '"' for name in sorted(self.filenames)), 283 ) 284 self._write_if_changed(filename, content) 285 286 def template_dir_for_comments(self) -> str: 287 """ 288 This needs to be deterministic. The template dir is an absolute path 289 that varies across builds. So, just use the path relative to this file, 290 which will point to the codegen source but will be stable. 291 """ 292 return os.path.relpath(self.template_dir, os.path.dirname(__file__)) 293 294 295# Helper function to generate file manager 296def make_file_manager( 297 options: Namespace, install_dir: str | None = None 298) -> FileManager: 299 template_dir = os.path.join(options.source_path, "templates") 300 install_dir = install_dir if install_dir else options.install_dir 301 return FileManager( 302 install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run 303 ) 304 305 306# Helper function to create a pretty representation for dataclasses 307def dataclass_repr( 308 obj: Any, 309 indent: int = 0, 310 width: int = 80, 311) -> str: 312 # built-in pprint module support dataclasses from python 3.10 313 if sys.version_info >= (3, 10): 314 from pprint import pformat 315 316 return pformat(obj, indent, width) 317 318 return _pformat(obj, indent=indent, width=width) 319 320 321def _pformat( 322 obj: Any, 323 indent: int, 324 width: int, 325 curr_indent: int = 0, 326) -> str: 327 assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}" 328 329 class_name = obj.__class__.__name__ 330 # update current indentation level with class name 331 curr_indent += len(class_name) + 1 332 333 fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr] 334 335 fields_str = [] 336 for name, attr in fields_list: 337 # update the current indent level with the field name 338 # dict, list, set and tuple also add indent as done in pprint 339 _curr_indent = curr_indent + len(name) + 1 340 if is_dataclass(attr): 341 str_repr = _pformat(attr, indent, width, _curr_indent) 342 elif isinstance(attr, dict): 343 str_repr = _format_dict(attr, indent, width, _curr_indent) 344 elif isinstance(attr, (list, set, tuple)): 345 str_repr = _format_list(attr, indent, width, _curr_indent) 346 else: 347 str_repr = repr(attr) 348 349 fields_str.append(f"{name}={str_repr}") 350 351 indent_str = curr_indent * " " 352 body = f",\n{indent_str}".join(fields_str) 353 return f"{class_name}({body})" 354 355 356def _format_dict( 357 attr: dict[Any, Any], 358 indent: int, 359 width: int, 360 curr_indent: int, 361) -> str: 362 curr_indent += indent + 3 363 dict_repr = [] 364 for k, v in attr.items(): 365 k_repr = repr(k) 366 v_str = ( 367 _pformat(v, indent, width, curr_indent + len(k_repr)) 368 if is_dataclass(v) 369 else repr(v) 370 ) 371 dict_repr.append(f"{k_repr}: {v_str}") 372 373 return _format(dict_repr, indent, width, curr_indent, "{", "}") 374 375 376def _format_list( 377 attr: list[Any] | set[Any] | tuple[Any, ...], 378 indent: int, 379 width: int, 380 curr_indent: int, 381) -> str: 382 curr_indent += indent + 1 383 list_repr = [ 384 _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l) 385 for l in attr 386 ] 387 start, end = ("[", "]") if isinstance(attr, list) else ("(", ")") 388 return _format(list_repr, indent, width, curr_indent, start, end) 389 390 391def _format( 392 fields_str: list[str], 393 indent: int, 394 width: int, 395 curr_indent: int, 396 start: str, 397 end: str, 398) -> str: 399 delimiter, curr_indent_str = "", "" 400 # if it exceed the max width then we place one element per line 401 if len(repr(fields_str)) >= width: 402 delimiter = "\n" 403 curr_indent_str = " " * curr_indent 404 405 indent_str = " " * indent 406 body = f", {delimiter}{curr_indent_str}".join(fields_str) 407 return f"{start}{indent_str}{body}{end}" 408 409 410class NamespaceHelper: 411 """A helper for constructing the namespace open and close strings for a nested set of namespaces. 412 413 e.g. for namespace_str torch::lazy, 414 415 prologue: 416 namespace torch { 417 namespace lazy { 418 419 epilogue: 420 } // namespace lazy 421 } // namespace torch 422 """ 423 424 def __init__( 425 self, namespace_str: str, entity_name: str = "", max_level: int = 2 426 ) -> None: 427 # cpp_namespace can be a colon joined string such as torch::lazy 428 cpp_namespaces = namespace_str.split("::") 429 assert ( 430 len(cpp_namespaces) <= max_level 431 ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}." 432 self.cpp_namespace_ = namespace_str 433 self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces]) 434 self.epilogue_ = "\n".join( 435 [f"}} // namespace {n}" for n in reversed(cpp_namespaces)] 436 ) 437 self.namespaces_ = cpp_namespaces 438 self.entity_name_ = entity_name 439 440 @staticmethod 441 def from_namespaced_entity( 442 namespaced_entity: str, max_level: int = 2 443 ) -> NamespaceHelper: 444 """ 445 Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" 446 """ 447 names = namespaced_entity.split("::") 448 entity_name = names[-1] 449 namespace_str = "::".join(names[:-1]) 450 return NamespaceHelper( 451 namespace_str=namespace_str, entity_name=entity_name, max_level=max_level 452 ) 453 454 @property 455 def prologue(self) -> str: 456 return self.prologue_ 457 458 @property 459 def epilogue(self) -> str: 460 return self.epilogue_ 461 462 @property 463 def entity_name(self) -> str: 464 return self.entity_name_ 465 466 # Only allow certain level of namespaces 467 def get_cpp_namespace(self, default: str = "") -> str: 468 """ 469 Return the namespace string from joining all the namespaces by "::" (hence no leading "::"). 470 Return default if namespace string is empty. 471 """ 472 return self.cpp_namespace_ if self.cpp_namespace_ else default 473 474 475class OrderedSet(Generic[T]): 476 storage: dict[T, Literal[None]] 477 478 def __init__(self, iterable: Iterable[T] | None = None) -> None: 479 if iterable is None: 480 self.storage = {} 481 else: 482 self.storage = dict.fromkeys(iterable) 483 484 def __contains__(self, item: T) -> bool: 485 return item in self.storage 486 487 def __iter__(self) -> Iterator[T]: 488 return iter(self.storage.keys()) 489 490 def update(self, items: OrderedSet[T]) -> None: 491 self.storage.update(items.storage) 492 493 def add(self, item: T) -> None: 494 self.storage[item] = None 495 496 def copy(self) -> OrderedSet[T]: 497 ret: OrderedSet[T] = OrderedSet() 498 ret.storage = self.storage.copy() 499 return ret 500 501 @staticmethod 502 def union(*args: OrderedSet[T]) -> OrderedSet[T]: 503 ret = args[0].copy() 504 for s in args[1:]: 505 ret.update(s) 506 return ret 507 508 def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]: 509 return OrderedSet.union(self, other) 510 511 def __ior__(self, other: OrderedSet[T]) -> Self: 512 self.update(other) 513 return self 514 515 def __eq__(self, other: object) -> bool: 516 if isinstance(other, OrderedSet): 517 return self.storage == other.storage 518 else: 519 return set(self.storage.keys()) == other 520