1*da0073e9SAndroid Build Coastguard Worker #include <torch/library.h>
2*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/boxing/KernelFunction.h>
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker using torch::CppFunction;
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker namespace at {
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker // Note: [DispatchKey::VmapMode usage]
9*da0073e9SAndroid Build Coastguard Worker // Whenever we're inside a vmap, all Tensors dispatch on this key. At the moment,
10*da0073e9SAndroid Build Coastguard Worker // this key is used to disable random operations inside of vmap. If you are looking
11*da0073e9SAndroid Build Coastguard Worker // for Batching Rules, those are registered with DispatchKey::Batched instead.
12*da0073e9SAndroid Build Coastguard Worker //
13*da0073e9SAndroid Build Coastguard Worker // Note: [Ambiguity of random operations inside vmap]
14*da0073e9SAndroid Build Coastguard Worker // Random operations have an ambiguity where it isn't clear if they should
15*da0073e9SAndroid Build Coastguard Worker // apply the same randomness or apply different randomness. For example:
16*da0073e9SAndroid Build Coastguard Worker //
17*da0073e9SAndroid Build Coastguard Worker // >>> vmap(lambda t: torch.rand(1))(torch.zeros(5))
18*da0073e9SAndroid Build Coastguard Worker // Should the above return the same random number 5 times, or a different one?
19*da0073e9SAndroid Build Coastguard Worker //
20*da0073e9SAndroid Build Coastguard Worker // We haven't made a decision on that yet so we are temporarily banning random
21*da0073e9SAndroid Build Coastguard Worker // operations inside of vmap while we gather user feedback.
22*da0073e9SAndroid Build Coastguard Worker
unsupportedRandomOp(Args...args)23*da0073e9SAndroid Build Coastguard Worker template <typename... Args> Tensor unsupportedRandomOp(Args... args) {
24*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ",
25*da0073e9SAndroid Build Coastguard Worker "Please perform random operations outside of vmap as a workaround");
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker
unsupportedRandomOp_(Args...args)28*da0073e9SAndroid Build Coastguard Worker template <typename... Args> Tensor& unsupportedRandomOp_(Args... args) {
29*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(false, "vmap: We do not yet support calling random operations inside of vmap. ",
30*da0073e9SAndroid Build Coastguard Worker "Please perform random operations outside of vmap as a workaround");
31*da0073e9SAndroid Build Coastguard Worker }
32*da0073e9SAndroid Build Coastguard Worker
TORCH_LIBRARY_IMPL(_,VmapMode,m)33*da0073e9SAndroid Build Coastguard Worker TORCH_LIBRARY_IMPL(_, VmapMode, m) {
34*da0073e9SAndroid Build Coastguard Worker m.fallback(torch::CppFunction::makeFallthrough());
35*da0073e9SAndroid Build Coastguard Worker }
36*da0073e9SAndroid Build Coastguard Worker
TORCH_LIBRARY_IMPL(aten,VmapMode,m)37*da0073e9SAndroid Build Coastguard Worker TORCH_LIBRARY_IMPL(aten, VmapMode, m) {
38*da0073e9SAndroid Build Coastguard Worker // NB: I'd really like to register a special kernel like
39*da0073e9SAndroid Build Coastguard Worker // CppFunction::makeNamedNotSupported() to avoid listing out the types of everything.
40*da0073e9SAndroid Build Coastguard Worker // However, registering e.g. CppFunction::makeNamedNotSupported() as an implementation
41*da0073e9SAndroid Build Coastguard Worker // only works for operators that support boxing.
42*da0073e9SAndroid Build Coastguard Worker #define TENSOROPTIONS std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker // random operations (out-of-place)
45*da0073e9SAndroid Build Coastguard Worker m.impl("bernoulli", unsupportedRandomOp<const Tensor&, std::optional<Generator>>);
46*da0073e9SAndroid Build Coastguard Worker m.impl("bernoulli.out", unsupportedRandomOp_<const Tensor&, std::optional<Generator>, Tensor&>);
47*da0073e9SAndroid Build Coastguard Worker m.impl("bernoulli.p", unsupportedRandomOp<const Tensor&, double, std::optional<Generator>>);
48*da0073e9SAndroid Build Coastguard Worker m.impl("bernoulli_.Tensor", unsupportedRandomOp_<Tensor&, const Tensor&, std::optional<Generator>>);
49*da0073e9SAndroid Build Coastguard Worker m.impl("bernoulli_.float", unsupportedRandomOp_<Tensor&, double, std::optional<Generator>>);
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker m.impl("cauchy_", unsupportedRandomOp_<Tensor&, double, double, std::optional<Generator>>);
52*da0073e9SAndroid Build Coastguard Worker m.impl("exponential_", unsupportedRandomOp_<Tensor&, double, std::optional<Generator>>);
53*da0073e9SAndroid Build Coastguard Worker m.impl("geometric_", unsupportedRandomOp_<Tensor&, double, std::optional<Generator>>);
54*da0073e9SAndroid Build Coastguard Worker m.impl("log_normal_", unsupportedRandomOp_<Tensor&, double, double, std::optional<Generator>>);
55*da0073e9SAndroid Build Coastguard Worker m.impl("multinomial", unsupportedRandomOp<const Tensor&, int64_t, bool, std::optional<Generator>>);
56*da0073e9SAndroid Build Coastguard Worker m.impl("multinomial.out", unsupportedRandomOp_<const Tensor&, int64_t, bool, std::optional<Generator>, Tensor&>);
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker m.impl("normal.Tensor_float", unsupportedRandomOp<const Tensor&, double, std::optional<Generator>>);
59*da0073e9SAndroid Build Coastguard Worker m.impl("normal.Tensor_float_out", unsupportedRandomOp_<const Tensor&, double, std::optional<Generator>, Tensor&>);
60*da0073e9SAndroid Build Coastguard Worker m.impl("normal.float_Tensor_out", unsupportedRandomOp_<double, const Tensor&, std::optional<Generator>, Tensor&>);
61*da0073e9SAndroid Build Coastguard Worker m.impl("normal.float_Tensor", unsupportedRandomOp<double, const Tensor&, std::optional<Generator>>);
62*da0073e9SAndroid Build Coastguard Worker m.impl("normal.Tensor_Tensor", unsupportedRandomOp<const Tensor&, const Tensor&, std::optional<Generator>>);
63*da0073e9SAndroid Build Coastguard Worker m.impl("normal.Tensor_Tensor_out", unsupportedRandomOp_<const Tensor&, const Tensor&, std::optional<Generator>, Tensor&>);
64*da0073e9SAndroid Build Coastguard Worker m.impl("normal.float_float", unsupportedRandomOp<double, double, IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
65*da0073e9SAndroid Build Coastguard Worker m.impl("normal.float_float_out", unsupportedRandomOp_<double, double, IntArrayRef, std::optional<Generator>, Tensor&>);
66*da0073e9SAndroid Build Coastguard Worker m.impl("normal_", unsupportedRandomOp_<Tensor&, double, double, std::optional<Generator>>);
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker m.impl("poisson", unsupportedRandomOp<const Tensor&, std::optional<Generator>>);
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker m.impl("random_.from", unsupportedRandomOp_<Tensor&, int64_t, std::optional<int64_t>, std::optional<Generator>>);
71*da0073e9SAndroid Build Coastguard Worker m.impl("random_.to", unsupportedRandomOp_<Tensor&, int64_t, std::optional<Generator>>);
72*da0073e9SAndroid Build Coastguard Worker m.impl("random_", unsupportedRandomOp_<Tensor&, std::optional<Generator>>);
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker m.impl("rand_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
75*da0073e9SAndroid Build Coastguard Worker m.impl("randn_like", unsupportedRandomOp<const Tensor&, TENSOROPTIONS, std::optional<MemoryFormat>>);
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker m.impl("randint_like", unsupportedRandomOp<const Tensor&, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
78*da0073e9SAndroid Build Coastguard Worker m.impl("randint_like.low_dtype", unsupportedRandomOp<const Tensor&, int64_t, int64_t, TENSOROPTIONS, std::optional<MemoryFormat>>);
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker m.impl("rand", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
81*da0073e9SAndroid Build Coastguard Worker m.impl("rand.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
82*da0073e9SAndroid Build Coastguard Worker m.impl("rand.names", unsupportedRandomOp<IntArrayRef, std::optional<DimnameList>, TENSOROPTIONS>);
83*da0073e9SAndroid Build Coastguard Worker m.impl("rand.generator_with_names", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, std::optional<DimnameList>, TENSOROPTIONS>);
84*da0073e9SAndroid Build Coastguard Worker m.impl("rand.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
85*da0073e9SAndroid Build Coastguard Worker m.impl("rand.generator_out", unsupportedRandomOp_<IntArrayRef, std::optional<Generator>, Tensor&>);
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker m.impl("randn", unsupportedRandomOp<IntArrayRef, TENSOROPTIONS>);
88*da0073e9SAndroid Build Coastguard Worker m.impl("randn.generator", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
89*da0073e9SAndroid Build Coastguard Worker m.impl("randn.names", unsupportedRandomOp<IntArrayRef, std::optional<DimnameList>, TENSOROPTIONS>);
90*da0073e9SAndroid Build Coastguard Worker m.impl("randn.generator_with_names", unsupportedRandomOp<IntArrayRef, std::optional<Generator>, std::optional<DimnameList>, TENSOROPTIONS>);
91*da0073e9SAndroid Build Coastguard Worker m.impl("randn.out", unsupportedRandomOp_<IntArrayRef, Tensor&>);
92*da0073e9SAndroid Build Coastguard Worker m.impl("randn.generator_out", unsupportedRandomOp_<IntArrayRef, std::optional<Generator>, Tensor&>);
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker m.impl("randperm", unsupportedRandomOp<int64_t, TENSOROPTIONS>);
95*da0073e9SAndroid Build Coastguard Worker m.impl("randperm.generator", unsupportedRandomOp<int64_t, std::optional<Generator>, TENSOROPTIONS>);
96*da0073e9SAndroid Build Coastguard Worker m.impl("randperm.out", unsupportedRandomOp_<int64_t, Tensor&>);
97*da0073e9SAndroid Build Coastguard Worker m.impl("randperm.generator_out", unsupportedRandomOp_<int64_t, std::optional<Generator>, Tensor&>);
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker m.impl("randint", unsupportedRandomOp<int64_t, IntArrayRef, TENSOROPTIONS>);
100*da0073e9SAndroid Build Coastguard Worker m.impl("randint.generator", unsupportedRandomOp<int64_t, IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
101*da0073e9SAndroid Build Coastguard Worker m.impl("randint.low", unsupportedRandomOp<int64_t, int64_t, IntArrayRef, TENSOROPTIONS>);
102*da0073e9SAndroid Build Coastguard Worker m.impl("randint.low_generator", unsupportedRandomOp<int64_t, int64_t, IntArrayRef, std::optional<Generator>, TENSOROPTIONS>);
103*da0073e9SAndroid Build Coastguard Worker m.impl("randint.out", unsupportedRandomOp_<int64_t, IntArrayRef, Tensor&>);
104*da0073e9SAndroid Build Coastguard Worker m.impl("randint.generator_out", unsupportedRandomOp_<int64_t, IntArrayRef, std::optional<Generator>, Tensor&>);
105*da0073e9SAndroid Build Coastguard Worker m.impl("randint.low_out", unsupportedRandomOp_<int64_t, int64_t, IntArrayRef, Tensor&>);
106*da0073e9SAndroid Build Coastguard Worker m.impl("randint.low_generator_out", unsupportedRandomOp_<int64_t, int64_t, IntArrayRef, std::optional<Generator>, Tensor&>);
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker m.impl("uniform_", unsupportedRandomOp_<Tensor&, double, double, std::optional<Generator>>);
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker #undef TENSOROPTIONS
111*da0073e9SAndroid Build Coastguard Worker }
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker } // namespace at
115