xref: /aosp_15_r20/external/pytorch/aten/src/ATen/functorch/BatchRulesHelper.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #include <ATen/functorch/BatchRulesHelper.h>
8 #include <ATen/WrapDimUtils.h>
9 
10 namespace at::functorch {
11 
moveBatchDimToFront(const Tensor & tensor,std::optional<int64_t> maybe_batch_dim)12 Tensor moveBatchDimToFront(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim) {
13   if (!maybe_batch_dim.has_value()) {
14     return tensor;
15   }
16   if (maybe_batch_dim.value() == 0) {
17     return tensor;
18   }
19   return tensor.movedim(maybe_batch_dim.value(), 0);
20 }
21 
rankWithoutBatchDim(const Tensor & tensor,std::optional<int64_t> maybe_batch_dim)22 int64_t rankWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim) {
23   int64_t result = tensor.dim();
24   if (maybe_batch_dim.has_value()) {
25     result -= 1;
26   }
27   return result;
28 }
29 
numelWithoutBatchDim(const Tensor & tensor,std::optional<int64_t> maybe_batch_dim)30 int64_t numelWithoutBatchDim(const Tensor& tensor, std::optional<int64_t> maybe_batch_dim) {
31   if (!maybe_batch_dim) {
32     return tensor.numel();
33   }
34   return tensor.numel() / tensor.size(*maybe_batch_dim);
35 }
36 
valIfNonempty(std::optional<int64_t> maybe_empty,int64_t new_val)37 std::optional<int64_t> valIfNonempty(std::optional<int64_t> maybe_empty, int64_t new_val) {
38   if (maybe_empty.has_value()) {
39     return new_val;
40   }
41   return std::nullopt;
42 }
43 
getPhysicalDim(const Tensor & tensor,bool has_batch_dim,int64_t logical_dim)44 int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
45   // NB: assumes the batch dim is at the front of the tensor
46   std::optional<int64_t> bdim = has_batch_dim ? std::optional<int64_t>(0) : std::nullopt;
47   auto rank = rankWithoutBatchDim(tensor, bdim);
48   auto wrapped_dim = maybe_wrap_dim(logical_dim, rank);
49   if (has_batch_dim) {
50     return wrapped_dim + 1;
51   }
52   return wrapped_dim;
53 }
54 
getPhysicalDims(const Tensor & tensor,bool has_batch_dim,IntArrayRef logical_dims)55 VmapDimVector getPhysicalDims(const Tensor& tensor, bool has_batch_dim, IntArrayRef logical_dims) {
56   // NB: assumes the batch dim is at the front of the tensor
57   std::optional<int64_t> bdim = has_batch_dim ? std::optional<int64_t>(0) : std::nullopt;
58   auto rank = rankWithoutBatchDim(tensor, bdim);
59   VmapDimVector result;
60   result.reserve(logical_dims.size());
61   for (auto d : logical_dims){
62     if (has_batch_dim) {
63       result.push_back(maybe_wrap_dim(d, rank)+1);
64     } else {
65       result.push_back(maybe_wrap_dim(d, rank));
66     }
67   }
68   return result;
69 }
70 
maybePadToLogicalRank(const Tensor & tensor,std::optional<int64_t> has_bdim,int64_t logical_rank)71 Tensor maybePadToLogicalRank(const Tensor& tensor, std::optional<int64_t> has_bdim, int64_t logical_rank) {
72   if (!has_bdim) {
73     return tensor;
74   }
75   auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
76   if (tensor_logical_rank >= logical_rank) {
77     return tensor;
78   }
79   VmapSymDimVector new_sizes(tensor.sym_sizes().begin(), tensor.sym_sizes().end());
80   for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
81     new_sizes.insert(new_sizes.begin() + 1, 1);
82   }
83   return tensor.view_symint(SymIntArrayRef{new_sizes.begin(), new_sizes.end()});
84 }
85 
check_randomness(RandomnessType randomness,bool any_tensor_batched)86 void check_randomness(RandomnessType randomness, bool any_tensor_batched) {
87   TORCH_CHECK(
88     randomness != RandomnessType::Error,
89     "vmap: called random operation while in randomness error mode. Please either use the "
90     "'same' or 'different' randomness flags on vmap or perform the randomness operation out of vmap"
91   );
92 
93   TORCH_CHECK(
94     !(randomness == RandomnessType::Same && any_tensor_batched),
95     "Vmap does not currently support same randomness with a batched tensor input. ",
96     "Please file an issue with functorch"
97   )
98 }
99 
check_randomness(RandomnessType randomness)100 void check_randomness(RandomnessType randomness) {
101   check_randomness(randomness, false); // for ops that don't take in any tensors, don't hit same error
102 }
103 
reshape_dim_into(int64_t src,int64_t dst,const Tensor & x)104 Tensor reshape_dim_into(int64_t src, int64_t dst, const Tensor& x) {
105   auto x_dim = x.dim();
106   src = maybe_wrap_dim(src, x_dim);
107   dst = maybe_wrap_dim(dst, x_dim - 1); // Returned Tensor has one fewer dim
108   VmapDimVector new_shape(x.sizes().begin(), x.sizes().end());
109   new_shape.erase(new_shape.begin() + src);
110   new_shape[dst] *= x.sizes()[src];
111   return at::reshape(x.movedim(src, dst), new_shape);
112 }
113 
reshape_dim_outof(int64_t src,int64_t size1,const Tensor & x)114 Tensor reshape_dim_outof(int64_t src, int64_t size1, const Tensor& x) {
115   src = maybe_wrap_dim(src, x.dim());
116   VmapDimVector shape(x.sizes().begin(), x.sizes().end());
117   if (shape[src] != 0) {
118     // NOTE: 0 % 0 leads to FPE
119     TORCH_INTERNAL_ASSERT(shape[src] % size1 == 0);
120   }
121   // split any size out of `0`-sized dim
122   int64_t size2 = 0;
123   if (shape[src] != 0) {
124     size2 = shape[src] / size1;
125   }
126   shape[src] = size1;
127   shape.insert(shape.begin() + src + 1, size2);
128   return at::reshape(x, shape);
129 }
130 
reshape_dim_outof_symint(int64_t src,const c10::SymInt & size1,const Tensor & x)131 Tensor reshape_dim_outof_symint(int64_t src, const c10::SymInt& size1, const Tensor& x) {
132   src = maybe_wrap_dim(src, x.dim());
133   c10::SymDimVector shape(x.sym_sizes().begin(), x.sym_sizes().end());
134   if (shape[src] != 0) {
135     // NOTE: 0 % 0 leads to FPE
136     TORCH_INTERNAL_ASSERT(shape[src] % size1 == 0);
137   }
138   c10::SymInt size2;
139   // split any size out of `0`-sized dim
140   if (shape[src] == 0) {
141     size2 = 0;
142   } else {
143     size2 = shape[src] / size1;
144   }
145   shape[src] = size1;
146   shape.insert(shape.begin() + src + 1, size2);
147   return at::reshape_symint(x, shape);
148 }
149 
vmapIncompatibleInplaceError(const char * schema_name)150 void vmapIncompatibleInplaceError(const char* schema_name) {
151   TORCH_CHECK(false,
152     "vmap: ", schema_name, "(self, *extra_args) is not possible because ",
153     "there exists a Tensor `other` in extra_args that has more elements ",
154     "than `self`. This happened due to `other` being vmapped over but ",
155     "`self` not being vmapped over in a vmap. ",
156     "Please try to use out-of-place operators instead of ", schema_name, ". ",
157     "If said operator is being called inside the PyTorch framework, ",
158     "please file a bug report instead.");
159 }
160 
handleScalarTypePromotion(Tensor & logical_scalar_tensor,Tensor & second)161 static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
162   auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
163   if (logical_scalar_tensor.scalar_type() != result_type) {
164     logical_scalar_tensor = logical_scalar_tensor.to(result_type);
165   }
166   if (second.scalar_type() != result_type) {
167     second = second.to(result_type);
168   }
169 }
170 
_binary_pointwise_helper(const Tensor & tensor,std::optional<int64_t> tensor_batch_dim,const Tensor & other,std::optional<int64_t> other_batch_dim,bool do_type_promotion)171 std::tuple<Tensor, Tensor> _binary_pointwise_helper(
172     const Tensor& tensor, std::optional<int64_t> tensor_batch_dim,
173     const Tensor& other, std::optional<int64_t> other_batch_dim,
174     bool do_type_promotion) {
175   // compute max logical rank
176   auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
177   auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
178   auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
179 
180   auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
181   auto other_ = moveBatchDimToFront(other, other_batch_dim);
182 
183   // In the (0D, ND) case, type promotion semantics are different :/
184   if (do_type_promotion) {
185     auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value());
186     auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value());
187     if (tensor_is_logical_scalar && !other_is_logical_scalar) {
188       handleScalarTypePromotion(tensor_, other_);
189     }
190     if (other_is_logical_scalar && !tensor_is_logical_scalar) {
191       handleScalarTypePromotion(other_, tensor_);
192     }
193   }
194 
195   // If the dimensions aren't aligned, we need to line them up.
196   // Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
197   // Note that only tensors that have a batch dim need to be modified.
198   // Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
199   tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
200   other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
201 
202   return std::make_tuple(tensor_, other_);
203 }
204 
205 } // namespace at::functorch
206