xref: /aosp_15_r20/external/pytorch/test/functorch/test_vmap_registrations.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2import typing
3import unittest
4
5from torch._C import (
6    _dispatch_get_registrations_for_dispatch_key as get_registrations_for_dispatch_key,
7)
8from torch.testing._internal.common_utils import (
9    instantiate_parametrized_tests,
10    parametrize,
11    run_tests,
12    subtest,
13    TestCase,
14)
15
16
17xfail_functorch_batched = {
18    "aten::is_nonzero",
19    "aten::item",
20    "aten::linalg_slogdet",
21    "aten::masked_select_backward",
22    "aten::one_hot",
23    "aten::silu_backward",
24    "aten::where",
25}
26
27xfail_functorch_batched_decomposition = {
28    "aten::alias_copy",
29    "aten::as_strided_copy",
30    "aten::diagonal_copy",
31    "aten::is_same_size",
32    "aten::unfold_copy",
33}
34
35xfail_not_implemented = {
36    "aten::affine_grid_generator_backward",
37    "aten::align_as",
38    "aten::align_tensors",
39    "aten::align_to",
40    "aten::align_to.ellipsis_idx",
41    "aten::alpha_dropout",
42    "aten::alpha_dropout_",
43    "aten::argwhere",
44    "aten::bilinear",
45    "aten::can_cast",
46    "aten::cat.names",
47    "aten::chain_matmul",
48    "aten::chalf",
49    "aten::choose_qparams_optimized",
50    "aten::clip_",
51    "aten::clip_.Tensor",
52    "aten::coalesce",
53    "aten::column_stack",
54    "aten::concat.names",
55    "aten::concatenate.names",
56    "aten::conj",
57    "aten::conv_tbc_backward",
58    "aten::ctc_loss.IntList",
59    "aten::ctc_loss.Tensor",
60    "aten::cudnn_is_acceptable",
61    "aten::cummaxmin_backward",
62    "aten::data",
63    "aten::diagflat",
64    "aten::divide.out_mode",
65    "aten::divide_.Scalar",
66    "aten::dropout_",
67    "aten::embedding_bag",
68    "aten::embedding_bag.padding_idx",
69    "aten::feature_alpha_dropout",
70    "aten::feature_alpha_dropout_",
71    "aten::feature_dropout",
72    "aten::feature_dropout_",
73    "aten::fft_ihfft2",
74    "aten::fft_ihfftn",
75    "aten::fill_diagonal_",
76    "aten::fix_",
77    "aten::flatten.named_out_dim",
78    "aten::flatten.using_names",
79    "aten::flatten_dense_tensors",
80    "aten::float_power_.Scalar",
81    "aten::float_power_.Tensor",
82    "aten::floor_divide_.Scalar",
83    "aten::frobenius_norm",
84    "aten::fused_moving_avg_obs_fake_quant",
85    "aten::get_gradients",
86    "aten::greater_.Scalar",
87    "aten::greater_.Tensor",
88    "aten::greater_equal_.Scalar",
89    "aten::greater_equal_.Tensor",
90    "aten::gru.data",
91    "aten::gru.input",
92    "aten::gru_cell",
93    "aten::histogramdd",
94    "aten::histogramdd.TensorList_bins",
95    "aten::histogramdd.int_bins",
96    "aten::infinitely_differentiable_gelu_backward",
97    "aten::isclose",
98    "aten::istft",
99    "aten::item",
100    "aten::kl_div",
101    "aten::ldexp_",
102    "aten::less_.Scalar",
103    "aten::less_.Tensor",
104    "aten::less_equal_.Scalar",
105    "aten::less_equal_.Tensor",
106    "aten::linalg_cond.p_str",
107    "aten::linalg_eigh.eigvals",
108    "aten::linalg_matrix_rank",
109    "aten::linalg_matrix_rank.out_tol_tensor",
110    "aten::linalg_matrix_rank.tol_tensor",
111    "aten::linalg_pinv.out_rcond_tensor",
112    "aten::linalg_pinv.rcond_tensor",
113    "aten::linalg_slogdet",
114    "aten::linalg_svd.U",
115    "aten::linalg_tensorsolve",
116    "aten::logsumexp.names",
117    "aten::lstm.data",
118    "aten::lstm.input",
119    "aten::lstm_cell",
120    "aten::lu_solve",
121    "aten::margin_ranking_loss",
122    "aten::masked_select_backward",
123    "aten::matrix_exp",
124    "aten::matrix_exp_backward",
125    "aten::max.names_dim",
126    "aten::max.names_dim_max",
127    "aten::mean.names_dim",
128    "aten::median.names_dim",
129    "aten::median.names_dim_values",
130    "aten::min.names_dim",
131    "aten::min.names_dim_min",
132    "aten::mish_backward",
133    "aten::moveaxis.int",
134    "aten::multilabel_margin_loss",
135    "aten::nanmedian.names_dim",
136    "aten::nanmedian.names_dim_values",
137    "aten::nanquantile",
138    "aten::nanquantile.scalar",
139    "aten::narrow.Tensor",
140    "aten::native_channel_shuffle",
141    "aten::negative_",
142    "aten::nested_to_padded_tensor",
143    "aten::nonzero_numpy",
144    "aten::norm.names_ScalarOpt_dim",
145    "aten::norm.names_ScalarOpt_dim_dtype",
146    "aten::norm_except_dim",
147    "aten::not_equal_.Scalar",
148    "aten::not_equal_.Tensor",
149    "aten::one_hot",
150    "aten::output_nr",
151    "aten::pad_sequence",
152    "aten::pdist",
153    "aten::pin_memory",
154    "aten::promote_types",
155    "aten::qr.Q",
156    "aten::quantile",
157    "aten::quantile.scalar",
158    "aten::refine_names",
159    "aten::rename",
160    "aten::rename_",
161    "aten::requires_grad_",
162    "aten::retain_grad",
163    "aten::retains_grad",
164    "aten::rnn_relu.data",
165    "aten::rnn_relu.input",
166    "aten::rnn_relu_cell",
167    "aten::rnn_tanh.data",
168    "aten::rnn_tanh.input",
169    "aten::rnn_tanh_cell",
170    "aten::set_.source_Tensor_storage_offset",
171    "aten::set_data",
172    "aten::silu_backward",
173    "aten::slow_conv3d",
174    "aten::smm",
175    "aten::special_chebyshev_polynomial_t.n_scalar",
176    "aten::special_chebyshev_polynomial_t.x_scalar",
177    "aten::special_chebyshev_polynomial_u.n_scalar",
178    "aten::special_chebyshev_polynomial_u.x_scalar",
179    "aten::special_chebyshev_polynomial_v.n_scalar",
180    "aten::special_chebyshev_polynomial_v.x_scalar",
181    "aten::special_chebyshev_polynomial_w.n_scalar",
182    "aten::special_chebyshev_polynomial_w.x_scalar",
183    "aten::special_hermite_polynomial_h.n_scalar",
184    "aten::special_hermite_polynomial_h.x_scalar",
185    "aten::special_hermite_polynomial_he.n_scalar",
186    "aten::special_hermite_polynomial_he.x_scalar",
187    "aten::special_laguerre_polynomial_l.n_scalar",
188    "aten::special_laguerre_polynomial_l.x_scalar",
189    "aten::special_legendre_polynomial_p.n_scalar",
190    "aten::special_legendre_polynomial_p.x_scalar",
191    "aten::special_shifted_chebyshev_polynomial_t.n_scalar",
192    "aten::special_shifted_chebyshev_polynomial_t.x_scalar",
193    "aten::special_shifted_chebyshev_polynomial_u.n_scalar",
194    "aten::special_shifted_chebyshev_polynomial_u.x_scalar",
195    "aten::special_shifted_chebyshev_polynomial_v.n_scalar",
196    "aten::special_shifted_chebyshev_polynomial_v.x_scalar",
197    "aten::special_shifted_chebyshev_polynomial_w.n_scalar",
198    "aten::special_shifted_chebyshev_polynomial_w.x_scalar",
199    "aten::square_",
200    "aten::sspaddmm",
201    "aten::std.correction_names",
202    "aten::std.names_dim",
203    "aten::std_mean.correction_names",
204    "aten::std_mean.names_dim",
205    "aten::stft",
206    "aten::stft.center",
207    "aten::stride.int",
208    "aten::subtract.Scalar",
209    "aten::subtract_.Scalar",
210    "aten::subtract_.Tensor",
211    "aten::svd.U",
212    "aten::sym_size.int",
213    "aten::sym_stride.int",
214    "aten::sym_numel",
215    "aten::sym_storage_offset",
216    "aten::tensor_split.tensor_indices_or_sections",
217    "aten::thnn_conv2d",
218    "aten::to_dense",
219    "aten::to_dense_backward",
220    "aten::to_mkldnn_backward",
221    "aten::trace_backward",
222    "aten::triplet_margin_loss",
223    "aten::unflatten_dense_tensors",
224    "aten::vander",
225    "aten::var.correction_names",
226    "aten::var.names_dim",
227    "aten::var_mean.correction_names",
228    "aten::var_mean.names_dim",
229    "aten::where",
230    "aten::wrapped_linear_prepack",
231    "aten::wrapped_quantized_linear_prepacked",
232}
233
234
235def dispatch_registrations(
236    dispatch_key: str, xfails: set, filter_func: typing.Callable = lambda reg: True
237):
238    registrations = sorted(get_registrations_for_dispatch_key(dispatch_key))
239    subtests = [
240        subtest(
241            reg,
242            name=f"[{reg}]",
243            decorators=([unittest.expectedFailure] if reg in xfails else []),
244        )
245        for reg in registrations
246        if filter_func(reg)
247    ]
248    return parametrize("registration", subtests)
249
250
251CompositeImplicitAutogradRegistrations = set(
252    get_registrations_for_dispatch_key("CompositeImplicitAutograd")
253)
254FuncTorchBatchedRegistrations = set(
255    get_registrations_for_dispatch_key("FuncTorchBatched")
256)
257FuncTorchBatchedDecompositionRegistrations = set(
258    get_registrations_for_dispatch_key("FuncTorchBatchedDecomposition")
259)
260
261
262def filter_vmap_implementable(reg):
263    reg = reg.lower()
264    if not reg.startswith("aten::"):
265        return False
266    if reg.startswith("aten::_"):
267        return False
268    if reg.endswith(".out"):
269        return False
270    if reg.endswith("_out"):
271        return False
272    if ".dimname" in reg:
273        return False
274    if "_dimname" in reg:
275        return False
276    if "fbgemm" in reg:
277        return False
278    if "quantize" in reg:
279        return False
280    if "sparse" in reg:
281        return False
282    if "::is_" in reg:
283        return False
284    return True
285
286
287class TestFunctorchDispatcher(TestCase):
288    @dispatch_registrations("CompositeImplicitAutograd", xfail_functorch_batched)
289    def test_register_a_batching_rule_for_composite_implicit_autograd(
290        self, registration
291    ):
292        assert registration not in FuncTorchBatchedRegistrations, (
293            f"You've added a batching rule for a CompositeImplicitAutograd operator {registration}. "
294            "The correct way to add vmap support for it is to put it into BatchRulesDecomposition to "
295            "reuse the CompositeImplicitAutograd decomposition"
296        )
297
298    @dispatch_registrations(
299        "FuncTorchBatchedDecomposition", xfail_functorch_batched_decomposition
300    )
301    def test_register_functorch_batched_decomposition(self, registration):
302        assert registration in CompositeImplicitAutogradRegistrations, (
303            f"The registrations in BatchedDecompositions.cpp must be for CompositeImplicitAutograd "
304            f"operations. If your operation {registration} is not CompositeImplicitAutograd, "
305            "then please register it to the FuncTorchBatched key in another file."
306        )
307
308    @dispatch_registrations(
309        "CompositeImplicitAutograd", xfail_not_implemented, filter_vmap_implementable
310    )
311    def test_unimplemented_batched_registrations(self, registration):
312        assert registration in FuncTorchBatchedDecompositionRegistrations, (
313            f"Please check that there is an OpInfo that covers the operator {registration} "
314            "and add a registration in BatchedDecompositions.cpp. "
315            "If your operator isn't user facing, please add it to the xfail list"
316        )
317
318
319instantiate_parametrized_tests(TestFunctorchDispatcher)
320
321if __name__ == "__main__":
322    run_tests()
323