1# Owner(s): ["module: functorch"] 2import typing 3import unittest 4 5from torch._C import ( 6 _dispatch_get_registrations_for_dispatch_key as get_registrations_for_dispatch_key, 7) 8from torch.testing._internal.common_utils import ( 9 instantiate_parametrized_tests, 10 parametrize, 11 run_tests, 12 subtest, 13 TestCase, 14) 15 16 17xfail_functorch_batched = { 18 "aten::is_nonzero", 19 "aten::item", 20 "aten::linalg_slogdet", 21 "aten::masked_select_backward", 22 "aten::one_hot", 23 "aten::silu_backward", 24 "aten::where", 25} 26 27xfail_functorch_batched_decomposition = { 28 "aten::alias_copy", 29 "aten::as_strided_copy", 30 "aten::diagonal_copy", 31 "aten::is_same_size", 32 "aten::unfold_copy", 33} 34 35xfail_not_implemented = { 36 "aten::affine_grid_generator_backward", 37 "aten::align_as", 38 "aten::align_tensors", 39 "aten::align_to", 40 "aten::align_to.ellipsis_idx", 41 "aten::alpha_dropout", 42 "aten::alpha_dropout_", 43 "aten::argwhere", 44 "aten::bilinear", 45 "aten::can_cast", 46 "aten::cat.names", 47 "aten::chain_matmul", 48 "aten::chalf", 49 "aten::choose_qparams_optimized", 50 "aten::clip_", 51 "aten::clip_.Tensor", 52 "aten::coalesce", 53 "aten::column_stack", 54 "aten::concat.names", 55 "aten::concatenate.names", 56 "aten::conj", 57 "aten::conv_tbc_backward", 58 "aten::ctc_loss.IntList", 59 "aten::ctc_loss.Tensor", 60 "aten::cudnn_is_acceptable", 61 "aten::cummaxmin_backward", 62 "aten::data", 63 "aten::diagflat", 64 "aten::divide.out_mode", 65 "aten::divide_.Scalar", 66 "aten::dropout_", 67 "aten::embedding_bag", 68 "aten::embedding_bag.padding_idx", 69 "aten::feature_alpha_dropout", 70 "aten::feature_alpha_dropout_", 71 "aten::feature_dropout", 72 "aten::feature_dropout_", 73 "aten::fft_ihfft2", 74 "aten::fft_ihfftn", 75 "aten::fill_diagonal_", 76 "aten::fix_", 77 "aten::flatten.named_out_dim", 78 "aten::flatten.using_names", 79 "aten::flatten_dense_tensors", 80 "aten::float_power_.Scalar", 81 "aten::float_power_.Tensor", 82 "aten::floor_divide_.Scalar", 83 "aten::frobenius_norm", 84 "aten::fused_moving_avg_obs_fake_quant", 85 "aten::get_gradients", 86 "aten::greater_.Scalar", 87 "aten::greater_.Tensor", 88 "aten::greater_equal_.Scalar", 89 "aten::greater_equal_.Tensor", 90 "aten::gru.data", 91 "aten::gru.input", 92 "aten::gru_cell", 93 "aten::histogramdd", 94 "aten::histogramdd.TensorList_bins", 95 "aten::histogramdd.int_bins", 96 "aten::infinitely_differentiable_gelu_backward", 97 "aten::isclose", 98 "aten::istft", 99 "aten::item", 100 "aten::kl_div", 101 "aten::ldexp_", 102 "aten::less_.Scalar", 103 "aten::less_.Tensor", 104 "aten::less_equal_.Scalar", 105 "aten::less_equal_.Tensor", 106 "aten::linalg_cond.p_str", 107 "aten::linalg_eigh.eigvals", 108 "aten::linalg_matrix_rank", 109 "aten::linalg_matrix_rank.out_tol_tensor", 110 "aten::linalg_matrix_rank.tol_tensor", 111 "aten::linalg_pinv.out_rcond_tensor", 112 "aten::linalg_pinv.rcond_tensor", 113 "aten::linalg_slogdet", 114 "aten::linalg_svd.U", 115 "aten::linalg_tensorsolve", 116 "aten::logsumexp.names", 117 "aten::lstm.data", 118 "aten::lstm.input", 119 "aten::lstm_cell", 120 "aten::lu_solve", 121 "aten::margin_ranking_loss", 122 "aten::masked_select_backward", 123 "aten::matrix_exp", 124 "aten::matrix_exp_backward", 125 "aten::max.names_dim", 126 "aten::max.names_dim_max", 127 "aten::mean.names_dim", 128 "aten::median.names_dim", 129 "aten::median.names_dim_values", 130 "aten::min.names_dim", 131 "aten::min.names_dim_min", 132 "aten::mish_backward", 133 "aten::moveaxis.int", 134 "aten::multilabel_margin_loss", 135 "aten::nanmedian.names_dim", 136 "aten::nanmedian.names_dim_values", 137 "aten::nanquantile", 138 "aten::nanquantile.scalar", 139 "aten::narrow.Tensor", 140 "aten::native_channel_shuffle", 141 "aten::negative_", 142 "aten::nested_to_padded_tensor", 143 "aten::nonzero_numpy", 144 "aten::norm.names_ScalarOpt_dim", 145 "aten::norm.names_ScalarOpt_dim_dtype", 146 "aten::norm_except_dim", 147 "aten::not_equal_.Scalar", 148 "aten::not_equal_.Tensor", 149 "aten::one_hot", 150 "aten::output_nr", 151 "aten::pad_sequence", 152 "aten::pdist", 153 "aten::pin_memory", 154 "aten::promote_types", 155 "aten::qr.Q", 156 "aten::quantile", 157 "aten::quantile.scalar", 158 "aten::refine_names", 159 "aten::rename", 160 "aten::rename_", 161 "aten::requires_grad_", 162 "aten::retain_grad", 163 "aten::retains_grad", 164 "aten::rnn_relu.data", 165 "aten::rnn_relu.input", 166 "aten::rnn_relu_cell", 167 "aten::rnn_tanh.data", 168 "aten::rnn_tanh.input", 169 "aten::rnn_tanh_cell", 170 "aten::set_.source_Tensor_storage_offset", 171 "aten::set_data", 172 "aten::silu_backward", 173 "aten::slow_conv3d", 174 "aten::smm", 175 "aten::special_chebyshev_polynomial_t.n_scalar", 176 "aten::special_chebyshev_polynomial_t.x_scalar", 177 "aten::special_chebyshev_polynomial_u.n_scalar", 178 "aten::special_chebyshev_polynomial_u.x_scalar", 179 "aten::special_chebyshev_polynomial_v.n_scalar", 180 "aten::special_chebyshev_polynomial_v.x_scalar", 181 "aten::special_chebyshev_polynomial_w.n_scalar", 182 "aten::special_chebyshev_polynomial_w.x_scalar", 183 "aten::special_hermite_polynomial_h.n_scalar", 184 "aten::special_hermite_polynomial_h.x_scalar", 185 "aten::special_hermite_polynomial_he.n_scalar", 186 "aten::special_hermite_polynomial_he.x_scalar", 187 "aten::special_laguerre_polynomial_l.n_scalar", 188 "aten::special_laguerre_polynomial_l.x_scalar", 189 "aten::special_legendre_polynomial_p.n_scalar", 190 "aten::special_legendre_polynomial_p.x_scalar", 191 "aten::special_shifted_chebyshev_polynomial_t.n_scalar", 192 "aten::special_shifted_chebyshev_polynomial_t.x_scalar", 193 "aten::special_shifted_chebyshev_polynomial_u.n_scalar", 194 "aten::special_shifted_chebyshev_polynomial_u.x_scalar", 195 "aten::special_shifted_chebyshev_polynomial_v.n_scalar", 196 "aten::special_shifted_chebyshev_polynomial_v.x_scalar", 197 "aten::special_shifted_chebyshev_polynomial_w.n_scalar", 198 "aten::special_shifted_chebyshev_polynomial_w.x_scalar", 199 "aten::square_", 200 "aten::sspaddmm", 201 "aten::std.correction_names", 202 "aten::std.names_dim", 203 "aten::std_mean.correction_names", 204 "aten::std_mean.names_dim", 205 "aten::stft", 206 "aten::stft.center", 207 "aten::stride.int", 208 "aten::subtract.Scalar", 209 "aten::subtract_.Scalar", 210 "aten::subtract_.Tensor", 211 "aten::svd.U", 212 "aten::sym_size.int", 213 "aten::sym_stride.int", 214 "aten::sym_numel", 215 "aten::sym_storage_offset", 216 "aten::tensor_split.tensor_indices_or_sections", 217 "aten::thnn_conv2d", 218 "aten::to_dense", 219 "aten::to_dense_backward", 220 "aten::to_mkldnn_backward", 221 "aten::trace_backward", 222 "aten::triplet_margin_loss", 223 "aten::unflatten_dense_tensors", 224 "aten::vander", 225 "aten::var.correction_names", 226 "aten::var.names_dim", 227 "aten::var_mean.correction_names", 228 "aten::var_mean.names_dim", 229 "aten::where", 230 "aten::wrapped_linear_prepack", 231 "aten::wrapped_quantized_linear_prepacked", 232} 233 234 235def dispatch_registrations( 236 dispatch_key: str, xfails: set, filter_func: typing.Callable = lambda reg: True 237): 238 registrations = sorted(get_registrations_for_dispatch_key(dispatch_key)) 239 subtests = [ 240 subtest( 241 reg, 242 name=f"[{reg}]", 243 decorators=([unittest.expectedFailure] if reg in xfails else []), 244 ) 245 for reg in registrations 246 if filter_func(reg) 247 ] 248 return parametrize("registration", subtests) 249 250 251CompositeImplicitAutogradRegistrations = set( 252 get_registrations_for_dispatch_key("CompositeImplicitAutograd") 253) 254FuncTorchBatchedRegistrations = set( 255 get_registrations_for_dispatch_key("FuncTorchBatched") 256) 257FuncTorchBatchedDecompositionRegistrations = set( 258 get_registrations_for_dispatch_key("FuncTorchBatchedDecomposition") 259) 260 261 262def filter_vmap_implementable(reg): 263 reg = reg.lower() 264 if not reg.startswith("aten::"): 265 return False 266 if reg.startswith("aten::_"): 267 return False 268 if reg.endswith(".out"): 269 return False 270 if reg.endswith("_out"): 271 return False 272 if ".dimname" in reg: 273 return False 274 if "_dimname" in reg: 275 return False 276 if "fbgemm" in reg: 277 return False 278 if "quantize" in reg: 279 return False 280 if "sparse" in reg: 281 return False 282 if "::is_" in reg: 283 return False 284 return True 285 286 287class TestFunctorchDispatcher(TestCase): 288 @dispatch_registrations("CompositeImplicitAutograd", xfail_functorch_batched) 289 def test_register_a_batching_rule_for_composite_implicit_autograd( 290 self, registration 291 ): 292 assert registration not in FuncTorchBatchedRegistrations, ( 293 f"You've added a batching rule for a CompositeImplicitAutograd operator {registration}. " 294 "The correct way to add vmap support for it is to put it into BatchRulesDecomposition to " 295 "reuse the CompositeImplicitAutograd decomposition" 296 ) 297 298 @dispatch_registrations( 299 "FuncTorchBatchedDecomposition", xfail_functorch_batched_decomposition 300 ) 301 def test_register_functorch_batched_decomposition(self, registration): 302 assert registration in CompositeImplicitAutogradRegistrations, ( 303 f"The registrations in BatchedDecompositions.cpp must be for CompositeImplicitAutograd " 304 f"operations. If your operation {registration} is not CompositeImplicitAutograd, " 305 "then please register it to the FuncTorchBatched key in another file." 306 ) 307 308 @dispatch_registrations( 309 "CompositeImplicitAutograd", xfail_not_implemented, filter_vmap_implementable 310 ) 311 def test_unimplemented_batched_registrations(self, registration): 312 assert registration in FuncTorchBatchedDecompositionRegistrations, ( 313 f"Please check that there is an OpInfo that covers the operator {registration} " 314 "and add a registration in BatchedDecompositions.cpp. " 315 "If your operator isn't user facing, please add it to the xfail list" 316 ) 317 318 319instantiate_parametrized_tests(TestFunctorchDispatcher) 320 321if __name__ == "__main__": 322 run_tests() 323