1# mypy: allow-untyped-defs 2# Owner(s): ["module: onnx"] 3from __future__ import annotations 4 5import abc 6import dataclasses 7import inspect 8import logging 9from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING 10 11import torch 12import torch._ops 13import torch.fx 14import torch.fx.traceback as fx_traceback 15from torch import _prims_common, _refs 16from torch._prims_common import ( 17 ELEMENTWISE_TYPE_PROMOTION_KIND, 18 wrappers as _prims_common_wrappers, 19) 20from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs 21from torch._refs.nn import functional as _functional_refs 22from torch._subclasses import fake_tensor 23from torch.fx.experimental import proxy_tensor 24from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils 25from torch.utils import _python_dispatch, _pytree 26 27 28if TYPE_CHECKING: 29 from types import ModuleType 30 31 32logger = logging.getLogger(__name__) 33 34# TODO(bowbao): move to type utils. 35_SCALAR_TYPE_TENSOR_DTYPE_MAP: Mapping[type, torch.dtype] = { 36 bool: torch.bool, 37 int: torch.int64, 38 float: torch.float32, 39 complex: torch.complex32, 40} 41 42 43def _try_getclosurevars(func): 44 try: 45 return inspect.getclosurevars(func) 46 except TypeError as e: 47 return None 48 49 50@dataclasses.dataclass 51class TypePromotionSnapshot: 52 """Type promotion snapshot for a fx node and its inputs. 53 54 Contains the promoted dtype for args and kwargs that needs promoting. 55 Contains the expected node output dtype. 56 """ 57 58 args_dtypes: Mapping[int, torch.dtype] 59 """Mapping from arg position to dtype to promote to.""" 60 61 kwargs_dtypes: Mapping[str, torch.dtype] 62 """Mapping from kwarg name to dtype to promote to.""" 63 64 out_dtype: torch.dtype 65 """Expected output dtype of the node.""" 66 67 68class TypePromotionRule(abc.ABC): 69 """Base class for type promotion rule per 'torch.ops.{namespace}.{op_name}'.""" 70 71 def __init__(self, namespace: str, op_name: str): 72 self.namespace = namespace 73 self.op_name = op_name 74 75 # Make this abstract as well because subclass needs to override __eq__(). 76 # A class that overrides __eq__() and does not define __hash__() will have its __hash__() implicitly set to None. 77 # Ref: https://docs.python.org/3/reference/datamodel.html#object.__hash__ 78 @abc.abstractmethod 79 def __hash__(self) -> int: ... 80 81 @abc.abstractmethod 82 def __repr__(self): ... 83 84 @abc.abstractmethod 85 def __eq__(self, other: object) -> bool: ... 86 87 def is_valid(self) -> bool: 88 """Check if the rule is valid.""" 89 # This always returns a module. If the module does not exist it will be created. 90 module = getattr(torch.ops, self.namespace) 91 py_op = getattr(module, self.op_name, None) 92 if py_op is None: 93 logger.warning( 94 "Cannot find op: %s in module: %s", self.op_name, self.namespace 95 ) 96 return False 97 if not isinstance(py_op, torch._ops.OpOverloadPacket): 98 logger.warning( 99 "Op: torch.ops.%s.%s is not an OpOverloadPacket, got: %s", 100 self.namespace, 101 self.op_name, 102 type(py_op), 103 ) 104 return False 105 106 return True 107 108 @abc.abstractmethod 109 def preview_type_promotion( 110 self, args: tuple, kwargs: dict 111 ) -> TypePromotionSnapshot: 112 """Preview type promotion results for provided set of args and kwargs. 113 114 Returns a TypePromotionSnapshot object that contains the promoted dtypes for 115 the arguments and the expected output dtype. 116 """ 117 ... 118 119 120class ElementwiseTypePromotionRule(TypePromotionRule): 121 """Defines how to perform elementwise type promotion for 'torch.ops.{namespace}.{op_name}'.""" 122 123 _USE_OPMATH: bool = False 124 """Whether to use opmath to compute the promoted input dtype. 125 If used, upcasts will be inserted everywhere for lower precision models. 126 Set to False and have torchlib handle upcasts in op implementation internally. 127 """ 128 129 def __init__( 130 self, 131 namespace: str, 132 op_name: str, 133 promote_args_positions: Sequence[int], 134 promote_kwargs_names: Sequence[str], 135 promotion_kind: _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND, 136 ): 137 """Constructs a TypePromotionRule for elementwise operators. 138 139 Args: 140 namespace: Namespace of the op. E.g. 'aten' in 'torch.ops.aten.add'. 141 op_name: Name of the op. E.g. 'add' in 'torch.ops.aten.add'. 142 promote_args_positions: Positions of args to promote. 143 promote_kwargs_names: Names of kwargs to promote. 144 promotion_kind: Type promotion kind. Refer to [_prims_common.elementwise_dtypes](https://github.com/pytorch/pytorch/blob/main/torch/_prims_common/__init__.py) for detail. # noqa: B950 145 """ 146 super().__init__(namespace, op_name) 147 self.promote_args_positions = promote_args_positions 148 self.promote_kwargs_names = promote_kwargs_names 149 self.promotion_kind = promotion_kind 150 151 def __repr__(self): 152 return ( 153 f"ElementwiseTypePromotionRule('{self.namespace}', '{self.op_name}', " 154 f"{self.promote_args_positions}, {self.promote_kwargs_names}, {self.promotion_kind})" 155 ) 156 157 def __eq__(self, __value: object) -> bool: 158 if not isinstance(__value, ElementwiseTypePromotionRule): 159 return False 160 return ( 161 self.namespace == __value.namespace 162 and self.op_name == __value.op_name 163 and self.promote_args_positions == __value.promote_args_positions 164 and self.promote_kwargs_names == __value.promote_kwargs_names 165 and self.promotion_kind == __value.promotion_kind 166 ) 167 168 def __hash__(self) -> int: 169 return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__() 170 171 def _consolidate_input_dtype( 172 self, computed_dtype: torch.dtype, result_dtype: torch.dtype 173 ) -> torch.dtype: 174 """ 175 Although opmath is the right thing to do to retain on-par precision, it inserts 176 upcasts everywhere in the graph. This is particularly hard for backend to optimize 177 since there is no way to differentiate between inserted upcasts and model code 178 casts. Hence we consolidate the input dtype to the result dtype to avoid this. 179 """ 180 if not self._USE_OPMATH and self.promotion_kind in ( 181 _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 182 _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 183 ): 184 return result_dtype 185 return computed_dtype 186 187 def preview_type_promotion( 188 self, args: tuple, kwargs: dict 189 ) -> TypePromotionSnapshot: 190 candidate_args = { 191 i: args[i] 192 for i in self.promote_args_positions 193 if i < len(args) and args[i] is not None 194 } 195 candidate_kwargs = { 196 name: kwargs[name] 197 for name in self.promote_kwargs_names 198 if name in kwargs and kwargs[name] is not None 199 } 200 201 computed_dtype, result_dtype = _prims_common.elementwise_dtypes( 202 *_pytree.arg_tree_leaves(*candidate_args.values(), **candidate_kwargs), 203 type_promotion_kind=self.promotion_kind, 204 ) 205 206 consolidated_input_dtype = self._consolidate_input_dtype( 207 computed_dtype, result_dtype 208 ) 209 210 return TypePromotionSnapshot( 211 dict.fromkeys(candidate_args.keys(), consolidated_input_dtype), 212 dict.fromkeys(candidate_kwargs.keys(), consolidated_input_dtype), 213 result_dtype, 214 ) 215 216 217class DivElementwiseTypePromotionRule(ElementwiseTypePromotionRule): 218 """Reference type promotion rule from torch._refs.div. 219 220 Rule depends on the value of the `rounding_mode` argument. 221 """ 222 223 def __init__(self): 224 super().__init__( 225 "aten", 226 "div", 227 promote_args_positions=(0, 1), 228 promote_kwargs_names=(), 229 promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 230 ) 231 232 def preview_type_promotion( 233 self, args: tuple, kwargs: dict 234 ) -> TypePromotionSnapshot: 235 rounding_mode = kwargs.get("rounding_mode", None) 236 if rounding_mode is None: 237 # true_divide 238 self.promotion_kind = ( 239 _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 240 ) 241 return super().preview_type_promotion(args, kwargs) 242 if rounding_mode == "trunc": 243 # trunc_divide 244 self.promotion_kind = _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 245 return super().preview_type_promotion(args, kwargs) 246 if rounding_mode == "floor": 247 # floor_divide 248 self.promotion_kind = _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 249 return super().preview_type_promotion(args, kwargs) 250 raise ValueError(f"Unknown rounding_mode: {rounding_mode}") 251 252 253class ReductionTypePromotionRule(TypePromotionRule): 254 def __init__( 255 self, 256 namespace: str, 257 op_name: str, 258 promotion_kind: _prims_common.REDUCTION_OUTPUT_TYPE_KIND, 259 ): 260 """Constructs a TypePromotionRule for reduction operators. 261 262 Args: 263 namespace: Namespace of the op. E.g. 'aten' in 'torch.ops.aten.sum'. 264 op_name: Name of the op. E.g. 'sum' in 'torch.ops.aten.sum'. 265 promotion_kind: Type promotion kind. Refer to [_prims_common.reduction_dtypes]((https://github.com/pytorch/pytorch/blob/main/torch/_prims_common/__init__.py)) for detail. # noqa: B950 266 """ 267 super().__init__(namespace, op_name) 268 self.promotion_kind = promotion_kind 269 270 def __repr__(self): 271 return f"ReductionTypePromotionRule('{self.namespace}', '{self.op_name}', {self.promotion_kind})" 272 273 def __eq__(self, __value: object) -> bool: 274 if not isinstance(__value, ElementwiseTypePromotionRule): 275 return False 276 return ( 277 self.namespace == __value.namespace 278 and self.op_name == __value.op_name 279 and self.promotion_kind == __value.promotion_kind 280 ) 281 282 def __hash__(self) -> int: 283 return f"{type(self)}:{self.namespace}.{self.op_name}".__hash__() 284 285 def preview_type_promotion( 286 self, args: tuple, kwargs: dict 287 ) -> TypePromotionSnapshot: 288 assert ( 289 len(args) >= 1 290 ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" 291 arg = args[0] 292 assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" 293 dtype: torch.dtype | None = kwargs.get("dtype", None) 294 295 computation_dtype, result_dtype = _prims_common.reduction_dtypes( 296 arg, self.promotion_kind, dtype 297 ) 298 if result_dtype is None: 299 # Inspecting code, this can only happen when `promotion_kind` is `KEEP_PROMOTED_TYPE`. 300 # Hence set same as computation_dtype. 301 result_dtype = computation_dtype 302 303 return TypePromotionSnapshot( 304 {0: computation_dtype}, 305 {}, 306 result_dtype, 307 ) 308 309 310class AllOrAnyReductionTypePromotionRule(ReductionTypePromotionRule): 311 """Reference type promotion rule from torch.ops.aten.all or torch.ops.aten.any. 312 313 This is a special case where computation dtype is always torch.bool. 314 The result dtype is always uint8 if `dtype` kwarg is uint8, otherwise torch.bool. 315 """ 316 317 def __init__(self, op_name: str): 318 super().__init__( 319 "aten", 320 op_name, 321 _prims_common.REDUCTION_OUTPUT_TYPE_KIND.ALWAYS_BOOL, 322 ) 323 324 def preview_type_promotion( 325 self, args: tuple, kwargs: dict 326 ) -> TypePromotionSnapshot: 327 assert ( 328 len(args) >= 1 329 ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" 330 arg = args[0] 331 assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" 332 computation_dtype = torch.bool 333 # Preserves uint8 -- probably a legacy mask thing 334 result_dtype = torch.uint8 if arg.dtype == torch.uint8 else torch.bool 335 return TypePromotionSnapshot( 336 {0: computation_dtype}, 337 {}, 338 result_dtype, 339 ) 340 341 342class SumLikeReductionTypePromotionRule(ReductionTypePromotionRule): 343 """Reference type promotion rule from torch.ops.aten.sum. 344 345 This is a special case where computation dtype is always torch.int64 for integral arg, 346 unless overridden by `dtype` kwarg. 347 """ 348 349 def preview_type_promotion( 350 self, args: tuple, kwargs: dict 351 ) -> TypePromotionSnapshot: 352 assert ( 353 len(args) >= 1 354 ), f"Reduction op torch.ops.{self.namespace}.{self.op_name} expects at least one argument" 355 arg = args[0] 356 assert isinstance(arg, torch.Tensor), f"{type(arg)=} is not torch.Tensor" 357 dtype: torch.dtype | None = kwargs.get("dtype", None) 358 # The below logic is copied from `torch/_refs/__init__.py` reduction ops impl. 359 if dtype is None: 360 if _prims_common.is_boolean_dtype( 361 arg.dtype 362 ) or _prims_common.is_integer_dtype(arg.dtype): 363 dtype = torch.int64 364 else: 365 dtype = arg.dtype 366 return super().preview_type_promotion(args, {"dtype": dtype}) 367 368 369# NOTE: [Update type promotion rule] 370# BELOW TABLE IS GENERATED FROM `TypePromotionRuleSetGenerator.generate_from_torch_refs`. 371# DO NOT EDIT MANUALLY !!! 372# For missing rules or discrepancies, please 373# 1. Run `pytest test/onnx/test_fx_type_promotion.py` to validate if the generated rule set is current. 374# If it is not, update with new generated set. 375# 2. If discrepancies still exist, consider debugging torch._refs or report a bug. 376# 3. If rules are still missing, add them to `_EXTRA_TYPE_PROMOTION_RULE_SET` or report a bug. 377# Check `TypePromotionRule` class for how each rule is defined and used. 378_GENERATED_ATEN_TYPE_PROMOTION_RULE_SET = { 379 ElementwiseTypePromotionRule( 380 "aten", "abs", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT 381 ), 382 ElementwiseTypePromotionRule( 383 "aten", "abs_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT 384 ), 385 ElementwiseTypePromotionRule( 386 "aten", "acos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 387 ), 388 ElementwiseTypePromotionRule( 389 "aten", "acos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 390 ), 391 ElementwiseTypePromotionRule( 392 "aten", "acosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 393 ), 394 ElementwiseTypePromotionRule( 395 "aten", "acosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 396 ), 397 ElementwiseTypePromotionRule( 398 "aten", "add", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 399 ), 400 ElementwiseTypePromotionRule( 401 "aten", "add_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 402 ), 403 ElementwiseTypePromotionRule( 404 "aten", "addcdiv", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 405 ), 406 ElementwiseTypePromotionRule( 407 "aten", "addcdiv_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 408 ), 409 ElementwiseTypePromotionRule( 410 "aten", "addcmul", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 411 ), 412 ElementwiseTypePromotionRule( 413 "aten", "addcmul_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 414 ), 415 ElementwiseTypePromotionRule( 416 "aten", "addr", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 417 ), 418 ElementwiseTypePromotionRule( 419 "aten", "asin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 420 ), 421 ElementwiseTypePromotionRule( 422 "aten", "asin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 423 ), 424 ElementwiseTypePromotionRule( 425 "aten", "asinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 426 ), 427 ElementwiseTypePromotionRule( 428 "aten", "asinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 429 ), 430 ElementwiseTypePromotionRule( 431 "aten", "atan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 432 ), 433 ElementwiseTypePromotionRule( 434 "aten", "atan2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 435 ), 436 ElementwiseTypePromotionRule( 437 "aten", "atan2_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 438 ), 439 ElementwiseTypePromotionRule( 440 "aten", "atan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 441 ), 442 ElementwiseTypePromotionRule( 443 "aten", "atanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 444 ), 445 ElementwiseTypePromotionRule( 446 "aten", "atanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 447 ), 448 ElementwiseTypePromotionRule( 449 "aten", "bitwise_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 450 ), 451 ElementwiseTypePromotionRule( 452 "aten", "bitwise_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 453 ), 454 ElementwiseTypePromotionRule( 455 "aten", 456 "bitwise_left_shift", 457 [0, 1], 458 [], 459 ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 460 ), 461 ElementwiseTypePromotionRule( 462 "aten", 463 "bitwise_left_shift_", 464 [0, 1], 465 [], 466 ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 467 ), 468 ElementwiseTypePromotionRule( 469 "aten", "bitwise_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 470 ), 471 ElementwiseTypePromotionRule( 472 "aten", "bitwise_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 473 ), 474 ElementwiseTypePromotionRule( 475 "aten", "bitwise_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 476 ), 477 ElementwiseTypePromotionRule( 478 "aten", "bitwise_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 479 ), 480 ElementwiseTypePromotionRule( 481 "aten", 482 "bitwise_right_shift", 483 [0, 1], 484 [], 485 ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 486 ), 487 ElementwiseTypePromotionRule( 488 "aten", 489 "bitwise_right_shift_", 490 [0, 1], 491 [], 492 ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 493 ), 494 ElementwiseTypePromotionRule( 495 "aten", "bitwise_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 496 ), 497 ElementwiseTypePromotionRule( 498 "aten", "bitwise_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 499 ), 500 ElementwiseTypePromotionRule( 501 "aten", "cat", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH 502 ), 503 ElementwiseTypePromotionRule( 504 "aten", "cauchy", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 505 ), 506 ElementwiseTypePromotionRule( 507 "aten", "cauchy_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 508 ), 509 ElementwiseTypePromotionRule( 510 "aten", "ceil", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 511 ), 512 ElementwiseTypePromotionRule( 513 "aten", "ceil_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 514 ), 515 ElementwiseTypePromotionRule( 516 "aten", "celu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 517 ), 518 ElementwiseTypePromotionRule( 519 "aten", "celu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 520 ), 521 ElementwiseTypePromotionRule( 522 "aten", "clamp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 523 ), 524 ElementwiseTypePromotionRule( 525 "aten", "clamp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 526 ), 527 ElementwiseTypePromotionRule( 528 "aten", "copysign", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 529 ), 530 ElementwiseTypePromotionRule( 531 "aten", "copysign_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 532 ), 533 ElementwiseTypePromotionRule( 534 "aten", "cos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 535 ), 536 ElementwiseTypePromotionRule( 537 "aten", "cos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 538 ), 539 ElementwiseTypePromotionRule( 540 "aten", "cosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 541 ), 542 ElementwiseTypePromotionRule( 543 "aten", "cosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 544 ), 545 ElementwiseTypePromotionRule( 546 "aten", "deg2rad", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 547 ), 548 ElementwiseTypePromotionRule( 549 "aten", "deg2rad_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 550 ), 551 ElementwiseTypePromotionRule( 552 "aten", "digamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 553 ), 554 ElementwiseTypePromotionRule( 555 "aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 556 ), 557 ElementwiseTypePromotionRule( 558 "aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 559 ), 560 ElementwiseTypePromotionRule( 561 "aten", "elu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 562 ), 563 ElementwiseTypePromotionRule( 564 "aten", "eq", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 565 ), 566 ElementwiseTypePromotionRule( 567 "aten", "eq_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 568 ), 569 ElementwiseTypePromotionRule( 570 "aten", "erf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 571 ), 572 ElementwiseTypePromotionRule( 573 "aten", "erf_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 574 ), 575 ElementwiseTypePromotionRule( 576 "aten", "erfc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 577 ), 578 ElementwiseTypePromotionRule( 579 "aten", "erfc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 580 ), 581 ElementwiseTypePromotionRule( 582 "aten", "erfinv", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 583 ), 584 ElementwiseTypePromotionRule( 585 "aten", "erfinv_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 586 ), 587 ElementwiseTypePromotionRule( 588 "aten", "exp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 589 ), 590 ElementwiseTypePromotionRule( 591 "aten", "exp2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 592 ), 593 ElementwiseTypePromotionRule( 594 "aten", "exp2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 595 ), 596 ElementwiseTypePromotionRule( 597 "aten", "exp_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 598 ), 599 ElementwiseTypePromotionRule( 600 "aten", "expm1", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 601 ), 602 ElementwiseTypePromotionRule( 603 "aten", "expm1_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 604 ), 605 ElementwiseTypePromotionRule( 606 "aten", "exponential", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 607 ), 608 ElementwiseTypePromotionRule( 609 "aten", "exponential_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 610 ), 611 ElementwiseTypePromotionRule( 612 "aten", "fill", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH 613 ), 614 ElementwiseTypePromotionRule( 615 "aten", "floor", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 616 ), 617 ElementwiseTypePromotionRule( 618 "aten", "floor_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 619 ), 620 ElementwiseTypePromotionRule( 621 "aten", "floor_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 622 ), 623 ElementwiseTypePromotionRule( 624 "aten", "floor_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 625 ), 626 ElementwiseTypePromotionRule( 627 "aten", "fmax", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 628 ), 629 ElementwiseTypePromotionRule( 630 "aten", "fmin", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 631 ), 632 ElementwiseTypePromotionRule( 633 "aten", "fmod", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 634 ), 635 ElementwiseTypePromotionRule( 636 "aten", "fmod_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 637 ), 638 ElementwiseTypePromotionRule( 639 "aten", "frac", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 640 ), 641 ElementwiseTypePromotionRule( 642 "aten", "frac_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 643 ), 644 ElementwiseTypePromotionRule( 645 "aten", "gcd", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 646 ), 647 ElementwiseTypePromotionRule( 648 "aten", "gcd_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 649 ), 650 ElementwiseTypePromotionRule( 651 "aten", "ge", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 652 ), 653 ElementwiseTypePromotionRule( 654 "aten", "ge_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 655 ), 656 ElementwiseTypePromotionRule( 657 "aten", "gelu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 658 ), 659 ElementwiseTypePromotionRule( 660 "aten", "geometric", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 661 ), 662 ElementwiseTypePromotionRule( 663 "aten", "geometric_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 664 ), 665 ElementwiseTypePromotionRule( 666 "aten", "glu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 667 ), 668 ElementwiseTypePromotionRule( 669 "aten", "gt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 670 ), 671 ElementwiseTypePromotionRule( 672 "aten", "gt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 673 ), 674 ElementwiseTypePromotionRule( 675 "aten", "hardtanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 676 ), 677 ElementwiseTypePromotionRule( 678 "aten", "heaviside", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 679 ), 680 ElementwiseTypePromotionRule( 681 "aten", "heaviside_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 682 ), 683 ElementwiseTypePromotionRule( 684 "aten", "huber_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 685 ), 686 ElementwiseTypePromotionRule( 687 "aten", "hypot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 688 ), 689 ElementwiseTypePromotionRule( 690 "aten", "hypot_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 691 ), 692 ElementwiseTypePromotionRule( 693 "aten", "i0", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 694 ), 695 ElementwiseTypePromotionRule( 696 "aten", "i0_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 697 ), 698 ElementwiseTypePromotionRule( 699 "aten", "igamma", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 700 ), 701 ElementwiseTypePromotionRule( 702 "aten", "igamma_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 703 ), 704 ElementwiseTypePromotionRule( 705 "aten", "igammac", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 706 ), 707 ElementwiseTypePromotionRule( 708 "aten", "igammac_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 709 ), 710 ElementwiseTypePromotionRule( 711 "aten", "isfinite", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 712 ), 713 ElementwiseTypePromotionRule( 714 "aten", "isinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 715 ), 716 ElementwiseTypePromotionRule( 717 "aten", "isnan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 718 ), 719 ElementwiseTypePromotionRule( 720 "aten", "isneginf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 721 ), 722 ElementwiseTypePromotionRule( 723 "aten", "isposinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 724 ), 725 ElementwiseTypePromotionRule( 726 "aten", "isreal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 727 ), 728 ElementwiseTypePromotionRule( 729 "aten", "l1_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT 730 ), 731 ElementwiseTypePromotionRule( 732 "aten", "lcm", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 733 ), 734 ElementwiseTypePromotionRule( 735 "aten", "lcm_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 736 ), 737 ElementwiseTypePromotionRule( 738 "aten", "le", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 739 ), 740 ElementwiseTypePromotionRule( 741 "aten", "le_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 742 ), 743 ElementwiseTypePromotionRule( 744 "aten", "leaky_relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 745 ), 746 ElementwiseTypePromotionRule( 747 "aten", "lerp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 748 ), 749 ElementwiseTypePromotionRule( 750 "aten", "lerp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 751 ), 752 ElementwiseTypePromotionRule( 753 "aten", "lgamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 754 ), 755 ElementwiseTypePromotionRule( 756 "aten", "lgamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 757 ), 758 ElementwiseTypePromotionRule( 759 "aten", "log", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 760 ), 761 ElementwiseTypePromotionRule( 762 "aten", "log10", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 763 ), 764 ElementwiseTypePromotionRule( 765 "aten", "log10_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 766 ), 767 ElementwiseTypePromotionRule( 768 "aten", "log1p", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 769 ), 770 ElementwiseTypePromotionRule( 771 "aten", "log1p_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 772 ), 773 ElementwiseTypePromotionRule( 774 "aten", "log2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 775 ), 776 ElementwiseTypePromotionRule( 777 "aten", "log2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 778 ), 779 ElementwiseTypePromotionRule( 780 "aten", "log_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 781 ), 782 ElementwiseTypePromotionRule( 783 "aten", "log_normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 784 ), 785 ElementwiseTypePromotionRule( 786 "aten", "log_normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 787 ), 788 ElementwiseTypePromotionRule( 789 "aten", "logaddexp", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 790 ), 791 ElementwiseTypePromotionRule( 792 "aten", "logaddexp2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 793 ), 794 ElementwiseTypePromotionRule( 795 "aten", "logical_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 796 ), 797 ElementwiseTypePromotionRule( 798 "aten", "logical_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 799 ), 800 ElementwiseTypePromotionRule( 801 "aten", "logical_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 802 ), 803 ElementwiseTypePromotionRule( 804 "aten", "logical_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 805 ), 806 ElementwiseTypePromotionRule( 807 "aten", "logical_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 808 ), 809 ElementwiseTypePromotionRule( 810 "aten", "logical_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 811 ), 812 ElementwiseTypePromotionRule( 813 "aten", "logical_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 814 ), 815 ElementwiseTypePromotionRule( 816 "aten", "logical_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 817 ), 818 ElementwiseTypePromotionRule( 819 "aten", "logit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 820 ), 821 ElementwiseTypePromotionRule( 822 "aten", "logsumexp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 823 ), 824 ElementwiseTypePromotionRule( 825 "aten", "lt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 826 ), 827 ElementwiseTypePromotionRule( 828 "aten", "lt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 829 ), 830 ElementwiseTypePromotionRule( 831 "aten", "maximum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 832 ), 833 ElementwiseTypePromotionRule( 834 "aten", "minimum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 835 ), 836 ElementwiseTypePromotionRule( 837 "aten", "mish", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 838 ), 839 ElementwiseTypePromotionRule( 840 "aten", "mish_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 841 ), 842 ElementwiseTypePromotionRule( 843 "aten", "mse_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT 844 ), 845 ElementwiseTypePromotionRule( 846 "aten", "mul", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 847 ), 848 ElementwiseTypePromotionRule( 849 "aten", "mul_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 850 ), 851 ElementwiseTypePromotionRule( 852 "aten", "ne", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 853 ), 854 ElementwiseTypePromotionRule( 855 "aten", "ne_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 856 ), 857 ElementwiseTypePromotionRule( 858 "aten", "neg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 859 ), 860 ElementwiseTypePromotionRule( 861 "aten", "neg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 862 ), 863 ElementwiseTypePromotionRule( 864 "aten", "nextafter", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH 865 ), 866 ElementwiseTypePromotionRule( 867 "aten", "nextafter_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH 868 ), 869 ElementwiseTypePromotionRule( 870 "aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 871 ), 872 ElementwiseTypePromotionRule( 873 "aten", "normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 874 ), 875 ElementwiseTypePromotionRule( 876 "aten", "normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 877 ), 878 ElementwiseTypePromotionRule( 879 "aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 880 ), 881 ElementwiseTypePromotionRule( 882 "aten", 883 "poisson_nll_loss", 884 [0, 1], 885 [], 886 ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, 887 ), 888 ElementwiseTypePromotionRule( 889 "aten", "pow", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG 890 ), 891 ElementwiseTypePromotionRule( 892 "aten", "pow_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG 893 ), 894 ElementwiseTypePromotionRule( 895 "aten", "prelu", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 896 ), 897 ElementwiseTypePromotionRule( 898 "aten", "rad2deg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 899 ), 900 ElementwiseTypePromotionRule( 901 "aten", "rad2deg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 902 ), 903 ElementwiseTypePromotionRule( 904 "aten", "reciprocal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 905 ), 906 ElementwiseTypePromotionRule( 907 "aten", "reciprocal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 908 ), 909 ElementwiseTypePromotionRule( 910 "aten", "relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 911 ), 912 ElementwiseTypePromotionRule( 913 "aten", "remainder", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 914 ), 915 ElementwiseTypePromotionRule( 916 "aten", "remainder_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 917 ), 918 ElementwiseTypePromotionRule( 919 "aten", "round", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 920 ), 921 ElementwiseTypePromotionRule( 922 "aten", "rsqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 923 ), 924 ElementwiseTypePromotionRule( 925 "aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 926 ), 927 ElementwiseTypePromotionRule( 928 "aten", "rsub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 929 ), 930 ElementwiseTypePromotionRule( 931 "aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 932 ), 933 ElementwiseTypePromotionRule( 934 "aten", "selu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 935 ), 936 ElementwiseTypePromotionRule( 937 "aten", "sgn", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 938 ), 939 ElementwiseTypePromotionRule( 940 "aten", "sgn_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 941 ), 942 ElementwiseTypePromotionRule( 943 "aten", "sigmoid", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 944 ), 945 ElementwiseTypePromotionRule( 946 "aten", "sigmoid_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 947 ), 948 ElementwiseTypePromotionRule( 949 "aten", "sign", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 950 ), 951 ElementwiseTypePromotionRule( 952 "aten", "sign_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 953 ), 954 ElementwiseTypePromotionRule( 955 "aten", "signbit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL 956 ), 957 ElementwiseTypePromotionRule( 958 "aten", "sin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 959 ), 960 ElementwiseTypePromotionRule( 961 "aten", "sin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 962 ), 963 ElementwiseTypePromotionRule( 964 "aten", "sinc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 965 ), 966 ElementwiseTypePromotionRule( 967 "aten", "sinc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 968 ), 969 ElementwiseTypePromotionRule( 970 "aten", "sinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 971 ), 972 ElementwiseTypePromotionRule( 973 "aten", "sinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 974 ), 975 ElementwiseTypePromotionRule( 976 "aten", 977 "smooth_l1_loss", 978 [0, 1], 979 [], 980 ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT, 981 ), 982 ElementwiseTypePromotionRule( 983 "aten", "softplus", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 984 ), 985 ElementwiseTypePromotionRule( 986 "aten", "sqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 987 ), 988 ElementwiseTypePromotionRule( 989 "aten", "sqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 990 ), 991 ElementwiseTypePromotionRule( 992 "aten", "square", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG 993 ), 994 ElementwiseTypePromotionRule( 995 "aten", "square_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG 996 ), 997 ElementwiseTypePromotionRule( 998 "aten", "sub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 999 ), 1000 ElementwiseTypePromotionRule( 1001 "aten", "sub_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1002 ), 1003 ElementwiseTypePromotionRule( 1004 "aten", "tan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 1005 ), 1006 ElementwiseTypePromotionRule( 1007 "aten", "tan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 1008 ), 1009 ElementwiseTypePromotionRule( 1010 "aten", "tanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 1011 ), 1012 ElementwiseTypePromotionRule( 1013 "aten", "tanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 1014 ), 1015 ElementwiseTypePromotionRule( 1016 "aten", "threshold", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1017 ), 1018 ElementwiseTypePromotionRule( 1019 "aten", "threshold_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1020 ), 1021 ElementwiseTypePromotionRule( 1022 "aten", "true_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 1023 ), 1024 ElementwiseTypePromotionRule( 1025 "aten", "true_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 1026 ), 1027 ElementwiseTypePromotionRule( 1028 "aten", "trunc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1029 ), 1030 ElementwiseTypePromotionRule( 1031 "aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT 1032 ), 1033 ElementwiseTypePromotionRule( 1034 "aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH 1035 ), 1036 ElementwiseTypePromotionRule( 1037 "aten", "xlogy", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 1038 ), 1039 ElementwiseTypePromotionRule( 1040 "aten", "xlogy_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT 1041 ), 1042} 1043 1044# Manually curated extra type promotion rules. Please see NOTE [Update type promotion rule] 1045# before adding new rules. 1046_EXTRA_TYPE_PROMOTION_RULE_SET = { 1047 # torch._refs skips type promotion decoration for `clamp_min` and `clamp_max` since 1048 # the call is routed to the decorated `aten.clamp` op. 1049 ElementwiseTypePromotionRule( 1050 "aten", 1051 "clamp_max", 1052 promote_args_positions=(0, 1), 1053 promote_kwargs_names=(), 1054 promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1055 ), 1056 ElementwiseTypePromotionRule( 1057 "aten", 1058 "clamp_min", 1059 promote_args_positions=(0, 1), 1060 promote_kwargs_names=(), 1061 promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, 1062 ), 1063 # torch.ops.aten.div.Tensor_mode applies different type promotion rules 1064 # depending on the value of the `mode` argument. 1065 DivElementwiseTypePromotionRule(), 1066 # Manually curating reduction ops since the logic is written inside the op reference 1067 # implementation. 1068 AllOrAnyReductionTypePromotionRule("all"), 1069 AllOrAnyReductionTypePromotionRule("any"), 1070 ReductionTypePromotionRule( 1071 "aten", 1072 "amax", 1073 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, 1074 ), 1075 ReductionTypePromotionRule( 1076 "aten", 1077 "amin", 1078 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, 1079 ), 1080 # torch.ops.aten.mean is a special case that does not need type promotion. 1081 ReductionTypePromotionRule( 1082 "aten", 1083 "std", 1084 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, 1085 ), 1086 ReductionTypePromotionRule( 1087 "aten", 1088 "std_mean", 1089 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, 1090 ), 1091 ReductionTypePromotionRule( 1092 "aten", 1093 "var", 1094 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, 1095 ), 1096 SumLikeReductionTypePromotionRule( 1097 "aten", 1098 "cumprod", 1099 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, 1100 ), 1101 SumLikeReductionTypePromotionRule( 1102 "aten", 1103 "cumsum", 1104 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, 1105 ), 1106 SumLikeReductionTypePromotionRule( 1107 "aten", 1108 "prod", 1109 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, 1110 ), 1111 SumLikeReductionTypePromotionRule( 1112 "aten", 1113 "sum", 1114 promotion_kind=_prims_common.REDUCTION_OUTPUT_TYPE_KIND.SAME, 1115 ), 1116} 1117 1118 1119class ElementwiseTypePromotionRuleSetGenerator: 1120 """Hackly distilling info from reference ops decorated with elementwise type promotion rule. 1121 1122 The goal is to retrieve the decorator 1123 1124 ```python 1125 @elementwise_type_promotion_wrapper( 1126 type_promoting_args=("a", "b"), 1127 type_promotion_kind=type_promotion_kind, 1128 ) 1129 ``` 1130 1131 from the reference ops. It provides info as for which arguments are promoted 1132 and what kind of promotion is applied. 1133 """ 1134 1135 @classmethod 1136 def generate_from_torch_refs(cls) -> set[ElementwiseTypePromotionRule]: 1137 """Parse type promotion rules from reference ops under torch._C._refs.""" 1138 rule_set = set() 1139 rule_set.update(cls._parse_torch_refs(_refs)) 1140 rule_set.update(cls._parse_torch_refs(_nn_refs)) 1141 rule_set.update(cls._parse_torch_refs(_linalg_refs)) 1142 rule_set.update(cls._parse_torch_refs(_special_refs)) 1143 rule_set.update(cls._parse_torch_refs(_functional_refs)) 1144 return rule_set 1145 1146 @classmethod 1147 def _parse_torch_refs( 1148 cls, ref_module: ModuleType 1149 ) -> set[ElementwiseTypePromotionRule]: 1150 logger.info("Processing module: %s", ref_module.__name__) 1151 rule_set = set() 1152 for name in ref_module.__all__: 1153 decorated_op = getattr(ref_module, name) 1154 rule = cls._parse_type_promotion_rule_from_refs_op(decorated_op) 1155 if rule is not None and rule.is_valid(): 1156 rule_set.add(rule) 1157 1158 return rule_set 1159 1160 @classmethod 1161 def _parse_type_promotion_rule_from_refs_op( 1162 cls, 1163 decorated_op: Callable, 1164 ) -> ElementwiseTypePromotionRule | None: 1165 """Retrieve and parse type promotion decorator from op under torch._refs.""" 1166 fn = decorated_op 1167 type_promo_wrapper = None 1168 while fn_closure_vars := _try_getclosurevars(fn): 1169 if "fn" not in fn_closure_vars.nonlocals: 1170 break 1171 if "self" in fn_closure_vars.nonlocals and isinstance( 1172 fn_closure_vars.nonlocals["self"], 1173 _prims_common_wrappers.elementwise_type_promotion_wrapper, 1174 ): 1175 type_promo_wrapper = fn_closure_vars.nonlocals["self"] 1176 break 1177 fn = fn_closure_vars.nonlocals["fn"] 1178 1179 if type_promo_wrapper is not None: 1180 signature = inspect.signature(decorated_op) 1181 1182 pos = 0 1183 promote_args_positions = [] 1184 promote_kwargs_names = [] 1185 1186 if type_promo_wrapper.type_promoting_arg_names is not None: 1187 for name, param in signature.parameters.items(): 1188 if name in type_promo_wrapper.type_promoting_arg_names: 1189 if param.kind in ( 1190 param.POSITIONAL_OR_KEYWORD, 1191 param.POSITIONAL_ONLY, 1192 ): 1193 promote_args_positions.append(pos) 1194 elif param.kind == param.KEYWORD_ONLY: 1195 promote_kwargs_names.append(name) 1196 pos += 1 1197 1198 return ElementwiseTypePromotionRule( 1199 "aten", 1200 decorated_op.__name__, 1201 promote_args_positions=promote_args_positions, 1202 promote_kwargs_names=promote_kwargs_names, 1203 promotion_kind=type_promo_wrapper.type_promotion_kind, 1204 ) 1205 1206 logger.warning( 1207 "Cannot find type promotion rule for: %s.%s", 1208 decorated_op.__module__, 1209 decorated_op.__name__, 1210 ) 1211 return None 1212 1213 1214class TypePromotionTable: 1215 """Type promotion table for torch.ops.""" 1216 1217 def __init__(self): 1218 self._rule_table = {} 1219 for rule in _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET: 1220 self.add_rule(rule) 1221 for rule in _EXTRA_TYPE_PROMOTION_RULE_SET: 1222 self.add_rule(rule) 1223 1224 def add_rule(self, rule: TypePromotionRule) -> None: 1225 """Add a type promotion rule for a python op in a torch.ops module. 1226 1227 Args: 1228 rule: Type promotion rule. 1229 module: Module containing the op. E.g. torch.ops.aten. 1230 1231 Raises: 1232 ValueError: If the rule is invalid. 1233 """ 1234 if not rule.is_valid(): 1235 raise ValueError(f"Invalid type promotion rule: {rule}") 1236 self._rule_table[f"{rule.namespace}.{rule.op_name}"] = rule 1237 1238 def get_rule(self, py_op: torch._ops.OpOverloadPacket) -> TypePromotionRule | None: 1239 """Get type promotion rule for a python op under 'torch.ops.<namespace>'.""" 1240 return self._rule_table.get(str(py_op), None) 1241 1242 1243def get_type_promotion_rule( 1244 diagnostic: diagnostics.Diagnostic, 1245 node: torch.fx.Node, 1246 type_promotion_table: TypePromotionTable, 1247) -> TypePromotionRule | None: 1248 """Get type promotion rule for a node. 1249 1250 Args: 1251 diagnostic: Diagnostic object. 1252 node: Node to get type promotion rule for. 1253 type_promotion_table: Type promotion table. 1254 1255 Returns: 1256 Type promotion rule for the node. None if no rule is found or if the node is not 1257 representing a torch operator. 1258 """ 1259 op = node.target 1260 if not isinstance(op, torch._ops.OpOverload): 1261 # TODO(bowbao): diagnostic.emit and diagnostic.set_message api. 1262 diagnostic.message = ( 1263 f"Skipped for {diagnostics.format_argument(node)}: " 1264 f"node.target is not OpOverload. Got type: {type(op)}" 1265 ) 1266 return None 1267 if (rule := type_promotion_table.get_rule(op.overloadpacket)) is None: 1268 diagnostic.message = ( 1269 f"Skipped for {diagnostics.format_argument(node)}: " 1270 f"Cannot find type promotion rule for op: {op}" 1271 ) 1272 return None 1273 1274 diagnostic.info("Found type promotion rule: %s", rule) 1275 return rule 1276 1277 1278class _OpTraceDispatchMode(_python_dispatch.TorchDispatchMode): 1279 """Trace ops that were dispatched. 1280 1281 Utilize the dispatch mechanism in [`__torch_dispatch__`](https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557) 1282 to trace op overloads that were dispatched to. This is used to find the compatible 1283 op overload for a given op overload packet for different set of args and kwargs. 1284 """ 1285 1286 def __init__(self, *args, **kwargs): 1287 super().__init__(*args, **kwargs) 1288 self.traced_ops = [] 1289 1290 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1291 self.traced_ops.append(func) 1292 return func(*args, **kwargs) 1293 1294 1295def find_compatible_op_overload( 1296 op: torch._ops.OpOverloadPacket, args: tuple, kwargs: dict 1297) -> torch._ops.OpOverload: 1298 """Find compatible OpOverload for an OpOverloadPacket using provided args and kwargs. 1299 1300 Each "call_function" fx.Node in the fx.GraphModule has a target that represents a torch._ops.OpOverload. 1301 The OpOverload contains an OpOverloadPacket that holds all the available overloads for the operation. 1302 1303 During the type promotion pass, there are cases where the types of the args and kwargs may change, 1304 such as promoting Python numbers to tensors. Consequently, the original OpOverload might not be 1305 compatible with the updated args and kwargs. This function is used to identify the compatible 1306 OpOverload for the given args and kwargs. 1307 1308 Args: 1309 op: OpOverloadPacket to find compatible OpOverload for. 1310 args: The positional arguments to consider for compatibility. 1311 kwargs: The keyword arguments to consider for compatibility. 1312 1313 Returns: 1314 torch._ops.OpOverload: The compatible OpOverload found for the given args and kwargs. 1315 1316 Raises: 1317 RuntimeError: If no compatible op overload is found. 1318 1319 Examples: 1320 >>> import torch 1321 >>> packet = torch.ops.aten.pow 1322 >>> args = (torch.tensor([1.0, 2.0]), 2) 1323 >>> find_compatible_op_overload(packet, args, {})._overloadname 1324 'Tensor_Scalar' 1325 >>> args = (torch.tensor([1.0, 2.0]), torch.tensor(2.0)) 1326 >>> find_compatible_op_overload(packet, args, {})._overloadname 1327 'Tensor_Tensor' 1328 """ 1329 # Utilize the dispatch mechanism to find the compatible op overload. 1330 op_trace_dispatch_mode = _OpTraceDispatchMode() 1331 with op_trace_dispatch_mode: 1332 op(*args, **kwargs) 1333 assert ( 1334 len(op_trace_dispatch_mode.traced_ops) >= 1 1335 ), "Expected at least 1 traced op, got 0" 1336 1337 new_op_overload = op_trace_dispatch_mode.traced_ops[0] 1338 assert isinstance( 1339 new_op_overload, torch._ops.OpOverload 1340 ), f"Expected OpOverload, got {type(new_op_overload)}" 1341 assert ( 1342 new_op_overload.overloadpacket == op 1343 ), f"Expected same OpOverload packet, got {new_op_overload.overloadpacket} != {op}" 1344 1345 return new_op_overload 1346 1347 1348class _TypePromotionInterpreter(torch.fx.Interpreter): 1349 """Interpreter that inserts type promotion for each node.""" 1350 1351 def __init__( 1352 self, 1353 diagnostic_context: diagnostics.DiagnosticContext, 1354 module: torch.fx.GraphModule, 1355 type_promotion_table: TypePromotionTable, 1356 ): 1357 super().__init__(module) 1358 self.diagnostic_context = diagnostic_context 1359 self.type_promotion_table = type_promotion_table 1360 1361 def _run_node_and_set_meta(self, node) -> Any: 1362 """Run node and set meta according to `fx_traceback.get_current_meta()`. 1363 1364 This should be used on new nodes or nodes that have been modified. 1365 By default `Interpreter.run_node` does not update `node.meta`. 1366 Set `node.meta` to the current meta, except for `node.meta["val"]`, which is 1367 recomputed. 1368 """ 1369 out = super().run_node(node) 1370 # Update interpreter env state with new output value. 1371 self.env[node] = out 1372 node.meta.update( 1373 (k, v) 1374 for k, v in fx_traceback.get_current_meta().items() 1375 if k not in node.meta 1376 ) 1377 node.meta["val"] = proxy_tensor.extract_val(out) 1378 return out 1379 1380 def _create_node( 1381 self, 1382 graph: torch.fx.Graph, 1383 op_type: str, 1384 target: torch.fx.node.Target, 1385 args: tuple, 1386 kwargs: dict, 1387 ) -> torch.fx.Node: 1388 """Create a node and set its metadata.""" 1389 assert op_type in ( 1390 "call_function", 1391 "call_method", 1392 "get_attr", 1393 "call_module", 1394 "placeholder", 1395 "output", 1396 ), f"Unexpected op_type: {op_type}" 1397 node = getattr(graph, op_type)(target, args, kwargs) 1398 self._run_node_and_set_meta(node) 1399 return node 1400 1401 def _rerun_node_after_type_promotion( 1402 self, 1403 diagnostic: diagnostics.Diagnostic, 1404 node: torch.fx.Node, 1405 expected_out_dtype: torch.dtype, 1406 ) -> None: 1407 """Rerun a node after type promotion and update node.meta["val"] with the output value.""" 1408 node_val = node.meta.get("val", None) 1409 assert node_val is not None, f"Node {node} node.meta['val'] is not set." 1410 args, kwargs = self.fetch_args_kwargs_from_env(node) 1411 target = node.target 1412 assert isinstance( 1413 target, torch._ops.OpOverload 1414 ), f"Expected OpOverload, got {type(target)}" 1415 node.target = find_compatible_op_overload(target.overloadpacket, args, kwargs) 1416 1417 new_node_val = self._run_node_and_set_meta(node) 1418 assert isinstance(new_node_val, type(node_val)), ( 1419 f"run_node output type should not change between runs. " 1420 f"Got {type(new_node_val)}, expect {type(node_val)}." 1421 ) 1422 1423 if isinstance(node_val, torch.Tensor): 1424 prev_node_dtype = node_val.dtype 1425 1426 assert prev_node_dtype == expected_out_dtype, ( 1427 f"node.meta['val'].dtype({prev_node_dtype}) does not agree with " 1428 f"type promotion rule({expected_out_dtype})." 1429 ) 1430 1431 if new_node_val.dtype != expected_out_dtype: 1432 # With explicit type promotion, the expected result dtype may not be 1433 # the same as the computation dtype. This is referred to as "op math". 1434 # We need to explicitly cast the output back to the expected dtype. 1435 # See more about "op math" topic at `_prims_common.elementwise_dtypes`. 1436 graph = node.graph 1437 with graph.inserting_after(node): 1438 output_cast_node = self._create_node( 1439 graph, 1440 "call_function", 1441 torch.ops.prims.convert_element_type.default, 1442 (node,), 1443 {"dtype": expected_out_dtype}, 1444 ) 1445 node.replace_all_uses_with(output_cast_node) 1446 output_cast_node.args = (node,) 1447 diagnostic.info( 1448 "Node '%s' output dtype becomes %s due to op math. " 1449 "Cast back to %s.", 1450 node, 1451 new_node_val.dtype, 1452 expected_out_dtype, 1453 ) 1454 1455 elif fx_type_utils.is_torch_symbolic_type(node_val): 1456 raise NotImplementedError( 1457 "Type promotion does not support node output of sym types." 1458 ) 1459 elif isinstance(node_val, (list, tuple)): 1460 raise NotImplementedError( 1461 "Type promotion does not support node output of list or tuple." 1462 ) 1463 else: 1464 raise RuntimeError(f"Unexpected node output type: {type(node_val)}.") 1465 1466 def _maybe_promote_arg( 1467 self, 1468 diagnostic: diagnostics.Diagnostic, 1469 node: torch.fx.Node, 1470 fx_arg: torch.fx.node.Argument, 1471 dtype: torch.dtype | None, 1472 ) -> torch.fx.node.Argument: 1473 """Promote fx_arg to dtype if necessary.""" 1474 if dtype is None: 1475 diagnostic.info( 1476 "Argument %s is not promoted. Not mentioned by type promotion rule.", 1477 fx_arg, 1478 ) 1479 return fx_arg 1480 1481 if isinstance(fx_arg, torch.fx.Node): 1482 arg_val = self.env[fx_arg] 1483 if isinstance(arg_val, torch.Tensor): 1484 if (old_dtype := arg_val.dtype) != dtype: 1485 # Promote tensor to dtype. 1486 graph = node.graph 1487 with graph.inserting_before(node): 1488 diagnostic.info( 1489 "Argument %s(%s) is promoted to %s.", 1490 fx_arg, 1491 old_dtype, 1492 dtype, 1493 ) 1494 return self._create_node( 1495 graph, 1496 "call_function", 1497 torch.ops.prims.convert_element_type.default, 1498 (fx_arg,), 1499 {"dtype": dtype}, 1500 ) 1501 diagnostic.info( 1502 "Argument %s is not promoted. Already %s.", fx_arg, dtype 1503 ) 1504 return fx_arg 1505 elif fx_type_utils.is_torch_symbolic_type(arg_val): 1506 arg_type = type(arg_val) 1507 equivalent_dtype = fx_type_utils.from_scalar_type_to_torch_dtype( 1508 arg_type 1509 ) 1510 assert equivalent_dtype is not None, f"Unexpected arg_type: {arg_type}" 1511 if equivalent_dtype != dtype: 1512 # Promote Sym number to tensor of dtype. 1513 graph = node.graph 1514 with graph.inserting_before(node): 1515 diagnostic.info( 1516 "Argument %s(Scalar of equivalent dtype: %s) " 1517 "is promoted to %s.", 1518 fx_arg, 1519 equivalent_dtype, 1520 dtype, 1521 ) 1522 return self._create_node( 1523 graph, 1524 "call_function", 1525 torch.ops.aten.scalar_tensor.default, 1526 (fx_arg,), 1527 {"dtype": dtype}, 1528 ) 1529 diagnostic.info( 1530 "Argument %s is not promoted. Already %s.", fx_arg, dtype 1531 ) 1532 return fx_arg 1533 elif ( 1534 equivalent_dtype := fx_type_utils.from_scalar_type_to_torch_dtype( 1535 type(fx_arg) 1536 ) 1537 ) is not None: 1538 if equivalent_dtype != dtype: 1539 # Promote number to tensor of dtype. 1540 # The op should have overload that supports tensor for this arg, otherwise 1541 # the type promotion rule should not suggest promoting this arg. 1542 graph = node.graph 1543 with graph.inserting_before(node): 1544 diagnostic.info( 1545 "Argument %s(Scalar of equivalent dtype: %s) " 1546 "is promoted to %s.", 1547 fx_arg, 1548 equivalent_dtype, 1549 dtype, 1550 ) 1551 return self._create_node( 1552 graph, 1553 "call_function", 1554 torch.ops.aten.scalar_tensor.default, 1555 (fx_arg,), 1556 {"dtype": dtype}, 1557 ) 1558 diagnostic.info("Argument %s is not promoted. Already %s.", fx_arg, dtype) 1559 return fx_arg 1560 elif isinstance(fx_arg, (tuple, list)): 1561 diagnostic.info( 1562 "Argument %s is a tuple/list. Promoting each element.", fx_arg 1563 ) 1564 return type(fx_arg)( 1565 self._maybe_promote_arg(diagnostic, node, fx_arg_elem, dtype) 1566 for fx_arg_elem in fx_arg 1567 ) 1568 1569 raise NotImplementedError(f"Unknown fx arg type: {type(fx_arg)}") 1570 1571 def _maybe_promote_node( 1572 self, 1573 diagnostic: diagnostics.Diagnostic, 1574 node: torch.fx.Node, 1575 rule: TypePromotionRule, 1576 ) -> torch.fx.Node: 1577 """Promote node inputs and outputs according to type promotion rule.""" 1578 args, kwargs = self.fetch_args_kwargs_from_env(node) 1579 type_promotion_info = rule.preview_type_promotion(args, kwargs) 1580 new_args = [] 1581 new_kwargs = {} 1582 for i, arg in enumerate(node.args): 1583 new_args.append( 1584 self._maybe_promote_arg( 1585 diagnostic, node, arg, type_promotion_info.args_dtypes.get(i, None) 1586 ) 1587 ) 1588 1589 for name, arg in node.kwargs.items(): 1590 new_kwargs[name] = self._maybe_promote_arg( 1591 diagnostic, node, arg, type_promotion_info.kwargs_dtypes.get(name, None) 1592 ) 1593 new_args = tuple(new_args) 1594 1595 if node.args != new_args or node.kwargs != new_kwargs: 1596 diagnostic.message = f"Applied type promotion for {node}. " 1597 node.args = new_args 1598 node.kwargs = new_kwargs 1599 self._rerun_node_after_type_promotion( 1600 diagnostic, node, type_promotion_info.out_dtype 1601 ) 1602 else: 1603 diagnostic.message = f"Type promotion not needed for {node}. " 1604 1605 return node 1606 1607 @diagnostics.diagnose_call( 1608 rule=diagnostics.rules.fx_node_insert_type_promotion, 1609 level=diagnostics.levels.NONE, 1610 ) 1611 def run_node(self, node: torch.fx.Node) -> Any: 1612 """This method is an override which inserts type promotion nodes as needed. 1613 1614 For each `call_function` node, an initial check is conducted to determine if a type 1615 promotion rule is applicable. If a relevant rule exists, type casting nodes are 1616 introduced for the corresponding arguments. The OpOverload of the node is updated 1617 to one that accommodates the promoted types. Should the output type be different, 1618 type casting node is inserted for this output. 1619 1620 The call `super().run_node(node)` is guaranteed to be invoked for each node. 1621 In the case of new or modified nodes, the result of `super().run_node(node)` is 1622 used to update its `node.meta["val"]` value. 1623 """ 1624 diagnostic = self.diagnostic_context.inflight_diagnostic() 1625 with self._set_current_node(node): 1626 if node.op != "call_function": 1627 diagnostic.message = f"Skipped {node}: not a call_function." 1628 elif rule := get_type_promotion_rule( 1629 diagnostic, node, self.type_promotion_table 1630 ): 1631 self._maybe_promote_node(diagnostic, node, rule) 1632 1633 return super().run_node(node) 1634 1635 1636class InsertTypePromotion(_pass.Transform): 1637 """Explicitly insert type promotion ops to the graph. 1638 1639 This class subclasses `_pass.Transform` to provide graph level diagnostic tracking. 1640 Underneath, the main pass is driven by `_TypePromotionInterpreter`, which is a subclass 1641 of `torch.fx.Interpreter` to interpret the fx.Graph and perform the insertion of type 1642 promotion operations. 1643 1644 The interpreter is extended with ability to track diagnostic information for each node. 1645 1646 By re-running the new and modified nodes using the interpreter, we can update the 1647 metadata, specifically the fake tensor stored under node.meta["val"], and ensure it 1648 reflects the latest changes. 1649 1650 See [FXE0015: fx_node_insert_type_promotion](https://pytorch.org/docs/main/generated/onnx_dynamo_diagnostics_rules/FXE0015%3Afx-node-insert-type-promotion.html) for more details. # noqa: B950 1651 """ 1652 1653 def __init__( 1654 self, 1655 diagnostic_context: diagnostics.DiagnosticContext, 1656 module: torch.fx.GraphModule, 1657 type_promotion_table: TypePromotionTable | None = None, 1658 ): 1659 super().__init__(diagnostic_context, module) 1660 self.interpreter = _TypePromotionInterpreter( 1661 diagnostic_context, module, type_promotion_table or TypePromotionTable() 1662 ) 1663 1664 def _fetch_fake_args( 1665 self, 1666 ) -> Sequence[ 1667 fake_tensor.FakeTensor 1668 | float 1669 | int 1670 | bool 1671 | torch.SymInt 1672 | torch.SymFloat 1673 | torch.SymBool 1674 | None 1675 ]: 1676 """Fetch fake args from fx graph. 1677 1678 For each argument, try to fetch fake tensor from the matching placeholder node. 1679 """ 1680 fake_args = [] 1681 for node in self.module.graph.nodes: 1682 if node.op == "placeholder": 1683 try: 1684 # Meta value can be torch.Tensor, int, float, bool, 1685 # torch.SymInt, torch.SymFloat, torch.SymBool. 1686 meta_value = _val = node.meta.get("val", None) 1687 except RuntimeError as e: 1688 if not node.users: 1689 # If the placeholder is not used, we can safely ignore it and put 1690 # None as placeholder. 1691 meta_value = None 1692 else: 1693 raise RuntimeError( 1694 "Cannot fetch symbolic fake args from fx graph. " 1695 "InsertTypePromotion pass needs to run with pre-existing fake args, " 1696 "Otherwise the pass will produce inaccurate dynamic shape. " 1697 ) from e 1698 1699 fake_args.append(meta_value) 1700 return fake_args 1701 1702 def _run(self, *args, **kwargs) -> torch.fx.GraphModule: 1703 assert not args, ( 1704 "`InsertTypePromotion` deduces symbolic fake arguments from the graph. " 1705 "It does not accept concrete arguments as input because this pass requires " 1706 "re-running the graph. When executed with newly faked concrete arguments, " 1707 "the pass loses the symbolic dynamic shape information." 1708 ) 1709 assert not kwargs, "`kwargs` is not supported" 1710 1711 fake_args = self._fetch_fake_args() 1712 fake_mode = self.fake_mode 1713 assert fake_mode is not None, "Cannot detect fake_mode." 1714 1715 with fake_tensor.unset_fake_temporarily(), ( 1716 fake_mode 1717 ), fx_traceback.preserve_node_meta(): 1718 self.interpreter.run(*fake_args) 1719 1720 return self.module 1721