1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import itertools 8import os 9import unittest 10from collections import namedtuple 11 12from functorch_additional_op_db import additional_op_db 13 14import torch 15import torch.utils._pytree as pytree 16from functorch import vmap 17from torch.testing._internal.autograd_function_db import autograd_function_db 18from torch.testing._internal.common_device_type import toleranceOverride 19from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db 20from torch.testing._internal.common_modules import module_db 21from torch.testing._internal.custom_op_db import custom_op_db 22 23 24IS_FBCODE = os.getenv("FUNCTORCH_TEST_FBCODE") == "1" 25 26 27def loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values): 28 outs = [] 29 out_spec = None 30 for idx in range(batch_size): 31 flat_args, args_spec = pytree.tree_flatten(batched_args) 32 flat_dims, dims_spec = pytree.tree_flatten(in_dims) 33 assert args_spec == dims_spec 34 new_args = [ 35 a.select(in_dim, idx) if in_dim is not None else a 36 for a, in_dim in zip(flat_args, flat_dims) 37 ] 38 out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values) 39 flat_out, out_spec = pytree.tree_flatten(out) 40 outs.append(flat_out) 41 42 # use the same out_dim for all outputs 43 if isinstance(out_dim, int): 44 flat_out_dim = [out_dim for _ in flat_out] 45 else: 46 flat_out_dim, _ = pytree.tree_flatten(out_dim) 47 48 outs = zip(*outs) 49 50 result = [] 51 for i, out_lst in enumerate(outs): 52 if flat_out_dim[i] is not None: 53 if not all(isinstance(x, torch.Tensor) for x in out_lst): 54 raise ValueError( 55 f"vmap `{op}` must only return " 56 "Tensors. Did you mean to set out_dims= to None for output?" 57 ) 58 result.append(torch.stack(out_lst)) 59 else: 60 # not batched over, result should be the same for all batches 61 result.append(out_lst[0]) 62 return pytree.tree_unflatten(result, out_spec) 63 64 65# Like loop helper function but for 2 levels of vmap. If we need more levels than this, probably possible 66# to generalize the loops function but it seemed too complicated for this 67def loop2( 68 op, 69 in_dims1, 70 in_dims2, 71 out_dim1, 72 out_dim2, 73 batch_size1, 74 batch_size2, 75 *batched_args, 76 **kwarg_values, 77): 78 outs = [] 79 flat_args, args_spec = pytree.tree_flatten(batched_args) 80 flat_dims1, dims_spec1 = pytree.tree_flatten(in_dims1) 81 flat_dims2, dims_spec2 = pytree.tree_flatten(in_dims2) 82 assert args_spec == dims_spec1 83 assert args_spec == dims_spec2 84 assert len(flat_dims1) == len(flat_dims2) 85 for idx1 in range(batch_size1): 86 out_split = [] 87 arg_split = [ 88 a.select(in_dim1, idx1) if in_dim1 is not None else a 89 for a, in_dim1 in zip(flat_args, flat_dims1) 90 ] 91 for idx2 in range(batch_size2): 92 new_args = [ 93 a.select(in_dim, idx2) if in_dim is not None else a 94 for a, in_dim in zip(arg_split, flat_dims2) 95 ] 96 out = op(*pytree.tree_unflatten(new_args, args_spec), **kwarg_values) 97 out_split.append(out) 98 outs.append(out_split) 99 100 loop_out = [] 101 for out_split in outs: 102 if isinstance(out_split[0], torch.Tensor): 103 loop_out.append(torch.stack(out_split, out_dim1)) 104 else: 105 new_out = [] 106 for idx in range(len(out_split[0])): 107 new_out.append(torch.stack([i[idx] for i in out_split], out_dim1)) 108 loop_out.append(new_out) 109 110 new_out = [] 111 if isinstance(loop_out, torch.Tensor): 112 new_out = torch.stack(loop_out, out_dim2) 113 else: 114 for idx in range(len(loop_out[0])): 115 new_out.append(torch.stack([i[idx] for i in loop_out], out_dim2)) 116 return new_out 117 118 119def is_valid_inplace_sample_input(sample_input, op, inplace_variant): 120 if inplace_variant is None: 121 return False 122 if sample_input.broadcasts_input: 123 return False 124 if not isinstance(sample_input.input, torch.Tensor): 125 return False 126 127 # Check if input's dtype matches the output's dtype 128 args = (sample_input.input,) + sample_input.args 129 kwargs = sample_input.kwargs 130 output_dtype = op(*args, **kwargs).dtype 131 return sample_input.input.dtype == output_dtype 132 133 134# This is kind of dangerous, please think carefully before using it. 135# Known risks: 136# - the return better not be mutated so it's best to return immutable types 137# (e.g. prefer tuples to list) 138# - Don't hash tensors in a global context, that'll keep them around forever 139def memoize(fn): 140 memo = {} 141 142 def wrapped(*args): 143 if args not in memo: 144 memo[args] = fn(*args) 145 return memo[args] 146 147 return wrapped 148 149 150# NB: This is O(2 ** num_tensors). 151# num_tensors ranges from 1 to 10, with 2-4 being most common. 152# Try not to extravagate it if you're modifying it. 153@memoize 154def get_bdim_choices(num_tensors): 155 choices = [] 156 157 # full of zeros 158 choices.append((0,) * num_tensors) 159 160 # All permutations of (-1, None) 161 options = (-1, None) 162 choices.extend(itertools.product(options, repeat=num_tensors)) 163 164 assert choices[-1] == (None,) * num_tensors 165 return tuple(choices[:-1]) 166 167 168# NB: This is O(2 ** num_tensors). 169# num_tensors ranges from 1 to 10, with 2-4 being most common. 170# Try not to extravagate it if you're modifying it. 171def get_bdim_choices_batch_norm( 172 num_tensors, _, running_mean=None, running_var=None, *args 173): 174 choices = [] 175 options = (-1, None) 176 177 # instance norm turns these into unbatched 0 tensors, so we cannot batch the input if either is not specified 178 if running_mean is None or running_var is None: 179 choices.append((None,) + (0,) * (num_tensors - 1)) 180 for choice in itertools.product(options, repeat=num_tensors - 1): 181 choices.append((None,) + choice) 182 183 else: 184 # running_mean and running_var are specified as tensors. Batch norm doesn't work if the input is batched but 185 # running_mean/var are unbatched, so this tests all other cases 186 choices.append((0,) * num_tensors) 187 for choice in itertools.product(options, repeat=num_tensors): 188 input_bdim = choice[0] 189 running_mean_bdim = choice[1] 190 running_var_bdim = choice[2] 191 if input_bdim and (not running_mean_bdim or not running_var_bdim): 192 continue 193 choices.append(choice) 194 195 assert choices[-1] == (None,) * num_tensors 196 return tuple(choices[:-1]) 197 198 199def add_batch_dim(arg, bdim, batch_size=3): 200 assert bdim == 0 or bdim == -1 201 assert isinstance(arg, torch.Tensor) 202 if bdim == 0: 203 shape = [1] * len(arg.shape) 204 shape.insert(bdim, batch_size) 205 return (arg.repeat(shape), bdim) 206 if bdim == -1: 207 arg = arg.unsqueeze(-1).expand(*arg.shape, batch_size).contiguous() 208 return (arg, bdim) 209 210 211def construct_in_dims(bdim_choice_for_tensors, is_tensors): 212 result = [] 213 bdim = iter(bdim_choice_for_tensors) 214 for is_tensor in is_tensors: 215 if not is_tensor: 216 result.append(None) 217 continue 218 result.append(next(bdim)) 219 return tuple(result) 220 221 222def is_batch_norm_training(op_name, kwarg_values): 223 batch_norm_fns = ( 224 "nn.functional.batch_norm", 225 "nn.functional.instance_norm", 226 ) # instance norm calls batch norm 227 if op_name not in batch_norm_fns: 228 return False 229 230 # batch norm and instance norm require the value to be a plain bool 231 default_training = ( 232 op_name == "nn.functional.instance_norm" 233 ) # instance norm defaults to training, batch norm doesn't 234 is_training = tuple( 235 arg for arg in tuple(kwarg_values.values()) if isinstance(arg, bool) 236 ) 237 if len(is_training) == 0: 238 return default_training 239 else: 240 assert len(is_training) == 1 241 return is_training[0] 242 243 244def generate_vmap_inputs( 245 arg_values, kwarg_values, is_batch_norm_and_training=False, batch_size=2 246): 247 flat_args, arg_spec = pytree.tree_flatten(tuple(arg_values)) 248 is_tensors = [isinstance(a, torch.Tensor) for a in flat_args] 249 num_tensors = sum(is_tensors) 250 # For Batch Norm, if there's only an input, we can't 251 # batch it since running_mean/var will be seen as unbatched tensors 252 if num_tensors == 1 and is_batch_norm_and_training: 253 return 254 bdim_choices = ( 255 get_bdim_choices_batch_norm(num_tensors, *arg_values) 256 if is_batch_norm_and_training 257 else get_bdim_choices(num_tensors) 258 ) 259 260 @memoize 261 def get_batched_arg(arg, bdim): 262 assert isinstance(arg, torch.Tensor) 263 assert bdim is not None 264 result, _ = add_batch_dim(arg, bdim, batch_size) 265 return result 266 267 for bdim_choice in bdim_choices: 268 flat_in_dims = construct_in_dims(bdim_choice, is_tensors) 269 270 flat_batched_args = tuple( 271 arg if in_dim is None else get_batched_arg(arg, in_dim) 272 for arg, in_dim in zip(flat_args, flat_in_dims) 273 ) 274 batched_args = pytree.tree_unflatten(flat_batched_args, arg_spec) 275 in_dims = pytree.tree_unflatten(flat_in_dims, arg_spec) 276 yield batched_args, in_dims, kwarg_values 277 278 279def clone_if_tensor(x): 280 if isinstance(x, torch.Tensor): 281 return x.clone() 282 return x 283 284 285# Helper function to compare output of `vmap` against the 286# `for-loop` version. 287def _compute_quantities_for_vmap_test( 288 op, 289 orig_batched_args, 290 orig_kwarg_values, 291 in_dims, 292 out_dim, 293 batch_size, 294 compute_loop_out=True, 295 clone_inputs=False, 296): 297 def maybe_clone_inputs(): 298 if clone_inputs: 299 batched_args = pytree.tree_map(clone_if_tensor, orig_batched_args) 300 kwarg_values = pytree.tree_map(clone_if_tensor, orig_kwarg_values) 301 return batched_args, kwarg_values 302 return orig_batched_args, orig_kwarg_values 303 304 batched_args, kwarg_values = maybe_clone_inputs() 305 306 if compute_loop_out: 307 loop_out = loop(op, in_dims, out_dim, batch_size, *batched_args, **kwarg_values) 308 else: 309 loop_out = None 310 311 # Used for debugging the resulting operations 312 # from functorch import make_fx 313 # def f(a): 314 # return op(a) 315 # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values) 316 # print(in_dims, [arg.shape for arg in batched_args], kwarg_values) 317 batched_args, kwarg_values = maybe_clone_inputs() 318 batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)( 319 *batched_args, **kwarg_values 320 ) 321 322 # Tests case where we dispatch to a batching rule with no bdims 323 # This should be handled by autogenerated plumbing. For vmap support 324 # added via a manual plumbing you may need to handle this specially. 325 def add_bdim_if_tensor(x): 326 if isinstance(x, torch.Tensor): 327 return x.unsqueeze(1) 328 return x 329 330 def f(dummy, *args, **kwargs): 331 return op(*args, **kwargs) 332 333 dummy = torch.ones(batch_size, 1) 334 vmapvmap_expected = pytree.tree_map(add_bdim_if_tensor, batched_out) 335 336 inner_in_dims = (0,) + pytree.tree_map(lambda x: None, in_dims) 337 outer_in_dims = (0,) + in_dims 338 batched_args, kwarg_values = maybe_clone_inputs() 339 vmapvmap_output = vmap( 340 vmap(f, inner_in_dims, out_dims=out_dim), outer_in_dims, out_dims=out_dim 341 )(dummy, *batched_args, **kwarg_values) 342 343 yield (batched_out, loop_out, vmapvmap_output, vmapvmap_expected) 344 345 346# Function with more friendly return types 347# compared to `_compute_quantities_for_vmap_test` 348def compute_quantities_for_vmap_test( 349 op, 350 orig_batched_args, 351 orig_kwarg_values, 352 in_dims, 353 out_dim=0, 354 batch_size=2, 355 compute_loop_out=True, 356 clone_inputs=False, 357): 358 for quantities in _compute_quantities_for_vmap_test( 359 op, 360 orig_batched_args, 361 orig_kwarg_values, 362 in_dims, 363 out_dim, 364 batch_size, 365 compute_loop_out, 366 clone_inputs, 367 ): 368 yield (quantities[0], quantities[1]) 369 yield (quantities[2], quantities[3]) 370 371 372def get_fallback_and_vmap_exhaustive( 373 op, 374 arg_values, 375 kwarg_values, 376 is_batch_norm_and_training=False, 377 compute_loop_out=True, 378): 379 out_dim = 0 380 batch_size = 2 381 382 def make_batched(t): 383 if isinstance(t, torch.Tensor): 384 shape = list(t.shape) 385 shape.insert(out_dim, batch_size) 386 return t.expand(*shape) 387 return t 388 389 # Inputs generated by `generate_vmap_inputs` just copy/expand the unbatched inputs 390 # over the batched dimension. Thus we can compute the expected value once and just 391 # expand it based on the `out_dim` and `batch_size`. 392 expected_unbatched = op(*arg_values, **kwarg_values) 393 expected_batched = pytree.tree_map(make_batched, expected_unbatched) 394 generator = generate_vmap_inputs( 395 arg_values, kwarg_values, is_batch_norm_and_training 396 ) 397 for batched_args, in_dims, kwarg_values in generator: 398 for quantities in _compute_quantities_for_vmap_test( 399 op, 400 batched_args, 401 kwarg_values, 402 in_dims, 403 out_dim, 404 batch_size, 405 compute_loop_out=False, 406 ): 407 assert quantities[1] is None 408 yield (quantities[0], expected_batched) 409 yield (quantities[2], quantities[3]) 410 411 412def opinfo_in_dict(opinfo, d): 413 return (opinfo.name in d) or (f"{opinfo.name}.{opinfo.variant_test_name}" in d) 414 415 416DecorateMeta = namedtuple( 417 "DecorateMeta", 418 [ 419 "op_name", 420 "variant_name", 421 "decorator", 422 "device_type", 423 "dtypes", 424 ], 425) 426 427 428def decorate( 429 op_name, variant_name="", *, decorator=None, device_type=None, dtypes=None 430): 431 assert decorator is not None 432 return DecorateMeta( 433 op_name=op_name, 434 variant_name=variant_name, 435 decorator=decorator, 436 device_type=device_type, 437 dtypes=dtypes, 438 ) 439 440 441def xfail(op_name, variant_name="", *, device_type=None, dtypes=None): 442 return decorate( 443 op_name=op_name, 444 variant_name=variant_name, 445 decorator=unittest.expectedFailure, 446 device_type=device_type, 447 dtypes=dtypes, 448 ) 449 450 451def skip(op_name, variant_name="", *, device_type=None, dtypes=None): 452 return decorate( 453 op_name=op_name, 454 variant_name=variant_name, 455 decorator=unittest.skip("Skipped!"), 456 device_type=device_type, 457 dtypes=dtypes, 458 ) 459 460 461def skipOps(test_case_name, base_test_name, to_skip): 462 all_opinfos = op_db + additional_op_db + autograd_function_db + custom_op_db 463 for decorate_meta in to_skip: 464 matching_opinfos = [ 465 o 466 for o in all_opinfos 467 if o.name == decorate_meta.op_name 468 and o.variant_test_name == decorate_meta.variant_name 469 ] 470 assert len(matching_opinfos) > 0, f"Couldn't find OpInfo for {decorate_meta}" 471 assert len(matching_opinfos) == 1, ( 472 "OpInfos should be uniquely determined by their (name, variant_name). " 473 f"Got more than one result for ({decorate_meta.op_name}, {decorate_meta.variant_name})" 474 ) 475 opinfo = matching_opinfos[0] 476 decorators = list(opinfo.decorators) 477 new_decorator = DecorateInfo( 478 decorate_meta.decorator, 479 test_case_name, 480 base_test_name, 481 device_type=decorate_meta.device_type, 482 dtypes=decorate_meta.dtypes, 483 ) 484 decorators.append(new_decorator) 485 opinfo.decorators = tuple(decorators) 486 487 # This decorator doesn't modify fn in any way 488 def wrapped(fn): 489 return fn 490 491 return wrapped 492 493 494def decorateForModules(decorator, module_classes, device_type=None, dtypes=None): 495 # This decorator doesn't modify fn in any way 496 def wrapped( 497 fn, 498 module_classes=module_classes, 499 decorator=decorator, 500 device_type=device_type, 501 dtypes=dtypes, 502 ): 503 name_parts = fn.__qualname__.split(".") 504 assert ( 505 len(name_parts) == 2 506 ), "Decorator only applies to a test function of a test class" 507 test_case_name, base_test_name = name_parts 508 for module_cls in module_classes: 509 matching_module_infos = [m for m in module_db if m.module_cls == module_cls] 510 assert ( 511 len(matching_module_infos) == 1 512 ), f"Couldn't find single ModuleInfo for {module_cls}" 513 module_info = matching_module_infos[0] 514 decorators = list(module_info.decorators) 515 new_decorator = DecorateInfo( 516 decorator, 517 test_case_name, 518 base_test_name, 519 device_type=device_type, 520 dtypes=dtypes, 521 ) 522 decorators.append(new_decorator) 523 module_info.decorators = tuple(decorators) 524 return fn 525 526 return wrapped 527 528 529def expectedFailureIf(condition): 530 def decorator(fn): 531 if condition: 532 return unittest.expectedFailure(fn) 533 return fn 534 535 return decorator 536 537 538def tol2(op_name, variant_name, override_dct, *, device_type=None): 539 return (op_name, variant_name, override_dct, device_type) 540 541 542def tol1(op_name, override_dct, *, device_type=None): 543 return tol2(op_name, "", override_dct, device_type=device_type) 544 545 546def opsToleranceOverride(test_case_name, base_test_name, overrides): 547 all_opinfos = op_db + additional_op_db 548 for override in overrides: 549 op_name, variant_name, override, device_type = override 550 matching_opinfos = [ 551 o 552 for o in all_opinfos 553 if o.name == op_name and o.variant_test_name == variant_name 554 ] 555 assert len(matching_opinfos) == 1, f"Couldn't find OpInfo for {override}" 556 opinfo = matching_opinfos[0] 557 decorators = list(opinfo.decorators) 558 decorators.append( 559 DecorateInfo( 560 toleranceOverride(override), 561 test_case_name, 562 base_test_name, 563 device_type=device_type, 564 ) 565 ) 566 opinfo.decorators = tuple(decorators) 567 568 # This decorator doesn't modify fn in any way 569 def wrapped(fn): 570 return fn 571 572 return wrapped 573 574 575class DisableVmapFallback: 576 def __enter__(self): 577 self.prev_state = torch._C._functorch._is_vmap_fallback_enabled() 578 torch._C._functorch._set_vmap_fallback_enabled(False) 579 580 def __exit__(self, *ignored): 581 torch._C._functorch._set_vmap_fallback_enabled(self.prev_state) 582 583 584def check_vmap_fallback(test_case, thunk, opinfo, dry_run=False): 585 try: 586 with DisableVmapFallback(): 587 thunk() 588 except Exception: 589 if not dry_run: 590 raise 591 if opinfo.variant_test_name: 592 print(f"xfail('{opinfo.name}', '{opinfo.variant_test_name}'),") 593 else: 594 print(f"xfail('{opinfo.name}'),") 595