1*da0073e9SAndroid Build Coastguard Workerr""" 2*da0073e9SAndroid Build Coastguard WorkerThis module exposes a TunableOp interface. 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard WorkerSome operations, such as GEMMs, could be implemented using more than one library 5*da0073e9SAndroid Build Coastguard Workeror more than one technique. For example, a GEMM could be implemented for CUDA or 6*da0073e9SAndroid Build Coastguard WorkerROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and 7*da0073e9SAndroid Build Coastguard Workerhipblaslt libraries allow the user to query for all possible algorithms and then 8*da0073e9SAndroid Build Coastguard Workerchoose one. How does one know which implementation is the fastest and should be 9*da0073e9SAndroid Build Coastguard Workerchosen? That's what TunableOp provides. 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard WorkerEnabling TunableOp and Tuning Separately 12*da0073e9SAndroid Build Coastguard Worker======================================== 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard WorkerThe TunableOp feature is enabled separately from enabling the tuning phase 15*da0073e9SAndroid Build Coastguard Workeritself. Enabling TunableOp means that PyTorch will replace any standard 16*da0073e9SAndroid Build Coastguard Workeroperators with their Tunable implementations. Any call to a TunableOp first 17*da0073e9SAndroid Build Coastguard Workerchecks whether it has already been tuned for the given operator inputs. If so, 18*da0073e9SAndroid Build Coastguard Workerit will immediately call the tuned operation; no further tuning will take place 19*da0073e9SAndroid Build Coastguard Workereven when the tuning setting is enabled. Instead if no tuning result is found, 20*da0073e9SAndroid Build Coastguard Workerand tuning is enabled, the TunableOp will benchmark every registered 21*da0073e9SAndroid Build Coastguard Workerimplementation of that operator for the given set of inputs and select the 22*da0073e9SAndroid Build Coastguard Workerfastest. 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard WorkerFile Input and Output 25*da0073e9SAndroid Build Coastguard Worker===================== 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard WorkerThe first time any TunableOp is invoked, the internal database of tuned 28*da0073e9SAndroid Build Coastguard Workeroperations will be prepared by attempting to read the results from the given 29*da0073e9SAndroid Build Coastguard Workerfile. The default filename is 'tunableop_results.csv'. To support tuning when 30*da0073e9SAndroid Build Coastguard Workermultiple GPUs are used across multiple processes, the GPU device ordinal is 31*da0073e9SAndroid Build Coastguard Workerautomatically inserted into the filename to avoid multiple processes overwriting 32*da0073e9SAndroid Build Coastguard Workerthe same file. 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard WorkerIf tuning is enabled and new tunings are discovered during the course of your 35*da0073e9SAndroid Build Coastguard Workerworkload, it will also write out to this same filename with all tunings, both 36*da0073e9SAndroid Build Coastguard Workerthe ones it read in at startup as well as the new ones found at runtime. This 37*da0073e9SAndroid Build Coastguard Workercan be used, for example, to build up a tunings file across many workloads by 38*da0073e9SAndroid Build Coastguard Workerreusing the same file. The output file is automatically created when the 39*da0073e9SAndroid Build Coastguard Workerapplication terminates. This behavior can be controlled by the C++ and Python 40*da0073e9SAndroid Build Coastguard WorkerAPIs but not the environment variables. 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard WorkerAssuming you specified a filename, you'll end up with a CSV file with contents 43*da0073e9SAndroid Build Coastguard Workerlike so:: 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker Validator,PT_VERSION,2.2.0 46*da0073e9SAndroid Build Coastguard Worker Validator,ROCM_VERSION,6.0.0.0-12969-1544e39 47*da0073e9SAndroid Build Coastguard Worker Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7 48*da0073e9SAndroid Build Coastguard Worker Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty 49*da0073e9SAndroid Build Coastguard Worker GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262 50*da0073e9SAndroid Build Coastguard Worker GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard WorkerNote the "Validator" lines. If you change a library verison, or ROCm version, or 53*da0073e9SAndroid Build Coastguard WorkerPyTorch version, TunableOp will detect this and reject the tunings file because 54*da0073e9SAndroid Build Coastguard Workerthe prior tunings are likely affected by other software changes. 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard WorkerThe remaining lines are the tuned solutions for each TunableOp encountered 57*da0073e9SAndroid Build Coastguard Workerduring your execution. Each line consists of 4 comma-separated fields: operator 58*da0073e9SAndroid Build Coastguard Workername, operator parameters, solution name, and average execution time. The 59*da0073e9SAndroid Build Coastguard Workerexecution time is an optional field. The CSV file can be edited, but with 60*da0073e9SAndroid Build Coastguard Workercaution. For example, the solution name (field 3) can be changed to "Default" 61*da0073e9SAndroid Build Coastguard Workerand it will fall back to the original PyTorch untuned implementation. Or, in the 62*da0073e9SAndroid Build Coastguard Workercase of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution 63*da0073e9SAndroid Build Coastguard Workerindex you can override the solution that TunableOp selected by replacing the 64*da0073e9SAndroid Build Coastguard Workervalue. The operator name and parameters (fields 1 and 2) are internally named 65*da0073e9SAndroid Build Coastguard Workerand should not be modified. In the case of GemmTunableOp, field 1 indicates the 66*da0073e9SAndroid Build Coastguard Workerdatatype and whether the inputs are transposed (T) or not (N) and field 2 67*da0073e9SAndroid Build Coastguard Workerindicates the M, N, K input shapes. 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard WorkerThere is an option to enable verbose output but it is only recommended for 70*da0073e9SAndroid Build Coastguard Workerdebugging purposes. This will produce a lot of diagnostic messages but may be 71*da0073e9SAndroid Build Coastguard Workeruseful to see if TunableOp is being used at all. Otherwise, TunableOp is 72*da0073e9SAndroid Build Coastguard Workercompletely silent, besides file output, unless there is a warning or error 73*da0073e9SAndroid Build Coastguard Workerduring its use. The verbose option is only available by setting the environment 74*da0073e9SAndroid Build Coastguard Workervariable PYTORCH_TUNABLEOP_VEROBSE=1. 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard WorkerA Note on Tuning Behavior 77*da0073e9SAndroid Build Coastguard Worker========================= 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard WorkerTuning an operator consists of iterating through the list or registered 80*da0073e9SAndroid Build Coastguard Workerimplementations and profiling each one. The profile is established by running a 81*da0073e9SAndroid Build Coastguard Workersingle implementation in a loop multiple times and taking the average execution 82*da0073e9SAndroid Build Coastguard Workertime. 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard WorkerBy default, each possible solution for a given operator will be run for either 85*da0073e9SAndroid Build Coastguard Worker100 iterations or as many iterations that can be run within 30ms, whichever is 86*da0073e9SAndroid Build Coastguard Workersmaller, and its average execution will be calculated. The fastest solution 87*da0073e9SAndroid Build Coastguard Workeramong all that were successfully profiled will be chosen. A profile might fail 88*da0073e9SAndroid Build Coastguard Workerif the given solution doesn't achieve the same accuracy as the default 89*da0073e9SAndroid Build Coastguard Workerimplementation or if the solution returns an error code. 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard WorkerCurrent Tunable Operators 92*da0073e9SAndroid Build Coastguard Worker========================= 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard WorkerTunableGemm for ROCm 95*da0073e9SAndroid Build Coastguard Worker-------------------- 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard WorkerCurrently only a TunableGemm for ROCm is implemented. Note that CUDA builds of 98*da0073e9SAndroid Build Coastguard WorkerPyTorch will function correctly when using TunableOp but the only solution 99*da0073e9SAndroid Build Coastguard Workeravailable to CUDA builds is the 'Default' implementation i.e. the original 100*da0073e9SAndroid Build Coastguard WorkercuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm() 101*da0073e9SAndroid Build Coastguard Workeror ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a 102*da0073e9SAndroid Build Coastguard Workergiven set of input arguments (transa, transb, m, n, k) will attempt to use the 103*da0073e9SAndroid Build Coastguard Workerfastest available implementation across both rocblas and hipblaslt. 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard WorkerTuning Context 106*da0073e9SAndroid Build Coastguard Worker============== 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard WorkerThe behavior of TunableOp is currently manipulated through environment 109*da0073e9SAndroid Build Coastguard Workervariables, the C++ interface of at::cuda::tunable::getTuningContext(), or the 110*da0073e9SAndroid Build Coastguard Workertorch.cuda.tunable python interfaces that wrap the C++ TuningContext. The 111*da0073e9SAndroid Build Coastguard Workerenvironment variables take precedence over any setting you manipulate using the 112*da0073e9SAndroid Build Coastguard WorkerC++ or Python APIs. 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker""" 115*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Workerimport torch 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker__all__ = [ 121*da0073e9SAndroid Build Coastguard Worker "enable", 122*da0073e9SAndroid Build Coastguard Worker "is_enabled", 123*da0073e9SAndroid Build Coastguard Worker "tuning_enable", 124*da0073e9SAndroid Build Coastguard Worker "tuning_is_enabled", 125*da0073e9SAndroid Build Coastguard Worker "set_max_tuning_duration", 126*da0073e9SAndroid Build Coastguard Worker "get_max_tuning_duration", 127*da0073e9SAndroid Build Coastguard Worker "set_max_tuning_iterations", 128*da0073e9SAndroid Build Coastguard Worker "get_max_tuning_iterations", 129*da0073e9SAndroid Build Coastguard Worker "set_filename", 130*da0073e9SAndroid Build Coastguard Worker "get_filename", 131*da0073e9SAndroid Build Coastguard Worker "get_results", 132*da0073e9SAndroid Build Coastguard Worker "get_validators", 133*da0073e9SAndroid Build Coastguard Worker "write_file_on_exit", 134*da0073e9SAndroid Build Coastguard Worker "write_file", 135*da0073e9SAndroid Build Coastguard Worker "read_file", 136*da0073e9SAndroid Build Coastguard Worker] 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker 139*da0073e9SAndroid Build Coastguard Workerdef enable(val: bool = True) -> None: 140*da0073e9SAndroid Build Coastguard Worker r"""This is the big on/off switch for all TunableOp implementations.""" 141*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_tunableop_enable(val) # type: ignore[attr-defined] 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Workerdef is_enabled() -> bool: 145*da0073e9SAndroid Build Coastguard Worker r"""Returns whether the TunableOp feature is enabled.""" 146*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_is_enabled() # type: ignore[attr-defined] 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerdef tuning_enable(val: bool = True) -> None: 150*da0073e9SAndroid Build Coastguard Worker r"""Enable tuning of TunableOp implementations. 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker When enabled, if a tuned entry isn't found, run the tuning step and record 153*da0073e9SAndroid Build Coastguard Worker the entry. 154*da0073e9SAndroid Build Coastguard Worker """ 155*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_tunableop_tuning_enable(val) # type: ignore[attr-defined] 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Workerdef tuning_is_enabled() -> bool: 159*da0073e9SAndroid Build Coastguard Worker r"""Returns whether TunableOp implementations can be tuned.""" 160*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined] 161*da0073e9SAndroid Build Coastguard Worker 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Workerdef set_max_tuning_duration(duration: int) -> None: 164*da0073e9SAndroid Build Coastguard Worker r"""Set max time in milliseconds to spend tuning a given solution. 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker If both max tuning duration and iterations are set, the smaller of the two 167*da0073e9SAndroid Build Coastguard Worker will be honored. At minimum 1 tuning iteration will always be run. 168*da0073e9SAndroid Build Coastguard Worker """ 169*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_tunableop_set_max_tuning_duration(duration) # type: ignore[attr-defined] 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Workerdef get_max_tuning_duration() -> int: 173*da0073e9SAndroid Build Coastguard Worker r"""Get max time to spend tuning a given solution.""" 174*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_get_max_tuning_duration() # type: ignore[attr-defined] 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Workerdef set_max_tuning_iterations(iterations: int) -> None: 178*da0073e9SAndroid Build Coastguard Worker r"""Set max number of iterations to spend tuning a given solution. 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker If both max tuning duration and iterations are set, the smaller of the two 181*da0073e9SAndroid Build Coastguard Worker will be honored. At minimum 1 tuning iteration will always be run. 182*da0073e9SAndroid Build Coastguard Worker """ 183*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_tunableop_set_max_tuning_iterations(iterations) # type: ignore[attr-defined] 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Workerdef get_max_tuning_iterations() -> int: 187*da0073e9SAndroid Build Coastguard Worker r"""Get max iterations to spend tuning a given solution.""" 188*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_get_max_tuning_iterations() # type: ignore[attr-defined] 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker 191*da0073e9SAndroid Build Coastguard Workerdef set_filename(filename: str, insert_device_ordinal: bool = False) -> None: 192*da0073e9SAndroid Build Coastguard Worker r"""Set the filename to use for input/output of tuning results. 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal 195*da0073e9SAndroid Build Coastguard Worker will be added to the given filename automatically. This can be used in a 196*da0073e9SAndroid Build Coastguard Worker 1-process-per-gpu cenario to ensure all processes write to a separate file. 197*da0073e9SAndroid Build Coastguard Worker """ 198*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal) # type: ignore[attr-defined] 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Workerdef get_filename() -> str: 202*da0073e9SAndroid Build Coastguard Worker r"""Get the results filename.""" 203*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_get_filename() # type: ignore[attr-defined] 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker 206*da0073e9SAndroid Build Coastguard Workerdef get_results() -> Tuple[str, str, str, float]: 207*da0073e9SAndroid Build Coastguard Worker r"""Return all TunableOp results.""" 208*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_get_results() # type: ignore[attr-defined] 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Workerdef get_validators() -> Tuple[str, str]: 212*da0073e9SAndroid Build Coastguard Worker r"""Return the TunableOp validators.""" 213*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_get_validators() # type: ignore[attr-defined] 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Workerdef write_file_on_exit(val: bool) -> None: 217*da0073e9SAndroid Build Coastguard Worker r"""During Tuning Context destruction, write file to disk. 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker This is useful as a final flush of your results to disk if your application 220*da0073e9SAndroid Build Coastguard Worker terminates as result of normal operation or an error. Manual flushing of 221*da0073e9SAndroid Build Coastguard Worker your results can be achieved by manually calling ``write_file()``.""" 222*da0073e9SAndroid Build Coastguard Worker torch._C._cuda_tunableop_write_file_on_exit(val) # type: ignore[attr-defined] 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Workerdef write_file(filename: Optional[str] = None) -> bool: 226*da0073e9SAndroid Build Coastguard Worker r"""Write results to a CSV file. 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker If :attr:`filename` is not given, ``get_filename()`` is called. 229*da0073e9SAndroid Build Coastguard Worker """ 230*da0073e9SAndroid Build Coastguard Worker if filename is None: 231*da0073e9SAndroid Build Coastguard Worker filename = get_filename() 232*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_write_file(filename) # type: ignore[attr-defined] 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Workerdef read_file(filename: Optional[str] = None) -> bool: 236*da0073e9SAndroid Build Coastguard Worker r"""Read results from a TunableOp CSV file. 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker If :attr:`filename` is not given, ``get_filename()`` is called. 239*da0073e9SAndroid Build Coastguard Worker """ 240*da0073e9SAndroid Build Coastguard Worker if filename is None: 241*da0073e9SAndroid Build Coastguard Worker filename = get_filename() 242*da0073e9SAndroid Build Coastguard Worker return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined] 243