1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import logging 5from typing import Optional, Sequence 6 7import torch 8from torch import _prims, Tensor 9 10 11log = logging.getLogger(__name__) 12 13 14def make_prim( 15 schema: str, 16 impl_aten, 17 return_type=_prims.RETURN_TYPE.NEW, 18 doc: str = "", 19 tags: Optional[Sequence[torch.Tag]] = None, 20): 21 if isinstance(return_type, tuple): 22 23 def meta(*args, **kwargs): 24 return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs)) 25 26 else: 27 28 def meta(*args, **kwargs): 29 return _prims.TensorMeta(impl_aten(*args, **kwargs)) 30 31 return _prims._make_prim( 32 schema=schema, 33 return_type=return_type, 34 meta=meta, 35 impl_aten=impl_aten, 36 doc=doc, 37 tags=tags, 38 ) 39 40 41def eager_force_stride(input_tensor: Tensor, stride) -> Tensor: 42 if input_tensor.stride() == stride: 43 return input_tensor 44 new_tensor = input_tensor.clone().as_strided( 45 input_tensor.shape, 46 stride, 47 ) 48 new_tensor.copy_(input_tensor) 49 return new_tensor 50 51 52# Custom prims used for handling randomness 53seed = make_prim( 54 "inductor_seed(Device device) -> Tensor", 55 lambda device: torch.randint(2**63 - 1, [], device=device), 56 doc="create a fresh seed (one per call) for use with inductor_rand", 57 tags=(torch.Tag.nondeterministic_seeded,), 58) 59seeds = make_prim( 60 "inductor_seeds(int count, Device device) -> Tensor", 61 lambda count, device: torch.randint(2**63 - 1, [count], device=device), 62 doc="Horizontal fusion of many inductor_seed() calls", 63 tags=(torch.Tag.nondeterministic_seeded,), 64) 65lookup_seed = make_prim( 66 # if inductor_lookup_seed changes, update partitioners.py 67 "inductor_lookup_seed(Tensor seeds, int index) -> Tensor", 68 lambda seeds, index: seeds[index], 69 doc="Extract a single seed from the result of inductor_seeds()", 70) 71random = make_prim( 72 "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", 73 lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device), 74 doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", 75) 76randint = make_prim( 77 "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor", 78 lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device), 79 doc="torch.randint() using backend-specific RNG that can be fused", 80) 81force_stride_order = make_prim( 82 "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor", 83 eager_force_stride, 84 doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise", 85) 86_unsafe_index_put_ = make_prim( 87 "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", 88 lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_( 89 self, indices, values, accumulate 90 ), 91 doc="Unsafe index_put_ (doesn't issue device asserts)", 92) 93fma = make_prim( 94 "fma(Tensor a, Tensor b, Tensor c) -> Tensor", 95 lambda a, b, c: (a * b) + c, 96 doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication", 97) 98 99 100def _low_memory_max_pool2d_with_offsets_aten( 101 self, 102 kernel_size, 103 stride, 104 padding, 105 dilation, 106 ceil_mode, 107): 108 vals, indices = torch.ops.aten.max_pool2d_with_indices( 109 self, kernel_size, stride, padding, dilation, ceil_mode 110 ) 111 112 input_width = self.shape[-1] 113 kernel_width = kernel_size[1] 114 115 bh_shape = [1] * self.ndim 116 bh_shape[-2] = -1 117 bh = torch.arange(indices.shape[-2], dtype=torch.int64, device=self.device).view( 118 bh_shape 119 ) 120 121 bw_shape = [1] * self.ndim 122 bw_shape[-1] = -1 123 bw = torch.arange(indices.shape[-1], dtype=torch.int64, device=self.device).view( 124 bw_shape 125 ) 126 127 hbase = bh * stride[0] - padding[0] 128 wbase = bw * stride[1] - padding[1] 129 130 ih = indices // input_width 131 iw = indices - (ih * input_width) 132 133 h_inc = ih - hbase 134 w_inc = iw - wbase 135 136 offsets = h_inc * kernel_width + w_inc 137 138 return vals, offsets.to(torch.int8) 139 140 141def _low_memory_max_pool2d_offsets_to_indices_aten( 142 offsets, kernel_width, input_width, stride, padding 143): 144 offsets = offsets.to(torch.int64) 145 h_inc = offsets // kernel_width 146 w_inc = offsets - (h_inc * kernel_width) 147 148 bh_shape = [1] * offsets.ndim 149 bh_shape[-2] = -1 150 bh = torch.arange(offsets.shape[-2], dtype=torch.int64, device=offsets.device).view( 151 bh_shape 152 ) 153 154 bw_shape = [1] * offsets.ndim 155 bw_shape[-1] = -1 156 bw = torch.arange(offsets.shape[-1], dtype=torch.int64, device=offsets.device).view( 157 bw_shape 158 ) 159 160 hbase = bh * stride[0] - padding[0] 161 wbase = bw * stride[1] - padding[1] 162 163 ih = hbase + h_inc 164 iw = wbase + w_inc 165 return ih * input_width + iw 166 167 168_low_memory_max_pool2d_with_offsets = make_prim( 169 "_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 170 _low_memory_max_pool2d_with_offsets_aten, 171 return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), 172 doc="Instead of returning indices, returns indices offsets.", 173) 174 175_low_memory_max_pool2d_offsets_to_indices = make_prim( 176 "_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950 177 _low_memory_max_pool2d_offsets_to_indices_aten, 178 doc="Convert small int offsets to regular indices.", 179) 180