1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import torch 4import functools 5from torch.testing import make_tensor 6from torch.testing._internal.opinfo.core import ( 7 OpInfo, 8 SampleInput, 9) 10from torch.testing._internal.common_dtype import all_types_and 11import numpy as np 12from torch.testing._internal.autograd_function_db import ( 13 sample_inputs_numpy_cube, 14 sample_inputs_numpy_mul, 15 sample_inputs_numpy_mul_scalar, 16 sample_inputs_numpy_sort, 17 sample_inputs_numpy_take, 18) 19from torch import Tensor 20from torch.types import Number 21from typing import * # noqa: F403 22 23# Note: [custom op db] 24# 25# This is a collection of custom operator test cases written as OpInfos 26# so they can easily be consumed by OpInfo-based tests to check if subsystems 27# support them correctly. 28 29def to_numpy(tensor): 30 return tensor.cpu().numpy() 31 32@torch.library.custom_op("_torch_testing::numpy_cube", mutates_args=()) 33def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: 34 x_np = to_numpy(x) 35 dx = torch.tensor(3 * x_np ** 2, device=x.device) 36 return torch.tensor(x_np ** 3, device=x.device), dx 37 38@numpy_cube.register_fake 39def _(x): 40 return x.clone(), x.clone() 41 42def numpy_cube_setup_context(ctx, inputs, output): 43 x, = inputs 44 cube, dx = output 45 ctx.save_for_backward(x, dx) 46 47def numpy_cube_backward(ctx, grad_out, grad_dx): 48 x, dx = ctx.saved_tensors 49 grad_x = numpy_mul(grad_out, dx) + 6 * numpy_mul(grad_dx, x) 50 return grad_x 51 52numpy_cube.register_autograd(numpy_cube_backward, setup_context=numpy_cube_setup_context) 53 54def numpy_cube_vmap(info, in_dims, x): 55 result = numpy_cube(x) 56 return result, (in_dims[0], in_dims[0]) 57 58numpy_cube.register_vmap(numpy_cube_vmap) 59 60@torch.library.custom_op("_torch_testing::numpy_mul", mutates_args=()) 61def numpy_mul(x: Tensor, y: Tensor) -> Tensor: 62 return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) 63 64@numpy_mul.register_fake 65def _(x, y): 66 assert x.device == y.device 67 return (x * y).contiguous() 68 69def numpy_mul_setup_context(ctx, inputs, output): 70 ctx.save_for_backward(*inputs) 71 72def numpy_mul_backward(ctx, grad_out): 73 x, y = ctx.saved_tensors 74 grad_x = grad_out * y if ctx.needs_input_grad[0] else None 75 grad_y = grad_out * x if ctx.needs_input_grad[1] else None 76 return grad_x, grad_y 77 78numpy_mul.register_autograd(numpy_mul_backward, setup_context=numpy_mul_setup_context) 79 80def numpy_mul_vmap(info, in_dims, x, y): 81 x_bdim, y_bdim = in_dims 82 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 83 y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 84 result = x * y 85 result = result.movedim(-1, 0) 86 return result, 0 87 88numpy_mul.register_vmap(numpy_mul_vmap) 89 90@torch.library.custom_op("_torch_testing::numpy_mul_scalar", mutates_args=()) 91def numpy_mul_scalar(x: Tensor, *, scalar: float) -> Tensor: 92 return torch.tensor(to_numpy(x) * scalar, device=x.device) 93 94@numpy_mul_scalar.register_fake 95def _(x, *, scalar): 96 return (x * scalar).contiguous() 97 98def numpy_mul_scalar_setup_context(ctx, inputs, keyword_only_inputs, output): 99 ctx.scalar = keyword_only_inputs["scalar"] 100 101def numpy_mul_scalar_backward(ctx, grad_out): 102 grad_x = grad_out * ctx.scalar 103 return grad_x 104 105numpy_mul_scalar.register_autograd(numpy_mul_scalar_backward, setup_context=numpy_mul_scalar_setup_context) 106 107def numpy_mul_scalar_vmap(info, in_dims, x, *, scalar): 108 x_bdim, = in_dims 109 x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 110 result = x * scalar 111 result = result.movedim(-1, 0) 112 return result, 0 113 114numpy_mul_scalar.register_vmap(numpy_mul_scalar_vmap) 115 116@torch.library.custom_op("_torch_testing::numpy_sort", mutates_args=()) 117def numpy_sort(x: Tensor, dim: int) -> Tuple[Tensor, Tensor, Tensor]: 118 device = x.device 119 x = to_numpy(x) 120 ind = np.argsort(x, axis=dim) 121 ind_inv = np.argsort(ind, axis=dim) 122 result = np.take_along_axis(x, ind, axis=dim) 123 return ( 124 torch.tensor(result, device=device), 125 torch.tensor(ind, device=device), 126 torch.tensor(ind_inv, device=device), 127 ) 128 129@numpy_sort.register_fake 130def _(x, dim): 131 return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long) 132 133def numpy_sort_setup_context(ctx, inputs, output): 134 out, ind, ind_inv = output 135 ctx.dim = inputs[1] 136 ctx.save_for_backward(ind, ind_inv) 137 ctx.mark_non_differentiable(ind, ind_inv) 138 139def numpy_sort_backward(ctx, grad_out, grad_ind, grad_ind_inv): 140 ind, ind_inv = ctx.saved_tensors 141 return numpy_take(grad_out, ind_inv, ind, ctx.dim), None 142 143numpy_sort.register_autograd(numpy_sort_backward, setup_context=numpy_sort_setup_context) 144 145def numpy_sort_vmap(info, in_dims, x, dim): 146 x_bdim, _ = in_dims 147 x = x.movedim(x_bdim, 0) 148 dim = dim if dim >= 0 else dim + x.dim() - 1 149 result = numpy_sort(x, dim + 1) 150 return result, (0, 0, 0) 151 152numpy_sort.register_vmap(numpy_sort_vmap) 153 154@torch.library.custom_op("_torch_testing::numpy_take", mutates_args=()) 155def numpy_take(x: Tensor, ind: Tensor, ind_inv: Tensor, dim: int) -> Tensor: 156 device = x.device 157 x = to_numpy(x) 158 ind = to_numpy(ind) 159 return torch.tensor(np.take_along_axis(x, ind, dim), device=device) 160 161@numpy_take.register_fake 162def _(x, ind, ind_inv, dim): 163 assert x.device == ind.device 164 assert x.device == ind_inv.device 165 assert ind.dtype == torch.long 166 assert ind_inv.dtype == torch.long 167 return torch.empty_like(x) 168 169def numpy_take_setup_context(ctx, inputs, output): 170 x, ind, ind_inv, dim = inputs 171 ctx.dim = dim 172 ctx.save_for_backward(ind, ind_inv) 173 174def numpy_take_backward(ctx, grad_out): 175 ind, ind_inv = ctx.saved_tensors 176 grad_x = numpy_take(grad_out, ind_inv, ind, ctx.dim) 177 return grad_x, None, None, None 178 179numpy_take.register_autograd(numpy_take_backward, setup_context=numpy_take_setup_context) 180 181def numpy_take_vmap(info, in_dims, x, ind, ind_inv, dim): 182 x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims 183 184 # wrap dim 185 logical_dim = x.dim() if x_bdim is None else x_bdim - 1 186 dim = dim if dim >= 0 else dim + logical_dim 187 188 def expand_bdim(x, x_bdim): 189 if x_bdim is None: 190 return x.expand(info.batch_size, *x.shape) 191 return x.movedim(x_bdim, 0) 192 193 x = expand_bdim(x, x_bdim) 194 ind = expand_bdim(ind, ind_bdim) 195 ind_inv = expand_bdim(ind_inv, ind_inv_bdim) 196 197 return numpy_take(x, ind, ind_inv, dim + 1), 0 198 199numpy_take.register_vmap(numpy_take_vmap) 200 201@torch.library.custom_op("_torch_testing::numpy_nonzero", mutates_args=()) 202def numpy_nonzero(x: Tensor) -> Tensor: 203 x_np = to_numpy(x) 204 res = np.stack(np.nonzero(x_np), axis=1) 205 if res.shape[0] <= 1: 206 raise RuntimeError("not supported") 207 return torch.tensor(res, device=x.device) 208 209@numpy_nonzero.register_fake 210def _(x): 211 ctx = torch._custom_op.impl.get_ctx() 212 i0 = ctx.create_unbacked_symint() 213 shape = [i0, x.dim()] 214 result = x.new_empty(shape, dtype=torch.long) 215 return result 216 217def sample_inputs_numpy_nonzero(opinfo, device, dtype, requires_grad, **kwargs): 218 make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 219 shape = 10 220 result = make_arg(shape, low=0.9, high=2) 221 mask = make_tensor(shape, low=0, high=2, device=device, dtype=torch.long) 222 with torch.no_grad(): 223 result *= mask 224 225 yield SampleInput(result, args=()) 226 227def numpy_nonzero_vmap(info, in_dims, x): 228 raise NotImplementedError("Operator is data-dependent and cannot be vmapped.") 229 230numpy_nonzero.register_vmap(numpy_nonzero_vmap) 231 232@torch.library.custom_op("_torch_testing::numpy_view_copy", mutates_args=()) 233def numpy_view_copy(x: Tensor, shape: Sequence[int]) -> Tensor: 234 return torch.tensor(np.copy(to_numpy(x).reshape(shape)), device=x.device) 235 236@numpy_view_copy.register_fake 237def _(x, shape) -> Tensor: 238 return x.clone().view(shape).clone() 239 240def numpy_view_copy_setup_context(ctx, inputs, output) -> None: 241 ctx.x_shape = inputs[0].shape 242 243def numpy_view_copy_backward(ctx, grad_out): 244 return torch.ops._torch_testing.numpy_view_copy(grad_out, ctx.x_shape), None 245 246numpy_view_copy.register_autograd(numpy_view_copy_backward, setup_context=numpy_view_copy_setup_context) 247 248def numpy_view_copy_vmap(info, in_dims, x, shape): 249 x_bdim, _ = in_dims 250 x = x.movedim(x_bdim, 0) 251 x_shape = x.shape[0] 252 batch_shape = (x_shape, *shape) 253 result = numpy_view_copy(x, batch_shape) 254 return result, 0 255 256numpy_view_copy.register_vmap(numpy_view_copy_vmap) 257 258def sample_inputs_numpy_view_copy(opinfo, device, dtype, requires_grad, **kwargs): 259 make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 260 result = make_arg(2, 3, 4, low=0.9, high=2) 261 yield SampleInput(result, args=([2, 12],)) 262 263@torch.library.custom_op('_torch_testing::numpy_cat', mutates_args=()) 264def numpy_cat(xs: Sequence[Tensor], dim: int) -> Tensor: 265 assert len(xs) > 0 266 assert all(x.device == xs[0].device for x in xs) 267 assert all(x.dtype == xs[0].dtype for x in xs) 268 np_xs = [to_numpy(x) for x in xs] 269 np_out = np.concatenate(np_xs, axis=dim) 270 return torch.tensor(np_out, device=xs[0].device) 271 272@numpy_cat.register_fake 273def _(xs, dim): 274 assert len(xs) > 0 275 assert all(x.device == xs[0].device for x in xs) 276 assert all(x.dtype == xs[0].dtype for x in xs) 277 return torch.cat(xs, dim=dim) 278 279def numpy_cat_setup_context(ctx, inputs, output): 280 xs, dim = inputs 281 ctx.dim_sizes = [x.shape[dim] for x in xs] 282 ctx.dim = dim 283 284def numpy_cat_backward(ctx, grad_out): 285 dim_sizes = ctx.dim_sizes 286 dim = ctx.dim 287 288 splits = list(np.cumsum(dim_sizes)[:-1]) 289 grad_xs = torch.ops._torch_testing.numpy_split_copy(grad_out, splits, dim) 290 return grad_xs, None 291 292numpy_cat.register_autograd(numpy_cat_backward, setup_context=numpy_cat_setup_context) 293 294def numpy_cat_vmap(info, in_dims, x, dim): 295 x_bdim, = in_dims 296 result = numpy_cat(x, dim) 297 return result, x_bdim 298 299numpy_cat.register_vmap(numpy_cat_vmap) 300 301def sample_inputs_numpy_cat(opinfo, device, dtype, requires_grad, **kwargs): 302 make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 303 r0 = make_arg(2, 3, 4, low=0.9, high=2) 304 r1 = make_arg(4, 3, 4, low=0.9, high=2) 305 r2 = make_arg(5, 3, 4, low=0.9, high=2) 306 yield SampleInput([r0, r1, r2], args=(0,)) 307 308@torch.library.custom_op('_torch_testing::numpy_split_copy', mutates_args=()) 309def numpy_split_copy(x: Tensor, splits: Sequence[int], dim: int) -> List[Tensor]: 310 x_np = to_numpy(x) 311 arrs = np.split(x_np, splits, axis=dim) 312 return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs] 313 314@numpy_split_copy.register_fake 315def _(x, splits, dim): 316 return [xi.clone() for xi in torch.tensor_split(x, splits, dim)] 317 318def numpy_split_copy_setup_context(ctx, inputs, output): 319 _, _, dim = inputs 320 ctx.dim = dim 321 322def numpy_split_copy_backward(ctx, grad_out): 323 result = torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim) 324 return result, None, None 325 326numpy_split_copy.register_autograd(numpy_split_copy_backward, setup_context=numpy_split_copy_setup_context) 327 328def numpy_split_copy_vmap(info, in_dims, x, splits, dim): 329 x_bdim, _ , _ = in_dims 330 x = x.movedim(x_bdim, 0) 331 result = numpy_split_copy(x, splits, dim + 1) 332 return result, 0 333 334numpy_split_copy.register_vmap(numpy_split_copy_vmap) 335 336def sample_inputs_numpy_split_copy(opinfo, device, dtype, requires_grad, **kwargs): 337 make_arg = functools.partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) 338 x = make_arg(2, 9, low=0.9, high=2) 339 yield SampleInput(x, args=([1, 3, 6], 1)) 340 341@torch.library.custom_op('_torch_testing::numpy_split_copy_with_int', mutates_args=()) 342def numpy_split_copy_with_int(x: Tensor, splits: Sequence[int], dim: int) -> Tuple[List[Tensor], int]: 343 x_np = to_numpy(x) 344 arrs = np.split(x_np, splits, axis=dim) 345 return [torch.tensor(arr, device=x.device, dtype=x.dtype) for arr in arrs], len(splits) 346 347@numpy_split_copy_with_int.register_fake 348def _(x, splits, dim): 349 return [xi.clone() for xi in torch.tensor_split(x, splits, dim)], len(splits) 350 351def numpy_split_copy_with_int_setup_context(ctx, inputs, output): 352 _, _, dim = inputs 353 ctx.dim = dim 354 355def numpy_split_copy_with_int_backward(ctx, grad_out, _): 356 return torch.ops._torch_testing.numpy_cat(grad_out, dim=ctx.dim), None, None 357 358numpy_split_copy_with_int.register_autograd( 359 numpy_split_copy_with_int_backward, 360 setup_context=numpy_split_copy_with_int_setup_context) 361 362def numpy_split_copy_with_int_vmap(info, in_dims, x, splits, dim): 363 x_bdim, _ , _ = in_dims 364 x = x.movedim(x_bdim, 0) 365 result, len_split = numpy_split_copy_with_int(x, splits, dim + 1) 366 return (result, len_split), ([0 for _ in range(len(result))], None) 367 368numpy_split_copy_with_int.register_vmap(numpy_split_copy_with_int_vmap) 369 370@torch.library.custom_op("_torch_testing::numpy_nms", mutates_args=()) 371def numpy_nms(boxes: Tensor, scores: Tensor, iou_threshold: Number) -> Tensor: 372 # Adapted from Ross Girshick's fast-rcnn implementation at 373 # https://github.com/rbgirshick/fast-rcnn/blob/master/lib/utils/nms.py 374 assert boxes.device == scores.device 375 device = boxes.device 376 377 boxes = to_numpy(boxes) 378 scores = to_numpy(scores) 379 380 N = boxes.shape[0] 381 assert boxes.shape == (N, 4) 382 assert scores.shape == (N,) 383 384 x1 = boxes[:, 0] 385 y1 = boxes[:, 1] 386 x2 = boxes[:, 2] 387 y2 = boxes[:, 3] 388 389 areas = (x2 - x1 + 1) * (y2 - y1 + 1) 390 order = scores.argsort()[::-1] 391 392 keep = [] 393 while order.size > 0: 394 i = order[0] 395 keep.append(i) 396 xx1 = np.maximum(x1[i], x1[order[1:]]) 397 yy1 = np.maximum(y1[i], y1[order[1:]]) 398 xx2 = np.minimum(x2[i], x2[order[1:]]) 399 yy2 = np.minimum(y2[i], y2[order[1:]]) 400 401 w = np.maximum(0.0, xx2 - xx1 + 1) 402 h = np.maximum(0.0, yy2 - yy1 + 1) 403 inter = w * h 404 ovr = inter / (areas[i] + areas[order[1:]] - inter) 405 406 inds = np.where(ovr <= iou_threshold)[0] 407 order = order[inds + 1] 408 409 result = torch.tensor(np.stack(keep), device=device) 410 # Needed for data-dependent condition :( 411 assert result.size(0) >= 2 412 return result 413 414@numpy_nms.register_fake 415def _(boxes, scores, iou_threshold): 416 assert boxes.device == scores.device 417 N = boxes.shape[0] 418 assert boxes.shape == (N, 4) 419 assert scores.shape == (N,) 420 421 ctx = torch._custom_op.impl.get_ctx() 422 i0 = ctx.create_unbacked_symint() 423 result = boxes.new_empty([i0], dtype=torch.int64) 424 return result 425 426def numpy_nms_vmap(info, in_dims, boxes, scores, iou_threshold): 427 raise NotImplementedError("Operator is data-dependent and cannot be vmapped.") 428 429numpy_nms.register_vmap(numpy_nms_vmap) 430 431def sample_inputs_numpy_nms(opinfo, device, dtype, requires_grad, **kwargs): 432 make_arg = functools.partial(make_tensor, device=device, dtype=dtype) 433 N = 64 434 xs = make_arg([N], low=0, high=28) 435 dx = make_arg([N], low=0, high=4) 436 ys = make_arg([N], low=0, high=28) 437 dy = make_arg([N], low=0, high=4) 438 boxes = torch.stack([xs, ys, xs + dx, ys + dy], dim=1).requires_grad_(requires_grad) 439 scores = make_arg([N], low=0, high=1, requires_grad=requires_grad) 440 iou_threshold = make_arg([], low=0, high=1).item() 441 442 yield SampleInput(boxes, args=(scores, iou_threshold)) 443 444custom_op_db = [ 445 OpInfo( 446 'NumpyCubeCustomOp', 447 op=numpy_cube._opoverload, 448 sample_inputs_func=sample_inputs_numpy_cube, 449 dtypes=all_types_and(torch.bool, torch.half), 450 supports_out=False, 451 ), 452 OpInfo( 453 'NumpyMulCustomOp', 454 op=numpy_mul._opoverload, 455 sample_inputs_func=sample_inputs_numpy_mul, 456 dtypes=all_types_and(torch.bool, torch.half), 457 supports_out=False, 458 ), 459 OpInfo( 460 'NumpyMulScalarCustomOp', 461 op=numpy_mul_scalar._opoverload, 462 sample_inputs_func=sample_inputs_numpy_mul_scalar, 463 dtypes=all_types_and(torch.bool, torch.half), 464 supports_out=False, 465 ), 466 OpInfo( 467 'NumpySortCustomOp', 468 op=numpy_sort._opoverload, 469 sample_inputs_func=sample_inputs_numpy_sort, 470 dtypes=all_types_and(torch.bool, torch.half), 471 supports_out=False, 472 ), 473 OpInfo( 474 'NumpyTakeCustomOp', 475 op=numpy_take._opoverload, 476 sample_inputs_func=sample_inputs_numpy_take, 477 dtypes=all_types_and(torch.bool, torch.half), 478 supports_out=False, 479 ), 480 OpInfo( 481 'NumpyNonzeroCustomOp', 482 op=numpy_nonzero._opoverload, 483 sample_inputs_func=sample_inputs_numpy_nonzero, 484 dtypes=all_types_and(torch.bool, torch.half), 485 supports_autograd=False, 486 supports_out=False, 487 ), 488 OpInfo( 489 'NumpyNMSCustomOp', 490 op=torch.ops._torch_testing.numpy_nms, 491 sample_inputs_func=sample_inputs_numpy_nms, 492 dtypes=all_types_and(torch.bool, torch.half), 493 supports_autograd=False, 494 supports_out=False, 495 ), 496 OpInfo( 497 'NumpyViewCopyCustomOp', 498 op=torch.ops._torch_testing.numpy_view_copy, 499 sample_inputs_func=sample_inputs_numpy_view_copy, 500 dtypes=all_types_and(torch.bool, torch.half), 501 supports_autograd=True, 502 supports_out=False, 503 ), 504 OpInfo( 505 'NumpyCatCustomOp', 506 op=torch.ops._torch_testing.numpy_cat, 507 sample_inputs_func=sample_inputs_numpy_cat, 508 dtypes=all_types_and(torch.bool, torch.half), 509 supports_autograd=True, 510 check_batched_grad=False, 511 check_batched_gradgrad=False, 512 supports_out=False, 513 ), 514 OpInfo( 515 'NumpySplitCopyCustomOp', 516 op=torch.ops._torch_testing.numpy_split_copy, 517 sample_inputs_func=sample_inputs_numpy_split_copy, 518 dtypes=all_types_and(torch.bool, torch.half), 519 supports_autograd=True, 520 check_batched_grad=False, 521 check_batched_gradgrad=False, 522 supports_out=False, 523 ), 524 OpInfo( 525 'NumpySplitCopyWithIntCustomOp', 526 op=torch.ops._torch_testing.numpy_split_copy_with_int, 527 sample_inputs_func=sample_inputs_numpy_split_copy, 528 dtypes=all_types_and(torch.bool, torch.half), 529 gradcheck_wrapper=lambda op, *args, **kwargs: op(*args, **kwargs)[0], 530 supports_autograd=True, 531 check_batched_grad=False, 532 check_batched_gradgrad=False, 533 supports_out=False, 534 ), 535] 536 537 538# ============================================================== 539# some mechanical test cases 540# ============================================================== 541 542lib = torch.library.Library("_torch_testing", "FRAGMENT") # noqa: TOR901 543 544lib.define("source0(Tensor x) -> Tensor") 545 546@torch.library.register_fake("_torch_testing::source0", lib=lib) 547def _(x): 548 return x.clone() 549 550lib.define("source1(Tensor x) -> Tensor") 551 552def source1_fake(x): 553 return x.clone() 554 555torch.library.register_fake("_torch_testing::source1", source1_fake, lib=lib) 556 557lib.define("source2(Tensor x) -> Tensor") 558 559@torch.library.register_fake("_torch_testing::source2", lib=lib) 560def _(x): 561 return x.clone() 562 563lib.define("source3(Tensor x) -> Tensor") 564 565def source3_fake(x): 566 return x.clone() 567 568torch.library.register_fake("_torch_testing::source3", source3_fake, lib=lib) 569 570 571@torch.library.custom_op("_torch_testing::source4", mutates_args=()) 572def source4(x: Tensor) -> Tensor: 573 return x.clone() 574 575@source4.register_fake 576def _(x): 577 return x.clone() 578 579@torch.library.custom_op("_torch_testing::source5", mutates_args=()) 580def source5(x: Tensor) -> Tensor: 581 return x.clone() 582 583def source5_fake(x): 584 return x.clone() 585 586source5.register_fake(source5_fake) 587