1# mypy: allow-untyped-defs 2import inspect 3import textwrap 4 5import torch.jit 6from torch.jit._builtins import _find_builtin 7 8 9# this file is for generating documentation using sphinx autodoc 10# > help(torch.jit.supported_ops) will also give a nice listed of the 11# supported ops programmatically 12 13 14def _hidden(name): 15 return name.startswith("_") and not name.startswith("__") 16 17 18def _emit_type(type): 19 return str(type) 20 21 22def _emit_arg(indent, i, arg): 23 v = f"{arg.name} : {_emit_type(arg.type)}" 24 default = arg.default_value 25 if default is not None: 26 v = f"{v}={str(default)}" 27 if i > 0: 28 v = f"\n{' ' * indent}{v}" 29 return v 30 31 32def _emit_args(indent, arguments): 33 return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments)) 34 35 36def _emit_ret(ret): 37 return _emit_type(ret.type) 38 39 40def _emit_rets(returns): 41 if len(returns) == 1: 42 return _emit_ret(returns[0]) 43 return f"Tuple[{', '.join(_emit_ret(r) for r in returns)}]" 44 45 46def _emit_schema(mod, name, schema, arg_start=0, padding=4): 47 if mod is None: 48 qualified_name = name 49 else: 50 qualified_name = f"{mod}.{name}" 51 schema_str = ( 52 f"{qualified_name}" 53 f"({_emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:])}) " 54 f"-> {_emit_rets(schema.returns)}" 55 ) 56 return schema_str 57 58 59def _get_tensor_ops(): 60 def is_tensor_method(schema): 61 if len(schema.arguments) == 0: 62 return False 63 self = schema.arguments[0] 64 if self.name != "self": 65 return False 66 if not self.type.isSubtypeOf(torch._C.TensorType.get()): 67 return False 68 return True 69 70 methods = [] 71 # discover methods 72 for elem in dir(torch.Tensor): 73 if not _hidden(elem): 74 schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem) 75 for schema in schemas: 76 if is_tensor_method(schema): 77 methods.append(_emit_schema("Tensor", elem, schema, arg_start=1)) 78 79 return "Supported Tensor Methods", methods 80 81 82def _get_nn_functional_ops(): 83 functions = [] 84 85 # Iterate over torch.nn.functional 86 mod = torch.nn.functional 87 name = mod.__name__ 88 for elem in dir(torch.nn.functional): 89 attr = getattr(mod, elem) 90 if not inspect.isfunction(attr) or _hidden(elem[0]): 91 # Ignore non-functions and internal methods 92 continue 93 94 attr_module = inspect.getmodule(attr) 95 if not attr_module: 96 raise RuntimeError(f"Module for {attr} not found") 97 98 if "torch.nn.functional" not in attr_module.__name__: 99 # Ignore functions from outside torch.nn.functional 100 continue 101 102 try: 103 # compile fn, get schema 104 scripted = torch.jit.script(attr) 105 scripted_schema = scripted.schema 106 functions.append(_emit_schema(name, elem, scripted_schema)) 107 except: # noqa: B001,E722 108 # Skip interpolate / boolean dispatched things 109 pass 110 111 # Iterate over modules that we know contain a lot of builtins 112 for mod in torch.jit._builtins._modules_containing_builtins: 113 name = mod.__name__ 114 for elem in dir(mod): 115 builtin = _find_builtin(getattr(mod, elem)) 116 if builtin is not None: 117 schemas = torch._C._jit_get_schemas_for_operator(builtin) 118 for schema in schemas: 119 # remove _tan but not __and__ 120 if not _hidden(elem): 121 functions.append(_emit_schema(name, elem, schema)) 122 return "Supported PyTorch Functions", functions 123 124 125def _get_builtins_helper(): 126 builtins = [] 127 for fn, _builtin_name in torch.jit._builtins._builtin_ops: 128 mod = inspect.getmodule(fn) 129 130 if not hasattr(fn, "__name__"): 131 # typing classes 132 continue 133 if not mod: 134 continue 135 if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__): 136 # skip internal-only methods 137 continue 138 139 if "torch._C" in mod.__name__: 140 continue 141 142 builtins.append((fn, _builtin_name)) 143 144 return builtins 145 146 147def _is_math_fn(fn): 148 mod = inspect.getmodule(fn) 149 if not mod: 150 raise RuntimeError(f"Module for {fn} not found") 151 152 return mod.__name__ == "math" 153 154 155def _get_torchscript_builtins(): 156 functions = [] 157 builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper()) 158 builtins_list = list(builtins) 159 # Iterate over the specially added builtins 160 for fn, _builtin_name in builtins_list: 161 mod = inspect.getmodule(fn) 162 if not mod: 163 raise RuntimeError(f"Module for {fn} not found") 164 builtin = _find_builtin(fn) 165 if builtin is not None: 166 schemas = torch._C._jit_get_schemas_for_operator(builtin) 167 for schema in schemas: 168 functions.append(_emit_schema(mod.__name__, fn.__name__, schema)) 169 170 return "TorchScript Builtin Functions", functions 171 172 173def _get_math_builtins(): 174 functions = [] 175 builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper()) 176 builtins_list = list(builtins) 177 # Iterate over the specially added builtins 178 for fn, _builtin_name in builtins_list: 179 mod = inspect.getmodule(fn) 180 if not mod: 181 raise RuntimeError(f"Module for {fn} not found") 182 builtin = _find_builtin(fn) 183 if builtin is not None: 184 schemas = torch._C._jit_get_schemas_for_operator(builtin) 185 for schema in schemas: 186 schema_str = _emit_schema(mod.__name__, fn.__name__, schema) 187 if "Tensor" in schema_str: 188 # Skip Tensor ops that have the same name as math functions 189 # (they will show up in the tensor methods section) 190 continue 191 functions.append(schema) 192 193 return "``math`` Module", functions 194 195 196def _get_global_builtins(): 197 # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp 198 supported_builtins = [ 199 "print", 200 "tuple", 201 "float", 202 "complex", 203 "int", 204 "bool", 205 "str", 206 "getattr", 207 "hasattr", 208 "isinstance", 209 "len", 210 "hex", 211 "oct", 212 "round", 213 "hash", 214 "min", 215 "max", 216 "abs", 217 "all", 218 "divmod", 219 "list", 220 "ord", 221 "chr", 222 "bin", 223 "range", 224 "zip", 225 "enumerate", 226 "sorted", 227 ] 228 229 op_renames = { 230 "bool": "aten::Bool", 231 "int": "aten::Int", 232 "float": "aten::Float", 233 "complex": "aten::Complex", 234 "abs": "prim::abs", 235 "max": "prim::max", 236 "min": "prim::min", 237 "range": "fake::does_not_exist", 238 } 239 240 schemaless_op_explanations = { 241 "print": "Print any value", 242 "tuple": "Lists cannot be converted to tuples with this method since their size is not statically known", 243 "getattr": "Attribute name must be a literal string", 244 "hasattr": "Attribute name must be a literal string", 245 "isinstance": "Result is static", 246 "zip": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.", 247 "enumerate": "Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.", 248 "range": "Can only be used as an iterator in a for loop", 249 } 250 251 magic_methods = [ 252 ("complex", "__complex__"), 253 ("float", "__float__"), 254 ("int", "__int__"), 255 ("bool", "__bool__"), 256 ("str", "__str__"), 257 ("len", "__len__"), 258 ("hex", "__hex__"), 259 ("oct", "__oct__"), 260 ] 261 262 magic_methods_rows = [] 263 for fn, magic_method in magic_methods: 264 magic_methods_rows.append(f'"{fn}", "``{magic_method}``"') 265 266 schematized_ops = [] 267 schemaless_ops = [] 268 269 for fn in supported_builtins: 270 op_name = f"aten::{fn}" 271 if fn in op_renames: 272 op_name = op_renames[fn] 273 schemas = torch._C._jit_get_schemas_for_operator(op_name) 274 for s in schemas: 275 schematized_ops.append(_emit_schema(None, fn, s, padding=0)) 276 if len(schemas) > 0: 277 schematized_ops.append("") 278 else: 279 table_row = ( 280 f'":external+python:py:obj:`{fn}`", "{schemaless_op_explanations[fn]}"' 281 ) 282 schemaless_ops.append(table_row) 283 284 schematized_ops_str = "\n".join(schematized_ops) 285 schemaless_ops_str = "\n".join(schemaless_ops) 286 magic_methods_rows_str = "\n".join(magic_methods_rows) 287 schematized_ops_str = textwrap.indent(schematized_ops_str, "\t") 288 schemaless_ops_str = textwrap.indent(schemaless_ops_str, "\t") 289 magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, "\t") 290 section = f""" 291The functions in the following table are supported but do not have a static schema 292 293.. csv-table:: 294 :header: "Function", "Note" 295 296{schemaless_ops_str} 297 298The following functions will use the corresponding magic method on :any:`TorchScript classes` 299 300.. csv-table:: 301 :header: "Function", "Magic Method" 302 303{magic_methods_rows_str} 304 305These built-in functions use the schema 306 307.. rst-class:: codeblock-height-limiter 308 309:: 310 311{schematized_ops_str} 312 """ 313 314 return "Python Built-in Functions", section 315 316 317def _list_supported_ops(): 318 def emit_block(decls): 319 return "\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n".format( 320 "".join(f" {d}\n\n" for d in decls) 321 ) 322 323 body = "" 324 op_gathering_fns = ( 325 _get_tensor_ops, 326 _get_nn_functional_ops, 327 _get_torchscript_builtins, 328 _get_global_builtins, 329 _get_math_builtins, 330 ) 331 for fn in op_gathering_fns: 332 header, items = fn() 333 link_target = header.replace("`", "").replace("-", "").lower().replace(" ", "-") 334 if isinstance(items, str): 335 section = f"{header}\n{'~' * len(header)}\n{items}\n" 336 else: 337 section = f"{header}\n{'~' * len(header)}\n{emit_block(items)}" 338 section = f".. _{link_target}:" + "\n\n" + section 339 body += section 340 341 return body 342 343 344__doc__ = _list_supported_ops() 345