1# mypy: allow-untyped-defs 2import collections 3import typing 4from dataclasses import fields 5from enum import auto, Enum 6from typing import Dict, List, Optional, Union 7 8 9# NOTE: if these fail asserts submit a PR to increase them 10TRITON_MAX_BLOCK = { 11 "X": 2048, 12 "Y": 1024, 13 "Z": 1024, 14 "R": 4096 * 16, # * 16 is multi-kernel only 15} 16 17 18class ReductionHint(Enum): 19 INNER = 0 20 OUTER = 1 21 OUTER_TINY = 2 22 DEFAULT = 3 23 24 25class TileHint(Enum): 26 SQUARE = 0 27 DEFAULT = 1 28 29 30# Attempt to import AttrsDescriptor from Triton 31try: 32 from triton.compiler.compiler import AttrsDescriptor 33 34 attrs_descriptor_available = True 35 # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor 36 attr_desc_fields = {f.name for f in fields(AttrsDescriptor)} 37 ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields 38 divisible_by_8_available = "divisible_by_8" in attr_desc_fields 39except ImportError: 40 attrs_descriptor_available = False 41 42# Define `instance_descriptor` function with clear conditional handling 43if attrs_descriptor_available: 44 45 def instance_descriptor( 46 divisible_by_16=None, 47 equal_to_1=None, 48 ids_of_folded_args=None, 49 divisible_by_8=None, 50 ): 51 # Prepare the arguments for AttrsDescriptor 52 kwargs = { 53 "divisible_by_16": divisible_by_16, 54 "equal_to_1": equal_to_1, 55 } 56 57 # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor 58 if ids_of_folded_args_available: 59 kwargs["ids_of_folded_args"] = ids_of_folded_args 60 if divisible_by_8_available: 61 kwargs["divisible_by_8"] = divisible_by_8 62 63 # Instantiate AttrsDescriptor with the prepared arguments 64 return AttrsDescriptor(**kwargs) 65 66else: 67 # Define a namedtuple as a fallback when AttrsDescriptor is not available 68 instance_descriptor = collections.namedtuple( # type: ignore[no-redef] 69 "instance_descriptor", 70 ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], 71 defaults=[(), (), (), ()], 72 ) 73 74 75_NUM_THREADS_PER_WARP = 32 76 77 78class HeuristicType(Enum): 79 PERSISTENT_REDUCTION = auto() 80 POINTWISE = auto() 81 REDUCTION = auto() 82 SPLIT_SCAN = auto() 83 TEMPLATE = auto() 84 USER_AUTOTUNE = auto() 85 86 87class AutotuneHint(Enum): 88 ELEMENTS_PER_WARP_32 = 0 89 90 # Triton codegen tries to codegen set of AutotuneHints. 91 # Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>"" 92 # which isn't valid python. 93 # Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32". 94 __repr__ = Enum.__str__ 95 96 97class DeviceProperties(typing.NamedTuple): 98 """Copy device properties into a data structure not requiring torch to be imported""" 99 100 type: str # type: ignore[assignment] 101 index: int # type: ignore[assignment] 102 cc: int 103 major: Optional[int] = None 104 regs_per_multiprocessor: Optional[int] = None 105 max_threads_per_multi_processor: Optional[int] = None 106 multi_processor_count: Optional[int] = None 107 108 @classmethod 109 def create(cls, device): 110 import torch 111 from torch._dynamo.device_interface import get_interface_for_device 112 113 device_type = device.type if torch.version.hip is None else "hip" 114 device_interface = get_interface_for_device(device) 115 if device_type == "cuda": 116 props = device_interface.get_device_properties(device) 117 return cls( 118 type=device_type, 119 index=device.index, 120 cc=device_interface.get_compute_capability(device), 121 major=props.major, 122 regs_per_multiprocessor=props.regs_per_multiprocessor, 123 max_threads_per_multi_processor=props.max_threads_per_multi_processor, 124 multi_processor_count=props.multi_processor_count, 125 ) 126 return cls( 127 type=device_type, 128 index=device.index, 129 cc=device_interface.get_compute_capability(device), 130 ) 131 132 133class HalideInputSpec(typing.NamedTuple): 134 ctype: str 135 name: str 136 shape: Optional[List[str]] = None 137 stride: Optional[List[str]] = None 138 offset: Optional[str] = None 139 alias_of: Optional[str] = None 140 141 def bindings_type(self): 142 if self.ctype in ("half*", "bfloat16*"): 143 return "uint16_t*" # half not defined 144 return self.ctype 145 146 def halide_type(self): 147 if self.ctype == "half*": 148 return "halide_type_t(halide_type_float, 16)" # half not defined 149 if self.ctype == "bfloat16*": 150 return "halide_type_t(halide_type_bfloat, 16)" # half not defined 151 return f"halide_type_of<{self.ctype.replace('*', '')}>()" 152 153 def is_scalar(self): 154 return self.shape is None 155 156 def is_buffer(self): 157 return self.shape is not None 158 159 160class HalideMeta(typing.NamedTuple): 161 argtypes: List[HalideInputSpec] 162 target: str 163 scheduler: Optional[str] = None 164 scheduler_flags: Optional[Dict[str, Union[int, str]]] = None 165 cuda_device: Optional[int] = None 166 167 def args(self): 168 """Command line args to pass to halide generator""" 169 args = [f"target={self.target}"] 170 if self.scheduler: 171 args.append(f"autoscheduler={self.scheduler}") 172 if self.scheduler_flags: 173 assert self.scheduler 174 for k, v in self.scheduler_flags.items(): 175 args.append(f"autoscheduler.{k}={v}") 176 return args 177 178 def is_cuda(self): 179 return self.cuda_device is not None 180