#include #include using torch::CppFunction; namespace at { // Note: [DispatchKey::VmapMode usage] // Whenever we're inside a vmap, all Tensors dispatch on this key. At the moment, // this key is used to disable random operations inside of vmap. If you are looking // for Batching Rules, those are registered with DispatchKey::Batched instead. // // Note: [Ambiguity of random operations inside vmap] // Random operations have an ambiguity where it isn't clear if they should // apply the same randomness or apply different randomness. For example: // // >>> vmap(lambda t: torch.rand(1))(torch.zeros(5)) // Should the above return the same random number 5 times, or a different one? // // We haven't made a decision on that yet so we are temporarily banning random // operations inside of vmap while we gather user feedback. template Tensor unsupportedRandomOp(Args... args) { TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ", "Please perform random operations outside of vmap as a workaround"); } template Tensor& unsupportedRandomOp_(Args... args) { TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ", "Please perform random operations outside of vmap as a workaround"); } TORCH_LIBRARY_IMPL(_, VmapMode, m) { m.fallback(torch::CppFunction::makeFallthrough()); } TORCH_LIBRARY_IMPL(aten, VmapMode, m) { // NB: I'd really like to register a special kernel like // CppFunction::makeNamedNotSupported() to avoid listing out the types of everything. // However, registering e.g. CppFunction::makeNamedNotSupported() as an implementation // only works for operators that support boxing. #define TENSOROPTIONS std::optional, std::optional, std::optional, std::optional // random operations (out-of-place) m.impl("bernoulli", unsupportedRandomOp>); m.impl("bernoulli.out", unsupportedRandomOp_, Tensor&>); m.impl("bernoulli.p", unsupportedRandomOp>); m.impl("bernoulli_.Tensor", unsupportedRandomOp_>); m.impl("bernoulli_.float", unsupportedRandomOp_>); m.impl("cauchy_", unsupportedRandomOp_>); m.impl("exponential_", unsupportedRandomOp_>); m.impl("geometric_", unsupportedRandomOp_>); m.impl("log_normal_", unsupportedRandomOp_>); m.impl("multinomial", unsupportedRandomOp>); m.impl("multinomial.out", unsupportedRandomOp_, Tensor&>); m.impl("normal.Tensor_float", unsupportedRandomOp>); m.impl("normal.Tensor_float_out", unsupportedRandomOp_, Tensor&>); m.impl("normal.float_Tensor_out", unsupportedRandomOp_, Tensor&>); m.impl("normal.float_Tensor", unsupportedRandomOp>); m.impl("normal.Tensor_Tensor", unsupportedRandomOp>); m.impl("normal.Tensor_Tensor_out", unsupportedRandomOp_, Tensor&>); m.impl("normal.float_float", unsupportedRandomOp, TENSOROPTIONS>); m.impl("normal.float_float_out", unsupportedRandomOp_, Tensor&>); m.impl("normal_", unsupportedRandomOp_>); m.impl("poisson", unsupportedRandomOp>); m.impl("random_.from", unsupportedRandomOp_, std::optional>); m.impl("random_.to", unsupportedRandomOp_>); m.impl("random_", unsupportedRandomOp_>); m.impl("rand_like", unsupportedRandomOp>); m.impl("randn_like", unsupportedRandomOp>); m.impl("randint_like", unsupportedRandomOp>); m.impl("randint_like.low_dtype", unsupportedRandomOp>); m.impl("rand", unsupportedRandomOp); m.impl("rand.generator", unsupportedRandomOp, TENSOROPTIONS>); m.impl("rand.names", unsupportedRandomOp, TENSOROPTIONS>); m.impl("rand.generator_with_names", unsupportedRandomOp, std::optional, TENSOROPTIONS>); m.impl("rand.out", unsupportedRandomOp_); m.impl("rand.generator_out", unsupportedRandomOp_, Tensor&>); m.impl("randn", unsupportedRandomOp); m.impl("randn.generator", unsupportedRandomOp, TENSOROPTIONS>); m.impl("randn.names", unsupportedRandomOp, TENSOROPTIONS>); m.impl("randn.generator_with_names", unsupportedRandomOp, std::optional, TENSOROPTIONS>); m.impl("randn.out", unsupportedRandomOp_); m.impl("randn.generator_out", unsupportedRandomOp_, Tensor&>); m.impl("randperm", unsupportedRandomOp); m.impl("randperm.generator", unsupportedRandomOp, TENSOROPTIONS>); m.impl("randperm.out", unsupportedRandomOp_); m.impl("randperm.generator_out", unsupportedRandomOp_, Tensor&>); m.impl("randint", unsupportedRandomOp); m.impl("randint.generator", unsupportedRandomOp, TENSOROPTIONS>); m.impl("randint.low", unsupportedRandomOp); m.impl("randint.low_generator", unsupportedRandomOp, TENSOROPTIONS>); m.impl("randint.out", unsupportedRandomOp_); m.impl("randint.generator_out", unsupportedRandomOp_, Tensor&>); m.impl("randint.low_out", unsupportedRandomOp_); m.impl("randint.low_generator_out", unsupportedRandomOp_, Tensor&>); m.impl("uniform_", unsupportedRandomOp_>); #undef TENSOROPTIONS } } // namespace at