xref: /aosp_15_r20/external/pytorch/torch/cuda/tunable.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerr"""
2*da0073e9SAndroid Build Coastguard WorkerThis module exposes a TunableOp interface.
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard WorkerSome operations, such as GEMMs, could be implemented using more than one library
5*da0073e9SAndroid Build Coastguard Workeror more than one technique. For example, a GEMM could be implemented for CUDA or
6*da0073e9SAndroid Build Coastguard WorkerROCm using either the blas or blasLt libraries. Further, ROCm's rocblas and
7*da0073e9SAndroid Build Coastguard Workerhipblaslt libraries allow the user to query for all possible algorithms and then
8*da0073e9SAndroid Build Coastguard Workerchoose one. How does one know which implementation is the fastest and should be
9*da0073e9SAndroid Build Coastguard Workerchosen? That's what TunableOp provides.
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard WorkerEnabling TunableOp and Tuning Separately
12*da0073e9SAndroid Build Coastguard Worker========================================
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard WorkerThe TunableOp feature is enabled separately from enabling the tuning phase
15*da0073e9SAndroid Build Coastguard Workeritself. Enabling TunableOp means that PyTorch will replace any standard
16*da0073e9SAndroid Build Coastguard Workeroperators with their Tunable implementations. Any call to a TunableOp first
17*da0073e9SAndroid Build Coastguard Workerchecks whether it has already been tuned for the given operator inputs. If so,
18*da0073e9SAndroid Build Coastguard Workerit will immediately call the tuned operation; no further tuning will take place
19*da0073e9SAndroid Build Coastguard Workereven when the tuning setting is enabled. Instead if no tuning result is found,
20*da0073e9SAndroid Build Coastguard Workerand tuning is enabled, the TunableOp will benchmark every registered
21*da0073e9SAndroid Build Coastguard Workerimplementation of that operator for the given set of inputs and select the
22*da0073e9SAndroid Build Coastguard Workerfastest.
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard WorkerFile Input and Output
25*da0073e9SAndroid Build Coastguard Worker=====================
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard WorkerThe first time any TunableOp is invoked, the internal database of tuned
28*da0073e9SAndroid Build Coastguard Workeroperations will be prepared by attempting to read the results from the given
29*da0073e9SAndroid Build Coastguard Workerfile. The default filename is 'tunableop_results.csv'. To support tuning when
30*da0073e9SAndroid Build Coastguard Workermultiple GPUs are used across multiple processes, the GPU device ordinal is
31*da0073e9SAndroid Build Coastguard Workerautomatically inserted into the filename to avoid multiple processes overwriting
32*da0073e9SAndroid Build Coastguard Workerthe same file.
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard WorkerIf tuning is enabled and new tunings are discovered during the course of your
35*da0073e9SAndroid Build Coastguard Workerworkload, it will also write out to this same filename with all tunings, both
36*da0073e9SAndroid Build Coastguard Workerthe ones it read in at startup as well as the new ones found at runtime. This
37*da0073e9SAndroid Build Coastguard Workercan be used, for example, to build up a tunings file across many workloads by
38*da0073e9SAndroid Build Coastguard Workerreusing the same file. The output file is automatically created when the
39*da0073e9SAndroid Build Coastguard Workerapplication terminates. This behavior can be controlled by the C++ and Python
40*da0073e9SAndroid Build Coastguard WorkerAPIs but not the environment variables.
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard WorkerAssuming you specified a filename, you'll end up with a CSV file with contents
43*da0073e9SAndroid Build Coastguard Workerlike so::
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker  Validator,PT_VERSION,2.2.0
46*da0073e9SAndroid Build Coastguard Worker  Validator,ROCM_VERSION,6.0.0.0-12969-1544e39
47*da0073e9SAndroid Build Coastguard Worker  Validator,HIPBLASLT_VERSION,0.6.0-a9c5cc7
48*da0073e9SAndroid Build Coastguard Worker  Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty
49*da0073e9SAndroid Build Coastguard Worker  GemmTunableOp_float_NT,nt_25088_4096_64,1219,1.262
50*da0073e9SAndroid Build Coastguard Worker  GemmTunableOp_float_NT,nt_4096_4096_64,1216,0.033
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard WorkerNote the "Validator" lines. If you change a library verison, or ROCm version, or
53*da0073e9SAndroid Build Coastguard WorkerPyTorch version, TunableOp will detect this and reject the tunings file because
54*da0073e9SAndroid Build Coastguard Workerthe prior tunings are likely affected by other software changes.
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard WorkerThe remaining lines are the tuned solutions for each TunableOp encountered
57*da0073e9SAndroid Build Coastguard Workerduring your execution. Each line consists of 4 comma-separated fields: operator
58*da0073e9SAndroid Build Coastguard Workername, operator parameters, solution name, and average execution time. The
59*da0073e9SAndroid Build Coastguard Workerexecution time is an optional field. The CSV file can be edited, but with
60*da0073e9SAndroid Build Coastguard Workercaution. For example, the solution name (field 3) can be changed to "Default"
61*da0073e9SAndroid Build Coastguard Workerand it will fall back to the original PyTorch untuned implementation. Or, in the
62*da0073e9SAndroid Build Coastguard Workercase of ROCm's hipBLAS or hipBLASLt libraries, if you know the specific solution
63*da0073e9SAndroid Build Coastguard Workerindex you can override the solution that TunableOp selected by replacing the
64*da0073e9SAndroid Build Coastguard Workervalue. The operator name and parameters (fields 1 and 2) are internally named
65*da0073e9SAndroid Build Coastguard Workerand should not be modified. In the case of GemmTunableOp, field 1 indicates the
66*da0073e9SAndroid Build Coastguard Workerdatatype and whether the inputs are transposed (T) or not (N) and field 2
67*da0073e9SAndroid Build Coastguard Workerindicates the M, N, K input shapes.
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard WorkerThere is an option to enable verbose output but it is only recommended for
70*da0073e9SAndroid Build Coastguard Workerdebugging purposes. This will produce a lot of diagnostic messages but may be
71*da0073e9SAndroid Build Coastguard Workeruseful to see if TunableOp is being used at all. Otherwise, TunableOp is
72*da0073e9SAndroid Build Coastguard Workercompletely silent, besides file output, unless there is a warning or error
73*da0073e9SAndroid Build Coastguard Workerduring its use. The verbose option is only available by setting the environment
74*da0073e9SAndroid Build Coastguard Workervariable PYTORCH_TUNABLEOP_VEROBSE=1.
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard WorkerA Note on Tuning Behavior
77*da0073e9SAndroid Build Coastguard Worker=========================
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard WorkerTuning an operator consists of iterating through the list or registered
80*da0073e9SAndroid Build Coastguard Workerimplementations and profiling each one. The profile is established by running a
81*da0073e9SAndroid Build Coastguard Workersingle implementation in a loop multiple times and taking the average execution
82*da0073e9SAndroid Build Coastguard Workertime.
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard WorkerBy default, each possible solution for a given operator will be run for either
85*da0073e9SAndroid Build Coastguard Worker100 iterations or as many iterations that can be run within 30ms, whichever is
86*da0073e9SAndroid Build Coastguard Workersmaller, and its average execution will be calculated. The fastest solution
87*da0073e9SAndroid Build Coastguard Workeramong all that were successfully profiled will be chosen. A profile might fail
88*da0073e9SAndroid Build Coastguard Workerif the given solution doesn't achieve the same accuracy as the default
89*da0073e9SAndroid Build Coastguard Workerimplementation or if the solution returns an error code.
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard WorkerCurrent Tunable Operators
92*da0073e9SAndroid Build Coastguard Worker=========================
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard WorkerTunableGemm for ROCm
95*da0073e9SAndroid Build Coastguard Worker--------------------
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard WorkerCurrently only a TunableGemm for ROCm is implemented. Note that CUDA builds of
98*da0073e9SAndroid Build Coastguard WorkerPyTorch will function correctly when using TunableOp but the only solution
99*da0073e9SAndroid Build Coastguard Workeravailable to CUDA builds is the 'Default' implementation i.e. the original
100*da0073e9SAndroid Build Coastguard WorkercuBLAS default, now called through TunableOp. Any call to at::cuda::blas::gemm()
101*da0073e9SAndroid Build Coastguard Workeror ::bgemm() will be routed through TunableOp when enabled. Calling gemm() for a
102*da0073e9SAndroid Build Coastguard Workergiven set of input arguments (transa, transb, m, n, k) will attempt to use the
103*da0073e9SAndroid Build Coastguard Workerfastest available implementation across both rocblas and hipblaslt.
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard WorkerTuning Context
106*da0073e9SAndroid Build Coastguard Worker==============
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard WorkerThe behavior of TunableOp is currently manipulated through environment
109*da0073e9SAndroid Build Coastguard Workervariables, the C++ interface of at::cuda::tunable::getTuningContext(), or the
110*da0073e9SAndroid Build Coastguard Workertorch.cuda.tunable python interfaces that wrap the C++ TuningContext. The
111*da0073e9SAndroid Build Coastguard Workerenvironment variables take precedence over any setting you manipulate using the
112*da0073e9SAndroid Build Coastguard WorkerC++ or Python APIs.
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker"""
115*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional, Tuple
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Workerimport torch
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker__all__ = [
121*da0073e9SAndroid Build Coastguard Worker    "enable",
122*da0073e9SAndroid Build Coastguard Worker    "is_enabled",
123*da0073e9SAndroid Build Coastguard Worker    "tuning_enable",
124*da0073e9SAndroid Build Coastguard Worker    "tuning_is_enabled",
125*da0073e9SAndroid Build Coastguard Worker    "set_max_tuning_duration",
126*da0073e9SAndroid Build Coastguard Worker    "get_max_tuning_duration",
127*da0073e9SAndroid Build Coastguard Worker    "set_max_tuning_iterations",
128*da0073e9SAndroid Build Coastguard Worker    "get_max_tuning_iterations",
129*da0073e9SAndroid Build Coastguard Worker    "set_filename",
130*da0073e9SAndroid Build Coastguard Worker    "get_filename",
131*da0073e9SAndroid Build Coastguard Worker    "get_results",
132*da0073e9SAndroid Build Coastguard Worker    "get_validators",
133*da0073e9SAndroid Build Coastguard Worker    "write_file_on_exit",
134*da0073e9SAndroid Build Coastguard Worker    "write_file",
135*da0073e9SAndroid Build Coastguard Worker    "read_file",
136*da0073e9SAndroid Build Coastguard Worker]
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Workerdef enable(val: bool = True) -> None:
140*da0073e9SAndroid Build Coastguard Worker    r"""This is the big on/off switch for all TunableOp implementations."""
141*da0073e9SAndroid Build Coastguard Worker    torch._C._cuda_tunableop_enable(val)  # type: ignore[attr-defined]
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Workerdef is_enabled() -> bool:
145*da0073e9SAndroid Build Coastguard Worker    r"""Returns whether the TunableOp feature is enabled."""
146*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_is_enabled()  # type: ignore[attr-defined]
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Workerdef tuning_enable(val: bool = True) -> None:
150*da0073e9SAndroid Build Coastguard Worker    r"""Enable tuning of TunableOp implementations.
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    When enabled, if a tuned entry isn't found, run the tuning step and record
153*da0073e9SAndroid Build Coastguard Worker    the entry.
154*da0073e9SAndroid Build Coastguard Worker    """
155*da0073e9SAndroid Build Coastguard Worker    torch._C._cuda_tunableop_tuning_enable(val)  # type: ignore[attr-defined]
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Workerdef tuning_is_enabled() -> bool:
159*da0073e9SAndroid Build Coastguard Worker    r"""Returns whether TunableOp implementations can be tuned."""
160*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_tuning_is_enabled()  # type: ignore[attr-defined]
161*da0073e9SAndroid Build Coastguard Worker
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Workerdef set_max_tuning_duration(duration: int) -> None:
164*da0073e9SAndroid Build Coastguard Worker    r"""Set max time in milliseconds to spend tuning a given solution.
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker    If both max tuning duration and iterations are set, the smaller of the two
167*da0073e9SAndroid Build Coastguard Worker    will be honored. At minimum 1 tuning iteration will always be run.
168*da0073e9SAndroid Build Coastguard Worker    """
169*da0073e9SAndroid Build Coastguard Worker    torch._C._cuda_tunableop_set_max_tuning_duration(duration)  # type: ignore[attr-defined]
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Workerdef get_max_tuning_duration() -> int:
173*da0073e9SAndroid Build Coastguard Worker    r"""Get max time to spend tuning a given solution."""
174*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_get_max_tuning_duration()  # type: ignore[attr-defined]
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Workerdef set_max_tuning_iterations(iterations: int) -> None:
178*da0073e9SAndroid Build Coastguard Worker    r"""Set max number of iterations to spend tuning a given solution.
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker    If both max tuning duration and iterations are set, the smaller of the two
181*da0073e9SAndroid Build Coastguard Worker    will be honored. At minimum 1 tuning iteration will always be run.
182*da0073e9SAndroid Build Coastguard Worker    """
183*da0073e9SAndroid Build Coastguard Worker    torch._C._cuda_tunableop_set_max_tuning_iterations(iterations)  # type: ignore[attr-defined]
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Workerdef get_max_tuning_iterations() -> int:
187*da0073e9SAndroid Build Coastguard Worker    r"""Get max iterations to spend tuning a given solution."""
188*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_get_max_tuning_iterations()  # type: ignore[attr-defined]
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard Workerdef set_filename(filename: str, insert_device_ordinal: bool = False) -> None:
192*da0073e9SAndroid Build Coastguard Worker    r"""Set the filename to use for input/output of tuning results.
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    If :attr:`insert_device_ordinal` is ``True`` then the current device ordinal
195*da0073e9SAndroid Build Coastguard Worker    will be added to the given filename automatically. This can be used in a
196*da0073e9SAndroid Build Coastguard Worker    1-process-per-gpu cenario to ensure all processes write to a separate file.
197*da0073e9SAndroid Build Coastguard Worker    """
198*da0073e9SAndroid Build Coastguard Worker    torch._C._cuda_tunableop_set_filename(filename, insert_device_ordinal)  # type: ignore[attr-defined]
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Workerdef get_filename() -> str:
202*da0073e9SAndroid Build Coastguard Worker    r"""Get the results filename."""
203*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_get_filename()  # type: ignore[attr-defined]
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Workerdef get_results() -> Tuple[str, str, str, float]:
207*da0073e9SAndroid Build Coastguard Worker    r"""Return all TunableOp results."""
208*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_get_results()  # type: ignore[attr-defined]
209*da0073e9SAndroid Build Coastguard Worker
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Workerdef get_validators() -> Tuple[str, str]:
212*da0073e9SAndroid Build Coastguard Worker    r"""Return the TunableOp validators."""
213*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_get_validators()  # type: ignore[attr-defined]
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Workerdef write_file_on_exit(val: bool) -> None:
217*da0073e9SAndroid Build Coastguard Worker    r"""During Tuning Context destruction, write file to disk.
218*da0073e9SAndroid Build Coastguard Worker
219*da0073e9SAndroid Build Coastguard Worker    This is useful as a final flush of your results to disk if your application
220*da0073e9SAndroid Build Coastguard Worker    terminates as result of normal operation or an error. Manual flushing of
221*da0073e9SAndroid Build Coastguard Worker    your results can be achieved by manually calling ``write_file()``."""
222*da0073e9SAndroid Build Coastguard Worker    torch._C._cuda_tunableop_write_file_on_exit(val)  # type: ignore[attr-defined]
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Workerdef write_file(filename: Optional[str] = None) -> bool:
226*da0073e9SAndroid Build Coastguard Worker    r"""Write results to a CSV file.
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    If :attr:`filename` is not given, ``get_filename()`` is called.
229*da0073e9SAndroid Build Coastguard Worker    """
230*da0073e9SAndroid Build Coastguard Worker    if filename is None:
231*da0073e9SAndroid Build Coastguard Worker        filename = get_filename()
232*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_write_file(filename)  # type: ignore[attr-defined]
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Workerdef read_file(filename: Optional[str] = None) -> bool:
236*da0073e9SAndroid Build Coastguard Worker    r"""Read results from a TunableOp CSV file.
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    If :attr:`filename` is not given, ``get_filename()`` is called.
239*da0073e9SAndroid Build Coastguard Worker    """
240*da0073e9SAndroid Build Coastguard Worker    if filename is None:
241*da0073e9SAndroid Build Coastguard Worker        filename = get_filename()
242*da0073e9SAndroid Build Coastguard Worker    return torch._C._cuda_tunableop_read_file(filename)  # type: ignore[attr-defined]
243