xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/hop_db.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4import functools
5from torch.testing import make_tensor
6import unittest
7from functorch.experimental.control_flow import map
8from torch.testing._internal.opinfo.core import (
9    OpInfo,
10    SampleInput,
11)
12from torch.testing._internal.common_dtype import all_types_and, custom_types
13from torch.testing._internal.opinfo.core import DecorateInfo
14from torch.nn.attention.flex_attention import flex_attention, _create_empty_block_mask
15
16def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
17    make_arg = functools.partial(
18        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
19    yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)],
20                      args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)))
21
22def inner_f(x, y0, y1):
23    return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())]
24
25def simple_map(xs, y0, y1):
26    def f(x, y0, y1):
27        return inner_f(x, y0, y1)
28    return map(f, xs, y0, y1)
29
30def nested_map(xs, y0, y1):
31    def f1(xx, y0, y1):
32        def f2(x, y0, y1):
33            return inner_f(x, y0, y1)
34        return map(f2, xx, y0, y1)
35    return map(f1, xs, y0, y1)
36
37def triple_nested_map(xs, y0, y1):
38    def f0(xs, y0, y1):
39        def f1(xx, y0, y1):
40            def f2(x, y0, y1):
41                return inner_f(x, y0, y1)
42            return map(f2, xx, y0, y1)
43        return map(f1, xs, y0, y1)
44    return map(f0, xs, y0, y1)
45
46
47# Please consult with torch.export team before
48# adding new entry to this list.
49hop_that_doesnt_have_opinfo_test_allowlist = [
50    "custom_function_call",
51    "autograd_function_apply",
52    "run_and_save_rng_state",
53    "run_with_rng_state",
54    "out_dtype",
55    "trace_wrapped",
56    "map",  # T183144629
57    "map_impl",
58    "with_effects",
59    "strict_mode",
60    "_export_tracepoint",
61    "call_torchbind",
62    "triton_kernel_wrapper_mutation",
63    "triton_kernel_wrapper_functional",
64    "hints_wrapper",
65]
66
67torch.library.define(
68    "testlib::mutating_custom_op",
69    "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
70    tags=torch.Tag.pt2_compliant_tag,
71)
72
73
74@torch.library.impl("testlib::mutating_custom_op", "cpu")
75def foo_impl_cpu(x, z):
76    x.add_(5)
77    z.add_(5)
78    return x, z, x + z
79
80
81@torch.library.impl("testlib::mutating_custom_op", "cuda")
82def foo_impl_cuda(x, z):
83    x.add_(5)
84    z.add_(5)
85    return x, z, x + z
86
87
88@torch.library.register_fake("testlib::mutating_custom_op")
89def foo_impl_abstract(x, z):
90    return x, z, x + z
91
92
93def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs):
94    make_arg = functools.partial(
95        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
96    )
97    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2))
98
99
100def simple_cond(x):
101    return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x])
102
103
104def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs):
105    make_arg = functools.partial(
106        make_tensor, device=device, dtype=dtype, requires_grad=False
107    )
108    yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2))
109
110
111def simple_auto_functionalize(x, z):
112    return torch.ops.testlib.mutating_custom_op(x, z)
113
114
115def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs):
116    make_arg = functools.partial(
117        make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
118    )
119
120    def score_mod(score, b, h, m, n):
121        return score + h
122
123    q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3))
124    block_mask = _create_empty_block_mask(q, k)
125    yield SampleInput(
126        q,
127        k,
128        v,
129        score_mod,
130        block_mask
131    )
132
133def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs):
134    make_arg = functools.partial(
135        make_tensor, device=device, dtype=dtype, requires_grad=False
136    )
137    yield SampleInput(
138        torch.tensor(3),
139        make_arg(2, 3, 4, low=0.1, high=2),
140    )
141
142def simple_while_loop(iter_t, x):
143    def cond_fn(iter_t, x):
144        return iter_t > 0
145
146    def body_fn(iter_t, x):
147        return iter_t - 1, x.cos()
148
149    return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x))
150
151
152hop_db = [
153    OpInfo(
154        name="map",
155        variant_test_name="simple",
156        op=simple_map,
157        sample_inputs_func=sample_inputs_map,
158        dtypes=all_types_and(torch.bool, torch.half),
159        supports_out=False,
160        check_batched_grad=False,
161        check_batched_gradgrad=False,
162        check_batched_forward_grad=False,
163        check_inplace_batched_forward_grad=False,
164    ),
165    OpInfo(
166        name="map",
167        variant_test_name="nested",
168        op=nested_map,
169        sample_inputs_func=sample_inputs_map,
170        dtypes=all_types_and(torch.bool, torch.half),
171        supports_out=False,
172        check_batched_grad=False,
173        check_batched_gradgrad=False,
174        check_batched_forward_grad=False,
175        check_inplace_batched_forward_grad=False,
176    ),
177    OpInfo(
178        name="map",
179        variant_test_name="triple_nested",
180        op=triple_nested_map,
181        sample_inputs_func=sample_inputs_map,
182        dtypes=all_types_and(torch.bool, torch.half),
183        supports_out=False,
184        check_batched_grad=False,
185        check_batched_gradgrad=False,
186        check_batched_forward_grad=False,
187        check_inplace_batched_forward_grad=False,
188    ),
189    OpInfo(
190        name="cond",
191        variant_test_name="simple",
192        op=simple_cond,
193        sample_inputs_func=sample_inputs_cond,
194        dtypes=all_types_and(torch.bool, torch.half),
195        supports_out=False,
196        check_batched_grad=False,
197        check_batched_gradgrad=False,
198        check_batched_forward_grad=False,
199        check_inplace_batched_forward_grad=False,
200        supports_autograd=True,
201        # "torch.compile with aot_autograd does not currently support double backward."
202        supports_gradgrad=False,
203    ),
204    OpInfo(
205        name="while_loop",
206        variant_test_name="simple",
207        op=simple_while_loop,
208        sample_inputs_func=sample_inputs_while_loop,
209        dtypes=all_types_and(torch.bool, torch.half),
210        supports_out=False,
211        check_batched_grad=False,
212        check_batched_gradgrad=False,
213        check_batched_forward_grad=False,
214        check_inplace_batched_forward_grad=False,
215        supports_autograd=False,
216    ),
217    OpInfo(
218        name="auto_functionalize",
219        variant_test_name="simple",
220        op=simple_auto_functionalize,
221        sample_inputs_func=sample_inputs_auto_functionalize,
222        dtypes=all_types_and(torch.bool, torch.half),
223        supports_out=False,
224        check_batched_grad=False,
225        check_batched_gradgrad=False,
226        check_batched_forward_grad=False,
227        check_inplace_batched_forward_grad=False,
228        supports_autograd=False,
229    ),
230    OpInfo(
231        name="flex_attention",
232        variant_test_name="simple",
233        op=flex_attention,
234        sample_inputs_func=sample_inputs_flex_attention,
235        dtypes=custom_types(torch.float16, torch.float32),
236        supports_out=False,
237        check_batched_grad=False,
238        check_batched_gradgrad=False,
239        check_batched_forward_grad=False,
240        check_inplace_batched_forward_grad=False,
241        skips=(
242            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
243            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"),
244            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
245            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
246        ),
247    ),
248    OpInfo(
249        name="flex_attention_backward",
250        variant_test_name="simple",
251        op=flex_attention,
252        sample_inputs_func=sample_inputs_flex_attention,
253        dtypes=custom_types(torch.float16, torch.float32),
254        supports_out=False,
255        check_batched_grad=False,
256        check_batched_gradgrad=False,
257        check_batched_forward_grad=False,
258        check_inplace_batched_forward_grad=False,
259        skips=(
260            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"),
261            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"),
262            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"),
263            DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"),
264        ),
265    )
266]
267