xref: /aosp_15_r20/external/pytorch/torch/jit/supported_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3import textwrap
4
5import torch.jit
6from torch.jit._builtins import _find_builtin
7
8
9# this file is for generating documentation using sphinx autodoc
10# > help(torch.jit.supported_ops) will also give a nice listed of the
11# supported ops programmatically
12
13
14def _hidden(name):
15    return name.startswith("_") and not name.startswith("__")
16
17
18def _emit_type(type):
19    return str(type)
20
21
22def _emit_arg(indent, i, arg):
23    v = f"{arg.name} : {_emit_type(arg.type)}"
24    default = arg.default_value
25    if default is not None:
26        v = f"{v}={str(default)}"
27    if i > 0:
28        v = f"\n{' ' * indent}{v}"
29    return v
30
31
32def _emit_args(indent, arguments):
33    return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments))
34
35
36def _emit_ret(ret):
37    return _emit_type(ret.type)
38
39
40def _emit_rets(returns):
41    if len(returns) == 1:
42        return _emit_ret(returns[0])
43    return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]"
44
45
46def _emit_schema(mod, name, schema, arg_start=0, padding=4):
47    if mod is None:
48        qualified_name = name
49    else:
50        qualified_name = f"{mod}.{name}"
51    schema_str = (
52        f"{qualified_name}"
53        f"({_emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:])}) "
54        f"-> {_emit_rets(schema.returns)}"
55    )
56    return schema_str
57
58
59def _get_tensor_ops():
60    def is_tensor_method(schema):
61        if len(schema.arguments) == 0:
62            return False
63        self = schema.arguments[0]
64        if self.name != "self":
65            return False
66        if not self.type.isSubtypeOf(torch._C.TensorType.get()):
67            return False
68        return True
69
70    methods = []
71    # discover methods
72    for elem in dir(torch.Tensor):
73        if not _hidden(elem):
74            schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
75            for schema in schemas:
76                if is_tensor_method(schema):
77                    methods.append(_emit_schema("Tensor", elem, schema, arg_start=1))
78
79    return "Supported Tensor Methods", methods
80
81
82def _get_nn_functional_ops():
83    functions = []
84
85    # Iterate over torch.nn.functional
86    mod = torch.nn.functional
87    name = mod.__name__
88    for elem in dir(torch.nn.functional):
89        attr = getattr(mod, elem)
90        if not inspect.isfunction(attr) or _hidden(elem[0]):
91            # Ignore non-functions and internal methods
92            continue
93
94        attr_module = inspect.getmodule(attr)
95        if not attr_module:
96            raise RuntimeError(f"Module for {attr} not found")
97
98        if "torch.nn.functional" not in attr_module.__name__:
99            # Ignore functions from outside torch.nn.functional
100            continue
101
102        try:
103            # compile fn, get schema
104            scripted = torch.jit.script(attr)
105            scripted_schema = scripted.schema
106            functions.append(_emit_schema(name, elem, scripted_schema))
107        except:  # noqa: B001,E722
108            # Skip interpolate / boolean dispatched things
109            pass
110
111    # Iterate over modules that we know contain a lot of builtins
112    for mod in torch.jit._builtins._modules_containing_builtins:
113        name = mod.__name__
114        for elem in dir(mod):
115            builtin = _find_builtin(getattr(mod, elem))
116            if builtin is not None:
117                schemas = torch._C._jit_get_schemas_for_operator(builtin)
118                for schema in schemas:
119                    # remove _tan but not __and__
120                    if not _hidden(elem):
121                        functions.append(_emit_schema(name, elem, schema))
122    return "Supported PyTorch Functions", functions
123
124
125def _get_builtins_helper():
126    builtins = []
127    for fn, _builtin_name in torch.jit._builtins._builtin_ops:
128        mod = inspect.getmodule(fn)
129
130        if not hasattr(fn, "__name__"):
131            # typing classes
132            continue
133        if not mod:
134            continue
135        if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__):
136            # skip internal-only methods
137            continue
138
139        if "torch._C" in mod.__name__:
140            continue
141
142        builtins.append((fn, _builtin_name))
143
144    return builtins
145
146
147def _is_math_fn(fn):
148    mod = inspect.getmodule(fn)
149    if not mod:
150        raise RuntimeError(f"Module for {fn} not found")
151
152    return mod.__name__ == "math"
153
154
155def _get_torchscript_builtins():
156    functions = []
157    builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper())
158    builtins_list = list(builtins)
159    # Iterate over the specially added builtins
160    for fn, _builtin_name in builtins_list:
161        mod = inspect.getmodule(fn)
162        if not mod:
163            raise RuntimeError(f"Module for {fn} not found")
164        builtin = _find_builtin(fn)
165        if builtin is not None:
166            schemas = torch._C._jit_get_schemas_for_operator(builtin)
167            for schema in schemas:
168                functions.append(_emit_schema(mod.__name__, fn.__name__, schema))
169
170    return "TorchScript Builtin Functions", functions
171
172
173def _get_math_builtins():
174    functions = []
175    builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper())
176    builtins_list = list(builtins)
177    # Iterate over the specially added builtins
178    for fn, _builtin_name in builtins_list:
179        mod = inspect.getmodule(fn)
180        if not mod:
181            raise RuntimeError(f"Module for {fn} not found")
182        builtin = _find_builtin(fn)
183        if builtin is not None:
184            schemas = torch._C._jit_get_schemas_for_operator(builtin)
185            for schema in schemas:
186                schema_str = _emit_schema(mod.__name__, fn.__name__, schema)
187                if "Tensor" in schema_str:
188                    # Skip Tensor ops that have the same name as math functions
189                    # (they will show up in the tensor methods section)
190                    continue
191                functions.append(schema)
192
193    return "``math`` Module", functions
194
195
196def _get_global_builtins():
197    # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp
198    supported_builtins = [
199        "print",
200        "tuple",
201        "float",
202        "complex",
203        "int",
204        "bool",
205        "str",
206        "getattr",
207        "hasattr",
208        "isinstance",
209        "len",
210        "hex",
211        "oct",
212        "round",
213        "hash",
214        "min",
215        "max",
216        "abs",
217        "all",
218        "divmod",
219        "list",
220        "ord",
221        "chr",
222        "bin",
223        "range",
224        "zip",
225        "enumerate",
226        "sorted",
227    ]
228
229    op_renames = {
230        "bool": "aten::Bool",
231        "int": "aten::Int",
232        "float": "aten::Float",
233        "complex": "aten::Complex",
234        "abs": "prim::abs",
235        "max": "prim::max",
236        "min": "prim::min",
237        "range": "fake::does_not_exist",
238    }
239
240    schemaless_op_explanations = {
241        "print": "Print any value",
242        "tuple": "Lists cannot be converted to tuples with this method since their size is not statically known",
243        "getattr": "Attribute name must be a literal string",
244        "hasattr": "Attribute name must be a literal string",
245        "isinstance": "Result is static",
246        "zip": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.",
247        "enumerate": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.",
248        "range": "Can only be used as an iterator in a for loop",
249    }
250
251    magic_methods = [
252        ("complex", "__complex__"),
253        ("float", "__float__"),
254        ("int", "__int__"),
255        ("bool", "__bool__"),
256        ("str", "__str__"),
257        ("len", "__len__"),
258        ("hex", "__hex__"),
259        ("oct", "__oct__"),
260    ]
261
262    magic_methods_rows = []
263    for fn, magic_method in magic_methods:
264        magic_methods_rows.append(f'"{fn}", "``{magic_method}``"')
265
266    schematized_ops = []
267    schemaless_ops = []
268
269    for fn in supported_builtins:
270        op_name = f"aten::{fn}"
271        if fn in op_renames:
272            op_name = op_renames[fn]
273        schemas = torch._C._jit_get_schemas_for_operator(op_name)
274        for s in schemas:
275            schematized_ops.append(_emit_schema(None, fn, s, padding=0))
276        if len(schemas) > 0:
277            schematized_ops.append("")
278        else:
279            table_row = (
280                f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"'
281            )
282            schemaless_ops.append(table_row)
283
284    schematized_ops_str = "\n".join(schematized_ops)
285    schemaless_ops_str = "\n".join(schemaless_ops)
286    magic_methods_rows_str = "\n".join(magic_methods_rows)
287    schematized_ops_str = textwrap.indent(schematized_ops_str, "\t")
288    schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t")
289    magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t")
290    section = f"""
291The functions in the following table are supported but do not have a static schema
292
293.. csv-table::
294    :header: "Function", "Note"
295
296{schemaless_ops_str}
297
298The following functions will use the corresponding magic method on :any:`TorchScript classes`
299
300.. csv-table::
301    :header: "Function", "Magic Method"
302
303{magic_methods_rows_str}
304
305These built-in functions use the schema
306
307.. rst-class:: codeblock-height-limiter
308
309::
310
311{schematized_ops_str}
312    """
313
314    return "Python Built-in Functions", section
315
316
317def _list_supported_ops():
318    def emit_block(decls):
319        return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format(
320            "".join(f"    {d}\n\n" for d in decls)
321        )
322
323    body = ""
324    op_gathering_fns = (
325        _get_tensor_ops,
326        _get_nn_functional_ops,
327        _get_torchscript_builtins,
328        _get_global_builtins,
329        _get_math_builtins,
330    )
331    for fn in op_gathering_fns:
332        header, items = fn()
333        link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-")
334        if isinstance(items, str):
335            section = f"{header}\n{'~' * len(header)}\n{items}\n"
336        else:
337            section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}"
338        section = f".. _{link_target}:" + "\n\n" + section
339        body += section
340
341    return body
342
343
344__doc__ = _list_supported_ops()
345