xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/passes/type_promotion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Owner(s): ["module: onnx"]
3from __future__ import annotations
4
5import abc
6import dataclasses
7import inspect
8import logging
9from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING
10
11import torch
12import torch._ops
13import torch.fx
14import torch.fx.traceback as fx_traceback
15from torch import _prims_common, _refs
16from torch._prims_common import (
17    ELEMENTWISE_TYPE_PROMOTION_KIND,
18    wrappers as _prims_common_wrappers,
19)
20from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
21from torch._refs.nn import functional as _functional_refs
22from torch._subclasses import fake_tensor
23from torch.fx.experimental import proxy_tensor
24from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
25from torch.utils import _python_dispatch, _pytree
26
27
28if TYPE_CHECKING:
29    from types import ModuleType
30
31
32logger = logging.getLogger(__name__)
33
34# TODO(bowbao): move to type utils.
35_SCALAR_TYPE_TENSOR_DTYPE_MAP: Mapping[type, torch.dtype] = {
36    bool: torch.bool,
37    int: torch.int64,
38    float: torch.float32,
39    complex: torch.complex32,
40}
41
42
43def _try_getclosurevars(func):
44    try:
45        return inspect.getclosurevars(func)
46    except TypeError as e:
47        return None
48
49
50@dataclasses.dataclass
51class TypePromotionSnapshot:
52    """Type promotion snapshot for a fx node and its inputs.
53
54    Contains the promoted dtype for args and kwargs that needs promoting.
55    Contains the expected node output dtype.
56    """
57
58    args_dtypes: Mapping[int, torch.dtype]
59    """Mapping from arg position to dtype to promote to."""
60
61    kwargs_dtypes: Mapping[str, torch.dtype]
62    """Mapping from kwarg name to dtype to promote to."""
63
64    out_dtype: torch.dtype
65    """Expected output dtype of the node."""
66
67
68class TypePromotionRule(abc.ABC):
69    """Base class for type promotion rule per 'torch.ops.{namespace}.{op_name}'."""
70
71    def __init__(self, namespace: str, op_name: str):
72        self.namespace = namespace
73        self.op_name = op_name
74
75    # Make this abstract as well because subclass needs to override __eq__().
76    # A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None.
77    # Ref: https://docs.python.org/3/reference/datamodel.html#object.__hash__
78    @abc.abstractmethod
79    def __hash__(self) -> int: ...
80
81    @abc.abstractmethod
82    def __repr__(self): ...
83
84    @abc.abstractmethod
85    def __eq__(self, other: object) -> bool: ...
86
87    def is_valid(self) -> bool:
88        """Check if the rule is valid."""
89        # This always returns a module. If the module does not exist it will be created.
90        module = getattr(torch.ops, self.namespace)
91        py_op = getattr(module, self.op_name, None)
92        if py_op is None:
93            logger.warning(
94                "Cannot find op: %s in module: %s", self.op_name, self.namespace
95            )
96            return False
97        if not isinstance(py_op, torch._ops.OpOverloadPacket):
98            logger.warning(
99                "Op: torch.ops.%s.%s is not an OpOverloadPacket, got: %s",
100                self.namespace,
101                self.op_name,
102                type(py_op),
103            )
104            return False
105
106        return True
107
108    @abc.abstractmethod
109    def preview_type_promotion(
110        self, args: tuple, kwargs: dict
111    ) -> TypePromotionSnapshot:
112        """Preview type promotion results for provided set of args and kwargs.
113
114        Returns a TypePromotionSnapshot object that contains the promoted dtypes for
115        the arguments and the expected output dtype.
116        """
117        ...
118
119
120class ElementwiseTypePromotionRule(TypePromotionRule):
121    """Defines how to perform elementwise type promotion for 'torch.ops.{namespace}.{op_name}'."""
122
123    _USE_OPMATH: bool = False
124    """Whether to use opmath to compute the promoted input dtype.
125    If used, upcasts will be inserted everywhere for lower precision models.
126    Set to False and have torchlib handle upcasts in op implementation internally.
127    """
128
129    def __init__(
130        self,
131        namespace: str,
132        op_name: str,
133        promote_args_positions: Sequence[int],
134        promote_kwargs_names: Sequence[str],
135        promotion_kind: _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND,
136    ):
137        """Constructs a TypePromotionRule for elementwise operators.
138
139        Args:
140            namespace: Namespace of the op. E.g. 'aten' in 'torch.ops.aten.add'.
141            op_name: Name of the op. E.g. 'add' in 'torch.ops.aten.add'.
142            promote_args_positions: Positions of args to promote.
143            promote_kwargs_names: Names of kwargs to promote.
144            promotion_kind: Type promotion kind. Refer to [_prims_common.elementwise_dtypes](https://github.com/pytorch/pytorch/blob/main/torch/_prims_common/__init__.py) for detail.  # noqa: B950
145        """
146        super().__init__(namespace, op_name)
147        self.promote_args_positions = promote_args_positions
148        self.promote_kwargs_names = promote_kwargs_names
149        self.promotion_kind = promotion_kind
150
151    def __repr__(self):
152        return (
153            f"ElementwiseTypePromotionRule('{self.namespace}', '{self.op_name}', "
154            f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})"
155        )
156
157    def __eq__(self, __value: object) -> bool:
158        if not isinstance(__value, ElementwiseTypePromotionRule):
159            return False
160        return (
161            self.namespace == __value.namespace
162            and self.op_name == __value.op_name
163            and self.promote_args_positions == __value.promote_args_positions
164            and self.promote_kwargs_names == __value.promote_kwargs_names
165            and self.promotion_kind == __value.promotion_kind
166        )
167
168    def __hash__(self) -> int:
169        return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__()
170
171    def _consolidate_input_dtype(
172        self, computed_dtype: torch.dtype, result_dtype: torch.dtype
173    ) -> torch.dtype:
174        """
175        Although opmath is the right thing to do to retain on-par precision, it inserts
176        upcasts everywhere in the graph. This is particularly hard for backend to optimize
177        since there is no way to differentiate between inserted upcasts and model code
178        casts. Hence we consolidate the input dtype to the result dtype to avoid this.
179        """
180        if not self._USE_OPMATH and self.promotion_kind in (
181            _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
182            _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
183        ):
184            return result_dtype
185        return computed_dtype
186
187    def preview_type_promotion(
188        self, args: tuple, kwargs: dict
189    ) -> TypePromotionSnapshot:
190        candidate_args = {
191            i: args[i]
192            for i in self.promote_args_positions
193            if i < len(args) and args[i] is not None
194        }
195        candidate_kwargs = {
196            name: kwargs[name]
197            for name in self.promote_kwargs_names
198            if name in kwargs and kwargs[name] is not None
199        }
200
201        computed_dtype, result_dtype = _prims_common.elementwise_dtypes(
202            *_pytree.arg_tree_leaves(*candidate_args.values(), **candidate_kwargs),
203            type_promotion_kind=self.promotion_kind,
204        )
205
206        consolidated_input_dtype = self._consolidate_input_dtype(
207            computed_dtype, result_dtype
208        )
209
210        return TypePromotionSnapshot(
211            dict.fromkeys(candidate_args.keys(), consolidated_input_dtype),
212            dict.fromkeys(candidate_kwargs.keys(), consolidated_input_dtype),
213            result_dtype,
214        )
215
216
217class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule):
218    """Reference type promotion rule from torch._refs.div.
219
220    Rule depends on the value of the `rounding_mode` argument.
221    """
222
223    def __init__(self):
224        super().__init__(
225            "aten",
226            "div",
227            promote_args_positions=(0, 1),
228            promote_kwargs_names=(),
229            promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
230        )
231
232    def preview_type_promotion(
233        self, args: tuple, kwargs: dict
234    ) -> TypePromotionSnapshot:
235        rounding_mode = kwargs.get("rounding_mode", None)
236        if rounding_mode is None:
237            # true_divide
238            self.promotion_kind = (
239                _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
240            )
241            return super().preview_type_promotion(args, kwargs)
242        if rounding_mode == "trunc":
243            # trunc_divide
244            self.promotion_kind = _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
245            return super().preview_type_promotion(args, kwargs)
246        if rounding_mode == "floor":
247            # floor_divide
248            self.promotion_kind = _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
249            return super().preview_type_promotion(args, kwargs)
250        raise ValueError(f"Unknown rounding_mode: {rounding_mode}")
251
252
253class ReductionTypePromotionRule(TypePromotionRule):
254    def __init__(
255        self,
256        namespace: str,
257        op_name: str,
258        promotion_kind: _prims_common.REDUCTION_OUTPUT_TYPE_KIND,
259    ):
260        """Constructs a TypePromotionRule for reduction operators.
261
262        Args:
263            namespace: Namespace of the op. E.g. 'aten' in 'torch.ops.aten.sum'.
264            op_name: Name of the op. E.g. 'sum' in 'torch.ops.aten.sum'.
265            promotion_kind: Type promotion kind. Refer to [_prims_common.reduction_dtypes]((https://github.com/pytorch/pytorch/blob/main/torch/_prims_common/__init__.py)) for detail.  # noqa: B950
266        """
267        super().__init__(namespace, op_name)
268        self.promotion_kind = promotion_kind
269
270    def __repr__(self):
271        return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})"
272
273    def __eq__(self, __value: object) -> bool:
274        if not isinstance(__value, ElementwiseTypePromotionRule):
275            return False
276        return (
277            self.namespace == __value.namespace
278            and self.op_name == __value.op_name
279            and self.promotion_kind == __value.promotion_kind
280        )
281
282    def __hash__(self) -> int:
283        return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__()
284
285    def preview_type_promotion(
286        self, args: tuple, kwargs: dict
287    ) -> TypePromotionSnapshot:
288        assert (
289            len(args) >= 1
290        ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
291        arg = args[0]
292        assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
293        dtype: torch.dtype | None = kwargs.get("dtype", None)
294
295        computation_dtype, result_dtype = _prims_common.reduction_dtypes(
296            arg, self.promotion_kind, dtype
297        )
298        if result_dtype is None:
299            # Inspecting code, this can only happen when `promotion_kind` is `KEEP_PROMOTED_TYPE`.
300            # Hence set same as computation_dtype.
301            result_dtype = computation_dtype
302
303        return TypePromotionSnapshot(
304            {0: computation_dtype},
305            {},
306            result_dtype,
307        )
308
309
310class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule):
311    """Reference type promotion rule from torch.ops.aten.all or torch.ops.aten.any.
312
313    This is a special case where computation dtype is always torch.bool.
314    The result dtype is always uint8 if `dtype` kwarg is uint8, otherwise torch.bool.
315    """
316
317    def __init__(self, op_name: str):
318        super().__init__(
319            "aten",
320            op_name,
321            _prims_common.REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL,
322        )
323
324    def preview_type_promotion(
325        self, args: tuple, kwargs: dict
326    ) -> TypePromotionSnapshot:
327        assert (
328            len(args) >= 1
329        ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
330        arg = args[0]
331        assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
332        computation_dtype = torch.bool
333        # Preserves uint8 -- probably a legacy mask thing
334        result_dtype = torch.uint8 if arg.dtype == torch.uint8 else torch.bool
335        return TypePromotionSnapshot(
336            {0: computation_dtype},
337            {},
338            result_dtype,
339        )
340
341
342class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule):
343    """Reference type promotion rule from torch.ops.aten.sum.
344
345    This is a special case where computation dtype is always torch.int64 for integral arg,
346    unless overridden by `dtype` kwarg.
347    """
348
349    def preview_type_promotion(
350        self, args: tuple, kwargs: dict
351    ) -> TypePromotionSnapshot:
352        assert (
353            len(args) >= 1
354        ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument"
355        arg = args[0]
356        assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor"
357        dtype: torch.dtype | None = kwargs.get("dtype", None)
358        # The below logic is copied from `torch/_refs/__init__.py` reduction ops impl.
359        if dtype is None:
360            if _prims_common.is_boolean_dtype(
361                arg.dtype
362            ) or _prims_common.is_integer_dtype(arg.dtype):
363                dtype = torch.int64
364            else:
365                dtype = arg.dtype
366        return super().preview_type_promotion(args, {"dtype": dtype})
367
368
369# NOTE: [Update type promotion rule]
370# BELOW TABLE IS GENERATED FROM `TypePromotionRuleSetGenerator.generate_from_torch_refs`.
371# DO NOT EDIT MANUALLY !!!
372# For missing rules or discrepancies, please
373# 1. Run `pytest test/onnx/test_fx_type_promotion.py` to validate if the generated rule set is current.
374#    If it is not, update with new generated set.
375# 2. If discrepancies still exist, consider debugging torch._refs or report a bug.
376# 3. If rules are still missing, add them to `_EXTRA_TYPE_PROMOTION_RULE_SET` or report a bug.
377# Check `TypePromotionRule` class for how each rule is defined and used.
378_GENERATED_ATEN_TYPE_PROMOTION_RULE_SET = {
379    ElementwiseTypePromotionRule(
380        "aten", "abs", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
381    ),
382    ElementwiseTypePromotionRule(
383        "aten", "abs_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
384    ),
385    ElementwiseTypePromotionRule(
386        "aten", "acos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
387    ),
388    ElementwiseTypePromotionRule(
389        "aten", "acos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
390    ),
391    ElementwiseTypePromotionRule(
392        "aten", "acosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
393    ),
394    ElementwiseTypePromotionRule(
395        "aten", "acosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
396    ),
397    ElementwiseTypePromotionRule(
398        "aten", "add", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
399    ),
400    ElementwiseTypePromotionRule(
401        "aten", "add_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
402    ),
403    ElementwiseTypePromotionRule(
404        "aten", "addcdiv", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
405    ),
406    ElementwiseTypePromotionRule(
407        "aten", "addcdiv_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
408    ),
409    ElementwiseTypePromotionRule(
410        "aten", "addcmul", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
411    ),
412    ElementwiseTypePromotionRule(
413        "aten", "addcmul_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
414    ),
415    ElementwiseTypePromotionRule(
416        "aten", "addr", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
417    ),
418    ElementwiseTypePromotionRule(
419        "aten", "asin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
420    ),
421    ElementwiseTypePromotionRule(
422        "aten", "asin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
423    ),
424    ElementwiseTypePromotionRule(
425        "aten", "asinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
426    ),
427    ElementwiseTypePromotionRule(
428        "aten", "asinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
429    ),
430    ElementwiseTypePromotionRule(
431        "aten", "atan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
432    ),
433    ElementwiseTypePromotionRule(
434        "aten", "atan2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
435    ),
436    ElementwiseTypePromotionRule(
437        "aten", "atan2_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
438    ),
439    ElementwiseTypePromotionRule(
440        "aten", "atan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
441    ),
442    ElementwiseTypePromotionRule(
443        "aten", "atanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
444    ),
445    ElementwiseTypePromotionRule(
446        "aten", "atanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
447    ),
448    ElementwiseTypePromotionRule(
449        "aten", "bitwise_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
450    ),
451    ElementwiseTypePromotionRule(
452        "aten", "bitwise_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
453    ),
454    ElementwiseTypePromotionRule(
455        "aten",
456        "bitwise_left_shift",
457        [0, 1],
458        [],
459        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
460    ),
461    ElementwiseTypePromotionRule(
462        "aten",
463        "bitwise_left_shift_",
464        [0, 1],
465        [],
466        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
467    ),
468    ElementwiseTypePromotionRule(
469        "aten", "bitwise_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
470    ),
471    ElementwiseTypePromotionRule(
472        "aten", "bitwise_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
473    ),
474    ElementwiseTypePromotionRule(
475        "aten", "bitwise_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
476    ),
477    ElementwiseTypePromotionRule(
478        "aten", "bitwise_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
479    ),
480    ElementwiseTypePromotionRule(
481        "aten",
482        "bitwise_right_shift",
483        [0, 1],
484        [],
485        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
486    ),
487    ElementwiseTypePromotionRule(
488        "aten",
489        "bitwise_right_shift_",
490        [0, 1],
491        [],
492        ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
493    ),
494    ElementwiseTypePromotionRule(
495        "aten", "bitwise_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
496    ),
497    ElementwiseTypePromotionRule(
498        "aten", "bitwise_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
499    ),
500    ElementwiseTypePromotionRule(
501        "aten", "cat", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
502    ),
503    ElementwiseTypePromotionRule(
504        "aten", "cauchy", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
505    ),
506    ElementwiseTypePromotionRule(
507        "aten", "cauchy_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
508    ),
509    ElementwiseTypePromotionRule(
510        "aten", "ceil", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
511    ),
512    ElementwiseTypePromotionRule(
513        "aten", "ceil_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
514    ),
515    ElementwiseTypePromotionRule(
516        "aten", "celu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
517    ),
518    ElementwiseTypePromotionRule(
519        "aten", "celu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
520    ),
521    ElementwiseTypePromotionRule(
522        "aten", "clamp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
523    ),
524    ElementwiseTypePromotionRule(
525        "aten", "clamp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
526    ),
527    ElementwiseTypePromotionRule(
528        "aten", "copysign", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
529    ),
530    ElementwiseTypePromotionRule(
531        "aten", "copysign_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
532    ),
533    ElementwiseTypePromotionRule(
534        "aten", "cos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
535    ),
536    ElementwiseTypePromotionRule(
537        "aten", "cos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
538    ),
539    ElementwiseTypePromotionRule(
540        "aten", "cosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
541    ),
542    ElementwiseTypePromotionRule(
543        "aten", "cosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
544    ),
545    ElementwiseTypePromotionRule(
546        "aten", "deg2rad", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
547    ),
548    ElementwiseTypePromotionRule(
549        "aten", "deg2rad_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
550    ),
551    ElementwiseTypePromotionRule(
552        "aten", "digamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
553    ),
554    ElementwiseTypePromotionRule(
555        "aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
556    ),
557    ElementwiseTypePromotionRule(
558        "aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
559    ),
560    ElementwiseTypePromotionRule(
561        "aten", "elu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
562    ),
563    ElementwiseTypePromotionRule(
564        "aten", "eq", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
565    ),
566    ElementwiseTypePromotionRule(
567        "aten", "eq_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
568    ),
569    ElementwiseTypePromotionRule(
570        "aten", "erf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
571    ),
572    ElementwiseTypePromotionRule(
573        "aten", "erf_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
574    ),
575    ElementwiseTypePromotionRule(
576        "aten", "erfc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
577    ),
578    ElementwiseTypePromotionRule(
579        "aten", "erfc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
580    ),
581    ElementwiseTypePromotionRule(
582        "aten", "erfinv", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
583    ),
584    ElementwiseTypePromotionRule(
585        "aten", "erfinv_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
586    ),
587    ElementwiseTypePromotionRule(
588        "aten", "exp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
589    ),
590    ElementwiseTypePromotionRule(
591        "aten", "exp2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
592    ),
593    ElementwiseTypePromotionRule(
594        "aten", "exp2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
595    ),
596    ElementwiseTypePromotionRule(
597        "aten", "exp_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
598    ),
599    ElementwiseTypePromotionRule(
600        "aten", "expm1", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
601    ),
602    ElementwiseTypePromotionRule(
603        "aten", "expm1_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
604    ),
605    ElementwiseTypePromotionRule(
606        "aten", "exponential", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
607    ),
608    ElementwiseTypePromotionRule(
609        "aten", "exponential_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
610    ),
611    ElementwiseTypePromotionRule(
612        "aten", "fill", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
613    ),
614    ElementwiseTypePromotionRule(
615        "aten", "floor", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
616    ),
617    ElementwiseTypePromotionRule(
618        "aten", "floor_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
619    ),
620    ElementwiseTypePromotionRule(
621        "aten", "floor_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
622    ),
623    ElementwiseTypePromotionRule(
624        "aten", "floor_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
625    ),
626    ElementwiseTypePromotionRule(
627        "aten", "fmax", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
628    ),
629    ElementwiseTypePromotionRule(
630        "aten", "fmin", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
631    ),
632    ElementwiseTypePromotionRule(
633        "aten", "fmod", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
634    ),
635    ElementwiseTypePromotionRule(
636        "aten", "fmod_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
637    ),
638    ElementwiseTypePromotionRule(
639        "aten", "frac", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
640    ),
641    ElementwiseTypePromotionRule(
642        "aten", "frac_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
643    ),
644    ElementwiseTypePromotionRule(
645        "aten", "gcd", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
646    ),
647    ElementwiseTypePromotionRule(
648        "aten", "gcd_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
649    ),
650    ElementwiseTypePromotionRule(
651        "aten", "ge", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
652    ),
653    ElementwiseTypePromotionRule(
654        "aten", "ge_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
655    ),
656    ElementwiseTypePromotionRule(
657        "aten", "gelu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
658    ),
659    ElementwiseTypePromotionRule(
660        "aten", "geometric", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
661    ),
662    ElementwiseTypePromotionRule(
663        "aten", "geometric_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
664    ),
665    ElementwiseTypePromotionRule(
666        "aten", "glu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
667    ),
668    ElementwiseTypePromotionRule(
669        "aten", "gt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
670    ),
671    ElementwiseTypePromotionRule(
672        "aten", "gt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
673    ),
674    ElementwiseTypePromotionRule(
675        "aten", "hardtanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
676    ),
677    ElementwiseTypePromotionRule(
678        "aten", "heaviside", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
679    ),
680    ElementwiseTypePromotionRule(
681        "aten", "heaviside_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
682    ),
683    ElementwiseTypePromotionRule(
684        "aten", "huber_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
685    ),
686    ElementwiseTypePromotionRule(
687        "aten", "hypot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
688    ),
689    ElementwiseTypePromotionRule(
690        "aten", "hypot_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
691    ),
692    ElementwiseTypePromotionRule(
693        "aten", "i0", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
694    ),
695    ElementwiseTypePromotionRule(
696        "aten", "i0_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
697    ),
698    ElementwiseTypePromotionRule(
699        "aten", "igamma", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
700    ),
701    ElementwiseTypePromotionRule(
702        "aten", "igamma_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
703    ),
704    ElementwiseTypePromotionRule(
705        "aten", "igammac", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
706    ),
707    ElementwiseTypePromotionRule(
708        "aten", "igammac_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
709    ),
710    ElementwiseTypePromotionRule(
711        "aten", "isfinite", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
712    ),
713    ElementwiseTypePromotionRule(
714        "aten", "isinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
715    ),
716    ElementwiseTypePromotionRule(
717        "aten", "isnan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
718    ),
719    ElementwiseTypePromotionRule(
720        "aten", "isneginf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
721    ),
722    ElementwiseTypePromotionRule(
723        "aten", "isposinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
724    ),
725    ElementwiseTypePromotionRule(
726        "aten", "isreal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
727    ),
728    ElementwiseTypePromotionRule(
729        "aten", "l1_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
730    ),
731    ElementwiseTypePromotionRule(
732        "aten", "lcm", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
733    ),
734    ElementwiseTypePromotionRule(
735        "aten", "lcm_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
736    ),
737    ElementwiseTypePromotionRule(
738        "aten", "le", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
739    ),
740    ElementwiseTypePromotionRule(
741        "aten", "le_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
742    ),
743    ElementwiseTypePromotionRule(
744        "aten", "leaky_relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
745    ),
746    ElementwiseTypePromotionRule(
747        "aten", "lerp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
748    ),
749    ElementwiseTypePromotionRule(
750        "aten", "lerp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
751    ),
752    ElementwiseTypePromotionRule(
753        "aten", "lgamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
754    ),
755    ElementwiseTypePromotionRule(
756        "aten", "lgamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
757    ),
758    ElementwiseTypePromotionRule(
759        "aten", "log", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
760    ),
761    ElementwiseTypePromotionRule(
762        "aten", "log10", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
763    ),
764    ElementwiseTypePromotionRule(
765        "aten", "log10_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
766    ),
767    ElementwiseTypePromotionRule(
768        "aten", "log1p", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
769    ),
770    ElementwiseTypePromotionRule(
771        "aten", "log1p_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
772    ),
773    ElementwiseTypePromotionRule(
774        "aten", "log2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
775    ),
776    ElementwiseTypePromotionRule(
777        "aten", "log2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
778    ),
779    ElementwiseTypePromotionRule(
780        "aten", "log_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
781    ),
782    ElementwiseTypePromotionRule(
783        "aten", "log_normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
784    ),
785    ElementwiseTypePromotionRule(
786        "aten", "log_normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
787    ),
788    ElementwiseTypePromotionRule(
789        "aten", "logaddexp", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
790    ),
791    ElementwiseTypePromotionRule(
792        "aten", "logaddexp2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
793    ),
794    ElementwiseTypePromotionRule(
795        "aten", "logical_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
796    ),
797    ElementwiseTypePromotionRule(
798        "aten", "logical_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
799    ),
800    ElementwiseTypePromotionRule(
801        "aten", "logical_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
802    ),
803    ElementwiseTypePromotionRule(
804        "aten", "logical_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
805    ),
806    ElementwiseTypePromotionRule(
807        "aten", "logical_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
808    ),
809    ElementwiseTypePromotionRule(
810        "aten", "logical_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
811    ),
812    ElementwiseTypePromotionRule(
813        "aten", "logical_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
814    ),
815    ElementwiseTypePromotionRule(
816        "aten", "logical_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
817    ),
818    ElementwiseTypePromotionRule(
819        "aten", "logit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
820    ),
821    ElementwiseTypePromotionRule(
822        "aten", "logsumexp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
823    ),
824    ElementwiseTypePromotionRule(
825        "aten", "lt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
826    ),
827    ElementwiseTypePromotionRule(
828        "aten", "lt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
829    ),
830    ElementwiseTypePromotionRule(
831        "aten", "maximum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
832    ),
833    ElementwiseTypePromotionRule(
834        "aten", "minimum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
835    ),
836    ElementwiseTypePromotionRule(
837        "aten", "mish", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
838    ),
839    ElementwiseTypePromotionRule(
840        "aten", "mish_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
841    ),
842    ElementwiseTypePromotionRule(
843        "aten", "mse_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
844    ),
845    ElementwiseTypePromotionRule(
846        "aten", "mul", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
847    ),
848    ElementwiseTypePromotionRule(
849        "aten", "mul_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
850    ),
851    ElementwiseTypePromotionRule(
852        "aten", "ne", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
853    ),
854    ElementwiseTypePromotionRule(
855        "aten", "ne_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
856    ),
857    ElementwiseTypePromotionRule(
858        "aten", "neg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
859    ),
860    ElementwiseTypePromotionRule(
861        "aten", "neg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
862    ),
863    ElementwiseTypePromotionRule(
864        "aten", "nextafter", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
865    ),
866    ElementwiseTypePromotionRule(
867        "aten", "nextafter_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
868    ),
869    ElementwiseTypePromotionRule(
870        "aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
871    ),
872    ElementwiseTypePromotionRule(
873        "aten", "normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
874    ),
875    ElementwiseTypePromotionRule(
876        "aten", "normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
877    ),
878    ElementwiseTypePromotionRule(
879        "aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
880    ),
881    ElementwiseTypePromotionRule(
882        "aten",
883        "poisson_nll_loss",
884        [0, 1],
885        [],
886        ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
887    ),
888    ElementwiseTypePromotionRule(
889        "aten", "pow", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
890    ),
891    ElementwiseTypePromotionRule(
892        "aten", "pow_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
893    ),
894    ElementwiseTypePromotionRule(
895        "aten", "prelu", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
896    ),
897    ElementwiseTypePromotionRule(
898        "aten", "rad2deg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
899    ),
900    ElementwiseTypePromotionRule(
901        "aten", "rad2deg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
902    ),
903    ElementwiseTypePromotionRule(
904        "aten", "reciprocal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
905    ),
906    ElementwiseTypePromotionRule(
907        "aten", "reciprocal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
908    ),
909    ElementwiseTypePromotionRule(
910        "aten", "relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
911    ),
912    ElementwiseTypePromotionRule(
913        "aten", "remainder", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
914    ),
915    ElementwiseTypePromotionRule(
916        "aten", "remainder_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
917    ),
918    ElementwiseTypePromotionRule(
919        "aten", "round", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
920    ),
921    ElementwiseTypePromotionRule(
922        "aten", "rsqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
923    ),
924    ElementwiseTypePromotionRule(
925        "aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
926    ),
927    ElementwiseTypePromotionRule(
928        "aten", "rsub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
929    ),
930    ElementwiseTypePromotionRule(
931        "aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
932    ),
933    ElementwiseTypePromotionRule(
934        "aten", "selu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
935    ),
936    ElementwiseTypePromotionRule(
937        "aten", "sgn", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
938    ),
939    ElementwiseTypePromotionRule(
940        "aten", "sgn_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
941    ),
942    ElementwiseTypePromotionRule(
943        "aten", "sigmoid", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
944    ),
945    ElementwiseTypePromotionRule(
946        "aten", "sigmoid_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
947    ),
948    ElementwiseTypePromotionRule(
949        "aten", "sign", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
950    ),
951    ElementwiseTypePromotionRule(
952        "aten", "sign_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
953    ),
954    ElementwiseTypePromotionRule(
955        "aten", "signbit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
956    ),
957    ElementwiseTypePromotionRule(
958        "aten", "sin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
959    ),
960    ElementwiseTypePromotionRule(
961        "aten", "sin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
962    ),
963    ElementwiseTypePromotionRule(
964        "aten", "sinc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
965    ),
966    ElementwiseTypePromotionRule(
967        "aten", "sinc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
968    ),
969    ElementwiseTypePromotionRule(
970        "aten", "sinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
971    ),
972    ElementwiseTypePromotionRule(
973        "aten", "sinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
974    ),
975    ElementwiseTypePromotionRule(
976        "aten",
977        "smooth_l1_loss",
978        [0, 1],
979        [],
980        ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
981    ),
982    ElementwiseTypePromotionRule(
983        "aten", "softplus", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
984    ),
985    ElementwiseTypePromotionRule(
986        "aten", "sqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
987    ),
988    ElementwiseTypePromotionRule(
989        "aten", "sqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
990    ),
991    ElementwiseTypePromotionRule(
992        "aten", "square", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
993    ),
994    ElementwiseTypePromotionRule(
995        "aten", "square_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
996    ),
997    ElementwiseTypePromotionRule(
998        "aten", "sub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
999    ),
1000    ElementwiseTypePromotionRule(
1001        "aten", "sub_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1002    ),
1003    ElementwiseTypePromotionRule(
1004        "aten", "tan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
1005    ),
1006    ElementwiseTypePromotionRule(
1007        "aten", "tan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
1008    ),
1009    ElementwiseTypePromotionRule(
1010        "aten", "tanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
1011    ),
1012    ElementwiseTypePromotionRule(
1013        "aten", "tanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
1014    ),
1015    ElementwiseTypePromotionRule(
1016        "aten", "threshold", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1017    ),
1018    ElementwiseTypePromotionRule(
1019        "aten", "threshold_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1020    ),
1021    ElementwiseTypePromotionRule(
1022        "aten", "true_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
1023    ),
1024    ElementwiseTypePromotionRule(
1025        "aten", "true_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
1026    ),
1027    ElementwiseTypePromotionRule(
1028        "aten", "trunc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1029    ),
1030    ElementwiseTypePromotionRule(
1031        "aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
1032    ),
1033    ElementwiseTypePromotionRule(
1034        "aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
1035    ),
1036    ElementwiseTypePromotionRule(
1037        "aten", "xlogy", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
1038    ),
1039    ElementwiseTypePromotionRule(
1040        "aten", "xlogy_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
1041    ),
1042}
1043
1044# Manually curated extra type promotion rules. Please see NOTE [Update type promotion rule]
1045# before adding new rules.
1046_EXTRA_TYPE_PROMOTION_RULE_SET = {
1047    # torch._refs skips type promotion decoration for `clamp_min` and `clamp_max` since
1048    # the call is routed to the decorated `aten.clamp` op.
1049    ElementwiseTypePromotionRule(
1050        "aten",
1051        "clamp_max",
1052        promote_args_positions=(0, 1),
1053        promote_kwargs_names=(),
1054        promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1055    ),
1056    ElementwiseTypePromotionRule(
1057        "aten",
1058        "clamp_min",
1059        promote_args_positions=(0, 1),
1060        promote_kwargs_names=(),
1061        promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
1062    ),
1063    # torch.ops.aten.div.Tensor_mode applies different type promotion rules
1064    # depending on the value of the `mode` argument.
1065    DivElementwiseTypePromotionRule(),
1066    # Manually curating reduction ops since the logic is written inside the op reference
1067    # implementation.
1068    AllOrAnyReductionTypePromotionRule("all"),
1069    AllOrAnyReductionTypePromotionRule("any"),
1070    ReductionTypePromotionRule(
1071        "aten",
1072        "amax",
1073        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
1074    ),
1075    ReductionTypePromotionRule(
1076        "aten",
1077        "amin",
1078        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
1079    ),
1080    # torch.ops.aten.mean is a special case that does not need type promotion.
1081    ReductionTypePromotionRule(
1082        "aten",
1083        "std",
1084        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
1085    ),
1086    ReductionTypePromotionRule(
1087        "aten",
1088        "std_mean",
1089        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
1090    ),
1091    ReductionTypePromotionRule(
1092        "aten",
1093        "var",
1094        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT,
1095    ),
1096    SumLikeReductionTypePromotionRule(
1097        "aten",
1098        "cumprod",
1099        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
1100    ),
1101    SumLikeReductionTypePromotionRule(
1102        "aten",
1103        "cumsum",
1104        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
1105    ),
1106    SumLikeReductionTypePromotionRule(
1107        "aten",
1108        "prod",
1109        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
1110    ),
1111    SumLikeReductionTypePromotionRule(
1112        "aten",
1113        "sum",
1114        promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME,
1115    ),
1116}
1117
1118
1119class ElementwiseTypePromotionRuleSetGenerator:
1120    """Hackly distilling info from reference ops decorated with elementwise type promotion rule.
1121
1122    The goal is to retrieve the decorator
1123
1124    ```python
1125        @elementwise_type_promotion_wrapper(
1126            type_promoting_args=("a", "b"),
1127            type_promotion_kind=type_promotion_kind,
1128        )
1129    ```
1130
1131    from the reference ops. It provides info as for which arguments are promoted
1132    and what kind of promotion is applied.
1133    """
1134
1135    @classmethod
1136    def generate_from_torch_refs(cls) -> set[ElementwiseTypePromotionRule]:
1137        """Parse type promotion rules from reference ops under torch._C._refs."""
1138        rule_set = set()
1139        rule_set.update(cls._parse_torch_refs(_refs))
1140        rule_set.update(cls._parse_torch_refs(_nn_refs))
1141        rule_set.update(cls._parse_torch_refs(_linalg_refs))
1142        rule_set.update(cls._parse_torch_refs(_special_refs))
1143        rule_set.update(cls._parse_torch_refs(_functional_refs))
1144        return rule_set
1145
1146    @classmethod
1147    def _parse_torch_refs(
1148        cls, ref_module: ModuleType
1149    ) -> set[ElementwiseTypePromotionRule]:
1150        logger.info("Processing module: %s", ref_module.__name__)
1151        rule_set = set()
1152        for name in ref_module.__all__:
1153            decorated_op = getattr(ref_module, name)
1154            rule = cls._parse_type_promotion_rule_from_refs_op(decorated_op)
1155            if rule is not None and rule.is_valid():
1156                rule_set.add(rule)
1157
1158        return rule_set
1159
1160    @classmethod
1161    def _parse_type_promotion_rule_from_refs_op(
1162        cls,
1163        decorated_op: Callable,
1164    ) -> ElementwiseTypePromotionRule | None:
1165        """Retrieve and parse type promotion decorator from op under torch._refs."""
1166        fn = decorated_op
1167        type_promo_wrapper = None
1168        while fn_closure_vars := _try_getclosurevars(fn):
1169            if "fn" not in fn_closure_vars.nonlocals:
1170                break
1171            if "self" in fn_closure_vars.nonlocals and isinstance(
1172                fn_closure_vars.nonlocals["self"],
1173                _prims_common_wrappers.elementwise_type_promotion_wrapper,
1174            ):
1175                type_promo_wrapper = fn_closure_vars.nonlocals["self"]
1176                break
1177            fn = fn_closure_vars.nonlocals["fn"]
1178
1179        if type_promo_wrapper is not None:
1180            signature = inspect.signature(decorated_op)
1181
1182            pos = 0
1183            promote_args_positions = []
1184            promote_kwargs_names = []
1185
1186            if type_promo_wrapper.type_promoting_arg_names is not None:
1187                for name, param in signature.parameters.items():
1188                    if name in type_promo_wrapper.type_promoting_arg_names:
1189                        if param.kind in (
1190                            param.POSITIONAL_OR_KEYWORD,
1191                            param.POSITIONAL_ONLY,
1192                        ):
1193                            promote_args_positions.append(pos)
1194                        elif param.kind == param.KEYWORD_ONLY:
1195                            promote_kwargs_names.append(name)
1196                    pos += 1
1197
1198            return ElementwiseTypePromotionRule(
1199                "aten",
1200                decorated_op.__name__,
1201                promote_args_positions=promote_args_positions,
1202                promote_kwargs_names=promote_kwargs_names,
1203                promotion_kind=type_promo_wrapper.type_promotion_kind,
1204            )
1205
1206        logger.warning(
1207            "Cannot find type promotion rule for: %s.%s",
1208            decorated_op.__module__,
1209            decorated_op.__name__,
1210        )
1211        return None
1212
1213
1214class TypePromotionTable:
1215    """Type promotion table for torch.ops."""
1216
1217    def __init__(self):
1218        self._rule_table = {}
1219        for rule in _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET:
1220            self.add_rule(rule)
1221        for rule in _EXTRA_TYPE_PROMOTION_RULE_SET:
1222            self.add_rule(rule)
1223
1224    def add_rule(self, rule: TypePromotionRule) -> None:
1225        """Add a type promotion rule for a python op in a torch.ops module.
1226
1227        Args:
1228            rule: Type promotion rule.
1229            module: Module containing the op. E.g. torch.ops.aten.
1230
1231        Raises:
1232            ValueError: If the rule is invalid.
1233        """
1234        if not rule.is_valid():
1235            raise ValueError(f"Invalid type promotion rule: {rule}")
1236        self._rule_table[f"{rule.namespace}.{rule.op_name}"] = rule
1237
1238    def get_rule(self, py_op: torch._ops.OpOverloadPacket) -> TypePromotionRule | None:
1239        """Get type promotion rule for a python op under 'torch.ops.<namespace>'."""
1240        return self._rule_table.get(str(py_op), None)
1241
1242
1243def get_type_promotion_rule(
1244    diagnostic: diagnostics.Diagnostic,
1245    node: torch.fx.Node,
1246    type_promotion_table: TypePromotionTable,
1247) -> TypePromotionRule | None:
1248    """Get type promotion rule for a node.
1249
1250    Args:
1251        diagnostic: Diagnostic object.
1252        node: Node to get type promotion rule for.
1253        type_promotion_table: Type promotion table.
1254
1255    Returns:
1256        Type promotion rule for the node. None if no rule is found or if the node is not
1257        representing a torch operator.
1258    """
1259    op = node.target
1260    if not isinstance(op, torch._ops.OpOverload):
1261        # TODO(bowbao): diagnostic.emit and diagnostic.set_message api.
1262        diagnostic.message = (
1263            f"Skipped for {diagnostics.format_argument(node)}: "
1264            f"node.target is not OpOverload. Got type: {type(op)}"
1265        )
1266        return None
1267    if (rule := type_promotion_table.get_rule(op.overloadpacket)) is None:
1268        diagnostic.message = (
1269            f"Skipped for {diagnostics.format_argument(node)}: "
1270            f"Cannot find type promotion rule for op: {op}"
1271        )
1272        return None
1273
1274    diagnostic.info("Found type promotion rule: %s", rule)
1275    return rule
1276
1277
1278class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode):
1279    """Trace ops that were dispatched.
1280
1281    Utilize the dispatch mechanism in [`__torch_dispatch__`](https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557)
1282    to trace op overloads that were dispatched to. This is used to find the compatible
1283    op overload for a given op overload packet for different set of args and kwargs.
1284    """
1285
1286    def __init__(self, *args, **kwargs):
1287        super().__init__(*args, **kwargs)
1288        self.traced_ops = []
1289
1290    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1291        self.traced_ops.append(func)
1292        return func(*args, **kwargs)
1293
1294
1295def find_compatible_op_overload(
1296    op: torch._ops.OpOverloadPacket, args: tuple, kwargs: dict
1297) -> torch._ops.OpOverload:
1298    """Find compatible OpOverload for an OpOverloadPacket using provided args and kwargs.
1299
1300    Each "call_function" fx.Node in the fx.GraphModule has a target that represents a torch._ops.OpOverload.
1301    The OpOverload contains an OpOverloadPacket that holds all the available overloads for the operation.
1302
1303    During the type promotion pass, there are cases where the types of the args and kwargs may change,
1304    such as promoting Python numbers to tensors. Consequently, the original OpOverload might not be
1305    compatible with the updated args and kwargs. This function is used to identify the compatible
1306    OpOverload for the given args and kwargs.
1307
1308    Args:
1309        op: OpOverloadPacket to find compatible OpOverload for.
1310        args: The positional arguments to consider for compatibility.
1311        kwargs: The keyword arguments to consider for compatibility.
1312
1313    Returns:
1314        torch._ops.OpOverload: The compatible OpOverload found for the given args and kwargs.
1315
1316    Raises:
1317        RuntimeError: If no compatible op overload is found.
1318
1319    Examples:
1320        >>> import torch
1321        >>> packet = torch.ops.aten.pow
1322        >>> args = (torch.tensor([1.0, 2.0]), 2)
1323        >>> find_compatible_op_overload(packet, args, {})._overloadname
1324        'Tensor_Scalar'
1325        >>> args = (torch.tensor([1.0, 2.0]), torch.tensor(2.0))
1326        >>> find_compatible_op_overload(packet, args, {})._overloadname
1327        'Tensor_Tensor'
1328    """
1329    # Utilize the dispatch mechanism to find the compatible op overload.
1330    op_trace_dispatch_mode = _OpTraceDispatchMode()
1331    with op_trace_dispatch_mode:
1332        op(*args, **kwargs)
1333    assert (
1334        len(op_trace_dispatch_mode.traced_ops) >= 1
1335    ), "Expected at least 1 traced op, got 0"
1336
1337    new_op_overload = op_trace_dispatch_mode.traced_ops[0]
1338    assert isinstance(
1339        new_op_overload, torch._ops.OpOverload
1340    ), f"Expected OpOverload, got {type(new_op_overload)}"
1341    assert (
1342        new_op_overload.overloadpacket == op
1343    ), f"Expected same OpOverload packet, got {new_op_overload.overloadpacket} != {op}"
1344
1345    return new_op_overload
1346
1347
1348class _TypePromotionInterpreter(torch.fx.Interpreter):
1349    """Interpreter that inserts type promotion for each node."""
1350
1351    def __init__(
1352        self,
1353        diagnostic_context: diagnostics.DiagnosticContext,
1354        module: torch.fx.GraphModule,
1355        type_promotion_table: TypePromotionTable,
1356    ):
1357        super().__init__(module)
1358        self.diagnostic_context = diagnostic_context
1359        self.type_promotion_table = type_promotion_table
1360
1361    def _run_node_and_set_meta(self, node) -> Any:
1362        """Run node and set meta according to `fx_traceback.get_current_meta()`.
1363
1364        This should be used on new nodes or nodes that have been modified.
1365        By default `Interpreter.run_node` does not update `node.meta`.
1366        Set `node.meta` to the current meta, except for `node.meta["val"]`, which is
1367        recomputed.
1368        """
1369        out = super().run_node(node)
1370        # Update interpreter env state with new output value.
1371        self.env[node] = out
1372        node.meta.update(
1373            (k, v)
1374            for k, v in fx_traceback.get_current_meta().items()
1375            if k not in node.meta
1376        )
1377        node.meta["val"] = proxy_tensor.extract_val(out)
1378        return out
1379
1380    def _create_node(
1381        self,
1382        graph: torch.fx.Graph,
1383        op_type: str,
1384        target: torch.fx.node.Target,
1385        args: tuple,
1386        kwargs: dict,
1387    ) -> torch.fx.Node:
1388        """Create a node and set its metadata."""
1389        assert op_type in (
1390            "call_function",
1391            "call_method",
1392            "get_attr",
1393            "call_module",
1394            "placeholder",
1395            "output",
1396        ), f"Unexpected op_type: {op_type}"
1397        node = getattr(graph, op_type)(target, args, kwargs)
1398        self._run_node_and_set_meta(node)
1399        return node
1400
1401    def _rerun_node_after_type_promotion(
1402        self,
1403        diagnostic: diagnostics.Diagnostic,
1404        node: torch.fx.Node,
1405        expected_out_dtype: torch.dtype,
1406    ) -> None:
1407        """Rerun a node after type promotion and update node.meta["val"] with the output value."""
1408        node_val = node.meta.get("val", None)
1409        assert node_val is not None, f"Node {node} node.meta['val'] is not set."
1410        args, kwargs = self.fetch_args_kwargs_from_env(node)
1411        target = node.target
1412        assert isinstance(
1413            target, torch._ops.OpOverload
1414        ), f"Expected OpOverload, got {type(target)}"
1415        node.target = find_compatible_op_overload(target.overloadpacket, args, kwargs)
1416
1417        new_node_val = self._run_node_and_set_meta(node)
1418        assert isinstance(new_node_val, type(node_val)), (
1419            f"run_node output type should not change between runs. "
1420            f"Got {type(new_node_val)}, expect {type(node_val)}."
1421        )
1422
1423        if isinstance(node_val, torch.Tensor):
1424            prev_node_dtype = node_val.dtype
1425
1426            assert prev_node_dtype == expected_out_dtype, (
1427                f"node.meta['val'].dtype({prev_node_dtype}) does not agree with "
1428                f"type promotion rule({expected_out_dtype})."
1429            )
1430
1431            if new_node_val.dtype != expected_out_dtype:
1432                # With explicit type promotion, the expected result dtype may not be
1433                # the same as the computation dtype. This is referred to as "op math".
1434                # We need to explicitly cast the output back to the expected dtype.
1435                # See more about "op math" topic at `_prims_common.elementwise_dtypes`.
1436                graph = node.graph
1437                with graph.inserting_after(node):
1438                    output_cast_node = self._create_node(
1439                        graph,
1440                        "call_function",
1441                        torch.ops.prims.convert_element_type.default,
1442                        (node,),
1443                        {"dtype": expected_out_dtype},
1444                    )
1445                    node.replace_all_uses_with(output_cast_node)
1446                    output_cast_node.args = (node,)
1447                    diagnostic.info(
1448                        "Node '%s' output dtype becomes %s due to op math. "
1449                        "Cast back to %s.",
1450                        node,
1451                        new_node_val.dtype,
1452                        expected_out_dtype,
1453                    )
1454
1455        elif fx_type_utils.is_torch_symbolic_type(node_val):
1456            raise NotImplementedError(
1457                "Type promotion does not support node output of sym types."
1458            )
1459        elif isinstance(node_val, (list, tuple)):
1460            raise NotImplementedError(
1461                "Type promotion does not support node output of list or tuple."
1462            )
1463        else:
1464            raise RuntimeError(f"Unexpected node output type: {type(node_val)}.")
1465
1466    def _maybe_promote_arg(
1467        self,
1468        diagnostic: diagnostics.Diagnostic,
1469        node: torch.fx.Node,
1470        fx_arg: torch.fx.node.Argument,
1471        dtype: torch.dtype | None,
1472    ) -> torch.fx.node.Argument:
1473        """Promote fx_arg to dtype if necessary."""
1474        if dtype is None:
1475            diagnostic.info(
1476                "Argument %s is not promoted. Not mentioned by type promotion rule.",
1477                fx_arg,
1478            )
1479            return fx_arg
1480
1481        if isinstance(fx_arg, torch.fx.Node):
1482            arg_val = self.env[fx_arg]
1483            if isinstance(arg_val, torch.Tensor):
1484                if (old_dtype := arg_val.dtype) != dtype:
1485                    # Promote tensor to dtype.
1486                    graph = node.graph
1487                    with graph.inserting_before(node):
1488                        diagnostic.info(
1489                            "Argument %s(%s) is promoted to %s.",
1490                            fx_arg,
1491                            old_dtype,
1492                            dtype,
1493                        )
1494                        return self._create_node(
1495                            graph,
1496                            "call_function",
1497                            torch.ops.prims.convert_element_type.default,
1498                            (fx_arg,),
1499                            {"dtype": dtype},
1500                        )
1501                diagnostic.info(
1502                    "Argument %s is not promoted. Already %s.", fx_arg, dtype
1503                )
1504                return fx_arg
1505            elif fx_type_utils.is_torch_symbolic_type(arg_val):
1506                arg_type = type(arg_val)
1507                equivalent_dtype = fx_type_utils.from_scalar_type_to_torch_dtype(
1508                    arg_type
1509                )
1510                assert equivalent_dtype is not None, f"Unexpected arg_type: {arg_type}"
1511                if equivalent_dtype != dtype:
1512                    # Promote Sym number to tensor of dtype.
1513                    graph = node.graph
1514                    with graph.inserting_before(node):
1515                        diagnostic.info(
1516                            "Argument %s(Scalar of equivalent dtype: %s) "
1517                            "is promoted to %s.",
1518                            fx_arg,
1519                            equivalent_dtype,
1520                            dtype,
1521                        )
1522                        return self._create_node(
1523                            graph,
1524                            "call_function",
1525                            torch.ops.aten.scalar_tensor.default,
1526                            (fx_arg,),
1527                            {"dtype": dtype},
1528                        )
1529                diagnostic.info(
1530                    "Argument %s is not promoted. Already %s.", fx_arg, dtype
1531                )
1532                return fx_arg
1533        elif (
1534            equivalent_dtype := fx_type_utils.from_scalar_type_to_torch_dtype(
1535                type(fx_arg)
1536            )
1537        ) is not None:
1538            if equivalent_dtype != dtype:
1539                # Promote number to tensor of dtype.
1540                # The op should have overload that supports tensor for this arg, otherwise
1541                # the type promotion rule should not suggest promoting this arg.
1542                graph = node.graph
1543                with graph.inserting_before(node):
1544                    diagnostic.info(
1545                        "Argument %s(Scalar of equivalent dtype: %s) "
1546                        "is promoted to %s.",
1547                        fx_arg,
1548                        equivalent_dtype,
1549                        dtype,
1550                    )
1551                    return self._create_node(
1552                        graph,
1553                        "call_function",
1554                        torch.ops.aten.scalar_tensor.default,
1555                        (fx_arg,),
1556                        {"dtype": dtype},
1557                    )
1558            diagnostic.info("Argument %s is not promoted. Already %s.", fx_arg, dtype)
1559            return fx_arg
1560        elif isinstance(fx_arg, (tuple, list)):
1561            diagnostic.info(
1562                "Argument %s is a tuple/list. Promoting each element.", fx_arg
1563            )
1564            return type(fx_arg)(
1565                self._maybe_promote_arg(diagnostic, node, fx_arg_elem, dtype)
1566                for fx_arg_elem in fx_arg
1567            )
1568
1569        raise NotImplementedError(f"Unknown fx arg type: {type(fx_arg)}")
1570
1571    def _maybe_promote_node(
1572        self,
1573        diagnostic: diagnostics.Diagnostic,
1574        node: torch.fx.Node,
1575        rule: TypePromotionRule,
1576    ) -> torch.fx.Node:
1577        """Promote node inputs and outputs according to type promotion rule."""
1578        args, kwargs = self.fetch_args_kwargs_from_env(node)
1579        type_promotion_info = rule.preview_type_promotion(args, kwargs)
1580        new_args = []
1581        new_kwargs = {}
1582        for i, arg in enumerate(node.args):
1583            new_args.append(
1584                self._maybe_promote_arg(
1585                    diagnostic, node, arg, type_promotion_info.args_dtypes.get(i, None)
1586                )
1587            )
1588
1589        for name, arg in node.kwargs.items():
1590            new_kwargs[name] = self._maybe_promote_arg(
1591                diagnostic, node, arg, type_promotion_info.kwargs_dtypes.get(name, None)
1592            )
1593        new_args = tuple(new_args)
1594
1595        if node.args != new_args or node.kwargs != new_kwargs:
1596            diagnostic.message = f"Applied type promotion for {node}. "
1597            node.args = new_args
1598            node.kwargs = new_kwargs
1599            self._rerun_node_after_type_promotion(
1600                diagnostic, node, type_promotion_info.out_dtype
1601            )
1602        else:
1603            diagnostic.message = f"Type promotion not needed for {node}. "
1604
1605        return node
1606
1607    @diagnostics.diagnose_call(
1608        rule=diagnostics.rules.fx_node_insert_type_promotion,
1609        level=diagnostics.levels.NONE,
1610    )
1611    def run_node(self, node: torch.fx.Node) -> Any:
1612        """This method is an override which inserts type promotion nodes as needed.
1613
1614        For each `call_function` node, an initial check is conducted to determine if a type
1615        promotion rule is applicable. If a relevant rule exists, type casting nodes are
1616        introduced for the corresponding arguments. The OpOverload of the node is updated
1617        to one that accommodates the promoted types. Should the output type be different,
1618        type casting node is inserted for this output.
1619
1620        The call `super().run_node(node)` is guaranteed to be invoked for each node.
1621        In the case of new or modified nodes, the result of `super().run_node(node)` is
1622        used to update its `node.meta["val"]` value.
1623        """
1624        diagnostic = self.diagnostic_context.inflight_diagnostic()
1625        with self._set_current_node(node):
1626            if node.op != "call_function":
1627                diagnostic.message = f"Skipped {node}: not a call_function."
1628            elif rule := get_type_promotion_rule(
1629                diagnostic, node, self.type_promotion_table
1630            ):
1631                self._maybe_promote_node(diagnostic, node, rule)
1632
1633        return super().run_node(node)
1634
1635
1636class InsertTypePromotion(_pass.Transform):
1637    """Explicitly insert type promotion ops to the graph.
1638
1639    This class subclasses `_pass.Transform` to provide graph level diagnostic tracking.
1640    Underneath, the main pass is driven by `_TypePromotionInterpreter`, which is a subclass
1641    of `torch.fx.Interpreter` to interpret the fx.Graph and perform the insertion of type
1642    promotion operations.
1643
1644    The interpreter is extended with ability to track diagnostic information for each node.
1645
1646    By re-running the new and modified nodes using the interpreter, we can update the
1647    metadata, specifically the fake tensor stored under node.meta["val"], and ensure it
1648    reflects the latest changes.
1649
1650    See [FXE0015: fx_node_insert_type_promotion](https://pytorch.org/docs/main/generated/onnx_dynamo_diagnostics_rules/FXE0015%3Afx-node-insert-type-promotion.html) for more details.  # noqa: B950
1651    """
1652
1653    def __init__(
1654        self,
1655        diagnostic_context: diagnostics.DiagnosticContext,
1656        module: torch.fx.GraphModule,
1657        type_promotion_table: TypePromotionTable | None = None,
1658    ):
1659        super().__init__(diagnostic_context, module)
1660        self.interpreter = _TypePromotionInterpreter(
1661            diagnostic_context, module, type_promotion_table or TypePromotionTable()
1662        )
1663
1664    def _fetch_fake_args(
1665        self,
1666    ) -> Sequence[
1667        fake_tensor.FakeTensor
1668        | float
1669        | int
1670        | bool
1671        | torch.SymInt
1672        | torch.SymFloat
1673        | torch.SymBool
1674        | None
1675    ]:
1676        """Fetch fake args from fx graph.
1677
1678        For each argument, try to fetch fake tensor from the matching placeholder node.
1679        """
1680        fake_args = []
1681        for node in self.module.graph.nodes:
1682            if node.op == "placeholder":
1683                try:
1684                    # Meta value can be torch.Tensor, int, float, bool,
1685                    # torch.SymInt, torch.SymFloat, torch.SymBool.
1686                    meta_value = _val = node.meta.get("val", None)
1687                except RuntimeError as e:
1688                    if not node.users:
1689                        # If the placeholder is not used, we can safely ignore it and put
1690                        # None as placeholder.
1691                        meta_value = None
1692                    else:
1693                        raise RuntimeError(
1694                            "Cannot fetch symbolic fake args from fx graph. "
1695                            "InsertTypePromotion pass needs to run with pre-existing fake args, "
1696                            "Otherwise the pass will produce inaccurate dynamic shape. "
1697                        ) from e
1698
1699                fake_args.append(meta_value)
1700        return fake_args
1701
1702    def _run(self, *args, **kwargs) -> torch.fx.GraphModule:
1703        assert not args, (
1704            "`InsertTypePromotion` deduces symbolic fake arguments from the graph. "
1705            "It does not accept concrete arguments as input because this pass requires "
1706            "re-running the graph. When executed with newly faked concrete arguments, "
1707            "the pass loses the symbolic dynamic shape information."
1708        )
1709        assert not kwargs, "`kwargs` is not supported"
1710
1711        fake_args = self._fetch_fake_args()
1712        fake_mode = self.fake_mode
1713        assert fake_mode is not None, "Cannot detect fake_mode."
1714
1715        with fake_tensor.unset_fake_temporarily(), (
1716            fake_mode
1717        ), fx_traceback.preserve_node_meta():
1718            self.interpreter.run(*fake_args)
1719
1720        return self.module
1721