xref: /aosp_15_r20/external/pytorch/torch/fx/operator_schemas.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3import inspect
4import numbers
5import types
6import typing
7import enum
8import warnings
9from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
10from torch._jit_internal import boolean_dispatched
11from ._compatibility import compatibility
12from torch._ops import OpOverloadPacket, OpOverload
13
14if TYPE_CHECKING:
15    from .node import Argument
16
17__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint",
18           "type_matches", "normalize_function", "normalize_module"]
19
20@compatibility(is_backward_compatible=False)
21class ArgsKwargsPair(NamedTuple):
22    """
23    Simple named tuple for wrapping args/kwargs pairs.
24    """
25    args: Tuple[Any, ...]
26    kwargs: Dict[str, Any]
27
28_manual_overrides : Dict[Callable, List[inspect.Signature]] = {}
29
30def _nonzero_schemas():
31    signatures = []
32
33    def nonzero(self):
34        pass
35    signatures.append(inspect.signature(nonzero))
36
37    def nonzero(self, *, as_tuple : bool):  # type: ignore[no-redef]
38        pass
39    signatures.append(inspect.signature(nonzero))
40
41    return signatures
42
43_manual_overrides[torch.nonzero] = _nonzero_schemas()
44
45class _FakeGlobalNamespace:
46    def __getattr__(self, name):
47        if name == 'torch':
48            return torch
49        raise RuntimeError('Expected a torch namespace lookup')
50
51_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout,
52                      'number' : numbers.Number, 'Future' : torch.jit.Future,
53                      'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme,
54                      '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None),
55                      'Storage': torch.UntypedStorage,
56                      't': typing.TypeVar('t')}
57for k in dir(typing):
58    _type_eval_globals[k] = getattr(typing, k)
59
60def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any:
61    """
62    Convert a TorchScript type to a Python type (including subtypes) via
63    eval'ing the annotation_str. _type_eval_globals sets up expressions
64    like "List" and "Future" to map to actual types (typing.List and jit.Future)
65    """
66    return eval(ts_type.annotation_str, _type_eval_globals)
67
68def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
69    from inspect import Parameter
70    parameters : List[Parameter] = []
71    for arg in ts_schema.arguments:
72        arg_type = _torchscript_type_to_python_type(arg.type)
73        default = arg.default_value if arg.has_default_value() else Parameter.empty
74        # TODO: Figure out if this is safe. It seems like when generating the type signatures for
75        # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor
76        # argument name. Downstream, if someone converts that positional argument to a keyword
77        # argument, the name mismatch will break things, so here we're going to normalize the
78        # name to "input"
79        name = arg.name if arg.name != 'self' else 'input'
80        kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD
81        # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument
82        if name == "from":
83            assert kind == Parameter.POSITIONAL_OR_KEYWORD
84            # ParameterKind type is internal implementation detail to inspec package
85            # which makes it hard to do type annotation
86            kind = Parameter.POSITIONAL_ONLY  # type: ignore[assignment]
87            # This renders all previous arguments to positional only
88            for idx, p in enumerate(parameters):
89                assert p.kind == Parameter.POSITIONAL_OR_KEYWORD
90                parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation)
91        parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type))
92    return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns]
93    if len(return_types) == 0:
94        return_type = None
95    elif len(return_types) == 1:
96        return_type = return_types[0]
97    else:
98        return_type = tuple(return_types)
99
100    return inspect.Signature(parameters, return_annotation=return_type)
101
102_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {}
103
104def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature:
105    # Cached as it's called in the hot path of FakeTensor dispatch
106    cache_key = ts_schema.name, ts_schema.overload_name
107    cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key)
108    if cache_val is not None:
109        return cache_val
110
111    res = _torchscript_schema_to_signature_impl(ts_schema)
112    _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res
113    return res
114
115@compatibility(is_backward_compatible=False)
116def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']):
117    signatures, schemas = get_signature_for_torch_op(target, return_schemas=True)
118
119    if signatures and schemas:
120        matched_schemas = []
121
122        # Iterate through all of the schema until we find one that matches
123        # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
124        # values. If none matches, `new_args_and_kwargs` will be None
125        for candidate_signature, schema in zip(signatures, schemas):
126            try:
127                candidate_signature.bind(*args, **kwargs)
128                matched_schemas.append((candidate_signature, schema))
129            except TypeError as e:
130                continue
131
132        def throw_if_mutable(schema):
133            if schema.is_mutable:
134                raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional '
135                                   f'code, so operations that mutate operands in-place (e.g. via `out` arguments) '
136                                   f'are not supported')
137
138        if len(matched_schemas) == 0:
139            # Did not match any schema. Cannot check for mutation
140            pass
141        elif len(matched_schemas) == 1:
142            # Matched exactly one schema, unambiguous
143            _, schema_to_check = matched_schemas[0]
144            throw_if_mutable(schema_to_check)
145        else:
146            # Ambiguous schema match. Since mutability checking is best effort,
147            # do nothing.
148            pass
149
150@compatibility(is_backward_compatible=False)
151def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
152    """
153    Given an operator on the `torch` namespace, return a list of `inspect.Signature`
154    objects corresponding to the overloads of that op.. May return `None` if a signature
155    could not be retrieved.
156
157    Args:
158        op (Callable): An operator on the `torch` namespace to look up a signature for
159
160    Returns:
161        Optional[List[inspect.Signature]]: A list of signatures for the overloads of this
162            operator, or None if the operator signatures could not be retrieved. If
163            return_schemas=True, returns a tuple containing the optional Python signatures
164            and the optional TorchScript Function signature
165    """
166    if isinstance(op, OpOverload):
167        schemas = [op._schema]
168    elif isinstance(op, OpOverloadPacket):
169        schemas = [getattr(op, overload)._schema for overload in op.overloads()]
170    else:
171        override = _manual_overrides.get(op)
172        if override:
173            return (override, None) if return_schemas else None
174
175        aten_fn = torch.jit._builtins._find_builtin(op)
176
177        if aten_fn is None:
178            return (None, None) if return_schemas else None
179        schemas = torch._C._jit_get_schemas_for_operator(aten_fn)
180
181    signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
182    return (signatures, schemas) if return_schemas else signatures
183
184@compatibility(is_backward_compatible=False)
185def create_type_hint(x):
186    """
187    Produces a type hint for the given argument.
188
189    The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`.
190
191    If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass
192    of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned.
193    If no such object is found, it defaults to `List[Any]`.
194
195    If `x` is neither a `list` nor a `tuple`, it returns `x`.
196    """
197    try:
198        if isinstance(x, (list, tuple)):
199            # todo(chilli): Figure out the right way for mypy to handle this
200            if isinstance(x, list):
201                def ret_type(x):
202                    return List[x]  # type: ignore[valid-type]
203            else:
204                def ret_type(x):
205                    return Tuple[x, ...]
206            if len(x) == 0:
207                return ret_type(Any)
208            base_type = x[0]
209            for t in x:
210                if issubclass(t, base_type):
211                    continue
212                elif issubclass(base_type, t):
213                    base_type = t
214                else:
215                    return ret_type(Any)
216            return ret_type(base_type)
217    except Exception as e:
218        # We tried to create a type hint for list but failed.
219        warnings.warn(f"We were not able to successfully create type hint from the type {x}")
220    return x
221
222@compatibility(is_backward_compatible=False)
223def type_matches(signature_type : Any, argument_type : Any):
224    sig_origin_type = getattr(signature_type, '__origin__', signature_type)
225
226    if signature_type is argument_type:
227        return True
228
229    # Union types in signature. Given type needs to match one of the
230    # contained types in the Union
231    if sig_origin_type is typing.Union and signature_type != argument_type:
232        sig_contained = signature_type.__args__
233        return any(type_matches(c, argument_type) for c in sig_contained)
234
235    if signature_type is List[int] and argument_type is int:
236        # int can be promoted to List[int]
237        return True
238
239    if getattr(signature_type, '__origin__', None) in {list, List}:
240        sig_el_type = signature_type.__args__[0]
241        if not inspect.isclass(sig_el_type):
242            warnings.warn(
243                f"Does not support nested parametric types, got {signature_type}. Please file a bug.")
244            return False
245        if getattr(argument_type, '__origin__', None) in {list, List}:
246            return issubclass(argument_type.__args__[0], sig_el_type)
247
248        def is_homogeneous_tuple(t):
249            if getattr(t, "__origin__", None) not in {tuple, Tuple}:
250                return False
251            contained = t.__args__
252            if t.__args__ == ((),):  # Tuple[()].__args__ == ((),) for some reason
253                return True
254            return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained)
255
256        # Tuple[T] is accepted for List[T] parameters
257        return is_homogeneous_tuple(argument_type)
258
259    # Dtype is an int in schemas
260    if signature_type is int and argument_type is torch.dtype:
261        return True
262
263    if signature_type is numbers.Number and argument_type in {int, float}:
264        return True
265    if inspect.isclass(argument_type) and inspect.isclass(signature_type):
266        return issubclass(argument_type, signature_type)
267
268    return False
269
270@compatibility(is_backward_compatible=False)
271def normalize_function(
272        target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None,
273        kwarg_types : Optional[Dict[str, Any]] = None,
274        normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
275    """
276    Returns normalized arguments to PyTorch functions. This means that
277    `args/kwargs` will be matched up to the functional's
278    signature and return exclusively kwargs in positional order if
279    `normalize_to_only_use_kwargs` is True.
280    Also populates default values. Does not support positional-only
281    parameters or varargs parameters (*args, **kwargs). Does not support modules.
282
283    May require `arg_types` and `kwarg_types` in order to disambiguate overloads.
284
285    Args:
286        target (Callable): Function that we are normalizing
287        args (Tuple[Any]): Tuple of args to the function
288        kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
289        arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
290        kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
291        normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
292
293    Returns:
294
295        Returns normalized_args_and_kwargs, or `None` if not successful.
296    """
297    if kwargs is None:
298        kwargs = {}
299    new_args_and_kwargs = None
300    if not isinstance(target, types.BuiltinFunctionType) and not (
301        isinstance(target, (OpOverloadPacket, OpOverload))
302    ):
303        target_for_analysis = target
304        if target in boolean_dispatched:
305            # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have
306            # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false`
307            # branches of the dispatch have exactly the same signature. If they do, use the `true`
308            # branch signature for analysis. Otherwise, leave this un-normalized
309            assert not isinstance(target, str)
310            dispatched = boolean_dispatched[target]
311            if_true, if_false = dispatched['if_true'], dispatched['if_false']
312            if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters:
313                return None
314            target_for_analysis = if_true
315
316        assert callable(target_for_analysis)
317        sig = inspect.signature(inspect.unwrap(target_for_analysis))
318        new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs)
319    else:
320        assert callable(target)
321        torch_op_schemas = get_signature_for_torch_op(target)
322        matched_schemas = []
323        if torch_op_schemas:
324            # Iterate through all of the schema until we find one that matches
325            # If one matches, populate `new_args_and_kwargs` with the new args/kwargs
326            # values. If none matches, `new_args_and_kwargs` will be None
327            for candidate_signature in torch_op_schemas:
328                try:
329                    candidate_signature.bind(*args, **kwargs)
330                    matched_schemas.append(candidate_signature)
331                except TypeError as e:
332                    continue
333
334            if len(matched_schemas) == 0:
335                # Did not match any schema. Cannot normalize
336                pass
337            elif len(matched_schemas) == 1:
338                # Matched exactly one schema, unambiguous
339                new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs,
340                                                                             normalize_to_only_use_kwargs)
341            else:
342                if arg_types is not None or kwarg_types is not None:
343                    arg_types = arg_types if arg_types else cast(Tuple[Any], ())
344                    kwarg_types = kwarg_types if kwarg_types else {}
345                    for candidate_signature in torch_op_schemas:
346                        sig_matches = True
347                        try:
348                            bound_types = candidate_signature.bind(*arg_types, **kwarg_types)
349                            for arg_name, arg_type in bound_types.arguments.items():
350                                param = candidate_signature.parameters[arg_name]
351                                sig_matches = sig_matches and type_matches(param.annotation, arg_type)
352                        except TypeError as e:
353                            sig_matches = False
354                        if sig_matches:
355                            new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs,
356                                                                                         normalize_to_only_use_kwargs)
357                            break
358                else:
359                    # Matched more than one schema. In this situation, the caller must provide the types of
360                    # the arguments of the overload they expect.
361                    schema_printouts = '\n'.join(str(schema) for schema in matched_schemas)
362                    raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but '
363                                       f'the schema match was ambiguous! Please provide argument types to '
364                                       f'the normalize_arguments() call. Available schemas:\n{schema_printouts}')
365
366    return new_args_and_kwargs
367
368@compatibility(is_backward_compatible=False)
369def normalize_module(
370        root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None,
371        normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
372    """
373    Returns normalized arguments to PyTorch modules. This means that
374    `args/kwargs` will be matched up to the functional's
375    signature and return exclusively kwargs in positional order if
376    `normalize_to_only_use_kwargs` is True.
377    Also populates default values. Does not support positional-only
378    parameters or varargs parameters (*args, **kwargs).
379
380    Args:
381        root (nn.Module): root module upon which we query modules
382        target (Callable): Function that we are normalizing
383        args (Tuple[Any]): Tuple of args to the function
384        kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function
385        normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
386
387    Returns:
388
389        Returns normalized_args_and_kwargs, or `None` if not successful.
390    """
391    try:
392        submod = root.get_submodule(target)
393    except AttributeError as e:
394        raise RuntimeError(f"Tried to normalize node with target {target} but root did not "
395                           f"have that target!") from e
396    if hasattr(submod.__class__, '__name__'):
397        classname = submod.__class__.__name__
398        if getattr(torch.nn, classname, None) == submod.__class__:
399            sig = inspect.signature(inspect.unwrap(submod.forward))
400            if kwargs is None:
401                kwargs = {}
402            new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs,
403                                                                         normalize_to_only_use_kwargs)
404            return new_args_and_kwargs
405    return None
406
407def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...],
408                                           kwargs : Dict[str, Any],
409                                           normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]:
410    """
411    Given a call target, args, and kwargs, return the arguments normalized into
412    an ArgsKwargsPair, or None if the type signature is not supported by
413    this normalization.
414
415    Args:
416
417        sig (inspect.Signature): Signature object for the target
418        args (Tuple): Arguments that appear at the callsite for `target`
419        kwargs (Dict): Keyword arguments that appear at the callsite for `target`
420        normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.
421
422    Returns:
423
424        Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if
425            this target is not supported.
426    """
427
428    # Don't currently support positional-only
429    # or varargs (*args, **kwargs) signatures
430    supported_parameter_types = {
431        inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}
432    if any(p.kind not in supported_parameter_types for p in sig.parameters.values()):
433        # Add an exception for one signature, which is common for random/uniform, i.e.:
434        # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None
435        # `from` is Python keyword and as such functions with that signature should have
436        # positional-only args, but at the same time they could be dispatched as kwargs
437        if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']:
438            return None
439
440    bound_args = sig.bind(*args, **kwargs)
441    bound_args.apply_defaults()
442
443    new_kwargs : Dict[str, Any] = {}
444    new_args : List[Any] = []
445    for i, param in enumerate(sig.parameters):
446        if not normalize_to_only_use_kwargs and i < len(args):
447            new_args.append(bound_args.arguments[param])
448        else:
449            new_kwargs[param] = bound_args.arguments[param]
450
451    return ArgsKwargsPair(tuple(new_args), new_kwargs)
452