xref: /aosp_15_r20/external/pytorch/torch/_library/infer_schema.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import inspect
4import typing
5from typing import List, Optional, Sequence, Union  # noqa: F401
6
7import torch
8from torch import device, dtype, Tensor, types
9from torch.utils._exposed_in import exposed_in
10
11
12@exposed_in("torch.library")
13def infer_schema(
14    prototype_function: typing.Callable,
15    /,
16    *,
17    mutates_args,
18    op_name: Optional[str] = None,
19) -> str:
20    r"""Parses the schema of a given function with type hints. The schema is inferred from the
21    function's type hints, and can be used to define a new operator.
22
23    We make the following assumptions:
24
25    * None of the outputs alias any of the inputs or each other.
26    * | String type annotations "device, dtype, Tensor, types" without library specification are
27      | assumed to be torch.*. Similarly, string type annotations "Optional, List, Sequence, Union"
28      | without library specification are assumed to be typing.*.
29    * | Only the args listed in ``mutates_args`` are being mutated. If ``mutates_args`` is "unknown",
30      | it assumes that all inputs to the operator are being mutates.
31
32    Callers (e.g. the custom ops API) are responsible for checking these assumptions.
33
34    Args:
35        prototype_function: The function from which to infer a schema for from its type annotations.
36        op_name (Optional[str]): The name of the operator in the schema. If ``name`` is None, then the
37            name is not included in the inferred schema. Note that the input schema to
38            ``torch.library.Library.define`` requires a operator name.
39        mutates_args ("unknown" | Iterable[str]): The arguments that are mutated in the function.
40
41    Returns:
42        The inferred schema.
43
44    Example:
45        >>> def foo_impl(x: torch.Tensor) -> torch.Tensor:
46        >>>     return x.sin()
47        >>>
48        >>> infer_schema(foo_impl, op_name="foo", mutates_args={})
49        foo(Tensor x) -> Tensor
50        >>>
51        >>> infer_schema(foo_impl, mutates_args={})
52        (Tensor x) -> Tensor
53    """
54    UNKNOWN_MUTATES = "unknown"
55    sig = inspect.signature(prototype_function)
56
57    def error_fn(what):
58        raise ValueError(
59            f"infer_schema(func): {what} " f"Got func with signature {sig})"
60        )
61
62    def convert_type_string(annotation_type: str):
63        try:
64            return eval(annotation_type)
65        except Exception as e:
66            error_fn(
67                f"Unsupported type annotation {annotation_type}. It is not a type."
68            )
69
70    params = []
71    seen_args = set()
72    saw_kwarg_only_arg = False
73    for idx, (name, param) in enumerate(sig.parameters.items()):
74        if not supported_param(param):
75            error_fn("We do not support positional-only args, varargs, or varkwargs.")
76
77        if param.kind == inspect.Parameter.KEYWORD_ONLY:
78            # The first time we see a kwarg-only arg, add "*" to the schema.
79            if not saw_kwarg_only_arg:
80                params.append("*")
81                saw_kwarg_only_arg = True
82
83        if param.annotation is inspect.Parameter.empty:
84            error_fn(f"Parameter {name} must have a type annotation.")
85
86        # The annotation might be converted to a string by annotation,
87        # we convert it to the actual type.
88        annotation_type = param.annotation
89        if type(annotation_type) == str:
90            annotation_type = convert_type_string(annotation_type)
91
92        if annotation_type not in SUPPORTED_PARAM_TYPES.keys():
93            if annotation_type.__origin__ is tuple:
94                list_type = tuple_to_list(annotation_type)
95                example_type_str = "\n\n"
96                # Only suggest the list type if this type is supported.
97                if list_type in SUPPORTED_PARAM_TYPES.keys():
98                    example_type_str = f"For example, {list_type}.\n\n"
99                error_fn(
100                    f"Parameter {name} has unsupported type {param.annotation}. "
101                    f"We do not support Tuple inputs in schema. As a workaround, please try to use List instead. "
102                    f"{example_type_str}"
103                    f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
104                )
105            else:
106                error_fn(
107                    f"Parameter {name} has unsupported type {param.annotation}. "
108                    f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
109                )
110
111        schema_type = SUPPORTED_PARAM_TYPES[annotation_type]
112        if type(mutates_args) == str:
113            if mutates_args != UNKNOWN_MUTATES:
114                raise ValueError(
115                    "mutates_args must either be a sequence of the names of "
116                    "the arguments that are mutated or the string 'unknown'. "
117                )
118            if schema_type.startswith("Tensor"):
119                schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
120        elif name in mutates_args:
121            if not schema_type.startswith("Tensor"):
122                error_fn(
123                    f"Parameter {name} is in mutable_args but only Tensors or collections of Tensors can be mutated"
124                )
125            schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor'):]}"
126        seen_args.add(name)
127        if param.default is inspect.Parameter.empty:
128            params.append(f"{schema_type} {name}")
129        else:
130            default_repr = None
131            if param.default is None or isinstance(param.default, (int, float, bool)):
132                default_repr = str(param.default)
133            elif isinstance(param.default, (str, torch.device)):
134                default_repr = f'"{param.default}"'
135            elif isinstance(param.default, torch.dtype):
136                dtype_repr = str(param.default)
137                torch_dot = "torch."
138                assert dtype_repr.startswith(torch_dot)
139                default_repr = dtype_repr[len(torch_dot) :]
140            else:
141                error_fn(
142                    f"Parameter {name} has an unsupported default value type {type(param.default)}. "
143                    f"Please file an issue on GitHub so we can prioritize this."
144                )
145            params.append(f"{schema_type} {name}={default_repr}")
146    if mutates_args != UNKNOWN_MUTATES:
147        mutates_args_not_seen = set(mutates_args) - seen_args
148        if len(mutates_args_not_seen) > 0:
149            error_fn(
150                f"{mutates_args_not_seen} in mutates_args were not found in "
151                f"the custom op's signature. "
152                f"mutates_args should contain the names of all args that the "
153                f"custom op mutates, or just the string 'unknown' if you don't know."
154            )
155    return_annotation = sig.return_annotation
156    if type(return_annotation) == str:
157        return_annotation = convert_type_string(return_annotation)
158    ret = parse_return(return_annotation, error_fn)
159    if op_name is not None:
160        return f"{op_name}({', '.join(params)}) -> {ret}"
161    return f"({', '.join(params)}) -> {ret}"
162
163
164def derived_types(
165    base_type, cpp_type, list_base, optional_base_list, optional_list_base
166):
167    result = [
168        (base_type, cpp_type),
169        (typing.Optional[base_type], f"{cpp_type}?"),
170    ]
171
172    def derived_seq_types(typ):
173        return [
174            typing.Sequence[typ],  # type: ignore[valid-type]
175            typing.List[typ],  # type: ignore[valid-type]
176        ]
177
178    if list_base:
179        for seq_typ in derived_seq_types(base_type):
180            result.append((seq_typ, f"{cpp_type}[]"))  # type: ignore[valid-type]
181    if optional_base_list:
182        for seq_typ in derived_seq_types(typing.Optional[base_type]):
183            result.append((seq_typ, f"{cpp_type}?[]"))  # type: ignore[valid-type]
184    if optional_list_base:
185        for seq_typ in derived_seq_types(base_type):  # type: ignore[valid-type]
186            result.append((typing.Optional[seq_typ], f"{cpp_type}[]?"))  # type: ignore[valid-type]
187    return result
188
189
190def get_supported_param_types():
191    data = [
192        # (python type, schema type, type[] variant, type?[] variant, type[]? variant
193        (Tensor, "Tensor", True, True, False),
194        (int, "SymInt", True, False, True),
195        (float, "float", True, False, True),
196        (bool, "bool", True, False, True),
197        (str, "str", False, False, False),
198        (types.Number, "Scalar", True, False, False),
199        (dtype, "ScalarType", False, False, False),
200        (device, "Device", False, False, False),
201    ]
202    result = []
203    for line in data:
204        result.extend(derived_types(*line))
205    return dict(result)
206
207
208SUPPORTED_RETURN_TYPES = {
209    Tensor: "Tensor",
210    typing.List[Tensor]: "Tensor[]",
211    int: "SymInt",
212    float: "float",
213    bool: "bool",
214    types.Number: "Scalar",
215}
216
217
218def parse_return(annotation, error_fn):
219    if annotation is None:
220        return "()"
221
222    if annotation is inspect.Parameter.empty:
223        error_fn("No return type annotation was provided. Please add one.")
224
225    origin = typing.get_origin(annotation)
226    if origin is not tuple:
227        if annotation not in SUPPORTED_RETURN_TYPES.keys():
228            error_fn(
229                f"Return has unsupported type {annotation}. "
230                f"The valid types are: {SUPPORTED_RETURN_TYPES}."
231            )
232        return SUPPORTED_RETURN_TYPES[annotation]
233
234    args = typing.get_args(annotation)
235    for arg in args:
236        if arg not in SUPPORTED_RETURN_TYPES:
237            error_fn(
238                f"Return has unsupported type {annotation}. "
239                f"The valid types are: {SUPPORTED_RETURN_TYPES}."
240            )
241
242    return "(" + ", ".join([SUPPORTED_RETURN_TYPES[arg] for arg in args]) + ")"
243
244
245SUPPORTED_PARAM_TYPES = get_supported_param_types()
246
247
248def supported_param(param: inspect.Parameter) -> bool:
249    return param.kind in (
250        inspect.Parameter.POSITIONAL_OR_KEYWORD,
251        inspect.Parameter.KEYWORD_ONLY,
252    )
253
254
255def tuple_to_list(tuple_type: typing.Type[typing.Tuple]) -> typing.Type[typing.List]:
256    """
257    Convert `tuple_type` into a list type with the same type arguments. Assumes that `tuple_type` is typing.Tuple type.
258    """
259    type_args = getattr(tuple_type, "__args__", None)
260    # Account for different python versions, e.g. python 3.8 would give ()
261    # but python 3.12 would give None.
262    if tuple_type is typing.Tuple or type_args == () or type_args is None:
263        # Handle the case of an empty tuple type
264        return typing.List
265    elif len(type_args) == 1:
266        # General case: create a List with the same type arguments
267        return typing.List[type_args[0]]  # type: ignore[valid-type]
268    elif len(type_args) == 2 and type_args[1] is Ellipsis:  # type: ignore[valid-type]
269        return typing.List[type_args[0]]  # type: ignore[valid-type]
270    else:
271        return typing.List[typing.Union[tuple(type_args)]]  # type: ignore[misc]
272