xref: /aosp_15_r20/external/pytorch/torch/_inductor/exc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import os
5import tempfile
6import textwrap
7from functools import lru_cache
8
9
10if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
11
12    @lru_cache(None)
13    def _record_missing_op(target):
14        with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd:
15            fd.write(str(target) + "\n")
16
17else:
18
19    def _record_missing_op(target):  # type: ignore[misc]
20        pass
21
22
23class OperatorIssue(RuntimeError):
24    @staticmethod
25    def operator_str(target, args, kwargs):
26        lines = [f"target: {target}"] + [
27            f"args[{i}]: {arg}" for i, arg in enumerate(args)
28        ]
29        if kwargs:
30            lines.append(f"kwargs: {kwargs}")
31        return textwrap.indent("\n".join(lines), "  ")
32
33
34class MissingOperatorWithoutDecomp(OperatorIssue):
35    def __init__(self, target, args, kwargs) -> None:
36        _record_missing_op(target)
37        super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
38
39
40class MissingOperatorWithDecomp(OperatorIssue):
41    def __init__(self, target, args, kwargs) -> None:
42        _record_missing_op(target)
43        super().__init__(
44            f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
45            + textwrap.dedent(
46                f"""
47
48                There is a decomposition available for {target} in
49                torch._decomp.get_decompositions().  Please add this operator to the
50                `decompositions` list in torch._inductor.decomposition
51                """
52            )
53        )
54
55
56class LoweringException(OperatorIssue):
57    def __init__(self, exc: Exception, target, args, kwargs) -> None:
58        super().__init__(
59            f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
60        )
61
62
63class SubgraphLoweringException(RuntimeError):
64    pass
65
66
67class InvalidCxxCompiler(RuntimeError):
68    def __init__(self) -> None:
69        from . import config
70
71        super().__init__(
72            f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
73        )
74
75
76class CppWrapperCodeGenError(RuntimeError):
77    def __init__(self, msg: str) -> None:
78        super().__init__(f"C++ wrapper codegen error: {msg}")
79
80
81class CppCompileError(RuntimeError):
82    def __init__(self, cmd: list[str], output: str) -> None:
83        if isinstance(output, bytes):
84            output = output.decode("utf-8")
85
86        super().__init__(
87            textwrap.dedent(
88                """
89                    C++ compile error
90
91                    Command:
92                    {cmd}
93
94                    Output:
95                    {output}
96                """
97            )
98            .strip()
99            .format(cmd=" ".join(cmd), output=output)
100        )
101
102
103class CUDACompileError(CppCompileError):
104    pass
105