xref: /aosp_15_r20/external/pytorch/torch/_inductor/inductor_prims.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import logging
5from typing import Optional, Sequence
6
7import torch
8from torch import _prims, Tensor
9
10
11log = logging.getLogger(__name__)
12
13
14def make_prim(
15    schema: str,
16    impl_aten,
17    return_type=_prims.RETURN_TYPE.NEW,
18    doc: str = "",
19    tags: Optional[Sequence[torch.Tag]] = None,
20):
21    if isinstance(return_type, tuple):
22
23        def meta(*args, **kwargs):
24            return tuple(_prims.TensorMeta(o) for o in impl_aten(*args, **kwargs))
25
26    else:
27
28        def meta(*args, **kwargs):
29            return _prims.TensorMeta(impl_aten(*args, **kwargs))
30
31    return _prims._make_prim(
32        schema=schema,
33        return_type=return_type,
34        meta=meta,
35        impl_aten=impl_aten,
36        doc=doc,
37        tags=tags,
38    )
39
40
41def eager_force_stride(input_tensor: Tensor, stride) -> Tensor:
42    if input_tensor.stride() == stride:
43        return input_tensor
44    new_tensor = input_tensor.clone().as_strided(
45        input_tensor.shape,
46        stride,
47    )
48    new_tensor.copy_(input_tensor)
49    return new_tensor
50
51
52# Custom prims used for handling randomness
53seed = make_prim(
54    "inductor_seed(Device device) -> Tensor",
55    lambda device: torch.randint(2**63 - 1, [], device=device),
56    doc="create a fresh seed (one per call) for use with inductor_rand",
57    tags=(torch.Tag.nondeterministic_seeded,),
58)
59seeds = make_prim(
60    "inductor_seeds(int count, Device device) -> Tensor",
61    lambda count, device: torch.randint(2**63 - 1, [count], device=device),
62    doc="Horizontal fusion of many inductor_seed() calls",
63    tags=(torch.Tag.nondeterministic_seeded,),
64)
65lookup_seed = make_prim(
66    # if inductor_lookup_seed changes, update partitioners.py
67    "inductor_lookup_seed(Tensor seeds, int index) -> Tensor",
68    lambda seeds, index: seeds[index],
69    doc="Extract a single seed from the result of inductor_seeds()",
70)
71random = make_prim(
72    "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor",
73    lambda size, seed, mode: getattr(torch, mode)(size, device=seed.device),
74    doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused",
75)
76randint = make_prim(
77    "inductor_randint(SymInt low, SymInt high, SymInt[] size, Tensor seed) -> Tensor",
78    lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device),
79    doc="torch.randint() using backend-specific RNG that can be fused",
80)
81force_stride_order = make_prim(
82    "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor",
83    eager_force_stride,
84    doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise",
85)
86_unsafe_index_put_ = make_prim(
87    "_unsafe_index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)",
88    lambda self, indices, values, accumulate=False: torch.ops.aten.index_put_(
89        self, indices, values, accumulate
90    ),
91    doc="Unsafe index_put_ (doesn't issue device asserts)",
92)
93fma = make_prim(
94    "fma(Tensor a, Tensor b, Tensor c) -> Tensor",
95    lambda a, b, c: (a * b) + c,
96    doc="Fused multiply add: fma(a, b, c) -> (a * b) + c without rounding after the multiplication",
97)
98
99
100def _low_memory_max_pool2d_with_offsets_aten(
101    self,
102    kernel_size,
103    stride,
104    padding,
105    dilation,
106    ceil_mode,
107):
108    vals, indices = torch.ops.aten.max_pool2d_with_indices(
109        self, kernel_size, stride, padding, dilation, ceil_mode
110    )
111
112    input_width = self.shape[-1]
113    kernel_width = kernel_size[1]
114
115    bh_shape = [1] * self.ndim
116    bh_shape[-2] = -1
117    bh = torch.arange(indices.shape[-2], dtype=torch.int64, device=self.device).view(
118        bh_shape
119    )
120
121    bw_shape = [1] * self.ndim
122    bw_shape[-1] = -1
123    bw = torch.arange(indices.shape[-1], dtype=torch.int64, device=self.device).view(
124        bw_shape
125    )
126
127    hbase = bh * stride[0] - padding[0]
128    wbase = bw * stride[1] - padding[1]
129
130    ih = indices // input_width
131    iw = indices - (ih * input_width)
132
133    h_inc = ih - hbase
134    w_inc = iw - wbase
135
136    offsets = h_inc * kernel_width + w_inc
137
138    return vals, offsets.to(torch.int8)
139
140
141def _low_memory_max_pool2d_offsets_to_indices_aten(
142    offsets, kernel_width, input_width, stride, padding
143):
144    offsets = offsets.to(torch.int64)
145    h_inc = offsets // kernel_width
146    w_inc = offsets - (h_inc * kernel_width)
147
148    bh_shape = [1] * offsets.ndim
149    bh_shape[-2] = -1
150    bh = torch.arange(offsets.shape[-2], dtype=torch.int64, device=offsets.device).view(
151        bh_shape
152    )
153
154    bw_shape = [1] * offsets.ndim
155    bw_shape[-1] = -1
156    bw = torch.arange(offsets.shape[-1], dtype=torch.int64, device=offsets.device).view(
157        bw_shape
158    )
159
160    hbase = bh * stride[0] - padding[0]
161    wbase = bw * stride[1] - padding[1]
162
163    ih = hbase + h_inc
164    iw = wbase + w_inc
165    return ih * input_width + iw
166
167
168_low_memory_max_pool2d_with_offsets = make_prim(
169    "_low_memory_max_pool2d_with_offsets(Tensor self, SymInt[2] kernel_size, SymInt[2] stride,  SymInt[2] padding, SymInt[2] dilation, bool ceil_mode) -> (Tensor, Tensor)",  # noqa: B950
170    _low_memory_max_pool2d_with_offsets_aten,
171    return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW),
172    doc="Instead of returning indices, returns indices offsets.",
173)
174
175_low_memory_max_pool2d_offsets_to_indices = make_prim(
176    "_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor",  # noqa: B950
177    _low_memory_max_pool2d_offsets_to_indices_aten,
178    doc="Convert small int offsets to regular indices.",
179)
180