xref: /aosp_15_r20/external/pytorch/torch/_inductor/autoheuristic/autoheuristic_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2from typing import Any, Callable, Dict, List, Tuple
3
4import torch
5
6
7Feedback = float
8Choice = str
9Value = Any
10
11CHOICE_COL = "choice"
12FEEDBACK_COL = "feedback"
13
14
15class AHFeature:
16    """
17    The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is
18    categorical (i.e., not a continuous variable) to learn a machine learning model.
19    """
20
21    def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None:
22        self.name = name
23        self.value = value
24        self.is_categorical = is_categorical
25
26
27class AHOperation:
28    """
29    AHOperation can be used to augment the data collected by AutoHeuristic.
30    One might for example store features like m, k, n, but also want to use
31    features like m*n, or k*n, to learn a heuristic. Instead of storing features
32    that can be created from the collected data, one can use AHOperation to
33    create new features from the collected data.
34    """
35
36    def __init__(
37        self, name: str, func: Callable[[Any], Value], is_categorical: bool = False
38    ) -> None:
39        self.name = name
40        self.func = func
41        self.is_categorical = is_categorical
42
43    def apply_operation(self, data: Any) -> None:
44        data[self.name] = self.func(data)
45
46
47class AHContext:
48    """
49    This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will
50    store the context and the collected feedback. The context could be something like the shape of a tensor, i.e.,
51    information that will help to learn a heuristic.
52    """
53
54    features: List[AHFeature]
55    context_dict: Dict[str, Value]
56
57    def __init__(self) -> None:
58        self.features = []
59        self.context_dict = {}
60
61    def add_feature(
62        self, name: str, value: Value, is_categorical: bool = False
63    ) -> None:
64        self.features.append(AHFeature(name, value, is_categorical=is_categorical))
65        self.context_dict[name] = value
66
67    def get_numerical_and_categorical_features(self) -> Tuple[List[str], List[str]]:
68        numerical_features = []
69        categorical_features = []
70        for feature in self.features:
71            if feature.is_categorical:
72                categorical_features.append(feature.name)
73            else:
74                numerical_features.append(feature.name)
75
76        return numerical_features, categorical_features
77
78    def get_feature_names_csv(self) -> str:
79        return ",".join(feature.name for feature in self.features)
80
81    def get_feature_values_csv(self) -> str:
82        return ",".join(str(feature.value) for feature in self.features)
83
84    def get_value(self, name: str) -> Value:
85        return self.context_dict[name]
86
87    def apply_operations(self, operations: List[AHOperation]) -> None:
88        for op in operations:
89            op.apply_operation(self.context_dict)
90
91
92class AHMetadata:
93    def __init__(
94        self,
95        shared_memory: Any,
96        device_capa: Tuple[int, int],
97        choices: List[Choice],
98        name: str,
99    ) -> None:
100        # use amount of shared_memory and device_capability to identify GPU
101        # TODO(AlnisM): there might be a better way to do this
102        self.shared_memory = shared_memory
103        self.device_capa = device_capa
104        self.choices = choices
105        self.name = name
106
107    def to_dict(self) -> Dict[str, Value]:
108        return {
109            "shared_memory": self.shared_memory,
110            "device_capa": self.device_capa,
111            "name": self.name,
112        }
113
114
115def get_metadata_str_from_log(log_path: str) -> str:
116    with open(log_path, newline="") as file:
117        json_string = file.readline().strip()
118        return json_string
119
120
121def check_minsize(context: AHContext, minsize: int) -> bool:
122    return (
123        context.get_value("m") >= minsize
124        and context.get_value("k") >= minsize
125        and context.get_value("n") >= minsize
126    )
127
128
129def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
130    if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0):
131        # A100 precondition
132        return check_minsize(context, 512)
133    elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0):
134        # H100 precondition
135        return check_minsize(context, 768)
136    return True
137
138
139def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
140    m = context.get_value("m")
141    k = context.get_value("k")
142    n = context.get_value("n")
143    if m > 128 or k < 1024 or n < 1024:
144        return False
145    mat1_iscontig = context.get_value("mat1_iscontig")
146    mat2_iscontig = context.get_value("mat2_iscontig")
147    return mat1_iscontig and not mat2_iscontig
148
149
150def get_mult_dims_ops() -> List[AHOperation]:
151    m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
152    m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
153    k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
154    return [m_times_k_op, m_times_n_op, k_times_n_op]
155
156
157def get_arith_intensity(data: Any) -> float:
158    m = data["m"]
159    k = data["k"]
160    n = data["n"]
161    if m == 0 or k == 0 or n == 0:
162        return 0.0
163    return m * k * n / (m * k + k * n + m * n)
164
165
166def pad_mm_operations() -> List[AHOperation]:
167    mult_dims_ops = get_mult_dims_ops()
168    k_div_m_times_n_op = AHOperation(
169        "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
170    )
171
172    def bfloat_perf_hit(data: Any) -> bool:
173        m = data["m"]
174        k = data["k"]
175        n = data["n"]
176        is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16"
177        return k > (m * 1024) and k > (n * 1024) and is_bfloat
178
179    bfloat_perf_hit_op = AHOperation(
180        "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True
181    )
182
183    arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
184    dims_need_padding_ops = get_dims_need_padding_ops()
185    dims_multiple_ops = get_dims_multiple_ops()
186    is_contig_ops = get_is_contig_ops()
187
188    ah_operations = mult_dims_ops + [
189        k_div_m_times_n_op,
190        bfloat_perf_hit_op,
191        arith_intensity_op,
192    ]
193    ah_operations.extend(dims_need_padding_ops)
194    ah_operations.extend(dims_multiple_ops)
195    ah_operations.extend(is_contig_ops)
196    return ah_operations
197
198
199def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
200    return data[dim] >= lower and data[dim] <= upper
201
202
203def between_ops() -> List[AHOperation]:
204    dims = ["m", "k", "n"]
205    limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
206    ah_operations = []
207    for dim in dims:
208        for lower, upper in limits:
209            between_op_fn = functools.partial(
210                between_op, dim=dim, lower=lower, upper=upper
211            )
212            # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot
213            between_op_name = f"{lower}LEQ{dim}LEQ{upper}"
214            ah_operations.append(
215                AHOperation(between_op_name, between_op_fn, is_categorical=True)
216            )
217    return ah_operations
218
219
220def pow2_op(data: Any, dim: str, exponent: int) -> bool:
221    return data[dim] == 2**exponent
222
223
224def mm_operations() -> List[AHOperation]:
225    mult_dims_ops = get_mult_dims_ops()
226    arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
227    return mult_dims_ops + [arith_intensity_op]
228
229
230def mixed_mm_operations() -> List[AHOperation]:
231    return mm_operations() + between_ops()
232
233
234def is_multiple(data: Any, dim: str, mult: int) -> bool:
235    return data[dim] % mult == 0
236
237
238def get_dims_multiple_ops() -> List[AHOperation]:
239    multiples = [2, 4, 8, 16, 32]
240    dims = ["m", "k", "n"]
241    dims_multiple_ops = []
242    for dim in dims:
243        for mult in multiples:
244            is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult)
245            dims_multiple_op = AHOperation(
246                f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True
247            )
248            dims_multiple_ops.append(dims_multiple_op)
249    return dims_multiple_ops
250
251
252def get_dims_need_padding_ops() -> List[AHOperation]:
253    def mat1_innermost_needs_padding_fn(data: Any) -> bool:
254        mat1_stride_0 = data["mat1_stride_0"]
255        mat1_stride_1 = data["mat1_stride_1"]
256        m_padded_length = data["m_padded_length"]
257        k_padded_length = data["k_padded_length"]
258        mat1_innermost_needs_padding = False
259        if mat1_stride_0 == 1 and m_padded_length != 0:
260            mat1_innermost_needs_padding = True
261        if mat1_stride_1 == 1 and k_padded_length != 0:
262            mat1_innermost_needs_padding = True
263        return mat1_innermost_needs_padding
264
265    mat1_innermost_op = AHOperation(
266        "mat1_innermost_needs_padding",
267        mat1_innermost_needs_padding_fn,
268        is_categorical=True,
269    )
270
271    def mat2_innermost_needs_padding_fn(data: Any) -> bool:
272        mat2_stride_0 = data["mat2_stride_0"]
273        mat2_stride_1 = data["mat2_stride_1"]
274        k_padded_length = data["k_padded_length"]
275        n_padded_length = data["n_padded_length"]
276        mat2_innermost_needs_padding = False
277        if mat2_stride_0 == 1 and k_padded_length != 0:
278            mat2_innermost_needs_padding = True
279        if mat2_stride_1 == 1 and n_padded_length != 0:
280            mat2_innermost_needs_padding = True
281        return mat2_innermost_needs_padding
282
283    mat2_innermost_op = AHOperation(
284        "mat2_innermost_needs_padding",
285        mat2_innermost_needs_padding_fn,
286        is_categorical=True,
287    )
288
289    def num_dims_needs_padding_fn(data: Any) -> int:
290        m_padded_length = data["m_padded_length"]
291        k_padded_length = data["k_padded_length"]
292        n_padded_length = data["n_padded_length"]
293        num_dims_needs_padding = 0
294        if m_padded_length != 0:
295            num_dims_needs_padding += 1
296        if k_padded_length != 0:
297            num_dims_needs_padding += 1
298        if n_padded_length != 0:
299            num_dims_needs_padding += 1
300        return num_dims_needs_padding
301
302    num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn)
303    return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
304
305
306def get_is_contig_ops() -> List[AHOperation]:
307    def mat1_is_contig_fn(data: Any) -> bool:
308        stride_0 = data["mat1_stride_0"]
309        stride_1 = data["mat1_stride_1"]
310        k = data["k"]
311        return stride_0 == k and stride_1 == 1
312
313    mat1_is_contig_op = AHOperation(
314        "mat1_iscontig", mat1_is_contig_fn, is_categorical=True
315    )
316
317    def mat2_is_contig_fn(data: Any) -> bool:
318        stride_0 = data["mat2_stride_0"]
319        stride_1 = data["mat2_stride_1"]
320        n = data["n"]
321        return stride_0 == n and stride_1 == 1
322
323    mat2_is_contig_op = AHOperation(
324        "mat2_iscontig", mat2_is_contig_fn, is_categorical=True
325    )
326
327    return [mat1_is_contig_op, mat2_is_contig_op]
328
329
330def context_add_strides(context: AHContext, name: str, stride: Tuple[int, ...]) -> None:
331    for i, s in enumerate(stride):
332        context.add_feature(f"{name}_stride_{i}", s)
333
334
335def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None:
336    using_tf32 = "not_float_32"
337    if dtype == torch.float32:
338        using_tf32 = torch.backends.cuda.matmul.allow_tf32
339    context.add_feature("using_tf32", using_tf32, is_categorical=True)
340