xref: /aosp_15_r20/external/pytorch/torchgen/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import contextlib
4import functools
5import hashlib
6import os
7import re
8import sys
9import textwrap
10from dataclasses import fields, is_dataclass
11from enum import auto, Enum
12from pathlib import Path
13from typing import (
14    Any,
15    Callable,
16    Generic,
17    Iterable,
18    Iterator,
19    Literal,
20    NoReturn,
21    Sequence,
22    TYPE_CHECKING,
23    TypeVar,
24)
25from typing_extensions import Self
26
27from torchgen.code_template import CodeTemplate
28
29
30if TYPE_CHECKING:
31    from argparse import Namespace
32
33
34REPO_ROOT = Path(__file__).absolute().parent.parent
35
36
37# Many of these functions share logic for defining both the definition
38# and declaration (for example, the function signature is the same), so
39# we organize them into one function that takes a Target to say which
40# code we want.
41#
42# This is an OPEN enum (we may add more cases to it in the future), so be sure
43# to explicitly specify with Literal[Target.XXX] or Literal[Target.XXX, Target.YYY]
44# what targets are valid for your use.
45class Target(Enum):
46    # top level namespace (not including at)
47    DEFINITION = auto()
48    DECLARATION = auto()
49    # TORCH_LIBRARY(...) { ... }
50    REGISTRATION = auto()
51    # namespace { ... }
52    ANONYMOUS_DEFINITION = auto()
53    # namespace cpu { ... }
54    NAMESPACED_DEFINITION = auto()
55    NAMESPACED_DECLARATION = auto()
56
57
58# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
59# occurrence of a parameter in the derivative formula
60IDENT_REGEX = r"(^|\W){}($|\W)"
61
62
63# TODO: Use a real parser here; this will get bamboozled
64def split_name_params(schema: str) -> tuple[str, list[str]]:
65    m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
66    if m is None:
67        raise RuntimeError(f"Unsupported function schema: {schema}")
68    name, _, params = m.groups()
69    return name, params.split(", ")
70
71
72T = TypeVar("T")
73S = TypeVar("S")
74
75# These two functions purposely return generators in analogy to map()
76# so that you don't mix up when you need to list() them
77
78
79# Map over function that may return None; omit Nones from output sequence
80def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
81    for x in xs:
82        r = func(x)
83        if r is not None:
84            yield r
85
86
87# Map over function that returns sequences and cat them all together
88def concatMap(func: Callable[[T], Sequence[S]], xs: Iterable[T]) -> Iterator[S]:
89    for x in xs:
90        yield from func(x)
91
92
93# Conveniently add error context to exceptions raised.  Lets us
94# easily say that an error occurred while processing a specific
95# context.
96@contextlib.contextmanager
97def context(msg_fn: Callable[[], str]) -> Iterator[None]:
98    try:
99        yield
100    except Exception as e:
101        # TODO: this does the wrong thing with KeyError
102        msg = msg_fn()
103        msg = textwrap.indent(msg, "  ")
104        msg = f"{e.args[0]}\n{msg}" if e.args else msg
105        e.args = (msg,) + e.args[1:]
106        raise
107
108
109# A little trick from https://github.com/python/mypy/issues/6366
110# for getting mypy to do exhaustiveness checking
111# TODO: put this somewhere else, maybe
112def assert_never(x: NoReturn) -> NoReturn:
113    raise AssertionError(f"Unhandled type: {type(x).__name__}")
114
115
116@functools.lru_cache(maxsize=None)
117def _read_template(template_fn: str) -> CodeTemplate:
118    return CodeTemplate.from_file(template_fn)
119
120
121# String hash that's stable across different executions, unlike builtin hash
122def string_stable_hash(s: str) -> int:
123    sha1 = hashlib.sha1(s.encode("latin1")).digest()
124    return int.from_bytes(sha1, byteorder="little")
125
126
127# A small abstraction for writing out generated files and keeping track
128# of what files have been written (so you can write out a list of output
129# files)
130class FileManager:
131    install_dir: str
132    template_dir: str
133    dry_run: bool
134    filenames: set[str]
135
136    def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
137        self.install_dir = install_dir
138        self.template_dir = template_dir
139        self.filenames = set()
140        self.dry_run = dry_run
141
142    def _write_if_changed(self, filename: str, contents: str) -> None:
143        old_contents: str | None
144        try:
145            with open(filename) as f:
146                old_contents = f.read()
147        except OSError:
148            old_contents = None
149        if contents != old_contents:
150            # Create output directory if it doesn't exist
151            os.makedirs(os.path.dirname(filename), exist_ok=True)
152            with open(filename, "w") as f:
153                f.write(contents)
154
155    # Read from template file and replace pattern with callable (type could be dict or str).
156    def substitute_with_template(
157        self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
158    ) -> str:
159        template_path = os.path.join(self.template_dir, template_fn)
160        env = env_callable()
161        if isinstance(env, dict):
162            if "generated_comment" not in env:
163                generator_default = REPO_ROOT / "torchgen" / "gen.py"
164                try:
165                    generator = Path(
166                        sys.modules["__main__"].__file__ or generator_default
167                    ).absolute()
168                except (KeyError, AttributeError):
169                    generator = generator_default.absolute()
170
171                try:
172                    generator_path = generator.relative_to(REPO_ROOT).as_posix()
173                except ValueError:
174                    generator_path = generator.name
175
176                env = {
177                    **env,  # copy the original dict instead of mutating it
178                    "generated_comment": (
179                        "@" + f"generated by {generator_path} from {template_fn}"
180                    ),
181                }
182            template = _read_template(template_path)
183            return template.substitute(env)
184        elif isinstance(env, str):
185            return env
186        else:
187            assert_never(env)
188
189    def write_with_template(
190        self,
191        filename: str,
192        template_fn: str,
193        env_callable: Callable[[], str | dict[str, Any]],
194    ) -> None:
195        filename = f"{self.install_dir}/{filename}"
196        assert filename not in self.filenames, "duplicate file write {filename}"
197        self.filenames.add(filename)
198        if not self.dry_run:
199            substitute_out = self.substitute_with_template(
200                template_fn=template_fn,
201                env_callable=env_callable,
202            )
203            self._write_if_changed(filename=filename, contents=substitute_out)
204
205    def write(
206        self,
207        filename: str,
208        env_callable: Callable[[], str | dict[str, Any]],
209    ) -> None:
210        self.write_with_template(filename, filename, env_callable)
211
212    def write_sharded(
213        self,
214        filename: str,
215        items: Iterable[T],
216        *,
217        key_fn: Callable[[T], str],
218        env_callable: Callable[[T], dict[str, list[str]]],
219        num_shards: int,
220        base_env: dict[str, Any] | None = None,
221        sharded_keys: set[str],
222    ) -> None:
223        everything: dict[str, Any] = {"shard_id": "Everything"}
224        shards: list[dict[str, Any]] = [
225            {"shard_id": f"_{i}"} for i in range(num_shards)
226        ]
227        all_shards = [everything] + shards
228
229        if base_env is not None:
230            for shard in all_shards:
231                shard.update(base_env)
232
233        for key in sharded_keys:
234            for shard in all_shards:
235                if key in shard:
236                    assert isinstance(
237                        shard[key], list
238                    ), "sharded keys in base_env must be a list"
239                    shard[key] = shard[key].copy()
240                else:
241                    shard[key] = []
242
243        def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
244            for k, v in from_.items():
245                assert k in sharded_keys, f"undeclared sharded key {k}"
246                into[k] += v
247
248        if self.dry_run:
249            # Dry runs don't write any templates, so incomplete environments are fine
250            items = ()
251
252        for item in items:
253            key = key_fn(item)
254            sid = string_stable_hash(key) % num_shards
255            env = env_callable(item)
256
257            merge_env(shards[sid], env)
258            merge_env(everything, env)
259
260        dot_pos = filename.rfind(".")
261        if dot_pos == -1:
262            dot_pos = len(filename)
263        base_filename = filename[:dot_pos]
264        extension = filename[dot_pos:]
265
266        for shard in all_shards:
267            shard_id = shard["shard_id"]
268            self.write_with_template(
269                f"{base_filename}{shard_id}{extension}", filename, lambda: shard
270            )
271
272        # filenames is used to track compiled files, but FooEverything.cpp isn't meant to be compiled
273        self.filenames.discard(
274            f"{self.install_dir}/{base_filename}Everything{extension}"
275        )
276
277    def write_outputs(self, variable_name: str, filename: str) -> None:
278        """Write a file containing the list of all outputs which are
279        generated by this script."""
280        content = "set({}\n    {})".format(
281            variable_name,
282            "\n    ".join('"' + name + '"' for name in sorted(self.filenames)),
283        )
284        self._write_if_changed(filename, content)
285
286    def template_dir_for_comments(self) -> str:
287        """
288        This needs to be deterministic. The template dir is an absolute path
289        that varies across builds. So, just use the path relative to this file,
290        which will point to the codegen source but will be stable.
291        """
292        return os.path.relpath(self.template_dir, os.path.dirname(__file__))
293
294
295# Helper function to generate file manager
296def make_file_manager(
297    options: Namespace, install_dir: str | None = None
298) -> FileManager:
299    template_dir = os.path.join(options.source_path, "templates")
300    install_dir = install_dir if install_dir else options.install_dir
301    return FileManager(
302        install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run
303    )
304
305
306# Helper function to create a pretty representation for dataclasses
307def dataclass_repr(
308    obj: Any,
309    indent: int = 0,
310    width: int = 80,
311) -> str:
312    # built-in pprint module support dataclasses from python 3.10
313    if sys.version_info >= (3, 10):
314        from pprint import pformat
315
316        return pformat(obj, indent, width)
317
318    return _pformat(obj, indent=indent, width=width)
319
320
321def _pformat(
322    obj: Any,
323    indent: int,
324    width: int,
325    curr_indent: int = 0,
326) -> str:
327    assert is_dataclass(obj), f"obj should be a dataclass, received: {type(obj)}"
328
329    class_name = obj.__class__.__name__
330    # update current indentation level with class name
331    curr_indent += len(class_name) + 1
332
333    fields_list = [(f.name, getattr(obj, f.name)) for f in fields(obj) if f.repr]
334
335    fields_str = []
336    for name, attr in fields_list:
337        # update the current indent level with the field name
338        # dict, list, set and tuple also add indent as done in pprint
339        _curr_indent = curr_indent + len(name) + 1
340        if is_dataclass(attr):
341            str_repr = _pformat(attr, indent, width, _curr_indent)
342        elif isinstance(attr, dict):
343            str_repr = _format_dict(attr, indent, width, _curr_indent)
344        elif isinstance(attr, (list, set, tuple)):
345            str_repr = _format_list(attr, indent, width, _curr_indent)
346        else:
347            str_repr = repr(attr)
348
349        fields_str.append(f"{name}={str_repr}")
350
351    indent_str = curr_indent * " "
352    body = f",\n{indent_str}".join(fields_str)
353    return f"{class_name}({body})"
354
355
356def _format_dict(
357    attr: dict[Any, Any],
358    indent: int,
359    width: int,
360    curr_indent: int,
361) -> str:
362    curr_indent += indent + 3
363    dict_repr = []
364    for k, v in attr.items():
365        k_repr = repr(k)
366        v_str = (
367            _pformat(v, indent, width, curr_indent + len(k_repr))
368            if is_dataclass(v)
369            else repr(v)
370        )
371        dict_repr.append(f"{k_repr}: {v_str}")
372
373    return _format(dict_repr, indent, width, curr_indent, "{", "}")
374
375
376def _format_list(
377    attr: list[Any] | set[Any] | tuple[Any, ...],
378    indent: int,
379    width: int,
380    curr_indent: int,
381) -> str:
382    curr_indent += indent + 1
383    list_repr = [
384        _pformat(l, indent, width, curr_indent) if is_dataclass(l) else repr(l)
385        for l in attr
386    ]
387    start, end = ("[", "]") if isinstance(attr, list) else ("(", ")")
388    return _format(list_repr, indent, width, curr_indent, start, end)
389
390
391def _format(
392    fields_str: list[str],
393    indent: int,
394    width: int,
395    curr_indent: int,
396    start: str,
397    end: str,
398) -> str:
399    delimiter, curr_indent_str = "", ""
400    # if it exceed the max width then we place one element per line
401    if len(repr(fields_str)) >= width:
402        delimiter = "\n"
403        curr_indent_str = " " * curr_indent
404
405    indent_str = " " * indent
406    body = f", {delimiter}{curr_indent_str}".join(fields_str)
407    return f"{start}{indent_str}{body}{end}"
408
409
410class NamespaceHelper:
411    """A helper for constructing the namespace open and close strings for a nested set of namespaces.
412
413    e.g. for namespace_str torch::lazy,
414
415    prologue:
416    namespace torch {
417    namespace lazy {
418
419    epilogue:
420    } // namespace lazy
421    } // namespace torch
422    """
423
424    def __init__(
425        self, namespace_str: str, entity_name: str = "", max_level: int = 2
426    ) -> None:
427        # cpp_namespace can be a colon joined string such as torch::lazy
428        cpp_namespaces = namespace_str.split("::")
429        assert (
430            len(cpp_namespaces) <= max_level
431        ), f"Codegen doesn't support more than {max_level} level(s) of custom namespace. Got {namespace_str}."
432        self.cpp_namespace_ = namespace_str
433        self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
434        self.epilogue_ = "\n".join(
435            [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
436        )
437        self.namespaces_ = cpp_namespaces
438        self.entity_name_ = entity_name
439
440    @staticmethod
441    def from_namespaced_entity(
442        namespaced_entity: str, max_level: int = 2
443    ) -> NamespaceHelper:
444        """
445        Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
446        """
447        names = namespaced_entity.split("::")
448        entity_name = names[-1]
449        namespace_str = "::".join(names[:-1])
450        return NamespaceHelper(
451            namespace_str=namespace_str, entity_name=entity_name, max_level=max_level
452        )
453
454    @property
455    def prologue(self) -> str:
456        return self.prologue_
457
458    @property
459    def epilogue(self) -> str:
460        return self.epilogue_
461
462    @property
463    def entity_name(self) -> str:
464        return self.entity_name_
465
466    # Only allow certain level of namespaces
467    def get_cpp_namespace(self, default: str = "") -> str:
468        """
469        Return the namespace string from joining all the namespaces by "::" (hence no leading "::").
470        Return default if namespace string is empty.
471        """
472        return self.cpp_namespace_ if self.cpp_namespace_ else default
473
474
475class OrderedSet(Generic[T]):
476    storage: dict[T, Literal[None]]
477
478    def __init__(self, iterable: Iterable[T] | None = None) -> None:
479        if iterable is None:
480            self.storage = {}
481        else:
482            self.storage = dict.fromkeys(iterable)
483
484    def __contains__(self, item: T) -> bool:
485        return item in self.storage
486
487    def __iter__(self) -> Iterator[T]:
488        return iter(self.storage.keys())
489
490    def update(self, items: OrderedSet[T]) -> None:
491        self.storage.update(items.storage)
492
493    def add(self, item: T) -> None:
494        self.storage[item] = None
495
496    def copy(self) -> OrderedSet[T]:
497        ret: OrderedSet[T] = OrderedSet()
498        ret.storage = self.storage.copy()
499        return ret
500
501    @staticmethod
502    def union(*args: OrderedSet[T]) -> OrderedSet[T]:
503        ret = args[0].copy()
504        for s in args[1:]:
505            ret.update(s)
506        return ret
507
508    def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
509        return OrderedSet.union(self, other)
510
511    def __ior__(self, other: OrderedSet[T]) -> Self:
512        self.update(other)
513        return self
514
515    def __eq__(self, other: object) -> bool:
516        if isinstance(other, OrderedSet):
517            return self.storage == other.storage
518        else:
519            return set(self.storage.keys()) == other
520