1# mypy: ignore-errors 2 3import torch 4import functools 5from torch.testing import make_tensor 6import unittest 7from functorch.experimental.control_flow import map 8from torch.testing._internal.opinfo.core import ( 9 OpInfo, 10 SampleInput, 11) 12from torch.testing._internal.common_dtype import all_types_and, custom_types 13from torch.testing._internal.opinfo.core import DecorateInfo 14from torch.nn.attention.flex_attention import flex_attention, _create_empty_block_mask 15 16def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs): 17 make_arg = functools.partial( 18 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 19 yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)], 20 args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2))) 21 22def inner_f(x, y0, y1): 23 return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())] 24 25def simple_map(xs, y0, y1): 26 def f(x, y0, y1): 27 return inner_f(x, y0, y1) 28 return map(f, xs, y0, y1) 29 30def nested_map(xs, y0, y1): 31 def f1(xx, y0, y1): 32 def f2(x, y0, y1): 33 return inner_f(x, y0, y1) 34 return map(f2, xx, y0, y1) 35 return map(f1, xs, y0, y1) 36 37def triple_nested_map(xs, y0, y1): 38 def f0(xs, y0, y1): 39 def f1(xx, y0, y1): 40 def f2(x, y0, y1): 41 return inner_f(x, y0, y1) 42 return map(f2, xx, y0, y1) 43 return map(f1, xs, y0, y1) 44 return map(f0, xs, y0, y1) 45 46 47# Please consult with torch.export team before 48# adding new entry to this list. 49hop_that_doesnt_have_opinfo_test_allowlist = [ 50 "custom_function_call", 51 "autograd_function_apply", 52 "run_and_save_rng_state", 53 "run_with_rng_state", 54 "out_dtype", 55 "trace_wrapped", 56 "map", # T183144629 57 "map_impl", 58 "with_effects", 59 "strict_mode", 60 "_export_tracepoint", 61 "call_torchbind", 62 "triton_kernel_wrapper_mutation", 63 "triton_kernel_wrapper_functional", 64 "hints_wrapper", 65] 66 67torch.library.define( 68 "testlib::mutating_custom_op", 69 "(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)", 70 tags=torch.Tag.pt2_compliant_tag, 71) 72 73 74@torch.library.impl("testlib::mutating_custom_op", "cpu") 75def foo_impl_cpu(x, z): 76 x.add_(5) 77 z.add_(5) 78 return x, z, x + z 79 80 81@torch.library.impl("testlib::mutating_custom_op", "cuda") 82def foo_impl_cuda(x, z): 83 x.add_(5) 84 z.add_(5) 85 return x, z, x + z 86 87 88@torch.library.register_fake("testlib::mutating_custom_op") 89def foo_impl_abstract(x, z): 90 return x, z, x + z 91 92 93def sample_inputs_cond(opinfo, device, dtype, requires_grad, **kwargs): 94 make_arg = functools.partial( 95 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 96 ) 97 yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2)) 98 99 100def simple_cond(x): 101 return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x]) 102 103 104def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs): 105 make_arg = functools.partial( 106 make_tensor, device=device, dtype=dtype, requires_grad=False 107 ) 108 yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)) 109 110 111def simple_auto_functionalize(x, z): 112 return torch.ops.testlib.mutating_custom_op(x, z) 113 114 115def sample_inputs_flex_attention(opinfo, device, dtype, requires_grad, **kwargs): 116 make_arg = functools.partial( 117 make_tensor, device=device, dtype=dtype, requires_grad=requires_grad 118 ) 119 120 def score_mod(score, b, h, m, n): 121 return score + h 122 123 q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3)) 124 block_mask = _create_empty_block_mask(q, k) 125 yield SampleInput( 126 q, 127 k, 128 v, 129 score_mod, 130 block_mask 131 ) 132 133def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs): 134 make_arg = functools.partial( 135 make_tensor, device=device, dtype=dtype, requires_grad=False 136 ) 137 yield SampleInput( 138 torch.tensor(3), 139 make_arg(2, 3, 4, low=0.1, high=2), 140 ) 141 142def simple_while_loop(iter_t, x): 143 def cond_fn(iter_t, x): 144 return iter_t > 0 145 146 def body_fn(iter_t, x): 147 return iter_t - 1, x.cos() 148 149 return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x)) 150 151 152hop_db = [ 153 OpInfo( 154 name="map", 155 variant_test_name="simple", 156 op=simple_map, 157 sample_inputs_func=sample_inputs_map, 158 dtypes=all_types_and(torch.bool, torch.half), 159 supports_out=False, 160 check_batched_grad=False, 161 check_batched_gradgrad=False, 162 check_batched_forward_grad=False, 163 check_inplace_batched_forward_grad=False, 164 ), 165 OpInfo( 166 name="map", 167 variant_test_name="nested", 168 op=nested_map, 169 sample_inputs_func=sample_inputs_map, 170 dtypes=all_types_and(torch.bool, torch.half), 171 supports_out=False, 172 check_batched_grad=False, 173 check_batched_gradgrad=False, 174 check_batched_forward_grad=False, 175 check_inplace_batched_forward_grad=False, 176 ), 177 OpInfo( 178 name="map", 179 variant_test_name="triple_nested", 180 op=triple_nested_map, 181 sample_inputs_func=sample_inputs_map, 182 dtypes=all_types_and(torch.bool, torch.half), 183 supports_out=False, 184 check_batched_grad=False, 185 check_batched_gradgrad=False, 186 check_batched_forward_grad=False, 187 check_inplace_batched_forward_grad=False, 188 ), 189 OpInfo( 190 name="cond", 191 variant_test_name="simple", 192 op=simple_cond, 193 sample_inputs_func=sample_inputs_cond, 194 dtypes=all_types_and(torch.bool, torch.half), 195 supports_out=False, 196 check_batched_grad=False, 197 check_batched_gradgrad=False, 198 check_batched_forward_grad=False, 199 check_inplace_batched_forward_grad=False, 200 supports_autograd=True, 201 # "torch.compile with aot_autograd does not currently support double backward." 202 supports_gradgrad=False, 203 ), 204 OpInfo( 205 name="while_loop", 206 variant_test_name="simple", 207 op=simple_while_loop, 208 sample_inputs_func=sample_inputs_while_loop, 209 dtypes=all_types_and(torch.bool, torch.half), 210 supports_out=False, 211 check_batched_grad=False, 212 check_batched_gradgrad=False, 213 check_batched_forward_grad=False, 214 check_inplace_batched_forward_grad=False, 215 supports_autograd=False, 216 ), 217 OpInfo( 218 name="auto_functionalize", 219 variant_test_name="simple", 220 op=simple_auto_functionalize, 221 sample_inputs_func=sample_inputs_auto_functionalize, 222 dtypes=all_types_and(torch.bool, torch.half), 223 supports_out=False, 224 check_batched_grad=False, 225 check_batched_gradgrad=False, 226 check_batched_forward_grad=False, 227 check_inplace_batched_forward_grad=False, 228 supports_autograd=False, 229 ), 230 OpInfo( 231 name="flex_attention", 232 variant_test_name="simple", 233 op=flex_attention, 234 sample_inputs_func=sample_inputs_flex_attention, 235 dtypes=custom_types(torch.float16, torch.float32), 236 supports_out=False, 237 check_batched_grad=False, 238 check_batched_gradgrad=False, 239 check_batched_forward_grad=False, 240 check_inplace_batched_forward_grad=False, 241 skips=( 242 DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), 243 DecorateInfo(unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"), 244 DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), 245 DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), 246 ), 247 ), 248 OpInfo( 249 name="flex_attention_backward", 250 variant_test_name="simple", 251 op=flex_attention, 252 sample_inputs_func=sample_inputs_flex_attention, 253 dtypes=custom_types(torch.float16, torch.float32), 254 supports_out=False, 255 check_batched_grad=False, 256 check_batched_gradgrad=False, 257 check_batched_forward_grad=False, 258 check_inplace_batched_forward_grad=False, 259 skips=( 260 DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), 261 DecorateInfo(unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"), 262 DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), 263 DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), 264 ), 265 ) 266] 267