1# mypy: allow-untyped-defs 2import inspect 3from collections import defaultdict 4from functools import wraps 5from itertools import chain 6from typing import Callable, Dict, List, Sequence, TypeVar, Union 7from typing_extensions import ParamSpec 8 9import torch 10import torch.library 11from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket 12from torch._prims_common import CustomOutParamAnnotation 13from torch.utils import _pytree as pytree 14 15 16__all__ = [ 17 "decomposition_table", 18 "pre_autograd_decomposition_table", 19 "meta_table", 20 "register_decomposition", 21 "get_decompositions", 22 "core_aten_decompositions", 23] 24 25_T = TypeVar("_T") 26_P = ParamSpec("_P") 27 28# TODO: relax key type here; torch registrations should be possible to; but 29# right now this type is accurate 30global_decomposition_table: Dict[ 31 str, Dict[torch._ops.OperatorBase, Callable] 32] = defaultdict(dict) 33 34decomposition_table = global_decomposition_table["post_autograd"] 35pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"] 36meta_table = global_decomposition_table["meta"] 37 38 39def _add_op_to_registry(registry, op, fn): 40 """ 41 This is an internal API for adding an op to the decomposition table. 42 43 If op is OpOverload, it will be added to the registry directly. 44 If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry. 45 """ 46 overloads: List[Union[torch._ops.OperatorBase]] = [] 47 if isinstance(op, HigherOrderOperator): 48 # There's no concept of overloads for HigherOrderOperator 49 registry[op] = fn 50 return 51 elif isinstance(op, OpOverload): 52 overloads.append(op) 53 else: 54 assert isinstance(op, OpOverloadPacket) 55 for ol in op.overloads(): 56 overloads.append(getattr(op, ol)) 57 58 for op_overload in overloads: 59 if op_overload in registry: 60 raise RuntimeError(f"duplicate registrations for {op_overload}") 61 # TorchScript dumps a bunch of extra nonsense overloads 62 # which don't have corresponding dispatcher entries, we need 63 # to filter those out, e.g aten.add.float_int 64 if torch._C._dispatch_has_kernel(op_overload.name()): 65 registry[op_overload] = fn 66 67 68def _convert_out_params(f): 69 out_annotation = f.__annotations__.get("out") 70 71 # If there are no out params, do not wrap the function. 72 if not out_annotation: 73 return f 74 75 # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this 76 if getattr(out_annotation, "__origin__", None) is tuple: 77 sig = inspect.signature(f) 78 out_names = sig.return_annotation._fields 79 # If out is a tuple, we need to register a function that unpacks all the out 80 # elements as this is what native_functions.yaml expects 81 82 @wraps(f) 83 def _fn(*args, **kwargs): 84 out_kwargs = tuple(kwargs.pop(o, None) for o in out_names) 85 # Either all of the out kwargs are set or none of them 86 is_none = out_kwargs[0] is None 87 assert all((o is None) == is_none for o in out_kwargs) 88 return f(*args, **kwargs, out=None if is_none else out_kwargs) 89 90 out_params = [ 91 inspect.Parameter( 92 o, 93 kind=inspect.Parameter.KEYWORD_ONLY, 94 default=None, 95 annotation=t, 96 ) 97 for o, t in zip(out_names, out_annotation.__args__) 98 ] 99 # Drop the out parameter and concatenate the new kwargs in the signature 100 params = chain((v for k, v in sig.parameters.items() if k != "out"), out_params) 101 _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] 102 parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] 103 ) 104 # Drop the out parameter and concatenate the new kwargs in the annotations 105 _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} 106 for o in out_params: 107 _fn.__annotations__[o.name] = o.annotation 108 109 # Propagate that this function is wrapped by `out_wrapper` 110 _fn._torch_decompositions_out_wrapper = f._torch_decompositions_out_wrapper # type: ignore[attr-defined] 111 112 return _fn 113 114 # Alternatively, there may be a single tensor out parameter with a name 115 # other than "out". This will need special treatment and is indicated by an 116 # annotation, which we will remove here so it is not exposed after wrapping. 117 custom_out_param_name = f.__annotations__.pop(CustomOutParamAnnotation, None) 118 if custom_out_param_name: 119 120 @wraps(f) 121 def _fn(*args, **kwargs): 122 out_kwarg = kwargs.pop(custom_out_param_name, None) 123 return f(*args, **kwargs, out=out_kwarg) 124 125 out_param = inspect.Parameter( 126 custom_out_param_name, 127 kind=inspect.Parameter.KEYWORD_ONLY, 128 default=None, 129 annotation=out_annotation, 130 ) 131 132 # Drop the out parameter and concatenate the new kwarg in the signature 133 sig = inspect.signature(f) 134 params = chain( 135 (v for k, v in sig.parameters.items() if k != "out"), (out_param,) 136 ) 137 _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] 138 parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] 139 ) 140 141 # Drop the out parameter and concatenate the new kwargs in the annotations 142 _fn.__annotations__ = {k: v for k, v in f.__annotations__.items() if k != "out"} 143 _fn.__annotations__[out_param.name] = out_param.annotation 144 145 return _fn 146 147 return f 148 149 150def register_decomposition( 151 aten_op, registry=None, *, type="post_autograd", unsafe=False 152) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: 153 """ 154 A decorator to register a function as a decomposition to the Python 155 decomposition table. Use it like this:: 156 157 @register_decomposition(torch.ops.aten.clamp_min) 158 def clamp_min(x): 159 return torch.clamp(self, min=min) 160 161 If you are writing a new decomposition, consider contributing it 162 directly to PyTorch in torch._decomp.decompositions. 163 164 This API is experimental; we are almost certainly going to extend 165 the API when we make decompositions eligible for use in transforms (e.g., 166 autograd) and not just backend tracing, where we then need to know if a 167 decomposition can be used to simulate a transform. 168 169 By default, we also will register it to the Meta key of dispatcher, 170 and replace the c++ Meta implementation if there is already one. 171 172 unsafe kwarg is for reuse of this function for registering non-function 173 things 174 """ 175 176 assert type in {"post_autograd", "pre_autograd", "meta"} 177 178 def decomposition_decorator(fn: Callable[_P, _T]) -> Callable[_P, _T]: 179 orig_fn = fn 180 if not unsafe: 181 fn = _convert_out_params(fn) 182 183 nonlocal registry 184 if registry is None: 185 registry = global_decomposition_table[type] 186 187 def register(op): 188 _add_op_to_registry(registry, op, fn) 189 190 # To handle allowing multiple aten_ops at once 191 pytree.tree_map_(register, aten_op) 192 return orig_fn 193 194 return decomposition_decorator 195 196 197def get_decompositions( 198 aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]], 199 type: str = "post_autograd", 200) -> Dict[torch._ops.OperatorBase, Callable]: 201 """ 202 Retrieve a dictionary of decompositions corresponding to the list of 203 operator overloads and overload packets passed as input. Overload 204 packets will include all decomposed overloads in the packet. If there is 205 no decomposition for a requested operator, it is silently ignored. 206 207 This API is experimental; we are almost certainly going to give an alternate, 208 more recommended formulation, where a user provides the set of operators 209 they know how to implement, and we provide decompositions for everything 210 not in this set. 211 """ 212 assert type in {"post_autograd", "pre_autograd", "meta"} 213 214 registry = global_decomposition_table[type] 215 packets_to_overloads = defaultdict(list) 216 for opo in registry: 217 if isinstance(opo, (OpOverload, OpOverloadPacket)): 218 packets_to_overloads[opo.overloadpacket].append(opo) 219 decompositions: Dict[torch._ops.OperatorBase, Callable] = {} 220 for op in aten_ops: 221 if isinstance(op, OpOverloadPacket) and op in packets_to_overloads: 222 for op_overload in packets_to_overloads[op]: 223 decompositions[op_overload] = registry[op_overload] 224 elif isinstance(op, (torch._ops.OperatorBase)) and op in registry: 225 decompositions[op] = registry[op] 226 return decompositions 227 228 229def remove_decompositions( 230 decompositions: Dict[torch._ops.OperatorBase, Callable], 231 aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]], 232) -> None: 233 """ 234 Given a dictionary of decompositions obtained from get_decompositions(), removes 235 operators associated with a list of operator overloads and overload packets passed 236 as input. If the decomposition dictionary does not contain a decomposition that is 237 specified to be removed, it is silently ignored. 238 """ 239 for op in aten_ops: 240 if isinstance(op, OpOverloadPacket): 241 for overload_name in op.overloads(): 242 opo = getattr(op, overload_name) 243 decompositions.pop(opo, None) 244 elif isinstance(op, OpOverload): 245 decompositions.pop(op, None) 246 247 248# populate the table 249import torch._decomp.decompositions 250import torch._refs 251 252 253# See NOTE [Core ATen Ops] 254# 255# list was copied from torch/_inductor/decomposition.py 256# excluding decompositions that results in prim ops 257# Resulting opset of decomposition is core aten ops 258def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: 259 aten = torch.ops.aten 260 return get_decompositions( 261 [ 262 aten.addcdiv, 263 aten.addcdiv_, 264 aten.addcmul, 265 aten.addcmul_, 266 aten.addr, 267 aten.affine_grid_generator, 268 aten.alias_copy, 269 aten.all, 270 aten.aminmax, 271 aten.arange.default, 272 aten.arange.start, 273 aten.avg_pool2d_backward, 274 aten.baddbmm, 275 aten.binary_cross_entropy, 276 aten.binary_cross_entropy_backward, 277 aten.binary_cross_entropy_with_logits, 278 aten.block_diag, 279 aten.celu, 280 aten.celu_, 281 aten.channel_shuffle, 282 aten.clamp_max, 283 aten.clamp_min, 284 aten.col2im, 285 aten.count_nonzero, 286 aten.linalg_cross, 287 aten.cudnn_batch_norm, 288 aten.cudnn_batch_norm_backward, 289 aten.miopen_batch_norm_backward, 290 aten.deg2rad, 291 aten.deg2rad_, 292 aten.detach, 293 aten.diag_embed, 294 aten.diagonal_backward, 295 aten.dot, 296 aten.vdot, 297 aten.elu, 298 aten.elu_, 299 aten.elu_backward, 300 aten._embedding_bag, 301 aten.embedding_dense_backward, 302 aten.empty_like, 303 aten._euclidean_dist.default, 304 aten.expand_as, 305 aten.expand_copy, 306 aten.eye, 307 aten.fill, 308 aten.fill_, 309 aten.floor_divide, 310 aten.frac, 311 aten.frac_, 312 aten._fused_moving_avg_obs_fq_helper, 313 aten.gelu_, 314 aten.gelu_backward, 315 aten.glu, 316 aten.glu_backward, 317 aten.hardshrink, 318 aten.hardsigmoid, 319 aten.hardsigmoid_, 320 aten.hardsigmoid_backward, 321 aten.hardswish, 322 aten.hardswish_, 323 aten.hardswish_backward, 324 aten.hardtanh_, 325 aten.hardtanh_backward, 326 aten.heaviside, 327 aten.heaviside_, 328 aten.huber_loss, 329 aten.huber_loss_backward, 330 aten.im2col, 331 aten.index_add, 332 aten.index_add_, 333 aten.index_copy, 334 aten.index_copy_, 335 aten.index_fill, 336 aten.index_fill_, 337 aten.isin, 338 aten.isneginf, 339 aten.isposinf, 340 aten.l1_loss, 341 aten._lazy_clone, 342 aten._test_parallel_materialize, 343 aten.leaky_relu_, 344 aten.leaky_relu_backward, 345 aten.lerp, 346 aten.lerp_, 347 aten.linspace, 348 aten.logaddexp, 349 aten.logaddexp2, 350 aten.logit, 351 aten.logit_, 352 aten.logit_backward, 353 aten.log_sigmoid_backward, 354 aten.log_sigmoid_forward, 355 aten._log_softmax_backward_data, 356 aten.logspace, 357 aten.logsumexp.default, 358 aten.masked_fill, 359 aten.masked_fill_, 360 aten.mish, 361 aten.mish_, 362 aten.mse_loss, 363 aten.mse_loss_backward, 364 aten.multi_margin_loss, 365 aten.multilabel_margin_loss_forward, 366 aten.mv, 367 aten.mvlgamma, 368 aten.mvlgamma_, 369 aten.nansum, 370 aten.nan_to_num, 371 aten.nan_to_num_, 372 aten.narrow, 373 aten.native_batch_norm_backward, 374 aten.native_dropout_backward, 375 aten.native_group_norm_backward, 376 aten.native_layer_norm_backward, 377 aten.new_empty, 378 aten.new_full, 379 aten.new_ones, 380 aten.new_zeros, 381 aten.nll_loss2d_forward, 382 aten.nll_loss2d_backward, 383 aten.nll_loss_backward, 384 aten.nll_loss_forward, 385 aten.norm, 386 aten.ones, 387 aten.ones_like, 388 aten.pixel_shuffle, 389 aten.pixel_unshuffle, 390 aten._prelu_kernel, 391 aten._prelu_kernel_backward, 392 aten._reshape_alias, 393 aten.rad2deg, 394 aten.rad2deg_, 395 aten.reflection_pad1d, 396 aten.reflection_pad1d_backward, 397 aten.reflection_pad2d, 398 aten.reflection_pad2d_backward, 399 aten.reflection_pad3d, 400 aten.reflection_pad3d_backward, 401 aten.replication_pad1d, 402 aten.replication_pad2d, 403 aten.replication_pad3d, 404 aten.renorm, 405 aten.renorm_, 406 aten.replication_pad2d, 407 aten.resize_as, 408 aten.roll, 409 aten.rot90, 410 aten.rrelu_with_noise, 411 aten.rrelu_with_noise_, 412 aten.rsub, 413 aten._safe_softmax, 414 aten._scaled_dot_product_flash_attention_for_cpu.default, 415 aten.select_backward, 416 aten.select_scatter, 417 aten.sgn, 418 aten.sgn_, 419 aten.sigmoid_backward, 420 aten.silu, 421 aten.silu_, 422 aten.silu_backward, 423 aten.sinc, 424 aten.sinc_, 425 aten.slice_backward, 426 aten.smooth_l1_loss, 427 aten.smooth_l1_loss_backward, 428 aten.soft_margin_loss, 429 aten.soft_margin_loss_backward, 430 aten._softmax_backward_data, 431 aten.softplus, 432 aten.softplus_backward, 433 aten.softshrink, 434 aten.special_entr, 435 aten.special_log_ndtr, 436 aten.special_xlog1py, 437 aten.split.Tensor, 438 aten.split_with_sizes_copy, 439 aten.squeeze.default, 440 aten.squeeze.dim, 441 aten.std, 442 aten.std_mean, 443 aten.stack, 444 aten.sum.default, 445 aten.sum.out, 446 aten.t, 447 aten.t_copy, 448 aten.take, 449 aten.tanh_backward, 450 aten.threshold, 451 aten.threshold_, 452 aten.threshold_backward, 453 aten.trace, 454 aten.transpose.int, 455 aten.tril, 456 aten.tril_, 457 aten.triu, 458 aten.triu_, 459 aten.unbind, 460 aten.unfold_backward, 461 aten.unfold_copy, 462 aten._unsafe_index, 463 aten._unsafe_index_put, 464 aten._unsafe_masked_index, 465 aten._unsafe_masked_index_put_accumulate, 466 aten.unsafe_split.Tensor, 467 aten.unsafe_split_with_sizes, 468 aten.unsqueeze_copy, 469 aten._unsafe_view, 470 aten.upsample_linear1d, 471 aten.upsample_bilinear2d, 472 aten.upsample_trilinear3d, 473 aten.upsample_nearest2d_backward, 474 aten.view_as_complex, 475 aten.xlogy, 476 aten.xlogy_, 477 aten.zero, 478 aten.zero_, 479 aten.zeros, 480 aten.zeros_like, 481 aten._chunk_cat, 482 aten._weight_norm_interface, 483 ] 484 ) 485