xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/autograd_function_db.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4from functools import partial
5from torch.testing import make_tensor
6from torch.testing._internal.opinfo.core import (
7    OpInfo,
8    SampleInput,
9)
10from torch.testing._internal.common_dtype import all_types_and
11import numpy as np
12
13# Note: [autograd.Function db]
14#
15# This is a collection of autograd.Function test cases written as OpInfos
16# so they can easily be consumed by OpInfo-based tests to check if a subsystem
17# supports autograd.Function.
18#
19# Axes:
20# - saves {output, input, intermediate, non-tensor}
21# - {inputs, output} x {single tensor, tensors, arbitrary objects}
22# - Uses {mark_dirty, mark_non_differentiable, once_differentiable}
23
24
25def to_numpy(tensor):
26    return tensor.cpu().numpy()
27
28
29class NumpyCube(torch.autograd.Function):
30    @staticmethod
31    def forward(input):
32        input_np = to_numpy(input)
33        dinput = torch.tensor(3 * input_np ** 2, device=input.device)
34        return torch.tensor(input_np ** 3, device=input.device), dinput
35
36    @staticmethod
37    def setup_context(ctx, inputs, output):
38        ctx.save_for_backward(inputs[0], output[1])
39        ctx.save_for_forward(inputs[0], output[1])
40
41    @staticmethod
42    def backward(ctx, grad_output, grad_saved):
43        input, dinput = ctx.saved_tensors
44        return NumpyMul.apply(grad_output, dinput) + 6 * NumpyMul.apply(grad_saved, input)
45
46    @staticmethod
47    def vmap(info, in_dims, input):
48        result = NumpyCube.apply(input)
49        return result, (in_dims[0], in_dims[0])
50
51    @staticmethod
52    def jvp(ctx, input_tangent):
53        input, dinput = ctx.saved_tensors
54        return NumpyMul.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
55
56
57class CubeGenVmap(torch.autograd.Function):
58    generate_vmap_rule = True
59
60    @staticmethod
61    def forward(x):
62        return x ** 3, 3 * x ** 2
63
64    @staticmethod
65    def setup_context(ctx, inputs, outputs):
66        ctx.save_for_backward(inputs[0], outputs[1])
67        ctx.save_for_forward(inputs[0], outputs[1])
68
69    @staticmethod
70    def backward(ctx, grad_output, grad_saved):
71        input, dinput = ctx.saved_tensors
72        result = grad_output * dinput + 6 * dinput
73        return result
74
75    @staticmethod
76    def jvp(ctx, input_tangent):
77        input, dinput = ctx.saved_tensors
78        return MulGenVmap.apply(input_tangent, dinput), 6 * NumpyMul.apply(input_tangent, input)
79
80
81def sample_inputs_numpy_cube(opinfo, device, dtype, requires_grad, **kwargs):
82    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
83    yield SampleInput(make_arg(1, low=0.8, high=2), args=())
84
85
86class NumpyCubeNotComposable(torch.autograd.Function):
87    @staticmethod
88    def forward(input):
89        input_np = to_numpy(input)
90        return torch.tensor(input_np ** 3, device=input.device), input_np
91
92    @staticmethod
93    def setup_context(ctx, inputs, output):
94        _, input_np = output
95        ctx.input_np = input_np
96        ctx.device = inputs[0].device
97
98    @staticmethod
99    @torch.autograd.function.once_differentiable
100    def backward(ctx, grad_output, grad_saved):
101        result_np = 3 * (ctx.input_np ** 2)
102        return torch.tensor(result_np, device=ctx.device)
103
104
105class NumpyMul(torch.autograd.Function):
106    @staticmethod
107    def forward(x, y):
108        return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
109
110    @staticmethod
111    def setup_context(ctx, inputs, output):
112        ctx.save_for_backward(*inputs)
113        ctx.save_for_forward(*inputs)
114
115    @staticmethod
116    def backward(ctx, grad_output):
117        x, y = ctx.saved_tensors
118        gx = None
119        if ctx.needs_input_grad[0]:
120            gx = NumpyMul.apply(grad_output, y)
121        gy = None
122        if ctx.needs_input_grad[1]:
123            gy = NumpyMul.apply(grad_output, x)
124        return gx, gy
125
126    @staticmethod
127    def vmap(info, in_dims, x, y):
128        x_bdim, y_bdim = in_dims
129        x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
130        y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
131        result = NumpyMul.apply(x, y)
132        result = result.movedim(-1, 0)
133        return result, 0
134
135    @staticmethod
136    def jvp(ctx, x_tangent, y_tangent):
137        x, y = ctx.saved_tensors
138        return x_tangent * y + y_tangent * x
139
140def sample_inputs_numpy_mul(opinfo, device, dtype, requires_grad, **kwargs):
141    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
142    # Broadcasting
143    yield SampleInput(make_arg(4, low=0.9, high=2), args=(make_arg(3, 4, low=0.9, high=2),))
144
145def sample_inputs_numpy_mul_scalar(opinfo, device, dtype, requires_grad, **kwargs):
146    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
147    yield SampleInput(make_arg(4, low=0.9, high=2), args=(), kwargs={"scalar": 3.14})
148
149class MulGenVmap(torch.autograd.Function):
150    generate_vmap_rule = True
151
152    @staticmethod
153    def forward(x, y):
154        return x * y
155
156    @staticmethod
157    def setup_context(ctx, inputs, outputs):
158        ctx.save_for_backward(*inputs)
159        ctx.save_for_forward(*inputs)
160
161    @staticmethod
162    def backward(ctx, grad_output):
163        x, y = ctx.saved_tensors
164        gx = None
165        if ctx.needs_input_grad[0]:
166            gx = MulGenVmap.apply(grad_output, y)
167        gy = None
168        if ctx.needs_input_grad[1]:
169            gy = MulGenVmap.apply(grad_output, x)
170        return gx, gy
171
172    @staticmethod
173    def jvp(ctx, x_tangent, y_tangent):
174        x, y = ctx.saved_tensors
175        return x_tangent * y + y_tangent * x
176
177
178class NumpyExp_(torch.autograd.Function):
179    @staticmethod
180    def forward(x):
181        x_np = to_numpy(x)
182        np.exp(x_np, x_np)
183        return x
184
185    @staticmethod
186    def setup_context(ctx, inputs, output):
187        x, = inputs
188        ctx.mark_dirty(x)
189        ctx.save_for_backward(output)
190        ctx.save_for_forward(output)
191
192    @staticmethod
193    def backward(ctx, grad_output):
194        output, = ctx.saved_tensors
195        return NumpyMul.apply(grad_output, output)
196
197    @staticmethod
198    def vmap(info, in_dims, x):
199        NumpyExp_.apply(x)
200        return x, in_dims[0]
201
202    @staticmethod
203    def jvp(ctx, x_tangent):
204        # Doesn't call numpy operations because I didn't want to write NumpyMul_
205        output, = ctx.saved_tensors
206        x_tangent.mul_(output)
207        return x_tangent
208
209class NumpySort(torch.autograd.Function):
210    @staticmethod
211    def forward(x, dim):
212        device = x.device
213        x = to_numpy(x)
214        ind = np.argsort(x, axis=dim)
215        ind_inv = np.argsort(ind, axis=dim)
216        result = np.take_along_axis(x, ind, axis=dim)
217        return (
218            torch.tensor(x, device=device),
219            torch.tensor(ind, device=device),
220            torch.tensor(ind_inv, device=device),
221        )
222
223    @staticmethod
224    def setup_context(ctx, inputs, output):
225        x, dim = inputs
226        _, ind, ind_inv = output
227        ctx.mark_non_differentiable(ind, ind_inv)
228        ctx.save_for_backward(ind, ind_inv)
229        ctx.save_for_forward(ind, ind_inv)
230        ctx.dim = dim
231
232    @staticmethod
233    def backward(ctx, grad_output, _0, _1):
234        ind, ind_inv = ctx.saved_tensors
235        return NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim), None
236
237    @staticmethod
238    def vmap(info, in_dims, x, dim):
239        x_bdim, _ = in_dims
240        x = x.movedim(x_bdim, 0)
241        # wrap dim
242        dim = dim if dim >= 0 else dim + x.dim() - 1
243        return NumpySort.apply(x, dim + 1), (0, 0, 0)
244
245    @staticmethod
246    def jvp(ctx, x_tangent, _):
247        ind, ind_inv = ctx.saved_tensors
248        return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
249
250class SortGenVmap(torch.autograd.Function):
251    generate_vmap_rule = True
252
253    @staticmethod
254    def forward(x, dim):
255        device = x.device
256        ind = torch.argsort(x, dim=dim)
257        ind_inv = torch.argsort(ind, axis=dim)
258        result = torch.take_along_dim(x, ind, dim=dim)
259        return result, ind, ind_inv
260
261    @staticmethod
262    def setup_context(ctx, inputs, outputs):
263        x, dim = inputs
264        _, ind, ind_inv = outputs
265        ctx.mark_non_differentiable(ind, ind_inv)
266        ctx.save_for_backward(ind, ind_inv)
267        ctx.save_for_forward(ind, ind_inv)
268        ctx.dim = dim
269
270    @staticmethod
271    def backward(ctx, grad_output, _0, _1):
272        ind, ind_inv = ctx.saved_tensors
273        return TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim), None
274
275    @staticmethod
276    def jvp(ctx, x_tangent, _):
277        ind, ind_inv = ctx.saved_tensors
278        return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim), None, None
279
280
281def sample_inputs_numpy_sort(opinfo, device, dtype, requires_grad, **kwargs):
282    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
283    yield SampleInput(make_arg(3, 5), args=(1,))
284
285
286def sample_inputs_numpy_take(opinfo, device, dtype, requires_grad, **kwargs):
287    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
288    tensor = make_arg(3, 5)
289    dim = 1
290    _, ind, ind_inv = NumpySort.apply(tensor, 1)
291    yield SampleInput(tensor, args=(ind, ind_inv, dim))
292
293
294class NumpyTake(torch.autograd.Function):
295    @staticmethod
296    def forward(x, ind, ind_inv, dim):
297        device = x.device
298        x = to_numpy(x)
299        ind = to_numpy(ind)
300        return torch.tensor(np.take_along_axis(x, ind, dim), device=device)
301
302    @staticmethod
303    def setup_context(ctx, inputs, output):
304        x, ind, ind_inv, dim = inputs
305        ctx.save_for_backward(ind, ind_inv)
306        ctx.save_for_forward(ind, ind_inv)
307        ctx.dim = dim
308
309    @staticmethod
310    def backward(ctx, grad_output):
311        ind, ind_inv = ctx.saved_tensors
312        result = NumpyTake.apply(grad_output, ind_inv, ind, ctx.dim)
313        return result, None, None, None
314
315    @staticmethod
316    def vmap(info, in_dims, x, ind, ind_inv, dim):
317        x_bdim, ind_bdim, ind_inv_bdim, _ = in_dims
318
319        # wrap dim
320        logical_dim = x.dim() if x_bdim is None else x_bdim - 1
321        dim = dim if dim >= 0 else dim + logical_dim
322
323        def expand_bdim(x, x_bdim):
324            if x_bdim is None:
325                return x.expand(info.batch_size, *x.shape)
326            return x.movedim(x_bdim, 0)
327
328        x = expand_bdim(x, x_bdim)
329        ind = expand_bdim(ind, ind_bdim)
330        ind_inv = expand_bdim(ind_inv, ind_inv_bdim)
331
332        return NumpyTake.apply(x, ind, ind_inv, dim + 1), 0
333
334    @staticmethod
335    def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
336        assert ind_tangent is None
337        assert ind_inv_tangent is None
338        ind, ind_inv = ctx.saved_tensors
339        return NumpyTake.apply(x_tangent, ind, ind_inv, ctx.dim)
340
341class TakeGenVmap(torch.autograd.Function):
342    generate_vmap_rule = True
343
344    @staticmethod
345    def forward(x, ind, ind_inv, dim):
346        return torch.take_along_dim(x, ind, dim)
347
348    @staticmethod
349    def setup_context(ctx, inputs, outputs):
350        x, ind, ind_inv, dim = inputs
351        ctx.save_for_backward(ind, ind_inv)
352        ctx.save_for_forward(ind, ind_inv)
353        ctx.dim = dim
354
355    @staticmethod
356    def backward(ctx, grad_output):
357        ind, ind_inv = ctx.saved_tensors
358        result = TakeGenVmap.apply(grad_output, ind_inv, ind, ctx.dim)
359        return result, None, None, None
360
361    @staticmethod
362    def jvp(ctx, x_tangent, ind_tangent, ind_inv_tangent, _):
363        ind, ind_inv = ctx.saved_tensors
364        return TakeGenVmap.apply(x_tangent, ind, ind_inv, ctx.dim)
365
366class Select(torch.autograd.Function):
367    @staticmethod
368    def forward(x, idx):
369        return x[idx]
370
371    @staticmethod
372    def setup_context(ctx, inputs, output):
373        x, idx = inputs
374        ctx.x_shape = x.shape
375        ctx.idx = idx
376
377    @staticmethod
378    def backward(ctx, grad_output):
379        result = grad_output.new_zeros(ctx.x_shape)
380        result[ctx.idx] = grad_output
381        return result, None
382
383    @staticmethod
384    def vmap(info, in_dims, x, idx):
385        x_bdim, _ = in_dims
386        x = x.movedim(x_bdim, 1)
387        return Select.apply(x, idx), 0
388
389    @staticmethod
390    def jvp(ctx, x_tangent, _):
391        return Select.apply(x_tangent, ctx.idx)
392
393class SelectGenVmap(torch.autograd.Function):
394    generate_vmap_rule = True
395
396    @staticmethod
397    def forward(x, idx):
398        return x[idx]
399
400    @staticmethod
401    def setup_context(ctx, inputs, outputs):
402        x, idx = inputs
403        ctx.x_shape = x.shape
404        ctx.idx = idx
405
406    @staticmethod
407    def backward(ctx, grad_output):
408        result = grad_output.new_zeros(ctx.x_shape)
409        result[ctx.idx] = grad_output
410        return result, None
411
412    @staticmethod
413    def jvp(ctx, x_tangent, _):
414        return SelectGenVmap.apply(x_tangent, ctx.idx)
415
416
417def sample_inputs_select(opinfo, device, dtype, requires_grad, **kwargs):
418    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
419    yield SampleInput(make_arg(3, 5), args=(2,))
420
421class ScaleGradGenVmap(torch.autograd.Function):
422    generate_vmap_rule = True
423    scale = 3.14
424
425    @staticmethod
426    def forward(x):
427        return x.clone()
428
429    @staticmethod
430    def setup_context(ctx, inputs, outputs):
431        pass
432
433    @staticmethod
434    def backward(ctx, grad_output):
435        return grad_output * ScaleGradGenVmap.scale
436
437    @staticmethod
438    def jvp(ctx, x_tangent):
439        return x_tangent * ScaleGradGenVmap.scale
440
441class ZeroGradientsGenVmap(torch.autograd.Function):
442    generate_vmap_rule = True
443
444    @staticmethod
445    def forward(x, y):
446        return x.clone(), y.clone()
447
448    @staticmethod
449    def setup_context(ctx, inputs, outputs):
450        pass
451
452    @staticmethod
453    def backward(ctx, gx, gy):
454        # Intentionally returning torch.zeros instead of zeros_like or new_zeros.
455        # Also intentionally not None.
456        return (
457            # Intentionally too-large gradient
458            torch.zeros(3, 4, *gx.shape, dtype=gx.dtype, device=gx.device),
459            torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
460        )
461
462    @staticmethod
463    def jvp(ctx, gx, gy):
464        # Intentionally returning torch.zeros instead of zeros_like or new_zeros.
465        # Also intentionally not None.
466        return (
467            torch.zeros(gx.shape, dtype=gx.dtype, device=gx.device),
468            torch.zeros(gy.shape, dtype=gy.dtype, device=gy.device),
469        )
470
471
472def sample_inputs_forward_default_args(opinfo, device, dtype, requires_grad, **kwargs):
473    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
474    yield SampleInput(make_arg(3, 5))
475
476
477class ForwardHasDefaultArgs(torch.autograd.Function):
478    @staticmethod
479    def forward(x, idx=(2,)):
480        return x[idx]
481
482    @staticmethod
483    def setup_context(ctx, inputs, output):
484        x, idx = inputs
485        ctx.x_shape = x.shape
486        ctx.idx = idx
487
488    @staticmethod
489    def backward(ctx, grad_output):
490        result = grad_output.new_zeros(ctx.x_shape)
491        result[ctx.idx] = grad_output
492        return result, None
493
494    @staticmethod
495    def vmap(info, in_dims, x, idx):
496        x_bdim, _ = in_dims
497        x = x.movedim(x_bdim, 1)
498        return ForwardHasDefaultArgs.apply(x, idx), 0
499
500    @staticmethod
501    def jvp(ctx, x_tangent, _):
502        return ForwardHasDefaultArgs.apply(x_tangent, ctx.idx)
503
504
505autograd_function_db = [
506    OpInfo(
507        'NumpyCubeAutogradFunction',
508        op=NumpyCube.apply,
509        supports_forward_ad=True,
510        supports_fwgrad_bwgrad=True,
511        sample_inputs_func=sample_inputs_numpy_cube,
512        dtypes=all_types_and(torch.bool, torch.half),
513        supports_out=False,
514    ),
515    OpInfo(
516        'NumpyExpMarkDirtyAutogradFunction',
517        op=lambda x: NumpyExp_.apply(x.clone()),
518        inplace_variant=NumpyExp_.apply,
519        supports_forward_ad=True,
520        supports_fwgrad_bwgrad=True,
521        sample_inputs_func=sample_inputs_numpy_cube,
522        dtypes=all_types_and(torch.bool, torch.half),
523        supports_out=False,
524    ),
525    OpInfo(
526        'NumpyMulAutogradFunction',
527        op=NumpyMul.apply,
528        supports_forward_ad=True,
529        supports_fwgrad_bwgrad=True,
530        sample_inputs_func=sample_inputs_numpy_mul,
531        dtypes=all_types_and(torch.bool, torch.half),
532        supports_out=False,
533    ),
534    OpInfo(
535        'NumpyCubeNotComposableAutogradFunction',
536        op=lambda x: NumpyCubeNotComposable.apply(x)[0],
537        supports_forward_ad=False,
538        supports_fwgrad_bwgrad=False,
539        sample_inputs_func=sample_inputs_numpy_cube,
540        dtypes=all_types_and(torch.bool, torch.half),
541        supports_out=False,
542    ),
543    OpInfo(
544        'NumpySortAutogradFunction',
545        op=NumpySort.apply,
546        supports_forward_ad=False,
547        supports_fwgrad_bwgrad=False,
548        sample_inputs_func=sample_inputs_numpy_sort,
549        dtypes=all_types_and(torch.bool, torch.half),
550        supports_out=False,
551        gradcheck_wrapper=lambda y, ind: y,
552    ),
553    OpInfo(
554        'NumpyTakeAutogradFunction',
555        op=NumpyTake.apply,
556        supports_forward_ad=False,
557        supports_fwgrad_bwgrad=False,
558        sample_inputs_func=sample_inputs_numpy_take,
559        dtypes=all_types_and(torch.bool, torch.half),
560        supports_out=False,
561    ),
562    OpInfo(
563        'SelectAutogradFunction',
564        op=Select.apply,
565        supports_forward_ad=True,
566        supports_fwgrad_bwgrad=True,
567        sample_inputs_func=sample_inputs_select,
568        dtypes=all_types_and(torch.bool, torch.half),
569        supports_out=False,
570    ),
571    OpInfo(
572        'CubeGenVmapAutogradFunction',
573        op=CubeGenVmap.apply,
574        supports_forward_ad=True,
575        supports_fwgrad_bwgrad=True,
576        sample_inputs_func=sample_inputs_numpy_cube,
577        dtypes=all_types_and(torch.bool, torch.half),
578        supports_out=False,
579    ),
580    OpInfo(
581        'MulGenVmapAutogradFunction',
582        op=MulGenVmap.apply,
583        supports_forward_ad=True,
584        supports_fwgrad_bwgrad=True,
585        sample_inputs_func=sample_inputs_numpy_mul,
586        dtypes=all_types_and(torch.bool, torch.half),
587        supports_out=False,
588    ),
589    OpInfo(
590        'SortGenVmapAutogradFunction',
591        op=SortGenVmap.apply,
592        supports_forward_ad=True,
593        supports_fwgrad_bwgrad=True,
594        sample_inputs_func=sample_inputs_numpy_sort,
595        dtypes=all_types_and(torch.bool, torch.half),
596        supports_out=False,
597        gradcheck_wrapper=lambda y, ind: y,
598    ),
599    OpInfo(
600        'SelectGenVmapAutogradFunction',
601        op=SelectGenVmap.apply,
602        supports_forward_ad=True,
603        supports_fwgrad_bwgrad=True,
604        sample_inputs_func=sample_inputs_select,
605        dtypes=all_types_and(torch.bool, torch.half),
606        supports_out=False,
607    ),
608    OpInfo(
609        'ScaleGradGenVmapAutogradFunction',
610        op=ScaleGradGenVmap.apply,
611        supports_forward_ad=True,
612        supports_fwgrad_bwgrad=True,
613        sample_inputs_func=sample_inputs_numpy_cube,
614        dtypes=all_types_and(torch.bool, torch.half),
615        supports_out=False,
616    ),
617    OpInfo(
618        'ZeroGradientsGenVmapAutogradFunction',
619        op=ZeroGradientsGenVmap.apply,
620        supports_forward_ad=True,
621        supports_fwgrad_bwgrad=True,
622        sample_inputs_func=sample_inputs_numpy_mul,
623        dtypes=all_types_and(torch.bool, torch.half),
624        supports_out=False,
625    ),
626    OpInfo(
627        'ForwardHasDefaultArgsAutogradFunction',
628        op=ForwardHasDefaultArgs.apply,
629        supports_forward_ad=True,
630        supports_fwgrad_bwgrad=True,
631        sample_inputs_func=sample_inputs_forward_default_args,
632        dtypes=all_types_and(torch.bool, torch.half),
633        supports_out=False,
634    ),
635]
636