1# mypy: allow-untyped-defs 2import inspect 3import warnings 4from functools import wraps 5from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple, TypeVar 6from typing_extensions import ParamSpec 7 8import torch 9import torch._prims_common as utils 10from torch._prims_common import ( 11 CustomOutParamAnnotation, 12 ELEMENTWISE_TYPE_PROMOTION_KIND, 13 Number, 14 NumberType, 15 ShapeType, 16 TensorLike, 17 TensorLikeType, 18) 19from torch.utils import _pytree as pytree 20from torch.utils._pytree import tree_flatten, tree_unflatten 21 22 23_T = TypeVar("_T") 24_P = ParamSpec("_P") 25 26 27@overload 28def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType: 29 pass 30 31 32@overload 33def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType: 34 pass 35 36 37@overload 38def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence: 39 pass 40 41 42@overload 43def _maybe_convert_to_dtype(a: None, dtype: torch.dtype) -> None: 44 pass 45 46 47# TODO: implement ref.cast with an option to enforce safe casting 48def _maybe_convert_to_dtype(a, dtype): 49 if isinstance(a, TensorLike): 50 if a.dtype != dtype: 51 return a.to(dtype) 52 return a 53 if isinstance(a, Number): 54 return utils.dtype_to_type_ctor(dtype)(a) # type: ignore[arg-type] 55 if isinstance(a, Sequence): 56 return tuple(_maybe_convert_to_dtype(x, dtype) for x in a) 57 # Passthrough None because some functions wrapped with type promotion 58 # wrapper might have optional args 59 if a is None: 60 return None 61 62 raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!") 63 64 65def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: 66 if not isinstance(a, Number): 67 msg = f"Found unknown type {type(a)} when trying to convert scalars!" 68 raise ValueError(msg) 69 if not utils.is_weakly_lesser_type(type(a), typ): 70 msg = f"Scalar {a} of type {type(a)} cannot be safely cast to type {typ}!" 71 raise ValueError(msg) 72 73 return typ(a) 74 75 76def _annotation_has_type(*, typ, annotation): 77 if hasattr(annotation, "__args__"): 78 for a in annotation.__args__: 79 if _annotation_has_type(typ=typ, annotation=a): 80 return True 81 return False 82 83 return typ is annotation 84 85 86class elementwise_type_promotion_wrapper: 87 """ 88 Adds elementwise type promotion to a Python reference implementation. 89 90 Takes two kwargs, type_promoting_args and type_promotion_kind. 91 92 type_promoting_args must be a string Sequence specifiying the argument names of all 93 arguments that participate in type promotion (and should be type promoted). If the 94 arg specifies a Sequence-type then every element of the Sequence will participate in 95 type promotion. 96 97 type_promotion_kind must be one of the kinds specified by ELEMENTWISE_TYPE_PROMOTION_KIND. 98 See its documentation for details. 99 100 The return_dtype will be coerced to the wrapped function's dtype arg if it is available and 101 not None. 102 103 Other type promotion behavior, like validating the Python type of scalar arguments, must 104 be handled separately. 105 """ 106 107 def __init__( 108 self, 109 *, 110 type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND, 111 type_promoting_args: Optional[Sequence[str]] = None, 112 ): 113 self.type_promoting_arg_names = type_promoting_args 114 self.type_promotion_kind = type_promotion_kind 115 116 def __call__(self, fn: Callable) -> Callable: 117 sig = inspect.signature(fn) 118 119 @wraps(fn) 120 def _fn(*args, **kwargs): 121 bound = sig.bind(*args, **kwargs) 122 type_promoting_args = tuple( 123 bound.arguments[x] 124 for x in self.type_promoting_arg_names # type: ignore[union-attr] 125 if x in bound.arguments.keys() 126 ) 127 128 flattened_type_promoting_args = pytree.arg_tree_leaves(*type_promoting_args) 129 compute_dtype, result_dtype = utils.elementwise_dtypes( 130 *flattened_type_promoting_args, 131 type_promotion_kind=self.type_promotion_kind, 132 ) 133 134 promoted_args = { 135 x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype) 136 for x in self.type_promoting_arg_names # type: ignore[union-attr] 137 if x in bound.arguments.keys() 138 } 139 bound.arguments.update(promoted_args) 140 141 result = fn(**bound.arguments) 142 143 # Override the return_dtype if a dtype arg is present and not None 144 if "dtype" in bound.arguments: 145 maybe_dtype = bound.arguments["dtype"] 146 if maybe_dtype: # dtype cannot be None 147 result_dtype = maybe_dtype 148 149 if isinstance(result, TensorLike): 150 return _maybe_convert_to_dtype(result, result_dtype) 151 if isinstance(result, Sequence): 152 return tuple(_maybe_convert_to_dtype(x, result_dtype) for x in result) 153 raise AssertionError(f"Unhandled result type: {type(result)}") 154 155 _fn.__signature__ = sig # type: ignore[attr-defined] 156 return _fn 157 158 159# Returns True if resize is necessary 160def _resize_output_check(out: TensorLikeType, shape: ShapeType): 161 # If the shapes are correct there's nothing to do 162 if utils.same_shape(out.shape, shape): 163 return False 164 if out.numel() != 0: 165 msg = ( 166 f"An output with one or more elements was resized since it had shape {str(out.shape)} " 167 "which does not match the required output shape {str(shape)}. " 168 "This behavior is deprecated, and in a future PyTorch release outputs will not " 169 "be resized unless they have zero elements. " 170 "You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0)." 171 ) 172 warnings.warn(msg) 173 return True 174 175 176# TODO: handle tuples of tensors 177def _maybe_resize_out( 178 out: TensorLikeType, 179 shape: ShapeType, 180 memory_format: Optional[torch.memory_format] = None, 181): 182 if _resize_output_check(out, shape): 183 return out.resize_(shape, memory_format=memory_format) 184 else: 185 return out 186 187 188def is_cpu_scalar(x: TensorLikeType) -> bool: 189 return x.dim() == 0 and x.device.type == "cpu" 190 191 192def _safe_copy_out( 193 *, copy_from: TensorLikeType, copy_to: TensorLikeType, exact_dtype: bool = False 194): 195 # Checks same device 196 if not is_cpu_scalar(copy_from) and copy_from.device != copy_to.device: 197 msg = ( 198 f"Attempting to copy from device {copy_from.device} " 199 f"to device {copy_to.device}, but cross-device copies are not allowed!" 200 ) 201 raise RuntimeError(msg) 202 203 # Checks safe cast 204 if exact_dtype: 205 torch._check( 206 copy_from.dtype == copy_to.dtype, 207 lambda: f"Expected out tensor to have dtype {copy_from.dtype} " 208 f"but got {copy_to.dtype} instead", 209 ) 210 else: 211 torch._check( 212 utils.can_safe_cast_to(cast_from=copy_from.dtype, cast_to=copy_to.dtype), 213 lambda: f"Attempting to cast from {copy_from.dtype} to out tensor with dtype {copy_to.dtype}, " 214 "but this can't be cast because it is not safe!", 215 ) 216 217 return copy_to.copy_(copy_from) 218 219 220def out_wrapper( 221 *out_names: str, 222 exact_dtype: bool = False, 223 pass_is_out: bool = False, 224 preserve_memory_format: bool = False, 225) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: 226 # The wrapped function needs to convert the output parameters to ensure 227 # compatibility between the Python API (which always uses "out" as the 228 # parameter name and may be a tuple) and the Aten API (which may have 229 # multiple output parameters and use different parameter names such as 230 # "grad_input", "indices" or "values".) 231 232 default_out_names = ("out",) 233 if len(out_names) == 0: 234 # Use default in out name 235 out_names = default_out_names 236 237 is_tensor = len(out_names) == 1 238 239 def maybe_compute_memory_format(t): 240 return utils.suggest_memory_format(t) if preserve_memory_format else None 241 242 def _out_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]: 243 """ 244 Adds the out parameter to a Python reference. 245 """ 246 out_type = ( 247 TensorLikeType 248 if is_tensor 249 else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))] 250 ) 251 return_type = ( 252 TensorLikeType 253 if is_tensor 254 else NamedTuple( 255 f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names] 256 ) 257 ) 258 259 sig = inspect.signature(fn) 260 factory_kwargs = ("device", "dtype") 261 is_factory_fn = all(p in sig.parameters for p in factory_kwargs) 262 263 @wraps(fn) 264 def _fn(*args: _P.args, out=None, **kwargs: _P.kwargs): 265 if is_factory_fn and out is not None: 266 for k in factory_kwargs: 267 out_attr = getattr(out, k) 268 if k not in kwargs: 269 kwargs[k] = out_attr 270 if pass_is_out: 271 result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type] 272 else: 273 result = fn(*args, **kwargs) 274 assert ( 275 isinstance(result, TensorLike) 276 and is_tensor 277 or isinstance(result, Tuple) # type: ignore[arg-type] 278 and len(result) == len(out_names) # type: ignore[arg-type] 279 ) 280 if out is not None: 281 # Naively you might expect this assert to be true, but 282 # it's not: 283 # 284 # assert type(out) == type(result) 285 # 286 # The reason is that functions under this wrapper can 287 # get registered to the Meta dispatch key, and that 288 # means they can be executed in a context where tensor 289 # subclasses are disabled (with no_dispatch), which is a 290 # handy way for an is-a tensor subclass (e.g., 291 # FakeTensor) to have the normal meta backend create a 292 # meta tensor, to be wrapped once it gets returned. 293 # In this situation, you will get a FakeTensor as 294 # the output tensor, but not the result--which will 295 # be a normal meta tensor, but this is perfectly 296 # harmless. 297 if is_tensor: 298 assert isinstance(out, TensorLike) 299 # These two operations are done in-place 300 _maybe_resize_out( 301 out, result.shape, maybe_compute_memory_format(result) # type: ignore[union-attr] 302 ) 303 _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] 304 else: 305 assert isinstance(out, Tuple) # type: ignore[arg-type] 306 torch._check_type( 307 len(out) == len(result), # type: ignore[arg-type] 308 lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type] 309 ) 310 for r, o in zip(result, out): # type: ignore[arg-type] 311 # These two operations are done in-place 312 _maybe_resize_out(o, r.shape, maybe_compute_memory_format(r)) 313 _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=exact_dtype) # type: ignore[arg-type] 314 else: 315 out = result 316 # mypy does not see through the definition of out_type given that it's in a different scope 317 return out if is_tensor else return_type(*out) # type: ignore[operator] 318 319 out_param = inspect.Parameter( 320 "out", 321 kind=inspect.Parameter.KEYWORD_ONLY, 322 default=None, 323 annotation=out_type, 324 ) 325 # Mark that the function now returns a tuple 326 assert isinstance(sig.return_annotation, str) or sig.return_annotation in ( 327 sig.empty, 328 out_type, 329 ) 330 params = *sig.parameters.values(), out_param 331 332 # If there's a Parameter.VAR_KEYWORD parameter (like **kwds), it must appear 333 # after the out= parameter, which is Parameter.KEYWORD_ONLY. Sorting by 334 # Parameter.kind guarantees that all the parameters are in legal order. 335 params = sorted(params, key=lambda p: p.kind) 336 337 _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] 338 parameters=params, return_annotation=return_type # type: ignore[arg-type] 339 ) 340 341 _fn.__annotations__ = dict(getattr(fn, "__annotations__", {})) 342 _fn.__annotations__["out"] = out_type 343 _fn.__annotations__["return"] = return_type 344 345 # In the special case of having a single tensor out parameter with a 346 # name other than out, add a special annotation to name the parameter 347 if is_tensor and out_names != default_out_names: 348 _fn.__annotations__[CustomOutParamAnnotation] = out_names[0] 349 350 # Add an indicator attribute that can be used in special cases 351 # where having a function wrapped by `out_wrapper` is not desirable e.g. 352 # jit 353 _fn._torch_decompositions_out_wrapper = f"This function is wrapped by {out_wrapper.__module__}.out_wrapper" # type: ignore[attr-defined] 354 355 return _fn 356 357 return _out_wrapper 358 359 360def _maybe_remove_out_wrapper(fn: Callable): 361 return inspect.unwrap( 362 fn, 363 stop=lambda f: not hasattr(f, "_torch_decompositions_out_wrapper"), 364 ) 365 366 367def backwards_not_supported(prim): 368 def redispatch_prim(args, kwargs): 369 with torch._C._AutoDispatchBelowAutograd(): 370 old = torch._C._dispatch_tls_is_dispatch_key_excluded( 371 torch._C.DispatchKey.ADInplaceOrView 372 ) 373 return prim(*args, **kwargs) 374 375 class BackwardsNotSupported(torch.autograd.Function): 376 @staticmethod 377 def forward(ctx, args_spec, *flat_args): 378 args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type] 379 return redispatch_prim(args, kwargs) 380 381 @staticmethod 382 def backward(ctx, *args): 383 raise RuntimeError("backwards not supported on prim") 384 385 @wraps(prim) 386 def _autograd_impl(*args, **kwargs): 387 flat_args, args_spec = tree_flatten((args, kwargs)) 388 if torch.is_grad_enabled() and any( 389 a.requires_grad for a in flat_args if isinstance(a, torch.Tensor) 390 ): 391 # TODO: There is a subtle bug here: prims like copy_to 392 # return their input argument after mutating it; and custom 393 # autograd function will incorrectly turn the result into 394 # a view which will fail test_python_ref_executor tests. 395 # At the moment, we sidestep this by observing that the 396 # unit tests don't ever try to run the executor with 397 # autograd, so we don't exercise the buggy case, but if 398 # you ever want to feed autograd through this, be aware 399 # of it! We need a way of properly implementing autograd 400 # for mutating operations in Python to do this. 401 return BackwardsNotSupported.apply(args_spec, *flat_args) 402 else: 403 return redispatch_prim(args, kwargs) 404 405 return _autograd_impl 406 407 408# TODO: when tracing this will add torch tensors and not TensorMeta objects 409# to the trace -- we should fix this by adding a tracing context and NumberMeta classes 410# TODO: this wrapper is currently untested 411def elementwise_unary_scalar_wrapper(fn: Callable) -> Callable: 412 """ 413 Allows unary operators that accept tensors to work with Python numbers. 414 """ 415 sig = inspect.signature(fn) 416 417 @wraps(fn) 418 def _fn(*args, **kwargs): 419 if len(args) > 0 and isinstance(args[0], Number): 420 dtype = utils.type_to_dtype(type(args[0])) 421 args_ = list(args) 422 args_[0] = torch.tensor(args[0], dtype=dtype) 423 result = fn(*args_, **kwargs) 424 assert isinstance(result, torch.Tensor) 425 return result.item() 426 427 return fn(*args, **kwargs) 428 429 _fn.__signature__ = sig # type: ignore[attr-defined] 430 return _fn 431