1from __future__ import annotations 2 3import contextlib 4import functools 5from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union 6 7import torchgen.local as local 8from torchgen.model import ( 9 BackendIndex, 10 DispatchKey, 11 NativeFunction, 12 NativeFunctionsGroup, 13 NativeFunctionsViewGroup, 14) 15from torchgen.utils import context, S, T 16 17 18# Helper functions for defining generators on things in the model 19 20F = TypeVar( 21 "F", 22 NativeFunction, 23 NativeFunctionsGroup, 24 NativeFunctionsViewGroup, 25 Union[NativeFunction, NativeFunctionsGroup], 26 Union[NativeFunction, NativeFunctionsViewGroup], 27) 28 29F2 = TypeVar( 30 "F2", 31 NativeFunction, 32 NativeFunctionsGroup, 33 Optional[NativeFunction], 34 bool, 35 str, 36) 37 38F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction]) 39 40 41@contextlib.contextmanager 42def native_function_manager( 43 g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction, 44) -> Iterator[None]: 45 if isinstance(g, NativeFunctionsGroup): 46 # By default, we associate all errors with structured native functions 47 # with the out variant. In some cases, it might be better to have 48 # a more specific place to hang things; if so, use 49 # native_function_manager again on the inside 50 f = g.out 51 elif isinstance(g, NativeFunctionsViewGroup): 52 # We associate errors with the view operator 53 f = g.view 54 else: 55 f = g 56 with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"): 57 with local.parametrize( 58 use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, 59 use_ilistref_for_tensor_lists=f.part_of_structured_group, 60 ): 61 yield 62 63 64# Given a function that operates on NativeFunction, wrap it into a new function 65# that sets some appropriate context managers for that native function. 66# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound 67# (you will get an error if we try to access the local variables without having 68# set them). 69def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]: 70 @functools.wraps(func) 71 def wrapper(f: F) -> T: 72 with native_function_manager(f): 73 return func(f) 74 75 return wrapper 76 77 78def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]: 79 @functools.wraps(func) 80 def wrapper(f: F, f2: F2) -> T: 81 # The first native_function is assumed to be the one with the appropriate context. 82 with native_function_manager(f): 83 return func(f, f2) 84 85 return wrapper 86 87 88def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]: 89 @functools.wraps(func) 90 def wrapper(slf: S, f: F) -> T: 91 with native_function_manager(f): 92 return func(slf, f) 93 94 return wrapper 95 96 97def method_with_nested_native_function( 98 func: Callable[[S, F3], T] 99) -> Callable[[S, F3], T]: 100 @functools.wraps(func) 101 def wrapper(slf: S, f: F3) -> T: 102 with native_function_manager(f[0]): 103 return func(slf, f) 104 105 return wrapper 106 107 108# Convenience decorator for functions that explicitly take in a BackendIndex, 109# instead of indirectly taking one in as a closure 110def with_native_function_and_index( 111 func: Callable[[F, BackendIndex], T] 112) -> Callable[[F, BackendIndex], T]: 113 @functools.wraps(func) 114 def wrapper(f: F, backend_index: BackendIndex) -> T: 115 with native_function_manager(f): 116 return func(f, backend_index) 117 118 return wrapper 119 120 121# Convenience decorator for functions that explicitly take in a Dict of BackendIndices 122def with_native_function_and_indices( 123 func: Callable[[F, dict[DispatchKey, BackendIndex]], T] 124) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]: 125 @functools.wraps(func) 126 def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T: 127 with native_function_manager(f): 128 return func(f, backend_indices) 129 130 return wrapper 131