1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import inspect 4import typing 5from typing import List, Optional, Sequence, Union # noqa: F401 6 7import torch 8from torch import device, dtype, Tensor, types 9from torch.utils._exposed_in import exposed_in 10 11 12@exposed_in("torch.library") 13def infer_schema( 14 prototype_function: typing.Callable, 15 /, 16 *, 17 mutates_args, 18 op_name: Optional[str] = None, 19) -> str: 20 r"""Parses the schema of a given function with type hints. The schema is inferred from the 21 function's type hints, and can be used to define a new operator. 22 23 We make the following assumptions: 24 25 * None of the outputs alias any of the inputs or each other. 26 * | String type annotations "device, dtype, Tensor, types" without library specification are 27 | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union" 28 | without library specification are assumed to be typing.*. 29 * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown", 30 | it assumes that all inputs to the operator are being mutates. 31 32 Callers (e.g. the custom ops API) are responsible for checking these assumptions. 33 34 Args: 35 prototype_function: The function from which to infer a schema for from its type annotations. 36 op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the 37 name is not included in the inferred schema. Note that the input schema to 38 ``torch.library.Library.define`` requires a operator name. 39 mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function. 40 41 Returns: 42 The inferred schema. 43 44 Example: 45 >>> def foo_impl(x: torch.Tensor) -> torch.Tensor: 46 >>> return x.sin() 47 >>> 48 >>> infer_schema(foo_impl, op_name="foo", mutates_args={}) 49 foo(Tensor x) -> Tensor 50 >>> 51 >>> infer_schema(foo_impl, mutates_args={}) 52 (Tensor x) -> Tensor 53 """ 54 UNKNOWN_MUTATES = "unknown" 55 sig = inspect.signature(prototype_function) 56 57 def error_fn(what): 58 raise ValueError( 59 f"infer_schema(func): {what} " f"Got func with signature {sig})" 60 ) 61 62 def convert_type_string(annotation_type: str): 63 try: 64 return eval(annotation_type) 65 except Exception as e: 66 error_fn( 67 f"Unsupported type annotation {annotation_type}. It is not a type." 68 ) 69 70 params = [] 71 seen_args = set() 72 saw_kwarg_only_arg = False 73 for idx, (name, param) in enumerate(sig.parameters.items()): 74 if not supported_param(param): 75 error_fn("We do not support positional-only args, varargs, or varkwargs.") 76 77 if param.kind == inspect.Parameter.KEYWORD_ONLY: 78 # The first time we see a kwarg-only arg, add "*" to the schema. 79 if not saw_kwarg_only_arg: 80 params.append("*") 81 saw_kwarg_only_arg = True 82 83 if param.annotation is inspect.Parameter.empty: 84 error_fn(f"Parameter {name} must have a type annotation.") 85 86 # The annotation might be converted to a string by annotation, 87 # we convert it to the actual type. 88 annotation_type = param.annotation 89 if type(annotation_type) == str: 90 annotation_type = convert_type_string(annotation_type) 91 92 if annotation_type not in SUPPORTED_PARAM_TYPES.keys(): 93 if annotation_type.__origin__ is tuple: 94 list_type = tuple_to_list(annotation_type) 95 example_type_str = "\n\n" 96 # Only suggest the list type if this type is supported. 97 if list_type in SUPPORTED_PARAM_TYPES.keys(): 98 example_type_str = f"For example, {list_type}.\n\n" 99 error_fn( 100 f"Parameter {name} has unsupported type {param.annotation}. " 101 f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. " 102 f"{example_type_str}" 103 f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." 104 ) 105 else: 106 error_fn( 107 f"Parameter {name} has unsupported type {param.annotation}. " 108 f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}." 109 ) 110 111 schema_type = SUPPORTED_PARAM_TYPES[annotation_type] 112 if type(mutates_args) == str: 113 if mutates_args != UNKNOWN_MUTATES: 114 raise ValueError( 115 "mutates_args must either be a sequence of the names of " 116 "the arguments that are mutated or the string 'unknown'. " 117 ) 118 if schema_type.startswith("Tensor"): 119 schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" 120 elif name in mutates_args: 121 if not schema_type.startswith("Tensor"): 122 error_fn( 123 f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated" 124 ) 125 schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}" 126 seen_args.add(name) 127 if param.default is inspect.Parameter.empty: 128 params.append(f"{schema_type} {name}") 129 else: 130 default_repr = None 131 if param.default is None or isinstance(param.default, (int, float, bool)): 132 default_repr = str(param.default) 133 elif isinstance(param.default, (str, torch.device)): 134 default_repr = f'"{param.default}"' 135 elif isinstance(param.default, torch.dtype): 136 dtype_repr = str(param.default) 137 torch_dot = "torch." 138 assert dtype_repr.startswith(torch_dot) 139 default_repr = dtype_repr[len(torch_dot) :] 140 else: 141 error_fn( 142 f"Parameter {name} has an unsupported default value type {type(param.default)}. " 143 f"Please file an issue on GitHub so we can prioritize this." 144 ) 145 params.append(f"{schema_type} {name}={default_repr}") 146 if mutates_args != UNKNOWN_MUTATES: 147 mutates_args_not_seen = set(mutates_args) - seen_args 148 if len(mutates_args_not_seen) > 0: 149 error_fn( 150 f"{mutates_args_not_seen} in mutates_args were not found in " 151 f"the custom op's signature. " 152 f"mutates_args should contain the names of all args that the " 153 f"custom op mutates, or just the string 'unknown' if you don't know." 154 ) 155 return_annotation = sig.return_annotation 156 if type(return_annotation) == str: 157 return_annotation = convert_type_string(return_annotation) 158 ret = parse_return(return_annotation, error_fn) 159 if op_name is not None: 160 return f"{op_name}({', '.join(params)}) -> {ret}" 161 return f"({', '.join(params)}) -> {ret}" 162 163 164def derived_types( 165 base_type, cpp_type, list_base, optional_base_list, optional_list_base 166): 167 result = [ 168 (base_type, cpp_type), 169 (typing.Optional[base_type], f"{cpp_type}?"), 170 ] 171 172 def derived_seq_types(typ): 173 return [ 174 typing.Sequence[typ], # type: ignore[valid-type] 175 typing.List[typ], # type: ignore[valid-type] 176 ] 177 178 if list_base: 179 for seq_typ in derived_seq_types(base_type): 180 result.append((seq_typ, f"{cpp_type}[]")) # type: ignore[valid-type] 181 if optional_base_list: 182 for seq_typ in derived_seq_types(typing.Optional[base_type]): 183 result.append((seq_typ, f"{cpp_type}?[]")) # type: ignore[valid-type] 184 if optional_list_base: 185 for seq_typ in derived_seq_types(base_type): # type: ignore[valid-type] 186 result.append((typing.Optional[seq_typ], f"{cpp_type}[]?")) # type: ignore[valid-type] 187 return result 188 189 190def get_supported_param_types(): 191 data = [ 192 # (python type, schema type, type[] variant, type?[] variant, type[]? variant 193 (Tensor, "Tensor", True, True, False), 194 (int, "SymInt", True, False, True), 195 (float, "float", True, False, True), 196 (bool, "bool", True, False, True), 197 (str, "str", False, False, False), 198 (types.Number, "Scalar", True, False, False), 199 (dtype, "ScalarType", False, False, False), 200 (device, "Device", False, False, False), 201 ] 202 result = [] 203 for line in data: 204 result.extend(derived_types(*line)) 205 return dict(result) 206 207 208SUPPORTED_RETURN_TYPES = { 209 Tensor: "Tensor", 210 typing.List[Tensor]: "Tensor[]", 211 int: "SymInt", 212 float: "float", 213 bool: "bool", 214 types.Number: "Scalar", 215} 216 217 218def parse_return(annotation, error_fn): 219 if annotation is None: 220 return "()" 221 222 if annotation is inspect.Parameter.empty: 223 error_fn("No return type annotation was provided. Please add one.") 224 225 origin = typing.get_origin(annotation) 226 if origin is not tuple: 227 if annotation not in SUPPORTED_RETURN_TYPES.keys(): 228 error_fn( 229 f"Return has unsupported type {annotation}. " 230 f"The valid types are: {SUPPORTED_RETURN_TYPES}." 231 ) 232 return SUPPORTED_RETURN_TYPES[annotation] 233 234 args = typing.get_args(annotation) 235 for arg in args: 236 if arg not in SUPPORTED_RETURN_TYPES: 237 error_fn( 238 f"Return has unsupported type {annotation}. " 239 f"The valid types are: {SUPPORTED_RETURN_TYPES}." 240 ) 241 242 return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")" 243 244 245SUPPORTED_PARAM_TYPES = get_supported_param_types() 246 247 248def supported_param(param: inspect.Parameter) -> bool: 249 return param.kind in ( 250 inspect.Parameter.POSITIONAL_OR_KEYWORD, 251 inspect.Parameter.KEYWORD_ONLY, 252 ) 253 254 255def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.List]: 256 """ 257 Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type. 258 """ 259 type_args = getattr(tuple_type, "__args__", None) 260 # Account for different python versions, e.g. python 3.8 would give () 261 # but python 3.12 would give None. 262 if tuple_type is typing.Tuple or type_args == () or type_args is None: 263 # Handle the case of an empty tuple type 264 return typing.List 265 elif len(type_args) == 1: 266 # General case: create a List with the same type arguments 267 return typing.List[type_args[0]] # type: ignore[valid-type] 268 elif len(type_args) == 2 and type_args[1] is Ellipsis: # type: ignore[valid-type] 269 return typing.List[type_args[0]] # type: ignore[valid-type] 270 else: 271 return typing.List[typing.Union[tuple(type_args)]] # type: ignore[misc] 272