xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/hints.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import typing
4from dataclasses import fields
5from enum import auto, Enum
6from typing import Dict, List, Optional, Union
7
8
9# NOTE: if these fail asserts submit a PR to increase them
10TRITON_MAX_BLOCK = {
11    "X": 2048,
12    "Y": 1024,
13    "Z": 1024,
14    "R": 4096 * 16,  # * 16 is multi-kernel only
15}
16
17
18class ReductionHint(Enum):
19    INNER = 0
20    OUTER = 1
21    OUTER_TINY = 2
22    DEFAULT = 3
23
24
25class TileHint(Enum):
26    SQUARE = 0
27    DEFAULT = 1
28
29
30# Attempt to import AttrsDescriptor from Triton
31try:
32    from triton.compiler.compiler import AttrsDescriptor
33
34    attrs_descriptor_available = True
35    # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor
36    attr_desc_fields = {f.name for f in fields(AttrsDescriptor)}
37    ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields
38    divisible_by_8_available = "divisible_by_8" in attr_desc_fields
39except ImportError:
40    attrs_descriptor_available = False
41
42# Define `instance_descriptor` function with clear conditional handling
43if attrs_descriptor_available:
44
45    def instance_descriptor(
46        divisible_by_16=None,
47        equal_to_1=None,
48        ids_of_folded_args=None,
49        divisible_by_8=None,
50    ):
51        # Prepare the arguments for AttrsDescriptor
52        kwargs = {
53            "divisible_by_16": divisible_by_16,
54            "equal_to_1": equal_to_1,
55        }
56
57        # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor
58        if ids_of_folded_args_available:
59            kwargs["ids_of_folded_args"] = ids_of_folded_args
60        if divisible_by_8_available:
61            kwargs["divisible_by_8"] = divisible_by_8
62
63        # Instantiate AttrsDescriptor with the prepared arguments
64        return AttrsDescriptor(**kwargs)
65
66else:
67    # Define a namedtuple as a fallback when AttrsDescriptor is not available
68    instance_descriptor = collections.namedtuple(  # type: ignore[no-redef]
69        "instance_descriptor",
70        ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
71        defaults=[(), (), (), ()],
72    )
73
74
75_NUM_THREADS_PER_WARP = 32
76
77
78class HeuristicType(Enum):
79    PERSISTENT_REDUCTION = auto()
80    POINTWISE = auto()
81    REDUCTION = auto()
82    SPLIT_SCAN = auto()
83    TEMPLATE = auto()
84    USER_AUTOTUNE = auto()
85
86
87class AutotuneHint(Enum):
88    ELEMENTS_PER_WARP_32 = 0
89
90    # Triton codegen tries to codegen set of AutotuneHints.
91    # Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
92    # which isn't valid python.
93    # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
94    __repr__ = Enum.__str__
95
96
97class DeviceProperties(typing.NamedTuple):
98    """Copy device properties into a data structure not requiring torch to be imported"""
99
100    type: str  # type: ignore[assignment]
101    index: int  # type: ignore[assignment]
102    cc: int
103    major: Optional[int] = None
104    regs_per_multiprocessor: Optional[int] = None
105    max_threads_per_multi_processor: Optional[int] = None
106    multi_processor_count: Optional[int] = None
107
108    @classmethod
109    def create(cls, device):
110        import torch
111        from torch._dynamo.device_interface import get_interface_for_device
112
113        device_type = device.type if torch.version.hip is None else "hip"
114        device_interface = get_interface_for_device(device)
115        if device_type == "cuda":
116            props = device_interface.get_device_properties(device)
117            return cls(
118                type=device_type,
119                index=device.index,
120                cc=device_interface.get_compute_capability(device),
121                major=props.major,
122                regs_per_multiprocessor=props.regs_per_multiprocessor,
123                max_threads_per_multi_processor=props.max_threads_per_multi_processor,
124                multi_processor_count=props.multi_processor_count,
125            )
126        return cls(
127            type=device_type,
128            index=device.index,
129            cc=device_interface.get_compute_capability(device),
130        )
131
132
133class HalideInputSpec(typing.NamedTuple):
134    ctype: str
135    name: str
136    shape: Optional[List[str]] = None
137    stride: Optional[List[str]] = None
138    offset: Optional[str] = None
139    alias_of: Optional[str] = None
140
141    def bindings_type(self):
142        if self.ctype in ("half*", "bfloat16*"):
143            return "uint16_t*"  # half not defined
144        return self.ctype
145
146    def halide_type(self):
147        if self.ctype == "half*":
148            return "halide_type_t(halide_type_float, 16)"  # half not defined
149        if self.ctype == "bfloat16*":
150            return "halide_type_t(halide_type_bfloat, 16)"  # half not defined
151        return f"halide_type_of<{self.ctype.replace('*', '')}>()"
152
153    def is_scalar(self):
154        return self.shape is None
155
156    def is_buffer(self):
157        return self.shape is not None
158
159
160class HalideMeta(typing.NamedTuple):
161    argtypes: List[HalideInputSpec]
162    target: str
163    scheduler: Optional[str] = None
164    scheduler_flags: Optional[Dict[str, Union[int, str]]] = None
165    cuda_device: Optional[int] = None
166
167    def args(self):
168        """Command line args to pass to halide generator"""
169        args = [f"target={self.target}"]
170        if self.scheduler:
171            args.append(f"autoscheduler={self.scheduler}")
172        if self.scheduler_flags:
173            assert self.scheduler
174            for k, v in self.scheduler_flags.items():
175                args.append(f"autoscheduler.{k}={v}")
176        return args
177
178    def is_cuda(self):
179        return self.cuda_device is not None
180